├── .gitignore ├── LICENSE ├── README ├── README.md ├── __init__.py ├── configs ├── create_configs.py ├── flow_config.py ├── inj_config.py ├── pivot_config.py ├── rave_config.py ├── sca_config.py └── temporal_config.py ├── constants.py ├── example_workflows ├── vv_sd15_example_workflow.json └── vv_sdxl_example_workflow.json ├── modules ├── vv_attention.py ├── vv_block.py ├── vv_resnet.py └── vv_unet.py ├── nodes ├── config_flow_attn_node.py ├── config_flow_inj_node.py ├── flow_config_node.py ├── get_flow_node.py ├── get_raft_flow_node.py ├── inj_config_node.py ├── pivot_config_node.py ├── rave_config_node.py ├── sca_config_node.py ├── temporal_config_node.py ├── vv_apply_node.py ├── vv_sampler_node.py └── vv_unsampler_node.py ├── requirements.txt ├── sea_raft ├── .gitignore ├── LICENSE ├── README.md ├── assets │ └── visualization.png ├── chairs_split.txt ├── config │ ├── eval │ │ ├── kitti-L.json │ │ ├── kitti-M.json │ │ ├── kitti-S.json │ │ ├── sintel-L.json │ │ ├── sintel-M.json │ │ ├── sintel-S.json │ │ ├── spring-L.json │ │ ├── spring-M.json │ │ └── spring-S.json │ ├── parser.py │ └── train │ │ ├── Tartan-C-T-TSKH-kitti432x960-M.json │ │ ├── Tartan-C-T-TSKH-kitti432x960-S.json │ │ ├── Tartan-C-T-TSKH-spring540x960-M.json │ │ ├── Tartan-C-T-TSKH-spring540x960-S.json │ │ ├── Tartan-C-T-TSKH432x960-M.json │ │ ├── Tartan-C-T-TSKH432x960-S.json │ │ ├── Tartan-C-T432x960-M.json │ │ ├── Tartan-C-T432x960-S.json │ │ ├── Tartan-C368x496-M.json │ │ ├── Tartan-C368x496-S.json │ │ ├── Tartan480x640-M.json │ │ └── Tartan480x640-S.json ├── core │ ├── __init__.py │ ├── corr.py │ ├── extractor.py │ ├── layer.py │ ├── loss.py │ ├── raft.py │ ├── update.py │ └── utils │ │ ├── __init__.py │ │ ├── augmentor.py │ │ ├── flow_transforms.py │ │ ├── flow_viz.py │ │ ├── frame_utils.py │ │ └── utils.py ├── custom.py ├── custom │ ├── flow.jpg │ ├── heatmap.jpg │ ├── image1.jpg │ └── image2.jpg ├── ddp_utils.py ├── demo.py ├── eval_ptlflow.py ├── evaluate.py ├── profile_ptlflow.py ├── profiler.py ├── requirements.txt ├── scripts │ ├── eval.sh │ ├── submission.sh │ └── train.sh ├── submission.py └── train.py ├── unimatch ├── DATASETS.md ├── LICENSE ├── MODEL_ZOO.md ├── README.md ├── conda_environment.yml ├── dataloader │ ├── __init__.py │ ├── depth │ │ ├── __init__.py │ │ ├── augmentation.py │ │ ├── datasets.py │ │ ├── download_demon_test.sh │ │ ├── download_demon_train.sh │ │ ├── prepare_demon_test.py │ │ ├── prepare_demon_train.py │ │ ├── scannet_banet_test_pairs.txt │ │ └── scannet_banet_train_pairs.txt │ ├── flow │ │ ├── __init__.py │ │ ├── chairs_split.txt │ │ ├── datasets.py │ │ └── transforms.py │ └── stereo │ │ ├── __init__.py │ │ ├── datasets.py │ │ └── transforms.py ├── demo │ ├── depth-scannet │ │ ├── color │ │ │ ├── 0048.png │ │ │ ├── 0054.png │ │ │ ├── 0060.png │ │ │ └── 0066.png │ │ ├── intrinsic │ │ │ └── intrinsic_depth.txt │ │ └── pose │ │ │ ├── 0048.txt │ │ │ ├── 0054.txt │ │ │ ├── 0060.txt │ │ │ └── 0066.txt │ ├── flow-davis │ │ ├── 00000.jpg │ │ ├── 00001.jpg │ │ └── 00002.jpg │ ├── kitti.mp4 │ └── stereo-middlebury │ │ ├── im0.png │ │ └── im1.png ├── evaluate_depth.py ├── evaluate_flow.py ├── evaluate_stereo.py ├── loss │ ├── __init__.py │ ├── depth_loss.py │ ├── flow_loss.py │ └── stereo_metric.py ├── main_depth.py ├── main_flow.py ├── main_stereo.py ├── pip_install.sh ├── scripts │ ├── gmdepth_demo.sh │ ├── gmdepth_evaluate.sh │ ├── gmdepth_scale1_regrefine1_train.sh │ ├── gmdepth_scale1_train.sh │ ├── gmflow_demo.sh │ ├── gmflow_evaluate.sh │ ├── gmflow_scale1_train.sh │ ├── gmflow_scale2_regrefine6_train.sh │ ├── gmflow_scale2_train.sh │ ├── gmflow_submission.sh │ ├── gmstereo_demo.sh │ ├── gmstereo_evaluate.sh │ ├── gmstereo_scale1_train.sh │ ├── gmstereo_scale2_regrefine3_train.sh │ ├── gmstereo_scale2_train.sh │ └── gmstereo_submission.sh ├── unimatch │ ├── __init__.py │ ├── attention.py │ ├── backbone.py │ ├── geometry.py │ ├── matching.py │ ├── position.py │ ├── reg_refine.py │ ├── transformer.py │ ├── trident_conv.py │ ├── unimatch.py │ └── utils.py └── utils │ ├── dist_utils.py │ ├── file_io.py │ ├── flow_viz.py │ ├── frame_utils.py │ ├── logger.py │ ├── misc.py │ ├── utils.py │ └── visualization.py ├── utils ├── attention_utils.py ├── batching_utils.py ├── feature_utils.py ├── flow_utils.py ├── flow_viz.py ├── module_utils.py ├── noise_utils.py ├── rave_utils.py └── sampler_utils.py └── vv_defaults.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | 3 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .nodes.vv_apply_node import ApplyVVModel 2 | from .nodes.get_flow_node import FlowGetFlowNode 3 | from .nodes.vv_unsampler_node import VVUnsamplerSamplerNode 4 | from .nodes.vv_sampler_node import VVSamplerSamplerNode 5 | from .nodes.inj_config_node import InjectionConfigNode 6 | from .nodes.pivot_config_node import PivotConfigNode 7 | from .nodes.rave_config_node import RaveConfigNode 8 | from .nodes.sca_config_node import SCAConfigNode 9 | from .nodes.flow_config_node import FlowConfigNode 10 | from .nodes.temporal_config_node import TemporalConfigNode 11 | from .nodes.get_raft_flow_node import GetRaftFlowNode 12 | 13 | 14 | NODE_CLASS_MAPPINGS = { 15 | "ApplyVVModel": ApplyVVModel, 16 | "FlowGetFlow": FlowGetFlowNode, 17 | "VVUnsamplerSampler": VVUnsamplerSamplerNode, 18 | "VVSamplerSampler": VVSamplerSamplerNode, 19 | "InjectionConfig": InjectionConfigNode, 20 | "FlowConfig": FlowConfigNode, 21 | "PivotConfig": PivotConfigNode, 22 | "RaveConfig": RaveConfigNode, 23 | "SCAConfig": SCAConfigNode, 24 | "TemporalConfig": TemporalConfigNode, 25 | "GetRaftFlow": GetRaftFlowNode, 26 | } 27 | 28 | NODE_DISPLAY_NAME_MAPPINGS = { 29 | "ApplyVVModel": "VV] Apply Model", 30 | "FlowGetFlow": "VV] Get Flow", 31 | "VVUnsamplerSampler": "VV] Unsampler", 32 | "VVSamplerSampler": "VV] Sampler", 33 | "InjectionConfig": "VV] Injection Config", 34 | "FlowConfig": "VV] Flow Attn Config", 35 | "PivotConfig": "VV] Pivot Attn Config", 36 | "RaveConfig": "VV] Rave Attn Config", 37 | "SCAConfig": "VV] Sparse Casual Attn Config", 38 | "TemporalConfig": "VV] Temporal Attn Config", 39 | "VVGetRaftFlow": "VV] Get Raft Flow" 40 | } 41 | -------------------------------------------------------------------------------- /configs/create_configs.py: -------------------------------------------------------------------------------- 1 | from ..configs.rave_config import RaveAttentionConfig 2 | from ..configs.pivot_config import PivotAttentionConfig 3 | from ..configs.sca_config import SparseCasualAttentionConfig 4 | from ..configs.flow_config import FlowAttentionConfig 5 | from ..configs.inj_config import InjectionConfig 6 | from ..configs.temporal_config import TemporalAttentionConfig 7 | from ..vv_defaults import SD1_MAPS, SDXL_MAPS, SD1_ATTN_INJ_DEFAULTS, SD1_RES_INJ_DEFAULTS, SD1_FLOW_MAP, SDXL_ATTN_INJ_DEFAULTS, SDXL_RES_INJ_DEFAULTS, SDXL_FLOW_MAP 8 | 9 | 10 | def create_configs( 11 | model, 12 | inj_config=None, 13 | flow_config=None, 14 | pivot_config=None, 15 | rave_config=None, 16 | sca_config=None, 17 | temporal_config=None, 18 | ): 19 | model_type = 'SD1.5' #TODO 20 | if model_type == 'SD1.5': 21 | if inj_config is not None: 22 | inj_config = InjectionConfig( 23 | SD1_ATTN_INJ_DEFAULTS, 24 | SD1_RES_INJ_DEFAULTS, 25 | inj_config['attn_injections'], 26 | inj_config['res_injections'], 27 | inj_config['attn_save_steps'], 28 | inj_config['res_save_steps'], 29 | inj_config['attn_sigmas'], 30 | inj_config['res_sigmas'] 31 | ) 32 | if flow_config is not None: 33 | flow_config = FlowAttentionConfig( 34 | SD1_MAPS[flow_config['targets']], 35 | flow_config['flow'], 36 | flow_config['start_percent'], 37 | flow_config['end_percent'] 38 | ) 39 | if rave_config is not None: 40 | rave_config = RaveAttentionConfig( 41 | SD1_MAPS[rave_config['targets']], 42 | rave_config['grid_size'], 43 | rave_config['seed'], 44 | rave_config['start_percent'], 45 | rave_config['end_percent']) 46 | if pivot_config is not None: 47 | pivot_config = PivotAttentionConfig( 48 | SD1_MAPS[pivot_config['targets']], 49 | pivot_config['batch_size'], 50 | pivot_config['seed'], 51 | pivot_config['start_percent'], 52 | pivot_config['end_percent']) 53 | if sca_config is not None: 54 | sca_config = SparseCasualAttentionConfig( 55 | SD1_MAPS[sca_config['targets']], 56 | sca_config['direction'], 57 | sca_config['start_percent'], 58 | sca_config['end_percent']) 59 | if temporal_config is not None: 60 | temporal_config = TemporalAttentionConfig( 61 | temporal_config['start_percent'], 62 | temporal_config['end_percent'], 63 | ) 64 | else: 65 | if inj_config is not None: 66 | inj_config = InjectionConfig( 67 | SDXL_ATTN_INJ_DEFAULTS, 68 | SDXL_RES_INJ_DEFAULTS, 69 | inj_config['attn_injections'], 70 | inj_config['res_injections'], 71 | inj_config['attn_save_steps'], 72 | inj_config['res_save_steps'], 73 | inj_config['attn_sigmas'], 74 | inj_config['res_sigmas'] 75 | ) 76 | if flow_config is not None: 77 | flow_config = FlowAttentionConfig( 78 | SDXL_MAPS[flow_config['targets']], 79 | flow_config['flow'], 80 | flow_config['start_percent'], 81 | flow_config['end_percent'] 82 | ) 83 | if rave_config is not None: 84 | rave_config = RaveAttentionConfig( 85 | SDXL_MAPS[rave_config['targets']], 86 | rave_config['grid_size'], 87 | rave_config['seed'], 88 | rave_config['start_percent'], 89 | rave_config['end_percent']) 90 | if pivot_config is not None: 91 | pivot_config = PivotAttentionConfig( 92 | SDXL_MAPS[pivot_config['targets']], 93 | pivot_config['batch_size'], 94 | pivot_config['seed'], 95 | pivot_config['start_percent'], 96 | pivot_config['end_percent']) 97 | if sca_config is not None: 98 | sca_config = SparseCasualAttentionConfig( 99 | SDXL_MAPS[sca_config['targets']], 100 | sca_config['direction'], 101 | sca_config['start_percent'], 102 | sca_config['end_percent']) 103 | if temporal_config is not None: 104 | temporal_config = TemporalAttentionConfig( 105 | temporal_config['start_percent'], 106 | temporal_config['end_percent'], 107 | ) 108 | return { 109 | 'INJ_CONFIG': inj_config, 110 | 'FLOW_CONFIG': flow_config, 111 | 'RAVE_CONFIG': rave_config, 112 | 'PIVOT_CONFIG': pivot_config, 113 | 'SCA_CONFIG': sca_config, 114 | 'TEMPORAL_CONFIG': temporal_config 115 | } -------------------------------------------------------------------------------- /configs/flow_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, Dict 3 | 4 | 5 | @dataclass 6 | class FlowAttentionConfig: 7 | targets: set[tuple[str, int]] 8 | flow: Dict[Any, Any] 9 | start_percent: float 10 | end_percent: float 11 | -------------------------------------------------------------------------------- /configs/inj_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict, List, Union 3 | 4 | import torch 5 | 6 | 7 | @dataclass 8 | class InjectionConfig: 9 | attn_map: set[tuple[str, int]] 10 | res_map: set[tuple[str, int]] 11 | attn_injections: Dict[float, Dict[tuple[str, int], Union[torch.Tensor, List[torch.Tensor]]]] 12 | res_injections: Dict[float, Dict[tuple[str, int], Union[torch.Tensor, List[torch.Tensor]]]] 13 | attn_save_steps: int 14 | res_save_steps: int 15 | attn_sigmas: List[float] 16 | res_sigmas: List[float] 17 | -------------------------------------------------------------------------------- /configs/pivot_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class PivotAttentionConfig: 6 | targets: set[tuple[str, int]] 7 | batch_size: int 8 | seed: int 9 | start_percent: float 10 | end_percent: float -------------------------------------------------------------------------------- /configs/rave_config.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | from dataclasses import dataclass 5 | 6 | 7 | @dataclass 8 | class RaveAttentionConfig: 9 | targets: set[tuple[str, int]] 10 | grid_size: int 11 | seed: int 12 | start_percent: float 13 | end_percent: float 14 | 15 | -------------------------------------------------------------------------------- /configs/sca_config.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | from dataclasses import dataclass 5 | 6 | 7 | class SCADirection: 8 | PREVIOUS = 'PREVIOUS' 9 | NEXT = 'NEXT' 10 | BOTH = 'BOTH' 11 | 12 | 13 | @dataclass 14 | class SparseCasualAttentionConfig: 15 | targets: set[tuple[str, int]] 16 | direction: SCADirection 17 | start_percent: float 18 | end_percent: float 19 | 20 | -------------------------------------------------------------------------------- /configs/temporal_config.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | from dataclasses import dataclass 5 | 6 | 7 | @dataclass 8 | class TemporalAttentionConfig: 9 | start_percent: float 10 | end_percent: float 11 | 12 | -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | 2 | class ModelOptionKey: 3 | INJECTION_TYPE = "INJECTION_KEY" 4 | FLOW = "FLOW" 5 | STEP = "STEP" 6 | BANK = "BANK" 7 | 8 | 9 | class InjectionType: 10 | UNSAMPLING = "UNSAMPLING" 11 | SAMPLING = "SAMPLING" 12 | REFERENCE = "REFERENCE" -------------------------------------------------------------------------------- /modules/vv_block.py: -------------------------------------------------------------------------------- 1 | from comfy.ldm.modules.attention import BasicTransformerBlock 2 | from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel 3 | 4 | from ..utils.module_utils import isinstance_str 5 | 6 | 7 | 8 | class VVTransformerBlock(BasicTransformerBlock): 9 | def forward(self, x, context=None, transformer_options={}): 10 | extra_options = {} 11 | block = transformer_options.get("block", None) 12 | block_index = transformer_options.get("block_index", 0) 13 | transformer_patches = {} 14 | transformer_patches_replace = {} 15 | 16 | for k in transformer_options: 17 | if k == "patches": 18 | transformer_patches = transformer_options[k] 19 | elif k == "patches_replace": 20 | transformer_patches_replace = transformer_options[k] 21 | else: 22 | extra_options[k] = transformer_options[k] 23 | 24 | extra_options["n_heads"] = self.n_heads 25 | extra_options["dim_head"] = self.d_head 26 | extra_options["attn_precision"] = self.attn_precision 27 | 28 | if self.ff_in: 29 | x_skip = x 30 | x = self.ff_in(self.norm_in(x)) 31 | if self.is_res: 32 | x += x_skip 33 | 34 | n = self.norm1(x) 35 | 36 | if self.disable_self_attn: 37 | context_attn1 = context 38 | else: 39 | context_attn1 = None 40 | value_attn1 = None 41 | 42 | if "attn1_patch" in transformer_patches: 43 | patch = transformer_patches["attn1_patch"] 44 | if context_attn1 is None: 45 | context_attn1 = n 46 | value_attn1 = context_attn1 47 | for p in patch: 48 | n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options) 49 | 50 | if block is not None: 51 | transformer_block = (block[0], block[1], block_index) 52 | else: 53 | transformer_block = None 54 | attn1_replace_patch = transformer_patches_replace.get("attn1", {}) 55 | block_attn1 = transformer_block 56 | if block_attn1 not in attn1_replace_patch: 57 | block_attn1 = block 58 | 59 | if block_attn1 in attn1_replace_patch: 60 | if context_attn1 is None: 61 | context_attn1 = n 62 | value_attn1 = n 63 | n = self.attn1.to_q(n) 64 | context_attn1 = self.attn1.to_k(context_attn1) 65 | value_attn1 = self.attn1.to_v(value_attn1) 66 | n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options) 67 | n = self.attn1.to_out(n) 68 | else: 69 | n = self.attn1(n, extra_options=extra_options) 70 | 71 | if "attn1_output_patch" in transformer_patches: 72 | patch = transformer_patches["attn1_output_patch"] 73 | for p in patch: 74 | n = p(n, extra_options) 75 | 76 | x += n 77 | if "middle_patch" in transformer_patches: 78 | patch = transformer_patches["middle_patch"] 79 | for p in patch: 80 | x = p(x, extra_options) 81 | 82 | if self.attn2 is not None: 83 | n = self.norm2(x) 84 | if self.switch_temporal_ca_to_sa: 85 | context_attn2 = n 86 | else: 87 | context_attn2 = context 88 | value_attn2 = None 89 | if "attn2_patch" in transformer_patches: 90 | patch = transformer_patches["attn2_patch"] 91 | value_attn2 = context_attn2 92 | for p in patch: 93 | n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options) 94 | 95 | attn2_replace_patch = transformer_patches_replace.get("attn2", {}) 96 | block_attn2 = transformer_block 97 | if block_attn2 not in attn2_replace_patch: 98 | block_attn2 = block 99 | 100 | if block_attn2 in attn2_replace_patch: 101 | if value_attn2 is None: 102 | value_attn2 = context_attn2 103 | n = self.attn2.to_q(n) 104 | context_attn2 = self.attn2.to_k(context_attn2) 105 | value_attn2 = self.attn2.to_v(value_attn2) 106 | n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options) 107 | n = self.attn2.to_out(n) 108 | else: 109 | n = self.attn2(n, context=context_attn2, value=value_attn2) 110 | 111 | if "attn2_output_patch" in transformer_patches: 112 | patch = transformer_patches["attn2_output_patch"] 113 | for p in patch: 114 | n = p(n, extra_options) 115 | 116 | x += n 117 | if self.is_res: 118 | x_skip = x 119 | x = self.ff(self.norm3(x)) 120 | if self.is_res: 121 | x += x_skip 122 | 123 | return x 124 | 125 | 126 | 127 | def _get_block_modules(module): 128 | blocks = list(filter(lambda x: isinstance_str(x[1], 'BasicTransformerBlock'), module.named_modules())) 129 | return [block for _, block in blocks] 130 | 131 | 132 | def inject_vv_block(diffusion_model: UNetModel): 133 | blocks = _get_block_modules(diffusion_model) 134 | 135 | for block in blocks: 136 | block.__class__ = VVTransformerBlock 137 | -------------------------------------------------------------------------------- /modules/vv_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from comfy.ldm.modules.diffusionmodules.openaimodel import ResBlock as ComfyResBlock, UNetModel 4 | 5 | from ..utils.module_utils import isinstance_str 6 | 7 | 8 | class ResBlock(ComfyResBlock): 9 | def init_module(self, block, idx): 10 | self.block = block 11 | self.idx = idx 12 | self.block_idx = (block, idx) 13 | 14 | def forward(self, x, emb, extra_options): 15 | if self.updown: 16 | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] 17 | h = in_rest(x) 18 | h = self.h_upd(h) 19 | x = self.x_upd(x) 20 | h = in_conv(h) 21 | else: 22 | h = self.in_layers(x) 23 | 24 | emb_out = None 25 | if not self.skip_t_emb: 26 | emb_out = self.emb_layers(emb).type(h.dtype) 27 | while len(emb_out.shape) < len(h.shape): 28 | emb_out = emb_out[..., None] 29 | if self.use_scale_shift_norm: 30 | out_norm, out_rest = self.out_layers[0], self.out_layers[1:] 31 | h = out_norm(h) 32 | if emb_out is not None: 33 | scale, shift = torch.chunk(emb_out, 2, dim=1) 34 | h *= (1 + scale) 35 | h += shift 36 | h = out_rest(h) 37 | else: 38 | if emb_out is not None: 39 | if self.exchange_temb_dims: 40 | emb_out = emb_out.movedim(1, 2) 41 | h = h + emb_out 42 | h = self.out_layers(h) 43 | 44 | ad_params = extra_options.get('ad_params', {}) 45 | sub_idxs = ad_params.get('sub_idxs', [0]) 46 | if sub_idxs is None: 47 | sub_idx = 0 48 | else: 49 | sub_idx = sub_idxs[0] 50 | 51 | res_inj_steps = extra_options.get('RES_INJECTION_STEPS', 0) 52 | step = extra_options.get('STEP', 999) 53 | inj_config = extra_options.get('INJ_CONFIG', None) 54 | if inj_config and self.block_idx in inj_config.res_map: 55 | if extra_options['INJECTION_KEY'] == 'SAMPLING' and step < res_inj_steps and step < inj_config.res_save_steps: 56 | if extra_options['INJECTION_KEY'] == 'SAMPLING': 57 | len_cond = len(extra_options['cond_or_uncond']) 58 | sigma_key = inj_config.res_sigmas[step] 59 | res_inj = inj_config.res_injections[sigma_key][self.block_idx] 60 | if sub_idxs is not None: 61 | h = res_inj[sub_idxs].to(x.device) 62 | else: 63 | h = res_inj.to(x.device) 64 | if len_cond > 1: 65 | h = torch.cat([h]*len_cond) 66 | elif extra_options['INJECTION_KEY'] == 'UNSAMPLING' and step < inj_config.res_save_steps: 67 | overlap = extra_options.get('OVERLAP', None) 68 | sigma_key = inj_config.res_sigmas[step] 69 | res_inj = inj_config.res_injections[sigma_key][self.block_idx] 70 | if overlap == 0 or sub_idx == 0: 71 | res_inj.append(h.clone().detach().cpu()) 72 | else: 73 | res_inj.append(h[overlap:].clone().detach().cpu()) 74 | 75 | return self.skip_connection(x) + h 76 | 77 | 78 | 79 | def _get_resnet_modules(module): 80 | blocks = list(filter(lambda x: isinstance_str(x[1], 'ResBlock'), module.named_modules())) 81 | return [block for _, block in blocks] 82 | 83 | 84 | def inject_vv_resblock(diffusion_model: UNetModel): 85 | input = _get_resnet_modules(diffusion_model.input_blocks) 86 | middle = _get_resnet_modules(diffusion_model.middle_block) 87 | output = _get_resnet_modules(diffusion_model.output_blocks) 88 | 89 | for idx, resnet in enumerate(input): 90 | resnet.__class__ = ResBlock 91 | resnet.init_module('input', idx) 92 | 93 | for idx, resnet in enumerate(middle): 94 | resnet.__class__ = ResBlock 95 | resnet.init_module('middle', idx) 96 | 97 | for idx, resnet in enumerate(output): 98 | resnet.__class__ = ResBlock 99 | resnet.init_module('output', idx) -------------------------------------------------------------------------------- /modules/vv_unet.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import List 3 | import torch 4 | 5 | from comfy.ldm.modules.attention import SpatialTransformer 6 | from comfy.ldm.modules.diffusionmodules.openaimodel import TimestepBlock, UNetModel, Upsample, apply_control 7 | from comfy.ldm.modules.diffusionmodules.util import timestep_embedding 8 | 9 | 10 | def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None): 11 | for layer in ts: 12 | if isinstance(layer, TimestepBlock): 13 | x = layer(x, emb, transformer_options) 14 | elif isinstance(layer, SpatialTransformer): 15 | x = layer(x, context, transformer_options) 16 | if "transformer_index" in transformer_options: 17 | transformer_options["transformer_index"] += 1 18 | elif isinstance(layer, Upsample): 19 | x = layer(x, output_shape=output_shape) 20 | elif 'raunet' in str(layer.__class__): 21 | x = layer(x, output_shape=output_shape, transformer_options=transformer_options) 22 | elif 'Temporal' in str(layer.__class__): 23 | temporal_config = transformer_options.get('TEMPORAL_CONFIG', None) 24 | step_percent = transformer_options.get('STEP_PERCENT', 999) 25 | if temporal_config is None or (temporal_config.start_percent <= step_percent <= temporal_config.end_percent): 26 | x = layer(x, context) 27 | else: 28 | x = layer(x) 29 | return x 30 | 31 | 32 | class VVUNetModel(UNetModel): 33 | def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs): 34 | return self._forward_sample(x, timesteps, context, y, control, transformer_options=transformer_options) 35 | 36 | def _forward_sample(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs): 37 | transformer_options["original_shape"] = list(x.shape) 38 | transformer_options["transformer_index"] = 0 39 | transformer_patches = transformer_options.get("patches", {}) 40 | 41 | hs = [] 42 | t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) 43 | emb = self.time_embed(t_emb) 44 | 45 | if self.num_classes is not None: 46 | assert y.shape[0] == x.shape[0] 47 | emb = emb + self.label_emb(y) 48 | 49 | h = x 50 | for id, module in enumerate(self.input_blocks): 51 | transformer_options["block"] = ("input", id) 52 | h = forward_timestep_embed(module, h, emb, context, transformer_options) 53 | if control is not None: 54 | h = apply_control(h, control, 'input') 55 | if "input_block_patch" in transformer_patches: 56 | patch = transformer_patches["input_block_patch"] 57 | for p in patch: 58 | h = p(h, transformer_options) 59 | 60 | hs.append(h) 61 | if "input_block_patch_after_skip" in transformer_patches: 62 | patch = transformer_patches["input_block_patch_after_skip"] 63 | for p in patch: 64 | h = p(h, transformer_options) 65 | 66 | transformer_options["block"] = ("middle", 0) 67 | h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options) 68 | h = apply_control(h, control, 'middle') 69 | 70 | for id, module in enumerate(self.output_blocks): 71 | transformer_options["block"] = ("output", id) 72 | hsp = hs.pop() 73 | if control is not None: 74 | hsp = apply_control(hsp, control, 'output') 75 | 76 | if "output_block_patch" in transformer_patches: 77 | patch = transformer_patches["output_block_patch"] 78 | for p in patch: 79 | h, hsp = p(h, hsp, transformer_options) 80 | 81 | h = torch.cat([h, hsp], dim=1) 82 | del hsp 83 | if len(hs) > 0: 84 | output_shape = hs[-1].shape 85 | else: 86 | output_shape = None 87 | h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape) 88 | 89 | h = h.type(x.dtype) 90 | return self.out(h) 91 | 92 | 93 | def inject_unet(diffusion_model): 94 | diffusion_model.__class__ = VVUNetModel 95 | -------------------------------------------------------------------------------- /nodes/config_flow_attn_node.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class ConfigFlowAttnFlowSD1Node: 4 | @classmethod 5 | def INPUT_TYPES(s): 6 | base = {"required": { 7 | }} 8 | for i in range(8): 9 | base['required'][f'input_{i}'] = ("BOOLEAN", { "default": i < 2 }) 10 | 11 | base['required'][f'middle_0'] = ("BOOLEAN", { "default": False }) 12 | 13 | for i in range(9): 14 | base['required'][f'output_{i}'] = ("BOOLEAN", { "default": i > 6 }) 15 | 16 | return base 17 | RETURN_TYPES = ("ATTN_MAP",) 18 | FUNCTION = "apply" 19 | 20 | CATEGORY = "vv" 21 | 22 | def apply(self, **kwargs): 23 | 24 | attention_map = set() 25 | for key, value in kwargs.items(): 26 | if value: 27 | block, idx = key.split('_') 28 | attention_map.add((block, int(idx))) 29 | 30 | return (attention_map, ) 31 | 32 | 33 | SDXL_DEFAULT_ALL_ATTNS = "0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35" 34 | 35 | class ConfigFlowAttnFlowSDXLNode: 36 | @classmethod 37 | def INPUT_TYPES(s): 38 | base = {"required": { 39 | "input_attns": ("STRING", {"multiline": True, "default": SDXL_DEFAULT_ALL_ATTNS }), 40 | "output_attns": ("STRING", {"multiline": True, "default": SDXL_DEFAULT_ALL_ATTNS }), 41 | }} 42 | return base 43 | RETURN_TYPES = ("ATTN_MAP",) 44 | FUNCTION = "apply" 45 | 46 | CATEGORY = "flow" 47 | 48 | def apply(self, input_attns, output_attns): 49 | 50 | attention_map = set() 51 | if input_attns != '' and input_attns is not None: 52 | for idx in input_attns.split(','): 53 | idx = idx.strip() 54 | if idx is '': 55 | continue 56 | attention_map.add(('input', int(idx))) 57 | if input_attns != '' and input_attns is not None: 58 | for idx in output_attns.split(','): 59 | idx = idx.strip() 60 | if idx is '': 61 | continue 62 | attention_map.add(('output', int(idx))) 63 | 64 | return (attention_map, ) 65 | 66 | 67 | class ConfigFlowAttnCrossFrameSDXLNode: 68 | @classmethod 69 | def INPUT_TYPES(s): 70 | base = {"required": { 71 | "input_attns": ("STRING", {"multiline": True, "default": SDXL_DEFAULT_ALL_ATTNS }), 72 | "output_attns": ("STRING", {"multiline": True, "default": SDXL_DEFAULT_ALL_ATTNS }), 73 | }} 74 | return base 75 | RETURN_TYPES = ("ATTN_MAP",) 76 | FUNCTION = "apply" 77 | 78 | CATEGORY = "flow" 79 | 80 | def apply(self, input_attns, output_attns): 81 | 82 | attention_map = set() 83 | if input_attns != '' and input_attns is not None: 84 | for idx in input_attns.split(','): 85 | idx = idx.strip() 86 | if idx == '' or idx is None: 87 | continue 88 | attention_map.add(('input', int(idx))) 89 | if output_attns != '' and output_attns is not None: 90 | for idx in output_attns.split(','): 91 | idx = idx.strip() 92 | if idx == '' or idx is None: 93 | continue 94 | attention_map.add(('output', int(idx))) 95 | 96 | return (attention_map, ) -------------------------------------------------------------------------------- /nodes/config_flow_inj_node.py: -------------------------------------------------------------------------------- 1 | SD1_ATTN_DEFAULTS = set([1,2,3,4,5,6]) 2 | SD1_RES_DEFAULTS = set([3,4,6]) 3 | 4 | 5 | class ConfigFlowAttnInjSD1Node: 6 | @classmethod 7 | def INPUT_TYPES(s): 8 | base = {"required": { 9 | }} 10 | 11 | for i in range(9): 12 | base['required'][f'output_{i}'] = ("BOOLEAN", { "default": i in SD1_ATTN_DEFAULTS }) 13 | 14 | return base 15 | RETURN_TYPES = ("ATTN_INJ_MAP",) 16 | FUNCTION = "apply" 17 | 18 | CATEGORY = "vv" 19 | 20 | def apply(self, **kwargs): 21 | 22 | attention_map = set() 23 | for key, value in kwargs.items(): 24 | if value: 25 | block, idx = key.split('_') 26 | attention_map.add((block, int(idx))) 27 | 28 | return (attention_map, ) 29 | 30 | 31 | # 0-35 32 | class ConfigFlowAttnInjSDXLNode: 33 | @classmethod 34 | def INPUT_TYPES(s): 35 | base = {"required": { 36 | "attns": ("STRING", {"multiline": True }), 37 | }} 38 | return base 39 | RETURN_TYPES = ("ATTN_INJ_MAP",) 40 | FUNCTION = "apply" 41 | 42 | CATEGORY = "flow" 43 | 44 | def apply(self, attns): 45 | 46 | attention_map = set() 47 | for idx in attns.split(','): 48 | idx = int(idx.strip()) 49 | attention_map.add(('output', int(idx))) 50 | 51 | return (attention_map, ) 52 | 53 | 54 | class ConfigFlowResInjSD1Node: 55 | @classmethod 56 | def INPUT_TYPES(s): 57 | base = {"required": { 58 | }} 59 | 60 | for i in range(9): 61 | base['required'][f'output_{i}'] = ("BOOLEAN", { "default": i in SD1_RES_DEFAULTS }) 62 | 63 | return base 64 | RETURN_TYPES = ("RES_INJ_MAP",) 65 | FUNCTION = "apply" 66 | 67 | CATEGORY = "flow" 68 | 69 | def apply(self, **kwargs): 70 | 71 | resnet_map = set() 72 | for key, value in kwargs.items(): 73 | if value: 74 | block, idx = key.split('_') 75 | resnet_map.add((block, int(idx))) 76 | 77 | return (resnet_map, ) 78 | 79 | 80 | class ConfigFlowResInjSDXLNode: 81 | @classmethod 82 | def INPUT_TYPES(s): 83 | base = {"required": { 84 | }} 85 | 86 | for i in range(9): 87 | base['required'][f'output_{i}'] = ("BOOLEAN", { "default": i in SD1_RES_DEFAULTS }) 88 | 89 | return base 90 | RETURN_TYPES = ("RES_INJ_MAP",) 91 | FUNCTION = "apply" 92 | 93 | CATEGORY = "flow" 94 | 95 | def apply(self, **kwargs): 96 | 97 | resnet_map = set() 98 | for key, value in kwargs.items(): 99 | if value: 100 | block, idx = key.split('_') 101 | resnet_map.add((block, int(idx))) 102 | 103 | return (resnet_map, ) -------------------------------------------------------------------------------- /nodes/flow_config_node.py: -------------------------------------------------------------------------------- 1 | 2 | class FlowConfigNode: 3 | @classmethod 4 | def INPUT_TYPES(s): 5 | return {"required": { 6 | "flow": ("FLOW",), 7 | "targets": (['full', 'inner', 'outer', 'none'],), 8 | "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}), 9 | "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), 10 | }} 11 | RETURN_TYPES = ("FLOW_CONFIG",) 12 | FUNCTION = "build" 13 | 14 | CATEGORY = "vv/configs" 15 | 16 | def build(self, flow, targets, start_percent, end_percent): 17 | return ({ 'targets': targets, 'flow': flow, 'start_percent': start_percent, 'end_percent': end_percent},) -------------------------------------------------------------------------------- /nodes/inj_config_node.py: -------------------------------------------------------------------------------- 1 | 2 | def find_closest_index(lst, target): 3 | return min(range(len(lst)), key=lambda i: abs(lst[i] - target)) 4 | 5 | 6 | class InjectionConfigNode: 7 | @classmethod 8 | def INPUT_TYPES(s): 9 | return {"required": { 10 | "unsampler_sigmas": ("SIGMAS",), 11 | "sampler_sigmas": ("SIGMAS",), 12 | "save_attn_steps": ("INT", {"default": 3, "min": 0, "max": 999, "step": 1}), 13 | "save_res_steps": ("INT", {"default": 3, "min": 0, "max": 999, "step": 1}), 14 | }} 15 | RETURN_TYPES = ("INJ_CONFIG",) 16 | FUNCTION = "create" 17 | 18 | CATEGORY = "vv/configs" 19 | 20 | def create(self, unsampler_sigmas, sampler_sigmas, save_attn_steps, save_res_steps): 21 | attn_injections = {} 22 | attn_sigmas = [] 23 | for i in range(save_attn_steps): 24 | sampler_sigma = sampler_sigmas[i].item() 25 | unsampler_idx = find_closest_index(unsampler_sigmas, sampler_sigma) 26 | unsampler_sigma = unsampler_sigmas[unsampler_idx].item() 27 | attn_injections[unsampler_sigma] = {} 28 | attn_sigmas.append(unsampler_sigma) 29 | 30 | res_injections = {} 31 | res_sigmas = [] 32 | for i in range(save_res_steps): 33 | sampler_sigma = sampler_sigmas[i].item() 34 | unsampler_idx = find_closest_index(unsampler_sigmas, sampler_sigma) 35 | unsampler_sigma = unsampler_sigmas[unsampler_idx].item() 36 | res_injections[unsampler_sigma] = {} 37 | res_sigmas.append(unsampler_sigma) 38 | 39 | config = { 40 | 'attn_injections': attn_injections, 41 | 'res_injections': res_injections, 42 | 'attn_save_steps': save_attn_steps, 43 | 'res_save_steps': save_res_steps, 44 | 'attn_sigmas': attn_sigmas, 45 | 'res_sigmas': res_sigmas 46 | } 47 | return (config, ) 48 | -------------------------------------------------------------------------------- /nodes/pivot_config_node.py: -------------------------------------------------------------------------------- 1 | from ..vv_defaults import MAP_TYPES 2 | 3 | class PivotConfigNode: 4 | @classmethod 5 | def INPUT_TYPES(s): 6 | return {"required": { 7 | "batch_size": ("INT", {"default": 3, "min": 2, "max": 9, "step": 1}), 8 | "targets": (MAP_TYPES,), 9 | "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), 10 | "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}), 11 | "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), 12 | }} 13 | RETURN_TYPES = ("PIVOT_CONFIG",) 14 | FUNCTION = "build" 15 | 16 | CATEGORY = "vv/configs" 17 | 18 | def build(self, batch_size, targets, seed, start_percent, end_percent): 19 | return ({ 'batch_size': batch_size, 'targets': targets, 'seed': seed, 'start_percent': start_percent, 'end_percent': end_percent},) 20 | -------------------------------------------------------------------------------- /nodes/rave_config_node.py: -------------------------------------------------------------------------------- 1 | from ..vv_defaults import MAP_TYPES 2 | 3 | 4 | class RaveConfigNode: 5 | @classmethod 6 | def INPUT_TYPES(s): 7 | return {"required": { 8 | "grid_size": ("INT", {"default": 3, "min": 2, "max": 9, "step": 1}), 9 | "targets": (MAP_TYPES,), 10 | "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), 11 | "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}), 12 | "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), 13 | }} 14 | RETURN_TYPES = ("RAVE_CONFIG",) 15 | FUNCTION = "build" 16 | 17 | CATEGORY = "vv/configs" 18 | 19 | def build(self, grid_size, targets, seed, start_percent, end_percent): 20 | return ({ 'grid_size': grid_size, 'targets': targets, 'seed': seed, 'start_percent': start_percent, 'end_percent': end_percent},) -------------------------------------------------------------------------------- /nodes/sca_config_node.py: -------------------------------------------------------------------------------- 1 | from ..vv_defaults import MAP_TYPES 2 | 3 | 4 | class SCAConfigNode: 5 | @classmethod 6 | def INPUT_TYPES(s): 7 | return {"required": { 8 | "direction": (['PREVIOUS', 'NEXT', 'BOTH'],), 9 | "targets": (MAP_TYPES,), 10 | "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}), 11 | "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), 12 | }} 13 | RETURN_TYPES = ("SCA_CONFIG",) 14 | FUNCTION = "build" 15 | 16 | CATEGORY = "vv/configs" 17 | 18 | def build(self, direction, targets, start_percent, end_percent): 19 | 20 | return ({ 'direction': direction, 'targets': targets, 'start_percent': start_percent, 'end_percent': end_percent},) -------------------------------------------------------------------------------- /nodes/temporal_config_node.py: -------------------------------------------------------------------------------- 1 | 2 | class TemporalConfigNode: 3 | @classmethod 4 | def INPUT_TYPES(s): 5 | return {"required": { 6 | "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}), 7 | "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), 8 | }} 9 | RETURN_TYPES = ("TEMPORAL_CONFIG",) 10 | FUNCTION = "build" 11 | 12 | CATEGORY = "vv/configs" 13 | 14 | def build(self, start_percent, end_percent): 15 | 16 | return ({ 'start_percent': start_percent, 'end_percent': end_percent},) -------------------------------------------------------------------------------- /nodes/vv_apply_node.py: -------------------------------------------------------------------------------- 1 | from ..modules.vv_attention import inject_vv_atn 2 | from ..modules.vv_block import inject_vv_block 3 | from ..modules.vv_resnet import inject_vv_resblock 4 | from ..modules.vv_unet import inject_unet 5 | 6 | 7 | class ApplyVVModel: 8 | @classmethod 9 | def INPUT_TYPES(s): 10 | return {"required": { 11 | "model": ("MODEL",), 12 | }} 13 | RETURN_TYPES = ("MODEL",) 14 | FUNCTION = "apply" 15 | 16 | CATEGORY = "vv" 17 | 18 | def apply(self, model): 19 | inject_vv_atn(model.model.diffusion_model) 20 | inject_vv_block(model.model.diffusion_model) 21 | inject_vv_resblock(model.model.diffusion_model) 22 | inject_unet(model.model.diffusion_model) 23 | return (model, ) 24 | -------------------------------------------------------------------------------- /nodes/vv_sampler_node.py: -------------------------------------------------------------------------------- 1 | import comfy.samplers 2 | from comfy.samplers import KSAMPLER 3 | 4 | from ..utils.sampler_utils import get_sampler_fn, create_sampler 5 | 6 | 7 | 8 | class VVSamplerSamplerNode: 9 | @classmethod 10 | def INPUT_TYPES(s): 11 | return {"required": { 12 | "sampler_name": (comfy.samplers.SAMPLER_NAMES, ), 13 | "attn_injection_steps": ("INT", {"default": 3, "min": 0, "max": 999, "step": 1}), 14 | "res_injection_steps": ("INT", {"default": 3, "min": 0, "max": 999, "step": 1}), 15 | "overlap": ("INT", {"default": 5, "min": 0, "max": 999, "step": 1}), 16 | }, "optional": { 17 | "flow_config": ("FLOW_CONFIG",), 18 | "inj_config": ("INJ_CONFIG",), 19 | "sca_config": ('SCA_CONFIG',), 20 | "pivot_config": ('PIVOT_CONFIG',), 21 | "rave_config": ('RAVE_CONFIG',), 22 | "temporal_config": ('TEMPORAL_CONFIG',), 23 | "sampler": ("SAMPLER",), 24 | }} 25 | RETURN_TYPES = ("SAMPLER",) 26 | FUNCTION = "build" 27 | 28 | CATEGORY = "vv/sampling" 29 | 30 | def build(self, 31 | sampler_name, 32 | attn_injection_steps, 33 | res_injection_steps, 34 | overlap, 35 | flow_config=None, 36 | inj_config=None, 37 | sca_config=None, 38 | pivot_config=None, 39 | rave_config=None, 40 | temporal_config=None, 41 | sampler=None): 42 | 43 | sampler_fn = get_sampler_fn(sampler_name) 44 | sampler_fn = create_sampler(sampler_fn, attn_injection_steps, res_injection_steps, overlap, 45 | flow_config, inj_config,pivot_config, rave_config, sca_config, 46 | temporal_config) 47 | 48 | if sampler is None: 49 | sampler = KSAMPLER(sampler_fn) 50 | else: 51 | sampler.sampler_function = sampler_fn 52 | 53 | return (sampler, ) -------------------------------------------------------------------------------- /nodes/vv_unsampler_node.py: -------------------------------------------------------------------------------- 1 | import comfy.samplers 2 | from comfy.samplers import KSAMPLER 3 | 4 | from ..utils.sampler_utils import get_sampler_fn, create_unsampler 5 | 6 | 7 | class VVUnsamplerSamplerNode: 8 | @classmethod 9 | def INPUT_TYPES(s): 10 | return {"required": { 11 | "sampler_name": (comfy.samplers.SAMPLER_NAMES, ), 12 | "overlap": ("INT", {"default": 5, "min": 0, "max": 999, "step": 1}), 13 | }, "optional": { 14 | "flow_config": ("FLOW_CONFIG",), 15 | "inj_config": ("INJ_CONFIG",), 16 | "sampler": ("SAMPLER",) 17 | }} 18 | RETURN_TYPES = ("SAMPLER",) 19 | FUNCTION = "build" 20 | 21 | CATEGORY = "vv/sampling" 22 | 23 | def build(self, 24 | sampler_name, 25 | overlap, 26 | flow_config=None, 27 | inj_config=None, 28 | sampler=None): 29 | 30 | sampler_fn = get_sampler_fn(sampler_name) 31 | sampler_fn = create_unsampler(sampler_fn, overlap, flow_config, inj_config) 32 | 33 | if sampler is None: 34 | sampler = KSAMPLER(sampler_fn) 35 | else: 36 | sampler.sampler_function = sampler_fn 37 | 38 | return (sampler, ) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einshape 2 | timm -------------------------------------------------------------------------------- /sea_raft/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | dist 3 | datasets 4 | pytorch_env 5 | models 6 | ckpt* 7 | build 8 | demo 9 | runs 10 | results 11 | */__pycache__/* 12 | checkpoints 13 | weights 14 | *.ckpt 15 | *.pth 16 | *.ppm 17 | *.flo 18 | *.egg-info 19 | *.npy 20 | *.npz -------------------------------------------------------------------------------- /sea_raft/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2024, Princeton Vision & Learning Lab 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /sea_raft/README.md: -------------------------------------------------------------------------------- 1 | # [ECCV24] SEA-RAFT 2 | 3 | We introduce SEA-RAFT, a more simple, efficient, and accurate [RAFT](https://github.com/princeton-vl/RAFT) for optical flow. Compared with RAFT, SEA-RAFT is trained with a new loss (mixture of Laplace). It directly regresses an initial flow for faster convergence in iterative refinements and introduces rigid-motion pre-training to improve generalization. SEA-RAFT achieves state-of-the-art accuracy on the [Spring benchmark](https://spring-benchmark.org/) with a 3.69 endpoint-error (EPE) and a 0.36 1-pixel outlier rate (1px), representing 22.9\% and 17.8\% error reduction from best-published results. In addition, SEA-RAFT obtains the best cross-dataset generalization on KITTI and Spring. With its high efficiency, SEA-RAFT operates at least 2.3x faster than existing methods while maintaining competitive performance. 4 | 5 | 6 | 7 | If you find SEA-RAFT useful for your work, please consider citing our academic paper: 8 | 9 |

