├── models ├── scv │ ├── __init__.py │ ├── utils │ │ ├── __init__.py │ │ ├── frame_utils.py │ │ ├── flow_viz.py │ │ └── utils.py │ ├── compute_sparse_correlation.py │ ├── knn.py │ ├── update.py │ ├── sparsenet.py │ └── extractor.py ├── gma │ ├── utils │ │ ├── __init__.py │ │ ├── frame_utils.py │ │ ├── flow_viz.py │ │ └── utils.py │ ├── gma.py │ ├── update.py │ ├── network.py │ ├── corr.py │ └── extractor.py ├── raft │ ├── utils │ │ ├── __init__.py │ │ ├── frame_utils.py │ │ ├── flow_viz.py │ │ └── utils.py │ ├── corrector.py │ └── raft.py └── __init__.py ├── requirements.txt ├── losses ├── __init__.py ├── loss_utils │ ├── selfsup_loss_utils.py │ └── smoothness_loss_utils.py ├── losses.py └── losses_corrections.py ├── runners ├── __init__.py ├── runners.py └── runners_corrections.py ├── datasets ├── __init__.py ├── utils_data.py └── datasets.py ├── augmentations ├── __init__.py └── augmentations_corrections.py ├── utils ├── distances.py ├── metrics.py ├── utils_corrections.py ├── coords_and_warp.py ├── config.py ├── masks_and_occlusions.py ├── utils.py └── argument_parser.py ├── scripts ├── train_baseline.sh ├── eval.sh └── train_brightflow.sh ├── README.md ├── evaluate.py └── train.py /models/scv/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/gma/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/scv/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/raft/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch 2 | torchvision 3 | opencv-python 4 | einops 5 | scikit-image 6 | faiss-gpu 7 | tensorboard -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .losses import Loss 2 | from .losses_corrections import LossCorrections 3 | 4 | Loss = Loss 5 | LossCorrections = LossCorrections -------------------------------------------------------------------------------- /runners/__init__.py: -------------------------------------------------------------------------------- 1 | from .runners import Runner 2 | from .runners_corrections import RunnerCorrection 3 | 4 | Runner = Runner 5 | RunnerCorrection = RunnerCorrection -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from datasets import datasets 2 | 3 | Chairs = datasets.Chairs 4 | KITTI = datasets.KITTI 5 | Sintel = datasets.MpiSintel 6 | HD1K = datasets.HD1K 7 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .raft import raft 2 | from .gma import network 3 | # from .scv import sparsenet 4 | 5 | raft = raft.RAFT 6 | gma = network.RAFTGMA 7 | # scv = sparsenet.SparseNet -------------------------------------------------------------------------------- /augmentations/__init__.py: -------------------------------------------------------------------------------- 1 | from .augmentations import Augmentor 2 | from .augmentations_corrections import AugmentorCorrections 3 | 4 | Augmentor = Augmentor 5 | AugmentorCorrections = AugmentorCorrections 6 | -------------------------------------------------------------------------------- /utils/distances.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | def robust_l1(x): 8 | """Robust L1 metric.""" 9 | return (x**2 + 0.001**2)**0.5 10 | 11 | 12 | def abs_robust_loss(diff, eps=0.01, q=0.4): 13 | """The so-called robust loss used by DDFlow.""" 14 | return (torch.abs(diff) + eps) ** q 15 | 16 | 17 | def l1(x): 18 | """L1 metric.""" 19 | return torch.norm(x, p=1, dim=-3, keepdim=True) 20 | -------------------------------------------------------------------------------- /scripts/train_baseline.sh: -------------------------------------------------------------------------------- 1 | python train.py \ 2 | --ckpt_dir /path/to/checkpoint/dir/ \ 3 | --log_dir /path/to/log/dir/ \ 4 | --restore_ckpt /path/to/checkpoint.pth \ 5 | --name baseline \ 6 | --dataset_train Sintel \ 7 | --dataset_test Sintel \ 8 | --model raft \ 9 | --batch_size 8 \ 10 | --num_steps 75000 \ 11 | --lr 0.0002 \ 12 | --lr_decay_step 15000 \ 13 | --census_weight_flow 1. \ 14 | --selfsup_starting_step 30000 \ 15 | --selfsup_end_rising_step 37500 \ 16 | --selfsup_weight_max 0.3 \ 17 | --mode flow_only \ -------------------------------------------------------------------------------- /scripts/eval.sh: -------------------------------------------------------------------------------- 1 | ~/miniconda3/envs/brightflow2/bin/python -u evaluate.py \ 2 | --name KITTI \ 3 | --dataset_train KITTI \ 4 | --dataset_test KITTI \ 5 | --model raft \ 6 | --batch_size 1 \ 7 | --num_steps 75000 \ 8 | --lr 0.0002 \ 9 | --lr_decay_step 15000 \ 10 | --crop_size 296 696 \ 11 | --occlusions wang \ 12 | --census_weight_flow 0.5 \ 13 | --unflow_weight_flow 0. \ 14 | --l1_weight_flow 0. \ 15 | --census_weight_correc 0. \ 16 | --unflow_weight_correc 0. \ 17 | --l1_weight_correc 1. \ 18 | --selfsup_starting_step 30000 \ 19 | --selfsup_end_rising_step 37500 \ 20 | --selfsup_weight_max 0.3 \ 21 | --restore_ckpt /path/to/checkpoint.pth \ -------------------------------------------------------------------------------- /scripts/train_brightflow.sh: -------------------------------------------------------------------------------- 1 | python train.py \ 2 | --ckpt_dir /path/to/checkpoint/dir/ \ 3 | --log_dir /path/to/log/dir/ \ 4 | --restore_ckpt /path/to/checkpoint.pth \ 5 | --name brightflow \ 6 | --dataset_train Sintel \ 7 | --dataset_test Sintel \ 8 | --model raft \ 9 | --batch_size 8 \ 10 | --num_steps 75000 \ 11 | --lr 0.0002 \ 12 | --lr_decay_step 15000 \ 13 | --census_weight_flow 1. \ 14 | --l1_weight_correc 1. \ 15 | --selfsup_starting_step 30000 \ 16 | --selfsup_end_rising_step 37500 \ 17 | --selfsup_weight_max 0.3 \ 18 | --mode flow_correc \ 19 | --correc_weight 0.1 \ 20 | --correc_in_photo_starting_step 25000 \ 21 | --correc_starting_step 20000 \ 22 | --use_full_size_warping \ 23 | --smart_clamp \ 24 | --occ_in_correc_inputs \ 25 | --sequentially \ 26 | --keep_good_corrections_only \ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BrightFlow (WACV 2023) 2 | This repository contains the official implementation of [BrightFlow: Brightness-Change-Aware Unsupervised Learning of Optical Flow](https://openaccess.thecvf.com/content/WACV2023/html/Marsal_BrightFlow_Brightness-Change-Aware_Unsupervised_Learning_of_Optical_Flow_WACV_2023_paper.html) that has been published to the **IEEE Winter Conference on Applications of Computer Vision (WACV) 2023**. 3 | 4 | ## Requirements 5 | 6 | ``` 7 | requirement.txt 8 | ``` 9 | 10 | ## Datasets 11 | 12 | To train/evaluate BrightFlow or the baseline without BrightFlow, please download the required datasets: 13 | * [Sintel](http://sintel.is.tue.mpg.de/) 14 | * [KITTI](https://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow) 15 | * [HD1K](http://hci-benchmark.iwr.uni-heidelberg.de/) 16 | 17 | ## Training 18 | 19 | #### Baseline 20 | 21 | 22 | 23 | ``` 24 | sh script/train_baseline.sh 25 | ``` 26 | 27 | #### BrightFlow 28 | 29 | ``` 30 | sh script/train_brightflow.sh 31 | ``` 32 | 33 | ## Evaluation 34 | 35 | The checkpoints of trained models are available [here](https://drive.google.com/drive/folders/1r2LrW3svWW1kQ98u2D8CaapHXjrV43zD?usp=drive_link). 36 | 37 | ``` 38 | sh script/eval.sh 39 | ``` 40 | 41 | ## Acknowledgements 42 | 43 | We thank authors of [RAFT](https://github.com/princeton-vl/RAFT/), [GMA](https://github.com/zacjiang/GMA), [SCV](https://github.com/zacjiang/SCV) and [SMURF](https://github.com/google-research/google-research/tree/master/smurf) for their great work and for sharing their code. 44 | 45 | ## Citation 46 | 47 | ``` 48 | @inproceedings{marsal2023brightflow, 49 | title={BrightFlow: Brightness-Change-Aware Unsupervised Learning of Optical Flow}, 50 | author={Marsal, Remi and Chabot, Florian and Loesch, Angelique and Sahbi, Hichem}, 51 | booktitle={Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision}, 52 | pages={2061--2070}, 53 | year={2023} 54 | } 55 | ``` 56 | 57 | ## License 58 | 59 | This project is under the CeCILL license 2.1. -------------------------------------------------------------------------------- /models/scv/compute_sparse_correlation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .knn import knn_faiss_raw 3 | from .utils.utils import coords_grid, coords_grid_y_first 4 | 5 | 6 | def normalize_coords(coords, H, W): 7 | """ Normalize coordinates based on feature map shape. coords: [B, 2, N]""" 8 | one = coords.new_tensor(1) 9 | size = torch.stack([one*W, one*H])[None] 10 | center = size / 2 11 | scaling = size.max(1, keepdim=True).values * 0.5 12 | return (coords - center[:, :, None]) / scaling[:, :, None] 13 | 14 | 15 | def compute_sparse_corr_init(fmap1, fmap2, k=32): 16 | """ 17 | Compute a cost volume containing the k-largest hypotheses for each pixel. 18 | Output: corr_mink 19 | """ 20 | B, C, H1, W1 = fmap1.shape 21 | H2, W2 = fmap2.shape[2:] 22 | N = H1 * W1 23 | 24 | fmap1, fmap2 = fmap1.view(B, C, -1), fmap2.view(B, C, -1) 25 | 26 | with torch.no_grad(): 27 | _, indices = knn_faiss_raw(fmap1, fmap2, k) # [B, k, H1*W1] 28 | 29 | indices_coord = indices.unsqueeze(1).expand(-1, 2, -1, -1) # [B, 2, k, H1*W1] 30 | coords0 = coords_grid_y_first(B, H2, W2).view(B, 2, 1, -1).expand(-1, -1, k, -1).to(fmap1.device) # [B, 2, k, H1*W1] 31 | coords1 = coords0.gather(3, indices_coord) # [B, 2, k, H1*W1] 32 | coords1 = coords1 - coords0 33 | 34 | # Append batch index 35 | batch_index = torch.arange(B).view(B, 1, 1, 1).expand(-1, -1, k, N).type_as(coords1) 36 | 37 | # Gather by indices from map2 and compute correlation volume 38 | fmap2 = fmap2.gather(2, indices.view(B, 1, -1).expand(-1, C, -1)).view(B, C, k, N) 39 | me_corr = torch.einsum('bcn,bckn->bkn', fmap1, fmap2).contiguous() / torch.sqrt(torch.tensor(C).float()) # [B, k, H1*W1] 40 | 41 | return me_corr, coords0, coords1, batch_index # coords: [B, 2, k, H1*W1] 42 | 43 | 44 | if __name__ == "__main__": 45 | torch.manual_seed(0) 46 | 47 | for _ in range(100): 48 | fmap1 = torch.randn(8, 256, 92, 124).cuda() 49 | fmap2 = torch.randn(8, 256, 92, 124).cuda() 50 | corr_me = compute_sparse_corr_init(fmap1, fmap2, k=16) 51 | 52 | # corr_dense = corr(fmap1, fmap2) 53 | # corr_max = torch.max(corr_dense, dim=3) 54 | 55 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class Metrics(object): 8 | def __init__(self, args, dataset_name): 9 | self.args = args 10 | if 'kitti' in dataset_name.lower(): 11 | self.data_metric = self.kitti_metrics 12 | elif 'hd1k' in dataset_name.lower(): 13 | self.data_metric = self.kitti_metrics 14 | else: 15 | self.data_metric = self.standard_metrics 16 | 17 | self.flow_metrics = self.epe_f1 18 | 19 | 20 | def epe_f1(self, flow_pred, flow_gt, mask, suffix=''): 21 | ''' Compute EPE and F1 (or %ER) metrics''' 22 | 23 | epe = torch.sum((flow_pred - flow_gt)**2, dim=1, keepdims=True).sqrt() 24 | f1 = torch.logical_and(epe > 3, epe / torch.sum(flow_gt**2, dim=1, keepdims=True).sqrt() > 0.05).float() 25 | epe = (epe * mask).sum(dim=(1, 2, 3)) / mask.sum(dim=(1, 2, 3)) 26 | f1 = (f1 * mask).sum(dim=(1, 2, 3)) / mask.sum(dim=(1, 2, 3)) * 100 27 | 28 | return { 29 | 'epe'+suffix: epe.mean(), 30 | 'f1'+suffix: f1.mean() 31 | } 32 | 33 | def kitti_metrics(self, example, flow_pred): 34 | 35 | flow_occ = example['flow_occ'] 36 | valid_occ = example['valid_occ'] 37 | 38 | flow_noc = example['flow_noc'] 39 | valid_noc = example['valid_noc'] 40 | 41 | metrics_dict = {} 42 | metrics_dict.update(self.flow_metrics(flow_pred, flow_occ, valid_occ, suffix='_occ')) 43 | metrics_dict.update(self.flow_metrics(flow_pred, flow_noc, valid_noc, suffix='_noc')) 44 | 45 | return metrics_dict 46 | 47 | def standard_metrics(self, example, flow_pred): 48 | 49 | flow_gt = example['flow'] 50 | valid = example['valid'] 51 | 52 | metrics_dict = self.flow_metrics(flow_pred, flow_gt, valid) 53 | 54 | return metrics_dict 55 | 56 | def resize_flow(self, flow, new_size): 57 | H, W = flow.size()[2:] 58 | new_H, new_W = new_size 59 | flow = F.interpolate(flow, new_size, mode='bilinear', align_corners=True) 60 | return flow * torch.tensor([new_W/W, new_H/H], device=flow.device).reshape(1, 2, 1, 1) 61 | 62 | def __call__(self, example, output_dict, padder): 63 | 64 | metrics = {} 65 | flow_pred = padder.unpad(output_dict['flow_f_pred']) 66 | metrics.update(self.data_metric(example, flow_pred)) 67 | 68 | return metrics 69 | 70 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | import utils.argument_parser as argument_parser 7 | from utils.utils import InputPadder, to_cuda, list_all_gather_without_backprop 8 | import utils.config as cfg 9 | 10 | 11 | class Validation(object): 12 | def __init__(self, args): 13 | self.args = args 14 | self.dataset_test = args.dataset_test 15 | self.input_padder = InputPadder 16 | 17 | 18 | @torch.no_grad() 19 | def validate(self, runner, loader, metrics, total_step=0): 20 | 21 | dict_cumul = {} 22 | 23 | for _, example in enumerate(loader): 24 | 25 | ## Transfer to cuda 26 | to_cuda(example, excluded_keys=['orig_dims']) 27 | 28 | padder = self.input_padder(example['ims'].shape) 29 | example['ims'] = padder.pad(example['ims'].flatten(end_dim=1)).unflatten(dim=0, sizes=(-1, 2)) 30 | 31 | output_dict, _ = runner(example=example, total_step=total_step, val_mode=True) 32 | 33 | metrics_dict = metrics(example, output_dict, padder) 34 | 35 | for key, value in metrics_dict.items(): 36 | value_gathered = list_all_gather_without_backprop(value) 37 | if key in dict_cumul: 38 | dict_cumul[key].extend(value_gathered) 39 | else: 40 | dict_cumul[key] = value_gathered 41 | 42 | for key, values in dict_cumul.items(): 43 | dict_cumul[key] = np.mean(torch.tensor(values).cpu().numpy()) 44 | 45 | return dict_cumul 46 | 47 | 48 | if __name__ == '__main__': 49 | 50 | ## get arguments 51 | args = argument_parser.get_arguments() 52 | args.gpu = 0 53 | 54 | assert args.batch_size == 1 55 | print(args) 56 | 57 | # set random seeds 58 | cfg.configure_random_seed(args.seed) 59 | 60 | ## get dataloaders 61 | test_loader = cfg.get_test_dataloaders(args) 62 | 63 | runner = cfg.get_runner(args, val_mode=True) 64 | print(sum(p.numel() for p in runner.parameters() if p.requires_grad)) 65 | 66 | runner = nn.DataParallel(runner) 67 | runner.cuda() 68 | runner.eval() 69 | ckpt = torch.load(args.restore_ckpt) 70 | ckpt = {(k.replace('photometric_loss_function', 'census_loss') if 'photometric_loss_function' in k else k): v for k, v in ckpt.items()} 71 | missing_keys, unexpected_keys = runner.load_state_dict(ckpt, strict=False) 72 | print('missing_keys:', missing_keys) 73 | print('unexpected_keys:', unexpected_keys) 74 | 75 | metrics = cfg.get_metrics(args) 76 | 77 | validator = Validation(args) 78 | 79 | with torch.no_grad(): 80 | res_dict = validator.validate(runner, test_loader, metrics) 81 | print(res_dict) -------------------------------------------------------------------------------- /losses/loss_utils/selfsup_loss_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from functools import partial 6 | from enum import Enum 7 | 8 | from utils.distances import robust_l1 9 | 10 | def charbonnier(x, y, eps=0.01, q=0.4): 11 | """The so-called robust loss used by DDFlow.""" 12 | return (robust_l1(x, y) + eps) ** q 13 | 14 | def robust_l1(x, y): 15 | """Robust L1 metric.""" 16 | return ((x - y)**2 + 0.001**2)**0.5 17 | 18 | def huber_charbonnier(x, y, eps=1., q=0.4): 19 | """Home made mix between Charbonnier and Huber loss""" 20 | return ((x - y)**2 + eps)**0.2 21 | 22 | 23 | class Distances(Enum): 24 | L1 = 'l1' 25 | CHARBONNIER = 'charbonnier' 26 | HUBER = 'huber' 27 | CHARBONNIER_HUBER = 'huber_charbonnier' 28 | 29 | 30 | class SelfSupLoss(nn.Module): 31 | def __init__(self, args): 32 | super().__init__() 33 | self.args = args 34 | self.sequence_weights = args.sequence_weight ** torch.arange(args.iters, dtype=torch.float, device=torch.device('cuda:'+str(args.gpu))).view(1, 1, args.iters) 35 | self.selfsup_loss = self.selfsup_loss_sequence_flows 36 | if args.selfsup_distance == 'l1': 37 | self.distance = robust_l1 38 | elif args.selfsup_distance == 'charbonnier': 39 | self.distance = charbonnier 40 | elif args.selfsup_distance == 'huber': 41 | self.distance = partial(F.huber_loss, reduction='none') 42 | elif args.selfsup_distance == 'huber_charbonnier': 43 | self.distance = huber_charbonnier 44 | 45 | def selfsup_transform_flow(self, x): 46 | dims = x.size() 47 | x = x.view(-1, *dims[-3:]) 48 | H, W = dims[-2], dims[-1] 49 | x = x[..., 64:-64, 64:-64] 50 | _, _, H_down, W_down = x.size() 51 | x_up = F.interpolate(x, (H, W), mode='bilinear', align_corners=True) 52 | x_up[:, 0] *= W / W_down 53 | x_up[:, 1] *= H / H_down 54 | return x_up.unflatten(dim=0, sizes=dims[:-3]) 55 | 56 | def selfsup_loss_sequence_flows(self, flows_stud, flows_teacher, masks_eraser, weights=1.): 57 | flows_teacher = self.selfsup_transform_flow(flows_teacher) 58 | flows_teacher = flows_teacher.unsqueeze(2) 59 | masks_eraser = masks_eraser.unsqueeze(2) 60 | selfsup_loss = (self.distance(flows_teacher, flows_stud) * masks_eraser * weights).sum((-1, -2, -3)) / (masks_eraser.sum((-1, -2, -3)) + 1e-6) 61 | return ((selfsup_loss * self.sequence_weights).sum(-1)).mean() 62 | 63 | def selfsup_loss_last_flow(self, flows_stud, flows_teacher, masks_eraser, weights=1.): 64 | flows_teacher = self.selfsup_transform_flow(flows_teacher) 65 | selfsup_loss = (self.distance(flows_teacher, flows_stud[..., 0, :, :, :]) * masks_eraser).sum((-1, -2, -3)) / (masks_eraser.sum((-1, -2, -3)) + 1e-6) 66 | return selfsup_loss.mean() 67 | 68 | def forward(self, flows_stud, flows_teacher, masks_eraser, total_step, weights=1): 69 | return self.selfsup_loss_sequence_flows(flows_stud, flows_teacher, masks_eraser, weights) 70 | -------------------------------------------------------------------------------- /losses/losses.py: -------------------------------------------------------------------------------- 1 | from turtle import forward 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | from utils.coords_and_warp import WarpMulti, WarpFullSizeMulti 8 | from .loss_utils.photometric_loss_utils import PhotometricLossSequential, PhotometricLossParallel 9 | from .loss_utils.smoothness_loss_utils import SmoothnessLoss 10 | from .loss_utils.selfsup_loss_utils import SelfSupLoss 11 | 12 | 13 | class LossBasic(nn.Module): 14 | def __init__(self, args): 15 | super(LossBasic, self).__init__() 16 | self.args = args 17 | self.bwd = torch.tensor([1, 0], device=torch.device('cuda:'+str(args.gpu))) 18 | 19 | self.warp = WarpFullSizeMulti() if args.use_full_size_warping else WarpMulti() 20 | 21 | self.smoothness_loss = SmoothnessLoss(args) 22 | self.smoothness_weight = self.args.smoothness_weight 23 | 24 | self.selfsup_weight = 0. 25 | self.selfsup_loss = SelfSupLoss(args) 26 | 27 | def update_selfsup_weight(self, total_step): 28 | if total_step >= self.args.selfsup_starting_step: 29 | self.selfsup_weight = min(self.args.selfsup_weight_max, (total_step - self.args.selfsup_starting_step) / \ 30 | (self.args.selfsup_end_rising_step - self.args.selfsup_starting_step) * self.args.selfsup_weight_max) 31 | 32 | 33 | class Loss(LossBasic): 34 | def __init__(self, args): 35 | super(Loss, self).__init__(args) 36 | 37 | self.photometric_loss = PhotometricLossSequential(args) if args.sequentially else PhotometricLossParallel(args) 38 | 39 | def forward(self, example, outputs, total_step): 40 | 41 | self.update_selfsup_weight(total_step) 42 | 43 | loss_dict = {} 44 | 45 | # Reconstruction: warping of images with flows predictions 46 | if self.args.use_full_size_warping: 47 | pad_params = example['pad_params'].int() 48 | orig_dims = example['orig_dims'].int() 49 | example['ims_warp'] = self.warp(example['ims_uncropped'][:, self.bwd], outputs['flows_aug'], pad_params, orig_dims) 50 | else: 51 | example['ims_warp'] = self.warp(example['ims'][:, self.bwd], outputs['coords']) 52 | 53 | # Computation of the photometric loss 54 | loss_photo = self.photometric_loss(example, outputs) 55 | 56 | # Computation of the smoothness loss 57 | loss_smooth = self.smoothness_loss(outputs['flows_aug'], example['ims'], example['masks_eraser'], total_step) 58 | 59 | # Computation of the selfsup loss 60 | loss_selfsup = torch.tensor(0.0, device=torch.device('cuda')) 61 | if total_step >= self.args.selfsup_starting_step: 62 | loss_selfsup += self.selfsup_loss(outputs['flows_stud'], outputs['flows_teacher'], example['masks_eraser_stud'], total_step) 63 | 64 | loss_dict['photo'] = loss_photo 65 | loss_dict['smoothness'] = loss_smooth 66 | loss_dict['selfsup'] = loss_selfsup 67 | loss_dict['mask_sum'] = outputs['masks'][:, :, 0].mean() 68 | loss_dict['mean_abs_flow'] = outputs['flows_aug'][:, 0, 0].abs().mean() 69 | 70 | loss_dict['loss_total'] = loss_photo + self.selfsup_weight * loss_selfsup.clamp(0., 100.) + self.args.smoothness_weight * loss_smooth 71 | 72 | return loss_dict 73 | -------------------------------------------------------------------------------- /losses/loss_utils/smoothness_loss_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from functools import partial 6 | 7 | from utils.distances import robust_l1 8 | 9 | 10 | class SmoothnessLoss(nn.Module): 11 | def __init__(self, args): 12 | super().__init__() 13 | self.args = args 14 | self.sequence_weights = args.sequence_weight ** torch.arange(args.iters, dtype=torch.float, device=torch.device('cuda:'+str(args.gpu))).view(1, 1, args.iters) 15 | self.smooth_const = -150 16 | if args.smoothness_order == 1: 17 | self.flow_grad = self.flow_grad_1st_order 18 | elif args.smoothness_order == 2: 19 | self.flow_grad = self.flow_grad_2nd_order 20 | 21 | def grad(self, img, stride=1): 22 | gx = img[..., :, :-stride] - img[..., :, stride:] # NCHW 23 | gy = img[..., :-stride, :] - img[..., stride:, :] # NCHW 24 | return gx, gy 25 | 26 | def grad_img(self, im, stride): 27 | im_grad_x, im_grad_y = self.grad(im, stride) 28 | im_grad_x = im_grad_x.abs().mean(-3, keepdim=True) 29 | im_grad_y = im_grad_y.abs().mean(-3, keepdim=True) 30 | return im_grad_x, im_grad_y 31 | 32 | def get_smoothness_mask(self, mask, stride=1): 33 | mask_x = mask[..., :-stride] * mask[..., stride:] 34 | mask_y = mask[..., :-stride, :] * mask[..., stride:, :] 35 | return mask_x, mask_y 36 | 37 | def flow_grad_1st_order(self, flows): 38 | return self.grad(flows) 39 | 40 | def flow_grad_2nd_order(self, flows): 41 | flows_grad_x, flows_grad_y = self.grad(flows) 42 | flows_grad_xx, _ = self.grad(flows_grad_x) 43 | _, flows_grad_yy = self.grad(flows_grad_y) 44 | return flows_grad_xx, flows_grad_yy 45 | 46 | def smoothness_loss_last_flow(self, flows, ims, masks): 47 | 48 | ims_grad_x, ims_grad_y = self.grad_img(ims, self.args.smoothness_order) 49 | mask_x, mask_y = self.get_smoothness_mask(masks, stride=self.args.smoothness_order) 50 | 51 | flows_grad_x, flows_grad_y = self.flow_grad(flows[..., 0, :, :, :]) 52 | smoothness_loss = ((torch.exp(self.smooth_const * ims_grad_x.abs().mean(-3, keepdim=True)) * robust_l1(flows_grad_x) * mask_x).sum((-1, -2, -3)) / (mask_x.sum((-1, -2, -3)) + 1e-6) + 53 | (torch.exp(self.smooth_const * ims_grad_y.abs().mean(-3, keepdim=True)) * robust_l1(flows_grad_y) * mask_y).sum((-1, -2, -3)) / (mask_y.sum((-1, -2, -3)) + 1e-6)) / 2 54 | return smoothness_loss.mean() 55 | 56 | def smoothness_loss_sequence_flow(self, flows, ims, masks): 57 | ims_grad_x, ims_grad_y = self.grad_img(ims.unsqueeze(2), self.args.smoothness_order) 58 | 59 | mask_x, mask_y = self.get_smoothness_mask(masks.unsqueeze(2), stride=self.args.smoothness_order) 60 | 61 | flows_grad_x, flows_grad_y = self.flow_grad(flows) 62 | smoothness_loss = ((torch.exp(self.smooth_const * ims_grad_x.abs().mean(-3, keepdim=True)) * robust_l1(flows_grad_x) * mask_x).sum((-1, -2, -3)) / (mask_x.sum((-1, -2, -3)) + 1e-6) + 63 | (torch.exp(self.smooth_const * ims_grad_y.abs().mean(-3, keepdim=True)) * robust_l1(flows_grad_y) * mask_y).sum((-1, -2, -3)) / (mask_y.sum((-1, -2, -3)) + 1e-6)) / 2 64 | return ((smoothness_loss * self.sequence_weights).sum(2)).mean() 65 | 66 | def forward(self, flows, ims, masks, total_step): 67 | return self.smoothness_loss_sequence_flow(flows, ims, masks) 68 | -------------------------------------------------------------------------------- /models/scv/knn.py: -------------------------------------------------------------------------------- 1 | import faiss 2 | import torch 3 | 4 | res = faiss.StandardGpuResources() 5 | res.setDefaultNullStreamAllDevices() 6 | 7 | 8 | def swig_ptr_from_Tensor(x): 9 | """ gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """ 10 | assert x.is_contiguous() 11 | 12 | if x.dtype == torch.float32: 13 | return faiss.cast_integer_to_float_ptr(x.storage().data_ptr() + x.storage_offset() * 4) 14 | 15 | if x.dtype == torch.int64: 16 | return faiss.cast_integer_to_idx_t_ptr(x.storage().data_ptr() + x.storage_offset() * 8) 17 | 18 | raise Exception("tensor type not supported: {}".format(x.dtype)) 19 | 20 | 21 | def search_raw_array_pytorch(res, xb, xq, k, D=None, I=None, 22 | metric=faiss.METRIC_L2): 23 | """search xq in xb, without building an index""" 24 | assert xb.device == xq.device 25 | 26 | nq, d = xq.size() 27 | if xq.is_contiguous(): 28 | xq_row_major = True 29 | elif xq.t().is_contiguous(): 30 | xq = xq.t() # I initially wrote xq:t(), Lua is still haunting me :-) 31 | xq_row_major = False 32 | else: 33 | raise TypeError('matrix should be row or column-major') 34 | 35 | xq_ptr = swig_ptr_from_Tensor(xq) 36 | 37 | nb, d2 = xb.size() 38 | assert d2 == d 39 | if xb.is_contiguous(): 40 | xb_row_major = True 41 | elif xb.t().is_contiguous(): 42 | xb = xb.t() 43 | xb_row_major = False 44 | else: 45 | raise TypeError('matrix should be row or column-major') 46 | xb_ptr = swig_ptr_from_Tensor(xb) 47 | 48 | if D is None: 49 | D = torch.empty(nq, k, device=xb.device, dtype=torch.float32) 50 | else: 51 | assert D.shape == (nq, k) 52 | assert D.device == xb.device 53 | 54 | if I is None: 55 | I = torch.empty(nq, k, device=xb.device, dtype=torch.int64) 56 | else: 57 | assert I.shape == (nq, k) 58 | assert I.device == xb.device 59 | 60 | D_ptr = swig_ptr_from_Tensor(D) 61 | I_ptr = swig_ptr_from_Tensor(I) 62 | 63 | args = faiss.GpuDistanceParams() 64 | args.metric = metric 65 | args.k = k 66 | args.dims = d 67 | args.vectors = xb_ptr 68 | args.vectorsRowMajor = xb_row_major 69 | args.numVectors = nb 70 | args.queries = xq_ptr 71 | args.queriesRowMajor = xq_row_major 72 | args.numQueries = nq 73 | args.outDistances = D_ptr 74 | args.outIndices = I_ptr 75 | faiss.bfKnn(res, args) 76 | 77 | return D, I 78 | 79 | 80 | def knn_faiss_raw(fmap1, fmap2, k): 81 | 82 | b, ch, _ = fmap1.shape 83 | 84 | if b == 1: 85 | fmap1 = fmap1.view(ch, -1).t().contiguous() 86 | fmap2 = fmap2.view(ch, -1).t().contiguous() 87 | 88 | dist, indx = search_raw_array_pytorch(res, fmap2, fmap1, k, metric=faiss.METRIC_INNER_PRODUCT) 89 | 90 | dist = dist.t().unsqueeze(0).contiguous() 91 | indx = indx.t().unsqueeze(0).contiguous() 92 | else: 93 | fmap1 = fmap1.view(b, ch, -1).permute(0, 2, 1).contiguous() 94 | fmap2 = fmap2.view(b, ch, -1).permute(0, 2, 1).contiguous() 95 | dist = [] 96 | indx = [] 97 | for i in range(b): 98 | dist_i, indx_i = search_raw_array_pytorch(res, fmap2[i], fmap1[i], k, metric=faiss.METRIC_INNER_PRODUCT) 99 | dist_i = dist_i.t().unsqueeze(0).contiguous() 100 | indx_i = indx_i.t().unsqueeze(0).contiguous() 101 | dist.append(dist_i) 102 | indx.append(indx_i) 103 | dist = torch.cat(dist, dim=0) 104 | indx = torch.cat(indx, dim=0) 105 | return dist, indx 106 | -------------------------------------------------------------------------------- /runners/runners.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from functools import partial 6 | 7 | from utils.utils import InputPadder 8 | from utils.coords_and_warp import Coords 9 | from utils.masks_and_occlusions import Occlusions 10 | from losses import Loss 11 | 12 | 13 | class BasicRunner(nn.Module): 14 | def __init__(self, model, args): 15 | super(BasicRunner, self).__init__() 16 | self.args = args 17 | 18 | self.model = model(args) 19 | self.get_coords = Coords() 20 | self.get_occ_mask = Occlusions(args.occlusions, args.use_full_size_warping) 21 | 22 | self.input_padder = InputPadder 23 | 24 | self.fwd = torch.tensor([0, 1], device=torch.device('cuda:'+str(args.gpu))) 25 | self.bwd = torch.tensor([1, 0], device=torch.device('cuda:'+str(args.gpu))) 26 | 27 | 28 | @torch.no_grad() 29 | def run_inference_step(self, example, fwd_bwd, suffix='_no_suffix_provided'): 30 | return self.model(example['ims'], return_last_flow_only=True, fwd_bwd=fwd_bwd, suffix=suffix) 31 | 32 | def forward(self, example, total_step=None, val_mode=False): 33 | 34 | if val_mode: 35 | outputs = self.run_inference_step(example, fwd_bwd=False, suffix='_pred') 36 | 37 | return outputs, {} 38 | 39 | else: 40 | if self.training: 41 | 42 | outputs = self.run_training_step(example, total_step) 43 | loss_dict = self.loss(example, outputs, total_step) 44 | 45 | return outputs, loss_dict 46 | 47 | else: 48 | with torch.no_grad(): 49 | 50 | padder = self.input_padder(example['ims'].shape) 51 | example['ims'] = padder.pad(example['ims'].flatten(end_dim=1)).unflatten(dim=0, sizes=(-1, 2)) 52 | 53 | outputs = self.run_training_step(example, total_step) 54 | 55 | outputs['flow_f_pred'] = outputs['flows_aug'][:, 0, 0] 56 | 57 | loss_dict = self.loss(example, outputs, total_step) 58 | 59 | return outputs, loss_dict 60 | 61 | 62 | class Runner(BasicRunner): 63 | def __init__(self, model, args): 64 | super(Runner, self).__init__(model, args) 65 | 66 | self.loss = Loss(args) 67 | 68 | def run_training_step(self, example, total_step=None): 69 | 70 | pad_params = example['pad_params'].int() if self.args.use_full_size_warping else None 71 | 72 | # Forward pass of augmented images in the optical flow model 73 | outputs = self.model(example['ims_aug'], fwd_bwd=True, suffix='_aug') 74 | 75 | # Absolute coordinate computations 76 | outputs['coords'] = self.get_coords(outputs['flows_aug']) 77 | 78 | # Compute occlusion mask, boundary (out) mask with and without full-size wapring 79 | outputs['masks'], outputs['outs'], outputs['outs_full_size'] = \ 80 | self.get_occ_mask(outputs['flows_aug'], outputs['coords'], 81 | outputs['flows_aug'][:, self.bwd], outputs['coords'][:, self.bwd], 82 | example['masks_eraser'], pad_params=pad_params) 83 | 84 | if total_step >= self.args.selfsup_starting_step: 85 | # Forward pass of non-augmented images in the optical flow model to get teacher flow 86 | outputs.update(self.run_inference_step(example, fwd_bwd=True, suffix='_teacher')) 87 | # Forward pass of cropped augmented images in the optical flow model to get student flow 88 | outputs.update(self.model(example['ims_aug_stud'], fwd_bwd=True, suffix='_stud')) 89 | 90 | return outputs 91 | 92 | 93 | -------------------------------------------------------------------------------- /utils/utils_corrections.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | def apply_corrections(img, correc, clamp_corrected_images=False): 8 | ''' Rescale corrections between 0 and 255 then apply corrections ''' 9 | img_correc = img + correc*255 10 | if clamp_corrected_images: 11 | return torch.clamp(img_correc, 0, 255) 12 | else: 13 | return img_correc 14 | 15 | 16 | def apply_corrections_uncropped(img_uncropped, correc, pad_params, clamp_corrected_images=False): 17 | ''' Rescale corrections between 0 and 255 then apply corrections on an uncropped image when full-size warping is used ''' 18 | dims = correc.shape 19 | H, W = dims [-2:] 20 | H_uncropped, W_uncropped = img_uncropped.size()[-2:] 21 | img_correc = torch.empty_like(img_uncropped) 22 | for b in range(dims[0]): 23 | pad_left, _, pad_top, _ = pad_params[b].data 24 | correc_pad = F.pad(correc[b:b+1] * 255, (pad_left, W_uncropped - W - pad_left, pad_top, H_uncropped - H - pad_top)) 25 | img_correc[b:b+1] = correc_pad + img_uncropped[b:b+1] 26 | 27 | if clamp_corrected_images: 28 | return torch.clamp(img_correc, 0, 255) 29 | else: 30 | return img_correc 31 | 32 | 33 | @torch.no_grad() 34 | def get_good_correction(im1, im2_warp, im2_warp_correc): 35 | '''Get a binary mask of the well-estimated corrections relative to L1-norm''' 36 | im1_diff = (im1 - im2_warp).abs().mean(dim=-3, keepdim=True) 37 | im1_diff_correc = (im1 - im2_warp_correc).abs().mean(dim=-3, keepdim=True) 38 | good_correc = im1_diff > im1_diff_correc 39 | return good_correc.float() 40 | 41 | 42 | @torch.no_grad() 43 | def get_best_index(coords, good_corrections): 44 | ''' Unwarp the good correction mask ''' 45 | B, _, H, W = coords.size() 46 | coords_floor = torch.floor(coords) 47 | coords_offset = coords - coords_floor 48 | coords_floor = coords_floor.long() 49 | 50 | idx_batch_offset = torch.arange(B, device=coords.device).view(B, 1, 1).expand(-1, H, W) * H * W 51 | 52 | coords_floor_flattened = coords_floor.permute(0, 2, 3, 1).reshape(-1, 2) 53 | coords_offset_flattened = coords_offset.permute(0, 2, 3, 1).reshape(-1, 2) 54 | idx_batch_offset_flattened = idx_batch_offset.reshape(-1) 55 | good_corrections = good_corrections.reshape(-1) 56 | 57 | # Initialize results. 58 | idxs_list = [] 59 | weights_list = [] 60 | 61 | # Loop over differences di and dj to the four neighboring pixels. 62 | for di in range(2): 63 | for dj in range(2): 64 | idxs_i = coords_floor_flattened[:, 1] + di 65 | idxs_j = coords_floor_flattened[:, 0] + dj 66 | 67 | 68 | idxs = idx_batch_offset_flattened + idxs_i * W + idxs_j 69 | 70 | mask = torch.where(torch.logical_and(good_corrections.bool(), torch.logical_and( 71 | torch.logical_and(idxs_i >= 0, idxs_i < H), 72 | torch.logical_and(idxs_j >= 0, idxs_j < W))))[0] 73 | 74 | valid_idxs = torch.gather(idxs, 0, mask) 75 | valid_offsets = torch.stack([torch.gather(coords_offset_flattened[:, 0], 0, mask), torch.gather(coords_offset_flattened[:, 1], 0, mask)], 1) 76 | 77 | # Compute weights according to bilinear interpolation. 78 | weights_i = (1. - di) - (-1)**di * valid_offsets[:, 1] 79 | weights_j = (1. - dj) - (-1)**dj * valid_offsets[:, 0] 80 | weights = weights_i * weights_j 81 | 82 | # Append indices and weights to the corresponding list. 83 | idxs_list.append(valid_idxs) 84 | weights_list.append(weights) 85 | 86 | # Concatenate everything. 87 | idxs = torch.cat(idxs_list, 0) 88 | weights = torch.cat(weights_list, 0) 89 | 90 | counts = torch.zeros(B * H * W, device=coords.device).scatter_add(0, idxs, weights).reshape(B, 1, H, W) 91 | 92 | return counts > 0. -------------------------------------------------------------------------------- /models/gma/gma.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | from einops import rearrange 4 | 5 | 6 | class RelPosEmb(nn.Module): 7 | def __init__( 8 | self, 9 | max_pos_size, 10 | dim_head 11 | ): 12 | super().__init__() 13 | self.rel_height = nn.Embedding(2 * max_pos_size - 1, dim_head) 14 | self.rel_width = nn.Embedding(2 * max_pos_size - 1, dim_head) 15 | 16 | deltas = torch.arange(max_pos_size).view(1, -1) - torch.arange(max_pos_size).view(-1, 1) 17 | rel_ind = deltas + max_pos_size - 1 18 | self.register_buffer('rel_ind', rel_ind) 19 | 20 | def forward(self, q): 21 | batch, heads, h, w, c = q.shape 22 | height_emb = self.rel_height(self.rel_ind[:h, :h].reshape(-1)) 23 | width_emb = self.rel_width(self.rel_ind[:w, :w].reshape(-1)) 24 | 25 | height_emb = rearrange(height_emb, '(x u) d -> x u () d', x=h) 26 | width_emb = rearrange(width_emb, '(y v) d -> y () v d', y=w) 27 | 28 | height_score = einsum('b h x y d, x u v d -> b h x y u v', q, height_emb) 29 | width_score = einsum('b h x y d, y u v d -> b h x y u v', q, width_emb) 30 | 31 | return height_score + width_score 32 | 33 | 34 | class Attention(nn.Module): 35 | def __init__( 36 | self, 37 | *, 38 | args, 39 | dim, 40 | max_pos_size = 100, 41 | heads = 4, 42 | dim_head = 128, 43 | ): 44 | super().__init__() 45 | self.args = args 46 | self.heads = heads 47 | self.scale = dim_head ** -0.5 48 | inner_dim = heads * dim_head 49 | 50 | self.to_qk = nn.Conv2d(dim, inner_dim * 2, 1, bias=False) 51 | 52 | if self.args.position_only or self.args.position_and_content: 53 | self.pos_emb = RelPosEmb(max_pos_size, dim_head) 54 | 55 | def forward(self, fmap): 56 | heads, b, c, h, w = self.heads, *fmap.shape 57 | 58 | q, k = self.to_qk(fmap).chunk(2, dim=1) 59 | 60 | q, k = map(lambda t: rearrange(t, 'b (h d) x y -> b h x y d', h=heads), (q, k)) 61 | q = self.scale * q 62 | 63 | if self.args.position_only: 64 | sim = self.pos_emb(q) 65 | 66 | elif self.args.position_and_content: 67 | sim_content = einsum('b h x y d, b h u v d -> b h x y u v', q, k) 68 | sim_pos = self.pos_emb(q) 69 | sim = sim_content + sim_pos 70 | 71 | else: 72 | sim = einsum('b h x y d, b h u v d -> b h x y u v', q, k) 73 | 74 | sim = rearrange(sim, 'b h x y u v -> b h (x y) (u v)') 75 | attn = sim.softmax(dim=-1) 76 | 77 | return attn 78 | 79 | 80 | class Aggregate(nn.Module): 81 | def __init__( 82 | self, 83 | args, 84 | dim, 85 | heads = 4, 86 | dim_head = 128, 87 | ): 88 | super().__init__() 89 | self.args = args 90 | self.heads = heads 91 | self.scale = dim_head ** -0.5 92 | inner_dim = heads * dim_head 93 | 94 | self.to_v = nn.Conv2d(dim, inner_dim, 1, bias=False) 95 | 96 | self.gamma = nn.Parameter(torch.zeros(1)) 97 | 98 | if dim != inner_dim: 99 | self.project = nn.Conv2d(inner_dim, dim, 1, bias=False) 100 | else: 101 | self.project = None 102 | 103 | def forward(self, attn, fmap): 104 | heads, b, c, h, w = self.heads, *fmap.shape 105 | 106 | v = self.to_v(fmap) 107 | v = rearrange(v, 'b (h d) x y -> b h (x y) d', h=heads) 108 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 109 | out = rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w).contiguous() 110 | 111 | if self.project is not None: 112 | out = self.project(out) 113 | 114 | out = fmap + self.gamma * out 115 | 116 | return out 117 | 118 | 119 | if __name__ == "__main__": 120 | att = Attention(dim=128, heads=1) 121 | fmap = torch.randn(2, 128, 40, 90) 122 | out = att(fmap) 123 | 124 | print(out.shape) 125 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | 6 | import utils.argument_parser as argument_parser 7 | import utils.config as cfg 8 | from utils.utils import to_cuda, Logging, stack_all_gather_without_backprop, CheckpointSavior 9 | from evaluate import Validation 10 | 11 | import torch.multiprocessing as mp 12 | 13 | 14 | def train(gpu, args): 15 | print(gpu) 16 | 17 | # set random seeds 18 | cfg.configure_random_seed(args.seed, gpu=gpu) 19 | 20 | args.gpu = gpu 21 | 22 | torch.distributed.init_process_group( 23 | backend='nccl', 24 | init_method='env://', 25 | world_size=args.num_gpus, 26 | rank=gpu 27 | ) 28 | 29 | ## get dataloaders 30 | train_loader = cfg.get_train_dataloaders(args, gpu) 31 | test_loader = cfg.get_test_dataloaders(args, gpu) 32 | 33 | ## get runner and loss 34 | runner = cfg.get_runner(args) 35 | 36 | torch.cuda.set_device(gpu) 37 | runner.cuda(gpu) 38 | runner = nn.parallel.DistributedDataParallel(runner, device_ids=[gpu], find_unused_parameters=('correc' in args.mode and args.init_step < args.correc_starting_step)) 39 | runner.train() 40 | 41 | if args.restore_ckpt is not None: 42 | missing_keys, unexpected_keys = runner.load_state_dict(torch.load(args.restore_ckpt), strict=False) 43 | if gpu == 0: 44 | print('missing_keys:', missing_keys) 45 | print('unexpected_keys:', unexpected_keys) 46 | 47 | ## get metrics 48 | metrics = cfg.get_metrics(args) 49 | 50 | ## Init validator 51 | validator = Validation(args) 52 | 53 | ## get optimizer 54 | optimizer, scheduler = cfg.get_optimizer(args, runner) 55 | 56 | save_checkpoint = CheckpointSavior(args) 57 | logger = Logging(runner, scheduler, args) 58 | 59 | total_steps = args.init_step 60 | prev_loss_dict = {} 61 | should_keep_training = True 62 | while should_keep_training: 63 | 64 | for _, example in enumerate(train_loader): 65 | 66 | ## Reset gradients 67 | runner.zero_grad(set_to_none=True) 68 | 69 | ## Transfer to cuda 70 | to_cuda(example) 71 | 72 | ## Run forwad pass 73 | _, loss_dict = runner(example, total_steps) 74 | 75 | ## Check total_loss for NaNs 76 | training_loss = loss_dict['loss_total'] 77 | if np.isnan(training_loss.item()): 78 | print('Current loss dict:', gpu, loss_dict) 79 | print() 80 | print('Previous loss dict:', gpu, prev_loss_dict) 81 | print() 82 | raise ValueError("training_loss is NaN") 83 | else: 84 | prev_loss_dict = loss_dict 85 | 86 | training_loss.backward() 87 | torch.nn.utils.clip_grad_norm_(runner.parameters(), args.clip) 88 | 89 | # else: 90 | optimizer.step() 91 | scheduler.step() 92 | 93 | ## increment total step 94 | total_steps += 1 95 | 96 | for key, value in loss_dict.items(): 97 | loss_dict[key] = stack_all_gather_without_backprop(value).mean() 98 | logger.push(loss_dict) 99 | 100 | if total_steps % args.VAL_FREQ == args.VAL_FREQ - 1: 101 | runner.eval() 102 | results = validator.validate(runner, test_loader, metrics, total_steps) 103 | runner.train() 104 | 105 | logger.write_dict(results) 106 | save_checkpoint(results, runner) 107 | 108 | if total_steps >= args.num_steps: 109 | should_keep_training = False 110 | save_checkpoint(results, runner) 111 | break 112 | 113 | 114 | if __name__ == '__main__': 115 | 116 | ## get arguments 117 | args = argument_parser.get_arguments() 118 | print(args) 119 | 120 | ## get checkpoints 121 | if not args.debug: 122 | os.makedirs(args.ckpt_dir, exist_ok=True) 123 | log_dir = os.path.join(args.ckpt_dir, args.name) 124 | os.makedirs(log_dir, exist_ok=True) 125 | argument_parser.save_args(args, log_dir) 126 | 127 | os.environ['MASTER_ADDR'] = 'localhost' 128 | os.environ['MASTER_PORT'] = str(12000 + np.round(np.random.random() * 1000).astype(int)) 129 | 130 | mp.spawn(train, nprocs=args.num_gpus, args=(args,)) -------------------------------------------------------------------------------- /utils/coords_and_warp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | def mesh_grid(B, H, W, device): 8 | # mesh grid 9 | x_base = torch.arange(0, W, device=device).view(1, 1, -1).expand(B, H, -1) # BHW 10 | y_base = torch.arange(0, H, device=device).view(1, -1, 1).expand(B, -1, W) # BHW 11 | 12 | base_grid = torch.stack([x_base, y_base], 1) # B2HW 13 | return base_grid 14 | 15 | 16 | class Coords(nn.Module): 17 | """ Get absolute coordinates from optical flow """ 18 | def __init__(self) -> None: 19 | super(Coords, self).__init__() 20 | 21 | def get_coords(self, flow): 22 | _, _, H, W = flow.size() 23 | return mesh_grid(1, H, W, device=flow.device) + flow 24 | 25 | def forward(self, flows): 26 | dims = flows.size() 27 | return self.get_coords(flows.view(-1, 2, *dims[-2:])).view(dims) 28 | 29 | 30 | class Warp(nn.Module): 31 | """ Warping module """ 32 | def __init__(self) -> None: 33 | super(Warp, self).__init__() 34 | 35 | def warp(self, x, coords): 36 | B, _, H, W = coords.size() 37 | _coords = 2.0 * coords / torch.tensor([[[[max(W-1, 1)]], [[max(H-1, 1)]]]], device=coords.device).float() - 1.0 38 | _coords = _coords.permute(0, 2, 3, 1) 39 | 40 | x_warp = F.grid_sample(x, _coords, align_corners=True, padding_mode="zeros") 41 | 42 | mask = torch.ones(B, 1, H, W, requires_grad=False, device=_coords.device) 43 | mask = F.grid_sample(mask, _coords, align_corners=True, padding_mode="zeros") 44 | mask = (mask >= 1.0).float() 45 | 46 | return x_warp * mask 47 | 48 | def forward(self, x, coords): 49 | dims_x = x.size() 50 | x_warped = self.warp(x.view(-1, *dims_x[-3:]), coords.reshape(-1, 2, *dims_x[-2:])) 51 | return x_warped.view(*dims_x) 52 | 53 | 54 | class WarpFullSize(Warp): 55 | """ Warping module implementing full-size warping """ 56 | def __init__(self) -> None: 57 | super(WarpFullSize, self).__init__() 58 | self.coords = Coords() 59 | 60 | def forward(self, x, flows, pad_params, orig_dims): 61 | flow_dims = flows.size() 62 | H, W = flow_dims[-2:] 63 | x_warp = torch.empty((*flow_dims[:-3], x.size(-3), H, W), device=x.device) 64 | 65 | for b in range(flow_dims[0]): 66 | x_orig_dims = x[b, ..., :orig_dims[b, 0].data, :orig_dims[b, 1].data] 67 | coords_f_pad = self.coords(F.pad(flows[b], pad_params[b].tolist())) 68 | 69 | pad_left, _, pad_top, _ = pad_params[b].data 70 | x_warp[b] = self.warp(x_orig_dims, coords_f_pad)[..., pad_top:pad_top+H, pad_left:pad_left+W] 71 | 72 | return x_warp 73 | 74 | 75 | class WarpMulti(Warp): 76 | """ Warp images multiple times with several flows """ 77 | def __init__(self) -> None: 78 | super(WarpMulti, self).__init__() 79 | 80 | def forward(self, x, coords): 81 | coords_dims = coords.size() # B 2 S 2 H W or B 2 S N 2 H W 82 | H, W = coords_dims[-2:] 83 | C = x.size(-3) 84 | x_resized = x.unsqueeze(-4).expand(-1, -1, np.prod(coords_dims[2:-3]), -1, -1, -1).reshape(-1, C, H, W) 85 | coords_resized = coords.reshape(-1, *coords_dims[-3:]) 86 | return self.warp(x_resized, coords_resized).view(*coords_dims[:-3], C, H, W) 87 | 88 | 89 | class WarpFullSizeMulti(Warp): 90 | """ Warp images multiple times with several flows using full-size warping """ 91 | def __init__(self) -> None: 92 | super(WarpFullSizeMulti, self).__init__() 93 | self.coords = Coords() 94 | 95 | def forward(self, x, flows, pad_params, orig_dims): 96 | flows_dims = flows.size() # B 2 S 2 H W or B 2 S N 2 H W 97 | S = np.prod(flows_dims[2:-3]).astype(int) # S or S*N 98 | B = flows_dims[0] 99 | H, W = flows_dims[-2:] 100 | C = x.size(-3) 101 | x_warp = torch.empty((B, 2*S, C, H, W), device=x.device) 102 | 103 | for b in range(B): 104 | x_orig_dims = x[b, ..., :orig_dims[b, 0].data, :orig_dims[b, 1].data].unsqueeze(-4).expand(-1, S, -1, -1, -1).flatten(end_dim=-4) 105 | coords_f_pad = self.coords(F.pad(flows[b], pad_params[b].tolist()).flatten(end_dim=-4)) 106 | 107 | pad_left, _, pad_top, _ = pad_params[b].data 108 | x_warp[b] = self.warp(x_orig_dims, coords_f_pad)[..., pad_top:pad_top+H, pad_left:pad_left+W] 109 | 110 | return x_warp.view(*flows_dims[:-3], C, H, W) 111 | -------------------------------------------------------------------------------- /datasets/utils_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from os.path import * 4 | import re 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | TAG_CHAR = np.array([202021.25], np.float32) 11 | 12 | def readFlow(fn): 13 | """ Read .flo file in Middlebury format""" 14 | # Code adapted from: 15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 16 | 17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 18 | # print 'fn = %s'%(fn) 19 | with open(fn, 'rb') as f: 20 | magic = np.fromfile(f, np.float32, count=1) 21 | if 202021.25 != magic: 22 | print('Magic number incorrect. Invalid .flo file') 23 | return None 24 | else: 25 | w = np.fromfile(f, np.int32, count=1) 26 | h = np.fromfile(f, np.int32, count=1) 27 | # print 'Reading %d x %d flo file\n' % (w, h) 28 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) 29 | # Reshape data into 3D array (columns, rows, bands) 30 | # The reshape here is for visualization, the original code is (w,h,2) 31 | return np.resize(data, (int(h), int(w), 2)) 32 | 33 | def readPFM(file): 34 | file = open(file, 'rb') 35 | 36 | color = None 37 | width = None 38 | height = None 39 | scale = None 40 | endian = None 41 | 42 | header = file.readline().rstrip() 43 | if header == b'PF': 44 | color = True 45 | elif header == b'Pf': 46 | color = False 47 | else: 48 | raise Exception('Not a PFM file.') 49 | 50 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) 51 | if dim_match: 52 | width, height = map(int, dim_match.groups()) 53 | else: 54 | raise Exception('Malformed PFM header.') 55 | 56 | scale = float(file.readline().rstrip()) 57 | if scale < 0: # little-endian 58 | endian = '<' 59 | scale = -scale 60 | else: 61 | endian = '>' # big-endian 62 | 63 | data = np.fromfile(file, endian + 'f') 64 | shape = (height, width, 3) if color else (height, width) 65 | 66 | data = np.reshape(data, shape) 67 | data = np.flipud(data) 68 | return data 69 | 70 | def writeFlow(filename,uv,v=None): 71 | """ Write optical flow to file. 72 | 73 | If v is None, uv is assumed to contain both u and v channels, 74 | stacked in depth. 75 | Original code by Deqing Sun, adapted from Daniel Scharstein. 76 | """ 77 | nBands = 2 78 | 79 | if v is None: 80 | assert(uv.ndim == 3) 81 | assert(uv.shape[2] == 2) 82 | u = uv[:,:,0] 83 | v = uv[:,:,1] 84 | else: 85 | u = uv 86 | 87 | assert(u.shape == v.shape) 88 | height,width = u.shape 89 | f = open(filename,'wb') 90 | # write the header 91 | f.write(TAG_CHAR) 92 | np.array(width).astype(np.int32).tofile(f) 93 | np.array(height).astype(np.int32).tofile(f) 94 | # arrange into matrix form 95 | tmp = np.zeros((height, width*nBands)) 96 | tmp[:,np.arange(width)*2] = u 97 | tmp[:,np.arange(width)*2 + 1] = v 98 | tmp.astype(np.float32).tofile(f) 99 | f.close() 100 | 101 | 102 | def readFlowKITTI(filename): 103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) 104 | flow = flow[:,:,::-1].astype(np.float32) 105 | flow, valid = flow[:, :, :2], flow[:, :, 2] 106 | flow = (flow - 2**15) / 64.0 107 | return flow, valid 108 | 109 | def readDispKITTI(filename): 110 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 111 | valid = disp > 0.0 112 | flow = np.stack([-disp, np.zeros_like(disp)], -1) 113 | return flow, valid 114 | 115 | 116 | def writeFlowKITTI(filename, uv): 117 | uv = 64.0 * uv + 2**15 118 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 119 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 120 | cv2.imwrite(filename, uv[..., ::-1]) 121 | 122 | 123 | def read_gen(file_name, pil=False): 124 | ext = splitext(file_name)[-1] 125 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 126 | return Image.open(file_name) 127 | elif ext == '.bin' or ext == '.raw': 128 | return np.load(file_name) 129 | elif ext == '.flo': 130 | return readFlow(file_name).astype(np.float32) 131 | elif ext == '.pfm': 132 | flow = readPFM(file_name).astype(np.float32) 133 | if len(flow.shape) == 2: 134 | return flow 135 | else: 136 | return flow[:, :, :-1] 137 | return [] 138 | -------------------------------------------------------------------------------- /models/gma/utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from os.path import * 4 | import re 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | TAG_CHAR = np.array([202021.25], np.float32) 11 | 12 | def readFlow(fn): 13 | """ Read .flo file in Middlebury format""" 14 | # Code adapted from: 15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 16 | 17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 18 | # print 'fn = %s'%(fn) 19 | with open(fn, 'rb') as f: 20 | magic = np.fromfile(f, np.float32, count=1) 21 | if 202021.25 != magic: 22 | print('Magic number incorrect. Invalid .flo file') 23 | return None 24 | else: 25 | w = np.fromfile(f, np.int32, count=1) 26 | h = np.fromfile(f, np.int32, count=1) 27 | # print 'Reading %d x %d flo file\n' % (w, h) 28 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) 29 | # Reshape data into 3D array (columns, rows, bands) 30 | # The reshape here is for visualization, the original code is (w,h,2) 31 | return np.resize(data, (int(h), int(w), 2)) 32 | 33 | def readPFM(file): 34 | file = open(file, 'rb') 35 | 36 | color = None 37 | width = None 38 | height = None 39 | scale = None 40 | endian = None 41 | 42 | header = file.readline().rstrip() 43 | if header == b'PF': 44 | color = True 45 | elif header == b'Pf': 46 | color = False 47 | else: 48 | raise Exception('Not a PFM file.') 49 | 50 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) 51 | if dim_match: 52 | width, height = map(int, dim_match.groups()) 53 | else: 54 | raise Exception('Malformed PFM header.') 55 | 56 | scale = float(file.readline().rstrip()) 57 | if scale < 0: # little-endian 58 | endian = '<' 59 | scale = -scale 60 | else: 61 | endian = '>' # big-endian 62 | 63 | data = np.fromfile(file, endian + 'f') 64 | shape = (height, width, 3) if color else (height, width) 65 | 66 | data = np.reshape(data, shape) 67 | data = np.flipud(data) 68 | return data 69 | 70 | def writeFlow(filename,uv,v=None): 71 | """ Write optical flow to file. 72 | 73 | If v is None, uv is assumed to contain both u and v channels, 74 | stacked in depth. 75 | Original code by Deqing Sun, adapted from Daniel Scharstein. 76 | """ 77 | nBands = 2 78 | 79 | if v is None: 80 | assert(uv.ndim == 3) 81 | assert(uv.shape[2] == 2) 82 | u = uv[:,:,0] 83 | v = uv[:,:,1] 84 | else: 85 | u = uv 86 | 87 | assert(u.shape == v.shape) 88 | height,width = u.shape 89 | f = open(filename,'wb') 90 | # write the header 91 | f.write(TAG_CHAR) 92 | np.array(width).astype(np.int32).tofile(f) 93 | np.array(height).astype(np.int32).tofile(f) 94 | # arrange into matrix form 95 | tmp = np.zeros((height, width*nBands)) 96 | tmp[:,np.arange(width)*2] = u 97 | tmp[:,np.arange(width)*2 + 1] = v 98 | tmp.astype(np.float32).tofile(f) 99 | f.close() 100 | 101 | 102 | def readFlowKITTI(filename): 103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) 104 | flow = flow[:,:,::-1].astype(np.float32) 105 | flow, valid = flow[:, :, :2], flow[:, :, 2] 106 | flow = (flow - 2**15) / 64.0 107 | return flow, valid 108 | 109 | def readDispKITTI(filename): 110 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 111 | valid = disp > 0.0 112 | flow = np.stack([-disp, np.zeros_like(disp)], -1) 113 | return flow, valid 114 | 115 | 116 | def writeFlowKITTI(filename, uv): 117 | uv = 64.0 * uv + 2**15 118 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 119 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 120 | cv2.imwrite(filename, uv[..., ::-1]) 121 | 122 | 123 | def read_gen(file_name, pil=False): 124 | ext = splitext(file_name)[-1] 125 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 126 | return Image.open(file_name) 127 | elif ext == '.bin' or ext == '.raw': 128 | return np.load(file_name) 129 | elif ext == '.flo': 130 | return readFlow(file_name).astype(np.float32) 131 | elif ext == '.pfm': 132 | flow = readPFM(file_name).astype(np.float32) 133 | if len(flow.shape) == 2: 134 | return flow 135 | else: 136 | return flow[:, :, :-1] 137 | return [] -------------------------------------------------------------------------------- /models/scv/utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from os.path import * 4 | import re 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | TAG_CHAR = np.array([202021.25], np.float32) 11 | 12 | def readFlow(fn): 13 | """ Read .flo file in Middlebury format""" 14 | # Code adapted from: 15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 16 | 17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 18 | # print 'fn = %s'%(fn) 19 | with open(fn, 'rb') as f: 20 | magic = np.fromfile(f, np.float32, count=1) 21 | if 202021.25 != magic: 22 | print('Magic number incorrect. Invalid .flo file') 23 | return None 24 | else: 25 | w = np.fromfile(f, np.int32, count=1) 26 | h = np.fromfile(f, np.int32, count=1) 27 | # print 'Reading %d x %d flo file\n' % (w, h) 28 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) 29 | # Reshape data into 3D array (columns, rows, bands) 30 | # The reshape here is for visualization, the original code is (w,h,2) 31 | return np.resize(data, (int(h), int(w), 2)) 32 | 33 | def readPFM(file): 34 | file = open(file, 'rb') 35 | 36 | color = None 37 | width = None 38 | height = None 39 | scale = None 40 | endian = None 41 | 42 | header = file.readline().rstrip() 43 | if header == b'PF': 44 | color = True 45 | elif header == b'Pf': 46 | color = False 47 | else: 48 | raise Exception('Not a PFM file.') 49 | 50 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) 51 | if dim_match: 52 | width, height = map(int, dim_match.groups()) 53 | else: 54 | raise Exception('Malformed PFM header.') 55 | 56 | scale = float(file.readline().rstrip()) 57 | if scale < 0: # little-endian 58 | endian = '<' 59 | scale = -scale 60 | else: 61 | endian = '>' # big-endian 62 | 63 | data = np.fromfile(file, endian + 'f') 64 | shape = (height, width, 3) if color else (height, width) 65 | 66 | data = np.reshape(data, shape) 67 | data = np.flipud(data) 68 | return data 69 | 70 | def writeFlow(filename,uv,v=None): 71 | """ Write optical flow to file. 72 | 73 | If v is None, uv is assumed to contain both u and v channels, 74 | stacked in depth. 75 | Original code by Deqing Sun, adapted from Daniel Scharstein. 76 | """ 77 | nBands = 2 78 | 79 | if v is None: 80 | assert(uv.ndim == 3) 81 | assert(uv.shape[2] == 2) 82 | u = uv[:,:,0] 83 | v = uv[:,:,1] 84 | else: 85 | u = uv 86 | 87 | assert(u.shape == v.shape) 88 | height,width = u.shape 89 | f = open(filename,'wb') 90 | # write the header 91 | f.write(TAG_CHAR) 92 | np.array(width).astype(np.int32).tofile(f) 93 | np.array(height).astype(np.int32).tofile(f) 94 | # arrange into matrix form 95 | tmp = np.zeros((height, width*nBands)) 96 | tmp[:,np.arange(width)*2] = u 97 | tmp[:,np.arange(width)*2 + 1] = v 98 | tmp.astype(np.float32).tofile(f) 99 | f.close() 100 | 101 | 102 | def readFlowKITTI(filename): 103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) 104 | flow = flow[:,:,::-1].astype(np.float32) 105 | flow, valid = flow[:, :, :2], flow[:, :, 2] 106 | flow = (flow - 2**15) / 64.0 107 | return flow, valid 108 | 109 | def readDispKITTI(filename): 110 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 111 | valid = disp > 0.0 112 | flow = np.stack([-disp, np.zeros_like(disp)], -1) 113 | return flow, valid 114 | 115 | 116 | def writeFlowKITTI(filename, uv): 117 | uv = 64.0 * uv + 2**15 118 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 119 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 120 | cv2.imwrite(filename, uv[..., ::-1]) 121 | 122 | 123 | def read_gen(file_name, pil=False): 124 | ext = splitext(file_name)[-1] 125 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 126 | return Image.open(file_name) 127 | elif ext == '.bin' or ext == '.raw': 128 | return np.load(file_name) 129 | elif ext == '.flo': 130 | return readFlow(file_name).astype(np.float32) 131 | elif ext == '.pfm': 132 | flow = readPFM(file_name).astype(np.float32) 133 | if len(flow.shape) == 2: 134 | return flow 135 | else: 136 | return flow[:, :, :-1] 137 | return [] -------------------------------------------------------------------------------- /models/raft/utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from os.path import * 4 | import re 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | TAG_CHAR = np.array([202021.25], np.float32) 11 | 12 | def readFlow(fn): 13 | """ Read .flo file in Middlebury format""" 14 | # Code adapted from: 15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 16 | 17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 18 | # print 'fn = %s'%(fn) 19 | with open(fn, 'rb') as f: 20 | magic = np.fromfile(f, np.float32, count=1) 21 | if 202021.25 != magic: 22 | print('Magic number incorrect. Invalid .flo file') 23 | return None 24 | else: 25 | w = np.fromfile(f, np.int32, count=1) 26 | h = np.fromfile(f, np.int32, count=1) 27 | # print 'Reading %d x %d flo file\n' % (w, h) 28 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) 29 | # Reshape data into 3D array (columns, rows, bands) 30 | # The reshape here is for visualization, the original code is (w,h,2) 31 | return np.resize(data, (int(h), int(w), 2)) 32 | 33 | def readPFM(file): 34 | file = open(file, 'rb') 35 | 36 | color = None 37 | width = None 38 | height = None 39 | scale = None 40 | endian = None 41 | 42 | header = file.readline().rstrip() 43 | if header == b'PF': 44 | color = True 45 | elif header == b'Pf': 46 | color = False 47 | else: 48 | raise Exception('Not a PFM file.') 49 | 50 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) 51 | if dim_match: 52 | width, height = map(int, dim_match.groups()) 53 | else: 54 | raise Exception('Malformed PFM header.') 55 | 56 | scale = float(file.readline().rstrip()) 57 | if scale < 0: # little-endian 58 | endian = '<' 59 | scale = -scale 60 | else: 61 | endian = '>' # big-endian 62 | 63 | data = np.fromfile(file, endian + 'f') 64 | shape = (height, width, 3) if color else (height, width) 65 | 66 | data = np.reshape(data, shape) 67 | data = np.flipud(data) 68 | return data 69 | 70 | def writeFlow(filename,uv,v=None): 71 | """ Write optical flow to file. 72 | 73 | If v is None, uv is assumed to contain both u and v channels, 74 | stacked in depth. 75 | Original code by Deqing Sun, adapted from Daniel Scharstein. 76 | """ 77 | nBands = 2 78 | 79 | if v is None: 80 | assert(uv.ndim == 3) 81 | assert(uv.shape[2] == 2) 82 | u = uv[:,:,0] 83 | v = uv[:,:,1] 84 | else: 85 | u = uv 86 | 87 | assert(u.shape == v.shape) 88 | height,width = u.shape 89 | f = open(filename,'wb') 90 | # write the header 91 | f.write(TAG_CHAR) 92 | np.array(width).astype(np.int32).tofile(f) 93 | np.array(height).astype(np.int32).tofile(f) 94 | # arrange into matrix form 95 | tmp = np.zeros((height, width*nBands)) 96 | tmp[:,np.arange(width)*2] = u 97 | tmp[:,np.arange(width)*2 + 1] = v 98 | tmp.astype(np.float32).tofile(f) 99 | f.close() 100 | 101 | 102 | def readFlowKITTI(filename): 103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) 104 | flow = flow[:,:,::-1].astype(np.float32) 105 | flow, valid = flow[:, :, :2], flow[:, :, 2] 106 | flow = (flow - 2**15) / 64.0 107 | return flow, valid 108 | 109 | def readDispKITTI(filename): 110 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 111 | valid = disp > 0.0 112 | flow = np.stack([-disp, np.zeros_like(disp)], -1) 113 | return flow, valid 114 | 115 | 116 | def writeFlowKITTI(filename, uv): 117 | uv = 64.0 * uv + 2**15 118 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 119 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 120 | cv2.imwrite(filename, uv[..., ::-1]) 121 | 122 | 123 | def read_gen(file_name, pil=False): 124 | ext = splitext(file_name)[-1] 125 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 126 | return Image.open(file_name) 127 | elif ext == '.bin' or ext == '.raw': 128 | return np.load(file_name) 129 | elif ext == '.flo': 130 | return readFlow(file_name).astype(np.float32) 131 | elif ext == '.pfm': 132 | flow = readPFM(file_name).astype(np.float32) 133 | if len(flow.shape) == 2: 134 | return flow 135 | else: 136 | return flow[:, :, :-1] 137 | return [] -------------------------------------------------------------------------------- /models/gma/utils/flow_viz.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 2 | 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Tom Runia 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to conditions. 14 | # 15 | # Author: Tom Runia 16 | # Date Created: 2018-08-03 17 | 18 | import numpy as np 19 | 20 | def make_colorwheel(): 21 | """ 22 | Generates a color wheel for optical flow visualization as presented in: 23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 25 | 26 | Code follows the original C++ source code of Daniel Scharstein. 27 | Code follows the the Matlab source code of Deqing Sun. 28 | 29 | Returns: 30 | np.ndarray: Color wheel 31 | """ 32 | 33 | RY = 15 34 | YG = 6 35 | GC = 4 36 | CB = 11 37 | BM = 13 38 | MR = 6 39 | 40 | ncols = RY + YG + GC + CB + BM + MR 41 | colorwheel = np.zeros((ncols, 3)) 42 | col = 0 43 | 44 | # RY 45 | colorwheel[0:RY, 0] = 255 46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) 47 | col = col+RY 48 | # YG 49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) 50 | colorwheel[col:col+YG, 1] = 255 51 | col = col+YG 52 | # GC 53 | colorwheel[col:col+GC, 1] = 255 54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) 55 | col = col+GC 56 | # CB 57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) 58 | colorwheel[col:col+CB, 2] = 255 59 | col = col+CB 60 | # BM 61 | colorwheel[col:col+BM, 2] = 255 62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) 63 | col = col+BM 64 | # MR 65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) 66 | colorwheel[col:col+MR, 0] = 255 67 | return colorwheel 68 | 69 | 70 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 71 | """ 72 | Applies the flow color wheel to (possibly clipped) flow components u and v. 73 | 74 | According to the C++ source code of Daniel Scharstein 75 | According to the Matlab source code of Deqing Sun 76 | 77 | Args: 78 | u (np.ndarray): Input horizontal flow of shape [H,W] 79 | v (np.ndarray): Input vertical flow of shape [H,W] 80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 81 | 82 | Returns: 83 | np.ndarray: Flow visualization image of shape [H,W,3] 84 | """ 85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 86 | colorwheel = make_colorwheel() # shape [55x3] 87 | ncols = colorwheel.shape[0] 88 | rad = np.sqrt(np.square(u) + np.square(v)) 89 | a = np.arctan2(-v, -u)/np.pi 90 | fk = (a+1) / 2*(ncols-1) 91 | k0 = np.floor(fk).astype(np.int32) 92 | k1 = k0 + 1 93 | k1[k1 == ncols] = 0 94 | f = fk - k0 95 | for i in range(colorwheel.shape[1]): 96 | tmp = colorwheel[:,i] 97 | col0 = tmp[k0] / 255.0 98 | col1 = tmp[k1] / 255.0 99 | col = (1-f)*col0 + f*col1 100 | idx = (rad <= 1) 101 | col[idx] = 1 - rad[idx] * (1-col[idx]) 102 | col[~idx] = col[~idx] * 0.75 # out of range 103 | # Note the 2-i => BGR instead of RGB 104 | ch_idx = 2-i if convert_to_bgr else i 105 | flow_image[:,:,ch_idx] = np.floor(255 * col) 106 | return flow_image 107 | 108 | 109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 110 | """ 111 | Expects a two dimensional flow image of shape. 112 | 113 | Args: 114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 117 | 118 | Returns: 119 | np.ndarray: Flow visualization image of shape [H,W,3] 120 | """ 121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 123 | if clip_flow is not None: 124 | flow_uv = np.clip(flow_uv, 0, clip_flow) 125 | u = flow_uv[:,:,0] 126 | v = flow_uv[:,:,1] 127 | rad = np.sqrt(np.square(u) + np.square(v)) 128 | rad_max = np.max(rad) 129 | epsilon = 1e-5 130 | u = u / (rad_max + epsilon) 131 | v = v / (rad_max + epsilon) 132 | return flow_uv_to_colors(u, v, convert_to_bgr) -------------------------------------------------------------------------------- /models/raft/utils/flow_viz.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 2 | 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Tom Runia 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to conditions. 14 | # 15 | # Author: Tom Runia 16 | # Date Created: 2018-08-03 17 | 18 | import numpy as np 19 | 20 | def make_colorwheel(): 21 | """ 22 | Generates a color wheel for optical flow visualization as presented in: 23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 25 | 26 | Code follows the original C++ source code of Daniel Scharstein. 27 | Code follows the the Matlab source code of Deqing Sun. 28 | 29 | Returns: 30 | np.ndarray: Color wheel 31 | """ 32 | 33 | RY = 15 34 | YG = 6 35 | GC = 4 36 | CB = 11 37 | BM = 13 38 | MR = 6 39 | 40 | ncols = RY + YG + GC + CB + BM + MR 41 | colorwheel = np.zeros((ncols, 3)) 42 | col = 0 43 | 44 | # RY 45 | colorwheel[0:RY, 0] = 255 46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) 47 | col = col+RY 48 | # YG 49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) 50 | colorwheel[col:col+YG, 1] = 255 51 | col = col+YG 52 | # GC 53 | colorwheel[col:col+GC, 1] = 255 54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) 55 | col = col+GC 56 | # CB 57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) 58 | colorwheel[col:col+CB, 2] = 255 59 | col = col+CB 60 | # BM 61 | colorwheel[col:col+BM, 2] = 255 62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) 63 | col = col+BM 64 | # MR 65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) 66 | colorwheel[col:col+MR, 0] = 255 67 | return colorwheel 68 | 69 | 70 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 71 | """ 72 | Applies the flow color wheel to (possibly clipped) flow components u and v. 73 | 74 | According to the C++ source code of Daniel Scharstein 75 | According to the Matlab source code of Deqing Sun 76 | 77 | Args: 78 | u (np.ndarray): Input horizontal flow of shape [H,W] 79 | v (np.ndarray): Input vertical flow of shape [H,W] 80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 81 | 82 | Returns: 83 | np.ndarray: Flow visualization image of shape [H,W,3] 84 | """ 85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 86 | colorwheel = make_colorwheel() # shape [55x3] 87 | ncols = colorwheel.shape[0] 88 | rad = np.sqrt(np.square(u) + np.square(v)) 89 | a = np.arctan2(-v, -u)/np.pi 90 | fk = (a+1) / 2*(ncols-1) 91 | k0 = np.floor(fk).astype(np.int32) 92 | k1 = k0 + 1 93 | k1[k1 == ncols] = 0 94 | f = fk - k0 95 | for i in range(colorwheel.shape[1]): 96 | tmp = colorwheel[:,i] 97 | col0 = tmp[k0] / 255.0 98 | col1 = tmp[k1] / 255.0 99 | col = (1-f)*col0 + f*col1 100 | idx = (rad <= 1) 101 | col[idx] = 1 - rad[idx] * (1-col[idx]) 102 | col[~idx] = col[~idx] * 0.75 # out of range 103 | # Note the 2-i => BGR instead of RGB 104 | ch_idx = 2-i if convert_to_bgr else i 105 | flow_image[:,:,ch_idx] = np.floor(255 * col) 106 | return flow_image 107 | 108 | 109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 110 | """ 111 | Expects a two dimensional flow image of shape. 112 | 113 | Args: 114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 117 | 118 | Returns: 119 | np.ndarray: Flow visualization image of shape [H,W,3] 120 | """ 121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 123 | if clip_flow is not None: 124 | flow_uv = np.clip(flow_uv, 0, clip_flow) 125 | u = flow_uv[:,:,0] 126 | v = flow_uv[:,:,1] 127 | rad = np.sqrt(np.square(u) + np.square(v)) 128 | rad_max = np.max(rad) 129 | epsilon = 1e-5 130 | u = u / (rad_max + epsilon) 131 | v = v / (rad_max + epsilon) 132 | return flow_uv_to_colors(u, v, convert_to_bgr) -------------------------------------------------------------------------------- /models/scv/utils/flow_viz.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 2 | 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Tom Runia 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to conditions. 14 | # 15 | # Author: Tom Runia 16 | # Date Created: 2018-08-03 17 | 18 | import numpy as np 19 | 20 | def make_colorwheel(): 21 | """ 22 | Generates a color wheel for optical flow visualization as presented in: 23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 25 | 26 | Code follows the original C++ source code of Daniel Scharstein. 27 | Code follows the the Matlab source code of Deqing Sun. 28 | 29 | Returns: 30 | np.ndarray: Color wheel 31 | """ 32 | 33 | RY = 15 34 | YG = 6 35 | GC = 4 36 | CB = 11 37 | BM = 13 38 | MR = 6 39 | 40 | ncols = RY + YG + GC + CB + BM + MR 41 | colorwheel = np.zeros((ncols, 3)) 42 | col = 0 43 | 44 | # RY 45 | colorwheel[0:RY, 0] = 255 46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) 47 | col = col+RY 48 | # YG 49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) 50 | colorwheel[col:col+YG, 1] = 255 51 | col = col+YG 52 | # GC 53 | colorwheel[col:col+GC, 1] = 255 54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) 55 | col = col+GC 56 | # CB 57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) 58 | colorwheel[col:col+CB, 2] = 255 59 | col = col+CB 60 | # BM 61 | colorwheel[col:col+BM, 2] = 255 62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) 63 | col = col+BM 64 | # MR 65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) 66 | colorwheel[col:col+MR, 0] = 255 67 | return colorwheel 68 | 69 | 70 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 71 | """ 72 | Applies the flow color wheel to (possibly clipped) flow components u and v. 73 | 74 | According to the C++ source code of Daniel Scharstein 75 | According to the Matlab source code of Deqing Sun 76 | 77 | Args: 78 | u (np.ndarray): Input horizontal flow of shape [H,W] 79 | v (np.ndarray): Input vertical flow of shape [H,W] 80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 81 | 82 | Returns: 83 | np.ndarray: Flow visualization image of shape [H,W,3] 84 | """ 85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 86 | colorwheel = make_colorwheel() # shape [55x3] 87 | ncols = colorwheel.shape[0] 88 | rad = np.sqrt(np.square(u) + np.square(v)) 89 | a = np.arctan2(-v, -u)/np.pi 90 | fk = (a+1) / 2*(ncols-1) 91 | k0 = np.floor(fk).astype(np.int32) 92 | k1 = k0 + 1 93 | k1[k1 == ncols] = 0 94 | f = fk - k0 95 | for i in range(colorwheel.shape[1]): 96 | tmp = colorwheel[:,i] 97 | col0 = tmp[k0] / 255.0 98 | col1 = tmp[k1] / 255.0 99 | col = (1-f)*col0 + f*col1 100 | idx = (rad <= 1) 101 | col[idx] = 1 - rad[idx] * (1-col[idx]) 102 | col[~idx] = col[~idx] * 0.75 # out of range 103 | # Note the 2-i => BGR instead of RGB 104 | ch_idx = 2-i if convert_to_bgr else i 105 | flow_image[:,:,ch_idx] = np.floor(255 * col) 106 | return flow_image 107 | 108 | 109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 110 | """ 111 | Expects a two dimensional flow image of shape. 112 | 113 | Args: 114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 117 | 118 | Returns: 119 | np.ndarray: Flow visualization image of shape [H,W,3] 120 | """ 121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 123 | if clip_flow is not None: 124 | flow_uv = np.clip(flow_uv, 0, clip_flow) 125 | u = flow_uv[:,:,0] 126 | v = flow_uv[:,:,1] 127 | rad = np.sqrt(np.square(u) + np.square(v)) 128 | rad_max = np.max(rad) 129 | epsilon = 1e-5 130 | u = u / (rad_max + epsilon) 131 | v = v / (rad_max + epsilon) 132 | return flow_uv_to_colors(u, v, convert_to_bgr) -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.data as data 4 | import torch.optim as optim 5 | import numpy as np 6 | import random 7 | 8 | import datasets 9 | import augmentations 10 | import models 11 | import runners 12 | from utils.metrics import Metrics 13 | 14 | 15 | def configure_random_seed(seed, gpu=0): 16 | '''Set seeds''' 17 | seed = seed + gpu 18 | 19 | # python 20 | random.seed(seed) 21 | 22 | # numpy 23 | seed += 1 24 | np.random.seed(seed) 25 | 26 | # torch 27 | seed += 1 28 | torch.manual_seed(seed) 29 | 30 | # torch cuda 31 | seed += 1 32 | torch.cuda.manual_seed(seed) 33 | 34 | 35 | def get_train_dataloaders(args, rank=None): 36 | '''Set train dataloader''' 37 | if args.mode == 'flow_correc': 38 | augmentor = augmentations.AugmentorCorrections 39 | elif args.mode == 'flow_only': 40 | augmentor = augmentations.Augmentor 41 | else: 42 | raise NotImplementedError 43 | 44 | train_dataset = getattr(datasets, args.dataset_train)(args, augmentor, is_training=True, split='training') 45 | if rank == None: 46 | train_loader = data.DataLoader(train_dataset, 47 | batch_size=args.batch_size, 48 | pin_memory=True, 49 | shuffle=True, 50 | num_workers=args.num_workers, 51 | drop_last=True) 52 | else: 53 | train_loader = data.DataLoader(train_dataset, 54 | batch_size=args.batch_size, 55 | pin_memory=True, 56 | shuffle=False, 57 | num_workers=args.num_workers, 58 | drop_last=True, 59 | sampler=torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas=args.num_gpus, rank=rank)) 60 | return train_loader 61 | 62 | def get_test_dataloaders(args, rank=None): 63 | '''Set test dataloader''' 64 | test_dataset = getattr(datasets, args.dataset_test)(args, augmentor=None, is_training=False, split='training' if args.eval_on_train else 'validation') 65 | 66 | if rank == None: 67 | test_loader = data.DataLoader(test_dataset, 68 | batch_size=1, 69 | pin_memory=True, 70 | shuffle=False, 71 | num_workers=2, 72 | drop_last=False) 73 | else: 74 | test_loader = data.DataLoader(test_dataset, 75 | batch_size=1, 76 | pin_memory=True, 77 | shuffle=False, 78 | num_workers=2, 79 | drop_last=False, 80 | sampler=torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas=args.num_gpus, rank=rank, shuffle=False)) 81 | return test_loader 82 | 83 | def get_runner(args): 84 | '''Set runner (model(s) + losses)''' 85 | model = getattr(models, args.model) 86 | if args.mode == 'flow_correc': 87 | runner = runners.RunnerCorrection(model, args) 88 | elif args.mode == 'flow_only': 89 | runner = runners.Runner(model, args) 90 | else: 91 | raise NotImplementedError 92 | return runner 93 | 94 | def get_optimizer(args, model): 95 | """ Create the optimizer and learning rate scheduler """ 96 | if args.optimizer == 'adam': 97 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 98 | elif args.optimizer == 'adamw': 99 | optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon) 100 | elif args.optimizer == 'sgd': 101 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9) 102 | else: 103 | raise NotImplementedError 104 | 105 | if args.scheduler == 'smurf': 106 | # SMURF scheduler with optional warmup 107 | lambda_lr = lambda step: (0.001 + step/args.end_warmup_step if step < args.end_warmup_step else 1) * \ 108 | (0.5 ** ((step + args.init_step - (args.num_steps - args.lr_decay_step)) / (np.log(0.5)/np.log(args.lr_decay_max)*args.lr_decay_step)) if step + args.init_step > (args.num_steps - args.lr_decay_step) else 1.) 109 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda_lr) 110 | elif args.scheduler == 'raft': 111 | scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100, 112 | pct_start=0.05, cycle_momentum=False, anneal_strategy='linear') 113 | else: 114 | raise NotImplementedError 115 | 116 | return optimizer, scheduler 117 | 118 | 119 | def get_metrics(args): 120 | '''Set evaluator''' 121 | return Metrics(args, args.dataset_test) -------------------------------------------------------------------------------- /models/gma/update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .gma import Aggregate 5 | 6 | 7 | class FlowHead(nn.Module): 8 | def __init__(self, input_dim=128, hidden_dim=256): 9 | super(FlowHead, self).__init__() 10 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 11 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 12 | self.relu = nn.ReLU(inplace=True) 13 | 14 | def forward(self, x): 15 | return self.conv2(self.relu(self.conv1(x))) 16 | 17 | 18 | class ConvGRU(nn.Module): 19 | def __init__(self, hidden_dim=128, input_dim=128+128): 20 | super(ConvGRU, self).__init__() 21 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 22 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 23 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 24 | 25 | def forward(self, h, x): 26 | hx = torch.cat([h, x], dim=1) 27 | 28 | z = torch.sigmoid(self.convz(hx)) 29 | r = torch.sigmoid(self.convr(hx)) 30 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) 31 | 32 | h = (1-z) * h + z * q 33 | return h 34 | 35 | 36 | class SepConvGRU(nn.Module): 37 | def __init__(self, hidden_dim=128, input_dim=192+128): 38 | super(SepConvGRU, self).__init__() 39 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 40 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 41 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 42 | 43 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 44 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 45 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 46 | 47 | 48 | def forward(self, h, x): 49 | # horizontal 50 | hx = torch.cat([h, x], dim=1) 51 | z = torch.sigmoid(self.convz1(hx)) 52 | r = torch.sigmoid(self.convr1(hx)) 53 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) 54 | h = (1-z) * h + z * q 55 | 56 | # vertical 57 | hx = torch.cat([h, x], dim=1) 58 | z = torch.sigmoid(self.convz2(hx)) 59 | r = torch.sigmoid(self.convr2(hx)) 60 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) 61 | h = (1-z) * h + z * q 62 | 63 | return h 64 | 65 | 66 | class BasicMotionEncoder(nn.Module): 67 | def __init__(self, args): 68 | super(BasicMotionEncoder, self).__init__() 69 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 70 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) 71 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 72 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 73 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 74 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) 75 | 76 | def forward(self, flow, corr): 77 | cor = F.relu(self.convc1(corr)) 78 | cor = F.relu(self.convc2(cor)) 79 | flo = F.relu(self.convf1(flow)) 80 | flo = F.relu(self.convf2(flo)) 81 | 82 | cor_flo = torch.cat([cor, flo], dim=1) 83 | out = F.relu(self.conv(cor_flo)) 84 | return torch.cat([out, flow], dim=1) 85 | 86 | 87 | class BasicUpdateBlock(nn.Module): 88 | def __init__(self, args, hidden_dim=128, input_dim=128): 89 | super(BasicUpdateBlock, self).__init__() 90 | self.args = args 91 | self.encoder = BasicMotionEncoder(args) 92 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) 93 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 94 | 95 | self.mask = nn.Sequential( 96 | nn.Conv2d(128, 256, 3, padding=1), 97 | nn.ReLU(inplace=True), 98 | nn.Conv2d(256, 64*9, 1, padding=0)) 99 | 100 | def forward(self, net, inp, corr, flow, upsample=True): 101 | motion_features = self.encoder(flow, corr) 102 | inp = torch.cat([inp, motion_features], dim=1) 103 | 104 | net = self.gru(net, inp) 105 | delta_flow = self.flow_head(net) 106 | 107 | # scale mask to balence gradients 108 | mask = .25 * self.mask(net) 109 | return net, mask, delta_flow 110 | 111 | 112 | class GMAUpdateBlock(nn.Module): 113 | def __init__(self, args, hidden_dim=128): 114 | super().__init__() 115 | self.args = args 116 | self.encoder = BasicMotionEncoder(args) 117 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim+hidden_dim) 118 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 119 | 120 | self.mask = nn.Sequential( 121 | nn.Conv2d(128, 256, 3, padding=1), 122 | nn.ReLU(inplace=True), 123 | nn.Conv2d(256, 64*9, 1, padding=0)) 124 | 125 | self.aggregator = Aggregate(args=self.args, dim=128, dim_head=128, heads=self.args.num_heads) 126 | 127 | def forward(self, net, inp, corr, flow, attention): 128 | motion_features = self.encoder(flow, corr) 129 | motion_features_global = self.aggregator(attention, motion_features) 130 | inp_cat = torch.cat([inp, motion_features, motion_features_global], dim=1) 131 | 132 | # Attentional update 133 | net = self.gru(net, inp_cat) 134 | 135 | delta_flow = self.flow_head(net) 136 | 137 | # scale mask to balence gradients 138 | mask = .25 * self.mask(net) 139 | return net, mask, delta_flow 140 | 141 | 142 | 143 | -------------------------------------------------------------------------------- /losses/losses_corrections.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from .losses import LossBasic 7 | from utils.coords_and_warp import WarpMulti, WarpFullSizeMulti 8 | from utils.utils_corrections import apply_corrections_uncropped, apply_corrections 9 | from .loss_utils.photometric_loss_utils import PhotometricLossSequential, PhotometricLossCorrec 10 | 11 | 12 | class LossCorrections(LossBasic): 13 | def __init__(self, args): 14 | super(LossCorrections, self).__init__(args) 15 | 16 | self.warp = WarpFullSizeMulti() if args.use_full_size_warping else WarpMulti() 17 | 18 | assert self.args.correc_starting_step <= self.args.correc_in_photo_starting_step, 'corrections should not be used if the corrector is not trained' 19 | 20 | self.photometric_loss = PhotometricLossSequential(args) 21 | self.photometric_loss_correc = PhotometricLossCorrec(args) 22 | 23 | self.apply_corrections = apply_corrections_uncropped if self.args.use_full_size_warping else apply_corrections 24 | 25 | def forward(self, example, outputs, total_step): 26 | 27 | self.update_selfsup_weight(total_step) 28 | 29 | loss_dict = {} 30 | 31 | if total_step >= self.args.correc_starting_step: 32 | 33 | masks = outputs['masks'][:, :, 0] * example['masks_eraser'] 34 | 35 | # Apply correction on augmented images used in the correction loss 36 | # Then warp the corrected augmented images 37 | if self.args.use_full_size_warping: 38 | pad_params = example['pad_params'].int() 39 | orig_dims = example['orig_dims'].int() 40 | ims_aug_correc = self.apply_corrections(example['ims_aug_uncropped'], outputs['correcs_aug'] * masks, pad_params) 41 | ims_aug_warp_correc = self.warp(ims_aug_correc[:, self.bwd], outputs['flows_aug'][:, :, 0].detach(), pad_params, orig_dims) 42 | 43 | else: 44 | ims_aug_correc = self.apply_corrections(example['ims_aug'], outputs['correcs_aug'] * masks) 45 | ims_aug_warp_correc = self.warp(ims_aug_correc[:, self.bwd], outputs['coords'][:, :, 0].detach()) 46 | 47 | # Apply correction losses 48 | loss_correc = self.photometric_loss_correc(example['ims_aug'], ims_aug_warp_correc, masks) 49 | 50 | else: 51 | loss_correc = torch.tensor(0.0, device=torch.device('cuda')) 52 | 53 | if total_step >= self.args.correc_in_photo_starting_step: 54 | # Apply correction on non-augmented images used in the photometric loss 55 | with torch.no_grad(): 56 | masks = outputs['masks'][:, :, 0] * example['masks_eraser'] * outputs['best_indices'] 57 | if self.args.use_full_size_warping: 58 | pad_params = example['pad_params'].int() 59 | orig_dims = example['orig_dims'].int() 60 | example['ims_uncropped_correc'] = self.apply_corrections(example['ims_uncropped'], outputs['correcs'] * masks, pad_params, clamp_corrected_images=self.args.smart_clamp) 61 | 62 | else: 63 | example['ims_correc'] = self.apply_corrections(example['ims'], outputs['correcs'] * masks, clamp_corrected_images=self.args.smart_clamp) 64 | 65 | # Reconstruction: warping of corrected images with flows predictions 66 | if self.args.use_full_size_warping: 67 | pad_params = example['pad_params'].int() 68 | orig_dims = example['orig_dims'].int() 69 | example['ims_warp'] = self.warp(example['ims_uncropped_correc'][:, self.bwd], outputs['flows_aug'], pad_params, orig_dims) 70 | else: 71 | example['ims_warp'] = self.warp(example['ims_correc'][:, self.bwd], outputs['coords']) 72 | 73 | else: 74 | 75 | # Reconstruction: warping of images with flows predictions 76 | if self.args.use_full_size_warping: 77 | pad_params = example['pad_params'].int() 78 | orig_dims = example['orig_dims'].int() 79 | example['ims_warp'] = self.warp(example['ims_uncropped'][:, self.bwd], outputs['flows_aug'], pad_params, orig_dims) 80 | else: 81 | example['ims_warp'] = self.warp(example['ims'][:, self.bwd], outputs['coords']) 82 | 83 | # Computation of the photometric loss 84 | loss_photo = self.photometric_loss(example, outputs) 85 | 86 | # Computation of the smoothness loss 87 | loss_smooth = self.smoothness_loss(outputs['flows_aug'], example['ims'], example['masks_eraser'], total_step) 88 | 89 | # Computation of the selfsup loss 90 | loss_selfsup = torch.tensor(0.0, device=torch.device('cuda')) 91 | if total_step >= self.args.selfsup_starting_step: 92 | loss_selfsup += self.selfsup_loss(outputs['flows_stud'], outputs['flows_teacher'], example['masks_eraser_stud'], total_step) 93 | 94 | loss_dict['photo'] = loss_photo 95 | loss_dict['smooth'] = loss_smooth 96 | loss_dict['self'] = loss_selfsup 97 | loss_dict['correc'] = loss_correc 98 | loss_dict['mean_mask'] = outputs['masks'][:, :, 0].mean() 99 | loss_dict['mean_flow'] = outputs['flows_aug'][:, 0, 0].abs().mean() 100 | if 'correcs_aug' in outputs: 101 | loss_dict['mean_correc'] = outputs['correcs_aug'][:, 0].mean() 102 | if 'best_indices' in outputs: 103 | loss_dict['best_indices'] = outputs['best_indices'][:, 0].float().mean() 104 | 105 | loss_dict['loss_total'] = loss_photo + self.selfsup_weight * loss_selfsup.clamp(0., 100.) + self.args.smoothness_weight * loss_smooth + self.args.correc_weight * loss_correc 106 | 107 | return loss_dict 108 | -------------------------------------------------------------------------------- /models/gma/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .update import GMAUpdateBlock 6 | from .extractor import BasicEncoder 7 | from .corr import CorrBlock, CorrBlock_fb 8 | from .utils.utils import bilinear_sampler, coords_grid, upflow8 9 | from .gma import Attention, Aggregate 10 | 11 | try: 12 | autocast = torch.cuda.amp.autocast 13 | except: 14 | # dummy autocast for PyTorch < 1.6 15 | class autocast: 16 | def __init__(self, enabled): 17 | pass 18 | 19 | def __enter__(self): 20 | pass 21 | 22 | def __exit__(self, *args): 23 | pass 24 | 25 | 26 | class RAFTGMA(nn.Module): 27 | def __init__(self, args): 28 | super().__init__() 29 | self.args = args 30 | self.iters = args.iters 31 | 32 | self.hidden_dim = hdim = 128 33 | self.context_dim = cdim = 128 34 | args.corr_levels = 4 35 | args.corr_radius = 4 36 | 37 | if 'dropout' not in self.args: 38 | self.args.dropout = 0 39 | 40 | # feature network, context network, and update block 41 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout) 42 | self.cnet = BasicEncoder(output_dim=hdim + cdim, norm_fn='instance', dropout=args.dropout) 43 | self.update_block = GMAUpdateBlock(self.args, hidden_dim=hdim) 44 | self.att = Attention(args=self.args, dim=cdim, heads=self.args.num_heads, max_pos_size=160, dim_head=cdim) 45 | 46 | def freeze_bn(self): 47 | for m in self.modules(): 48 | if isinstance(m, nn.BatchNorm2d): 49 | m.eval() 50 | 51 | def initialize_flow(self, img, ctxt): 52 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 53 | _, _, H, W = img.shape 54 | N = ctxt.size(0) 55 | coords0 = coords_grid(N, H // 8, W // 8, device=img.device) 56 | coords1 = coords_grid(N, H // 8, W // 8, device=img.device) 57 | 58 | # optical flow computed as difference: flow = coords1 - coords0 59 | return coords0, coords1 60 | 61 | def upsample_flow(self, flow, mask): 62 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 63 | N, _, H, W = flow.shape 64 | mask = mask.view(N, 1, 9, 8, 8, H, W) 65 | mask = torch.softmax(mask, dim=2) 66 | 67 | up_flow = F.unfold(8 * flow, [3, 3], padding=1) 68 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 69 | 70 | up_flow = torch.sum(mask * up_flow, dim=2) 71 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 72 | return up_flow.reshape(N, 2, 8 * H, 8 * W) 73 | 74 | def compute_flow(self, image, ctxt, corr_fn, flow_init=None, return_last_flow_only=False): 75 | 76 | net, inp = torch.split(ctxt, [self.hidden_dim, self.context_dim], dim=1) 77 | net = torch.tanh(net) 78 | inp = torch.relu(inp) 79 | # attention, att_c, att_p = self.att(inp) 80 | attention = self.att(inp) 81 | 82 | coords0, coords1 = self.initialize_flow(image, ctxt) 83 | 84 | if flow_init is not None: 85 | coords1 = coords1 + flow_init 86 | 87 | flow_predictions = [] 88 | for _ in range(self.iters): 89 | coords1 = coords1.detach() 90 | corr = corr_fn(coords1) # index correlation volume 91 | 92 | flow = coords1 - coords0 93 | with autocast(enabled=self.args.mixed_precision): 94 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow, attention) 95 | 96 | # F(t+1) = F(t) + \Delta(t) 97 | coords1 = coords1 + delta_flow 98 | flow = coords1 - coords0 99 | 100 | # upsample predictions 101 | if up_mask is None: 102 | flow_up = upflow8(flow, mode=self.args.upsampling_mode) 103 | else: 104 | flow_up = self.upsample_flow(flow, up_mask) 105 | 106 | flow_predictions.append(flow_up) 107 | 108 | if return_last_flow_only: 109 | return flow_up 110 | 111 | flow_predictions.reverse() 112 | return flow_predictions 113 | 114 | 115 | def forward(self, images, flow_init=None, return_last_flow_only=False, fwd_bwd=True, suffix=''): 116 | """ Estimate optical flow between pair of frames """ 117 | outputs = {} 118 | 119 | dims = images.size() 120 | images_norm = (2 * (images / 255.0) - 1.0).contiguous() 121 | images_norm_flatten = images_norm.flatten(end_dim=-4) 122 | 123 | # run the feature network 124 | with autocast(enabled=self.args.mixed_precision): 125 | fmaps = self.fnet(images_norm_flatten).unflatten(dim=0, sizes=dims[:2]) 126 | 127 | corrs_fn = CorrBlock_fb(fmaps[:, 0], fmaps[:, 1], radius=self.args.corr_radius, fwd_bwd=fwd_bwd) 128 | 129 | if fwd_bwd: 130 | with autocast(enabled=self.args.mixed_precision): 131 | ctxts = self.cnet(images_norm_flatten) 132 | 133 | flows = self.compute_flow(images_norm_flatten, ctxts, corrs_fn, return_last_flow_only=return_last_flow_only) 134 | 135 | if return_last_flow_only: 136 | outputs['flows' + suffix] = flows.unflatten(dim=0, sizes=dims[:2]) 137 | else: 138 | outputs['flows' + suffix] = torch.stack(flows, dim=-4).unflatten(dim=0, sizes=dims[:2]) 139 | 140 | else: 141 | images_norm1 = images_norm[:, 0] 142 | with autocast(enabled=self.args.mixed_precision): 143 | ctxt1 = self.cnet(images_norm1) 144 | 145 | outputs['flow_f' + suffix] = self.compute_flow(images_norm1, ctxt1, corrs_fn, flow_init=flow_init, return_last_flow_only=True) 146 | 147 | return outputs 148 | -------------------------------------------------------------------------------- /models/scv/update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FlowHead(nn.Module): 7 | def __init__(self, input_dim=128, hidden_dim=256): 8 | super(FlowHead, self).__init__() 9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 11 | self.relu = nn.ReLU(inplace=True) 12 | 13 | def forward(self, x): 14 | return self.conv2(self.relu(self.conv1(x))) 15 | 16 | 17 | class ConvGRU(nn.Module): 18 | def __init__(self, hidden_dim=128, input_dim=192+128): 19 | super(ConvGRU, self).__init__() 20 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 21 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 22 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 23 | 24 | def forward(self, h, x): 25 | hx = torch.cat([h, x], dim=1) 26 | 27 | z = torch.sigmoid(self.convz(hx)) 28 | r = torch.sigmoid(self.convr(hx)) 29 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) 30 | 31 | h = (1-z) * h + z * q 32 | return h 33 | 34 | 35 | class SepConvGRU(nn.Module): 36 | def __init__(self, hidden_dim=128, input_dim=192+128): 37 | super(SepConvGRU, self).__init__() 38 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 39 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 40 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 41 | 42 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 43 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 44 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 45 | 46 | def forward(self, h, x): 47 | # horizontal 48 | hx = torch.cat([h, x], dim=1) 49 | z = torch.sigmoid(self.convz1(hx)) 50 | r = torch.sigmoid(self.convr1(hx)) 51 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) 52 | h = (1-z) * h + z * q 53 | 54 | # vertical 55 | hx = torch.cat([h, x], dim=1) 56 | z = torch.sigmoid(self.convz2(hx)) 57 | r = torch.sigmoid(self.convr2(hx)) 58 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) 59 | h = (1-z) * h + z * q 60 | 61 | return h 62 | 63 | 64 | class ResidualBlock(nn.Module): 65 | def __init__(self, in_planes, planes, dilation=(1, 1), kernel_size=3): 66 | super(ResidualBlock, self).__init__() 67 | 68 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=kernel_size, padding=(kernel_size-1)//2, dilation=dilation[0]) 69 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=kernel_size, padding=(kernel_size-1)//2, dilation=dilation[1]) 70 | self.relu = nn.ReLU(inplace=True) 71 | 72 | self.projector = nn.Conv2d(in_planes, planes, kernel_size=1) 73 | 74 | def forward(self, x): 75 | y = x 76 | y = self.relu(self.conv1(y)) 77 | y = self.relu(self.conv2(y)) 78 | 79 | if self.projector is not None: 80 | x = self.projector(x) 81 | 82 | return self.relu(x + y) 83 | 84 | 85 | class BasicMotionEncoder(nn.Module): 86 | def __init__(self, args, input_dim=128): 87 | super().__init__() 88 | self.convc1 = nn.Conv2d(input_dim, 256, 1) 89 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 90 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 91 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 92 | self.conv = nn.Conv2d(192+64, 128-2, 3, padding=1) 93 | 94 | def forward(self, flow, corr): 95 | cor = F.relu(self.convc1(corr)) 96 | cor = F.relu(self.convc2(cor)) 97 | flo = F.relu(self.convf1(flow)) 98 | flo = F.relu(self.convf2(flo)) 99 | 100 | cor_flo = torch.cat([cor, flo], dim=1) 101 | out = F.relu(self.conv(cor_flo)) 102 | return torch.cat([out, flow], dim=1) 103 | 104 | 105 | class BasicUpdateBlock(nn.Module): 106 | def __init__(self, args, hidden_dim=128, input_dim=128): 107 | super(BasicUpdateBlock, self).__init__() 108 | self.args = args 109 | self.encoder = BasicMotionEncoder(args, input_dim=input_dim) 110 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) 111 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 112 | 113 | self.mask = nn.Sequential( 114 | nn.Conv2d(128, 256, 3, padding=1), 115 | nn.ReLU(inplace=True), 116 | nn.Conv2d(256, 64*9, 1, padding=0)) 117 | 118 | def forward(self, net, inp, corr, flow): 119 | motion_features = self.encoder(flow, corr) 120 | inp = torch.cat([inp, motion_features], dim=1) 121 | 122 | net = self.gru(net, inp) 123 | delta_flow = self.flow_head(net) 124 | 125 | # scale mask to balence gradients 126 | mask = .25 * self.mask(net) 127 | return net, mask, delta_flow 128 | 129 | 130 | class BasicUpdateBlockQuarter(nn.Module): 131 | def __init__(self, args, hidden_dim=128, input_dim=128): 132 | super(BasicUpdateBlockQuarter, self).__init__() 133 | self.args = args 134 | self.encoder = BasicMotionEncoder(args, input_dim=input_dim) 135 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) 136 | self.flow_head = FlowHead(input_dim=hidden_dim, hidden_dim=256) 137 | 138 | self.mask = nn.Sequential( 139 | nn.Conv2d(128, 256, 3, padding=1), 140 | nn.ReLU(inplace=True), 141 | nn.Conv2d(256, 16*9, 1, padding=0)) 142 | 143 | def forward(self, net, inp, corr, flow): 144 | motion_features = self.encoder(flow, corr) 145 | inp = torch.cat([inp, motion_features], dim=1) 146 | 147 | net = self.gru(net, inp) 148 | delta_flow = self.flow_head(net) 149 | 150 | # scale mask to balence gradients 151 | mask = .25 * self.mask(net) 152 | return net, mask, delta_flow 153 | 154 | -------------------------------------------------------------------------------- /models/raft/corrector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ResidualBlock(nn.Module): 7 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 8 | super(ResidualBlock, self).__init__() 9 | 10 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) 11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 12 | self.relu = nn.ReLU(inplace=True) 13 | 14 | num_groups = planes // 8 15 | 16 | if norm_fn == 'group': 17 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 18 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 19 | if not stride == 1: 20 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 21 | 22 | elif norm_fn == 'batch': 23 | self.norm1 = nn.BatchNorm2d(planes) 24 | self.norm2 = nn.BatchNorm2d(planes) 25 | if not stride == 1: 26 | self.norm3 = nn.BatchNorm2d(planes) 27 | 28 | elif norm_fn == 'instance': 29 | self.norm1 = nn.InstanceNorm2d(planes) 30 | self.norm2 = nn.InstanceNorm2d(planes) 31 | if not stride == 1: 32 | self.norm3 = nn.InstanceNorm2d(planes) 33 | 34 | elif norm_fn == 'none': 35 | self.norm1 = nn.Sequential() 36 | self.norm2 = nn.Sequential() 37 | if not stride == 1: 38 | self.norm3 = nn.Sequential() 39 | 40 | if stride == 1: 41 | self.downsample = None 42 | 43 | else: 44 | self.downsample = nn.Sequential( 45 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 46 | 47 | 48 | def forward(self, x): 49 | y = x 50 | y = self.relu(self.norm1(self.conv1(y))) 51 | y = self.relu(self.norm2(self.conv2(y))) 52 | 53 | if self.downsample is not None: 54 | x = self.downsample(x) 55 | 56 | return self.relu(x+y) 57 | 58 | class OffsetHead(nn.Module): 59 | def __init__(self, input_dim=128, hidden_dim=256, output_dim=1): 60 | super(OffsetHead, self).__init__() 61 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 62 | self.conv2 = nn.Conv2d(hidden_dim, output_dim, 3, padding=1) 63 | self.relu = nn.ReLU(inplace=True) 64 | 65 | def forward(self, x): 66 | return torch.tanh(self.conv2(self.relu(self.conv1(x)))) 67 | 68 | class Corrector(nn.Module): 69 | def __init__(self, args, intput_dim=7, output_dim=128, norm_fn='instance'): 70 | super(Corrector, self).__init__() 71 | self.norm_fn = norm_fn 72 | self.args = args 73 | 74 | if self.norm_fn == 'group': 75 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) 76 | 77 | elif self.norm_fn == 'batch': 78 | self.norm1 = nn.BatchNorm2d(64) 79 | 80 | elif self.norm_fn == 'instance': 81 | self.norm1 = nn.InstanceNorm2d(64) 82 | 83 | elif self.norm_fn == 'none': 84 | self.norm1 = nn.Sequential() 85 | 86 | self.conv1 = nn.Conv2d(intput_dim, 64, kernel_size=7, stride=2, padding=3) 87 | self.relu1 = nn.ReLU(inplace=True) 88 | 89 | self.in_planes = 64 90 | self.layer1 = self._make_layer(64, stride=1) 91 | self.layer2 = self._make_layer(96, stride=2) 92 | self.layer3 = self._make_layer(128, stride=2) 93 | 94 | # output convolution 95 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) 96 | 97 | for m in self.modules(): 98 | if isinstance(m, nn.Conv2d): 99 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 100 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 101 | if m.weight is not None: 102 | nn.init.constant_(m.weight, 1) 103 | if m.bias is not None: 104 | nn.init.constant_(m.bias, 0) 105 | 106 | self.offset_head = OffsetHead(output_dim=3) 107 | 108 | self.mask = nn.Sequential( 109 | nn.Conv2d(128, 256, 3, padding=1), 110 | nn.ReLU(inplace=True), 111 | nn.Conv2d(256, 64*9, 1, padding=0)) 112 | 113 | def _make_layer(self, dim, stride=1): 114 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) 115 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) 116 | layers = (layer1, layer2) 117 | 118 | self.in_planes = dim 119 | return nn.Sequential(*layers) 120 | 121 | def upsample_offsets(self, offsets, mask): 122 | """ Upsample offsets field [H/8, W/8, 1] -> [H, W, 1] using convex combination """ 123 | N, _, H, W = offsets.shape 124 | mask = mask.view(N, 1, 9, 8, 8, H, W) 125 | mask = torch.softmax(mask, dim=2) 126 | 127 | up_offsets = F.unfold(offsets, [3,3], padding=1) 128 | up_offsets = up_offsets.view(N, 3, 9, 1, 1, H, W) 129 | 130 | up_offsets = torch.sum(mask * up_offsets, dim=2) 131 | up_offsets = up_offsets.permute(0, 1, 4, 2, 5, 3) 132 | return up_offsets.reshape(N, 3, 8*H, 8*W) 133 | 134 | 135 | def forward(self, x): 136 | # is_list = isinstance(x, tuple) or isinstance(x, list) 137 | # if is_list: 138 | # x = torch.cat(x, dim=0) 139 | shape_5 = len(x.shape) == 5 140 | if shape_5: 141 | x = self.conv1(x.flatten(end_dim=-4)) 142 | else: 143 | x = self.conv1(x) 144 | 145 | x = self.norm1(x) 146 | x = self.relu1(x) 147 | 148 | x = self.layer1(x) 149 | x = self.layer2(x) 150 | x = self.layer3(x) 151 | 152 | x = self.conv2(x) 153 | 154 | offsets = self.offset_head(x) 155 | 156 | mask = .25 * self.mask(x) 157 | 158 | up_offsets = self.upsample_offsets(offsets, mask) 159 | 160 | # if is_list: 161 | # return torch.split(up_offsets, self.args.batch_size, dim=0) 162 | if shape_5: 163 | return up_offsets.unflatten(dim=0, sizes=(-1, 2)) 164 | 165 | return up_offsets -------------------------------------------------------------------------------- /models/gma/corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from .utils.utils import bilinear_sampler, coords_grid 6 | # from compute_sparse_correlation import compute_sparse_corr, compute_sparse_corr_torch, compute_sparse_corr_mink 7 | 8 | try: 9 | import alt_cuda_corr 10 | except: 11 | # alt_cuda_corr is not compiled 12 | pass 13 | 14 | 15 | class CorrBlock: 16 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 17 | self.num_levels = num_levels 18 | self.radius = radius 19 | self.corr_pyramid = [] 20 | 21 | # all pairs correlation 22 | corr = CorrBlock.corr(fmap1, fmap2) 23 | 24 | batch, h1, w1, dim, h2, w2 = corr.shape 25 | corr = corr.reshape(batch * h1 * w1, dim, h2, w2) 26 | 27 | self.corr_pyramid.append(corr) 28 | for i in range(self.num_levels - 1): 29 | corr = F.avg_pool2d(corr, 2, stride=2) 30 | self.corr_pyramid.append(corr) 31 | 32 | def __call__(self, coords): 33 | r = self.radius 34 | coords = coords.permute(0, 2, 3, 1) 35 | batch, h1, w1, _ = coords.shape 36 | 37 | out_pyramid = [] 38 | for i in range(self.num_levels): 39 | corr = self.corr_pyramid[i] 40 | dx = torch.linspace(-r, r, 2 * r + 1) 41 | dy = torch.linspace(-r, r, 2 * r + 1) 42 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) 43 | 44 | centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2 ** i 45 | delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) 46 | coords_lvl = centroid_lvl + delta_lvl 47 | 48 | corr = bilinear_sampler(corr, coords_lvl) 49 | corr = corr.view(batch, h1, w1, -1) 50 | out_pyramid.append(corr) 51 | 52 | out = torch.cat(out_pyramid, dim=-1) 53 | return out.permute(0, 3, 1, 2).contiguous().float() 54 | 55 | @staticmethod 56 | def corr(fmap1, fmap2): 57 | batch, dim, ht, wd = fmap1.shape 58 | fmap1 = fmap1.view(batch, dim, ht * wd) 59 | fmap2 = fmap2.view(batch, dim, ht * wd) 60 | 61 | corr = torch.matmul(fmap1.transpose(1, 2), fmap2) 62 | corr = corr.view(batch, ht, wd, 1, ht, wd) 63 | return corr / torch.sqrt(torch.tensor(dim).float()) 64 | 65 | 66 | class CorrBlockSingleScale(nn.Module): 67 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 68 | super().__init__() 69 | self.radius = radius 70 | 71 | # all pairs correlation 72 | corr = CorrBlock.corr(fmap1, fmap2) 73 | batch, h1, w1, dim, h2, w2 = corr.shape 74 | self.corr = corr.reshape(batch * h1 * w1, dim, h2, w2) 75 | 76 | def __call__(self, coords): 77 | r = self.radius 78 | coords = coords.permute(0, 2, 3, 1) 79 | batch, h1, w1, _ = coords.shape 80 | 81 | corr = self.corr 82 | dx = torch.linspace(-r, r, 2 * r + 1) 83 | dy = torch.linspace(-r, r, 2 * r + 1) 84 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) 85 | 86 | centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) 87 | delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) 88 | coords_lvl = centroid_lvl + delta_lvl 89 | 90 | corr = bilinear_sampler(corr, coords_lvl) 91 | out = corr.view(batch, h1, w1, -1) 92 | out = out.permute(0, 3, 1, 2).contiguous().float() 93 | return out 94 | 95 | @staticmethod 96 | def corr(fmap1, fmap2): 97 | batch, dim, ht, wd = fmap1.shape 98 | fmap1 = fmap1.view(batch, dim, ht * wd) 99 | fmap2 = fmap2.view(batch, dim, ht * wd) 100 | 101 | corr = torch.matmul(fmap1.transpose(1, 2), fmap2) 102 | corr = corr.view(batch, ht, wd, 1, ht, wd) 103 | return corr / torch.sqrt(torch.tensor(dim).float()) 104 | 105 | 106 | class CorrBlock_fb: 107 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4, fwd_bwd=False): 108 | self.num_levels = num_levels 109 | self.radius = radius 110 | self.corrs_pyramid = [] 111 | 112 | # all pairs correlation 113 | corr1 = CorrBlock_fb.corr(fmap1, fmap2) 114 | batch, h1, w1, dim, h2, w2 = corr1.shape 115 | 116 | if fwd_bwd: 117 | corr2 = corr1.permute(0, 4, 5, 3, 1, 2) 118 | corrs = torch.stack([corr1, corr2], 1).reshape(2*batch*h1*w1, dim, h2, w2) 119 | self.correlations = torch.stack([corr1.view(batch, h1*w1, h2*w2), corr2.view(batch, h2*w2, h1*w1)], dim=1) 120 | 121 | else: 122 | corrs = corr1.reshape(batch*h1*w1, dim, h2, w2) 123 | self.correlations = corr1.view(batch, h1*w1, h2, w2) 124 | 125 | self.corrs_pyramid.append(corrs) 126 | for i in range(self.num_levels-1): 127 | corrs = F.avg_pool2d(corrs, 2, stride=2) 128 | self.corrs_pyramid.append(corrs) 129 | 130 | def __call__(self, coords): 131 | r = self.radius 132 | coords = coords.permute(0, 2, 3, 1) 133 | batch, h1, w1, _ = coords.shape 134 | 135 | dx = torch.linspace(-r, r, 2*r+1, device=coords.device) 136 | dy = torch.linspace(-r, r, 2*r+1, device=coords.device) 137 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1) 138 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) 139 | 140 | out_pyramid = [] 141 | for i in range(self.num_levels): 142 | corr = self.corrs_pyramid[i] 143 | 144 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i 145 | coords_lvl = centroid_lvl + delta_lvl 146 | 147 | corr = bilinear_sampler(corr, coords_lvl) 148 | corr = corr.view(batch, h1, w1, -1) 149 | out_pyramid.append(corr) 150 | 151 | out = torch.cat(out_pyramid, dim=-1) 152 | return out.permute(0, 3, 1, 2).contiguous().float() 153 | 154 | @staticmethod 155 | def corr(fmap1, fmap2): 156 | batch, dim, ht1, wd1 = fmap1.shape 157 | _, _, ht2, wd2 = fmap2.shape 158 | fmap1 = fmap1.view(batch, dim, ht1*wd1) 159 | fmap2 = fmap2.view(batch, dim, ht2*wd2) 160 | 161 | corr = torch.matmul(fmap1.transpose(1,2), fmap2) 162 | corr = corr.view(batch, ht1, wd1, 1, ht2, wd2) 163 | return corr / torch.sqrt(torch.tensor(dim).float()) 164 | -------------------------------------------------------------------------------- /models/raft/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from scipy import interpolate 6 | 7 | 8 | class InputPadder: 9 | """ Pads images such that dimensions are divisible by 8 """ 10 | def __init__(self, dims, mode='sintel'): 11 | self.ht, self.wd = dims[-2:] 12 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 13 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 14 | if mode == 'sintel': 15 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 16 | else: 17 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] 18 | 19 | def pad(self, *inputs): 20 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 21 | 22 | def unpad(self,x): 23 | ht, wd = x.shape[-2:] 24 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 25 | return x[..., c[0]:c[1], c[2]:c[3]] 26 | 27 | def forward_interpolate(flow): 28 | flow = flow.detach().cpu().numpy() 29 | dx, dy = flow[0], flow[1] 30 | 31 | ht, wd = dx.shape 32 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 33 | 34 | x1 = x0 + dx 35 | y1 = y0 + dy 36 | 37 | x1 = x1.reshape(-1) 38 | y1 = y1.reshape(-1) 39 | dx = dx.reshape(-1) 40 | dy = dy.reshape(-1) 41 | 42 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 43 | x1 = x1[valid] 44 | y1 = y1[valid] 45 | dx = dx[valid] 46 | dy = dy[valid] 47 | 48 | flow_x = interpolate.griddata( 49 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 50 | 51 | flow_y = interpolate.griddata( 52 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) 53 | 54 | flow = np.stack([flow_x, flow_y], axis=0) 55 | return torch.from_numpy(flow).float() 56 | 57 | 58 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 59 | """ Wrapper for grid_sample, uses pixel coordinates """ 60 | H, W = img.shape[-2:] 61 | xgrid, ygrid = coords.split([1,1], dim=-1) 62 | xgrid = 2*xgrid/(W-1) - 1 63 | ygrid = 2*ygrid/(H-1) - 1 64 | 65 | grid = torch.cat([xgrid, ygrid], dim=-1) 66 | img = F.grid_sample(img, grid, align_corners=True) 67 | 68 | if mask: 69 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 70 | return img, mask.float() 71 | 72 | return img 73 | 74 | 75 | def coords_grid(batch, ht, wd, device): 76 | coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device)) 77 | coords = torch.stack(coords[::-1], dim=0).float() 78 | return coords[None].expand(batch, -1, -1, -1) 79 | 80 | 81 | def upflow8(flow, mode='bilinear'): 82 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 83 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 84 | 85 | 86 | def positionalencoding2d(d_model, height, width): 87 | """ 88 | :param d_model: dimension of the model 89 | :param height: height of the positions 90 | :param width: width of the positions 91 | :return: d_model*height*width position matrix 92 | """ 93 | if d_model % 4 != 0: 94 | raise ValueError("Cannot use sin/cos positional encoding with " 95 | "odd dimension (got dim={:d})".format(d_model)) 96 | pe = torch.zeros(d_model, height, width).cuda() 97 | # Each dimension use half of d_model 98 | d_model = int(d_model / 2) 99 | div_term = torch.exp(torch.arange(0., d_model, 2) * 100 | -(np.log(10000.0) / d_model)) 101 | pos_w = torch.arange(0., width).unsqueeze(1) 102 | pos_h = torch.arange(0., height).unsqueeze(1) 103 | pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).expand(-1, height, -1) 104 | pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).expand(-1, height, -1) 105 | pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).expand(-1, -1, width) 106 | pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).expand(-1, -1, width) 107 | idx = torch.cat([torch.LongTensor(range(d_model//4)), torch.LongTensor(range(d_model//2,d_model//2 + d_model//4))]) 108 | 109 | return pe[idx] 110 | 111 | 112 | def mask_corr(pad_size, dims): 113 | H, W = dims 114 | top = torch.zeros(1, pad_size, H, W, device=torch.device('cuda')) 115 | bottom = torch.zeros(1, pad_size, H, W, device=torch.device('cuda')) 116 | left = torch.zeros(1, pad_size, H, W, device=torch.device('cuda')) 117 | right = torch.zeros(1, pad_size, H, W, device=torch.device('cuda')) 118 | 119 | for w in range(W): 120 | for i in range(pad_size): 121 | if w <= i: 122 | left[:, i, :, w] = 1 123 | if W-w <= i+1: 124 | right[:, i, :, w] = 1 125 | 126 | for h in range(H): 127 | for i in range(pad_size): 128 | if h <= i: 129 | top[:, i, h] = 1 130 | if H-h <= i+1: 131 | bottom[:, i, h] = 1 132 | 133 | return torch.cat([top, bottom, left, right], dim=1) 134 | 135 | 136 | def mask_padding(pad_size, like): 137 | zeros = torch.zeros_like(like) 138 | return F.pad(zeros, (pad_size, pad_size, pad_size, pad_size), value=1.) 139 | 140 | 141 | class PadFmaps(nn.Module): 142 | def __init__(self, pad_size, ch, dims): 143 | super(PadFmaps, self).__init__() 144 | self.pad_size = pad_size 145 | self.smart_pad = nn.Sequential(*[nn.ConvTranspose2d(ch, ch, 3, stride=1), nn.LeakyReLU(0.1, inplace=True), 146 | nn.ConvTranspose2d(ch, ch, 3, stride=1), nn.LeakyReLU(0.1, inplace=True)]*(pad_size//2)) 147 | self.zero_pad = nn.ZeroPad2d(pad_size) 148 | 149 | def forward(self, fmap): 150 | zeros_pad = self.zero_pad(fmap) 151 | smart_pad = self.smart_pad(fmap) 152 | return zeros_pad + mask_padding(self.pad_size, fmap).to(smart_pad.device) * smart_pad 153 | 154 | 155 | def get_out_corrs(coords, pad_size, radius): 156 | B, _, H, W = coords.size() 157 | coords = coords - pad_size 158 | coords_h = coords[:, 1] 159 | coords_w = coords[:, 0] 160 | coords_h[coords_h==torch.clamp(coords_h, radius , H - radius)]=0 161 | coords_w[coords_w==torch.clamp(coords_w, radius , W - radius)]=0 162 | coords = coords - coords_grid(B, H, W, coords.device) * (coords != 0) 163 | return F.hardtanh(coords/pad_size) 164 | -------------------------------------------------------------------------------- /models/raft/raft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .update import BasicUpdateBlock, SmallUpdateBlock 6 | from .extractor import BasicEncoder, SmallEncoder 7 | from .corr import CorrBlock, CorrBlock_fb, AlternateCorrBlock 8 | from .utils.utils import bilinear_sampler, coords_grid, upflow8 9 | 10 | try: 11 | autocast = torch.cuda.amp.autocast 12 | except: 13 | # dummy autocast for PyTorch < 1.6 14 | class autocast: 15 | def __init__(self, enabled): 16 | pass 17 | 18 | def __enter__(self): 19 | pass 20 | 21 | def __exit__(self, *args): 22 | pass 23 | 24 | 25 | class RAFT(nn.Module): 26 | def __init__(self, args): 27 | super().__init__() 28 | self.args = args 29 | self.iters = args.iters 30 | self.enc_dims = [32, 32, 64, 96] if args.small else [64, 64, 96, 128, 256] 31 | self.out_dim_f = 128 if args.small else 256 32 | # self.num_samples = self.out_dim_f if args.num_samples is None else args.num_samples 33 | 34 | if args.small: 35 | self.hidden_dim = hdim = 96 36 | self.context_dim = cdim = 64 37 | args.corr_levels = 4 38 | args.corr_radius = 3 39 | 40 | else: 41 | self.hidden_dim = hdim = 128 42 | self.context_dim = cdim = 128 43 | args.corr_levels = 4 44 | args.corr_radius = 4 45 | 46 | self.out_dim_c = hdim+cdim 47 | 48 | if 'dropout' not in self.args: 49 | self.args.dropout = 0 50 | 51 | if 'alternate_corr' not in self.args: 52 | self.args.alternate_corr = False 53 | 54 | # feature network, context network, and update block 55 | if args.small: 56 | self.fnet = SmallEncoder(layers_dims=self.enc_dims, output_dim=self.out_dim_f, norm_fn='instance', dropout=args.dropout, pad_mode=args.pad_mode) 57 | self.cnet = SmallEncoder(layers_dims=self.enc_dims, output_dim=self.out_dim_c, norm_fn='none', dropout=args.dropout) 58 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim, out_dim=args.dim_out_flow) 59 | 60 | else: 61 | self.fnet = BasicEncoder(layers_dims=self.enc_dims, output_dim=self.out_dim_f, norm_fn='instance', dropout=args.dropout, pad_mode=args.pad_mode) 62 | self.cnet = BasicEncoder(layers_dims=self.enc_dims, output_dim=self.out_dim_c, norm_fn='instance', dropout=args.dropout) 63 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim, out_dim=args.dim_out_flow) 64 | 65 | 66 | def freeze_bn(self): 67 | for m in self.modules(): 68 | if isinstance(m, nn.BatchNorm2d): 69 | m.eval() 70 | 71 | def initialize_flow(self, img, ctxt): 72 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 73 | _, _, H, W = img.shape 74 | N = ctxt.size(0) 75 | coords0 = coords_grid(N, H // 8, W // 8, device=img.device) 76 | coords1 = coords_grid(N, H // 8, W // 8, device=img.device) 77 | 78 | # optical flow computed as difference: flow = coords1 - coords0 79 | return coords0, coords1 80 | 81 | def upsample_flow(self, flow, mask): 82 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 83 | N, _, H, W = flow.shape 84 | mask = mask.view(N, 1, 9, 8, 8, H, W) 85 | mask = torch.softmax(mask, dim=2) 86 | 87 | up_flow = F.unfold(8 * flow, [3, 3], padding=1) 88 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 89 | 90 | up_flow = torch.einsum('bdn..., bdn...->bd...', mask, up_flow) 91 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 92 | return up_flow.reshape(N, 2, 8 * H, 8 * W) 93 | 94 | def compute_flow(self, image, ctxt, corr_fn, flow_init=None, return_last_flow_only=False): 95 | 96 | net, inp = torch.split(ctxt, [self.hidden_dim, self.context_dim], dim=1) 97 | net = torch.tanh(net) 98 | inp = torch.relu(inp) 99 | 100 | coords0, coords1 = self.initialize_flow(image, ctxt) 101 | 102 | if flow_init is not None: 103 | coords1 = coords1 + flow_init 104 | 105 | flow_predictions = [] 106 | for _ in range(self.iters): 107 | coords1 = coords1.detach() 108 | corr = corr_fn(coords1) # index correlation volume 109 | 110 | flow = coords1 - coords0 111 | with autocast(enabled=self.args.mixed_precision): 112 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) 113 | 114 | # F(t+1) = F(t) + \Delta(t) 115 | coords1 = coords1 + delta_flow 116 | flow = coords1 - coords0 117 | 118 | # upsample predictions 119 | if up_mask is None: 120 | flow_up = upflow8(flow, mode=self.args.upsampling_mode) 121 | else: 122 | flow_up = self.upsample_flow(flow, up_mask) 123 | 124 | flow_predictions.append(flow_up) 125 | 126 | if return_last_flow_only: 127 | return flow_up 128 | 129 | flow_predictions.reverse() 130 | return flow_predictions 131 | 132 | 133 | def forward(self, images, flow_init=None, return_last_flow_only=False, fwd_bwd=True, suffix=''): 134 | """ Estimate optical flow between pair of frames """ 135 | outputs = {} 136 | 137 | dims = images.size() 138 | images_norm = (2 * (images / 255.0) - 1.0).contiguous() 139 | images_norm_flatten = images_norm.flatten(end_dim=-4) 140 | 141 | # run the feature network 142 | with autocast(enabled=self.args.mixed_precision): 143 | fmaps = self.fnet(images_norm_flatten).unflatten(dim=0, sizes=dims[:2]) 144 | 145 | corrs_fn = CorrBlock_fb(fmaps[:, 0], fmaps[:, 1], radius=self.args.corr_radius, fwd_bwd=fwd_bwd) 146 | 147 | if fwd_bwd: 148 | with autocast(enabled=self.args.mixed_precision): 149 | ctxts = self.cnet(images_norm_flatten) 150 | 151 | flows = self.compute_flow(images_norm_flatten, ctxts, corrs_fn, return_last_flow_only=return_last_flow_only) 152 | 153 | if return_last_flow_only: 154 | outputs['flows' + suffix] = flows.unflatten(dim=0, sizes=dims[:2]) 155 | else: 156 | outputs['flows' + suffix] = torch.stack(flows, dim=-4).unflatten(dim=0, sizes=dims[:2]) 157 | 158 | else: 159 | images_norm1 = images_norm[:, 0] 160 | with autocast(enabled=self.args.mixed_precision): 161 | ctxt1 = self.cnet(images_norm1) 162 | 163 | outputs['flow_f' + suffix] = self.compute_flow(images_norm1, ctxt1, corrs_fn, flow_init=flow_init, return_last_flow_only=True) 164 | 165 | return outputs 166 | -------------------------------------------------------------------------------- /augmentations/augmentations_corrections.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import random 4 | from PIL import Image 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | from .augmentations import BasicAugmentor 11 | 12 | class AugmentorCorrections(BasicAugmentor): 13 | def __init__(self, args): 14 | super(AugmentorCorrections, self).__init__(args) 15 | 16 | def spatial_transform_correc_full_size_warping(self, img1, img2, img1_correc, img2_correc): 17 | 18 | # randomly sample scale 19 | if np.random.rand() < self.spatial_aug_prob: 20 | ht, wd = img1.shape[:2] 21 | min_scale = np.maximum( 22 | (self.crop_size[0] + 8) / float(ht), 23 | (self.crop_size[1] + 8) / float(wd)) 24 | 25 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 26 | scale_x = scale 27 | scale_y = scale 28 | if np.random.rand() < self.stretch_prob: 29 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 30 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 31 | 32 | scale_x = np.clip(scale_x, min_scale, None) 33 | scale_y = np.clip(scale_y, min_scale, None) 34 | 35 | # rescale the images 36 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 37 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 38 | img1_correc = cv2.resize(img1_correc, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 39 | img2_correc = cv2.resize(img2_correc, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 40 | 41 | if self.do_flip: 42 | if np.random.rand() < self.h_flip_prob: # h-flip 43 | img1 = img1[:, ::-1] 44 | img2 = img2[:, ::-1] 45 | img1_correc = img1_correc[:, ::-1] 46 | img2_correc = img2_correc[:, ::-1] 47 | 48 | if np.random.rand() < self.v_flip_prob: # v-flip 49 | img1 = img1[::-1, :] 50 | img2 = img2[::-1, :] 51 | img1_correc = img1_correc[::-1, :] 52 | img2_correc = img2_correc[::-1, :] 53 | 54 | uncropped_img1, uncropped_img2 = img1.copy(), img2.copy() 55 | uncropped_img1_correc, uncropped_img2_correc = img1_correc.copy(), img2_correc.copy() 56 | 57 | H, W, _ = uncropped_img1.shape 58 | orig_dims = np.array([H, W]) 59 | 60 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) 61 | x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) 62 | 63 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 64 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 65 | img1_correc = img1_correc[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 66 | img2_correc = img2_correc[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 67 | 68 | pad_right = uncropped_img1.shape[1] - (x0 + self.crop_size[1]) 69 | pad_bottom = uncropped_img1.shape[0] - (y0 + self.crop_size[0]) 70 | pad_params = np.array([x0, pad_right, y0, pad_bottom]) 71 | 72 | uncropped_img1 = np.pad(uncropped_img1, ((0, self.h_max - H), (0, self.w_max - W), (0, 0))) 73 | uncropped_img2 = np.pad(uncropped_img2, ((0, self.h_max - H), (0, self.w_max - W), (0, 0))) 74 | uncropped_img1_correc = np.pad(uncropped_img1_correc, ((0, self.h_max - H), (0, self.w_max - W), (0, 0))) 75 | uncropped_img2_correc = np.pad(uncropped_img2_correc, ((0, self.h_max - H), (0, self.w_max - W), (0, 0))) 76 | 77 | return img1, img2, uncropped_img1, uncropped_img2, img1_correc, img2_correc, uncropped_img1_correc, uncropped_img2_correc, pad_params, orig_dims 78 | 79 | def __call__(self, img1, img2): 80 | example = {} 81 | if self.args.use_full_size_warping: 82 | 83 | if self.args.no_photo_aug: 84 | img1_aug, img2_aug = img1, img2 85 | else: 86 | img1_aug, img2_aug = self.global_color_transform(img1, img2) 87 | 88 | img1, img2, uncropped_img1, uncropped_img2, img1_aug, img2_aug, uncropped_img1_aug, uncropped_img2_aug, pad_params, orig_dims \ 89 | = self.spatial_transform_correc_full_size_warping(img1, img2, img1_aug, img2_aug) 90 | 91 | else: 92 | img1, img2 = self.spatial_transform(img1, img2) 93 | if self.args.no_photo_aug: 94 | img1_aug, img2_aug = img1, img2 95 | else: 96 | img1_aug, img2_aug = self.global_color_transform(img1, img2) 97 | 98 | if self.args.no_photo_aug: 99 | ht, wd = img1_aug.shape[:2] 100 | mask_eraser1 = np.ones((ht, wd, 1)) 101 | mask_eraser2 = np.ones_like(mask_eraser1) 102 | else: 103 | if self.args.random_eraser: 104 | img1_aug, img2_aug, mask_eraser1, mask_eraser2 = self.raft_eraser_bidirectional(img1_aug, img2_aug) 105 | else: 106 | ht, wd = img1_aug.shape[:2] 107 | mask_eraser1 = np.ones((ht, wd, 1)) 108 | mask_eraser2 = np.ones_like(mask_eraser1) 109 | 110 | if self.args.use_full_size_warping: 111 | x0, _, y0, _ = pad_params 112 | uncropped_img1_aug[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] = img1_aug 113 | uncropped_img2_aug[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] = img2_aug 114 | 115 | example.update({ 116 | 'ims_uncropped': np.stack([np.ascontiguousarray(uncropped_img1), np.ascontiguousarray(uncropped_img2)]), 117 | 'ims_aug_uncropped': np.stack([np.ascontiguousarray(uncropped_img1_aug), np.ascontiguousarray(uncropped_img2_aug)]), 118 | 'pad_params': pad_params, 119 | 'orig_dims': orig_dims, 120 | }) 121 | 122 | img1_aug_stud = self.selfsup_transform(img1_aug) 123 | img2_aug_stud = self.selfsup_transform(img2_aug) 124 | mask_eraser1_stud = self.selfsup_transform(mask_eraser1) 125 | mask_eraser2_stud = self.selfsup_transform(mask_eraser2) 126 | valid = np.ones((1, 1, *self.crop_size)) 127 | 128 | example.update({ 129 | 'ims': np.stack([np.ascontiguousarray(img1), np.ascontiguousarray(img2)]), 130 | 'ims_aug': np.stack([np.ascontiguousarray(img1_aug), np.ascontiguousarray(img2_aug)]), 131 | 'ims_aug_stud': np.stack([np.ascontiguousarray(img1_aug_stud), np.ascontiguousarray(img2_aug_stud)]), 132 | 'masks_eraser': np.stack([np.ascontiguousarray(mask_eraser1), np.ascontiguousarray(mask_eraser2)]), 133 | 'masks_eraser_stud': np.stack([np.ascontiguousarray(mask_eraser1_stud), np.ascontiguousarray(mask_eraser2_stud)]), 134 | 'valid': np.ascontiguousarray(valid), 135 | }) 136 | 137 | return example 138 | -------------------------------------------------------------------------------- /utils/masks_and_occlusions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from functools import partial 6 | 7 | from .coords_and_warp import Warp 8 | 9 | 10 | class Occlusions(Warp): 11 | def __init__(self, occlusions, use_full_size_warping): 12 | super(Occlusions, self).__init__() 13 | 14 | if occlusions == 'brox': 15 | self.occlusions_estimator = self.occlusions_brox 16 | elif occlusions == 'wang': 17 | self.occlusions_estimator = self.occlusions_wang 18 | elif occlusions == 'none': 19 | self.occlusions_estimator = self.occlusions_none 20 | else: 21 | raise NotImplementedError 22 | 23 | if use_full_size_warping: 24 | self.get_occlusions_masks = self.occlusions_masks_full_size 25 | else: 26 | self.get_occlusions_masks = self.occlusions_masks 27 | 28 | def mask_out_flow(self, coords, margin=0): 29 | '''Mask boundary occlusions''' 30 | H, W = coords.size()[-2:] 31 | max_height, max_width = H-1-margin, W-1-margin 32 | mask = torch.logical_and( 33 | torch.logical_and(coords[..., 0:1, :, :] >= margin, coords[..., 0:1, :, :] <= max_width), 34 | torch.logical_and(coords[..., 1:2, :, :] >= margin, coords[..., 1:2, :, :] <= max_height)) 35 | return mask 36 | 37 | 38 | def mask_out_flow_full_size_warp(self, coords, pad_params, margin=0): 39 | '''Mask boundary occlusions when full-size warping is available''' 40 | 41 | coords_dims = coords.size() 42 | pad_left, pad_right, pad_top, pad_bottom = torch.split(-pad_params.view(-1, 4, *(1,)*(len(coords_dims)-2)), 1, 1) 43 | pad_left += margin 44 | pad_right += margin 45 | pad_top += margin 46 | pad_bottom += margin 47 | H, W = coords_dims[-2:] 48 | max_height, max_width = H-1-pad_bottom, W-1-pad_right 49 | mask = torch.logical_and( 50 | torch.logical_and(coords[..., 0:1, :, :] >= pad_left, coords[..., 0:1, :, :] <= max_width), 51 | torch.logical_and(coords[..., 1:2, :, :] >= pad_top, coords[..., 1:2, :, :] <= max_height)) 52 | return mask 53 | 54 | def occlusions_none(self, *args, **kwargs): 55 | return 1. 56 | 57 | def occlusions_brox(self, forward_flow, forward_coords, backward_flow, backward_coords): 58 | # Warp backward flow 59 | dims = backward_flow.size() 60 | backward_flow_warp = self.warp(backward_flow.reshape(-1, *dims[-3:]), forward_coords.reshape(-1, *dims[-3:])).view(*dims) 61 | 62 | # Compute occlusions based on forward-backward consistency. 63 | fb_sq_diff = torch.sum((forward_flow + backward_flow_warp)**2, axis=-3, keepdims=True) 64 | fb_sum_sq = torch.sum(forward_flow**2 + backward_flow_warp**2, axis=-3, keepdims=True) 65 | 66 | occ = fb_sq_diff < 0.01 * fb_sum_sq + 0.5 67 | 68 | return occ.view(*dims[:-3], 1, *dims[-2:]) 69 | 70 | def occlusions_wang(self, forward_flow, forward_coords, backward_flow, backward_coords): 71 | dims = backward_coords.size() 72 | B, H, W = np.prod(dims[:-3]), dims[-2], dims[-1] 73 | backward_coords = backward_coords.reshape(-1, *dims[-3:]) 74 | coords_floor = torch.floor(backward_coords) 75 | coords_offset = backward_coords - coords_floor 76 | coords_floor = coords_floor.long() 77 | 78 | idx_batch_offset = torch.arange(B, device=backward_coords.device).view(B, 1, 1).expand(-1, H, W) * H * W 79 | 80 | coords_floor_flattened = coords_floor.permute(0, 2, 3, 1).reshape(-1, 2) 81 | coords_offset_flattened = coords_offset.permute(0, 2, 3, 1).reshape(-1, 2) 82 | idx_batch_offset_flattened = idx_batch_offset.reshape(-1) 83 | 84 | # Initialize results. 85 | idxs_list = [] 86 | weights_list = [] 87 | 88 | # Loop over differences di and dj to the four neighboring pixels. 89 | for di in range(2): 90 | for dj in range(2): 91 | idxs_i = coords_floor_flattened[:, 1] + di 92 | idxs_j = coords_floor_flattened[:, 0] + dj 93 | 94 | 95 | idxs = idx_batch_offset_flattened + idxs_i * W + idxs_j 96 | 97 | mask = torch.where(torch.logical_and(torch.logical_and(idxs_i >= 0, idxs_i < H), 98 | torch.logical_and(idxs_j >= 0, idxs_j < W)))[0] 99 | 100 | valid_idxs = torch.gather(idxs, 0, mask) 101 | valid_offsets = torch.stack([torch.gather(coords_offset_flattened[:, 0], 0, mask), torch.gather(coords_offset_flattened[:, 1], 0, mask)], 1) 102 | 103 | # Compute weights according to bilinear interpolation. 104 | weights_i = (1. - di) - (-1)**di * valid_offsets[:, 1] 105 | weights_j = (1. - dj) - (-1)**dj * valid_offsets[:, 0] 106 | weights = weights_i * weights_j 107 | 108 | # Append indices and weights to the corresponding list. 109 | idxs_list.append(valid_idxs) 110 | weights_list.append(weights) 111 | 112 | # Concatenate everything. 113 | idxs = torch.cat(idxs_list, 0) 114 | weights = torch.cat(weights_list, 0) 115 | 116 | counts = torch.zeros(B * H * W, device=backward_coords.device).scatter_add(0, idxs, weights).reshape(B, 1, H, W).clamp(0, 1) 117 | 118 | return counts.view(*dims[:-3], 1, H, W) 119 | 120 | def occlusions_masks(self, flow_f, coords_f, flow_b, coords_b, masks_eraser=None, pad_params=None): 121 | '''Get final occlusion masks''' 122 | mask_out = self.mask_out_flow(coords_f) 123 | occ = self.occlusions_estimator(forward_flow=flow_f, forward_coords=coords_f, backward_flow=flow_b, backward_coords=coords_b) 124 | if len(masks_eraser.size()) + 1 == len(mask_out.size()): 125 | masks_eraser = masks_eraser.unsqueeze(-4) 126 | return mask_out * occ * masks_eraser, mask_out, None 127 | 128 | def occlusions_masks_full_size(self, flow_f, coords_f, flow_b, coords_b, masks_eraser=None, pad_params=None): 129 | '''Get final occlusion masks when full-size warping is available''' 130 | mask_out = self.mask_out_flow(coords_f) 131 | mask_out_pad = self.mask_out_flow_full_size_warp(coords_f, pad_params) 132 | occ = self.occlusions_estimator(forward_flow=flow_f, forward_coords=coords_f, backward_flow=flow_b, backward_coords=coords_b) + ~mask_out 133 | if len(masks_eraser.size()) + 1 == len(mask_out.size()): 134 | masks_eraser = masks_eraser.unsqueeze(-4) 135 | return mask_out_pad * occ.float().clamp(0,1) * masks_eraser, mask_out, mask_out_pad 136 | 137 | @torch.no_grad() 138 | def forward(self, flow_f, coords_f, flow_b, coords_b, masks_eraser=None, pad_params=None): 139 | return self.get_occlusions_masks(flow_f, coords_f, flow_b, coords_b, masks_eraser, pad_params) 140 | -------------------------------------------------------------------------------- /models/gma/extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ResidualBlock(nn.Module): 7 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 8 | super(ResidualBlock, self).__init__() 9 | 10 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) 11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 12 | self.relu = nn.ReLU(inplace=True) 13 | 14 | num_groups = planes // 8 15 | 16 | if norm_fn == 'group': 17 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 18 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 19 | if not stride == 1: 20 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 21 | 22 | elif norm_fn == 'batch': 23 | self.norm1 = nn.BatchNorm2d(planes) 24 | self.norm2 = nn.BatchNorm2d(planes) 25 | if not stride == 1: 26 | self.norm3 = nn.BatchNorm2d(planes) 27 | 28 | elif norm_fn == 'instance': 29 | self.norm1 = nn.InstanceNorm2d(planes) 30 | self.norm2 = nn.InstanceNorm2d(planes) 31 | if not stride == 1: 32 | self.norm3 = nn.InstanceNorm2d(planes) 33 | 34 | elif norm_fn == 'none': 35 | self.norm1 = nn.Sequential() 36 | self.norm2 = nn.Sequential() 37 | if not stride == 1: 38 | self.norm3 = nn.Sequential() 39 | 40 | if stride == 1: 41 | self.downsample = None 42 | 43 | else: 44 | self.downsample = nn.Sequential( 45 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 46 | 47 | def forward(self, x): 48 | y = x 49 | y = self.relu(self.norm1(self.conv1(y))) 50 | y = self.relu(self.norm2(self.conv2(y))) 51 | 52 | if self.downsample is not None: 53 | x = self.downsample(x) 54 | 55 | return self.relu(x + y) 56 | 57 | 58 | class BottleneckBlock(nn.Module): 59 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 60 | super(BottleneckBlock, self).__init__() 61 | 62 | self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0) 63 | self.conv2 = nn.Conv2d(planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride) 64 | self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0) 65 | self.relu = nn.ReLU(inplace=True) 66 | 67 | num_groups = planes // 8 68 | 69 | if norm_fn == 'group': 70 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4) 71 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4) 72 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 73 | if not stride == 1: 74 | self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 75 | 76 | elif norm_fn == 'batch': 77 | self.norm1 = nn.BatchNorm2d(planes // 4) 78 | self.norm2 = nn.BatchNorm2d(planes // 4) 79 | self.norm3 = nn.BatchNorm2d(planes) 80 | if not stride == 1: 81 | self.norm4 = nn.BatchNorm2d(planes) 82 | 83 | elif norm_fn == 'instance': 84 | self.norm1 = nn.InstanceNorm2d(planes // 4) 85 | self.norm2 = nn.InstanceNorm2d(planes // 4) 86 | self.norm3 = nn.InstanceNorm2d(planes) 87 | if not stride == 1: 88 | self.norm4 = nn.InstanceNorm2d(planes) 89 | 90 | elif norm_fn == 'none': 91 | self.norm1 = nn.Sequential() 92 | self.norm2 = nn.Sequential() 93 | self.norm3 = nn.Sequential() 94 | if not stride == 1: 95 | self.norm4 = nn.Sequential() 96 | 97 | if stride == 1: 98 | self.downsample = None 99 | 100 | else: 101 | self.downsample = nn.Sequential( 102 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) 103 | 104 | def forward(self, x): 105 | y = x 106 | y = self.relu(self.norm1(self.conv1(y))) 107 | y = self.relu(self.norm2(self.conv2(y))) 108 | y = self.relu(self.norm3(self.conv3(y))) 109 | 110 | if self.downsample is not None: 111 | x = self.downsample(x) 112 | 113 | return self.relu(x + y) 114 | 115 | 116 | class BasicEncoder(nn.Module): 117 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): 118 | super(BasicEncoder, self).__init__() 119 | self.norm_fn = norm_fn 120 | 121 | if self.norm_fn == 'group': 122 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) 123 | 124 | elif self.norm_fn == 'batch': 125 | self.norm1 = nn.BatchNorm2d(64) 126 | 127 | elif self.norm_fn == 'instance': 128 | self.norm1 = nn.InstanceNorm2d(64) 129 | 130 | elif self.norm_fn == 'none': 131 | self.norm1 = nn.Sequential() 132 | 133 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) 134 | self.relu1 = nn.ReLU(inplace=True) 135 | 136 | self.in_planes = 64 137 | self.layer1 = self._make_layer(64, stride=1) 138 | self.layer2 = self._make_layer(96, stride=2) 139 | self.layer3 = self._make_layer(128, stride=2) 140 | 141 | # output convolution 142 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) 143 | 144 | self.dropout = None 145 | if dropout > 0: 146 | self.dropout = nn.Dropout2d(p=dropout) 147 | 148 | for m in self.modules(): 149 | if isinstance(m, nn.Conv2d): 150 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 151 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 152 | if m.weight is not None: 153 | nn.init.constant_(m.weight, 1) 154 | if m.bias is not None: 155 | nn.init.constant_(m.bias, 0) 156 | 157 | def _make_layer(self, dim, stride=1): 158 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) 159 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) 160 | layers = (layer1, layer2) 161 | 162 | self.in_planes = dim 163 | return nn.Sequential(*layers) 164 | 165 | def forward(self, x): 166 | 167 | # if input is list, combine batch dimension 168 | is_list = isinstance(x, tuple) or isinstance(x, list) 169 | if is_list: 170 | batch_dim = x[0].shape[0] 171 | x = torch.cat(x, dim=0) 172 | 173 | x = self.conv1(x) 174 | x = self.norm1(x) 175 | x = self.relu1(x) 176 | 177 | x = self.layer1(x) 178 | x = self.layer2(x) 179 | x = self.layer3(x) 180 | 181 | x = self.conv2(x) 182 | 183 | if self.training and self.dropout is not None: 184 | x = self.dropout(x) 185 | 186 | if is_list: 187 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 188 | 189 | return x 190 | -------------------------------------------------------------------------------- /models/scv/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy import interpolate 5 | from torch_scatter import scatter_softmax, scatter_add 6 | 7 | 8 | class InputPadder: 9 | """ Pads images such that dimensions are divisible by 8 """ 10 | def __init__(self, dims, mode='sintel'): 11 | self.ht, self.wd = dims[-2:] 12 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 13 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 14 | if mode == 'sintel': 15 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 16 | else: 17 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] 18 | 19 | def pad(self, *inputs): 20 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 21 | 22 | def unpad(self,x): 23 | ht, wd = x.shape[-2:] 24 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 25 | return x[..., c[0]:c[1], c[2]:c[3]] 26 | 27 | 28 | def forward_interpolate(flow): 29 | flow = flow.detach().cpu().numpy() 30 | dx, dy = flow[0], flow[1] 31 | 32 | ht, wd = dx.shape 33 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 34 | 35 | x1 = x0 + dx 36 | y1 = y0 + dy 37 | 38 | x1 = x1.reshape(-1) 39 | y1 = y1.reshape(-1) 40 | dx = dx.reshape(-1) 41 | dy = dy.reshape(-1) 42 | 43 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 44 | x1 = x1[valid] 45 | y1 = y1[valid] 46 | dx = dx[valid] 47 | dy = dy[valid] 48 | 49 | flow_x = interpolate.griddata( 50 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 51 | 52 | flow_y = interpolate.griddata( 53 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) 54 | 55 | flow = np.stack([flow_x, flow_y], axis=0) 56 | return torch.from_numpy(flow).float() 57 | 58 | 59 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 60 | """ Wrapper for grid_sample, uses pixel coordinates """ 61 | H, W = img.shape[-2:] 62 | xgrid, ygrid = coords.split([1,1], dim=-1) 63 | xgrid = 2*xgrid/(W-1) - 1 64 | ygrid = 2*ygrid/(H-1) - 1 65 | 66 | grid = torch.cat([xgrid, ygrid], dim=-1) 67 | img = F.grid_sample(img, grid, align_corners=True) 68 | 69 | if mask: 70 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 71 | return img, mask.float() 72 | 73 | return img 74 | 75 | 76 | def coords_grid(batch, ht, wd): 77 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) 78 | coords = torch.stack(coords[::-1], dim=0).float() 79 | return coords[None].expand(batch, -1, -1, -1) 80 | 81 | 82 | def coords_grid_y_first(batch, ht, wd): 83 | """Place y grid before x grid""" 84 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) 85 | coords = torch.stack(coords, dim=0).int() 86 | return coords[None].expand(batch, -1, -1, -1) 87 | 88 | 89 | def soft_argmax(corr_me, B, H1, W1): 90 | # Implement soft argmin 91 | coords, feats = corr_me.decomposed_coordinates_and_features 92 | 93 | # Computing soft argmin 94 | flow_pred = torch.zeros(B, 2, H1, W1).to(corr_me.device) 95 | for batch, (coord, feat) in enumerate(zip(coords, feats)): 96 | coord_img_1 = coord[:, :2].to(corr_me.device) 97 | coord_img_2 = coord[:, 2:].to(corr_me.device) 98 | # relative positions (flow hypotheses) 99 | rel_pos = (coord_img_2 - coord_img_1) 100 | # augmented indices 101 | aug_coord_img_1 = (coord_img_1[:, 0:1] * W1 + coord_img_1[:, 1:2]).long() 102 | # run softmax on the score 103 | weight = scatter_softmax(feat, aug_coord_img_1, dim=0) 104 | rel_pos_weighted = weight * rel_pos 105 | out = scatter_add(rel_pos_weighted, aug_coord_img_1, dim=0) 106 | # Need to permute (y, x) to (x, y) for flow 107 | flow_pred[batch] = out[:, [1,0]].view(H1, W1, 2).permute(2, 0, 1) 108 | return flow_pred 109 | 110 | 111 | def upflow8(flow, mode='bilinear'): 112 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 113 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 114 | 115 | 116 | def upflow4(flow, mode='bilinear'): 117 | new_size = (4 * flow.shape[2], 4 * flow.shape[3]) 118 | return 4 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 119 | 120 | 121 | def upflow2(flow, mode='bilinear'): 122 | new_size = (2 * flow.shape[2], 2 * flow.shape[3]) 123 | return 2 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 124 | 125 | 126 | def downflow8(flow, mode='bilinear'): 127 | new_size = (flow.shape[2] // 8, flow.shape[3] // 8) 128 | return F.interpolate(flow, size=new_size, mode=mode, align_corners=True) / 8 129 | 130 | 131 | def downflow4(flow, mode='bilinear'): 132 | new_size = (flow.shape[2] // 4, flow.shape[3] // 4) 133 | return F.interpolate(flow, size=new_size, mode=mode, align_corners=True) / 4 134 | 135 | 136 | def compute_interpolation_weights(yx_warped): 137 | # yx_warped: [N, 2] 138 | y_warped = yx_warped[:, 0] 139 | x_warped = yx_warped[:, 1] 140 | 141 | # elementwise operations below 142 | y_f = torch.floor(y_warped) 143 | y_c = y_f + 1 144 | x_f = torch.floor(x_warped) 145 | x_c = x_f + 1 146 | 147 | w0 = (y_c - y_warped) * (x_c - x_warped) 148 | w1 = (y_warped - y_f) * (x_c - x_warped) 149 | w2 = (y_c - y_warped) * (x_warped - x_f) 150 | w3 = (y_warped - y_f) * (x_warped - x_f) 151 | 152 | weights = [w0, w1, w2, w3] 153 | indices = [torch.stack([y_f, x_f], dim=1), torch.stack([y_c, x_f], dim=1), 154 | torch.stack([y_f, x_c], dim=1), torch.stack([y_c, x_c], dim=1)] 155 | weights = torch.cat(weights, dim=1) 156 | indices = torch.cat(indices, dim=2) 157 | # indices = torch.cat(indices, dim=0) # [4*N, 2] 158 | 159 | return weights, indices 160 | 161 | # weights, indices = compute_interpolation_weights(xy_warped, b, h_i, w_i) 162 | 163 | 164 | def compute_inverse_interpolation_img(weights, indices, img, b, h_i, w_i): 165 | """ 166 | weights: [b, h*w] 167 | indices: [b, h*w] 168 | img: [b, h*w, a, b, c, ...] 169 | """ 170 | w0, w1, w2, w3 = weights 171 | ff_idx, cf_idx, fc_idx, cc_idx = indices 172 | 173 | k = len(img.size()) - len(w0.size()) 174 | img_0 = w0[(...,) + (None,) * k] * img 175 | img_1 = w1[(...,) + (None,) * k] * img 176 | img_2 = w2[(...,) + (None,) * k] * img 177 | img_3 = w3[(...,) + (None,) * k] * img 178 | 179 | img_out = torch.zeros(b, h_i * w_i, *img.shape[2:]).type_as(img) 180 | 181 | ff_idx = torch.clamp(ff_idx, min=0, max=h_i * w_i - 1) 182 | cf_idx = torch.clamp(cf_idx, min=0, max=h_i * w_i - 1) 183 | fc_idx = torch.clamp(fc_idx, min=0, max=h_i * w_i - 1) 184 | cc_idx = torch.clamp(cc_idx, min=0, max=h_i * w_i - 1) 185 | 186 | img_out.scatter_add_(1, ff_idx[(...,) + (None,) * k].expand_as(img_0), img_0) 187 | img_out.scatter_add_(1, cf_idx[(...,) + (None,) * k].expand_as(img_1), img_1) 188 | img_out.scatter_add_(1, fc_idx[(...,) + (None,) * k].expand_as(img_2), img_2) 189 | img_out.scatter_add_(1, cc_idx[(...,) + (None,) * k].expand_as(img_3), img_3) 190 | 191 | return img_out # [b, h_i*w_i, ...] 192 | -------------------------------------------------------------------------------- /models/gma/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy import interpolate 5 | # from torch_scatter import scatter_softmax, scatter_add 6 | 7 | 8 | class InputPadder: 9 | """ Pads images such that dimensions are divisible by 8 """ 10 | def __init__(self, dims, mode='sintel'): 11 | self.ht, self.wd = dims[-2:] 12 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 13 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 14 | if mode == 'sintel': 15 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 16 | else: 17 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] 18 | 19 | def pad(self, *inputs): 20 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 21 | 22 | def unpad(self,x): 23 | ht, wd = x.shape[-2:] 24 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 25 | return x[..., c[0]:c[1], c[2]:c[3]] 26 | 27 | 28 | def forward_interpolate(flow): 29 | flow = flow.detach().cpu().numpy() 30 | dx, dy = flow[0], flow[1] 31 | 32 | ht, wd = dx.shape 33 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 34 | 35 | x1 = x0 + dx 36 | y1 = y0 + dy 37 | 38 | x1 = x1.reshape(-1) 39 | y1 = y1.reshape(-1) 40 | dx = dx.reshape(-1) 41 | dy = dy.reshape(-1) 42 | 43 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 44 | x1 = x1[valid] 45 | y1 = y1[valid] 46 | dx = dx[valid] 47 | dy = dy[valid] 48 | 49 | flow_x = interpolate.griddata( 50 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 51 | 52 | flow_y = interpolate.griddata( 53 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) 54 | 55 | flow = np.stack([flow_x, flow_y], axis=0) 56 | return torch.from_numpy(flow).float() 57 | 58 | 59 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 60 | """ Wrapper for grid_sample, uses pixel coordinates """ 61 | H, W = img.shape[-2:] 62 | xgrid, ygrid = coords.split([1,1], dim=-1) 63 | xgrid = 2*xgrid/(W-1) - 1 64 | ygrid = 2*ygrid/(H-1) - 1 65 | 66 | grid = torch.cat([xgrid, ygrid], dim=-1) 67 | img = F.grid_sample(img, grid, align_corners=True) 68 | 69 | if mask: 70 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 71 | return img, mask.float() 72 | 73 | return img 74 | 75 | 76 | def coords_grid(batch, ht, wd, device): 77 | coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device)) 78 | coords = torch.stack(coords[::-1], dim=0).float() 79 | return coords[None].expand(batch, -1, -1, -1) 80 | 81 | 82 | def coords_grid_y_first(batch, ht, wd): 83 | """Place y grid before x grid""" 84 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) 85 | coords = torch.stack(coords, dim=0).int() 86 | return coords[None].expand(batch, -1, -1, -1) 87 | 88 | 89 | def soft_argmax(corr_me, B, H1, W1): 90 | # Implement soft argmin 91 | coords, feats = corr_me.decomposed_coordinates_and_features 92 | 93 | # Computing soft argmin 94 | flow_pred = torch.zeros(B, 2, H1, W1).to(corr_me.device) 95 | for batch, (coord, feat) in enumerate(zip(coords, feats)): 96 | coord_img_1 = coord[:, :2].to(corr_me.device) 97 | coord_img_2 = coord[:, 2:].to(corr_me.device) 98 | # relative positions (flow hypotheses) 99 | rel_pos = (coord_img_2 - coord_img_1) 100 | # augmented indices 101 | aug_coord_img_1 = (coord_img_1[:, 0:1] * W1 + coord_img_1[:, 1:2]).long() 102 | # run softmax on the score 103 | weight = scatter_softmax(feat, aug_coord_img_1, dim=0) 104 | rel_pos_weighted = weight * rel_pos 105 | out = scatter_add(rel_pos_weighted, aug_coord_img_1, dim=0) 106 | # Need to permute (y, x) to (x, y) for flow 107 | flow_pred[batch] = out[:, [1,0]].view(H1, W1, 2).permute(2, 0, 1) 108 | return flow_pred 109 | 110 | 111 | def upflow8(flow, mode='bilinear'): 112 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 113 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 114 | 115 | 116 | def upflow4(flow, mode='bilinear'): 117 | new_size = (4 * flow.shape[2], 4 * flow.shape[3]) 118 | return 4 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 119 | 120 | 121 | def upflow2(flow, mode='bilinear'): 122 | new_size = (2 * flow.shape[2], 2 * flow.shape[3]) 123 | return 2 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 124 | 125 | 126 | def downflow8(flow, mode='bilinear'): 127 | new_size = (flow.shape[2] // 8, flow.shape[3] // 8) 128 | return F.interpolate(flow, size=new_size, mode=mode, align_corners=True) / 8 129 | 130 | 131 | def downflow4(flow, mode='bilinear'): 132 | new_size = (flow.shape[2] // 4, flow.shape[3] // 4) 133 | return F.interpolate(flow, size=new_size, mode=mode, align_corners=True) / 4 134 | 135 | 136 | def compute_interpolation_weights(yx_warped): 137 | # yx_warped: [N, 2] 138 | y_warped = yx_warped[:, 0] 139 | x_warped = yx_warped[:, 1] 140 | 141 | # elementwise operations below 142 | y_f = torch.floor(y_warped) 143 | y_c = y_f + 1 144 | x_f = torch.floor(x_warped) 145 | x_c = x_f + 1 146 | 147 | w0 = (y_c - y_warped) * (x_c - x_warped) 148 | w1 = (y_warped - y_f) * (x_c - x_warped) 149 | w2 = (y_c - y_warped) * (x_warped - x_f) 150 | w3 = (y_warped - y_f) * (x_warped - x_f) 151 | 152 | weights = [w0, w1, w2, w3] 153 | indices = [torch.stack([y_f, x_f], dim=1), torch.stack([y_c, x_f], dim=1), 154 | torch.stack([y_f, x_c], dim=1), torch.stack([y_c, x_c], dim=1)] 155 | weights = torch.cat(weights, dim=1) 156 | indices = torch.cat(indices, dim=2) 157 | # indices = torch.cat(indices, dim=0) # [4*N, 2] 158 | 159 | return weights, indices 160 | 161 | # weights, indices = compute_interpolation_weights(xy_warped, b, h_i, w_i) 162 | 163 | 164 | def compute_inverse_interpolation_img(weights, indices, img, b, h_i, w_i): 165 | """ 166 | weights: [b, h*w] 167 | indices: [b, h*w] 168 | img: [b, h*w, a, b, c, ...] 169 | """ 170 | w0, w1, w2, w3 = weights 171 | ff_idx, cf_idx, fc_idx, cc_idx = indices 172 | 173 | k = len(img.size()) - len(w0.size()) 174 | img_0 = w0[(...,) + (None,) * k] * img 175 | img_1 = w1[(...,) + (None,) * k] * img 176 | img_2 = w2[(...,) + (None,) * k] * img 177 | img_3 = w3[(...,) + (None,) * k] * img 178 | 179 | img_out = torch.zeros(b, h_i * w_i, *img.shape[2:]).type_as(img) 180 | 181 | ff_idx = torch.clamp(ff_idx, min=0, max=h_i * w_i - 1) 182 | cf_idx = torch.clamp(cf_idx, min=0, max=h_i * w_i - 1) 183 | fc_idx = torch.clamp(fc_idx, min=0, max=h_i * w_i - 1) 184 | cc_idx = torch.clamp(cc_idx, min=0, max=h_i * w_i - 1) 185 | 186 | img_out.scatter_add_(1, ff_idx[(...,) + (None,) * k].expand_as(img_0), img_0) 187 | img_out.scatter_add_(1, cf_idx[(...,) + (None,) * k].expand_as(img_1), img_1) 188 | img_out.scatter_add_(1, fc_idx[(...,) + (None,) * k].expand_as(img_2), img_2) 189 | img_out.scatter_add_(1, cc_idx[(...,) + (None,) * k].expand_as(img_3), img_3) 190 | 191 | return img_out # [b, h_i*w_i, ...] 192 | -------------------------------------------------------------------------------- /runners/runners_corrections.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from torch.cuda.amp import autocast 6 | 7 | from .runners import BasicRunner 8 | from models.raft.corrector import Corrector 9 | from losses import LossCorrections 10 | from utils.utils_corrections import apply_corrections_uncropped, apply_corrections, get_best_index, get_good_correction 11 | from utils.coords_and_warp import Warp, WarpFullSize 12 | 13 | 14 | class RunnerCorrection(BasicRunner): 15 | def __init__(self, model, args): 16 | super(RunnerCorrection, self).__init__(model, args) 17 | 18 | self.warp = WarpFullSize() if args.use_full_size_warping else Warp() 19 | 20 | self.corrector = Corrector(args, intput_dim=args.input_dim_corrector, norm_fn='instance') 21 | self.loss = LossCorrections(args) 22 | self.apply_corrections = apply_corrections_uncropped if self.args.use_full_size_warping else apply_corrections 23 | 24 | 25 | @torch.no_grad() 26 | def get_corrector_inputs(self, ims, ims_warp, masks=None, outs=None, outs_full_size=None): 27 | '''Set the inputs for the correction estimator''' 28 | inputs_correc = [2*(ims/255.) - 1., 2*(ims_warp/255.) - 1.] 29 | 30 | if self.args.occ_in_correc_inputs: 31 | inputs_correc.append(masks) 32 | elif self.args.no_FSW_occ_in_correc_inputs: 33 | inputs_correc.append(masks * outs) 34 | elif self.args.occ_and_out_in_correc_inputs: 35 | inputs_correc.extend([masks, outs]) 36 | 37 | return torch.cat(inputs_correc, dim=-3) 38 | 39 | def run_training_step(self, example, total_step=None): 40 | outputs = {} 41 | 42 | B, _, _, H, W = example['ims'].size() 43 | 44 | if self.args.use_full_size_warping: 45 | pad_params = example['pad_params'].int() 46 | orig_dims = example['orig_dims'].int() 47 | else: 48 | pad_params = None 49 | 50 | # Forward pass of augmented images in the optical flow model 51 | outputs = self.model(example['ims_aug'], fwd_bwd=True, suffix='_aug') 52 | 53 | # Absolute coordinate computations 54 | outputs['coords'] = self.get_coords(outputs['flows_aug']) 55 | 56 | # Compute occlusion mask, boundary (out) mask with and without full-size wapring 57 | outputs['masks'], outputs['outs'], outs_full_size = \ 58 | self.get_occ_mask(outputs['flows_aug'][:, self.fwd], outputs['coords'][:, self.fwd], 59 | outputs['flows_aug'][:, self.bwd], outputs['coords'][:, self.bwd], 60 | example['masks_eraser'], pad_params=pad_params) 61 | 62 | if total_step >= self.args.selfsup_starting_step: 63 | # Forward pass of non-augmented images in the optical flow model to get teacher flow 64 | outputs.update(self.run_inference_step(example, fwd_bwd=True, suffix='_teacher')) 65 | # Forward pass of cropped augmented images in the optical flow model to get student flow 66 | outputs.update(self.model(example['ims_aug_stud'], fwd_bwd=True, suffix='_stud')) 67 | 68 | if total_step >= self.args.correc_starting_step: 69 | 70 | with torch.no_grad(): 71 | 72 | # Warp augmented image 73 | if self.args.use_full_size_warping: 74 | ims_aug_warp_0 = self.warp(example['ims_aug_uncropped'][:, self.bwd], outputs['flows_aug'][:, :, 0], pad_params, orig_dims) 75 | 76 | else: 77 | ims_aug_warp_0 = self.warp(example['ims_aug'][:, self.bwd], outputs['coords'][:, :, 0]) 78 | 79 | # Set the augmented inputs for the corrector 80 | inputs_correc_aug = self.get_corrector_inputs(ims=example['ims_aug'], ims_warp=ims_aug_warp_0, 81 | masks=outputs['masks'][:, :, 0], outs=outputs['outs'][:, :, 0], outs_full_size=outs_full_size[:, :, 0]) 82 | 83 | with autocast(enabled=self.args.mixed_precision): 84 | # Forward pass in the corrector with augmented inputs 85 | outputs['correcs_aug'] = self.corrector(inputs_correc_aug.detach()) 86 | 87 | if total_step >= self.args.correc_in_photo_starting_step: 88 | 89 | with torch.no_grad(): 90 | flows0 = outputs['flows_aug'][:, :, 0] 91 | coords0 = outputs['coords'][:, :, 0] 92 | masks0 = outputs['masks'][:, :, 0] 93 | outs0 = outputs['outs'][:, :, 0] 94 | outs_full_size0 = outs_full_size[:, :, 0] 95 | 96 | # Warp non-augmented image 97 | if self.args.use_full_size_warping: 98 | ims_warp_0 = self.warp(example['ims_uncropped'][:, self.bwd], flows0, pad_params, orig_dims) 99 | 100 | else: 101 | ims_warp_0 = self.warp(example['ims'][:, self.bwd], coords0) 102 | 103 | # Set the non-augmented inputs for the corrector 104 | inputs_correc = self.get_corrector_inputs(ims=example['ims'], ims_warp=ims_warp_0, 105 | masks=masks0, outs=outs0, outs_full_size=outs_full_size0) 106 | 107 | with autocast(enabled=self.args.mixed_precision): 108 | # Forward pass in the corrector with non-augmented inputs 109 | correcs = self.corrector(inputs_correc.detach()) 110 | 111 | outputs['correcs'] = correcs 112 | 113 | if self.args.keep_good_corrections_only: 114 | 115 | with torch.no_grad(): 116 | 117 | # Warp non-augmented image 118 | if self.args.use_full_size_warping: 119 | pad_params = example['pad_params'].int() 120 | orig_dims = example['orig_dims'].int() 121 | ims_correc = self.apply_corrections(example['ims_uncropped'], outputs['correcs'], pad_params, clamp_corrected_images=self.args.smart_clamp) 122 | ims_warp_correc_0 = self.warp(ims_correc[:, self.bwd], outputs['flows_aug'][:, :, 0].detach(), pad_params, orig_dims) 123 | 124 | else: 125 | ims_correc = self.apply_corrections(example['ims'], outputs['correcs'], clamp_corrected_images=self.args.smart_clamp) 126 | ims_warp_correc_0 = self.warp(ims_correc[:, self.bwd], outputs['coords'][:, :, 0].detach()) 127 | 128 | # Get well-estimated correction mask 129 | good_correcs = get_good_correction(example['ims'], ims_warp_0, ims_warp_correc_0) 130 | 131 | # Unwarp the well-estimated correction mask 132 | outputs['best_indices'] = get_best_index(coords0.flatten(end_dim=-4), good_correcs.flatten(end_dim=-4)).unflatten(dim=0, sizes=(B, 2))[:, self.bwd] 133 | 134 | else: 135 | outputs['best_indices'] = torch.ones_like(masks0) 136 | 137 | return outputs 138 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | from turtle import forward 2 | import os 3 | import warnings 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch import Tensor 9 | from torch.utils.tensorboard import SummaryWriter 10 | from datetime import datetime 11 | import torch.distributed as dist 12 | from torch import Tensor 13 | 14 | 15 | def to_cuda(example, excluded_keys=[]): 16 | for key, value in example.items(): 17 | if key not in excluded_keys: 18 | if torch.is_tensor(value): 19 | example[key] = value.float().cuda() 20 | 21 | class InputPadder: 22 | """ Pads images such that dimensions are divisible by 8 """ 23 | def __init__(self, dims, mode='sintel', div=8): 24 | self.ht, self.wd = dims[-2:] 25 | pad_ht = (((self.ht // div) + 1) * div - self.ht) % div 26 | pad_wd = (((self.wd // div) + 1) * div - self.wd) % div 27 | if mode == 'sintel': 28 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 29 | else: 30 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] 31 | 32 | def pad(self, inputs): 33 | return F.pad(inputs, self._pad, mode='replicate') 34 | 35 | def unpad(self,x): 36 | ht, wd = x.shape[-2:] 37 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 38 | return x[..., c[0]:c[1], c[2]:c[3]] 39 | 40 | 41 | def selfsup_transform_flow(x): 42 | '''Center-crop, resize and rescale flow for the selfsup loss''' 43 | _, _, H, W = x.size() 44 | x = x[..., 64:-64, 64:-64] 45 | _, _, H_down, W_down = x.size() 46 | x_up = F.interpolate(x, (H, W), mode='bilinear', align_corners=True) 47 | x_up[:, 0] *= W / W_down 48 | x_up[:, 1] *= H / H_down 49 | return x_up 50 | 51 | 52 | @torch.no_grad() 53 | def to_low_res(x, scale_factor=8): 54 | dims = x.size() 55 | H_new, W_new = dims[-2]//scale_factor, dims[-1]//scale_factor 56 | return torch.ceil(F.interpolate(x.view(-1, *dims[-3:]), (H_new, W_new), mode='area').view(*dims[:-2], H_new, W_new)) 57 | 58 | 59 | def l1(x): 60 | return torch.norm(x, p=1, dim=1, keepdim=True) 61 | 62 | 63 | def rgb2gray(img): 64 | img_gray = img[..., 0, :, :]*0.2989 + img[..., 1, :, :]*0.1140 + img[..., 2, :, :]*0.5870 65 | return img_gray.unsqueeze(-3) 66 | 67 | 68 | def count_parameters(model): 69 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 70 | 71 | class Logging: 72 | def __init__(self, runner, scheduler, args): 73 | self.args = args 74 | if args.gpu == 0: 75 | self.runner = runner 76 | self.scheduler = scheduler 77 | self.total_steps = args.init_step 78 | self.running_loss = {} 79 | self.LOG_FREQ = args.LOG_FREQ 80 | if not args.debug: 81 | log_dir = args.log_dir + datetime.now().strftime("%Y-%m-%d_%H-%M-%S" + '_' + args.name) 82 | print('log_dir:', log_dir) 83 | self.writer = SummaryWriter(log_dir=log_dir) 84 | else: 85 | warnings.warn("WARNING: debug mode activated no checkpoint will be saved") 86 | 87 | print("Parameter Count: %d" % count_parameters(runner)) 88 | print(datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), 'start training') 89 | 90 | def _print_training_status(self): 91 | metrics_data = {k: round((self.running_loss[k].item() if type(self.running_loss[k]) == torch.Tensor else self.running_loss[k])/self.LOG_FREQ, 4) for k,v in self.running_loss.items()} 92 | training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps+1, self.scheduler.get_last_lr()[0]) 93 | 94 | # print the training status 95 | print(datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), training_str, metrics_data) 96 | 97 | if not self.args.debug: 98 | for k in self.running_loss: 99 | self.writer.add_scalar(k, self.running_loss[k]/self.LOG_FREQ, self.total_steps) 100 | self.running_loss[k] = 0.0 101 | else: 102 | for k in self.running_loss: 103 | self.running_loss[k] = 0.0 104 | 105 | def push(self, metrics): 106 | if self.args.gpu == 0: 107 | self.total_steps += 1 108 | 109 | for key in metrics: 110 | if key not in self.running_loss: 111 | self.running_loss[key] = 0.0 112 | 113 | self.running_loss[key] += metrics[key] 114 | 115 | if self.total_steps % self.LOG_FREQ == self.LOG_FREQ-1: 116 | self._print_training_status() 117 | self.running_loss = {} 118 | 119 | def write_dict(self, results): 120 | if self.args.gpu == 0: 121 | print(datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), results) 122 | if not self.args.debug: 123 | for key in results: 124 | self.writer.add_scalar(key, results[key], self.total_steps) 125 | 126 | def close(self): 127 | if not self.args.debug and self.args.gpu == 0: 128 | self.writer.close() 129 | 130 | 131 | @torch.no_grad() 132 | def stack_all_gather_without_backprop(x: Tensor, dim: int = 0) -> Tensor: 133 | """Gather tensor across devices without grad. 134 | 135 | Args: 136 | x (Tensor): Tensor to gather. 137 | dim (int): Dimension to concat. Defaults to 0. 138 | 139 | Returns: 140 | Tensor: Gathered tensor. 141 | """ 142 | if dist.is_available() and dist.is_initialized(): 143 | tensors_gather = [torch.ones_like(x) 144 | for _ in range(dist.get_world_size())] 145 | dist.all_gather(tensors_gather, x, async_op=False) 146 | output = torch.stack(tensors_gather, dim=dim) 147 | else: 148 | output = x 149 | return output 150 | 151 | @torch.no_grad() 152 | def list_all_gather_without_backprop(x: Tensor, dim: int = 0) -> Tensor: 153 | """Gather tensor across devices without grad. 154 | 155 | Args: 156 | x (Tensor): Tensor to gather. 157 | dim (int): Dimension to concat. Defaults to 0. 158 | 159 | Returns: 160 | Tensor: Gathered tensor. 161 | """ 162 | if dist.is_available() and dist.is_initialized(): 163 | tensors_gather = [torch.ones_like(x) 164 | for _ in range(dist.get_world_size())] 165 | dist.all_gather(tensors_gather, x, async_op=False) 166 | output = tensors_gather 167 | else: 168 | output = [x] 169 | return output 170 | 171 | 172 | class CheckpointSavior(object): 173 | def __init__(self, args): 174 | self.args = args 175 | self.main_metric = self.set_main_metric() 176 | self.best_metric = np.float32('inf') 177 | 178 | def set_main_metric(self): 179 | if 'kitti' in self.args.dataset_test.lower(): 180 | return 'epe_occ' 181 | else: 182 | return 'epe' 183 | 184 | def __call__(self, results, runner, custom_name=None): 185 | if self.args.gpu == 0: 186 | save_best_checkpoint = False 187 | 188 | if results[self.main_metric] < self.best_metric: 189 | self.best_metric = results[self.main_metric] 190 | save_best_checkpoint = True 191 | print(f'New best {self.main_metric}:', self.best_metric) 192 | 193 | else: 194 | print(f'Best {self.main_metric}:', self.best_metric) 195 | assert results[self.main_metric] <= 1e6, "Training interrupted because the model has diverged" 196 | 197 | if not self.args.debug: 198 | if save_best_checkpoint: 199 | PATH = os.path.join(self.args.ckpt_dir, self.args.name, 'best_checkpoint.pth') 200 | torch.save(runner.state_dict(), PATH) 201 | 202 | PATH = os.path.join(self.args.ckpt_dir, self.args.name, 'last_checkpoint.pth') 203 | torch.save(runner.state_dict(), PATH) 204 | 205 | if custom_name is not None: 206 | PATH = os.path.join(self.args.ckpt_dir, self.args.name, custom_name + '.pth') 207 | torch.save(runner.state_dict(), PATH) 208 | 209 | else: 210 | warnings.warn("WARNING: debug mode activated no checkpoint will be saved") 211 | -------------------------------------------------------------------------------- /models/scv/sparsenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .extractor import BasicEncoder, BasicEncoderQuarter 6 | from .update import BasicUpdateBlock, BasicUpdateBlockQuarter 7 | from .utils.utils import bilinear_sampler, coords_grid, coords_grid_y_first,\ 8 | upflow4, compute_interpolation_weights 9 | from .knn import knn_faiss_raw 10 | 11 | autocast = torch.cuda.amp.autocast 12 | 13 | 14 | def compute_sparse_corr(fmap1, fmap2, k=32): 15 | """ 16 | Compute a cost volume containing the k-largest hypotheses for each pixel. 17 | Output: corr_mink 18 | """ 19 | B, C, H1, W1 = fmap1.shape 20 | H2, W2 = fmap2.shape[2:] 21 | N = H1 * W1 22 | 23 | fmap1, fmap2 = fmap1.view(B, C, -1), fmap2.view(B, C, -1) 24 | 25 | with torch.no_grad(): 26 | _, indices = knn_faiss_raw(fmap1, fmap2, k) # [B, k, H1*W1] 27 | 28 | indices_coord = indices.unsqueeze(1).expand(-1, 2, -1, -1) # [B, 2, k, H1*W1] 29 | coords0 = coords_grid_y_first(B, H2, W2).view(B, 2, 1, -1).expand(-1, -1, k, -1).to(fmap1.device) # [B, 2, k, H1*W1] 30 | coords1 = coords0.gather(3, indices_coord) # [B, 2, k, H1*W1] 31 | coords1 = coords1 - coords0 32 | 33 | # Append batch index 34 | batch_index = torch.arange(B).view(B, 1, 1, 1).expand(-1, -1, k, N).type_as(coords1) 35 | 36 | # Gather by indices from map2 and compute correlation volume 37 | fmap2 = fmap2.gather(2, indices.view(B, 1, -1).expand(-1, C, -1)).view(B, C, k, N) 38 | corr_sp = torch.einsum('bcn,bckn->bkn', fmap1, fmap2).contiguous() / torch.sqrt(torch.tensor(C).float()) # [B, k, H1*W1] 39 | 40 | return corr_sp, coords0, coords1, batch_index # coords: [B, 2, k, H1*W1] 41 | 42 | 43 | class FlowHead(nn.Module): 44 | def __init__(self, input_dim=256, batch_norm=True): 45 | super().__init__() 46 | if batch_norm: 47 | self.flowpredictor = nn.Sequential( 48 | nn.Conv2d(input_dim, 128, 3, padding=1), 49 | nn.BatchNorm2d(128), 50 | nn.ReLU(inplace=True), 51 | nn.Conv2d(128, 64, 3, padding=1), 52 | nn.BatchNorm2d(64), 53 | nn.ReLU(inplace=True), 54 | nn.Conv2d(64, 2, 3, padding=1) 55 | ) 56 | else: 57 | self.flowpredictor = nn.Sequential( 58 | nn.Conv2d(input_dim, 128, 3, padding=1), 59 | nn.ReLU(inplace=True), 60 | nn.Conv2d(128, 64, 3, padding=1), 61 | nn.ReLU(inplace=True), 62 | nn.Conv2d(64, 2, 3, padding=1) 63 | ) 64 | 65 | def forward(self, x): 66 | return self.flowpredictor(x) 67 | 68 | 69 | class SparseNet(nn.Module): 70 | def __init__(self, args): 71 | super().__init__() 72 | self.args = args 73 | self.iters = 8 #args.iters 74 | 75 | # feature network, context network, and update block 76 | self.fnet = BasicEncoderQuarter(output_dim=256, norm_fn='instance', dropout=False) 77 | self.cnet = BasicEncoderQuarter(output_dim=256, norm_fn='batch', dropout=False) 78 | 79 | # correlation volume encoder 80 | self.update_block = BasicUpdateBlockQuarter(self.args, hidden_dim=128, input_dim=405) 81 | 82 | def initialize_flow(self, img): 83 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 84 | N, C, H, W = img.shape 85 | coords0 = coords_grid(N, H//4, W//4).to(img.device) 86 | coords1 = coords_grid(N, H//4, W//4).to(img.device) 87 | 88 | # optical flow computed as difference: flow = coords1 - coords0 89 | return coords0, coords1 90 | 91 | def upsample_flow_quarter(self, flow, mask): 92 | """ Upsample flow field [H/4, W/4, 2] -> [H, W, 2] using convex combination """ 93 | N, _, H, W = flow.shape 94 | mask = mask.view(N, 1, 9, 4, 4, H, W) 95 | mask = torch.softmax(mask, dim=2) 96 | 97 | up_flow = F.unfold(4 * flow, [3,3], padding=1) 98 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 99 | 100 | up_flow = torch.sum(mask * up_flow, dim=2) 101 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 102 | return up_flow.reshape(N, 2, 4*H, 4*W) 103 | 104 | def compute_flow(self, image1, fmap1, fmap2, flow_init=None, test_mode=False): 105 | 106 | # run the feature and context network 107 | with autocast(enabled=self.args.mixed_precision): 108 | cnet = self.cnet(image1) 109 | net, inp = torch.split(cnet, [128, 128], dim=1) 110 | net = torch.tanh(net) 111 | inp = torch.relu(inp) 112 | 113 | fmap1 = fmap1.float() 114 | fmap2 = fmap2.float() 115 | 116 | B, _, H1, W1 = fmap1.shape 117 | 118 | # GRU 119 | coords0, coords1 = self.initialize_flow(image1) 120 | 121 | if flow_init is not None: 122 | coords1 = coords1 + flow_init 123 | 124 | # Generate sparse cost volume for GRU 125 | corr_val, coords0_cv, coords1_cv, batch_index_cv = compute_sparse_corr(fmap1, fmap2, k=self.args.num_k) 126 | 127 | delta_flow = torch.zeros_like(coords0) 128 | 129 | flow_predictions = [] 130 | 131 | search_range = 4 132 | corr_val = corr_val.expand(-1, 4, -1) 133 | 134 | for itr in range(self.iters): 135 | with torch.no_grad(): 136 | 137 | # need to switch order of delta_flow, also note the minus sign 138 | coords1_cv = coords1_cv - delta_flow[:, [1, 0], :, :].view(B, 2, 1, -1) # [B, 2, k, H1*W1] 139 | 140 | mask_pyramid = [] 141 | weights_pyramid = [] 142 | coords_sparse_pyramid = [] 143 | 144 | # Create multi-scale displacements 145 | for i in range(5): 146 | coords1_sp = coords1_cv * 0.5**i 147 | weights, coords1_sp = compute_interpolation_weights(coords1_sp) 148 | mask = (coords1_sp[:, 0].abs() <= search_range) & (coords1_sp[:, 1].abs() <= search_range) 149 | batch_ind = batch_index_cv.permute(0, 2, 3, 1).expand(-1, 4, -1, -1)[mask] 150 | coords0_sp = coords0_cv.permute(0, 2, 3, 1).expand(-1, 4, -1, -1)[mask] 151 | coords1_sp = coords1_sp.permute(0, 2, 3, 1)[mask] 152 | 153 | coords1_sp = coords1_sp + search_range 154 | coords_sp = torch.cat([batch_ind, coords0_sp, coords1_sp], dim=1) 155 | coords_sparse_pyramid.append(coords_sp) 156 | 157 | mask_pyramid.append(mask) 158 | weights_pyramid.append(weights) 159 | 160 | corr_val_pyramid = [] 161 | for mask, weights in zip(mask_pyramid, weights_pyramid): 162 | corr_masked = (weights * corr_val)[mask].unsqueeze(1) 163 | corr_val_pyramid.append(corr_masked) 164 | 165 | sparse_tensor_pyramid = [torch.sparse.FloatTensor(coords_sp.t().long(), corr_resample, torch.Size([B, H1, W1, 9, 9, 1])).coalesce() 166 | for coords_sp, corr_resample in zip(coords_sparse_pyramid, corr_val_pyramid)] 167 | 168 | corr = torch.cat([sp.to_dense().view(B, H1, W1, -1) for sp in sparse_tensor_pyramid], dim=3).permute(0, 3, 1, 2) 169 | 170 | coords1 = coords1.detach() 171 | 172 | flow = coords1 - coords0 173 | 174 | # GRU Update 175 | with autocast(enabled=self.args.mixed_precision): 176 | 177 | # 4D net map to 2D dense vector 178 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) 179 | 180 | # F(t+1) = F(t) + \Delta(t) 181 | coords1 = coords1 + delta_flow 182 | 183 | # upsample predictions 184 | if up_mask is None: 185 | flow_up = upflow4(coords1 - coords0) 186 | else: 187 | flow_up = self.upsample_flow_quarter(coords1 - coords0, up_mask) 188 | 189 | flow_predictions.append(flow_up) 190 | 191 | if test_mode: 192 | return flow_up 193 | 194 | return flow_predictions[::-1] 195 | 196 | 197 | def forward(self, images, flow_init=None, return_last_flow_only=False, fwd_bwd=True, suffix=''): 198 | """ Estimate optical flow between pair of frames """ 199 | outputs = {} 200 | 201 | dims = images.size() 202 | images_norm = (2 * (images / 255.0) - 1.0).contiguous() 203 | images_norm_flatten = images_norm.flatten(end_dim=-4) 204 | 205 | # run the feature network 206 | with autocast(enabled=self.args.mixed_precision): 207 | fmaps = self.fnet(images_norm_flatten).unflatten(dim=0, sizes=dims[:2]) 208 | 209 | fmap1, fmap2 = torch.split(fmaps.float(), 1, dim=1) 210 | 211 | if fwd_bwd: 212 | flows_f = self.compute_flow(images_norm[:, 0], fmap1, fmap2, test_mode=return_last_flow_only) 213 | flows_b = self.compute_flow(images_norm[:, 1], fmap2, fmap1, test_mode=return_last_flow_only) 214 | if return_last_flow_only: 215 | outputs['flows' + suffix] = torch.stack([flows_f, flows_b], dim=1) 216 | else: 217 | outputs['flows' + suffix] = torch.stack([torch.stack(flows_f, dim=1), torch.stack(flows_b, dim=1)], dim=1) 218 | 219 | else: 220 | outputs['flow_f' + suffix] = self.compute_flow(images_norm[:, 0], fmap1, fmap2, test_mode=True) 221 | 222 | return outputs 223 | -------------------------------------------------------------------------------- /utils/argument_parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import torch 5 | 6 | def get_arguments(): 7 | parser = argparse.ArgumentParser() 8 | add = parser.add_argument 9 | 10 | add('--name', type=str, default='raft', help="name your experiment") 11 | add('--ckpt_dir', type=str) 12 | add('--log_dir', type=str) 13 | add('--restore_ckpt', type=str, default=None) 14 | add('--init_step', type=int, default=0) 15 | add('--dataset_train', type=str) #, choices=['Chairs', 'Sintel', 'KITTI', 'HD1K'] 16 | add('--dataset_test', type=str) #, choices=['Chairs', 'Sintel', 'KITTI', 'HD1K'] 17 | add('--data_in_grayscale', action='store_true', help='indicate inputs are in grayscale') 18 | add('--num_workers', type=int, default=4) 19 | add('--seed', type=int, default=1) 20 | add('--debug', action='store_true', help='if True, no checkpoint file will be created') 21 | add('--sequentially', action='store_true', help='Apply photometric loss function in a sequential way to reduce memory cost') 22 | add('--VAL_FREQ', type=int, default=1000, help='validation frequency') 23 | add('--LOG_FREQ', type=int, default=100, help='log frequency') 24 | 25 | add('--batch_size', type=int) 26 | add('--num_steps', type=int) 27 | 28 | #Evaluation 29 | add('--eval', type=str, nargs='+', choices=['flow', 'match'], default='flow') 30 | add('--eval_on_train', action='store_true', help='Evaluate on train set') 31 | 32 | #Optimization 33 | add('--optimizer', type=str, choices=['adam', 'adamw', 'sgd'], default='adam') 34 | add('--lr', type=float, default=0.0002) 35 | add('--lr_decay_max', type=float, default=0.001) 36 | add('--scheduler', type=str, choices=['smurf', 'raft'], default='smurf') 37 | add('--end_warmup_step', type=int, default=0) 38 | add('--lr_decay_step', type=int) 39 | add('--wdecay', type=float, default=.00005) 40 | add('--epsilon', type=float, default=1e-8) 41 | add('--clip', type=float, default=1.0) 42 | add('--mixed_precision', action='store_true', help='use mixed precision') 43 | add('--sequence_weight', type=float, default=0.8) 44 | 45 | #Model 46 | add('--mode', type=str, choices=['flow_correc', 'flow_only']) 47 | add('--model', type=str, choices=['raft', 'gma', 'scv']) 48 | add('--small', action='store_true', help='use small model if model is raft') 49 | add('--iters', type=int, default=12) 50 | add('--dim_out_flow', type=int, default=2) 51 | add('--dropout', type=float, default=0.0) 52 | add('--crop_size', type=int, nargs='+', default=None) 53 | add('--occlusions', type=str, choices=['wang', 'brox', 'none'], default=None) 54 | add('--pad_mode', type=str, choices=['zeros', 'replicate', 'reflect'], default='zeros') 55 | add('--upsampling_mode', type=str, choices=['bilinear', 'bicubic', 'convex'], help='flow upsampling', default='convex') 56 | 57 | #Augmentations 58 | add('--no_photo_aug', action='store_true', help='To train the network without data augmentation (only works with corrections') 59 | add('--random_eraser', action='store_true', help='Use random eraser eraser') 60 | 61 | #Photometric loss 62 | add('--census_weight_flow', type=float, default=0.0) 63 | add('--census_weight_correc', type=float, default=0.0) 64 | add('--unflow_weight_flow', type=float, default=0.0) 65 | add('--unflow_weight_correc', type=float, default=0.0) 66 | add('--l1_weight_flow', type=float, default=0.0) 67 | add('--l1_weight_correc', type=float, default=0.0) 68 | add('--census_patch_size', type=int, default=7) 69 | add('--use_full_size_warping', action='store_true') 70 | 71 | #Smoothness loss 72 | add('--smoothness_order', type=int, choices=[1, 2], help='Order of the gradient of the flow used in the smoothness loss') 73 | add('--smoothness_weight', type=float) 74 | 75 | #Selfsup loss 76 | add('--selfsup_starting_step', type=int) 77 | add('--selfsup_end_rising_step', type=int) 78 | add('--selfsup_weight_max', type=float) 79 | add('--selfsup_distance', type=str, choices=['l1', 'charbonnier', 'huber', 'huber_charbonnier'], default='l1', help='Distance used to compare the flow from the teacher and the flow from the student') 80 | 81 | #Correction 82 | add('--correc_in_photo_starting_step', type=int, default=0) 83 | add('--correc_starting_step', type=int, default=0) 84 | add('--correc_weight', type=float, default=0.0) 85 | add('--smart_clamp', action='store_true', help='clip corrected images in the photometric loss') 86 | add('--keep_good_corrections_only', action='store_true', help='Keep in the photometric loss of the flow only well-estimated corrections') 87 | add('--input_dim_corrector', type=int, default=6) 88 | add('--occ_in_correc_inputs', action='store_true', help='Include occlusions in corrector inputs') 89 | add('--no_FSW_occ_in_correc_inputs', action='store_true', help='Include true occlusions in corrector inputs when training with full-size warping') 90 | add('--occ_and_out_in_correc_inputs', action='store_true', help='Include occlusions and true boundary occlusions in corrector inputs') 91 | add('--flows_in_correc_inputs', action='store_true', help='Include foward flow and warped backward flow in corrector inputs') 92 | 93 | # GMA args 94 | add('--position_only', default=False, action='store_true', help='only use position-wise attention') 95 | add('--position_and_content', default=False, action='store_true', help='use position and content-wise attention') 96 | add('--num_heads', default=1, type=int, help='number of heads in attention and aggregation') 97 | 98 | # SCV args 99 | add('--upsample-learn', action='store_true', default=False, help='If True, use learned upsampling, otherwise, use bilinear upsampling.') 100 | add('--gamma', type=float, default=0.8, help='exponential weighting') 101 | add('--num_k', type=int, default=32, help='number of hypotheses to compute for knn Faiss') 102 | add('--max_search_range', type=int, default=100, help='maximum search range for hypotheses in quarter resolution') 103 | 104 | args = parser.parse_args() 105 | args.num_gpus = torch.cuda.device_count() 106 | args.batch_size //= args.num_gpus 107 | 108 | args.upsampling_mode = 'bilinear' if args.small else args.upsampling_mode 109 | 110 | set_photometric_loss_weights(args) 111 | 112 | if args.crop_size is None: 113 | if 'kitti' in args.dataset_train.lower(): 114 | args.crop_size = [296, 696] 115 | elif 'chairs' in args.dataset_train.lower(): 116 | args.crop_size = [368, 496] 117 | elif 'sintel' in args.dataset_train.lower(): 118 | args.crop_size = [368, 496] 119 | else: 120 | raise NotImplementedError 121 | 122 | if args.occlusions is None: 123 | if 'kitti' in args.dataset_train.lower(): 124 | args.occlusions = 'brox' 125 | elif 'chairs' in args.dataset_train.lower(): 126 | args.occlusions = 'wang' 127 | elif 'sintel' in args.dataset_train.lower(): 128 | args.occlusions = 'wang' 129 | else: 130 | raise NotImplementedError 131 | 132 | if args.smoothness_weight is None: 133 | if 'kitti' in args.dataset_train.lower(): 134 | args.smoothness_weight = 2. 135 | elif 'chairs' in args.dataset_train.lower(): 136 | args.smoothness_weight = 2. 137 | elif 'sintel' in args.dataset_train.lower(): 138 | args.smoothness_weight = 2.5 139 | else: 140 | raise NotImplementedError 141 | 142 | if args.smoothness_order is None: 143 | if 'kitti' in args.dataset_train.lower(): 144 | args.smoothness_order = 2 145 | elif 'chairs' in args.dataset_train.lower(): 146 | args.smoothness_order = 1 147 | elif 'sintel' in args.dataset_train.lower(): 148 | args.smoothness_order = 1 149 | else: 150 | raise NotImplementedError 151 | 152 | if 'correc' in args.mode: 153 | if args.occ_in_correc_inputs or args.no_FSW_occ_in_correc_inputs: 154 | args.input_dim_corrector += 1 155 | elif args.occ_and_out_in_correc_inputs: 156 | args.input_dim_corrector += 2 157 | if args.flows_in_correc_inputs: 158 | args.input_dim_corrector += 4 159 | 160 | return args 161 | 162 | def assert_l1(census_weight, unflow_weight, l1_weight): 163 | if (unflow_weight > 0. and l1_weight > 0.) or (census_weight > 0. and l1_weight > 0): 164 | return False 165 | else: 166 | return True 167 | 168 | def set_photometric_loss_weights(args): 169 | assert assert_l1(args.census_weight_flow, args.unflow_weight_flow, args.l1_weight_flow), 'You have to choose between L1 and Census or Unflow loss' 170 | 171 | if args.unflow_weight_flow > 0.: 172 | args.ssim_weight_flow = 0.85 * args.unflow_weight_flow 173 | args.l1_weight_flow = 0.15 * args.unflow_weight_flow 174 | else: 175 | args.ssim_weight_flow = 0. 176 | 177 | if 'correc' in args.mode: 178 | assert assert_l1(args.census_weight_correc, args.unflow_weight_correc, args.l1_weight_correc), 'You have to choose between L1 and Census or Unflow loss' 179 | if args.unflow_weight_correc > 0.: 180 | args.ssim_weight_correc = 0.85 * args.unflow_weight_correc 181 | args.l1_weight_correc = 0.15 * args.unflow_weight_correc 182 | else: 183 | args.ssim_weight_correc = 0. 184 | 185 | del args.unflow_weight_flow 186 | del args.unflow_weight_correc 187 | 188 | 189 | def save_args(args, dir): 190 | with open(os.path.join(dir, 'config.json'), 'w') as f: 191 | json.dump(args.__dict__, f, indent=2) -------------------------------------------------------------------------------- /models/scv/extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ResidualBlock(nn.Module): 7 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 8 | super(ResidualBlock, self).__init__() 9 | 10 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) 11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 12 | self.relu = nn.ReLU(inplace=True) 13 | 14 | num_groups = planes // 8 15 | 16 | if norm_fn == 'group': 17 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 18 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 19 | # if not stride == 1: 20 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 21 | 22 | elif norm_fn == 'batch': 23 | self.norm1 = nn.BatchNorm2d(planes) 24 | self.norm2 = nn.BatchNorm2d(planes) 25 | # if not stride == 1: 26 | self.norm3 = nn.BatchNorm2d(planes) 27 | 28 | elif norm_fn == 'instance': 29 | self.norm1 = nn.InstanceNorm2d(planes) 30 | self.norm2 = nn.InstanceNorm2d(planes) 31 | # if not stride == 1: 32 | self.norm3 = nn.InstanceNorm2d(planes) 33 | 34 | elif norm_fn == 'none': 35 | self.norm1 = nn.Sequential() 36 | self.norm2 = nn.Sequential() 37 | # if not stride == 1: 38 | self.norm3 = nn.Sequential() 39 | 40 | # if stride == 1: 41 | # self.downsample = None 42 | # 43 | # else: 44 | self.downsample = nn.Sequential( 45 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 46 | 47 | 48 | def forward(self, x): 49 | y = x 50 | y = self.relu(self.norm1(self.conv1(y))) 51 | y = self.relu(self.norm2(self.conv2(y))) 52 | 53 | if self.downsample is not None: 54 | x = self.downsample(x) 55 | 56 | return self.relu(x+y) 57 | 58 | 59 | class BottleneckBlock(nn.Module): 60 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 61 | super(BottleneckBlock, self).__init__() 62 | 63 | self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) 64 | self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) 65 | self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) 66 | self.relu = nn.ReLU(inplace=True) 67 | 68 | num_groups = planes // 8 69 | 70 | if norm_fn == 'group': 71 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) 72 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) 73 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 74 | if not stride == 1: 75 | self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 76 | 77 | elif norm_fn == 'batch': 78 | self.norm1 = nn.BatchNorm2d(planes//4) 79 | self.norm2 = nn.BatchNorm2d(planes//4) 80 | self.norm3 = nn.BatchNorm2d(planes) 81 | if not stride == 1: 82 | self.norm4 = nn.BatchNorm2d(planes) 83 | 84 | elif norm_fn == 'instance': 85 | self.norm1 = nn.InstanceNorm2d(planes//4) 86 | self.norm2 = nn.InstanceNorm2d(planes//4) 87 | self.norm3 = nn.InstanceNorm2d(planes) 88 | if not stride == 1: 89 | self.norm4 = nn.InstanceNorm2d(planes) 90 | 91 | elif norm_fn == 'none': 92 | self.norm1 = nn.Sequential() 93 | self.norm2 = nn.Sequential() 94 | self.norm3 = nn.Sequential() 95 | if not stride == 1: 96 | self.norm4 = nn.Sequential() 97 | 98 | if stride == 1: 99 | self.downsample = None 100 | 101 | else: 102 | self.downsample = nn.Sequential( 103 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) 104 | 105 | def forward(self, x): 106 | y = x 107 | y = self.relu(self.norm1(self.conv1(y))) 108 | y = self.relu(self.norm2(self.conv2(y))) 109 | y = self.relu(self.norm3(self.conv3(y))) 110 | 111 | if self.downsample is not None: 112 | x = self.downsample(x) 113 | 114 | return self.relu(x+y) 115 | 116 | 117 | class BasicEncoder(nn.Module): 118 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): 119 | super(BasicEncoder, self).__init__() 120 | self.norm_fn = norm_fn 121 | 122 | if self.norm_fn == 'group': 123 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) 124 | 125 | elif self.norm_fn == 'batch': 126 | self.norm1 = nn.BatchNorm2d(64) 127 | 128 | elif self.norm_fn == 'instance': 129 | self.norm1 = nn.InstanceNorm2d(64) 130 | 131 | elif self.norm_fn == 'none': 132 | self.norm1 = nn.Sequential() 133 | 134 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) 135 | self.relu1 = nn.ReLU(inplace=True) 136 | 137 | self.in_planes = 64 138 | self.layer1 = self._make_layer(64, stride=1) 139 | self.layer2 = self._make_layer(96, stride=2) 140 | self.layer3 = self._make_layer(128, stride=2) 141 | 142 | # output convolution 143 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) 144 | 145 | self.dropout = None 146 | if dropout > 0: 147 | self.dropout = nn.Dropout2d(p=dropout) 148 | 149 | for m in self.modules(): 150 | if isinstance(m, nn.Conv2d): 151 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 152 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 153 | if m.weight is not None: 154 | nn.init.constant_(m.weight, 1) 155 | if m.bias is not None: 156 | nn.init.constant_(m.bias, 0) 157 | 158 | def _make_layer(self, dim, stride=1): 159 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) 160 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) 161 | layers = (layer1, layer2) 162 | 163 | self.in_planes = dim 164 | return nn.Sequential(*layers) 165 | 166 | def forward(self, x): 167 | 168 | # if input is list, combine batch dimension 169 | is_list = isinstance(x, tuple) or isinstance(x, list) 170 | if is_list: 171 | batch_dim = x[0].shape[0] 172 | x = torch.cat(x, dim=0) 173 | 174 | x = self.conv1(x) 175 | x = self.norm1(x) 176 | x = self.relu1(x) 177 | 178 | x = self.layer1(x) 179 | x = self.layer2(x) 180 | x = self.layer3(x) 181 | 182 | x = self.conv2(x) 183 | 184 | if self.training and self.dropout is not None: 185 | x = self.dropout(x) 186 | 187 | if is_list: 188 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 189 | 190 | return x 191 | 192 | 193 | class BasicEncoderQuarter(nn.Module): 194 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): 195 | super(BasicEncoderQuarter, self).__init__() 196 | self.norm_fn = norm_fn 197 | 198 | if self.norm_fn == 'group': 199 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) 200 | 201 | elif self.norm_fn == 'batch': 202 | self.norm1 = nn.BatchNorm2d(64) 203 | 204 | elif self.norm_fn == 'instance': 205 | self.norm1 = nn.InstanceNorm2d(64) 206 | 207 | elif self.norm_fn == 'none': 208 | self.norm1 = nn.Sequential() 209 | 210 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) 211 | self.relu1 = nn.ReLU(inplace=True) 212 | 213 | self.in_planes = 64 214 | self.layer1 = self._make_layer(64, stride=1) 215 | self.layer2 = self._make_layer(96, stride=2) 216 | self.layer3 = self._make_layer(128, stride=1) 217 | 218 | # output convolution 219 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) 220 | 221 | self.dropout = None 222 | if dropout > 0: 223 | self.dropout = nn.Dropout2d(p=dropout) 224 | 225 | for m in self.modules(): 226 | if isinstance(m, nn.Conv2d): 227 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 228 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 229 | if m.weight is not None: 230 | nn.init.constant_(m.weight, 1) 231 | if m.bias is not None: 232 | nn.init.constant_(m.bias, 0) 233 | 234 | def _make_layer(self, dim, stride=1): 235 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) 236 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) 237 | layers = (layer1, layer2) 238 | 239 | self.in_planes = dim 240 | return nn.Sequential(*layers) 241 | 242 | def forward(self, x): 243 | 244 | # if input is list, combine batch dimension 245 | is_list = isinstance(x, tuple) or isinstance(x, list) 246 | if is_list: 247 | batch_dim = x[0].shape[0] 248 | x = torch.cat(x, dim=0) 249 | 250 | x = self.conv1(x) 251 | x = self.norm1(x) 252 | x = self.relu1(x) 253 | 254 | x = self.layer1(x) 255 | x = self.layer2(x) 256 | x = self.layer3(x) 257 | 258 | x = self.conv2(x) 259 | 260 | if self.training and self.dropout is not None: 261 | x = self.dropout(x) 262 | 263 | if is_list: 264 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 265 | 266 | return x 267 | -------------------------------------------------------------------------------- /datasets/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | from glob import glob 4 | import numpy as np 5 | import torch 6 | import torch.utils.data as data 7 | from skimage.color import rgb2gray 8 | 9 | from datasets.utils_data import read_gen, readFlowKITTI 10 | 11 | 12 | class FlowDataset(data.Dataset): 13 | def __init__(self, args, augmentor=None, sparse=False, is_training=True): 14 | self.args = args 15 | if augmentor is not None: 16 | self.augmentor = augmentor(args) 17 | else: 18 | self.augmentor = None 19 | self.sparse = sparse 20 | self.is_training = is_training 21 | 22 | self.flow_list = [] 23 | self.image_list = [] 24 | self.extra_info = [] 25 | self.pseudo_gt_list = [] 26 | 27 | def __getitem__(self, index): 28 | 29 | index = index % len(self.image_list) 30 | 31 | im1 = read_gen(self.image_list[index][0]) 32 | im2 = read_gen(self.image_list[index][1]) 33 | im1 = np.array(im1) 34 | im2 = np.array(im2) 35 | 36 | # grayscale images 37 | if len(im1.shape) == 2: 38 | im1 = np.tile(im1.astype(np.uint8)[...,None], (1, 1, 3)) 39 | im2 = np.tile(im2.astype(np.uint8)[...,None], (1, 1, 3)) 40 | 41 | elif self.args.data_in_grayscale: 42 | im1 = np.tile(rgb2gray(im1)[...,None] * 255, (1, 1, 3)).astype(np.uint8) 43 | im2 = np.tile(rgb2gray(im2)[...,None] * 255, (1, 1, 3)).astype(np.uint8) 44 | 45 | else: 46 | im1 = im1[..., :3].astype(np.uint8) 47 | im2 = im2[..., :3].astype(np.uint8) 48 | # else: 49 | if self.augmentor is not None: 50 | example = self.augmentor(im1, im2) 51 | 52 | else: 53 | example = { 54 | 'ims': np.stack([im1, im2]) 55 | } 56 | 57 | for key, value in example.items(): 58 | if key in ['pad_params', 'orig_dims', 'offsets', 'valid', 'pseudo_gt']: 59 | example[key] = torch.from_numpy(value) 60 | else: 61 | example[key] = torch.from_numpy(value).permute(0, 3, 1, 2).float() 62 | 63 | example['index'] = [index, self.extra_info[index]] 64 | 65 | if not self.is_training: 66 | 67 | if self.sparse: 68 | flow_occ, valid_occ = readFlowKITTI(self.flow_occ_list[index]) 69 | flow_noc, valid_noc = readFlowKITTI(self.flow_noc_list[index]) 70 | 71 | flow_occ = np.array(flow_occ).astype(np.float32) 72 | flow_noc = np.array(flow_noc).astype(np.float32) 73 | valid_occ = np.expand_dims(valid_occ, axis=0) 74 | valid_noc = np.expand_dims(valid_noc, axis=0) 75 | 76 | example['flow_occ'] = torch.from_numpy(flow_occ).permute(2, 0, 1).float() 77 | example['valid_occ'] = torch.from_numpy(valid_occ) 78 | example['flow_noc'] = torch.from_numpy(flow_noc).permute(2, 0, 1).float() 79 | example['valid_noc'] = torch.from_numpy(valid_noc) 80 | 81 | else: 82 | 83 | flow = read_gen(self.flow_list[index]) 84 | flow = np.array(flow).astype(np.float32) 85 | flow = torch.from_numpy(flow).permute(2, 0, 1).float() 86 | valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000).unsqueeze(0) 87 | 88 | example['flow'] = flow 89 | example['valid'] = valid 90 | 91 | return example 92 | 93 | def __rmul__(self, v): 94 | self.flow_list = v * self.flow_list 95 | self.image_list = v * self.image_list 96 | return self 97 | 98 | def __len__(self): 99 | return len(self.image_list) 100 | 101 | 102 | class MpiSintel(FlowDataset): 103 | def __init__(self, args, augmentor, is_training=True, split='training', root='/path/to/sintel/', dstype='final'): 104 | super(MpiSintel, self).__init__(args, augmentor=augmentor, sparse=False, is_training=is_training) 105 | 106 | if split == 'training': 107 | image_root = osp.join(root, 'test', dstype) 108 | 109 | for scene in os.listdir(image_root): 110 | image_list = sorted(glob(osp.join(image_root, scene, '*.png'))) 111 | for i in range(len(image_list)-1): 112 | self.image_list += [ [image_list[i], image_list[i+1]] ] 113 | frame1_id = image_list[i].split('/')[-1].split('.')[0] 114 | frame2_id = image_list[i+1].split('/')[-1].split('.')[0] 115 | self.extra_info += [[f'{scene}_{frame1_id}', f'{scene}_{frame2_id}']] 116 | 117 | elif split == 'validation': 118 | image_root = osp.join(root, 'training', dstype) 119 | flow_root = osp.join(root, 'training', 'flow') 120 | 121 | for scene in os.listdir(image_root): 122 | image_list = sorted(glob(osp.join(image_root, scene, '*.png'))) 123 | for i in range(len(image_list)-1): 124 | self.image_list += [ [image_list[i], image_list[i+1]] ] 125 | 126 | self.extra_info += [ (scene, i) ] # scene and frame_id 127 | 128 | self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo'))) 129 | 130 | 131 | class KITTI(FlowDataset): 132 | def __init__(self, args, augmentor, is_training=True, split='training', root='/path/to/kitti_data/data_scene_flow/'): 133 | super(KITTI, self).__init__(args, augmentor=augmentor, sparse=True, is_training=is_training) 134 | 135 | images1 = [] 136 | images2 = [] 137 | self.flow_occ_list = [] 138 | self.flow_noc_list = [] 139 | 140 | if split == 'training': 141 | root = osp.join(root, 'testing') 142 | 143 | for dir in ['image_2', 'image_3']: 144 | for i in range(200): 145 | images1 += sorted(glob(osp.join(root, f'{dir}/{str(i).zfill(6)}_*.png')))[:-1] 146 | images2 += sorted(glob(osp.join(root, f'{dir}/{str(i).zfill(6)}_*.png')))[1:] 147 | 148 | elif split == 'validation': 149 | root = osp.join(root, 'training') 150 | 151 | for i in range(200): 152 | images1 += [osp.join(root, f'image_2/{str(i).zfill(6)}_10.png')] 153 | images2 += [osp.join(root, f'image_2/{str(i).zfill(6)}_11.png')] 154 | for i in range(200): 155 | 156 | self.flow_occ_list += [osp.join(root, f'flow_occ/{str(i).zfill(6)}_10.png')] 157 | self.flow_noc_list += [osp.join(root, f'flow_noc/{str(i).zfill(6)}_10.png')] 158 | 159 | for im1, im2 in zip(images1, images2): 160 | dir = im1.split('/')[-2] 161 | frame1_id = im1.split('/')[-1].split('.')[0] 162 | frame2_id = im2.split('/')[-1].split('.')[0] 163 | self.extra_info += [ [f'{dir}_{frame1_id}', f'{dir}_{frame2_id}'] ] 164 | self.image_list += [ [im1, im2] ] 165 | 166 | 167 | class HD1K(FlowDataset): 168 | def __init__(self, args, augmentor, is_training=True, split='training', root='/path/to/HD1K'): 169 | super(HD1K, self).__init__(args, augmentor=augmentor, sparse=True, is_training=is_training) 170 | 171 | self.flow_occ_list = [] 172 | self.flow_noc_list = [] 173 | 174 | if split == 'validation': 175 | 176 | seq_ix = 0 177 | while True: 178 | flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix))) 179 | images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix))) 180 | 181 | if len(flows) == 0: 182 | break 183 | 184 | for i in range(len(flows)-1): 185 | self.flow_occ_list += [flows[i]] 186 | self.flow_noc_list += [flows[i]] 187 | self.image_list += [ [images[i], images[i+1]] ] 188 | frame1_id = images[i].split('/')[-1].split('.')[0] 189 | frame2_id = images[i+1].split('/')[-1].split('.')[0] 190 | self.extra_info += [ [frame1_id, frame2_id] ] 191 | 192 | seq_ix += 1 193 | 194 | elif split == 'training': 195 | 196 | seq_ix = 0 197 | while True: 198 | images = sorted(glob(os.path.join(root, 'hd1k_challenge', 'image_2/%06d_*.png' % seq_ix))) 199 | 200 | if len(flows) == 0: 201 | break 202 | 203 | for i in range(len(flows)-1): 204 | self.image_list += [ [images[i], images[i+1]] ] 205 | frame1_id = images[i].split('/')[-1].split('.')[0] 206 | frame2_id = images[i+1].split('/')[-1].split('.')[0] 207 | self.extra_info += [ [frame1_id, frame2_id] ] 208 | 209 | seq_ix += 1 210 | 211 | 212 | class Chairs(FlowDataset): 213 | def __init__(self, args, augmentor, is_training=True, split='training', root='/path/to/FlyingChairs_release/data'): 214 | super(Chairs, self).__init__(args, augmentor=augmentor, sparse=False, is_training=is_training) 215 | 216 | self.args = args 217 | 218 | flows = sorted(glob(osp.join(root, '*.flo'))) 219 | split_list = np.loadtxt('/path/to/FlyingChairs_release/FlyingChairs_train_val.txt', dtype=np.int32) 220 | 221 | images = sorted(glob(osp.join(root, '*.ppm'))) 222 | assert (len(images)//2 == len(flows)) 223 | 224 | for i in range(len(flows)): 225 | xid = split_list[i] 226 | if (split=='training' and xid==1) or (split=='validation' and xid==2): 227 | self.flow_list += [ flows[i] ] 228 | self.image_list += [ [images[2*i], images[2*i+1]] ] 229 | frame1_id = images[2*i].split('/')[-1].split('.')[0] 230 | frame2_id = images[2*i+1].split('/')[-1].split('.')[0] 231 | self.extra_info += [ [frame1_id, frame2_id] ] 232 | --------------------------------------------------------------------------------