10 | 11 | SEA-RAFT: Simple, Efficient, Accurate RAFT for Optical Flow 12 | 13 |

14 |

15 | Yihan Wang, 16 | Lahav Lipson, 17 | Jia Deng
18 |

19 | 20 | ``` 21 | @article{wang2024sea, 22 | title={SEA-RAFT: Simple, Efficient, Accurate RAFT for Optical Flow}, 23 | author={Wang, Yihan and Lipson, Lahav and Deng, Jia}, 24 | journal={arXiv preprint arXiv:2405.14793}, 25 | year={2024} 26 | } 27 | ``` 28 | 29 | ## Requirements 30 | Our code is developed with pytorch 2.2.0, CUDA 12.2 and python 3.10. 31 | ```Shell 32 | conda create --name SEA-RAFT python=3.10.13 33 | conda activate SEA-RAFT 34 | pip install -r requirements.txt 35 | ``` 36 | 37 | ## Model Zoo 38 | Please download the models from [google drive](https://drive.google.com/drive/folders/1YLovlvUW94vciWvTyLf-p3uWscbOQRWW?usp=sharing) and put them into the `models` folder. 39 | 40 | ## Custom Usage 41 | 42 | We provide an example in `custom.py`. By default, this file will take two RGB images as the input and provide visualizations of the optical flow and the uncertainty. 43 | ```Shell 44 | python custom.py --cfg config/eval/spring-M.json --model models/Tartan-C-T-TSKH-spring540x960-M.pth 45 | ``` 46 | 47 | ## Datasets 48 | To evaluate/train SEA-RAFT, you will need to download the required datasets: [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs), [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html), [Sintel](http://sintel.is.tue.mpg.de/), [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow), [HD1K](http://hci-benchmark.iwr.uni-heidelberg.de/), [TartanAir](https://theairlab.org/tartanair-dataset/), and [Spring](https://spring-benchmark.org/). 49 | 50 | By default `datasets.py` will search for the datasets in these locations. You can create symbolic links to wherever the datasets were downloaded in the `datasets` folder. Please check [RAFT](https://github.com/princeton-vl/RAFT) for more details. 51 | 52 | ```Shell 53 | ├── datasets 54 | ├── Sintel 55 | ├── KITTI 56 | ├── FlyingChairs/FlyingChairs_release 57 | ├── FlyingThings3D 58 | ├── HD1K 59 | ├── spring 60 | ├── test 61 | ├── train 62 | ├── val 63 | ├── tartanair 64 | ``` 65 | 66 | ## Training, Evaluation, and Submission 67 | 68 | Please refer to [scripts/train.sh](scripts/train.sh), [scripts/eval.sh](scripts/eval.sh), and [scripts/submission.sh](scripts/submission.sh) for more details. 69 | 70 | ## Acknowledgements 71 | 72 | This project relies on code from existing repositories: [RAFT](https://github.com/princeton-vl/RAFT), [unimatch](https://github.com/autonomousvision/unimatch/tree/master), [Flowformer](https://github.com/drinkingcoder/FlowFormer-Official), [ptlflow](https://github.com/hmorimitsu/ptlflow), and [LoFTR](https://github.com/zju3dv/LoFTR). We thank the original authors for their excellent work. 73 | -------------------------------------------------------------------------------- /sea_raft/assets/visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/sea_raft/assets/visualization.png -------------------------------------------------------------------------------- /sea_raft/config/eval/kitti-L.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "kitti-L", 3 | "dataset": "kitti", 4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7], 5 | 6 | "use_var": true, 7 | "var_min": 0, 8 | "var_max": 10, 9 | "pretrain": "resnet34", 10 | "initial_dim": 64, 11 | "block_dims": [64, 128, 256], 12 | "radius": 4, 13 | "dim": 128, 14 | "num_blocks": 2, 15 | "iters": 12, 16 | 17 | "image_size": [432, 960], 18 | "scale": 0, 19 | "batch_size": 16, 20 | "epsilon": 1e-8, 21 | "lr": 1e-4, 22 | "wdecay": 1e-5, 23 | "dropout": 0, 24 | "clip": 1.0, 25 | "gamma": 0.85, 26 | "num_steps": 10000, 27 | 28 | "restore_ckpt": null, 29 | "coarse_config": null 30 | } -------------------------------------------------------------------------------- /sea_raft/config/eval/kitti-M.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "kitti-M", 3 | "dataset": "kitti", 4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7], 5 | 6 | "use_var": true, 7 | "var_min": 0, 8 | "var_max": 10, 9 | "pretrain": "resnet34", 10 | "initial_dim": 64, 11 | "block_dims": [64, 128, 256], 12 | "radius": 4, 13 | "dim": 128, 14 | "num_blocks": 2, 15 | "iters": 4, 16 | 17 | "image_size": [432, 960], 18 | "scale": 0, 19 | "batch_size": 16, 20 | "epsilon": 1e-8, 21 | "lr": 1e-4, 22 | "wdecay": 1e-5, 23 | "dropout": 0, 24 | "clip": 1.0, 25 | "gamma": 0.85, 26 | "num_steps": 10000, 27 | 28 | "restore_ckpt": null, 29 | "coarse_config": null 30 | } -------------------------------------------------------------------------------- /sea_raft/config/eval/kitti-S.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "kitti-S", 3 | "dataset": "kitti", 4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7], 5 | 6 | "use_var": true, 7 | "var_min": 0, 8 | "var_max": 10, 9 | "pretrain": "resnet18", 10 | "initial_dim": 64, 11 | "block_dims": [64, 128, 256], 12 | "radius": 4, 13 | "dim": 128, 14 | "num_blocks": 2, 15 | "iters": 4, 16 | 17 | "image_size": [432, 960], 18 | "scale": 0, 19 | "batch_size": 16, 20 | "epsilon": 1e-8, 21 | "lr": 1e-4, 22 | "wdecay": 1e-5, 23 | "dropout": 0, 24 | "clip": 1.0, 25 | "gamma": 0.85, 26 | "num_steps": 10000, 27 | 28 | "restore_ckpt": null, 29 | "coarse_config": null 30 | } -------------------------------------------------------------------------------- /sea_raft/config/eval/sintel-L.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "sintel-M", 3 | "dataset": "sintel", 4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7], 5 | 6 | "use_var": true, 7 | "var_min": 0, 8 | "var_max": 10, 9 | "pretrain": "resnet34", 10 | "initial_dim": 64, 11 | "block_dims": [64, 128, 256], 12 | "radius": 4, 13 | "dim": 128, 14 | "num_blocks": 2, 15 | "iters": 12, 16 | 17 | "image_size": [432, 960], 18 | "scale": 0, 19 | "batch_size": 32, 20 | "epsilon": 1e-8, 21 | "lr": 4e-4, 22 | "wdecay": 1e-5, 23 | "dropout": 0, 24 | "clip": 1.0, 25 | "gamma": 0.85, 26 | "num_steps": 300000, 27 | 28 | "restore_ckpt": null, 29 | "coarse_config": null 30 | } -------------------------------------------------------------------------------- /sea_raft/config/eval/sintel-M.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "sintel-M", 3 | "dataset": "sintel", 4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7], 5 | 6 | "use_var": true, 7 | "var_min": 0, 8 | "var_max": 10, 9 | "pretrain": "resnet34", 10 | "initial_dim": 64, 11 | "block_dims": [64, 128, 256], 12 | "radius": 4, 13 | "dim": 128, 14 | "num_blocks": 2, 15 | "iters": 4, 16 | 17 | "image_size": [432, 960], 18 | "scale": 0, 19 | "batch_size": 32, 20 | "epsilon": 1e-8, 21 | "lr": 4e-4, 22 | "wdecay": 1e-5, 23 | "dropout": 0, 24 | "clip": 1.0, 25 | "gamma": 0.85, 26 | "num_steps": 300000, 27 | 28 | "restore_ckpt": null, 29 | "coarse_config": null 30 | } -------------------------------------------------------------------------------- /sea_raft/config/eval/sintel-S.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "sintel-S", 3 | "dataset": "sintel", 4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7], 5 | 6 | "use_var": true, 7 | "var_min": 0, 8 | "var_max": 10, 9 | "pretrain": "resnet18", 10 | "initial_dim": 64, 11 | "block_dims": [64, 128, 256], 12 | "radius": 4, 13 | "dim": 128, 14 | "num_blocks": 2, 15 | "iters": 4, 16 | 17 | "image_size": [432, 960], 18 | "scale": 0, 19 | "batch_size": 32, 20 | "epsilon": 1e-8, 21 | "lr": 4e-4, 22 | "wdecay": 1e-5, 23 | "dropout": 0, 24 | "clip": 1.0, 25 | "gamma": 0.85, 26 | "num_steps": 300000, 27 | 28 | "restore_ckpt": null, 29 | "coarse_config": null 30 | } -------------------------------------------------------------------------------- /sea_raft/config/eval/spring-L.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "spring-L", 3 | "dataset": "spring", 4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7], 5 | 6 | "use_var": true, 7 | "var_min": 0, 8 | "var_max": 10, 9 | "pretrain": "resnet34", 10 | "initial_dim": 64, 11 | "block_dims": [64, 128, 256], 12 | "radius": 4, 13 | "dim": 128, 14 | "num_blocks": 2, 15 | "iters": 12, 16 | 17 | "image_size": [540, 960], 18 | "scale": -1, 19 | "batch_size": 32, 20 | "epsilon": 1e-8, 21 | "lr": 4e-4, 22 | "wdecay": 1e-5, 23 | "dropout": 0, 24 | "clip": 1.0, 25 | "gamma": 0.85, 26 | "num_steps": 120000, 27 | 28 | "restore_ckpt": null, 29 | "coarse_config": null 30 | } -------------------------------------------------------------------------------- /sea_raft/config/eval/spring-M.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "spring-M", 3 | "dataset": "spring", 4 | "gpus": [ 5 | 0, 6 | 1, 7 | 2, 8 | 3, 9 | 4, 10 | 5, 11 | 6, 12 | 7 13 | ], 14 | "use_var": true, 15 | "var_min": 0, 16 | "var_max": 10, 17 | "pretrain": "resnet34", 18 | "initial_dim": 64, 19 | "block_dims": [ 20 | 64, 21 | 128, 22 | 256 23 | ], 24 | "radius": 4, 25 | "dim": 128, 26 | "num_blocks": 2, 27 | "iters": 4, 28 | "image_size": [ 29 | 540, 30 | 960 31 | ], 32 | "scale": -1, 33 | "batch_size": 32, 34 | "epsilon": 1e-8, 35 | "lr": 4e-4, 36 | "wdecay": 1e-5, 37 | "dropout": 0, 38 | "clip": 1.0, 39 | "gamma": 0.85, 40 | "num_steps": 120000, 41 | "restore_ckpt": null, 42 | "coarse_config": null 43 | } -------------------------------------------------------------------------------- /sea_raft/config/eval/spring-S.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "spring-S", 3 | "dataset": "spring", 4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7], 5 | 6 | "use_var": true, 7 | "var_min": 0, 8 | "var_max": 10, 9 | "pretrain": "resnet18", 10 | "initial_dim": 64, 11 | "block_dims": [64, 128, 256], 12 | "radius": 4, 13 | "dim": 128, 14 | "num_blocks": 2, 15 | "iters": 4, 16 | 17 | "image_size": [540, 960], 18 | "scale": -1, 19 | "batch_size": 32, 20 | "epsilon": 1e-8, 21 | "lr": 4e-4, 22 | "wdecay": 1e-5, 23 | "dropout": 0, 24 | "clip": 1.0, 25 | "gamma": 0.85, 26 | "num_steps": 120000, 27 | 28 | "restore_ckpt": null, 29 | "coarse_config": null 30 | } -------------------------------------------------------------------------------- /sea_raft/config/parser.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import json 3 | from typing import Any, List 4 | 5 | def get_config_data(json_path): 6 | with open(json_path, 'r') as f: 7 | data = json.load(f) 8 | return data 9 | 10 | 11 | def get_model_config(config_path): 12 | args = get_config_data(config_path) 13 | model_args = ModelArgs(**args) 14 | return model_args -------------------------------------------------------------------------------- /sea_raft/config/train/Tartan-C-T-TSKH-kitti432x960-M.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Tartan-C-T-TSKH-kitti432x960-M", 3 | "dataset": "kitti", 4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7], 5 | 6 | "use_var": true, 7 | "var_min": 0, 8 | "var_max": 10, 9 | "pretrain": "resnet34", 10 | "initial_dim": 64, 11 | "block_dims": [64, 128, 256], 12 | "radius": 4, 13 | "dim": 128, 14 | "num_blocks": 2, 15 | "iters": 4, 16 | 17 | "image_size": [432, 960], 18 | "scale": 0, 19 | "batch_size": 16, 20 | "epsilon": 1e-8, 21 | "lr": 1e-4, 22 | "wdecay": 1e-5, 23 | "dropout": 0, 24 | "clip": 1.0, 25 | "gamma": 0.85, 26 | "num_steps": 10000, 27 | 28 | "restore_ckpt": null, 29 | "coarse_config": null 30 | } -------------------------------------------------------------------------------- /sea_raft/config/train/Tartan-C-T-TSKH-kitti432x960-S.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Tartan-C-T-TSKH-kitti432x960-S", 3 | "dataset": "kitti", 4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7], 5 | 6 | "use_var": true, 7 | "var_min": 0, 8 | "var_max": 10, 9 | "pretrain": "resnet18", 10 | "initial_dim": 64, 11 | "block_dims": [64, 128, 256], 12 | "radius": 4, 13 | "dim": 128, 14 | "num_blocks": 2, 15 | "iters": 4, 16 | 17 | "image_size": [432, 960], 18 | "scale": 0, 19 | "batch_size": 16, 20 | "epsilon": 1e-8, 21 | "lr": 1e-4, 22 | "wdecay": 1e-5, 23 | "dropout": 0, 24 | "clip": 1.0, 25 | "gamma": 0.85, 26 | "num_steps": 10000, 27 | 28 | "restore_ckpt": null, 29 | "coarse_config": null 30 | } -------------------------------------------------------------------------------- /sea_raft/config/train/Tartan-C-T-TSKH-spring540x960-M.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Tartan-C-T-TSKH-spring540x960-M", 3 | "dataset": "spring", 4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7], 5 | 6 | "use_var": true, 7 | "var_min": 0, 8 | "var_max": 10, 9 | "pretrain": "resnet34", 10 | "initial_dim": 64, 11 | "block_dims": [64, 128, 256], 12 | "radius": 4, 13 | "dim": 128, 14 | "num_blocks": 2, 15 | "iters": 4, 16 | 17 | "image_size": [540, 960], 18 | "scale": -1, 19 | "batch_size": 32, 20 | "epsilon": 1e-8, 21 | "lr": 4e-4, 22 | "wdecay": 1e-5, 23 | "dropout": 0, 24 | "clip": 1.0, 25 | "gamma": 0.85, 26 | "num_steps": 120000, 27 | 28 | "restore_ckpt": null, 29 | "coarse_config": null 30 | } -------------------------------------------------------------------------------- /sea_raft/config/train/Tartan-C-T-TSKH-spring540x960-S.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Tartan-C-T-TSKH-spring540x960-S", 3 | "dataset": "spring", 4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7], 5 | 6 | "use_var": true, 7 | "var_min": 0, 8 | "var_max": 10, 9 | "pretrain": "resnet18", 10 | "initial_dim": 64, 11 | "block_dims": [64, 128, 256], 12 | "radius": 4, 13 | "dim": 128, 14 | "num_blocks": 2, 15 | "iters": 4, 16 | 17 | "image_size": [540, 960], 18 | "scale": -1, 19 | "batch_size": 32, 20 | "epsilon": 1e-8, 21 | "lr": 4e-4, 22 | "wdecay": 1e-5, 23 | "dropout": 0, 24 | "clip": 1.0, 25 | "gamma": 0.85, 26 | "num_steps": 120000, 27 | 28 | "restore_ckpt": null, 29 | "coarse_config": null 30 | } -------------------------------------------------------------------------------- /sea_raft/config/train/Tartan-C-T-TSKH432x960-M.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Tartan-C-T-TSKH432x960-M", 3 | "dataset": "TSKH", 4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7], 5 | 6 | "use_var": true, 7 | "var_min": 0, 8 | "var_max": 10, 9 | "pretrain": "resnet34", 10 | "initial_dim": 64, 11 | "block_dims": [64, 128, 256], 12 | "radius": 4, 13 | "dim": 128, 14 | "num_blocks": 2, 15 | "iters": 4, 16 | 17 | "image_size": [432, 960], 18 | "scale": 0, 19 | "batch_size": 32, 20 | "epsilon": 1e-8, 21 | "lr": 4e-4, 22 | "wdecay": 1e-5, 23 | "dropout": 0, 24 | "clip": 1.0, 25 | "gamma": 0.85, 26 | "num_steps": 300000, 27 | 28 | "restore_ckpt": null, 29 | "coarse_config": null 30 | } -------------------------------------------------------------------------------- /sea_raft/config/train/Tartan-C-T-TSKH432x960-S.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Tartan-C-T-TSKH432x960-S", 3 | "dataset": "TSKH", 4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7], 5 | 6 | "use_var": true, 7 | "var_min": 0, 8 | "var_max": 10, 9 | "pretrain": "resnet18", 10 | "initial_dim": 64, 11 | "block_dims": [64, 128, 256], 12 | "radius": 4, 13 | "dim": 128, 14 | "num_blocks": 2, 15 | "iters": 4, 16 | 17 | "image_size": [432, 960], 18 | "scale": 0, 19 | "batch_size": 32, 20 | "epsilon": 1e-8, 21 | "lr": 4e-4, 22 | "wdecay": 1e-5, 23 | "dropout": 0, 24 | "clip": 1.0, 25 | "gamma": 0.85, 26 | "num_steps": 300000, 27 | 28 | "restore_ckpt": null, 29 | "coarse_config": null 30 | } -------------------------------------------------------------------------------- /sea_raft/config/train/Tartan-C-T432x960-M.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Tartan-C-T432x960-M", 3 | "dataset": "things", 4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7], 5 | 6 | "use_var": true, 7 | "var_min": 0, 8 | "var_max": 10, 9 | "pretrain": "resnet34", 10 | "initial_dim": 64, 11 | "block_dims": [64, 128, 256], 12 | "radius": 4, 13 | "dim": 128, 14 | "num_blocks": 2, 15 | "iters": 4, 16 | 17 | "image_size": [432, 960], 18 | "scale": 0, 19 | "batch_size": 32, 20 | "epsilon": 1e-8, 21 | "lr": 4e-4, 22 | "wdecay": 1e-5, 23 | "dropout": 0, 24 | "clip": 1.0, 25 | "gamma": 0.85, 26 | "num_steps": 120000, 27 | 28 | "restore_ckpt": null, 29 | "coarse_config": null 30 | } -------------------------------------------------------------------------------- /sea_raft/config/train/Tartan-C-T432x960-S.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Tartan-C-T432x960-S", 3 | "dataset": "things", 4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7], 5 | 6 | "use_var": true, 7 | "var_min": 0, 8 | "var_max": 10, 9 | "pretrain": "resnet18", 10 | "initial_dim": 64, 11 | "block_dims": [64, 128, 256], 12 | "radius": 4, 13 | "dim": 128, 14 | "num_blocks": 2, 15 | "iters": 4, 16 | 17 | "image_size": [432, 960], 18 | "scale": 0, 19 | "batch_size": 32, 20 | "epsilon": 1e-8, 21 | "lr": 4e-4, 22 | "wdecay": 1e-5, 23 | "dropout": 0, 24 | "clip": 1.0, 25 | "gamma": 0.85, 26 | "num_steps": 120000, 27 | 28 | "restore_ckpt": null, 29 | "coarse_config": null 30 | } -------------------------------------------------------------------------------- /sea_raft/config/train/Tartan-C368x496-M.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Tartan-C368x496-M", 3 | "dataset": "chairs", 4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7, 8], 5 | 6 | "use_var": true, 7 | "var_min": 0, 8 | "var_max": 10, 9 | "pretrain": "resnet34", 10 | "initial_dim": 64, 11 | "block_dims": [64, 128, 256], 12 | "radius": 4, 13 | "dim": 128, 14 | "num_blocks": 2, 15 | "iters": 4, 16 | 17 | "image_size": [368, 496], 18 | "scale": 0, 19 | "batch_size": 16, 20 | "epsilon": 1e-8, 21 | "lr": 2.5e-4, 22 | "wdecay": 1e-4, 23 | "dropout": 0, 24 | "clip": 1.0, 25 | "gamma": 0.8, 26 | "num_steps": 100000, 27 | 28 | "restore_ckpt": null, 29 | "coarse_config": null 30 | } -------------------------------------------------------------------------------- /sea_raft/config/train/Tartan-C368x496-S.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Tartan-C368x496-S", 3 | "dataset": "chairs", 4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7, 8], 5 | 6 | "use_var": true, 7 | "var_min": 0, 8 | "var_max": 10, 9 | "pretrain": "resnet18", 10 | "initial_dim": 64, 11 | "block_dims": [64, 128, 256], 12 | "radius": 4, 13 | "dim": 128, 14 | "num_blocks": 2, 15 | "iters": 4, 16 | 17 | "image_size": [368, 496], 18 | "scale": 0, 19 | "batch_size": 16, 20 | "epsilon": 1e-8, 21 | "lr": 2.5e-4, 22 | "wdecay": 1e-4, 23 | "dropout": 0, 24 | "clip": 1.0, 25 | "gamma": 0.8, 26 | "num_steps": 100000, 27 | 28 | "restore_ckpt": null, 29 | "coarse_config": null 30 | } -------------------------------------------------------------------------------- /sea_raft/config/train/Tartan480x640-M.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Tartan480x640-M", 3 | "dataset": "TartanAir", 4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7], 5 | 6 | "use_var": true, 7 | "var_min": 0, 8 | "var_max": 10, 9 | "pretrain": "resnet34", 10 | "initial_dim": 64, 11 | "block_dims": [64, 128, 256], 12 | "radius": 4, 13 | "dim": 128, 14 | "num_blocks": 2, 15 | "iters": 4, 16 | 17 | "image_size": [480, 640], 18 | "scale": 0, 19 | "batch_size": 32, 20 | "epsilon": 1e-8, 21 | "lr": 4e-4, 22 | "wdecay": 1e-5, 23 | "dropout": 0, 24 | "clip": 1.0, 25 | "gamma": 0.85, 26 | "num_steps": 300000, 27 | 28 | "restore_ckpt": null, 29 | "coarse_config": null 30 | } -------------------------------------------------------------------------------- /sea_raft/config/train/Tartan480x640-S.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Tartan480x640-S", 3 | "dataset": "TartanAir", 4 | "gpus": [0, 1, 2, 3, 4, 5, 6, 7], 5 | 6 | "use_var": true, 7 | "var_min": 0, 8 | "var_max": 10, 9 | "pretrain": "resnet18", 10 | "initial_dim": 64, 11 | "block_dims": [64, 128, 256], 12 | "radius": 4, 13 | "dim": 128, 14 | "num_blocks": 2, 15 | "iters": 4, 16 | 17 | "image_size": [480, 640], 18 | "scale": 0, 19 | "batch_size": 32, 20 | "epsilon": 1e-8, 21 | "lr": 4e-4, 22 | "wdecay": 1e-5, 23 | "dropout": 0, 24 | "clip": 1.0, 25 | "gamma": 0.85, 26 | "num_steps": 300000, 27 | 28 | "restore_ckpt": null, 29 | "coarse_config": null 30 | } -------------------------------------------------------------------------------- /sea_raft/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/sea_raft/core/__init__.py -------------------------------------------------------------------------------- /sea_raft/core/extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .layer import BasicBlock, conv1x1, conv3x3 5 | 6 | class ResNetFPN(nn.Module): 7 | """ 8 | ResNet18, output resolution is 1/8. 9 | Each block has 2 layers. 10 | """ 11 | def __init__(self, args, input_dim=3, output_dim=256, ratio=1.0, norm_layer=nn.BatchNorm2d, init_weight=False): 12 | super().__init__() 13 | # Config 14 | block = BasicBlock 15 | block_dims = args.block_dims 16 | initial_dim = args.initial_dim 17 | self.init_weight = init_weight 18 | self.input_dim = input_dim 19 | # Class Variable 20 | self.in_planes = initial_dim 21 | for i in range(len(block_dims)): 22 | block_dims[i] = int(block_dims[i] * ratio) 23 | # Networks 24 | self.conv1 = nn.Conv2d(input_dim, initial_dim, kernel_size=7, stride=2, padding=3) 25 | self.bn1 = norm_layer(initial_dim) 26 | self.relu = nn.ReLU(inplace=True) 27 | if args.pretrain == 'resnet34': 28 | n_block = [3, 4, 6] 29 | elif args.pretrain == 'resnet18': 30 | n_block = [2, 2, 2] 31 | else: 32 | raise NotImplementedError 33 | self.layer1 = self._make_layer(block, block_dims[0], stride=1, norm_layer=norm_layer, num=n_block[0]) # 1/2 34 | self.layer2 = self._make_layer(block, block_dims[1], stride=2, norm_layer=norm_layer, num=n_block[1]) # 1/4 35 | self.layer3 = self._make_layer(block, block_dims[2], stride=2, norm_layer=norm_layer, num=n_block[2]) # 1/8 36 | self.final_conv = conv1x1(block_dims[2], output_dim) 37 | self._init_weights(args) 38 | 39 | def _init_weights(self, args): 40 | 41 | for m in self.modules(): 42 | if isinstance(m, nn.Conv2d): 43 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 44 | if m.bias is not None: 45 | nn.init.constant_(m.bias, 0) 46 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 47 | if m.weight is not None: 48 | nn.init.constant_(m.weight, 1) 49 | if m.bias is not None: 50 | nn.init.constant_(m.bias, 0) 51 | 52 | if self.init_weight: 53 | from torchvision.models import resnet18, ResNet18_Weights, resnet34, ResNet34_Weights 54 | if args.pretrain == 'resnet18': 55 | pretrained_dict = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1).state_dict() 56 | else: 57 | pretrained_dict = resnet34(weights=ResNet34_Weights.IMAGENET1K_V1).state_dict() 58 | model_dict = self.state_dict() 59 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 60 | if self.input_dim == 6: 61 | for k, v in pretrained_dict.items(): 62 | if k == 'conv1.weight': 63 | pretrained_dict[k] = torch.cat((v, v), dim=1) 64 | model_dict.update(pretrained_dict) 65 | self.load_state_dict(model_dict, strict=False) 66 | 67 | 68 | def _make_layer(self, block, dim, stride=1, norm_layer=nn.BatchNorm2d, num=2): 69 | layers = [] 70 | layers.append(block(self.in_planes, dim, stride=stride, norm_layer=norm_layer)) 71 | for i in range(num - 1): 72 | layers.append(block(dim, dim, stride=1, norm_layer=norm_layer)) 73 | self.in_planes = dim 74 | return nn.Sequential(*layers) 75 | 76 | def forward(self, x): 77 | # ResNet Backbone 78 | x = self.relu(self.bn1(self.conv1(x))) 79 | for i in range(len(self.layer1)): 80 | x = self.layer1[i](x) 81 | for i in range(len(self.layer2)): 82 | x = self.layer2[i](x) 83 | for i in range(len(self.layer3)): 84 | x = self.layer3[i](x) 85 | # Output 86 | output = self.final_conv(x) 87 | return output -------------------------------------------------------------------------------- /sea_raft/core/layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import torch 6 | import math 7 | from torch.nn import Module, Dropout 8 | 9 | ### Gradient Clipping and Zeroing Operations ### 10 | 11 | GRAD_CLIP = 0.1 12 | 13 | class GradClip(torch.autograd.Function): 14 | @staticmethod 15 | def forward(ctx, x): 16 | return x 17 | 18 | @staticmethod 19 | def backward(ctx, grad_x): 20 | grad_x = torch.where(torch.isnan(grad_x), torch.zeros_like(grad_x), grad_x) 21 | return grad_x.clamp(min=-0.01, max=0.01) 22 | 23 | class GradientClip(nn.Module): 24 | def __init__(self): 25 | super(GradientClip, self).__init__() 26 | 27 | def forward(self, x): 28 | return GradClip.apply(x) 29 | 30 | def _make_divisible(v, divisor, min_value=None): 31 | if min_value is None: 32 | min_value = divisor 33 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 34 | # Make sure that round down does not go down by more than 10%. 35 | if new_v < 0.9 * v: 36 | new_v += divisor 37 | return new_v 38 | 39 | class ConvNextBlock(nn.Module): 40 | r""" ConvNeXt Block. There are two equivalent implementations: 41 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) 42 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back 43 | We use (2) as we find it slightly faster in PyTorch 44 | 45 | Args: 46 | dim (int): Number of input channels. 47 | drop_path (float): Stochastic depth rate. Default: 0.0 48 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 49 | """ 50 | def __init__(self, dim, output_dim, layer_scale_init_value=1e-6): 51 | super().__init__() 52 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv 53 | self.norm = LayerNorm(dim, eps=1e-6) 54 | self.pwconv1 = nn.Linear(dim, 4 * output_dim) # pointwise/1x1 convs, implemented with linear layers 55 | self.act = nn.GELU() 56 | self.pwconv2 = nn.Linear(4 * output_dim, dim) 57 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 58 | requires_grad=True) if layer_scale_init_value > 0 else None 59 | self.final = nn.Conv2d(dim, output_dim, kernel_size=1, padding=0) 60 | 61 | def forward(self, x): 62 | input = x 63 | x = self.dwconv(x) 64 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 65 | x = self.norm(x) 66 | x = self.pwconv1(x) 67 | x = self.act(x) 68 | x = self.pwconv2(x) 69 | if self.gamma is not None: 70 | x = self.gamma * x 71 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 72 | x = self.final(input + x) 73 | return x 74 | 75 | class LayerNorm(nn.Module): 76 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 77 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 78 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 79 | with shape (batch_size, channels, height, width). 80 | """ 81 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 82 | super().__init__() 83 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 84 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 85 | self.eps = eps 86 | self.data_format = data_format 87 | if self.data_format not in ["channels_last", "channels_first"]: 88 | raise NotImplementedError 89 | self.normalized_shape = (normalized_shape, ) 90 | 91 | def forward(self, x): 92 | if self.data_format == "channels_last": 93 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 94 | elif self.data_format == "channels_first": 95 | u = x.mean(1, keepdim=True) 96 | s = (x - u).pow(2).mean(1, keepdim=True) 97 | x = (x - u) / torch.sqrt(s + self.eps) 98 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 99 | return x 100 | 101 | def conv1x1(in_planes, out_planes, stride=1): 102 | """1x1 convolution without padding""" 103 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0) 104 | 105 | 106 | def conv3x3(in_planes, out_planes, stride=1): 107 | """3x3 convolution with padding""" 108 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1) 109 | 110 | class BasicBlock(nn.Module): 111 | def __init__(self, in_planes, planes, stride=1, norm_layer=nn.BatchNorm2d): 112 | super().__init__() 113 | 114 | # self.sparse = sparse 115 | self.conv1 = conv3x3(in_planes, planes, stride) 116 | self.conv2 = conv3x3(planes, planes) 117 | self.bn1 = norm_layer(planes) 118 | self.bn2 = norm_layer(planes) 119 | self.relu = nn.ReLU(inplace=True) 120 | if stride == 1 and in_planes == planes: 121 | self.downsample = None 122 | else: 123 | self.bn3 = norm_layer(planes) 124 | self.downsample = nn.Sequential( 125 | conv1x1(in_planes, planes, stride=stride), 126 | self.bn3 127 | ) 128 | 129 | def forward(self, x): 130 | y = x 131 | y = self.relu(self.bn1(self.conv1(y))) 132 | y = self.relu(self.bn2(self.conv2(y))) 133 | if self.downsample is not None: 134 | x = self.downsample(x) 135 | return self.relu(x+y) -------------------------------------------------------------------------------- /sea_raft/core/loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | # exclude extremly large displacements 7 | MAX_FLOW = 400 8 | SUM_FREQ = 100 9 | VAL_FREQ = 5000 10 | 11 | def sequence_loss(output, flow_gt, valid, gamma=0.8, max_flow=MAX_FLOW): 12 | """ Loss function defined over sequence of flow predictions """ 13 | n_predictions = len(output['flow']) 14 | flow_loss = 0.0 15 | # exlude invalid pixels and extremely large diplacements 16 | mag = torch.sum(flow_gt**2, dim=1).sqrt() 17 | valid = (valid >= 0.5) & (mag < max_flow) 18 | for i in range(n_predictions): 19 | i_weight = gamma ** (n_predictions - i - 1) 20 | loss_i = output['nf'][i] 21 | final_mask = (~torch.isnan(loss_i.detach())) & (~torch.isinf(loss_i.detach())) & valid[:, None] 22 | flow_loss += i_weight * ((final_mask * loss_i).sum() / final_mask.sum()) 23 | 24 | return flow_loss -------------------------------------------------------------------------------- /sea_raft/core/update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .layer import ConvNextBlock 5 | 6 | class FlowHead(nn.Module): 7 | def __init__(self, input_dim=128, hidden_dim=256, output_dim=4): 8 | super(FlowHead, self).__init__() 9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 10 | self.conv2 = nn.Conv2d(hidden_dim, output_dim, 3, padding=1) 11 | self.relu = nn.ReLU(inplace=True) 12 | 13 | def forward(self, x): 14 | return self.conv2(self.relu(self.conv1(x))) 15 | 16 | class BasicMotionEncoder(nn.Module): 17 | def __init__(self, args, dim=128): 18 | super(BasicMotionEncoder, self).__init__() 19 | cor_planes = args.corr_channel 20 | self.convc1 = nn.Conv2d(cor_planes, dim*2, 1, padding=0) 21 | self.convc2 = nn.Conv2d(dim*2, dim+dim//2, 3, padding=1) 22 | self.convf1 = nn.Conv2d(2, dim, 7, padding=3) 23 | self.convf2 = nn.Conv2d(dim, dim//2, 3, padding=1) 24 | self.conv = nn.Conv2d(dim*2, dim-2, 3, padding=1) 25 | 26 | def forward(self, flow, corr): 27 | cor = F.relu(self.convc1(corr)) 28 | cor = F.relu(self.convc2(cor)) 29 | flo = F.relu(self.convf1(flow)) 30 | flo = F.relu(self.convf2(flo)) 31 | 32 | cor_flo = torch.cat([cor, flo], dim=1) 33 | out = F.relu(self.conv(cor_flo)) 34 | return torch.cat([out, flow], dim=1) 35 | 36 | class BasicUpdateBlock(nn.Module): 37 | def __init__(self, args, hdim=128, cdim=128): 38 | #net: hdim, inp: cdim 39 | super(BasicUpdateBlock, self).__init__() 40 | self.args = args 41 | self.encoder = BasicMotionEncoder(args, dim=cdim) 42 | self.refine = [] 43 | for i in range(args.num_blocks): 44 | self.refine.append(ConvNextBlock(2*cdim+hdim, hdim)) 45 | self.refine = nn.ModuleList(self.refine) 46 | 47 | def forward(self, net, inp, corr, flow, upsample=True): 48 | motion_features = self.encoder(flow, corr) 49 | inp = torch.cat([inp, motion_features], dim=1) 50 | for blk in self.refine: 51 | net = blk(torch.cat([net, inp], dim=1)) 52 | return net -------------------------------------------------------------------------------- /sea_raft/core/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/sea_raft/core/utils/__init__.py -------------------------------------------------------------------------------- /sea_raft/core/utils/flow_viz.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 2 | 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Tom Runia 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to conditions. 14 | # 15 | # Author: Tom Runia 16 | # Date Created: 2018-08-03 17 | 18 | import numpy as np 19 | 20 | def make_colorwheel(): 21 | """ 22 | Generates a color wheel for optical flow visualization as presented in: 23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 25 | 26 | Code follows the original C++ source code of Daniel Scharstein. 27 | Code follows the the Matlab source code of Deqing Sun. 28 | 29 | Returns: 30 | np.ndarray: Color wheel 31 | """ 32 | 33 | RY = 15 34 | YG = 6 35 | GC = 4 36 | CB = 11 37 | BM = 13 38 | MR = 6 39 | 40 | ncols = RY + YG + GC + CB + BM + MR 41 | colorwheel = np.zeros((ncols, 3)) 42 | col = 0 43 | 44 | # RY 45 | colorwheel[0:RY, 0] = 255 46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) 47 | col = col+RY 48 | # YG 49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) 50 | colorwheel[col:col+YG, 1] = 255 51 | col = col+YG 52 | # GC 53 | colorwheel[col:col+GC, 1] = 255 54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) 55 | col = col+GC 56 | # CB 57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) 58 | colorwheel[col:col+CB, 2] = 255 59 | col = col+CB 60 | # BM 61 | colorwheel[col:col+BM, 2] = 255 62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) 63 | col = col+BM 64 | # MR 65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) 66 | colorwheel[col:col+MR, 0] = 255 67 | return colorwheel 68 | 69 | 70 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 71 | """ 72 | Applies the flow color wheel to (possibly clipped) flow components u and v. 73 | 74 | According to the C++ source code of Daniel Scharstein 75 | According to the Matlab source code of Deqing Sun 76 | 77 | Args: 78 | u (np.ndarray): Input horizontal flow of shape [H,W] 79 | v (np.ndarray): Input vertical flow of shape [H,W] 80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 81 | 82 | Returns: 83 | np.ndarray: Flow visualization image of shape [H,W,3] 84 | """ 85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 86 | colorwheel = make_colorwheel() # shape [55x3] 87 | ncols = colorwheel.shape[0] 88 | rad = np.sqrt(np.square(u) + np.square(v)) 89 | a = np.arctan2(-v, -u)/np.pi 90 | fk = (a+1) / 2*(ncols-1) 91 | k0 = np.floor(fk).astype(np.int32) 92 | k1 = k0 + 1 93 | k1[k1 == ncols] = 0 94 | f = fk - k0 95 | for i in range(colorwheel.shape[1]): 96 | tmp = colorwheel[:,i] 97 | col0 = tmp[k0] / 255.0 98 | col1 = tmp[k1] / 255.0 99 | col = (1-f)*col0 + f*col1 100 | idx = (rad <= 1) 101 | col[idx] = 1 - rad[idx] * (1-col[idx]) 102 | col[~idx] = col[~idx] * 0.75 # out of range 103 | # Note the 2-i => BGR instead of RGB 104 | ch_idx = 2-i if convert_to_bgr else i 105 | flow_image[:,:,ch_idx] = np.floor(255 * col) 106 | return flow_image 107 | 108 | 109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 110 | """ 111 | Expects a two dimensional flow image of shape. 112 | 113 | Args: 114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 117 | 118 | Returns: 119 | np.ndarray: Flow visualization image of shape [H,W,3] 120 | """ 121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 123 | if clip_flow is not None: 124 | flow_uv = np.clip(flow_uv, 0, clip_flow) 125 | u = flow_uv[:,:,0] 126 | v = flow_uv[:,:,1] 127 | rad = np.sqrt(np.square(u) + np.square(v)) 128 | rad_max = np.max(rad) 129 | epsilon = 1e-5 130 | u = u / (rad_max + epsilon) 131 | v = v / (rad_max + epsilon) 132 | return flow_uv_to_colors(u, v, convert_to_bgr) -------------------------------------------------------------------------------- /sea_raft/core/utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from os.path import * 4 | import re 5 | import h5py 6 | 7 | import cv2 8 | cv2.setNumThreads(0) 9 | cv2.ocl.setUseOpenCL(False) 10 | 11 | TAG_CHAR = np.array([202021.25], np.float32) 12 | 13 | def readFlow(fn): 14 | """ Read .flo file in Middlebury format""" 15 | # Code adapted from: 16 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 17 | 18 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 19 | # print 'fn = %s'%(fn) 20 | with open(fn, 'rb') as f: 21 | magic = np.fromfile(f, np.float32, count=1) 22 | if 202021.25 != magic: 23 | print('Magic number incorrect. Invalid .flo file') 24 | return None 25 | else: 26 | w = np.fromfile(f, np.int32, count=1) 27 | h = np.fromfile(f, np.int32, count=1) 28 | # print 'Reading %d x %d flo file\n' % (w, h) 29 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) 30 | # Reshape data into 3D array (columns, rows, bands) 31 | # The reshape here is for visualization, the original code is (w,h,2) 32 | return np.resize(data, (int(h), int(w), 2)) 33 | 34 | def readPFM(file): 35 | file = open(file, 'rb') 36 | 37 | color = None 38 | width = None 39 | height = None 40 | scale = None 41 | endian = None 42 | 43 | header = file.readline().rstrip() 44 | if header == b'PF': 45 | color = True 46 | elif header == b'Pf': 47 | color = False 48 | else: 49 | raise Exception('Not a PFM file.') 50 | 51 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) 52 | if dim_match: 53 | width, height = map(int, dim_match.groups()) 54 | else: 55 | raise Exception('Malformed PFM header.') 56 | 57 | scale = float(file.readline().rstrip()) 58 | if scale < 0: # little-endian 59 | endian = '<' 60 | scale = -scale 61 | else: 62 | endian = '>' # big-endian 63 | 64 | data = np.fromfile(file, endian + 'f') 65 | shape = (height, width, 3) if color else (height, width) 66 | 67 | data = np.reshape(data, shape) 68 | data = np.flipud(data) 69 | return data 70 | 71 | def writeFlow(filename,uv,v=None): 72 | """ Write optical flow to file. 73 | 74 | If v is None, uv is assumed to contain both u and v channels, 75 | stacked in depth. 76 | Original code by Deqing Sun, adapted from Daniel Scharstein. 77 | """ 78 | nBands = 2 79 | 80 | if v is None: 81 | assert(uv.ndim == 3) 82 | assert(uv.shape[2] == 2) 83 | u = uv[:,:,0] 84 | v = uv[:,:,1] 85 | else: 86 | u = uv 87 | 88 | assert(u.shape == v.shape) 89 | height,width = u.shape 90 | f = open(filename,'wb') 91 | # write the header 92 | f.write(TAG_CHAR) 93 | np.array(width).astype(np.int32).tofile(f) 94 | np.array(height).astype(np.int32).tofile(f) 95 | # arrange into matrix form 96 | tmp = np.zeros((height, width*nBands)) 97 | tmp[:,np.arange(width)*2] = u 98 | tmp[:,np.arange(width)*2 + 1] = v 99 | tmp.astype(np.float32).tofile(f) 100 | f.close() 101 | 102 | 103 | def readFlowKITTI(filename): 104 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) 105 | flow = flow[:,:,::-1].astype(np.float32) 106 | flow, valid = flow[:, :, :2], flow[:, :, 2] 107 | flow = (flow - 2**15) / 64.0 108 | return flow, valid 109 | 110 | def readDispKITTI(filename): 111 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 112 | valid = disp > 0.0 113 | flow = np.stack([-disp, np.zeros_like(disp)], -1) 114 | return flow, valid 115 | 116 | 117 | def writeFlowKITTI(filename, uv): 118 | uv = 64.0 * uv + 2**15 119 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 120 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 121 | cv2.imwrite(filename, uv[..., ::-1]) 122 | 123 | def readFlo5Flow(filename): 124 | with h5py.File(filename, "r") as f: 125 | if "flow" not in f.keys(): 126 | raise IOError(f"File {filename} does not have a 'flow' key. Is this a valid flo5 file?") 127 | return f["flow"][()] 128 | 129 | def writeFlo5File(flow, filename): 130 | with h5py.File(filename, "w") as f: 131 | f.create_dataset("flow", data=flow, compression="gzip", compression_opts=5) 132 | 133 | def read_gen(file_name, pil=False): 134 | ext = splitext(file_name)[-1] 135 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 136 | return Image.open(file_name) 137 | elif ext == '.bin' or ext == '.raw': 138 | return np.load(file_name) 139 | elif ext == '.flo': 140 | return readFlow(file_name).astype(np.float32) 141 | elif ext == '.pfm': 142 | flow = readPFM(file_name).astype(np.float32) 143 | if len(flow.shape) == 2: 144 | return flow 145 | else: 146 | return flow[:, :, :-1] 147 | elif ext == '.flo5': 148 | return readFlo5Flow(file_name) 149 | return [] -------------------------------------------------------------------------------- /sea_raft/core/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy import interpolate 5 | 6 | def load_ckpt(model, path): 7 | """ Load checkpoint """ 8 | state_dict = torch.load(path, map_location=torch.device('cpu')) 9 | model.load_state_dict(state_dict, strict=False) 10 | 11 | def resize_data(img1, img2, flow, factor=1.0): 12 | _, _, h, w = img1.shape 13 | h = int(h * factor) 14 | w = int(w * factor) 15 | img1 = F.interpolate(img1, (h, w), mode='area') 16 | img2 = F.interpolate(img2, (h, w), mode='area') 17 | flow = F.interpolate(flow, (h, w), mode='area') * factor 18 | return img1, img2, flow 19 | 20 | class InputPadder: 21 | """ Pads images such that dimensions are divisible by 8 """ 22 | def __init__(self, dims, mode='sintel'): 23 | self.ht, self.wd = dims[-2:] 24 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 25 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 26 | if mode == 'sintel': 27 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 28 | else: 29 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] 30 | 31 | def pad(self, *inputs): 32 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 33 | 34 | def unpad(self, x): 35 | ht, wd = x.shape[-2:] 36 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 37 | return x[..., c[0]:c[1], c[2]:c[3]] 38 | 39 | def forward_interpolate(flow): 40 | flow = flow.detach().cpu().numpy() 41 | dx, dy = flow[0], flow[1] 42 | 43 | ht, wd = dx.shape 44 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 45 | 46 | x1 = x0 + dx 47 | y1 = y0 + dy 48 | 49 | x1 = x1.reshape(-1) 50 | y1 = y1.reshape(-1) 51 | dx = dx.reshape(-1) 52 | dy = dy.reshape(-1) 53 | 54 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 55 | x1 = x1[valid] 56 | y1 = y1[valid] 57 | dx = dx[valid] 58 | dy = dy[valid] 59 | 60 | flow_x = interpolate.griddata( 61 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 62 | 63 | flow_y = interpolate.griddata( 64 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) 65 | 66 | flow = np.stack([flow_x, flow_y], axis=0) 67 | return torch.from_numpy(flow).float() 68 | 69 | 70 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 71 | """ Wrapper for grid_sample, uses pixel coordinates """ 72 | H, W = img.shape[-2:] 73 | xgrid, ygrid = coords.split([1,1], dim=-1) 74 | xgrid = 2*xgrid/(W-1) - 1 75 | ygrid = 2*ygrid/(H-1) - 1 76 | 77 | grid = torch.cat([xgrid, ygrid], dim=-1) 78 | img = F.grid_sample(img, grid, align_corners=True) 79 | 80 | if mask: 81 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 82 | return img, mask.float() 83 | 84 | return img 85 | 86 | def coords_grid(batch, ht, wd, device): 87 | coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device)) 88 | coords = torch.stack(coords[::-1], dim=0).float() 89 | return coords[None].repeat(batch, 1, 1, 1) 90 | 91 | 92 | def upflow8(flow, mode='bilinear'): 93 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 94 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 95 | 96 | def transform(T, p): 97 | assert T.shape == (4,4) 98 | return np.einsum('H W j, i j -> H W i', p, T[:3,:3]) + T[:3, 3] 99 | 100 | def from_homog(x): 101 | return x[...,:-1] / x[...,[-1]] 102 | 103 | def reproject(depth1, pose1, pose2, K1, K2): 104 | H, W = depth1.shape 105 | x, y = np.meshgrid(np.arange(W), np.arange(H), indexing='xy') 106 | img_1_coords = np.stack((x, y, np.ones_like(x)), axis=-1).astype(np.float64) 107 | cam1_coords = np.einsum('H W, H W j, i j -> H W i', depth1, img_1_coords, np.linalg.inv(K1)) 108 | rel_pose = np.linalg.inv(pose2) @ pose1 109 | cam2_coords = transform(rel_pose, cam1_coords) 110 | return from_homog(np.einsum('H W j, i j -> H W i', cam2_coords, K2)) 111 | 112 | def induced_flow(depth0, depth1, data): 113 | H, W = depth0.shape 114 | coords1 = reproject(depth0, data['T0'], data['T1'], data['K0'], data['K1']) 115 | x, y = np.meshgrid(np.arange(W), np.arange(H), indexing='xy') 116 | coords0 = np.stack([x, y], axis=-1) 117 | flow_01 = coords1 - coords0 118 | 119 | H, W = depth1.shape 120 | coords1 = reproject(depth1, data['T1'], data['T0'], data['K1'], data['K0']) 121 | x, y = np.meshgrid(np.arange(W), np.arange(H), indexing='xy') 122 | coords0 = np.stack([x, y], axis=-1) 123 | flow_10 = coords1 - coords0 124 | 125 | return flow_01, flow_10 126 | 127 | def check_cycle_consistency(flow_01, flow_10): 128 | flow_01 = torch.from_numpy(flow_01).permute(2, 0, 1)[None] 129 | flow_10 = torch.from_numpy(flow_10).permute(2, 0, 1)[None] 130 | H, W = flow_01.shape[-2:] 131 | coords = coords_grid(1, H, W, flow_01.device) 132 | coords1 = coords + flow_01 133 | flow_reprojected = bilinear_sampler(flow_10, coords1.permute(0, 2, 3, 1)) 134 | cycle = flow_reprojected + flow_01 135 | cycle = torch.norm(cycle, dim=1) 136 | mask = (cycle < 0.1 * min(H, W)).float() 137 | return mask[0].numpy() -------------------------------------------------------------------------------- /sea_raft/custom.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, List 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from .core.raft import RAFT 7 | 8 | 9 | def get_model_config(): 10 | return { 11 | "name": "spring-M", 12 | "dataset": "spring", 13 | "gpus": [ 14 | 0, 15 | 1, 16 | 2, 17 | 3, 18 | 4, 19 | 5, 20 | 6, 21 | 7 22 | ], 23 | "use_var": True, 24 | "var_min": 0, 25 | "var_max": 10, 26 | "pretrain": "resnet34", 27 | "initial_dim": 64, 28 | "block_dims": [ 29 | 64, 30 | 128, 31 | 256 32 | ], 33 | "radius": 4, 34 | "dim": 128, 35 | "num_blocks": 2, 36 | "iters": 4, 37 | "image_size": [ 38 | 540, 39 | 960 40 | ], 41 | "scale": -1, 42 | "batch_size": 32, 43 | "epsilon": 1e-8, 44 | "lr": 4e-4, 45 | "wdecay": 1e-5, 46 | "dropout": 0, 47 | "clip": 1.0, 48 | "gamma": 0.85, 49 | "num_steps": 120000, 50 | "restore_ckpt": None, 51 | "coarse_config": None 52 | } 53 | 54 | 55 | @dataclass 56 | class ModelArgs: 57 | name: str = None 58 | dataset: str = None 59 | gpus: List[int] = None 60 | use_var: bool = None 61 | var_min: int = None 62 | var_max: int = None 63 | pretrain: str = None 64 | initial_dim: int = None 65 | block_dims: List[int] = None 66 | radius: int = None 67 | dim: int = None 68 | num_blocks: int = None 69 | iters: int = None 70 | image_size: List[int] = None 71 | scale: int = None 72 | batch_size: int = None 73 | epsilon: float = None 74 | lr: float = None 75 | wdecay: float = None 76 | dropout: float = None 77 | clip: float = None 78 | gamma: float = None 79 | num_steps: int = None 80 | restore_ckpt: Any = None 81 | coarse_config: Any = None 82 | 83 | 84 | def forward_flow(args, model, image1, image2): 85 | output = model(image1, image2, iters=args.iters, test_mode=True) 86 | flow_final = output['flow'][-1] 87 | info_final = output['info'][-1] 88 | return flow_final, info_final 89 | 90 | 91 | @torch.no_grad 92 | def calc_flow(args, model, image1, image2): 93 | img1 = F.interpolate(image1, scale_factor=2 ** args.scale, mode='bilinear', align_corners=False) 94 | img2 = F.interpolate(image2, scale_factor=2 ** args.scale, mode='bilinear', align_corners=False) 95 | H, W = img1.shape[2:] 96 | flow, info = forward_flow(args, model, img1, img2) 97 | flow_down = F.interpolate(flow, scale_factor=0.5 ** args.scale, mode='bilinear', align_corners=False) * (0.5 ** args.scale) 98 | info_down = F.interpolate(info, scale_factor=0.5 ** args.scale, mode='area') 99 | return flow_down, info_down 100 | 101 | 102 | class RaftWrapper(torch.nn.Module): 103 | def __init__(self, model, model_args): 104 | super(RaftWrapper, self).__init__() 105 | self.model = model 106 | self.args = model_args 107 | 108 | def forward(self, image1, image2): 109 | flow_down, info_down = calc_flow(self.args, self.model, image1, image2) 110 | return flow_down, info_down 111 | 112 | def eval(self): 113 | self.model.eval() 114 | 115 | 116 | def get_model(checkpoint_path): 117 | model_args = ModelArgs(**get_model_config()) 118 | model = RAFT(model_args) 119 | state_dict = torch.load(checkpoint_path) 120 | model.load_state_dict(state_dict) 121 | 122 | return RaftWrapper(model, model_args) 123 | 124 | -------------------------------------------------------------------------------- /sea_raft/custom/flow.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/sea_raft/custom/flow.jpg -------------------------------------------------------------------------------- /sea_raft/custom/heatmap.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/sea_raft/custom/heatmap.jpg -------------------------------------------------------------------------------- /sea_raft/custom/image1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/sea_raft/custom/image1.jpg -------------------------------------------------------------------------------- /sea_raft/custom/image2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/sea_raft/custom/image2.jpg -------------------------------------------------------------------------------- /sea_raft/ddp_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from time import sleep 3 | import numpy as np 4 | import torch 5 | import torch.distributed as dist 6 | import torch.multiprocessing as mp 7 | from datetime import timedelta 8 | import random 9 | 10 | def init_fn(worker_id): 11 | np.random.seed(987) 12 | random.seed(987) 13 | 14 | def process_group_initialized(): 15 | try: 16 | dist.get_world_size() 17 | return True 18 | except: 19 | return False 20 | 21 | def calc_num_workers(): 22 | try: 23 | world_size = dist.get_world_size() 24 | except: 25 | world_size = 1 26 | return len(os.sched_getaffinity(0)) // world_size 27 | 28 | def setup_ddp(rank, world_size): 29 | dist.init_process_group(backend='nccl', rank=rank, world_size=world_size) 30 | torch.manual_seed(987) 31 | torch.cuda.set_device(rank) 32 | 33 | def init_ddp(): 34 | os.environ['MASTER_ADDR'] = 'localhost' 35 | os.environ['MASTER_PORT'] = str(11451 + np.random.randint(100)) 36 | world_size = torch.cuda.device_count() 37 | assert world_size > 0, "You need a GPU!" 38 | smp = mp.get_context('spawn') 39 | return smp, world_size 40 | 41 | def wait_for_world(state: mp.Queue, world_size): 42 | state.put(1) 43 | while state.qsize() < world_size: 44 | pass 45 | for _ in range(world_size): 46 | state.get() -------------------------------------------------------------------------------- /sea_raft/eval_ptlflow.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('core') 3 | import argparse 4 | import os 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.utils.data as data 11 | 12 | import datasets 13 | from raft import RAFT 14 | from tqdm import tqdm 15 | 16 | from utils import flow_viz 17 | from utils import frame_utils 18 | from utils.utils import resize_data, load_ckpt 19 | 20 | import ptlflow 21 | from ptlflow.utils import flow_utils 22 | 23 | def forward_flow(model, image1, image2, scale=0, mode='downsample'): 24 | if mode == 'downsample': 25 | dlt = 0 # avoid edge effects 26 | image1 = image1 / 255. 27 | image2 = image2 / 255. 28 | img1 = F.interpolate(image1, scale_factor=2 ** scale, mode='bilinear', align_corners=False) 29 | img2 = F.interpolate(image2, scale_factor=2 ** scale, mode='bilinear', align_corners=False) 30 | img1 = F.pad(img1, (dlt, dlt, dlt, dlt), "constant", 0) 31 | img2 = F.pad(img2, (dlt, dlt, dlt, dlt), "constant", 0) 32 | H, W = img1.shape[2:] 33 | inputs = {"images": torch.stack([img1, img2], dim=1)} 34 | predictions = model(inputs) 35 | flow = predictions['flows'][:, 0] 36 | flow = flow[..., dlt: H-dlt, dlt: W-dlt] 37 | flow = F.interpolate(flow, scale_factor=0.5 ** scale, mode='bilinear', align_corners=False) * (0.5 ** scale) 38 | else: 39 | raise NotImplementedError 40 | return flow 41 | 42 | @torch.no_grad() 43 | def validate_spring(model, mode='downsample'): 44 | """ Peform validation using the Spring (val) split """ 45 | val_dataset = datasets.SpringFlowDataset(split='val') + datasets.SpringFlowDataset(split='train') 46 | val_loader = data.DataLoader(val_dataset, batch_size=4, 47 | pin_memory=False, shuffle=False, num_workers=16, drop_last=False) 48 | 49 | epe_list = np.array([], dtype=np.float32) 50 | px1_list = np.array([], dtype=np.float32) 51 | px3_list = np.array([], dtype=np.float32) 52 | px5_list = np.array([], dtype=np.float32) 53 | for i_batch, data_blob in enumerate(val_loader): 54 | image1, image2, flow_gt, valid = [x.cuda(non_blocking=True) for x in data_blob] 55 | flow = forward_flow(model, image1, image2, scale=-1, mode=mode) 56 | epe = torch.sum((flow - flow_gt)**2, dim=1).sqrt() 57 | px1 = (epe < 1.0).float().mean(dim=[1, 2]).cpu().numpy() 58 | px3 = (epe < 3.0).float().mean(dim=[1, 2]).cpu().numpy() 59 | px5 = (epe < 5.0).float().mean(dim=[1, 2]).cpu().numpy() 60 | epe = epe.mean(dim=[1, 2]).cpu().numpy() 61 | epe_list = np.append(epe_list, epe) 62 | px1_list = np.append(px1_list, px1) 63 | px3_list = np.append(px3_list, px3) 64 | px5_list = np.append(px5_list, px5) 65 | 66 | epe = np.mean(epe_list) 67 | px1 = np.mean(px1_list) 68 | px3 = np.mean(px3_list) 69 | px5 = np.mean(px5_list) 70 | 71 | print(f"Validation Spring EPE: {epe}, 1px: {100 * (1 - px1)}") 72 | 73 | @torch.no_grad() 74 | def validate_middlebury(model, mode='downsample'): 75 | """ Peform validation using the Middlebury (public) split """ 76 | val_dataset = datasets.Middlebury() 77 | val_loader = data.DataLoader(val_dataset, batch_size=1, 78 | pin_memory=False, shuffle=False, num_workers=16, drop_last=False) 79 | epe_list = np.array([], dtype=np.float32) 80 | num_valid_pixels = 0 81 | out_valid_pixels = 0 82 | for i_batch, data_blob in enumerate(val_loader): 83 | image1, image2, flow_gt, valid_gt = [x.cuda(non_blocking=True) for x in data_blob] 84 | flow = forward_flow(model, image1, image2, scale=0, mode=mode) 85 | epe = torch.sum((flow - flow_gt)**2, dim=1).sqrt() 86 | mag = torch.sum(flow_gt**2, dim=1).sqrt() 87 | val = valid_gt >= 0.5 88 | out = ((epe > 3.0) & ((epe/mag) > 0.05)).float() 89 | for b in range(out.shape[0]): 90 | epe_list = np.append(epe_list, epe[b][val[b]].mean().cpu().numpy()) 91 | out_valid_pixels += out[b][val[b]].sum().cpu().numpy() 92 | num_valid_pixels += val[b].sum().cpu().numpy() 93 | 94 | epe = np.mean(epe_list) 95 | f1 = 100 * out_valid_pixels / num_valid_pixels 96 | print("Validation middlebury: %f, %f" % (epe, f1)) 97 | 98 | def eval(args): 99 | 100 | # Get an initialized model from PTLFlow 101 | device = torch.device('cuda') 102 | model = ptlflow.get_model(args.model, 'mixed').to(device) 103 | if "use_tile_input" in model.args: 104 | model.args.use_tile_input = False 105 | model.eval() 106 | print(args.model) 107 | with torch.no_grad(): 108 | try: 109 | validate_middlebury(model, mode='downsample') 110 | except: 111 | print('Middlebury validation failed') 112 | validate_spring(model, mode='downsample') 113 | 114 | def main(): 115 | parser = argparse.ArgumentParser() 116 | parser.add_argument('--model', help='experiment configure file name', required=True, type=str) 117 | args = parser.parse_args() 118 | eval(args) 119 | 120 | if __name__ == '__main__': 121 | main() -------------------------------------------------------------------------------- /sea_raft/profile_ptlflow.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('core') 3 | import argparse 4 | import os 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.utils.data as data 11 | 12 | import datasets 13 | from raft import RAFT 14 | from tqdm import tqdm 15 | 16 | from utils import flow_viz 17 | from utils import frame_utils 18 | from utils.profile import profile_model 19 | from utils.utils import resize_data, load_ckpt 20 | 21 | import ptlflow 22 | from ptlflow.utils import flow_utils 23 | 24 | @torch.no_grad() 25 | def eval(args): 26 | # Get an initialized model from PTLFlow 27 | model = ptlflow.get_model(args.model, 'mixed').cuda() 28 | if "use_tile_input" in model.args: 29 | model.args.use_tile_input = False 30 | model.eval() 31 | h, w = 540, 960 32 | inputs = {"images": torch.zeros(1, 2, 3, h, w).cuda()} 33 | with torch.profiler.profile( 34 | activities=[ 35 | torch.profiler.ProfilerActivity.CUDA, 36 | torch.profiler.ProfilerActivity.CPU 37 | ], 38 | with_flops=True) as prof: 39 | output = model(inputs) 40 | events = prof.events() 41 | print(prof.key_averages(group_by_stack_n=5).table(sort_by='self_cuda_time_total', row_limit=5)) 42 | forward_MACs = sum([int(evt.flops) for evt in events]) 43 | print("forward MACs: ", forward_MACs / 2 / 1e9, "G") 44 | 45 | def main(): 46 | parser = argparse.ArgumentParser() 47 | parser.add_argument('--model', help='experiment configure file name', required=True, type=str) 48 | args = parser.parse_args() 49 | eval(args) 50 | 51 | if __name__ == '__main__': 52 | main() -------------------------------------------------------------------------------- /sea_raft/profiler.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('core') 3 | import argparse 4 | import torch 5 | from config.parser import parse_args 6 | from raft import RAFT 7 | 8 | def main(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--cfg', help='experiment configure file name', required=True, type=str) 11 | args = parse_args(parser) 12 | model = RAFT(args) 13 | model.eval() 14 | h, w = [540, 960] 15 | input = torch.zeros(1, 3, h, w) 16 | model = model.cuda() 17 | input = input.cuda() 18 | with torch.profiler.profile( 19 | activities=[ 20 | torch.profiler.ProfilerActivity.CPU, 21 | torch.profiler.ProfilerActivity.CUDA 22 | ], 23 | with_flops=True) as prof: 24 | output = model(input, input, iters=args.iters, test_mode=True) 25 | 26 | print(prof.key_averages(group_by_stack_n=5).table(sort_by='self_cuda_time_total', row_limit=5)) 27 | events = prof.events() 28 | forward_MACs = sum([int(evt.flops) for evt in events]) 29 | print("forward MACs: ", forward_MACs / 2 / 1e9, "G") 30 | 31 | if __name__ == '__main__': 32 | main() -------------------------------------------------------------------------------- /sea_raft/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | torchaudio 4 | numpy 5 | matplotlib 6 | scipy 7 | opencv-python 8 | tensorboard 9 | h5py 10 | tqdm 11 | einops -------------------------------------------------------------------------------- /sea_raft/scripts/eval.sh: -------------------------------------------------------------------------------- 1 | # evaluate on sintel 2 | python evaluate.py --cfg config/eval/sintel-M.json --model models/Tartan-C-T-TSKH432x960-M.pth 3 | # evaluate on kitti 4 | python evaluate.py --cfg config/eval/kitti-M.json --model models/Tartan-C-T-TSKH-kitti432x960-M.pth 5 | # evaluate on spring 6 | python evaluate.py --cfg config/eval/spring-M.json --model models/Tartan-C-T-TSKH-spring540x960-M.pth -------------------------------------------------------------------------------- /sea_raft/scripts/submission.sh: -------------------------------------------------------------------------------- 1 | # submit to sintel public leaderboard 2 | python submission.py --cfg config/eval/sintel-M.json --model models/Tartan-C-T-TSKH432x960-M.pth 3 | # submit to kitti public leaderboard 4 | python submission.py --cfg config/eval/kitti-M.json --model models/Tartan-C-T-TSKH-kitti432x960-M.pth 5 | # submit to spring public leaderboard 6 | python submission.py --cfg config/eval/spring-M.json --model models/Tartan-C-T-TSKH-spring540x960-M.pth -------------------------------------------------------------------------------- /sea_raft/scripts/train.sh: -------------------------------------------------------------------------------- 1 | # The training scripts are tested on 8 NVIDIA L40 GPUs 2 | 3 | # Stage0: Stereo pretraining on TartanAir 4 | python -u train.py --cfg config/train/Tartan480x640-M.json 5 | # Stage1: FlyingChairs, please update your ckpts path (from Stage0 results) in the config file (restore_ckpt) 6 | python -u train.py --cfg config/train/Tartan-C368x496-M.json 7 | # Stage2: FlyingThings3D, please update your ckpts path (from Stage1 results) in the config file (restore_ckpt) 8 | python -u train.py --cfg config/train/Tartan-C-T432x960-M.json 9 | # Stage3: FlyingThings + Sintel + KITTI + HD1K, please update your ckpts path (from Stage2 results) in the config file (restore_ckpt) 10 | # The ckpts from this stage are used for sintel submission 11 | python -u train.py --cfg config/train/Tartan-C-T-TSKH432x960-M.json 12 | 13 | # Stage4 (finetune for KITTI submission): KITTI, please update your ckpts path (from Stage3 results) in the config file (restore_ckpt) 14 | python -u train.py --cfg config/train/Tartan-C-T-TSKH-kitti432x960-M.json 15 | 16 | # Stage5 (finetune for Spring submission): Spring, please update your ckpts path (from Stage3 results) in the config file (restore_ckpt 17 | python -u train.py --cfg config/train/Tartan-C-T-TSKH-spring540x960-M.json 18 | -------------------------------------------------------------------------------- /sea_raft/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('core') 3 | 4 | import argparse 5 | import numpy as np 6 | 7 | from config.parser import parse_args 8 | 9 | import torch 10 | import torch.optim as optim 11 | 12 | from raft import RAFT 13 | from datasets import fetch_dataloader 14 | from utils.utils import load_ckpt 15 | from loss import sequence_loss 16 | from ddp_utils import * 17 | 18 | os.system("export KMP_INIT_AT_FORK=FALSE") 19 | 20 | def fetch_optimizer(args, model): 21 | """ Create the optimizer and learning rate scheduler """ 22 | optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon) 23 | scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps + 100, 24 | pct_start=0.05, cycle_momentum=False, anneal_strategy='linear') 25 | 26 | return optimizer, scheduler 27 | 28 | def train(args, rank=0, world_size=1, use_ddp=False): 29 | """ Full training loop """ 30 | device_id = rank 31 | model = RAFT(args).to(device_id) 32 | if args.restore_ckpt is not None: 33 | load_ckpt(model, args.restore_ckpt) 34 | print(f"restore ckpt from {args.restore_ckpt}") 35 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) # there might not be any, actually 36 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank], static_graph=True) 37 | 38 | model.train() 39 | train_loader = fetch_dataloader(args, rank=rank, world_size=world_size, use_ddp=use_ddp) 40 | optimizer, scheduler = fetch_optimizer(args, model) 41 | total_steps = 0 42 | VAL_FREQ = 10000 43 | epoch = 0 44 | should_keep_training = True 45 | # torch.autograd.set_detect_anomaly(True) 46 | while should_keep_training: 47 | # shuffle sampler 48 | train_loader.sampler.set_epoch(epoch) 49 | epoch += 1 50 | for i_batch, data_blob in enumerate(train_loader): 51 | optimizer.zero_grad() 52 | image1, image2, flow, valid = [x.cuda(non_blocking=True) for x in data_blob] 53 | output = model(image1, image2, flow_gt=flow, iters=args.iters) 54 | loss = sequence_loss(output, flow, valid, args.gamma) 55 | loss.backward() 56 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) 57 | optimizer.step() 58 | scheduler.step() 59 | 60 | if total_steps % VAL_FREQ == VAL_FREQ - 1 and rank == 0: 61 | PATH = 'checkpoints/%d_%s.pth' % (total_steps+1, args.name) 62 | torch.save(model.module.state_dict(), PATH) 63 | 64 | if total_steps > args.num_steps: 65 | should_keep_training = False 66 | break 67 | 68 | total_steps += 1 69 | 70 | PATH = 'checkpoints/%s.pth' % args.name 71 | if rank == 0: 72 | torch.save(model.module.state_dict(), PATH) 73 | 74 | return PATH 75 | 76 | def main(rank, world_size, args, use_ddp): 77 | 78 | if use_ddp: 79 | print(f"Using DDP [{rank=} {world_size=}]") 80 | setup_ddp(rank, world_size) 81 | 82 | train(args, rank=rank, world_size=world_size, use_ddp=use_ddp) 83 | 84 | if __name__ == '__main__': 85 | parser = argparse.ArgumentParser() 86 | parser.add_argument('--cfg', help='experiment configure file name', required=True, type=str) 87 | args = parse_args(parser) 88 | args.name += f"_exp{str(np.random.randint(100))}" 89 | smp, world_size = init_ddp() 90 | if world_size > 1: 91 | spwn_ctx = mp.spawn(main, nprocs=world_size, args=(world_size, args, True), join=False) 92 | spwn_ctx.join() 93 | else: 94 | main(0, 1, args, False) 95 | print("Done!") -------------------------------------------------------------------------------- /unimatch/DATASETS.md: -------------------------------------------------------------------------------- 1 | # Datasets 2 | 3 | 4 | 5 | ## Optical Flow 6 | 7 | The datasets used to train and evaluate our GMFlow model are as follows: 8 | 9 | - [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs) 10 | - [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) 11 | - [Sintel](http://sintel.is.tue.mpg.de/) 12 | - [Virtual KITTI 2](https://europe.naverlabs.com/research/computer-vision/proxy-virtual-worlds-vkitti-2/) 13 | - [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow) 14 | - [HD1K](http://hci-benchmark.iwr.uni-heidelberg.de/) 15 | 16 | By default the dataloader [dataloader/flow/datasets.py](dataloader/flow/datasets.py) assumes the datasets are located in the `datasets` directory. 17 | 18 | It is recommended to symlink your dataset root to `datasets`: 19 | 20 | ``` 21 | ln -s $YOUR_DATASET_ROOT datasets 22 | ``` 23 | 24 | Otherwise, you may need to change the corresponding paths in [dataloader/flow/datasets.py](dataloader/flow/datasets.py). 25 | 26 | 27 | 28 | ## Stereo Matching 29 | 30 | The datasets used to train and evaluate our GMStereo model are as follows: 31 | 32 | - [Scene Flow](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) 33 | - [Virtual KITTI 2](https://europe.naverlabs.com/research/computer-vision/proxy-virtual-worlds-vkitti-2/) 34 | - [KITTI](https://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=stereo) 35 | - [TartanAir](https://github.com/castacks/tartanair_tools) 36 | - [Falling Things](https://research.nvidia.com/publication/2018-06_Falling-Things) 37 | - [HR-VS](https://drive.google.com/file/d/1SgEIrH_IQTKJOToUwR1rx4-237sThUqX/view) 38 | - [CREStereo Dataset](https://github.com/megvii-research/CREStereo/blob/master/dataset_download.sh) 39 | - [InStereo2K](https://github.com/YuhuaXu/StereoDataset) 40 | - [Middlebury](https://vision.middlebury.edu/stereo/data/) 41 | - [Sintel Stereo](http://sintel.is.tue.mpg.de/stereo) 42 | - [ETH3D](https://www.eth3d.net/datasets#low-res-two-view-training-data) 43 | 44 | By default the dataloader [dataloader/stereo/datasets.py](dataloader/stereo/datasets.py) assumes the datasets are located in the `datasets` directory. 45 | 46 | It is recommended to symlink your dataset root to `datasets`: 47 | 48 | ``` 49 | ln -s $YOUR_DATASET_ROOT datasets 50 | ``` 51 | 52 | Otherwise, you may need to change the corresponding paths in [dataloader/stereo/datasets.py](dataloader/flow/datasets.py). 53 | 54 | 55 | 56 | ## Depth Estimation 57 | 58 | The datasets used to train and evaluate our GMDepth model are as follows: 59 | 60 | - [DeMoN](https://github.com/lmb-freiburg/demon) 61 | - [ScanNet](http://www.scan-net.org/) 62 | 63 | We support downloading and extracting the DeMoN dataset in our code: [dataloader/depth/download_demon_train.sh](dataloader/depth/download_demon_train.sh), [dataloader/depth/download_demon_test.sh](dataloader/depth/download_demon_test.sh), [dataloader/depth/prepare_demon_train.sh](dataloader/depth/prepare_demon_train.sh) and [dataloader/depth/prepare_demon_test.sh](dataloader/depth/prepare_demon_test.sh). 64 | 65 | By default the dataloader [dataloader/depth/datasets.py](dataloader/depth/datasets.py) assumes the datasets are located in the `datasets` directory. 66 | 67 | It is recommended to symlink your dataset root to `datasets`: 68 | 69 | ``` 70 | ln -s $YOUR_DATASET_ROOT datasets 71 | ``` 72 | 73 | Otherwise, you may need to change the corresponding paths in [dataloader/depth/datasets.py](dataloader/depth/datasets.py). 74 | 75 | -------------------------------------------------------------------------------- /unimatch/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 autonomousvision 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /unimatch/conda_environment.yml: -------------------------------------------------------------------------------- 1 | name: unimatch 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1 7 | - _openmp_mutex=5.1 8 | - blas=1.0 9 | - brotli=1.0.9 10 | - brotli-bin=1.0.9 11 | - bzip2=1.0.8 12 | - ca-certificates=2022.10.11 13 | - certifi=2022.9.24 14 | - cloudpickle=2.0.0 15 | - cudatoolkit=10.2.89 16 | - cycler=0.11.0 17 | - cytoolz=0.12.0 18 | - dask-core=2022.7.0 19 | - dbus=1.13.18 20 | - expat=2.4.9 21 | - ffmpeg=4.3 22 | - fftw=3.3.9 23 | - fontconfig=2.13.1 24 | - fonttools=4.25.0 25 | - freetype=2.12.1 26 | - fsspec=2022.10.0 27 | - giflib=5.2.1 28 | - glib=2.69.1 29 | - gmp=6.2.1 30 | - gnutls=3.6.15 31 | - gst-plugins-base=1.14.0 32 | - gstreamer=1.14.0 33 | - icu=58.2 34 | - imageio=2.9.0 35 | - intel-openmp=2021.4.0 36 | - jpeg=9b 37 | - kiwisolver=1.4.2 38 | - lame=3.100 39 | - lcms2=2.12 40 | - ld_impl_linux-64=2.38 41 | - libbrotlicommon=1.0.9 42 | - libbrotlidec=1.0.9 43 | - libbrotlienc=1.0.9 44 | - libffi=3.3 45 | - libgcc-ng=11.2.0 46 | - libgfortran-ng=11.2.0 47 | - libgfortran5=11.2.0 48 | - libgomp=11.2.0 49 | - libiconv=1.16 50 | - libidn2=2.3.2 51 | - libpng=1.6.37 52 | - libstdcxx-ng=11.2.0 53 | - libtasn1=4.16.0 54 | - libtiff=4.1.0 55 | - libunistring=0.9.10 56 | - libuuid=1.41.5 57 | - libuv=1.40.0 58 | - libwebp=1.2.0 59 | - libxcb=1.15 60 | - libxml2=2.9.14 61 | - locket=1.0.0 62 | - lz4-c=1.9.3 63 | - matplotlib=3.5.1 64 | - matplotlib-base=3.5.1 65 | - mkl=2021.4.0 66 | - mkl-service=2.4.0 67 | - mkl_fft=1.3.1 68 | - mkl_random=1.2.2 69 | - munkres=1.1.4 70 | - ncurses=6.3 71 | - nettle=3.7.3 72 | - networkx=2.8.4 73 | - ninja=1.10.2 74 | - ninja-base=1.10.2 75 | - numpy=1.19.2 76 | - numpy-base=1.19.2 77 | - openh264=2.1.1 78 | - openssl=1.1.1s 79 | - packaging=21.3 80 | - partd=1.2.0 81 | - pcre=8.45 82 | - pillow=9.0.1 83 | - pip=22.2.2 84 | - pyparsing=3.0.9 85 | - pyqt=5.9.2 86 | - python=3.8.13 87 | - python-dateutil=2.8.2 88 | - pytorch=1.9.0 89 | - pywavelets=1.3.0 90 | - pyyaml=6.0 91 | - qt=5.9.7 92 | - readline=8.2 93 | - scikit-image=0.19.2 94 | - scipy=1.9.3 95 | - sip=4.19.13 96 | - six=1.16.0 97 | - sqlite=3.39.3 98 | - tifffile=2020.10.1 99 | - tk=8.6.12 100 | - toolz=0.12.0 101 | - torchvision=0.10.0 102 | - tornado=6.2 103 | - typing_extensions=4.3.0 104 | - wheel=0.37.1 105 | - xz=5.2.6 106 | - yaml=0.2.5 107 | - zlib=1.2.13 108 | - zstd=1.4.9 109 | - pip: 110 | - absl-py==1.3.0 111 | - cachetools==5.2.0 112 | - charset-normalizer==2.1.1 113 | - google-auth==2.14.1 114 | - google-auth-oauthlib==0.4.6 115 | - grpcio==1.50.0 116 | - h5py==3.7.0 117 | - idna==3.4 118 | - imageio-ffmpeg==0.4.7 119 | - importlib-metadata==5.0.0 120 | - joblib==1.2.0 121 | - lz4==4.0.2 122 | - markdown==3.4.1 123 | - markupsafe==2.1.1 124 | - oauthlib==3.2.2 125 | - opencv-python==4.6.0.66 126 | - path==16.5.0 127 | - protobuf==3.19.6 128 | - pyasn1==0.4.8 129 | - pyasn1-modules==0.2.8 130 | - requests==2.28.1 131 | - requests-oauthlib==1.3.1 132 | - rsa==4.9 133 | - setuptools==59.5.0 134 | - tensorboard==2.9.1 135 | - tensorboard-data-server==0.6.1 136 | - tensorboard-plugin-wit==1.8.1 137 | - tqdm==4.64.1 138 | - urllib3==1.26.12 139 | - werkzeug==2.2.2 140 | - zipp==3.10.0 141 | -------------------------------------------------------------------------------- /unimatch/dataloader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/unimatch/dataloader/__init__.py -------------------------------------------------------------------------------- /unimatch/dataloader/depth/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/unimatch/dataloader/depth/__init__.py -------------------------------------------------------------------------------- /unimatch/dataloader/depth/download_demon_test.sh: -------------------------------------------------------------------------------- 1 | # Source from https://github.com/lmb-freiburg/demon 2 | #!/bin/bash 3 | clear 4 | cat << EOF 5 | ================================================================================ 6 | The test datasets are provided for research purposes only. 7 | Some of the test datasets build upon other publicly available data. 8 | Make sure to cite the respective original source of the data if you use the 9 | provided files for your research. 10 | * sun3d_test.h5 is based on the SUN3D dataset http://sun3d.cs.princeton.edu/ 11 | J. Xiao, A. Owens, and A. Torralba, “SUN3D: A Database of Big Spaces Reconstructed Using SfM and Object Labels,” in 2013 IEEE International Conference on Computer Vision (ICCV), 2013, pp. 1625–1632. 12 | 13 | * rgbd_test.h5 is based on the RGBD SLAM benchmark http://vision.in.tum.de/data/datasets/rgbd-dataset (licensed under CC-BY 3.0) 14 | 15 | J. Sturm, N. Engelhard, F. Endres, W. Burgard, and D. Cremers, “A benchmark for the evaluation of RGB-D SLAM systems,” in 2012 IEEE/RSJ International Conference on Intelligent Robots and Systems, 2012, pp. 573–580. 16 | * scenes11_test.h5 uses objects from shapenet https://www.shapenet.org/ 17 | 18 | A. X. Chang et al., “ShapeNet: An Information-Rich 3D Model Repository,” arXiv:1512.03012 [cs], Dec. 2015. 19 | * mvs_test.h5 contains scenes from https://colmap.github.io/datasets.html 20 | 21 | J. L. Schönberger and J. M. Frahm, “Structure-from-Motion Revisited,” in 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2016, pp. 4104–4113. 22 | J. L. Schönberger, E. Zheng, J.-M. Frahm, and M. Pollefeys, “Pixelwise View Selection for Unstructured Multi-View Stereo,” in Computer Vision – ECCV 2016, 2016, pp. 501–518. 23 | * nyu2_test.h5 is based on http://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html 24 | 25 | N. Silberman, D. Hoiem, P. Kohli, and R. Fergus, “Indoor Segmentation and Support Inference from RGBD Images,” in Computer Vision – ECCV 2012, 2012, pp. 746–760. 26 | ================================================================================ 27 | type Y to start the download. 28 | EOF 29 | 30 | read -s -n 1 answer 31 | if [ "$answer" != "Y" -a "$answer" != "y" ]; then 32 | exit 0 33 | fi 34 | echo 35 | 36 | datasets=(sun3d rgbd mvs scenes11) 37 | 38 | OLD_PWD="$PWD" 39 | DESTINATION=testdata 40 | mkdir $DESTINATION 41 | cd $DESTINATION 42 | 43 | for ds in ${datasets[@]}; do 44 | if [ -e "${ds}_test.h5" ]; then 45 | echo "${ds}_test.h5 already exists, skipping ${ds}" 46 | else 47 | wget "https://lmb.informatik.uni-freiburg.de/data/demon/testdata/${ds}_test.tgz" 48 | tar -xvf "${ds}_test.tgz" 49 | fi 50 | done 51 | 52 | cd "$OLD_PWD" -------------------------------------------------------------------------------- /unimatch/dataloader/depth/download_demon_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | clear 3 | cat << EOF 4 | 5 | ================================================================================ 6 | 7 | 8 | The train datasets are provided for research purposes only. 9 | 10 | Some of the test datasets build upon other publicly available data. 11 | Make sure to cite the respective original source of the data if you use the 12 | provided files for your research. 13 | 14 | * sun3d_train.h5 is based on the SUN3D dataset http://sun3d.cs.princeton.edu/ 15 | 16 | J. Xiao, A. Owens, and A. Torralba, “SUN3D: A Database of Big Spaces Reconstructed Using SfM and Object Labels,” in 2013 IEEE International Conference on Computer Vision (ICCV), 2013, pp. 1625–1632. 17 | 18 | 19 | 20 | 21 | * rgbd_bugfix_train.h5 is based on the RGBD SLAM benchmark http://vision.in.tum.de/data/datasets/rgbd-dataset (licensed under CC-BY 3.0) 22 | 23 | J. Sturm, N. Engelhard, F. Endres, W. Burgard, and D. Cremers, “A benchmark for the evaluation of RGB-D SLAM systems,” in 2012 IEEE/RSJ International Conference on Intelligent Robots and Systems, 2012, pp. 573–580. 24 | 25 | 26 | 27 | * scenes11_train.h5 uses objects from shapenet https://www.shapenet.org/ 28 | 29 | A. X. Chang et al., “ShapeNet: An Information-Rich 3D Model Repository,” arXiv:1512.03012 [cs], Dec. 2015. 30 | 31 | 32 | 33 | * mvs_train.h5 contains the Citywall and Achteck-Turm scenes from MVE (Multi-View Environment) http://www.gcc.tu-darmstadt.de/home/proj/mve/ 34 | 35 | S. Fuhrmann, F. Langguth, and M. Goesele, “MVE: A Multi-view Reconstruction Environment,” in Proceedings of the Eurographics Workshop on Graphics and Cultural Heritage, Aire-la-Ville, Switzerland, Switzerland, 2014, pp. 11–18. 36 | 37 | 38 | 39 | ================================================================================ 40 | 41 | type Y to start the download. 42 | 43 | EOF 44 | 45 | read -s -n 1 answer 46 | if [ "$answer" != "Y" -a "$answer" != "y" ]; then 47 | exit 0 48 | fi 49 | echo 50 | 51 | datasets=(sun3d rgbd mvs scenes11) 52 | 53 | OLD_PWD="$PWD" 54 | DESTINATION=traindata 55 | mkdir $DESTINATION 56 | cd $DESTINATION 57 | 58 | if [ ! -e "README_traindata" ]; then 59 | wget --no-check-certificate "https://lmb.informatik.uni-freiburg.de/data/demon/traindata/README_traindata" 60 | fi 61 | 62 | for ds in ${datasets[@]}; do 63 | if [ -e "${ds}_train.h5" ]; then 64 | echo "${ds}_train.h5 already exists, skipping ${ds}" 65 | else 66 | wget --no-check-certificate "https://lmb.informatik.uni-freiburg.de/data/demon/traindata/${ds}_train.tgz" 67 | tar -xvf "${ds}_train.tgz" 68 | fi 69 | done 70 | 71 | cd "$OLD_PWD" -------------------------------------------------------------------------------- /unimatch/dataloader/depth/prepare_demon_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import subprocess 4 | 5 | from joblib import Parallel, delayed 6 | import numpy as np 7 | import imageio 8 | 9 | imageio.plugins.freeimage.download() 10 | from imageio.plugins import freeimage 11 | import h5py 12 | from lz4.block import decompress 13 | import scipy.misc 14 | 15 | import cv2 16 | 17 | from path import Path 18 | 19 | path = os.path.join(os.path.dirname(os.path.abspath(__file__))) 20 | 21 | 22 | def dump_example(dataset_name): 23 | print("Converting {:}.h5 ...".format(dataset_name)) 24 | file = h5py.File(os.path.join(path, "testdata", "{:}.h5".format(dataset_name)), "r") 25 | 26 | for (seq_idx, seq_name) in enumerate(file): 27 | if dataset_name == 'scenes11_test': 28 | scale = 0.4 29 | else: 30 | scale = 1 31 | 32 | print("Processing sequence {:d}/{:d}".format(seq_idx, len(file))) 33 | dump_dir = os.path.join(path, 'test', dataset_name + "_" + "{:05d}".format(seq_idx)) 34 | if not os.path.isdir(dump_dir): 35 | os.mkdir(dump_dir) 36 | dump_dir = Path(dump_dir) 37 | sequence = file[seq_name]["frames"]["t0"] 38 | poses = [] 39 | for (f_idx, f_name) in enumerate(sequence): 40 | frame = sequence[f_name] 41 | for dt_type in frame: 42 | dataset = frame[dt_type] 43 | img = dataset[...] 44 | if dt_type == "camera": 45 | if f_idx == 0: 46 | intrinsics = np.array([[img[0], 0, img[3]], [0, img[1], img[4]], [0, 0, 1]]) 47 | pose = np.array( 48 | [[img[5], img[8], img[11], img[14] * scale], [img[6], img[9], img[12], img[15] * scale], 49 | [img[7], img[10], img[13], img[16] * scale]]) 50 | poses.append(pose.tolist()) 51 | elif dt_type == "depth": 52 | dimension = dataset.attrs["extents"] 53 | depth = np.array(np.frombuffer(decompress(img.tobytes(), dimension[0] * dimension[1] * 2), 54 | dtype=np.float16)).astype(np.float32) 55 | depth = depth.reshape(dimension[0], dimension[1]) * scale 56 | 57 | dump_depth_file = dump_dir / '{:04d}.npy'.format(f_idx) 58 | np.save(dump_depth_file, depth) 59 | elif dt_type == "image": 60 | img = imageio.imread(img.tobytes()) 61 | dump_img_file = dump_dir / '{:04d}.jpg'.format(f_idx) 62 | imageio.imsave(dump_img_file, img) 63 | 64 | dump_cam_file = dump_dir / 'cam.txt' 65 | np.savetxt(dump_cam_file, intrinsics) 66 | poses_file = dump_dir / 'poses.txt' 67 | np.savetxt(poses_file, np.array(poses).reshape(-1, 12), fmt='%.6e') 68 | 69 | if len(dump_dir.files('*.jpg')) < 2: 70 | dump_dir.rmtree() 71 | 72 | 73 | def preparedata(): 74 | num_threads = 1 75 | SUB_DATASET_NAMES = (["rgbd_test", "scenes11_test", "sun3d_test"]) 76 | 77 | dump_root = os.path.join(path, 'test') 78 | if not os.path.isdir(dump_root): 79 | os.mkdir(dump_root) 80 | 81 | if num_threads == 1: 82 | for scene in SUB_DATASET_NAMES: 83 | dump_example(scene) 84 | else: 85 | Parallel(n_jobs=num_threads)(delayed(dump_example)(scene) for scene in SUB_DATASET_NAMES) 86 | 87 | dump_root = Path(dump_root) 88 | subdirs = dump_root.dirs() 89 | subdirs = [subdir.basename() for subdir in subdirs] 90 | subdirs = sorted(subdirs) 91 | with open(dump_root / 'test.txt', 'w') as tf: 92 | for subdir in subdirs: 93 | tf.write('{}\n'.format(subdir)) 94 | 95 | print("Finished Converting Data.") 96 | 97 | 98 | if __name__ == "__main__": 99 | preparedata() 100 | -------------------------------------------------------------------------------- /unimatch/dataloader/depth/prepare_demon_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import subprocess 4 | 5 | from joblib import Parallel, delayed 6 | import numpy as np 7 | import imageio 8 | 9 | imageio.plugins.freeimage.download() 10 | from imageio.plugins import freeimage 11 | import h5py 12 | from lz4.block import decompress 13 | import scipy.misc 14 | import cv2 15 | 16 | from path import Path 17 | 18 | path = os.path.join(os.path.dirname(os.path.abspath(__file__))) 19 | 20 | 21 | def dump_example(dataset_name): 22 | print("Converting {:}.h5 ...".format(dataset_name)) 23 | file = h5py.File(os.path.join(path, "traindata", "{:}.h5".format(dataset_name)), "r") 24 | 25 | for (seq_idx, seq_name) in enumerate(file): 26 | if dataset_name == 'scenes11_train': 27 | scale = 0.4 28 | else: 29 | scale = 1 30 | 31 | if ((dataset_name == 'sun3d_train_1.6m_to_infm' and seq_idx == 7) or \ 32 | (dataset_name == 'sun3d_train_0.4m_to_0.8m' and seq_idx == 15) or \ 33 | (dataset_name == 'scenes11_train' and ( 34 | seq_idx == 2758 or seq_idx == 4691 or seq_idx == 7023 or seq_idx == 11157 or seq_idx == 17168 or seq_idx == 19595))): 35 | continue # Skip error files 36 | 37 | print("Processing sequence {:d}/{:d}".format(seq_idx, len(file))) 38 | dump_dir = os.path.join(path, '../train', dataset_name + "_" + "{:05d}".format(seq_idx)) 39 | if not os.path.isdir(dump_dir): 40 | os.mkdir(dump_dir) 41 | dump_dir = Path(dump_dir) 42 | sequence = file[seq_name]["frames"]["t0"] 43 | poses = [] 44 | for (f_idx, f_name) in enumerate(sequence): 45 | frame = sequence[f_name] 46 | for dt_type in frame: 47 | dataset = frame[dt_type] 48 | img = dataset[...] 49 | if dt_type == "camera": 50 | if f_idx == 0: 51 | intrinsics = np.array([[img[0], 0, img[3]], [0, img[1], img[4]], [0, 0, 1]]) 52 | pose = np.array( 53 | [[img[5], img[8], img[11], img[14] * scale], [img[6], img[9], img[12], img[15] * scale], 54 | [img[7], img[10], img[13], img[16] * scale]]) 55 | poses.append(pose.tolist()) 56 | elif dt_type == "depth": 57 | dimension = dataset.attrs["extents"] 58 | depth = np.array(np.frombuffer(decompress(img.tobytes(), dimension[0] * dimension[1] * 2), 59 | dtype=np.float16)).astype(np.float32) 60 | depth = depth.reshape(dimension[0], dimension[1]) * scale 61 | 62 | dump_depth_file = dump_dir / '{:04d}.npy'.format(f_idx) 63 | np.save(dump_depth_file, depth) 64 | elif dt_type == "image": 65 | img = imageio.imread(img.tobytes()) 66 | dump_img_file = dump_dir / '{:04d}.jpg'.format(f_idx) 67 | imageio.imsave(dump_img_file, img) 68 | 69 | dump_cam_file = dump_dir / 'cam.txt' 70 | np.savetxt(dump_cam_file, intrinsics) 71 | poses_file = dump_dir / 'poses.txt' 72 | np.savetxt(poses_file, np.array(poses).reshape(-1, 12), fmt='%.6e') 73 | 74 | if len(dump_dir.files('*.jpg')) < 2: 75 | dump_dir.rmtree() 76 | 77 | 78 | def preparedata(): 79 | num_threads = 1 80 | SUB_DATASET_NAMES = ([ 81 | "rgbd_10_to_20_3d_train", "rgbd_10_to_20_handheld_train", "rgbd_10_to_20_simple_train", 82 | "rgbd_20_to_inf_3d_train", "rgbd_20_to_inf_handheld_train", "rgbd_20_to_inf_simple_train", 83 | "sun3d_train_0.01m_to_0.1m", "sun3d_train_0.1m_to_0.2m", "sun3d_train_0.2m_to_0.4m", "sun3d_train_0.4m_to_0.8m", 84 | "sun3d_train_0.8m_to_1.6m", "sun3d_train_1.6m_to_infm", 85 | "scenes11_train", 86 | ]) 87 | 88 | dump_root = os.path.join(path, 'train') 89 | if not os.path.isdir(dump_root): 90 | os.mkdir(dump_root) 91 | 92 | if num_threads == 1: 93 | for scene in SUB_DATASET_NAMES: 94 | dump_example(scene) 95 | else: 96 | Parallel(n_jobs=num_threads)(delayed(dump_example)(scene) for scene in SUB_DATASET_NAMES) 97 | 98 | np.random.seed(8964) 99 | dump_root = Path(dump_root) 100 | subdirs = dump_root.dirs() 101 | canonic_prefixes = set([subdir.basename()[:-2] for subdir in subdirs]) 102 | with open(dump_root / 'train.txt', 'w') as tf: 103 | with open(dump_root / 'val.txt', 'w') as vf: 104 | for pr in canonic_prefixes: 105 | corresponding_dirs = dump_root.dirs('{}*'.format(pr)) 106 | if np.random.random() < 0.1: 107 | for s in corresponding_dirs: 108 | vf.write('{}\n'.format(s.name)) 109 | else: 110 | for s in corresponding_dirs: 111 | tf.write('{}\n'.format(s.name)) 112 | 113 | print("Finished Converting Data.") 114 | 115 | 116 | if __name__ == "__main__": 117 | preparedata() 118 | -------------------------------------------------------------------------------- /unimatch/dataloader/flow/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/unimatch/dataloader/flow/__init__.py -------------------------------------------------------------------------------- /unimatch/dataloader/stereo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/unimatch/dataloader/stereo/__init__.py -------------------------------------------------------------------------------- /unimatch/demo/depth-scannet/color/0048.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/unimatch/demo/depth-scannet/color/0048.png -------------------------------------------------------------------------------- /unimatch/demo/depth-scannet/color/0054.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/unimatch/demo/depth-scannet/color/0054.png -------------------------------------------------------------------------------- /unimatch/demo/depth-scannet/color/0060.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/unimatch/demo/depth-scannet/color/0060.png -------------------------------------------------------------------------------- /unimatch/demo/depth-scannet/color/0066.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/unimatch/demo/depth-scannet/color/0066.png -------------------------------------------------------------------------------- /unimatch/demo/depth-scannet/intrinsic/intrinsic_depth.txt: -------------------------------------------------------------------------------- 1 | 577.590698 0.000000 318.905426 0.000000 2 | 0.000000 578.729797 242.683609 0.000000 3 | 0.000000 0.000000 1.000000 0.000000 4 | 0.000000 0.000000 0.000000 1.000000 5 | -------------------------------------------------------------------------------- /unimatch/demo/depth-scannet/pose/0048.txt: -------------------------------------------------------------------------------- 1 | 0.703694 -0.367391 0.608144 1.896290 2 | -0.708482 -0.427345 0.561630 2.467417 3 | 0.053549 -0.826075 -0.561010 1.399475 4 | 0.000000 0.000000 0.000000 1.000000 5 | -------------------------------------------------------------------------------- /unimatch/demo/depth-scannet/pose/0054.txt: -------------------------------------------------------------------------------- 1 | 0.750884 -0.329503 0.572363 1.915689 2 | -0.658300 -0.443024 0.608580 2.368469 3 | 0.053042 -0.833760 -0.549572 1.413484 4 | 0.000000 0.000000 0.000000 1.000000 5 | -------------------------------------------------------------------------------- /unimatch/demo/depth-scannet/pose/0060.txt: -------------------------------------------------------------------------------- 1 | 0.776779 -0.282017 0.563098 1.923212 2 | -0.625761 -0.446388 0.639656 2.259246 3 | 0.070966 -0.849237 -0.523221 1.407526 4 | 0.000000 0.000000 0.000000 1.000000 5 | -------------------------------------------------------------------------------- /unimatch/demo/depth-scannet/pose/0066.txt: -------------------------------------------------------------------------------- 1 | 0.794580 -0.275900 0.540852 1.915812 2 | -0.604505 -0.442681 0.662273 2.192295 3 | 0.056703 -0.853178 -0.518529 1.423975 4 | 0.000000 0.000000 0.000000 1.000000 5 | -------------------------------------------------------------------------------- /unimatch/demo/flow-davis/00000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/unimatch/demo/flow-davis/00000.jpg -------------------------------------------------------------------------------- /unimatch/demo/flow-davis/00001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/unimatch/demo/flow-davis/00001.jpg -------------------------------------------------------------------------------- /unimatch/demo/flow-davis/00002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/unimatch/demo/flow-davis/00002.jpg -------------------------------------------------------------------------------- /unimatch/demo/kitti.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/unimatch/demo/kitti.mp4 -------------------------------------------------------------------------------- /unimatch/demo/stereo-middlebury/im0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/unimatch/demo/stereo-middlebury/im0.png -------------------------------------------------------------------------------- /unimatch/demo/stereo-middlebury/im1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/unimatch/demo/stereo-middlebury/im1.png -------------------------------------------------------------------------------- /unimatch/loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/unimatch/loss/__init__.py -------------------------------------------------------------------------------- /unimatch/loss/depth_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import numpy as np 4 | 5 | 6 | def compute_errors(gt, pred): 7 | """Computation of error metrics between predicted and ground truth depths 8 | """ 9 | thresh = np.maximum((gt / pred), (pred / gt)) 10 | a1 = (thresh < 1.25).mean() 11 | a2 = (thresh < 1.25 ** 2).mean() 12 | a3 = (thresh < 1.25 ** 3).mean() 13 | 14 | rmse = (gt - pred) ** 2 15 | rmse = np.sqrt(rmse.mean()) 16 | 17 | rmse_log = (np.log(gt) - np.log(pred)) ** 2 18 | rmse_log = np.sqrt(rmse_log.mean()) 19 | 20 | abs_rel = np.mean(np.abs(gt - pred) / gt) 21 | 22 | sq_rel = np.mean(((gt - pred) ** 2) / gt) 23 | 24 | return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3 25 | 26 | 27 | def get_depth_grad_loss(depth_pred, depth_gt, valid, inverse_depth_loss=True): 28 | # default is on inverse depth 29 | # both: [B, H, W] 30 | assert depth_pred.dim() == 3 and depth_gt.dim() == 3 and valid.dim() == 3 31 | 32 | valid = valid > 0.5 33 | valid_x = valid[:, :, :-1] & valid[:, :, 1:] 34 | valid_y = valid[:, :-1, :] & valid[:, 1:, :] 35 | 36 | if valid_x.max() < 0.5 or valid_y.max() < 0.5: # no valid pixel 37 | return 0. 38 | 39 | if inverse_depth_loss: 40 | grad_pred_x = torch.abs(1. / depth_pred[:, :, :-1][valid_x] - 1. / depth_pred[:, :, 1:][valid_x]) 41 | grad_pred_y = torch.abs(1. / depth_pred[:, :-1, :][valid_y] - 1. / depth_pred[:, 1:, :][valid_y]) 42 | 43 | grad_gt_x = torch.abs(1. / depth_gt[:, :, :-1][valid_x] - 1. / depth_gt[:, :, 1:][valid_x]) 44 | grad_gt_y = torch.abs(1. / depth_gt[:, :-1, :][valid_y] - 1. / depth_gt[:, 1:, :][valid_y]) 45 | else: 46 | grad_pred_x = torch.abs((depth_pred[:, :, :-1] - depth_pred[:, :, 1:])[valid_x]) 47 | grad_pred_y = torch.abs((depth_pred[:, :-1, :] - depth_pred[:, 1:, :])[valid_y]) 48 | 49 | grad_gt_x = torch.abs((depth_gt[:, :, :-1] - depth_gt[:, :, 1:])[valid_x]) 50 | grad_gt_y = torch.abs((depth_gt[:, :-1, :] - depth_gt[:, 1:, :])[valid_y]) 51 | 52 | loss_grad_x = torch.abs(grad_pred_x - grad_gt_x).mean() 53 | loss_grad_y = torch.abs(grad_pred_y - grad_gt_y).mean() 54 | 55 | return loss_grad_x + loss_grad_y 56 | 57 | 58 | def depth_grad_loss_func(depth_preds, depth_gt, valid, 59 | inverse_depth_loss=True, 60 | gamma=0.9): 61 | num = len(depth_preds) 62 | loss = 0. 63 | 64 | for i in range(num): 65 | weight = gamma ** (num - i - 1) 66 | loss += weight * get_depth_grad_loss(depth_preds[i], depth_gt, valid, 67 | inverse_depth_loss=inverse_depth_loss) 68 | 69 | return loss 70 | 71 | 72 | def depth_loss_func(depth_preds, depth_gt, valid, gamma=0.9, 73 | ): 74 | """ loss function defined over multiple depth predictions """ 75 | 76 | n_predictions = len(depth_preds) 77 | depth_loss = 0.0 78 | 79 | for i in range(n_predictions): 80 | i_weight = gamma ** (n_predictions - i - 1) 81 | 82 | # inverse depth loss 83 | valid_bool = valid > 0.5 84 | if valid_bool.max() < 0.5: # no valid pixel 85 | i_loss = 0. 86 | else: 87 | i_loss = (1. / depth_preds[i][valid_bool] - 1. / depth_gt[valid_bool]).abs().mean() 88 | 89 | depth_loss += i_weight * i_loss 90 | 91 | return depth_loss 92 | -------------------------------------------------------------------------------- /unimatch/loss/flow_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def flow_loss_func(flow_preds, flow_gt, valid, 5 | gamma=0.9, 6 | max_flow=400, 7 | **kwargs, 8 | ): 9 | n_predictions = len(flow_preds) 10 | flow_loss = 0.0 11 | 12 | # exlude invalid pixels and extremely large diplacements 13 | mag = torch.sum(flow_gt ** 2, dim=1).sqrt() # [B, H, W] 14 | valid = (valid >= 0.5) & (mag < max_flow) 15 | 16 | for i in range(n_predictions): 17 | i_weight = gamma ** (n_predictions - i - 1) 18 | 19 | i_loss = (flow_preds[i] - flow_gt).abs() 20 | 21 | flow_loss += i_weight * (valid[:, None] * i_loss).mean() 22 | 23 | epe = torch.sum((flow_preds[-1] - flow_gt) ** 2, dim=1).sqrt() 24 | 25 | if valid.max() < 0.5: 26 | pass 27 | 28 | epe = epe.view(-1)[valid.view(-1)] 29 | 30 | metrics = { 31 | 'epe': epe.mean().item(), 32 | '1px': (epe > 1).float().mean().item(), 33 | '3px': (epe > 3).float().mean().item(), 34 | '5px': (epe > 5).float().mean().item(), 35 | } 36 | 37 | return flow_loss, metrics 38 | -------------------------------------------------------------------------------- /unimatch/loss/stereo_metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def epe_metric(d_est, d_gt, mask, use_np=False): 6 | d_est, d_gt = d_est[mask], d_gt[mask] 7 | if use_np: 8 | epe = np.mean(np.abs(d_est - d_gt)) 9 | else: 10 | epe = torch.mean(torch.abs(d_est - d_gt)) 11 | 12 | return epe 13 | 14 | 15 | def d1_metric(d_est, d_gt, mask, use_np=False): 16 | d_est, d_gt = d_est[mask], d_gt[mask] 17 | if use_np: 18 | e = np.abs(d_gt - d_est) 19 | else: 20 | e = torch.abs(d_gt - d_est) 21 | err_mask = (e > 3) & (e / d_gt > 0.05) 22 | 23 | if use_np: 24 | mean = np.mean(err_mask.astype('float')) 25 | else: 26 | mean = torch.mean(err_mask.float()) 27 | 28 | return mean 29 | 30 | 31 | def bad_pixel_metric(d_est, d_gt, mask, 32 | abs_threshold=10, 33 | rel_threshold=0.1, 34 | use_np=False): 35 | d_est, d_gt = d_est[mask], d_gt[mask] 36 | if use_np: 37 | e = np.abs(d_gt - d_est) 38 | else: 39 | e = torch.abs(d_gt - d_est) 40 | 41 | err_mask = (e > abs_threshold) & (e / torch.maximum(d_gt, torch.ones_like(d_gt)) > rel_threshold) 42 | 43 | if use_np: 44 | mean = np.mean(err_mask.astype('float')) 45 | else: 46 | mean = torch.mean(err_mask.float()) 47 | 48 | return mean 49 | 50 | 51 | def thres_metric(d_est, d_gt, mask, thres, use_np=False): 52 | assert isinstance(thres, (int, float)) 53 | d_est, d_gt = d_est[mask], d_gt[mask] 54 | if use_np: 55 | e = np.abs(d_gt - d_est) 56 | else: 57 | e = torch.abs(d_gt - d_est) 58 | err_mask = e > thres 59 | 60 | if use_np: 61 | mean = np.mean(err_mask.astype('float')) 62 | else: 63 | mean = torch.mean(err_mask.float()) 64 | 65 | return mean 66 | -------------------------------------------------------------------------------- /unimatch/pip_install.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | pip install torch==1.9.0+cu102 torchvision==0.10.0+cu102 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html 5 | 6 | pip install imageio==2.9.0 imageio-ffmpeg matplotlib opencv-python pillow scikit-image scipy tensorboard==2.9.1 setuptools==59.5.0 7 | -------------------------------------------------------------------------------- /unimatch/scripts/gmdepth_demo.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | # gmdepth-scale1-regrefine1 5 | CUDA_VISIBLE_DEVICES=0 python main_depth.py \ 6 | --inference_dir demo/depth-scannet \ 7 | --output_path output/gmdepth-scale1-regrefine1-scannet \ 8 | --resume pretrained/gmdepth-scale1-regrefine1-resumeflowthings-scannet-90325722.pth \ 9 | --reg_refine \ 10 | --num_reg_refine 1 11 | 12 | # --pred_bidir_depth 13 | 14 | -------------------------------------------------------------------------------- /unimatch/scripts/gmdepth_evaluate.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | # gmdepth-scale1 5 | CUDA_VISIBLE_DEVICES=0 python main_depth.py \ 6 | --eval \ 7 | --resume pretrained/gmdepth-scale1-resumeflowthings-demon-a2fe127b.pth \ 8 | --val_dataset demon \ 9 | --demon_split scenes11 10 | 11 | 12 | # gmdepth-scale1-regrefine1, this is our final model 13 | CUDA_VISIBLE_DEVICES=0 python main_depth.py \ 14 | --eval \ 15 | --resume pretrained/gmdepth-scale1-regrefine1-resumeflowthings-demon-7c23f230.pth \ 16 | --val_dataset demon \ 17 | --demon_split scenes11 \ 18 | --reg_refine \ 19 | --num_reg_refine 1 20 | 21 | -------------------------------------------------------------------------------- /unimatch/scripts/gmdepth_scale1_regrefine1_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # GMDepth (1/8 feature only), with additional 1 local regression refinement at 1/8 resolution 4 | 5 | # number of gpus for training, please set according to your hardware 6 | # trained on 8x 40GB A100 gpus 7 | NUM_GPUS=8 8 | 9 | 10 | # scannet 11 | CHECKPOINT_DIR=checkpoints_depth/scannet-gmdepth-scale1-regrefine1-resumeflowthings && \ 12 | mkdir -p ${CHECKPOINT_DIR} && \ 13 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_depth.py \ 14 | --launcher pytorch \ 15 | --checkpoint_dir ${CHECKPOINT_DIR} \ 16 | --resume pretrained/gmdepth-scale1-resumeflowthings-scannet-5d9d7964.pth \ 17 | --no_resume_optimizer \ 18 | --dataset scannet \ 19 | --val_dataset scannet \ 20 | --image_size 480 640 \ 21 | --batch_size 64 \ 22 | --lr 4e-4 \ 23 | --reg_refine \ 24 | --num_reg_refine 1 \ 25 | --summary_freq 100 \ 26 | --val_freq 5000 \ 27 | --save_ckpt_freq 5000 \ 28 | --num_steps 100000 \ 29 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 30 | 31 | 32 | # demon 33 | CHECKPOINT_DIR=checkpoints_depth/demon-gmdepth-scale1-regrefine1-resumeflowthings && \ 34 | mkdir -p ${CHECKPOINT_DIR} && \ 35 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_depth.py \ 36 | --launcher pytorch \ 37 | --checkpoint_dir ${CHECKPOINT_DIR} \ 38 | --resume pretrained/gmdepth-scale1-resumeflowthings-demon-a2fe127b.pth \ 39 | --no_resume_optimizer \ 40 | --dataset demon \ 41 | --val_dataset demon \ 42 | --demon_split rgbd \ 43 | --image_size 448 576 \ 44 | --batch_size 64 \ 45 | --lr 4e-4 \ 46 | --reg_refine \ 47 | --num_reg_refine 1 \ 48 | --summary_freq 100 \ 49 | --val_freq 5000 \ 50 | --save_ckpt_freq 5000 \ 51 | --num_steps 100000 \ 52 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 53 | 54 | 55 | -------------------------------------------------------------------------------- /unimatch/scripts/gmdepth_scale1_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # basic GMDepth without any refinement (1/8 feature only) 4 | 5 | # number of gpus for training, please set according to your hardware 6 | # trained on 8x 40GB A100 gpus 7 | NUM_GPUS=8 8 | 9 | 10 | # scannet (our final model is trained for 100K steps, for ablation, we train for 50K) 11 | # resume flow things model (our ablations are trained from random init) 12 | CHECKPOINT_DIR=checkpoints_depth/scannet-gmdepth-scale1-resumeflowthings && \ 13 | mkdir -p ${CHECKPOINT_DIR} && \ 14 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_depth.py \ 15 | --launcher pytorch \ 16 | --checkpoint_dir ${CHECKPOINT_DIR} \ 17 | --resume pretrained/gmflow-scale1-things-e9887eda.pth \ 18 | --no_resume_optimizer \ 19 | --dataset scannet \ 20 | --val_dataset scannet \ 21 | --image_size 480 640 \ 22 | --batch_size 80 \ 23 | --lr 4e-4 \ 24 | --summary_freq 100 \ 25 | --val_freq 5000 \ 26 | --save_ckpt_freq 5000 \ 27 | --num_steps 100000 \ 28 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 29 | 30 | 31 | # demon, resume flow things model 32 | CHECKPOINT_DIR=checkpoints_depth/demon-gmdepth-scale1-resumeflowthings && \ 33 | mkdir -p ${CHECKPOINT_DIR} && \ 34 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_depth.py \ 35 | --launcher pytorch \ 36 | --checkpoint_dir ${CHECKPOINT_DIR} \ 37 | --resume pretrained/gmflow-scale1-things-e9887eda.pth \ 38 | --no_resume_optimizer \ 39 | --dataset demon \ 40 | --val_dataset demon \ 41 | --demon_split rgbd \ 42 | --image_size 448 576 \ 43 | --batch_size 80 \ 44 | --lr 4e-4 \ 45 | --summary_freq 100 \ 46 | --val_freq 5000 \ 47 | --save_ckpt_freq 5000 \ 48 | --num_steps 100000 \ 49 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 50 | 51 | 52 | -------------------------------------------------------------------------------- /unimatch/scripts/gmflow_demo.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | # gmflow-scale2-regrefine6, inference on image dir 5 | CUDA_VISIBLE_DEVICES=0 python main_flow.py \ 6 | --inference_dir demo/flow-davis \ 7 | --resume pretrained/gmflow-scale2-regrefine6-mixdata-train320x576-4e7b215d.pth \ 8 | --output_path output/gmflow-scale2-regrefine6-davis \ 9 | --padding_factor 32 \ 10 | --upsample_factor 4 \ 11 | --num_scales 2 \ 12 | --attn_splits_list 2 8 \ 13 | --corr_radius_list -1 4 \ 14 | --prop_radius_list -1 1 \ 15 | --reg_refine \ 16 | --num_reg_refine 6 17 | 18 | 19 | # gmflow-scale2-regrefine6, inference on video, save as video 20 | CUDA_VISIBLE_DEVICES=0 python main_flow.py \ 21 | --inference_video demo/kitti.mp4 \ 22 | --resume pretrained/gmflow-scale2-regrefine6-kitti15-25b554d7.pth \ 23 | --output_path output/kitti \ 24 | --padding_factor 32 \ 25 | --upsample_factor 4 \ 26 | --num_scales 2 \ 27 | --attn_splits_list 2 8 \ 28 | --corr_radius_list -1 4 \ 29 | --prop_radius_list -1 1 \ 30 | --reg_refine \ 31 | --num_reg_refine 6 \ 32 | --save_video \ 33 | --concat_flow_img 34 | 35 | 36 | 37 | # gmflow-scale1, inference on image dir 38 | CUDA_VISIBLE_DEVICES=0 python main_flow.py \ 39 | --inference_dir demo/flow-davis \ 40 | --resume pretrained/gmflow-scale1-mixdata-train320x576-4c3a6e9a.pth \ 41 | --output_path output/gmflow-scale1-davis 42 | 43 | # optional predict bidirection flow and forward-backward consistency check 44 | #--pred_bidir_flow 45 | #--fwd_bwd_check 46 | 47 | 48 | # gmflow-scale2, inference on image dir 49 | CUDA_VISIBLE_DEVICES=0 python main_flow.py \ 50 | --inference_dir demo/flow-davis \ 51 | --resume pretrained/gmflow-scale2-mixdata-train320x576-9ff1c094.pth \ 52 | --output_path output/gmflow-scale2-davis \ 53 | --padding_factor 32 \ 54 | --upsample_factor 4 \ 55 | --num_scales 2 \ 56 | --attn_splits_list 2 8 \ 57 | --corr_radius_list -1 4 \ 58 | --prop_radius_list -1 1 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /unimatch/scripts/gmflow_evaluate.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | # gmflow-scale1 5 | CUDA_VISIBLE_DEVICES=0 python main_flow.py \ 6 | --eval \ 7 | --resume pretrained/gmflow-scale1-things-e9887eda.pth \ 8 | --val_dataset sintel \ 9 | --with_speed_metric 10 | 11 | 12 | # gmflow-scale2 13 | CUDA_VISIBLE_DEVICES=0 python main_flow.py \ 14 | --eval \ 15 | --resume pretrained/gmflow-scale2-things-36579974.pth \ 16 | --val_dataset kitti \ 17 | --padding_factor 32 \ 18 | --upsample_factor 4 \ 19 | --num_scales 2 \ 20 | --attn_splits_list 2 8 \ 21 | --corr_radius_list -1 4 \ 22 | --prop_radius_list -1 1 \ 23 | --with_speed_metric 24 | 25 | 26 | # gmflow-scale2-regrefine6 27 | CUDA_VISIBLE_DEVICES=0 python main_flow.py \ 28 | --eval \ 29 | --resume pretrained/gmflow-scale2-regrefine6-things-776ed612.pth \ 30 | --val_dataset kitti \ 31 | --padding_factor 32 \ 32 | --upsample_factor 4 \ 33 | --num_scales 2 \ 34 | --attn_splits_list 2 8 \ 35 | --corr_radius_list -1 4 \ 36 | --prop_radius_list -1 1 \ 37 | --reg_refine \ 38 | --num_reg_refine 6 \ 39 | --with_speed_metric 40 | 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /unimatch/scripts/gmflow_scale1_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # basic GMFlow without any refinement (1/8 feature only) 4 | 5 | # number of gpus for training, please set according to your hardware 6 | # can be trained on 4x 16GB V100 or 2x 32GB V100 or 2x 40GB A100 gpus 7 | NUM_GPUS=4 8 | 9 | # chairs 10 | CHECKPOINT_DIR=checkpoints_flow/chairs-gmflow-scale1 && \ 11 | mkdir -p ${CHECKPOINT_DIR} && \ 12 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_flow.py \ 13 | --launcher pytorch \ 14 | --checkpoint_dir ${CHECKPOINT_DIR} \ 15 | --stage chairs \ 16 | --batch_size 16 \ 17 | --val_dataset chairs sintel kitti \ 18 | --lr 4e-4 \ 19 | --image_size 384 512 \ 20 | --padding_factor 16 \ 21 | --upsample_factor 8 \ 22 | --with_speed_metric \ 23 | --val_freq 10000 \ 24 | --save_ckpt_freq 10000 \ 25 | --num_steps 100000 \ 26 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 27 | 28 | # things (our final model is trained for 800K iterations, for ablation study, you can train for 200K) 29 | CHECKPOINT_DIR=checkpoints_flow/things-gmflow-scale1 && \ 30 | mkdir -p ${CHECKPOINT_DIR} && \ 31 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_flow.py \ 32 | --launcher pytorch \ 33 | --checkpoint_dir ${CHECKPOINT_DIR} \ 34 | --resume checkpoints_flow/chairs-gmflow-scale1/step_100000.pth \ 35 | --stage things \ 36 | --batch_size 8 \ 37 | --val_dataset things sintel kitti \ 38 | --lr 2e-4 \ 39 | --image_size 384 768 \ 40 | --padding_factor 16 \ 41 | --upsample_factor 8 \ 42 | --with_speed_metric \ 43 | --val_freq 40000 \ 44 | --save_ckpt_freq 50000 \ 45 | --num_steps 800000 \ 46 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 47 | 48 | # a final note: if your training is terminated unexpectedly, you can resume from the latest checkpoint 49 | # an example: resume chairs training 50 | # CHECKPOINT_DIR=checkpoints_flow/chairs-gmflow-scale1 && \ 51 | # mkdir -p ${CHECKPOINT_DIR} && \ 52 | # python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_flow.py \ 53 | # --launcher pytorch \ 54 | # --checkpoint_dir ${CHECKPOINT_DIR} \ 55 | # --resume checkpoints_flow/chairs-gmflow-scale1/checkpoint_latest.pth \ 56 | # --stage chairs \ 57 | # --batch_size 16 \ 58 | # --val_dataset chairs sintel kitti \ 59 | # --lr 4e-4 \ 60 | # --image_size 384 512 \ 61 | # --padding_factor 16 \ 62 | # --upsample_factor 8 \ 63 | # --with_speed_metric \ 64 | # --val_freq 10000 \ 65 | # --save_ckpt_freq 10000 \ 66 | # --num_steps 100000 \ 67 | # 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 68 | 69 | 70 | -------------------------------------------------------------------------------- /unimatch/scripts/gmflow_scale2_regrefine6_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # GMFlow with hierarchical matching refinement (1/8 + 1/4 features) 4 | # with additional 6 local regression refinements 5 | 6 | # number of gpus for training, please set according to your hardware 7 | # can be trained on 8x 32G V100 or 8x 40GB A100 gpus 8 | NUM_GPUS=8 9 | 10 | # chairs, resume from scale2 model 11 | CHECKPOINT_DIR=checkpoints_flow/chairs-gmflow-scale2-regrefine6 && \ 12 | mkdir -p ${CHECKPOINT_DIR} && \ 13 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_flow.py \ 14 | --launcher pytorch \ 15 | --checkpoint_dir ${CHECKPOINT_DIR} \ 16 | --resume pretrained/gmflow-scale2-chairs-020cc9be.pth \ 17 | --no_resume_optimizer \ 18 | --stage chairs \ 19 | --batch_size 16 \ 20 | --val_dataset chairs sintel kitti \ 21 | --lr 4e-4 \ 22 | --image_size 384 512 \ 23 | --padding_factor 32 \ 24 | --upsample_factor 4 \ 25 | --num_scales 2 \ 26 | --attn_splits_list 2 8 \ 27 | --corr_radius_list -1 4 \ 28 | --prop_radius_list -1 1 \ 29 | --reg_refine \ 30 | --num_reg_refine 6 \ 31 | --with_speed_metric \ 32 | --val_freq 10000 \ 33 | --save_ckpt_freq 10000 \ 34 | --num_steps 100000 \ 35 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 36 | 37 | # things 38 | CHECKPOINT_DIR=checkpoints_flow/things-gmflow-scale2-regrefine6 && \ 39 | mkdir -p ${CHECKPOINT_DIR} && \ 40 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_flow.py \ 41 | --launcher pytorch \ 42 | --checkpoint_dir ${CHECKPOINT_DIR} \ 43 | --resume checkpoints_flow/chairs-gmflow-scale2-regrefine6/step_100000.pth \ 44 | --stage things \ 45 | --batch_size 8 \ 46 | --val_dataset things sintel kitti \ 47 | --lr 2e-4 \ 48 | --image_size 384 768 \ 49 | --padding_factor 32 \ 50 | --upsample_factor 4 \ 51 | --num_scales 2 \ 52 | --attn_splits_list 2 8 \ 53 | --corr_radius_list -1 4 \ 54 | --prop_radius_list -1 1 \ 55 | --reg_refine \ 56 | --num_reg_refine 6 \ 57 | --with_speed_metric \ 58 | --val_freq 40000 \ 59 | --save_ckpt_freq 50000 \ 60 | --num_steps 800000 \ 61 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 62 | 63 | # sintel, resume from things model 64 | CHECKPOINT_DIR=checkpoints_flow/sintel-gmflow-scale2-regrefine6 && \ 65 | mkdir -p ${CHECKPOINT_DIR} && \ 66 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_flow.py \ 67 | --launcher pytorch \ 68 | --checkpoint_dir ${CHECKPOINT_DIR} \ 69 | --resume checkpoints_flow/things-gmflow-scale2-regrefine6/step_800000.pth \ 70 | --stage sintel \ 71 | --batch_size 8 \ 72 | --val_dataset sintel kitti \ 73 | --lr 2e-4 \ 74 | --image_size 320 896 \ 75 | --padding_factor 32 \ 76 | --upsample_factor 4 \ 77 | --num_scales 2 \ 78 | --attn_splits_list 2 8 \ 79 | --corr_radius_list -1 4 \ 80 | --prop_radius_list -1 1 \ 81 | --reg_refine \ 82 | --num_reg_refine 6 \ 83 | --with_speed_metric \ 84 | --val_freq 20000 \ 85 | --save_ckpt_freq 20000 \ 86 | --num_steps 200000 \ 87 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 88 | 89 | 90 | # sintel finetune, resume from sintel model, this is our final model for sintel benchmark submission 91 | CHECKPOINT_DIR=checkpoints_flow/sintel-gmflow-scale2-regrefine6-ft && \ 92 | mkdir -p ${CHECKPOINT_DIR} && \ 93 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_flow.py \ 94 | --launcher pytorch \ 95 | --checkpoint_dir ${CHECKPOINT_DIR} \ 96 | --resume checkpoints_flow/sintel-gmflow-scale2-regrefine6/step_200000.pth \ 97 | --stage sintel_ft \ 98 | --batch_size 8 \ 99 | --val_dataset sintel \ 100 | --lr 1e-4 \ 101 | --image_size 416 1024 \ 102 | --padding_factor 32 \ 103 | --upsample_factor 4 \ 104 | --num_scales 2 \ 105 | --attn_splits_list 2 8 \ 106 | --corr_radius_list -1 4 \ 107 | --prop_radius_list -1 1 \ 108 | --reg_refine \ 109 | --num_reg_refine 6 \ 110 | --with_speed_metric \ 111 | --val_freq 1000 \ 112 | --save_ckpt_freq 1000 \ 113 | --num_steps 5000 \ 114 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 115 | 116 | 117 | # vkitti2, resume from things model 118 | CHECKPOINT_DIR=checkpoints_flow/vkitti2-gmflow-scale2-regrefine6 && \ 119 | mkdir -p ${CHECKPOINT_DIR} && \ 120 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_flow.py \ 121 | --launcher pytorch \ 122 | --checkpoint_dir ${CHECKPOINT_DIR} \ 123 | --resume checkpoints_flow/things-gmflow-scale2-regrefine6/step_800000.pth \ 124 | --stage vkitti2 \ 125 | --batch_size 16 \ 126 | --val_dataset kitti \ 127 | --lr 2e-4 \ 128 | --image_size 320 832 \ 129 | --padding_factor 32 \ 130 | --upsample_factor 4 \ 131 | --num_scales 2 \ 132 | --attn_splits_list 2 8 \ 133 | --corr_radius_list -1 4 \ 134 | --prop_radius_list -1 1 \ 135 | --reg_refine \ 136 | --num_reg_refine 6 \ 137 | --with_speed_metric \ 138 | --val_freq 10000 \ 139 | --save_ckpt_freq 10000 \ 140 | --num_steps 40000 \ 141 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 142 | 143 | 144 | # kitti, resume from vkitti2 model, this is our final model for kitti benchmark submission 145 | CHECKPOINT_DIR=checkpoints_flow/kitti-gmflow-scale2-regrefine6 && \ 146 | mkdir -p ${CHECKPOINT_DIR} && \ 147 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_flow.py \ 148 | --launcher pytorch \ 149 | --checkpoint_dir ${CHECKPOINT_DIR} \ 150 | --resume checkpoints_flow/vkitti2-gmflow-scale2-regrefine6/step_040000.pth \ 151 | --stage kitti_mix \ 152 | --batch_size 8 \ 153 | --val_dataset kitti \ 154 | --lr 2e-4 \ 155 | --image_size 352 1216 \ 156 | --padding_factor 32 \ 157 | --upsample_factor 4 \ 158 | --num_scales 2 \ 159 | --attn_splits_list 2 8 \ 160 | --corr_radius_list -1 4 \ 161 | --prop_radius_list -1 1 \ 162 | --reg_refine \ 163 | --num_reg_refine 6 \ 164 | --with_speed_metric \ 165 | --val_freq 5000 \ 166 | --save_ckpt_freq 10000 \ 167 | --num_steps 30000 \ 168 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 169 | 170 | 171 | -------------------------------------------------------------------------------- /unimatch/scripts/gmflow_scale2_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # GMFlow with hierarchical matching refinement (1/8 + 1/4 features) 4 | 5 | # number of gpus for training, please set according to your hardware 6 | # can be trained on 4x 32G V100 or 4x 40GB A100 or 8x 16G V100 gpus 7 | NUM_GPUS=4 8 | 9 | # chairs 10 | CHECKPOINT_DIR=checkpoints_flow/chairs-gmflow-scale2 && \ 11 | mkdir -p ${CHECKPOINT_DIR} && \ 12 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_flow.py \ 13 | --launcher pytorch \ 14 | --checkpoint_dir ${CHECKPOINT_DIR} \ 15 | --stage chairs \ 16 | --batch_size 16 \ 17 | --val_dataset chairs sintel kitti \ 18 | --lr 4e-4 \ 19 | --image_size 384 512 \ 20 | --padding_factor 32 \ 21 | --upsample_factor 4 \ 22 | --num_scales 2 \ 23 | --attn_splits_list 2 8 \ 24 | --corr_radius_list -1 4 \ 25 | --prop_radius_list -1 1 \ 26 | --with_speed_metric \ 27 | --val_freq 10000 \ 28 | --save_ckpt_freq 10000 \ 29 | --num_steps 100000 \ 30 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 31 | 32 | # things (our final model is trained for 800K iterations, for ablation study, you can train for 200K) 33 | CHECKPOINT_DIR=checkpoints_flow/things-gmflow-scale2 && \ 34 | mkdir -p ${CHECKPOINT_DIR} && \ 35 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_flow.py \ 36 | --launcher pytorch \ 37 | --checkpoint_dir ${CHECKPOINT_DIR} \ 38 | --resume checkpoints_flow/chairs-gmflow-scale2/step_100000.pth \ 39 | --stage things \ 40 | --batch_size 8 \ 41 | --val_dataset things sintel kitti \ 42 | --lr 2e-4 \ 43 | --image_size 384 768 \ 44 | --padding_factor 32 \ 45 | --upsample_factor 4 \ 46 | --num_scales 2 \ 47 | --attn_splits_list 2 8 \ 48 | --corr_radius_list -1 4 \ 49 | --prop_radius_list -1 1 \ 50 | --with_speed_metric \ 51 | --val_freq 40000 \ 52 | --save_ckpt_freq 50000 \ 53 | --num_steps 800000 \ 54 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 55 | 56 | # sintel 57 | CHECKPOINT_DIR=checkpoints_flow/sintel-gmflow-scale2 && \ 58 | mkdir -p ${CHECKPOINT_DIR} && \ 59 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_flow.py \ 60 | --launcher pytorch \ 61 | --checkpoint_dir ${CHECKPOINT_DIR} \ 62 | --resume checkpoints_flow/things-gmflow-scale2/step_800000.pth \ 63 | --stage sintel \ 64 | --batch_size 8 \ 65 | --val_dataset sintel kitti \ 66 | --lr 2e-4 \ 67 | --image_size 320 896 \ 68 | --padding_factor 32 \ 69 | --upsample_factor 4 \ 70 | --num_scales 2 \ 71 | --attn_splits_list 2 8 \ 72 | --corr_radius_list -1 4 \ 73 | --prop_radius_list -1 1 \ 74 | --with_speed_metric \ 75 | --val_freq 20000 \ 76 | --save_ckpt_freq 20000 \ 77 | --num_steps 200000 \ 78 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 79 | 80 | # kitti 81 | CHECKPOINT_DIR=checkpoints_flow/kitti-gmflow-scale2 && \ 82 | mkdir -p ${CHECKPOINT_DIR} && \ 83 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_flow.py \ 84 | --launcher pytorch \ 85 | --checkpoint_dir ${CHECKPOINT_DIR} \ 86 | --resume checkpoints_flow/sintel-gmflow-scale2/step_200000.pth \ 87 | --stage kitti \ 88 | --batch_size 8 \ 89 | --val_dataset kitti \ 90 | --lr 2e-4 \ 91 | --image_size 320 1152 \ 92 | --padding_factor 32 \ 93 | --upsample_factor 4 \ 94 | --num_scales 2 \ 95 | --attn_splits_list 2 8 \ 96 | --corr_radius_list -1 4 \ 97 | --prop_radius_list -1 1 \ 98 | --with_speed_metric \ 99 | --val_freq 10000 \ 100 | --save_ckpt_freq 10000 \ 101 | --num_steps 100000 \ 102 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 103 | 104 | -------------------------------------------------------------------------------- /unimatch/scripts/gmflow_submission.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | # generate prediction results for submission on sintel and kitti online servers 5 | 6 | 7 | # submission to sintel 8 | CUDA_VISIBLE_DEVICES=0 python main_flow.py \ 9 | --submission \ 10 | --output_path submission/sintel-gmflow-scale2-regrefine6-sintelft \ 11 | --val_dataset sintel \ 12 | --resume pretrained/gmflow-scale2-regrefine6-sintelft-6e39e2b9.pth \ 13 | --inference_size 416 1024 \ 14 | --padding_factor 32 \ 15 | --upsample_factor 4 \ 16 | --num_scales 2 \ 17 | --attn_splits_list 2 8 \ 18 | --corr_radius_list -1 4 \ 19 | --prop_radius_list -1 1 \ 20 | --reg_refine \ 21 | --num_reg_refine 6 22 | 23 | 24 | # you can also visualize the predictions before submission 25 | #CUDA_VISIBLE_DEVICES=0 python main_flow.py \ 26 | #--submission \ 27 | #--output_path submission/sintel-gmflow-scale2-regrefine6-sintelft-vis \ 28 | #--val_dataset sintel \ 29 | #--resume pretrained/gmflow-scale2-regrefine6-sintelft-6e39e2b9.pth \ 30 | #--inference_size 416 1024 \ 31 | #--save_vis_flow \ 32 | #--no_save_flo \ 33 | #--padding_factor 32 \ 34 | #--upsample_factor 4 \ 35 | #--num_scales 2 \ 36 | #--attn_splits_list 2 8 \ 37 | #--corr_radius_list -1 4 \ 38 | #--prop_radius_list -1 1 \ 39 | #--reg_refine \ 40 | #--num_reg_refine 6 41 | 42 | 43 | # submission to kitti 44 | CUDA_VISIBLE_DEVICES=0 python main_flow.py \ 45 | --submission \ 46 | --output_path submission/kitti-gmflow-scale2-regrefine6 \ 47 | --val_dataset kitti \ 48 | --resume pretrained/gmflow-scale2-regrefine6-kitti15-25b554d7.pth \ 49 | --inference_size 352 1216 \ 50 | --padding_factor 32 \ 51 | --upsample_factor 4 \ 52 | --num_scales 2 \ 53 | --attn_splits_list 2 8 \ 54 | --corr_radius_list -1 4 \ 55 | --prop_radius_list -1 1 \ 56 | --reg_refine \ 57 | --num_reg_refine 6 58 | 59 | 60 | -------------------------------------------------------------------------------- /unimatch/scripts/gmstereo_demo.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | # gmstereo-scale2-regrefine3 model 5 | CUDA_VISIBLE_DEVICES=0 python main_stereo.py \ 6 | --inference_dir demo/stereo-middlebury \ 7 | --inference_size 1024 1536 \ 8 | --output_path output/gmstereo-scale2-regrefine3-middlebury \ 9 | --resume pretrained/gmstereo-scale2-regrefine3-resumeflowthings-middleburyfthighres-a82bec03.pth \ 10 | --padding_factor 32 \ 11 | --upsample_factor 4 \ 12 | --num_scales 2 \ 13 | --attn_type self_swin2d_cross_swin1d \ 14 | --attn_splits_list 2 8 \ 15 | --corr_radius_list -1 4 \ 16 | --prop_radius_list -1 1 \ 17 | --reg_refine \ 18 | --num_reg_refine 3 19 | 20 | # optionally predict both left and right disparities 21 | #--pred_bidir_disp 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /unimatch/scripts/gmstereo_evaluate.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | # gmstereo-scale1 5 | CUDA_VISIBLE_DEVICES=0 python main_stereo.py \ 6 | --eval \ 7 | --resume pretrained/gmstereo-scale1-resumeflowthings-sceneflow-16e38788.pth \ 8 | --val_dataset kitti15 9 | 10 | 11 | # gmstereo-scale2 12 | CUDA_VISIBLE_DEVICES=0 python main_stereo.py \ 13 | --eval \ 14 | --resume pretrained/gmstereo-scale2-resumeflowthings-sceneflow-48020649.pth \ 15 | --val_dataset kitti15 \ 16 | --padding_factor 32 \ 17 | --upsample_factor 4 \ 18 | --num_scales 2 \ 19 | --attn_type self_swin2d_cross_swin1d \ 20 | --attn_splits_list 2 8 \ 21 | --corr_radius_list -1 4 \ 22 | --prop_radius_list -1 1 23 | 24 | 25 | # gmstereo-scale2-regrefine3 26 | CUDA_VISIBLE_DEVICES=0 python main_stereo.py \ 27 | --eval \ 28 | --resume pretrained/gmstereo-scale2-regrefine3-resumeflowthings-sceneflow-f724fee6.pth \ 29 | --val_dataset kitti15 \ 30 | --padding_factor 32 \ 31 | --upsample_factor 4 \ 32 | --num_scales 2 \ 33 | --attn_type self_swin2d_cross_swin1d \ 34 | --attn_splits_list 2 8 \ 35 | --corr_radius_list -1 4 \ 36 | --prop_radius_list -1 1 \ 37 | --reg_refine \ 38 | --num_reg_refine 3 39 | 40 | 41 | -------------------------------------------------------------------------------- /unimatch/scripts/gmstereo_scale1_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # basic GMStereo without any refinement (1/8 feature only) 4 | 5 | # number of gpus for training, please set according to your hardware 6 | # trained on 8x 40GB A100 gpus 7 | NUM_GPUS=8 8 | 9 | # sceneflow (our final model is trained for 100K steps, for ablation, we train for 50K) 10 | # resume flow things model (our ablations are trained from random init) 11 | CHECKPOINT_DIR=checkpoints_stereo/sceneflow-gmstereo-scale1-resumeflowthings && \ 12 | mkdir -p ${CHECKPOINT_DIR} && \ 13 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_stereo.py \ 14 | --launcher pytorch \ 15 | --checkpoint_dir ${CHECKPOINT_DIR} \ 16 | --resume pretrained/gmflow-scale1-things-e9887eda.pth \ 17 | --no_resume_optimizer \ 18 | --stage sceneflow \ 19 | --batch_size 64 \ 20 | --val_dataset things kitti15 \ 21 | --img_height 384 \ 22 | --img_width 768 \ 23 | --padding_factor 16 \ 24 | --upsample_factor 8 \ 25 | --attn_type self_swin2d_cross_1d \ 26 | --summary_freq 1000 \ 27 | --val_freq 10000 \ 28 | --save_ckpt_freq 1000 \ 29 | --save_latest_ckpt_freq 1000 \ 30 | --num_steps 100000 \ 31 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 32 | 33 | 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /unimatch/scripts/gmstereo_scale2_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # GMFlow with hierarchical matching refinement (1/8 + 1/4 features) 4 | 5 | # number of gpus for training, please set according to your hardware 6 | # trained on 8x 40GB A100 gpus 7 | NUM_GPUS=8 8 | 9 | # sceneflow 10 | # resume flow things model 11 | CHECKPOINT_DIR=checkpoints_stereo/sceneflow-gmstereo-scale2-resumeflowthings && \ 12 | mkdir -p ${CHECKPOINT_DIR} && \ 13 | python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} --master_port=9989 main_stereo.py \ 14 | --launcher pytorch \ 15 | --checkpoint_dir ${CHECKPOINT_DIR} \ 16 | --resume pretrained/gmflow-scale2-things-36579974.pth \ 17 | --no_resume_optimizer \ 18 | --stage sceneflow \ 19 | --batch_size 32 \ 20 | --val_dataset things kitti15 \ 21 | --img_height 384 \ 22 | --img_width 768 \ 23 | --padding_factor 32 \ 24 | --upsample_factor 4 \ 25 | --num_scales 2 \ 26 | --attn_type self_swin2d_cross_swin1d \ 27 | --attn_splits_list 2 8 \ 28 | --corr_radius_list -1 4 \ 29 | --prop_radius_list -1 1 \ 30 | --summary_freq 100 \ 31 | --val_freq 10000 \ 32 | --save_ckpt_freq 1000 \ 33 | --save_latest_ckpt_freq 1000 \ 34 | --num_steps 100000 \ 35 | 2>&1 | tee -a ${CHECKPOINT_DIR}/train.log 36 | 37 | 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /unimatch/scripts/gmstereo_submission.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # generate prediction results for submission on kitti, middlebury and eth3d online servers 4 | 5 | 6 | # submission to kitti 7 | CUDA_VISIBLE_DEVICES=0 python main_stereo.py \ 8 | --submission \ 9 | --val_dataset kitti15 \ 10 | --inference_size 352 1216 \ 11 | --output_path submission/kitti-gmstereo-scale2-regrefine3 \ 12 | --resume pretrained/gmstereo-scale2-regrefine3-resumeflowthings-kitti15-04487ebf.pth \ 13 | --padding_factor 32 \ 14 | --upsample_factor 4 \ 15 | --num_scales 2 \ 16 | --attn_type self_swin2d_cross_swin1d \ 17 | --attn_splits_list 2 8 \ 18 | --corr_radius_list -1 4 \ 19 | --prop_radius_list -1 1 \ 20 | --reg_refine \ 21 | --num_reg_refine 3 22 | 23 | 24 | # submission to middlebury 25 | # set --eth_submission_mode to train and test to generate results on both train and test sets 26 | # use --save_vis_disp to visualize disparity 27 | CUDA_VISIBLE_DEVICES=0 python main_stereo.py \ 28 | --submission \ 29 | --val_dataset middlebury \ 30 | --middlebury_resolution F \ 31 | --middlebury_submission_mode test \ 32 | --inference_size 1024 1536 \ 33 | --output_path submission/middlebury-test-gmstereo-scale2-regrefine3 \ 34 | --resume pretrained/gmstereo-scale2-regrefine3-resumeflowthings-middleburyfthighres-a82bec03.pth \ 35 | --padding_factor 32 \ 36 | --upsample_factor 4 \ 37 | --num_scales 2 \ 38 | --attn_type self_swin2d_cross_swin1d \ 39 | --attn_splits_list 2 8 \ 40 | --corr_radius_list -1 4 \ 41 | --prop_radius_list -1 1 \ 42 | --reg_refine \ 43 | --num_reg_refine 3 44 | 45 | 46 | # submission to eth3d 47 | # set --eth_submission_mode to train and test to generate results on both train and test sets 48 | # use --save_vis_disp to visualize disparity 49 | CUDA_VISIBLE_DEVICES=0 python main_stereo.py \ 50 | --submission \ 51 | --eth_submission_mode test \ 52 | --val_dataset eth3d \ 53 | --inference_size 512 768 \ 54 | --output_path submission/eth3d-test-gmstereo-scale2-regrefine3 \ 55 | --resume pretrained/gmstereo-scale2-regrefine3-resumeflowthings-eth3dft-46effc13.pth \ 56 | --padding_factor 32 \ 57 | --upsample_factor 4 \ 58 | --num_scales 2 \ 59 | --attn_type self_swin2d_cross_swin1d \ 60 | --attn_splits_list 2 8 \ 61 | --corr_radius_list -1 4 \ 62 | --prop_radius_list -1 1 \ 63 | --reg_refine \ 64 | --num_reg_refine 3 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /unimatch/unimatch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/logtd/ComfyUI-Veevee/19ff86a9d88cb6a584bca059987fab42a38441a3/unimatch/unimatch/__init__.py -------------------------------------------------------------------------------- /unimatch/unimatch/backbone.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .trident_conv import MultiScaleTridentConv 4 | 5 | 6 | class ResidualBlock(nn.Module): 7 | def __init__(self, in_planes, planes, norm_layer=nn.InstanceNorm2d, stride=1, dilation=1, 8 | ): 9 | super(ResidualBlock, self).__init__() 10 | 11 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, 12 | dilation=dilation, padding=dilation, stride=stride, bias=False) 13 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 14 | dilation=dilation, padding=dilation, bias=False) 15 | self.relu = nn.ReLU(inplace=True) 16 | 17 | self.norm1 = norm_layer(planes) 18 | self.norm2 = norm_layer(planes) 19 | if not stride == 1 or in_planes != planes: 20 | self.norm3 = norm_layer(planes) 21 | 22 | if stride == 1 and in_planes == planes: 23 | self.downsample = None 24 | else: 25 | self.downsample = nn.Sequential( 26 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 27 | 28 | def forward(self, x): 29 | y = x 30 | y = self.relu(self.norm1(self.conv1(y))) 31 | y = self.relu(self.norm2(self.conv2(y))) 32 | 33 | if self.downsample is not None: 34 | x = self.downsample(x) 35 | 36 | return self.relu(x + y) 37 | 38 | 39 | class CNNEncoder(nn.Module): 40 | def __init__(self, output_dim=128, 41 | norm_layer=nn.InstanceNorm2d, 42 | num_output_scales=1, 43 | **kwargs, 44 | ): 45 | super(CNNEncoder, self).__init__() 46 | self.num_branch = num_output_scales 47 | 48 | feature_dims = [64, 96, 128] 49 | 50 | self.conv1 = nn.Conv2d(3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False) # 1/2 51 | self.norm1 = norm_layer(feature_dims[0]) 52 | self.relu1 = nn.ReLU(inplace=True) 53 | 54 | self.in_planes = feature_dims[0] 55 | self.layer1 = self._make_layer(feature_dims[0], stride=1, norm_layer=norm_layer) # 1/2 56 | self.layer2 = self._make_layer(feature_dims[1], stride=2, norm_layer=norm_layer) # 1/4 57 | 58 | # highest resolution 1/4 or 1/8 59 | stride = 2 if num_output_scales == 1 else 1 60 | self.layer3 = self._make_layer(feature_dims[2], stride=stride, 61 | norm_layer=norm_layer, 62 | ) # 1/4 or 1/8 63 | 64 | self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0) 65 | 66 | if self.num_branch > 1: 67 | if self.num_branch == 4: 68 | strides = (1, 2, 4, 8) 69 | elif self.num_branch == 3: 70 | strides = (1, 2, 4) 71 | elif self.num_branch == 2: 72 | strides = (1, 2) 73 | else: 74 | raise ValueError 75 | 76 | self.trident_conv = MultiScaleTridentConv(output_dim, output_dim, 77 | kernel_size=3, 78 | strides=strides, 79 | paddings=1, 80 | num_branch=self.num_branch, 81 | ) 82 | 83 | for m in self.modules(): 84 | if isinstance(m, nn.Conv2d): 85 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 86 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 87 | if m.weight is not None: 88 | nn.init.constant_(m.weight, 1) 89 | if m.bias is not None: 90 | nn.init.constant_(m.bias, 0) 91 | 92 | def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d): 93 | layer1 = ResidualBlock(self.in_planes, dim, norm_layer=norm_layer, stride=stride, dilation=dilation) 94 | layer2 = ResidualBlock(dim, dim, norm_layer=norm_layer, stride=1, dilation=dilation) 95 | 96 | layers = (layer1, layer2) 97 | 98 | self.in_planes = dim 99 | return nn.Sequential(*layers) 100 | 101 | def forward(self, x): 102 | x = self.conv1(x) 103 | x = self.norm1(x) 104 | x = self.relu1(x) 105 | 106 | x = self.layer1(x) # 1/2 107 | x = self.layer2(x) # 1/4 108 | x = self.layer3(x) # 1/8 or 1/4 109 | 110 | x = self.conv2(x) 111 | 112 | if self.num_branch > 1: 113 | out = self.trident_conv([x] * self.num_branch) # high to low res 114 | else: 115 | out = [x] 116 | 117 | return out 118 | -------------------------------------------------------------------------------- /unimatch/unimatch/position.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # https://github.com/facebookresearch/detr/blob/main/models/position_encoding.py 3 | 4 | import torch 5 | import torch.nn as nn 6 | import math 7 | 8 | 9 | class PositionEmbeddingSine(nn.Module): 10 | """ 11 | This is a more standard version of the position embedding, very similar to the one 12 | used by the Attention is all you need paper, generalized to work on images. 13 | """ 14 | 15 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None): 16 | super().__init__() 17 | self.num_pos_feats = num_pos_feats 18 | self.temperature = temperature 19 | self.normalize = normalize 20 | if scale is not None and normalize is False: 21 | raise ValueError("normalize should be True if scale is passed") 22 | if scale is None: 23 | scale = 2 * math.pi 24 | self.scale = scale 25 | 26 | def forward(self, x): 27 | # x = tensor_list.tensors # [B, C, H, W] 28 | # mask = tensor_list.mask # [B, H, W], input with padding, valid as 0 29 | b, c, h, w = x.size() 30 | mask = torch.ones((b, h, w), device=x.device) # [B, H, W] 31 | y_embed = mask.cumsum(1, dtype=torch.float32) 32 | x_embed = mask.cumsum(2, dtype=torch.float32) 33 | if self.normalize: 34 | eps = 1e-6 35 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 36 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 37 | 38 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 39 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 40 | 41 | pos_x = x_embed[:, :, :, None] / dim_t 42 | pos_y = y_embed[:, :, :, None] / dim_t 43 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 44 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 45 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 46 | return pos 47 | -------------------------------------------------------------------------------- /unimatch/unimatch/reg_refine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FlowHead(nn.Module): 7 | def __init__(self, input_dim=128, hidden_dim=256, 8 | out_dim=2, 9 | ): 10 | super(FlowHead, self).__init__() 11 | 12 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 13 | self.conv2 = nn.Conv2d(hidden_dim, out_dim, 3, padding=1) 14 | self.relu = nn.ReLU(inplace=True) 15 | 16 | def forward(self, x): 17 | out = self.conv2(self.relu(self.conv1(x))) 18 | 19 | return out 20 | 21 | 22 | class SepConvGRU(nn.Module): 23 | def __init__(self, hidden_dim=128, input_dim=192 + 128, 24 | kernel_size=5, 25 | ): 26 | padding = (kernel_size - 1) // 2 27 | 28 | super(SepConvGRU, self).__init__() 29 | self.convz1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding)) 30 | self.convr1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding)) 31 | self.convq1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding)) 32 | 33 | self.convz2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0)) 34 | self.convr2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0)) 35 | self.convq2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0)) 36 | 37 | def forward(self, h, x): 38 | # horizontal 39 | hx = torch.cat([h, x], dim=1) 40 | z = torch.sigmoid(self.convz1(hx)) 41 | r = torch.sigmoid(self.convr1(hx)) 42 | q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1))) 43 | h = (1 - z) * h + z * q 44 | 45 | # vertical 46 | hx = torch.cat([h, x], dim=1) 47 | z = torch.sigmoid(self.convz2(hx)) 48 | r = torch.sigmoid(self.convr2(hx)) 49 | q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1))) 50 | h = (1 - z) * h + z * q 51 | 52 | return h 53 | 54 | 55 | class BasicMotionEncoder(nn.Module): 56 | def __init__(self, corr_channels=324, 57 | flow_channels=2, 58 | ): 59 | super(BasicMotionEncoder, self).__init__() 60 | 61 | self.convc1 = nn.Conv2d(corr_channels, 256, 1, padding=0) 62 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 63 | self.convf1 = nn.Conv2d(flow_channels, 128, 7, padding=3) 64 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 65 | self.conv = nn.Conv2d(64 + 192, 128 - flow_channels, 3, padding=1) 66 | 67 | def forward(self, flow, corr): 68 | cor = F.relu(self.convc1(corr)) 69 | cor = F.relu(self.convc2(cor)) 70 | flo = F.relu(self.convf1(flow)) 71 | flo = F.relu(self.convf2(flo)) 72 | 73 | cor_flo = torch.cat([cor, flo], dim=1) 74 | out = F.relu(self.conv(cor_flo)) 75 | return torch.cat([out, flow], dim=1) 76 | 77 | 78 | class BasicUpdateBlock(nn.Module): 79 | def __init__(self, corr_channels=324, 80 | hidden_dim=128, 81 | context_dim=128, 82 | downsample_factor=8, 83 | flow_dim=2, 84 | bilinear_up=False, 85 | ): 86 | super(BasicUpdateBlock, self).__init__() 87 | 88 | self.encoder = BasicMotionEncoder(corr_channels=corr_channels, 89 | flow_channels=flow_dim, 90 | ) 91 | 92 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=context_dim + hidden_dim) 93 | 94 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256, 95 | out_dim=flow_dim, 96 | ) 97 | 98 | if bilinear_up: 99 | self.mask = None 100 | else: 101 | self.mask = nn.Sequential( 102 | nn.Conv2d(hidden_dim, 256, 3, padding=1), 103 | nn.ReLU(inplace=True), 104 | nn.Conv2d(256, downsample_factor ** 2 * 9, 1, padding=0)) 105 | 106 | def forward(self, net, inp, corr, flow): 107 | motion_features = self.encoder(flow, corr) 108 | 109 | inp = torch.cat([inp, motion_features], dim=1) 110 | 111 | net = self.gru(net, inp) 112 | delta_flow = self.flow_head(net) 113 | 114 | if self.mask is not None: 115 | mask = self.mask(net) 116 | else: 117 | mask = None 118 | 119 | return net, mask, delta_flow 120 | -------------------------------------------------------------------------------- /unimatch/unimatch/trident_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # https://github.com/facebookresearch/detectron2/blob/main/projects/TridentNet/tridentnet/trident_conv.py 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | from torch.nn.modules.utils import _pair 8 | 9 | 10 | class MultiScaleTridentConv(nn.Module): 11 | def __init__( 12 | self, 13 | in_channels, 14 | out_channels, 15 | kernel_size, 16 | stride=1, 17 | strides=1, 18 | paddings=0, 19 | dilations=1, 20 | dilation=1, 21 | groups=1, 22 | num_branch=1, 23 | test_branch_idx=-1, 24 | bias=False, 25 | norm=None, 26 | activation=None, 27 | ): 28 | super(MultiScaleTridentConv, self).__init__() 29 | self.in_channels = in_channels 30 | self.out_channels = out_channels 31 | self.kernel_size = _pair(kernel_size) 32 | self.num_branch = num_branch 33 | self.stride = _pair(stride) 34 | self.groups = groups 35 | self.with_bias = bias 36 | self.dilation = dilation 37 | if isinstance(paddings, int): 38 | paddings = [paddings] * self.num_branch 39 | if isinstance(dilations, int): 40 | dilations = [dilations] * self.num_branch 41 | if isinstance(strides, int): 42 | strides = [strides] * self.num_branch 43 | self.paddings = [_pair(padding) for padding in paddings] 44 | self.dilations = [_pair(dilation) for dilation in dilations] 45 | self.strides = [_pair(stride) for stride in strides] 46 | self.test_branch_idx = test_branch_idx 47 | self.norm = norm 48 | self.activation = activation 49 | 50 | assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1 51 | 52 | self.weight = nn.Parameter( 53 | torch.Tensor(out_channels, in_channels // groups, *self.kernel_size) 54 | ) 55 | if bias: 56 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 57 | else: 58 | self.bias = None 59 | 60 | nn.init.kaiming_uniform_(self.weight, nonlinearity="relu") 61 | if self.bias is not None: 62 | nn.init.constant_(self.bias, 0) 63 | 64 | def forward(self, inputs): 65 | num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1 66 | assert len(inputs) == num_branch 67 | 68 | if self.training or self.test_branch_idx == -1: 69 | outputs = [ 70 | F.conv2d(input, self.weight, self.bias, stride, padding, self.dilation, self.groups) 71 | for input, stride, padding in zip(inputs, self.strides, self.paddings) 72 | ] 73 | else: 74 | outputs = [ 75 | F.conv2d( 76 | inputs[0], 77 | self.weight, 78 | self.bias, 79 | self.strides[self.test_branch_idx] if self.test_branch_idx == -1 else self.strides[-1], 80 | self.paddings[self.test_branch_idx] if self.test_branch_idx == -1 else self.paddings[-1], 81 | self.dilation, 82 | self.groups, 83 | ) 84 | ] 85 | 86 | if self.norm is not None: 87 | outputs = [self.norm(x) for x in outputs] 88 | if self.activation is not None: 89 | outputs = [self.activation(x) for x in outputs] 90 | return outputs 91 | -------------------------------------------------------------------------------- /unimatch/utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | # https://github.com/open-mmlab/mmcv/blob/7540cf73ac7e5d1e14d0ffbd9b6759e83929ecfc/mmcv/runner/dist_utils.py 3 | 4 | import os 5 | import subprocess 6 | 7 | import torch 8 | import torch.multiprocessing as mp 9 | from torch import distributed as dist 10 | 11 | 12 | def init_dist(launcher, backend='nccl', **kwargs): 13 | if mp.get_start_method(allow_none=True) is None: 14 | mp.set_start_method('spawn') 15 | if launcher == 'pytorch': 16 | _init_dist_pytorch(backend, **kwargs) 17 | elif launcher == 'mpi': 18 | _init_dist_mpi(backend, **kwargs) 19 | elif launcher == 'slurm': 20 | _init_dist_slurm(backend, **kwargs) 21 | else: 22 | raise ValueError(f'Invalid launcher type: {launcher}') 23 | 24 | 25 | def _init_dist_pytorch(backend, **kwargs): 26 | # TODO: use local_rank instead of rank % num_gpus 27 | rank = int(os.environ['RANK']) 28 | num_gpus = torch.cuda.device_count() 29 | torch.cuda.set_device(rank % num_gpus) 30 | dist.init_process_group(backend=backend, **kwargs) 31 | 32 | 33 | def _init_dist_mpi(backend, **kwargs): 34 | # TODO: use local_rank instead of rank % num_gpus 35 | rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 36 | num_gpus = torch.cuda.device_count() 37 | torch.cuda.set_device(rank % num_gpus) 38 | dist.init_process_group(backend=backend, **kwargs) 39 | 40 | 41 | def _init_dist_slurm(backend, port=None): 42 | """Initialize slurm distributed training environment. 43 | If argument ``port`` is not specified, then the master port will be system 44 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 45 | environment variable, then a default port ``29500`` will be used. 46 | Args: 47 | backend (str): Backend of torch.distributed. 48 | port (int, optional): Master port. Defaults to None. 49 | """ 50 | proc_id = int(os.environ['SLURM_PROCID']) 51 | ntasks = int(os.environ['SLURM_NTASKS']) 52 | node_list = os.environ['SLURM_NODELIST'] 53 | num_gpus = torch.cuda.device_count() 54 | torch.cuda.set_device(proc_id % num_gpus) 55 | addr = subprocess.getoutput( 56 | f'scontrol show hostname {node_list} | head -n1') 57 | # specify master port 58 | if port is not None: 59 | os.environ['MASTER_PORT'] = str(port) 60 | elif 'MASTER_PORT' in os.environ: 61 | pass # use MASTER_PORT in the environment variable 62 | else: 63 | # 29500 is torch.distributed default port 64 | os.environ['MASTER_PORT'] = '29500' 65 | # use MASTER_ADDR in the environment variable if it already exists 66 | if 'MASTER_ADDR' not in os.environ: 67 | os.environ['MASTER_ADDR'] = addr 68 | os.environ['WORLD_SIZE'] = str(ntasks) 69 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 70 | os.environ['RANK'] = str(proc_id) 71 | dist.init_process_group(backend=backend) 72 | 73 | 74 | def get_dist_info(): 75 | # if (TORCH_VERSION != 'parrots' 76 | # and digit_version(TORCH_VERSION) < digit_version('1.0')): 77 | # initialized = dist._initialized 78 | # else: 79 | if dist.is_available(): 80 | initialized = dist.is_initialized() 81 | else: 82 | initialized = False 83 | if initialized: 84 | rank = dist.get_rank() 85 | world_size = dist.get_world_size() 86 | else: 87 | rank = 0 88 | world_size = 1 89 | return rank, world_size 90 | 91 | 92 | # from DETR repo 93 | def setup_for_distributed(is_master): 94 | """ 95 | This function disables printing when not in master process 96 | """ 97 | import builtins as __builtin__ 98 | builtin_print = __builtin__.print 99 | 100 | def print(*args, **kwargs): 101 | force = kwargs.pop('force', False) 102 | if is_master or force: 103 | builtin_print(*args, **kwargs) 104 | 105 | __builtin__.print = print 106 | -------------------------------------------------------------------------------- /unimatch/utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from os.path import * 4 | import re 5 | import cv2 6 | 7 | TAG_CHAR = np.array([202021.25], np.float32) 8 | 9 | 10 | def readFlow(fn): 11 | """ Read .flo file in Middlebury format""" 12 | # Code adapted from: 13 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 14 | 15 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 16 | # print 'fn = %s'%(fn) 17 | with open(fn, 'rb') as f: 18 | magic = np.fromfile(f, np.float32, count=1) 19 | if 202021.25 != magic: 20 | print('Magic number incorrect. Invalid .flo file') 21 | return None 22 | else: 23 | w = np.fromfile(f, np.int32, count=1) 24 | h = np.fromfile(f, np.int32, count=1) 25 | # print 'Reading %d x %d flo file\n' % (w, h) 26 | data = np.fromfile(f, np.float32, count=2 * int(w) * int(h)) 27 | # Reshape testdata into 3D array (columns, rows, bands) 28 | # The reshape here is for visualization, the original code is (w,h,2) 29 | return np.resize(data, (int(h), int(w), 2)) 30 | 31 | 32 | def readPFM(file): 33 | file = open(file, 'rb') 34 | 35 | color = None 36 | width = None 37 | height = None 38 | scale = None 39 | endian = None 40 | 41 | header = file.readline().rstrip() 42 | if header == b'PF': 43 | color = True 44 | elif header == b'Pf': 45 | color = False 46 | else: 47 | raise Exception('Not a PFM file.') 48 | 49 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) 50 | if dim_match: 51 | width, height = map(int, dim_match.groups()) 52 | else: 53 | raise Exception('Malformed PFM header.') 54 | 55 | scale = float(file.readline().rstrip()) 56 | if scale < 0: # little-endian 57 | endian = '<' 58 | scale = -scale 59 | else: 60 | endian = '>' # big-endian 61 | 62 | data = np.fromfile(file, endian + 'f') 63 | shape = (height, width, 3) if color else (height, width) 64 | 65 | data = np.reshape(data, shape) 66 | data = np.flipud(data) 67 | return data 68 | 69 | 70 | def writeFlow(filename, uv, v=None): 71 | """ Write optical flow to file. 72 | 73 | If v is None, uv is assumed to contain both u and v channels, 74 | stacked in depth. 75 | Original code by Deqing Sun, adapted from Daniel Scharstein. 76 | """ 77 | nBands = 2 78 | 79 | if v is None: 80 | assert (uv.ndim == 3) 81 | assert (uv.shape[2] == 2) 82 | u = uv[:, :, 0] 83 | v = uv[:, :, 1] 84 | else: 85 | u = uv 86 | 87 | assert (u.shape == v.shape) 88 | height, width = u.shape 89 | f = open(filename, 'wb') 90 | # write the header 91 | f.write(TAG_CHAR) 92 | np.array(width).astype(np.int32).tofile(f) 93 | np.array(height).astype(np.int32).tofile(f) 94 | # arrange into matrix form 95 | tmp = np.zeros((height, width * nBands)) 96 | tmp[:, np.arange(width) * 2] = u 97 | tmp[:, np.arange(width) * 2 + 1] = v 98 | tmp.astype(np.float32).tofile(f) 99 | f.close() 100 | 101 | 102 | def readFlowKITTI(filename): 103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR) 104 | flow = flow[:, :, ::-1].astype(np.float32) 105 | flow, valid = flow[:, :, :2], flow[:, :, 2] 106 | flow = (flow - 2 ** 15) / 64.0 107 | return flow, valid 108 | 109 | 110 | def readDispKITTI(filename): 111 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 112 | valid = disp > 0.0 113 | flow = np.stack([-disp, np.zeros_like(disp)], -1) 114 | return flow, valid 115 | 116 | 117 | def writeFlowKITTI(filename, uv): 118 | uv = 64.0 * uv + 2 ** 15 119 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 120 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 121 | cv2.imwrite(filename, uv[..., ::-1]) 122 | 123 | 124 | def read_gen(file_name, pil=False): 125 | ext = splitext(file_name)[-1] 126 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 127 | return Image.open(file_name) 128 | elif ext == '.bin' or ext == '.raw': 129 | return np.load(file_name) 130 | elif ext == '.flo': 131 | return readFlow(file_name).astype(np.float32) 132 | elif ext == '.pfm': 133 | flow = readPFM(file_name).astype(np.float32) 134 | if len(flow.shape) == 2: 135 | return flow 136 | else: 137 | return flow[:, :, :-1] 138 | return [] 139 | 140 | 141 | def read_vkitti2_flow(filename): 142 | # In R, flow along x-axis normalized by image width and quantized to [0;2^16 – 1] 143 | # In G, flow along x-axis normalized by image width and quantized to [0;2^16 – 1] 144 | # B = 0 for invalid flow (e.g., sky pixels) 145 | bgr = cv2.imread(filename, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) 146 | h, w, _c = bgr.shape 147 | assert bgr.dtype == np.uint16 and _c == 3 148 | # b == invalid flow flag == 0 for sky or other invalid flow 149 | invalid = bgr[:, :, 0] == 0 150 | # g,r == flow_y,x normalized by height,width and scaled to [0;2**16 – 1] 151 | out_flow = 2.0 / (2 ** 16 - 1.0) * bgr[:, :, 2:0:-1].astype('f4') - 1 # [H, W, 2] 152 | out_flow[..., 0] *= (w - 1) 153 | out_flow[..., 1] *= (h - 1) 154 | 155 | out_flow[invalid] = 0.000001 # invalid as very small value to add supervison on the sky 156 | valid = (np.logical_or(invalid, ~invalid)).astype(np.float32) 157 | 158 | return out_flow, valid 159 | -------------------------------------------------------------------------------- /unimatch/utils/logger.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from utils.flow_viz import flow_tensor_to_image 4 | from .visualization import viz_depth_tensor 5 | 6 | 7 | class Logger: 8 | def __init__(self, lr_scheduler, 9 | summary_writer, 10 | summary_freq=100, 11 | start_step=0, 12 | img_mean=None, 13 | img_std=None, 14 | ): 15 | self.lr_scheduler = lr_scheduler 16 | self.total_steps = start_step 17 | self.running_loss = {} 18 | self.summary_writer = summary_writer 19 | self.summary_freq = summary_freq 20 | 21 | self.img_mean = img_mean 22 | self.img_std = img_std 23 | 24 | def print_training_status(self, mode='train', is_depth=False): 25 | if is_depth: 26 | print('step: %06d \t loss: %.3f' % (self.total_steps, self.running_loss['total_loss'] / self.summary_freq)) 27 | else: 28 | print('step: %06d \t epe: %.3f' % (self.total_steps, self.running_loss['epe'] / self.summary_freq)) 29 | 30 | for k in self.running_loss: 31 | self.summary_writer.add_scalar(mode + '/' + k, 32 | self.running_loss[k] / self.summary_freq, self.total_steps) 33 | self.running_loss[k] = 0.0 34 | 35 | def lr_summary(self): 36 | lr = self.lr_scheduler.get_last_lr()[0] 37 | self.summary_writer.add_scalar('lr', lr, self.total_steps) 38 | 39 | def add_image_summary(self, img1, img2, flow_preds=None, flow_gt=None, mode='train', 40 | is_depth=False, 41 | ): 42 | if self.total_steps % self.summary_freq == 0: 43 | if is_depth: 44 | img1 = self.unnormalize_image(img1.detach().cpu()) # [3, H, W], range [0, 1] 45 | img2 = self.unnormalize_image(img2.detach().cpu()) 46 | 47 | concat = torch.cat((img1, img2), dim=-1) # [3, H, W*2] 48 | 49 | self.summary_writer.add_image(mode + '/img', concat, self.total_steps) 50 | else: 51 | img_concat = torch.cat((img1[0].detach().cpu(), img2[0].detach().cpu()), dim=-1) 52 | img_concat = img_concat.type(torch.uint8) # convert to uint8 to visualize in tensorboard 53 | 54 | flow_pred = flow_tensor_to_image(flow_preds[-1][0]) 55 | forward_flow_gt = flow_tensor_to_image(flow_gt[0]) 56 | flow_concat = torch.cat((torch.from_numpy(flow_pred), 57 | torch.from_numpy(forward_flow_gt)), dim=-1) 58 | 59 | concat = torch.cat((img_concat, flow_concat), dim=-2) 60 | 61 | self.summary_writer.add_image(mode + '/img_pred_gt', concat, self.total_steps) 62 | 63 | def add_depth_summary(self, depth_pred, depth_gt, mode='train'): 64 | # assert depth_pred.dim() == 2 # [H, W] 65 | if self.total_steps % self.summary_freq == 0 or 'val' in mode: 66 | pred_viz = viz_depth_tensor(depth_pred.detach().cpu()) # [3, H, W] 67 | gt_viz = viz_depth_tensor(depth_gt.detach().cpu()) 68 | 69 | concat = torch.cat((pred_viz, gt_viz), dim=-1) # [3, H, W*2] 70 | 71 | self.summary_writer.add_image(mode + '/depth_pred_gt', concat, self.total_steps) 72 | 73 | def unnormalize_image(self, img): 74 | # img: [3, H, W], used for visualizing image 75 | mean = torch.tensor(self.img_mean).view(3, 1, 1).type_as(img) 76 | std = torch.tensor(self.img_std).view(3, 1, 1).type_as(img) 77 | 78 | out = img * std + mean 79 | 80 | return out 81 | 82 | def push(self, metrics, mode='train', is_depth=False, ): 83 | self.total_steps += 1 84 | 85 | self.lr_summary() 86 | 87 | for key in metrics: 88 | if key not in self.running_loss: 89 | self.running_loss[key] = 0.0 90 | 91 | self.running_loss[key] += metrics[key] 92 | 93 | if self.total_steps % self.summary_freq == 0: 94 | self.print_training_status(mode, is_depth=is_depth) 95 | self.running_loss = {} 96 | 97 | def write_dict(self, results): 98 | for key in results: 99 | tag = key.split('_')[0] 100 | tag = tag + '/' + key 101 | self.summary_writer.add_scalar(tag, results[key], self.total_steps) 102 | 103 | def close(self): 104 | self.summary_writer.close() 105 | -------------------------------------------------------------------------------- /unimatch/utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | 5 | 6 | def read_text_lines(filepath): 7 | with open(filepath, 'r') as f: 8 | lines = f.readlines() 9 | lines = [l.rstrip() for l in lines] 10 | return lines 11 | 12 | 13 | def check_path(path): 14 | if not os.path.exists(path): 15 | os.makedirs(path, exist_ok=True) # explicitly set exist_ok when multi-processing 16 | 17 | 18 | def save_command(save_path, filename='command_train.txt'): 19 | check_path(save_path) 20 | command = sys.argv 21 | save_file = os.path.join(save_path, filename) 22 | # Save all training commands when resuming training 23 | with open(save_file, 'a') as f: 24 | f.write(' '.join(command)) 25 | f.write('\n\n') 26 | 27 | 28 | def save_args(args, filename='args.json'): 29 | args_dict = vars(args) 30 | check_path(args.checkpoint_dir) 31 | save_path = os.path.join(args.checkpoint_dir, filename) 32 | 33 | # save all training args when resuming training 34 | with open(save_path, 'a') as f: 35 | json.dump(args_dict, f, indent=4, sort_keys=False) 36 | f.write('\n\n') 37 | -------------------------------------------------------------------------------- /unimatch/utils/visualization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import numpy as np 4 | import torchvision.utils as vutils 5 | import cv2 6 | from matplotlib.cm import get_cmap 7 | import matplotlib as mpl 8 | import matplotlib.cm as cm 9 | 10 | 11 | def vis_disparity(disp): 12 | disp_vis = (disp - disp.min()) / (disp.max() - disp.min()) * 255.0 13 | disp_vis = disp_vis.astype("uint8") 14 | disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO) 15 | 16 | return disp_vis 17 | 18 | 19 | def gen_error_colormap(): 20 | cols = np.array( 21 | [[0 / 3.0, 0.1875 / 3.0, 49, 54, 149], 22 | [0.1875 / 3.0, 0.375 / 3.0, 69, 117, 180], 23 | [0.375 / 3.0, 0.75 / 3.0, 116, 173, 209], 24 | [0.75 / 3.0, 1.5 / 3.0, 171, 217, 233], 25 | [1.5 / 3.0, 3 / 3.0, 224, 243, 248], 26 | [3 / 3.0, 6 / 3.0, 254, 224, 144], 27 | [6 / 3.0, 12 / 3.0, 253, 174, 97], 28 | [12 / 3.0, 24 / 3.0, 244, 109, 67], 29 | [24 / 3.0, 48 / 3.0, 215, 48, 39], 30 | [48 / 3.0, np.inf, 165, 0, 38]], dtype=np.float32) 31 | cols[:, 2: 5] /= 255. 32 | return cols 33 | 34 | 35 | def disp_error_img(D_est_tensor, D_gt_tensor, abs_thres=3., rel_thres=0.05, dilate_radius=1): 36 | D_gt_np = D_gt_tensor.detach().cpu().numpy() 37 | D_est_np = D_est_tensor.detach().cpu().numpy() 38 | B, H, W = D_gt_np.shape 39 | # valid mask 40 | mask = D_gt_np > 0 41 | # error in percentage. When error <= 1, the pixel is valid since <= 3px & 5% 42 | error = np.abs(D_gt_np - D_est_np) 43 | error[np.logical_not(mask)] = 0 44 | error[mask] = np.minimum(error[mask] / abs_thres, (error[mask] / D_gt_np[mask]) / rel_thres) 45 | # get colormap 46 | cols = gen_error_colormap() 47 | # create error image 48 | error_image = np.zeros([B, H, W, 3], dtype=np.float32) 49 | for i in range(cols.shape[0]): 50 | error_image[np.logical_and(error >= cols[i][0], error < cols[i][1])] = cols[i, 2:] 51 | # TODO: imdilate 52 | # error_image = cv2.imdilate(D_err, strel('disk', dilate_radius)); 53 | error_image[np.logical_not(mask)] = 0. 54 | # show color tag in the top-left cornor of the image 55 | for i in range(cols.shape[0]): 56 | distance = 20 57 | error_image[:, :10, i * distance:(i + 1) * distance, :] = cols[i, 2:] 58 | 59 | return torch.from_numpy(np.ascontiguousarray(error_image.transpose([0, 3, 1, 2]))) 60 | 61 | 62 | def save_images(logger, mode_tag, images_dict, global_step): 63 | images_dict = tensor2numpy(images_dict) 64 | for tag, values in images_dict.items(): 65 | if not isinstance(values, list) and not isinstance(values, tuple): 66 | values = [values] 67 | for idx, value in enumerate(values): 68 | if len(value.shape) == 3: 69 | value = value[:, np.newaxis, :, :] 70 | value = value[:1] 71 | value = torch.from_numpy(value) 72 | 73 | image_name = '{}/{}'.format(mode_tag, tag) 74 | if len(values) > 1: 75 | image_name = image_name + "_" + str(idx) 76 | logger.add_image(image_name, vutils.make_grid(value, padding=0, nrow=1, normalize=True, scale_each=True), 77 | global_step) 78 | 79 | 80 | def tensor2numpy(var_dict): 81 | for key, vars in var_dict.items(): 82 | if isinstance(vars, np.ndarray): 83 | var_dict[key] = vars 84 | elif isinstance(vars, torch.Tensor): 85 | var_dict[key] = vars.data.cpu().numpy() 86 | else: 87 | raise NotImplementedError("invalid input type for tensor2numpy") 88 | 89 | return var_dict 90 | 91 | 92 | def viz_depth_tensor(disp, return_numpy=False, colormap='plasma'): 93 | # visualize inverse depth 94 | assert isinstance(disp, torch.Tensor) 95 | 96 | disp = disp.numpy() 97 | vmax = np.percentile(disp, 95) 98 | normalizer = mpl.colors.Normalize(vmin=disp.min(), vmax=vmax) 99 | mapper = cm.ScalarMappable(norm=normalizer, cmap=colormap) 100 | colormapped_im = (mapper.to_rgba(disp)[:, :, :3] * 255).astype(np.uint8) # [H, W, 3] 101 | 102 | if return_numpy: 103 | return colormapped_im 104 | 105 | viz = torch.from_numpy(colormapped_im).permute(2, 0, 1) # [3, H, W] 106 | 107 | return viz 108 | -------------------------------------------------------------------------------- /utils/attention_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import torch 4 | import torch.nn.functional as F 5 | from einops import rearrange 6 | 7 | 8 | def reshape_heads_to_batch_dim3(tensor, head_size): 9 | batch_size1, batch_size2, seq_len, dim = tensor.shape 10 | tensor = tensor.reshape(batch_size1, batch_size2, 11 | seq_len, head_size, dim // head_size) 12 | tensor = tensor.permute(0, 3, 1, 2, 4) 13 | return tensor 14 | 15 | 16 | def trajectory_hidden_states(query, 17 | key, 18 | value, 19 | trajectories, 20 | use_old_qk, 21 | extra_options, 22 | n_heads): 23 | if not use_old_qk: 24 | query = value 25 | key = value 26 | # TODO: Hardcoded for SD1.5 27 | _,_, oh, ow = extra_options['original_shape'] 28 | height = 64 # int(value.shape[1]**0.5) 29 | width = height 30 | cond_size = len(extra_options['cond_or_uncond']) 31 | video_length = len(query) // cond_size 32 | 33 | sub_idxs = extra_options.get('ad_params', {}).get('sub_idxs', None) 34 | idx = 0 35 | if sub_idxs is not None: 36 | idx = sub_idxs[0] 37 | 38 | traj_window = trajectories['trajectory_windows'][idx] 39 | if f'traj{height}' not in traj_window: 40 | return value 41 | trajs = traj_window[f'traj{height}'] 42 | traj_mask = traj_window[f'mask{height}'] 43 | 44 | start = -video_length+1 45 | end = trajs.shape[2] 46 | 47 | traj_key_sequence_inds = torch.cat( 48 | [trajs[:, :, 0, :].unsqueeze(-2), trajs[:, :, start:end, :]], dim=-2) 49 | traj_mask = torch.cat([traj_mask[:, :, 0].unsqueeze(-1), 50 | traj_mask[:, :, start:end]], dim=-1) 51 | 52 | t_inds = traj_key_sequence_inds[:, :, :, 0] 53 | x_inds = traj_key_sequence_inds[:, :, :, 1] 54 | y_inds = traj_key_sequence_inds[:, :, :, 2] 55 | 56 | query_tempo = query.unsqueeze(-2) 57 | _key = rearrange(key, '(b f) (h w) d -> b f h w d', 58 | b=cond_size, h=height) 59 | _value = rearrange(value, '(b f) (h w) d -> b f h w d', 60 | b=cond_size, h=height) 61 | key_tempo = _key[:, t_inds, x_inds, y_inds] 62 | value_tempo = _value[:, t_inds, x_inds, y_inds] 63 | key_tempo = rearrange(key_tempo, 'b f n l d -> (b f) n l d') 64 | value_tempo = rearrange(value_tempo, 'b f n l d -> (b f) n l d') 65 | 66 | traj_mask = rearrange(torch.stack( 67 | [traj_mask] * cond_size), 'b f n l -> (b f) n l') 68 | traj_mask = traj_mask[:, None].repeat( 69 | 1, n_heads, 1, 1).unsqueeze(-2) 70 | attn_bias = torch.zeros_like( 71 | traj_mask, dtype=key_tempo.dtype, device=query.device) # regular zeros_like 72 | attn_bias[~traj_mask] = -torch.inf 73 | 74 | # flow attention 75 | query_tempo = reshape_heads_to_batch_dim3(query_tempo, n_heads) 76 | key_tempo = reshape_heads_to_batch_dim3(key_tempo, n_heads) 77 | value_tempo = reshape_heads_to_batch_dim3(value_tempo, n_heads) 78 | 79 | attn_matrix2 = query_tempo @ key_tempo.transpose(-2, -1) / math.sqrt( 80 | query_tempo.size(-1)) + attn_bias 81 | attn_matrix2 = F.softmax(attn_matrix2, dim=-1) 82 | out = (attn_matrix2@value_tempo).squeeze(-2) 83 | 84 | hidden_states = rearrange(out, 'b k r d -> b r (k d)') 85 | 86 | return hidden_states 87 | 88 | 89 | def select_pivot_indexes(length, batch_size, seed=None): 90 | # Create a new Random object with the given seed 91 | rng = random.Random(seed) 92 | 93 | # Use the seeded Random object to generate the random index 94 | rnd_idx = rng.randint(0, batch_size - 1) 95 | 96 | return [min(i, length-1) for i in range(rnd_idx, length, batch_size)] + [length-1] -------------------------------------------------------------------------------- /utils/batching_utils.py: -------------------------------------------------------------------------------- 1 | 2 | # Adjusted from ADE: https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved 3 | 4 | def create_windows_static_standard(num_frames, context_length, overlap): 5 | windows = [] 6 | if num_frames <= context_length or context_length == 0: 7 | windows.append(list(range(num_frames))) 8 | return windows 9 | # always return the same set of windows 10 | delta = context_length - overlap 11 | for start_idx in range(0, num_frames, delta): 12 | # if past the end of frames, move start_idx back to allow same context_length 13 | ending = start_idx + context_length 14 | if ending >= num_frames: 15 | final_delta = ending - num_frames 16 | final_start_idx = start_idx - final_delta 17 | windows.append( 18 | list(range(final_start_idx, final_start_idx + context_length))) 19 | break 20 | windows.append(list(range(start_idx, start_idx + context_length))) 21 | return windows 22 | -------------------------------------------------------------------------------- /utils/module_utils.py: -------------------------------------------------------------------------------- 1 | def isinstance_str(x: object, cls_name: str): 2 | for _cls in x.__class__.__mro__: 3 | if _cls.__name__ == cls_name: 4 | return True 5 | 6 | return False 7 | -------------------------------------------------------------------------------- /utils/noise_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_alphacumprod(sigma): 5 | return 1 / ((sigma * sigma) + 1) 6 | 7 | 8 | def add_noise(src_latent, noise, sigma): 9 | alphas_cumprod = get_alphacumprod(sigma) 10 | 11 | sqrt_alpha_prod = alphas_cumprod ** 0.5 12 | sqrt_alpha_prod = sqrt_alpha_prod.flatten() 13 | while len(sqrt_alpha_prod.shape) < len(src_latent.shape): 14 | sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) 15 | 16 | sqrt_one_minus_alpha_prod = (1 - alphas_cumprod) ** 0.5 17 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() 18 | while len(sqrt_one_minus_alpha_prod.shape) < len(src_latent.shape): 19 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) 20 | 21 | noisy_samples = sqrt_alpha_prod * src_latent + sqrt_one_minus_alpha_prod * noise 22 | return noisy_samples 23 | 24 | 25 | def add_noise_test(latents, sigma, noise=None): 26 | alpha_cumprod = 1/ ((sigma * sigma) + 1) 27 | sqrt_alpha_prod = alpha_cumprod ** 0.5 28 | sqrt_one_minus_alpha_prod = (1 - alpha_cumprod) ** 0.5 29 | if noise is None: 30 | generator = torch.Generator(device='cuda') 31 | # generator.manual_seed(0) 32 | noise = torch.empty_like(latents).normal_(generator=generator).to(latents.device) 33 | 34 | return sqrt_alpha_prod * latents + sqrt_one_minus_alpha_prod * noise -------------------------------------------------------------------------------- /vv_defaults.py: -------------------------------------------------------------------------------- 1 | SD1_ATTN_INJ_DEFAULTS = set([]) 2 | for idx in [1,2,3,4,5,6]: 3 | SD1_ATTN_INJ_DEFAULTS.add(('output', idx)) 4 | 5 | SD1_RES_INJ_DEFAULTS = set() 6 | for idx in [3,4,6]: 7 | SD1_RES_INJ_DEFAULTS.add(('output', idx)) 8 | 9 | SD1_FLOW_MAP = set([('input', 0), ('input', 1), ('output', 6), ('output', 7), ('output', 8)]) 10 | 11 | SD1_OUTER_MAP = SD1_FLOW_MAP 12 | SD1_INNER_MAP = set([]) 13 | for i in range(2, 12): 14 | SD1_INNER_MAP.add(('input', i)) 15 | for i in range(6): 16 | SD1_INNER_MAP.add(('output', i)) 17 | 18 | SD_FULL_MAP = set([]) 19 | for i in range(40): 20 | SD_FULL_MAP.add(('input', i)) 21 | SD_FULL_MAP.add(('output', i)) 22 | 23 | SD1_INPUT_MAP = set([]) 24 | SD1_OUTPUT_MAP = set([]) 25 | for i in range(12): 26 | SD1_INPUT_MAP.add(('input', i)) 27 | SD1_OUTPUT_MAP.add(('output', i)) 28 | 29 | # TODO get actual adj upsampler indexes 30 | SDXL_ATTN_INJ_DEFAULTS = set([]) 31 | for idx in [10, 15, 20, 25, 30]: 32 | SDXL_ATTN_INJ_DEFAULTS.add(('output', idx)) 33 | 34 | SDXL_RES_INJ_DEFAULTS = set([]) 35 | for idx in [1,2,4]: 36 | SDXL_RES_INJ_DEFAULTS.add(('output', idx)) 37 | 38 | SDXL_FLOW_MAP = set([]) 39 | for idx in range(36): 40 | SDXL_FLOW_MAP.add(('input', idx)) 41 | SDXL_FLOW_MAP.add(('output', idx)) 42 | 43 | # These are approximate 44 | SDXL_OUTER_MAP = set([]) 45 | for i in range(30, 40): 46 | SDXL_OUTER_MAP.add(('output', i)) 47 | for i in range(5): 48 | SDXL_OUTER_MAP.add(('input', i)) 49 | 50 | SDXL_INNER_MAP =set([]) 51 | for i in range(30): 52 | SDXL_INNER_MAP.add(('output', i)) 53 | for i in range(5, 40): 54 | SDXL_INNER_MAP.add(('input', i)) 55 | 56 | 57 | SDXL_INPUT_MAP = set([]) 58 | SDXL_OUTPUT_MAP = set([]) 59 | for i in range(40): 60 | SDXL_INPUT_MAP.add(('input', i)) 61 | SDXL_OUTER_MAP.add(('output', i)) 62 | 63 | MAP_TYPES = ['none', 'inner', 'outer', 'full', 'input', 'output'] 64 | 65 | SD1_MAPS = { 66 | 'none': set(), 67 | 'inner': SD1_INNER_MAP, 68 | 'outer': SD1_OUTER_MAP, 69 | 'full': SD_FULL_MAP, 70 | 'output': SD1_OUTPUT_MAP, 71 | 'input': SD1_INPUT_MAP, 72 | } 73 | 74 | SDXL_MAPS = { 75 | 'none': set(), 76 | 'inner': SDXL_INNER_MAP, 77 | 'outer': SDXL_OUTER_MAP, 78 | 'full': SD_FULL_MAP, 79 | 'output': SDXL_OUTPUT_MAP, 80 | 'input': SDXL_INPUT_MAP, 81 | } 82 | --------------------------------------------------------------------------------