├── README.md ├── alt_cuda_corr ├── build │ ├── lib.linux-x86_64-3.6 │ │ └── alt_cuda_corr.cpython-36m-x86_64-linux-gnu.so │ └── temp.linux-x86_64-3.6 │ │ ├── .ninja_deps │ │ ├── .ninja_log │ │ ├── build.ninja │ │ ├── correlation.o │ │ └── correlation_kernel.o ├── correlation.cpp ├── correlation.egg-info │ ├── PKG-INFO │ ├── SOURCES.txt │ ├── dependency_links.txt │ └── top_level.txt ├── correlation_kernel.cu ├── dist │ └── correlation-0.0.0-py3.6-linux-x86_64.egg ├── run_install.sh └── setup.py ├── configs ├── kitti.py ├── kitti_multiframes.py ├── multiframes_sintel_submission.py ├── sintel.py ├── sintel_multiframes.py ├── sintel_submission.py ├── things.py └── things_multiframes.py ├── core ├── Networks │ ├── BOFNet │ │ ├── __init__.py │ │ ├── cnn.py │ │ ├── corr.py │ │ ├── gma.py │ │ ├── network.py │ │ ├── sk.py │ │ ├── sk2.py │ │ └── update.py │ ├── MOFNetStack │ │ ├── __init__.py │ │ ├── cnn.py │ │ ├── corr.py │ │ ├── gma.py │ │ ├── network.py │ │ ├── resstack.py │ │ ├── sk.py │ │ ├── sk2.py │ │ ├── stack.py │ │ ├── stackcat.py │ │ └── update.py │ ├── __init__.py │ ├── common.py │ ├── encoders.py │ └── twins_ft.py ├── __init__.py ├── datasets_3frames.py ├── datasets_multiframes.py ├── loss.py ├── optimizer │ └── __init__.py └── utils │ ├── __init__.py │ ├── augmentor.py │ ├── augmentor_multiframes.py │ ├── augmentor_twoframes.py │ ├── flow_transforms.py │ ├── flow_viz.py │ ├── frame_utils.py │ ├── logger.py │ ├── misc.py │ └── utils.py ├── demo_input_images ├── frame_0001.png ├── frame_0002.png ├── frame_0003.png ├── frame_0004.png ├── frame_0005.png ├── frame_0006.png ├── frame_0007.png ├── frame_0008.png ├── frame_0009.png └── frame_0010.png ├── evaluate_BOFNet.py ├── evaluate_MOFNet.py ├── flow_dataset_mf ├── convert_HD1K.py ├── convert_sintel.py ├── convert_things.py ├── flyingthings_frames_cleanpass_future_pfm.pkl ├── flyingthings_frames_cleanpass_past_pfm.pkl ├── flyingthings_frames_cleanpass_png.pkl ├── flyingthings_frames_finalpass_future_pfm.pkl ├── flyingthings_frames_finalpass_past_pfm.pkl ├── flyingthings_frames_finalpass_png.pkl ├── flyingthings_thres5.pkl ├── hd1k_flo.pkl ├── hd1k_png.pkl ├── sintel_training_clean_flo.pkl ├── sintel_training_clean_png.pkl ├── sintel_training_final_flo.pkl ├── sintel_training_final_png.pkl └── sintel_training_scene.pkl ├── flow_datasets ├── KITTI │ ├── KITTI_testing_extra_info.txt │ ├── KITTI_testing_image.txt │ ├── KITTI_training_extra_info.txt │ ├── KITTI_training_flow.txt │ ├── KITTI_training_image.txt │ └── generate_KITTI_list.py ├── flying_things_three_frames │ ├── convert_things.py │ ├── flyingthings_frames_cleanpass_pfm.txt │ ├── flyingthings_frames_cleanpass_png.txt │ ├── flyingthings_frames_finalpass_pfm.txt │ └── flyingthings_frames_finalpass_png.txt ├── hd1k_three_frames │ ├── convert_HD1K.py │ ├── hd1k_flo.txt │ └── hd1k_image.txt └── sintel_three_frames │ ├── Sintel_clean_extra_info.txt │ ├── Sintel_clean_flo.txt │ ├── Sintel_clean_png.txt │ ├── Sintel_final_extra_info.txt │ ├── Sintel_final_flo.txt │ ├── Sintel_final_png.txt │ ├── Sintel_reverse_flo.txt │ └── convert_sintel.py ├── inference.py ├── train_BOFNet.py └── train_MOFNet.py /README.md: -------------------------------------------------------------------------------- 1 | # [VideoFlow: Exploiting Temporal Cues for Multi-frame Optical Flow Estimation](https://arxiv.org/abs/2303.08340) 2 | 3 | 4 | > VideoFlow: Exploiting Temporal Cues for Multi-frame Optical Flow Estimation 5 | > [Xiaoyu Shi](https://xiaoyushi97.github.io/), [Zhaoyang Huang](https://drinkingcoder.github.io), [Weikang Bian](https://wkbian.github.io/), [Dasong Li](https://dasongli1.github.io/), [Manyuan Zhang](https://manyuan97.github.io/), Ka Chun Cheung, Simon See, [Hongwei Qin](http://qinhongwei.com/academic/), [Jifeng Dai](https://jifengdai.org/), [Hongsheng Li](https://www.ee.cuhk.edu.hk/~hsli/) 6 | > ICCV 2023 7 | 8 | https://github.com/XiaoyuShi97/VideoFlow/assets/25840016/8121acc6-b874-411e-86de-df55f7d386a9 9 | 10 | 11 | ## Requirements 12 | ```shell 13 | conda create --name videoflow 14 | conda activate videoflow 15 | conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 matplotlib tensorboard scipy opencv-python -c pytorch 16 | pip install yacs loguru einops timm==0.4.12 imageio 17 | ``` 18 | 19 | ## Models 20 | We provide pretrained [models](https://drive.google.com/drive/folders/16YqDD_IQpzrVWvDHI9xK3kO0MaXnNIGx?usp=sharing). The default path of the models for evaluation is: 21 | ```Shell 22 | ├── VideoFlow_ckpt 23 | ├── MOF_sintel.pth 24 | ├── BOF_sintel.pth 25 | ├── MOF_things.pth 26 | ├── BOF_things.pth 27 | ├── MOF_kitti.pth 28 | ├── BOF_kitti.pth 29 | ``` 30 | 31 | ## Inference & Visualization 32 | Download VideoFlow_ckpt and put it in the root dir. Run the following command: 33 | ```shell 34 | python -u inference.py --mode MOF --seq_dir demo_input_images --vis_dir demo_flow_vis 35 | ``` 36 | If your input only contain three frames, we recommend to use the BOF model: 37 | ```shell 38 | python -u inference.py --mode BOF --seq_dir demo_input_images_three_frames --vis_dir demo_flow_vis_three_frames 39 | ``` 40 | 41 | ## Data Preparation 42 | To evaluate/train VideoFlow, you will need to download the required datasets. 43 | * [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs) 44 | * [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) 45 | * [Sintel](http://sintel.is.tue.mpg.de/) 46 | * [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow) (multi-view extension, 20 frames per scene, 14 GB) 47 | * [HD1K](http://hci-benchmark.iwr.uni-heidelberg.de/) 48 | 49 | By default `datasets.py` will search for the datasets in these locations. You can create symbolic links to wherever the datasets were downloaded in the `datasets` folder 50 | 51 | ```Shell 52 | ├── datasets 53 | ├── Sintel 54 | ├── test 55 | ├── training 56 | ├── KITTI 57 | ├── testing 58 | ├── training 59 | ├── devkit 60 | ├── FlyingChairs_release 61 | ├── data 62 | ├── FlyingThings3D 63 | ├── frames_cleanpass 64 | ├── frames_finalpass 65 | ├── optical_flow 66 | ``` 67 | 68 | 69 | ## Training 70 | The script will load the config according to the training stage. The trained model will be saved in a directory in `logs` and `checkpoints`. For example, the following script will load the config `configs/***.py`. The trained model will be saved as `logs/xxxx/final`. 71 | ```shell 72 | # Train MOF model 73 | python -u train_MOFNet.py --name MOF-things --stage things --validation sintel 74 | python -u train_MOFNet.py --name MOF-sintel --stage sintel --validation sintel 75 | python -u train_MOFNet.py --name MOF-kitti --stage kitti --validation sintel 76 | 77 | # Train BOF model 78 | python -u train_BOFNet.py --name BOF-things --stage things --validation sintel 79 | python -u train_BOFNet.py --name BOF-sintel --stage sintel --validation sintel 80 | python -u train_BOFNet.py --name BOF-kitti --stage kitti --validation sintel 81 | ``` 82 | 83 | ## Evaluation 84 | The script will load the config `configs/multiframes_sintel_submission.py` or `configs/sintel_submission.py`. Please change the `_CN.model` in the config file to load corresponding checkpoints. 85 | ```shell 86 | # Evaluate MOF_things.pth after C stage 87 | python -u evaluate_MOFNet.py --dataset=sintel 88 | python -u evaluate_MOFNet.py --dataset=things 89 | python -u evaluate_MOFNet.py --dataset=kitti 90 | # To evaluate MOF_sintel.pth, create submission to Sintel bechmark after C+S 91 | python -u evaluate_MOFNet.py --dataset=sintel_submission_stride1 92 | # To evaluate MOF_kitti.pth, create submission to Kitti bechmark after C+S+K 93 | python -u evaluate_MOFNet.py --dataset=kitti_submission 94 | ``` 95 | Similarly, to evaluate BOF models: 96 | ```shell 97 | # Evaluate BOF_things.pth after C stage 98 | python -u evaluate_BOFNet.py --dataset=sintel 99 | python -u evaluate_BOFNet.py --dataset=things 100 | python -u evaluate_BOFNet.py --dataset=kitti 101 | # To evaluate BOF_sintel.pth, create submission to Sintel bechmark after C+S 102 | python -u evaluate_BOFNet.py --dataset=sintel_submission 103 | # To evaluate BOF_kitti.pth, create submission to Kitti bechmark after C+S+K 104 | python -u evaluate_BOFNet.py --dataset=kitti_submission 105 | ``` 106 | 107 | ## (Optional & Inference Only) Efficent Implementation 108 | You can optionally use RAFT alternate (efficent) implementation by compiling the provided cuda extension and change the [`corr_fn`](https://github.com/XiaoyuShi97/VideoFlow/blob/main/configs/multiframes_sintel_submission.py#L32) flag to be `efficient` in config files. 109 | ```Shell 110 | cd alt_cuda_corr && python setup.py install && cd .. 111 | ``` 112 | Note that this implementation is somewhat slower than all-pairs, but uses significantly less GPU memory during the forward pass. And it does not implement backward function, so do not use it in training. 113 | 114 | ## License 115 | VideoFlow is released under the Apache License 116 | 117 | ## Citation 118 | ```bibtex 119 | @article{shi2023videoflow, 120 | title={Videoflow: Exploiting temporal cues for multi-frame optical flow estimation}, 121 | author={Shi, Xiaoyu and Huang, Zhaoyang and Bian, Weikang and Li, Dasong and Zhang, Manyuan and Cheung, Ka Chun and See, Simon and Qin, Hongwei and Dai, Jifeng and Li, Hongsheng}, 122 | journal={arXiv preprint arXiv:2303.08340}, 123 | year={2023} 124 | } 125 | ``` 126 | 127 | ## Acknowledgement 128 | 129 | In this project, we use parts of codes in: 130 | - [RAFT](https://github.com/princeton-vl/RAFT) 131 | - [GMA](https://github.com/zacjiang/GMA) 132 | - [timm](https://github.com/rwightman/pytorch-image-models) 133 | -------------------------------------------------------------------------------- /alt_cuda_corr/build/lib.linux-x86_64-3.6/alt_cuda_corr.cpython-36m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiaoyuShi97/VideoFlow/51489304db6e75fbdd9ff64d4656c1d120b5a673/alt_cuda_corr/build/lib.linux-x86_64-3.6/alt_cuda_corr.cpython-36m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /alt_cuda_corr/build/temp.linux-x86_64-3.6/.ninja_deps: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiaoyuShi97/VideoFlow/51489304db6e75fbdd9ff64d4656c1d120b5a673/alt_cuda_corr/build/temp.linux-x86_64-3.6/.ninja_deps -------------------------------------------------------------------------------- /alt_cuda_corr/build/temp.linux-x86_64-3.6/.ninja_log: -------------------------------------------------------------------------------- 1 | # ninja log v5 2 | 14 43414 1674813234000000000 /mnt/lustre/shixiaoyu1/flow/BOFNet/alt_cuda_corr/build/temp.linux-x86_64-3.6/correlation.o a907b39a4eb15a34 3 | 19 120472 1674813310000000000 /mnt/lustre/shixiaoyu1/flow/BOFNet/alt_cuda_corr/build/temp.linux-x86_64-3.6/correlation_kernel.o ee596d56bae4937b 4 | -------------------------------------------------------------------------------- /alt_cuda_corr/build/temp.linux-x86_64-3.6/build.ninja: -------------------------------------------------------------------------------- 1 | ninja_required_version = 1.3 2 | cxx = /mnt/lustre/share/gcc/gcc-5.4/bin/g++ 3 | nvcc = /mnt/lustre/share/cuda-11.2/bin/nvcc 4 | 5 | cflags = -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -I/mnt/cache/share/spring/conda_envs/miniconda3/envs/s0.3.4/lib/python3.6/site-packages/torch/include -I/mnt/cache/share/spring/conda_envs/miniconda3/envs/s0.3.4/lib/python3.6/site-packages/torch/include/torch/csrc/api/include -I/mnt/cache/share/spring/conda_envs/miniconda3/envs/s0.3.4/lib/python3.6/site-packages/torch/include/TH -I/mnt/cache/share/spring/conda_envs/miniconda3/envs/s0.3.4/lib/python3.6/site-packages/torch/include/THC -I/mnt/lustre/share/cuda-11.2/include -I/mnt/cache/share/spring/conda_envs/miniconda3/envs/s0.3.4/include/python3.6m -c 6 | post_cflags = -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1009"' -DTORCH_EXTENSION_NAME=alt_cuda_corr -D_GLIBCXX_USE_CXX11_ABI=1 -std=c++14 7 | cuda_cflags = -I/mnt/cache/share/spring/conda_envs/miniconda3/envs/s0.3.4/lib/python3.6/site-packages/torch/include -I/mnt/cache/share/spring/conda_envs/miniconda3/envs/s0.3.4/lib/python3.6/site-packages/torch/include/torch/csrc/api/include -I/mnt/cache/share/spring/conda_envs/miniconda3/envs/s0.3.4/lib/python3.6/site-packages/torch/include/TH -I/mnt/cache/share/spring/conda_envs/miniconda3/envs/s0.3.4/lib/python3.6/site-packages/torch/include/THC -I/mnt/lustre/share/cuda-11.2/include -I/mnt/cache/share/spring/conda_envs/miniconda3/envs/s0.3.4/include/python3.6m -c 8 | cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1009"' -DTORCH_EXTENSION_NAME=alt_cuda_corr -D_GLIBCXX_USE_CXX11_ABI=1 -gencode=arch=compute_70,code=compute_70 -gencode=arch=compute_70,code=sm_70 -ccbin /mnt/lustre/share/gcc/gcc-5.4/bin/gcc -std=c++14 9 | ldflags = 10 | 11 | rule compile 12 | command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags 13 | depfile = $out.d 14 | deps = gcc 15 | 16 | rule cuda_compile 17 | depfile = $out.d 18 | deps = gcc 19 | command = $nvcc --generate-dependencies-with-compile --dependency-output $out.d $cuda_cflags -c $in -o $out $cuda_post_cflags 20 | 21 | 22 | 23 | build /mnt/lustre/shixiaoyu1/flow/BOFNet/alt_cuda_corr/build/temp.linux-x86_64-3.6/correlation.o: compile /mnt/lustre/shixiaoyu1/flow/BOFNet/alt_cuda_corr/correlation.cpp 24 | build /mnt/lustre/shixiaoyu1/flow/BOFNet/alt_cuda_corr/build/temp.linux-x86_64-3.6/correlation_kernel.o: cuda_compile /mnt/lustre/shixiaoyu1/flow/BOFNet/alt_cuda_corr/correlation_kernel.cu 25 | 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /alt_cuda_corr/build/temp.linux-x86_64-3.6/correlation.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiaoyuShi97/VideoFlow/51489304db6e75fbdd9ff64d4656c1d120b5a673/alt_cuda_corr/build/temp.linux-x86_64-3.6/correlation.o -------------------------------------------------------------------------------- /alt_cuda_corr/build/temp.linux-x86_64-3.6/correlation_kernel.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiaoyuShi97/VideoFlow/51489304db6e75fbdd9ff64d4656c1d120b5a673/alt_cuda_corr/build/temp.linux-x86_64-3.6/correlation_kernel.o -------------------------------------------------------------------------------- /alt_cuda_corr/correlation.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | // CUDA forward declarations 5 | std::vector corr_cuda_forward( 6 | torch::Tensor fmap1, 7 | torch::Tensor fmap2, 8 | torch::Tensor coords, 9 | int radius); 10 | 11 | std::vector corr_cuda_backward( 12 | torch::Tensor fmap1, 13 | torch::Tensor fmap2, 14 | torch::Tensor coords, 15 | torch::Tensor corr_grad, 16 | int radius); 17 | 18 | // C++ interface 19 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 20 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 21 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 22 | 23 | std::vector corr_forward( 24 | torch::Tensor fmap1, 25 | torch::Tensor fmap2, 26 | torch::Tensor coords, 27 | int radius) { 28 | CHECK_INPUT(fmap1); 29 | CHECK_INPUT(fmap2); 30 | CHECK_INPUT(coords); 31 | 32 | return corr_cuda_forward(fmap1, fmap2, coords, radius); 33 | } 34 | 35 | 36 | std::vector corr_backward( 37 | torch::Tensor fmap1, 38 | torch::Tensor fmap2, 39 | torch::Tensor coords, 40 | torch::Tensor corr_grad, 41 | int radius) { 42 | CHECK_INPUT(fmap1); 43 | CHECK_INPUT(fmap2); 44 | CHECK_INPUT(coords); 45 | CHECK_INPUT(corr_grad); 46 | 47 | return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius); 48 | } 49 | 50 | 51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 52 | m.def("forward", &corr_forward, "CORR forward"); 53 | m.def("backward", &corr_backward, "CORR backward"); 54 | } -------------------------------------------------------------------------------- /alt_cuda_corr/correlation.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 1.0 2 | Name: correlation 3 | Version: 0.0.0 4 | Summary: UNKNOWN 5 | Home-page: UNKNOWN 6 | Author: UNKNOWN 7 | Author-email: UNKNOWN 8 | License: UNKNOWN 9 | Description: UNKNOWN 10 | Platform: UNKNOWN 11 | -------------------------------------------------------------------------------- /alt_cuda_corr/correlation.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | correlation.cpp 2 | correlation_kernel.cu 3 | setup.py 4 | correlation.egg-info/PKG-INFO 5 | correlation.egg-info/SOURCES.txt 6 | correlation.egg-info/dependency_links.txt 7 | correlation.egg-info/top_level.txt -------------------------------------------------------------------------------- /alt_cuda_corr/correlation.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /alt_cuda_corr/correlation.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | alt_cuda_corr 2 | -------------------------------------------------------------------------------- /alt_cuda_corr/correlation_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | 7 | #define BLOCK_H 4 8 | #define BLOCK_W 8 9 | #define BLOCK_HW BLOCK_H * BLOCK_W 10 | #define CHANNEL_STRIDE 32 11 | 12 | 13 | __forceinline__ __device__ 14 | bool within_bounds(int h, int w, int H, int W) { 15 | return h >= 0 && h < H && w >= 0 && w < W; 16 | } 17 | 18 | template 19 | __global__ void corr_forward_kernel( 20 | const torch::PackedTensorAccessor32 fmap1, 21 | const torch::PackedTensorAccessor32 fmap2, 22 | const torch::PackedTensorAccessor32 coords, 23 | torch::PackedTensorAccessor32 corr, 24 | int r) 25 | { 26 | const int b = blockIdx.x; 27 | const int h0 = blockIdx.y * blockDim.x; 28 | const int w0 = blockIdx.z * blockDim.y; 29 | const int tid = threadIdx.x * blockDim.y + threadIdx.y; 30 | 31 | const int H1 = fmap1.size(1); 32 | const int W1 = fmap1.size(2); 33 | const int H2 = fmap2.size(1); 34 | const int W2 = fmap2.size(2); 35 | const int N = coords.size(1); 36 | const int C = fmap1.size(3); 37 | 38 | __shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1]; 39 | __shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1]; 40 | __shared__ scalar_t x2s[BLOCK_HW]; 41 | __shared__ scalar_t y2s[BLOCK_HW]; 42 | 43 | for (int c=0; c(floor(y2s[k1]))-r+iy; 76 | int w2 = static_cast(floor(x2s[k1]))-r+ix; 77 | int c2 = tid % CHANNEL_STRIDE; 78 | 79 | auto fptr = fmap2[b][h2][w2]; 80 | if (within_bounds(h2, w2, H2, W2)) 81 | f2[c2][k1] = fptr[c+c2]; 82 | else 83 | f2[c2][k1] = 0.0; 84 | } 85 | 86 | __syncthreads(); 87 | 88 | scalar_t s = 0.0; 89 | for (int k=0; k 0 && ix > 0 && within_bounds(h1, w1, H1, W1)) 105 | *(corr_ptr + ix_nw) += nw; 106 | 107 | if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1)) 108 | *(corr_ptr + ix_ne) += ne; 109 | 110 | if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1)) 111 | *(corr_ptr + ix_sw) += sw; 112 | 113 | if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1)) 114 | *(corr_ptr + ix_se) += se; 115 | } 116 | } 117 | } 118 | } 119 | } 120 | 121 | 122 | template 123 | __global__ void corr_backward_kernel( 124 | const torch::PackedTensorAccessor32 fmap1, 125 | const torch::PackedTensorAccessor32 fmap2, 126 | const torch::PackedTensorAccessor32 coords, 127 | const torch::PackedTensorAccessor32 corr_grad, 128 | torch::PackedTensorAccessor32 fmap1_grad, 129 | torch::PackedTensorAccessor32 fmap2_grad, 130 | torch::PackedTensorAccessor32 coords_grad, 131 | int r) 132 | { 133 | 134 | const int b = blockIdx.x; 135 | const int h0 = blockIdx.y * blockDim.x; 136 | const int w0 = blockIdx.z * blockDim.y; 137 | const int tid = threadIdx.x * blockDim.y + threadIdx.y; 138 | 139 | const int H1 = fmap1.size(1); 140 | const int W1 = fmap1.size(2); 141 | const int H2 = fmap2.size(1); 142 | const int W2 = fmap2.size(2); 143 | const int N = coords.size(1); 144 | const int C = fmap1.size(3); 145 | 146 | __shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1]; 147 | __shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1]; 148 | 149 | __shared__ scalar_t f1_grad[CHANNEL_STRIDE][BLOCK_HW+1]; 150 | __shared__ scalar_t f2_grad[CHANNEL_STRIDE][BLOCK_HW+1]; 151 | 152 | __shared__ scalar_t x2s[BLOCK_HW]; 153 | __shared__ scalar_t y2s[BLOCK_HW]; 154 | 155 | for (int c=0; c(floor(y2s[k1]))-r+iy; 190 | int w2 = static_cast(floor(x2s[k1]))-r+ix; 191 | int c2 = tid % CHANNEL_STRIDE; 192 | 193 | auto fptr = fmap2[b][h2][w2]; 194 | if (within_bounds(h2, w2, H2, W2)) 195 | f2[c2][k1] = fptr[c+c2]; 196 | else 197 | f2[c2][k1] = 0.0; 198 | 199 | f2_grad[c2][k1] = 0.0; 200 | } 201 | 202 | __syncthreads(); 203 | 204 | const scalar_t* grad_ptr = &corr_grad[b][n][0][h1][w1]; 205 | scalar_t g = 0.0; 206 | 207 | int ix_nw = H1*W1*((iy-1) + rd*(ix-1)); 208 | int ix_ne = H1*W1*((iy-1) + rd*ix); 209 | int ix_sw = H1*W1*(iy + rd*(ix-1)); 210 | int ix_se = H1*W1*(iy + rd*ix); 211 | 212 | if (iy > 0 && ix > 0 && within_bounds(h1, w1, H1, W1)) 213 | g += *(grad_ptr + ix_nw) * dy * dx; 214 | 215 | if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1)) 216 | g += *(grad_ptr + ix_ne) * dy * (1-dx); 217 | 218 | if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1)) 219 | g += *(grad_ptr + ix_sw) * (1-dy) * dx; 220 | 221 | if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1)) 222 | g += *(grad_ptr + ix_se) * (1-dy) * (1-dx); 223 | 224 | for (int k=0; k(floor(y2s[k1]))-r+iy; 232 | int w2 = static_cast(floor(x2s[k1]))-r+ix; 233 | int c2 = tid % CHANNEL_STRIDE; 234 | 235 | scalar_t* fptr = &fmap2_grad[b][h2][w2][0]; 236 | if (within_bounds(h2, w2, H2, W2)) 237 | atomicAdd(fptr+c+c2, f2_grad[c2][k1]); 238 | } 239 | } 240 | } 241 | } 242 | __syncthreads(); 243 | 244 | 245 | for (int k=0; k corr_cuda_forward( 261 | torch::Tensor fmap1, 262 | torch::Tensor fmap2, 263 | torch::Tensor coords, 264 | int radius) 265 | { 266 | const auto B = coords.size(0); 267 | const auto N = coords.size(1); 268 | const auto H = coords.size(2); 269 | const auto W = coords.size(3); 270 | 271 | const auto rd = 2 * radius + 1; 272 | auto opts = fmap1.options(); 273 | auto corr = torch::zeros({B, N, rd*rd, H, W}, opts); 274 | 275 | const dim3 blocks(B, (H+BLOCK_H-1)/BLOCK_H, (W+BLOCK_W-1)/BLOCK_W); 276 | const dim3 threads(BLOCK_H, BLOCK_W); 277 | 278 | corr_forward_kernel<<>>( 279 | fmap1.packed_accessor32(), 280 | fmap2.packed_accessor32(), 281 | coords.packed_accessor32(), 282 | corr.packed_accessor32(), 283 | radius); 284 | 285 | return {corr}; 286 | } 287 | 288 | std::vector corr_cuda_backward( 289 | torch::Tensor fmap1, 290 | torch::Tensor fmap2, 291 | torch::Tensor coords, 292 | torch::Tensor corr_grad, 293 | int radius) 294 | { 295 | const auto B = coords.size(0); 296 | const auto N = coords.size(1); 297 | 298 | const auto H1 = fmap1.size(1); 299 | const auto W1 = fmap1.size(2); 300 | const auto H2 = fmap2.size(1); 301 | const auto W2 = fmap2.size(2); 302 | const auto C = fmap1.size(3); 303 | 304 | auto opts = fmap1.options(); 305 | auto fmap1_grad = torch::zeros({B, H1, W1, C}, opts); 306 | auto fmap2_grad = torch::zeros({B, H2, W2, C}, opts); 307 | auto coords_grad = torch::zeros({B, N, H1, W1, 2}, opts); 308 | 309 | const dim3 blocks(B, (H1+BLOCK_H-1)/BLOCK_H, (W1+BLOCK_W-1)/BLOCK_W); 310 | const dim3 threads(BLOCK_H, BLOCK_W); 311 | 312 | 313 | corr_backward_kernel<<>>( 314 | fmap1.packed_accessor32(), 315 | fmap2.packed_accessor32(), 316 | coords.packed_accessor32(), 317 | corr_grad.packed_accessor32(), 318 | fmap1_grad.packed_accessor32(), 319 | fmap2_grad.packed_accessor32(), 320 | coords_grad.packed_accessor32(), 321 | radius); 322 | 323 | return {fmap1_grad, fmap2_grad, coords_grad}; 324 | } -------------------------------------------------------------------------------- /alt_cuda_corr/dist/correlation-0.0.0-py3.6-linux-x86_64.egg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiaoyuShi97/VideoFlow/51489304db6e75fbdd9ff64d4656c1d120b5a673/alt_cuda_corr/dist/correlation-0.0.0-py3.6-linux-x86_64.egg -------------------------------------------------------------------------------- /alt_cuda_corr/run_install.sh: -------------------------------------------------------------------------------- 1 | PYTHONPATH=/mnt/cache/shixiaoyu1/.local/lib/python3.6/site-packages 2 | export CXX=/mnt/lustre/share/gcc/gcc-5.4/bin/g++ 3 | export CC=/mnt/lustre/share/gcc/gcc-5.4/bin/gcc 4 | export CUDA_HOME=/mnt/lustre/share/cuda-11.2 5 | srun --cpus-per-task=5 --ntasks-per-node=1 -p ISPCodec -n1 --gres=gpu:1 python setup.py install --user 6 | -------------------------------------------------------------------------------- /alt_cuda_corr/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | 5 | setup( 6 | name='correlation', 7 | ext_modules=[ 8 | CUDAExtension('alt_cuda_corr', 9 | sources=['correlation.cpp', 'correlation_kernel.cu'], 10 | extra_compile_args={'cxx': [], 'nvcc': ['-O3']}), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) 15 | -------------------------------------------------------------------------------- /configs/kitti.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | _CN = CN() 3 | 4 | _CN.name = '' 5 | _CN.suffix ='' 6 | _CN.gamma = 0.85 7 | _CN.max_flow = 400 8 | _CN.batch_size = 8 9 | _CN.sum_freq = 100 10 | _CN.val_freq = 100000000 11 | _CN.image_size = [432, 960] 12 | _CN.add_noise = False 13 | _CN.use_smoothl1 = False 14 | _CN.critical_params = [] 15 | 16 | _CN.network = 'BOFNet' 17 | _CN.mixed_precision = False 18 | _CN.filter_epe = False 19 | 20 | _CN.restore_ckpt = "PATH-TO-FINAL" 21 | 22 | _CN.BOFNet = CN() 23 | _CN.BOFNet.pretrain = True 24 | _CN.BOFNet.cnet = 'twins' 25 | _CN.BOFNet.fnet = 'twins' 26 | _CN.BOFNet.gma = 'GMA-SK2' 27 | _CN.BOFNet.corr_fn = "default" 28 | _CN.BOFNet.mixed_precision = False 29 | 30 | _CN.BOFNet.decoder_depth = 12 31 | _CN.BOFNet.critical_params = ["cnet", "fnet", "pretrain", "corr_fn", "mixed_precision"] 32 | 33 | ### TRAINER 34 | _CN.trainer = CN() 35 | _CN.trainer.scheduler = 'OneCycleLR' 36 | _CN.trainer.optimizer = 'adamw' 37 | _CN.trainer.canonical_lr = 12.5e-5 38 | _CN.trainer.adamw_decay = 1e-4 39 | _CN.trainer.clip = 1.0 40 | _CN.trainer.num_steps = 80000 41 | _CN.trainer.epsilon = 1e-8 42 | _CN.trainer.anneal_strategy = 'linear' 43 | def get_cfg(): 44 | return _CN.clone() 45 | -------------------------------------------------------------------------------- /configs/kitti_multiframes.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | _CN = CN() 3 | 4 | _CN.name = '' 5 | _CN.suffix ='' 6 | _CN.gamma = 0.85 7 | _CN.max_flow = 400 8 | _CN.batch_size = 8 9 | _CN.sum_freq = 100 10 | _CN.val_freq = 100000000 11 | _CN.image_size = [432, 960] 12 | _CN.add_noise = False 13 | _CN.use_smoothl1 = False 14 | _CN.critical_params = [] 15 | 16 | _CN.network = 'MOFNetStack' 17 | 18 | _CN.restore_ckpt = "PATH-TO-FINAL/final" 19 | 20 | _CN.mixed_precision = False 21 | _CN.input_frames = 5 22 | _CN.filter_epe = False 23 | 24 | ################################################### 25 | ################################################### 26 | _CN.MOFNetStack = CN() 27 | _CN.MOFNetStack.pretrain = True 28 | _CN.MOFNetStack.Tfusion = 'stack' 29 | _CN.MOFNetStack.cnet = 'twins' 30 | _CN.MOFNetStack.fnet = 'twins' 31 | _CN.MOFNetStack.down_ratio = 8 32 | _CN.MOFNetStack.feat_dim = 256 33 | _CN.MOFNetStack.corr_fn = 'default' 34 | _CN.MOFNetStack.corr_levels = 4 35 | _CN.MOFNetStack.mixed_precision = False 36 | _CN.MOFNetStack.context_3D = False 37 | _CN.MOFNetStack.GMA_MF = False 38 | 39 | _CN.MOFNetStack.decoder_depth = 12 40 | _CN.MOFNetStack.critical_params = ["cnet", "fnet", "pretrain", 'corr_fn', "Tfusion", "corr_levels", "decoder_depth", "mixed_precision", "GMA_MF"] 41 | 42 | ### TRAINER 43 | _CN.trainer = CN() 44 | _CN.trainer.scheduler = 'OneCycleLR' 45 | _CN.trainer.optimizer = 'adamw' 46 | _CN.trainer.canonical_lr = 12.5e-5 47 | _CN.trainer.adamw_decay = 1e-4 48 | _CN.trainer.clip = 1.0 49 | _CN.trainer.num_steps = 25000 50 | _CN.trainer.epsilon = 1e-8 51 | _CN.trainer.anneal_strategy = 'linear' 52 | def get_cfg(): 53 | return _CN.clone() 54 | -------------------------------------------------------------------------------- /configs/multiframes_sintel_submission.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | _CN = CN() 3 | 4 | _CN.name = '' 5 | _CN.suffix ='' 6 | _CN.gamma = 0.85 7 | _CN.max_flow = 400 8 | _CN.batch_size = 8 9 | _CN.sum_freq = 100 10 | _CN.val_freq = 100000000 11 | _CN.image_size = [432, 960] 12 | _CN.add_noise = False 13 | _CN.use_smoothl1 = False 14 | _CN.critical_params = [] 15 | 16 | _CN.network = 'MOFNetStack' 17 | 18 | _CN.model = 'VideoFlow_ckpt/MOF_sintel.pth' 19 | _CN.input_frames = 5 20 | 21 | _CN.restore_ckpt = None 22 | 23 | ################################################ 24 | ################################################ 25 | _CN.MOFNetStack = CN() 26 | _CN.MOFNetStack.pretrain = True 27 | _CN.MOFNetStack.Tfusion = 'stack' 28 | _CN.MOFNetStack.cnet = 'twins' 29 | _CN.MOFNetStack.fnet = 'twins' 30 | _CN.MOFNetStack.down_ratio = 8 31 | _CN.MOFNetStack.feat_dim = 256 32 | _CN.MOFNetStack.corr_fn = 'default' 33 | _CN.MOFNetStack.corr_levels = 4 34 | _CN.MOFNetStack.mixed_precision = True 35 | _CN.MOFNetStack.context_3D = False 36 | 37 | _CN.MOFNetStack.decoder_depth = 32 38 | _CN.MOFNetStack.critical_params = ["cnet", "fnet", "pretrain", 'corr_fn', "Tfusion", "corr_levels", "decoder_depth", "mixed_precision"] 39 | 40 | ### TRAINER 41 | _CN.trainer = CN() 42 | _CN.trainer.scheduler = 'OneCycleLR' 43 | _CN.trainer.optimizer = 'adamw' 44 | _CN.trainer.canonical_lr = 12.5e-5 45 | _CN.trainer.adamw_decay = 1e-4 46 | _CN.trainer.clip = 1.0 47 | _CN.trainer.num_steps = 90000 48 | _CN.trainer.epsilon = 1e-8 49 | _CN.trainer.anneal_strategy = 'linear' 50 | def get_cfg(): 51 | return _CN.clone() 52 | -------------------------------------------------------------------------------- /configs/sintel.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | _CN = CN() 3 | 4 | _CN.name = '' 5 | _CN.suffix ='' 6 | _CN.gamma = 0.85 7 | _CN.max_flow = 400 8 | _CN.batch_size = 8 9 | _CN.sum_freq = 100 10 | _CN.val_freq = 100000000 11 | _CN.image_size = [432, 960] 12 | _CN.add_noise = False 13 | _CN.use_smoothl1 = False 14 | _CN.critical_params = [] 15 | 16 | _CN.network = 'BOFNet' 17 | _CN.mixed_precision = False 18 | _CN.filter_epe = False 19 | 20 | _CN.restore_ckpt = "PATH_TO_FINAL/final" 21 | 22 | _CN.BOFNet = CN() 23 | _CN.BOFNet.pretrain = True 24 | _CN.BOFNet.cnet = 'twins' 25 | _CN.BOFNet.fnet = 'twins' 26 | _CN.BOFNet.gma = 'GMA-SK2' 27 | _CN.BOFNet.corr_fn = "default" 28 | _CN.BOFNet.mixed_precision = False 29 | 30 | _CN.BOFNet.decoder_depth = 12 31 | _CN.BOFNet.critical_params = ["cnet", "fnet", "pretrain", "corr_fn", "mixed_precision"] 32 | 33 | 34 | ### TRAINER 35 | _CN.trainer = CN() 36 | _CN.trainer.scheduler = 'OneCycleLR' 37 | _CN.trainer.optimizer = 'adamw' 38 | _CN.trainer.canonical_lr = 12.5e-5 39 | _CN.trainer.adamw_decay = 1e-4 40 | _CN.trainer.clip = 1.0 41 | _CN.trainer.num_steps = 120000 42 | _CN.trainer.epsilon = 1e-8 43 | _CN.trainer.anneal_strategy = 'linear' 44 | def get_cfg(): 45 | return _CN.clone() 46 | -------------------------------------------------------------------------------- /configs/sintel_multiframes.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | _CN = CN() 3 | 4 | _CN.name = '' 5 | _CN.suffix ='' 6 | _CN.gamma = 0.85 7 | _CN.max_flow = 400 8 | _CN.batch_size = 8 9 | _CN.sum_freq = 100 10 | _CN.val_freq = 100000000 11 | _CN.image_size = [432, 960] 12 | _CN.add_noise = False 13 | _CN.use_smoothl1 = False 14 | _CN.critical_params = [] 15 | 16 | _CN.network = 'MOFNetStack' 17 | 18 | _CN.restore_ckpt = "PATH_TO_FINAL/final" 19 | 20 | _CN.mixed_precision = True 21 | _CN.input_frames = 5 22 | _CN.filter_epe = False 23 | 24 | ################################################### 25 | ################################################### 26 | _CN.MOFNetStack = CN() 27 | _CN.MOFNetStack.pretrain = True 28 | _CN.MOFNetStack.Tfusion = 'stack' 29 | _CN.MOFNetStack.cnet = 'twins' 30 | _CN.MOFNetStack.fnet = 'twins' 31 | _CN.MOFNetStack.down_ratio = 8 32 | _CN.MOFNetStack.feat_dim = 256 33 | _CN.MOFNetStack.corr_fn = 'default' 34 | _CN.MOFNetStack.corr_levels = 4 35 | _CN.MOFNetStack.mixed_precision = True 36 | _CN.MOFNetStack.context_3D = False 37 | 38 | _CN.MOFNetStack.decoder_depth = 12 39 | _CN.MOFNetStack.critical_params = ["cnet", "fnet", "pretrain", 'corr_fn', "Tfusion", "corr_levels", "decoder_depth", "mixed_precision"] 40 | 41 | 42 | ### TRAINER 43 | _CN.trainer = CN() 44 | _CN.trainer.scheduler = 'OneCycleLR' 45 | _CN.trainer.optimizer = 'adamw' 46 | _CN.trainer.canonical_lr = 12.5e-5 47 | _CN.trainer.adamw_decay = 1e-4 48 | _CN.trainer.clip = 1.0 49 | _CN.trainer.num_steps = 40000 50 | _CN.trainer.epsilon = 1e-8 51 | _CN.trainer.anneal_strategy = 'linear' 52 | def get_cfg(): 53 | return _CN.clone() 54 | -------------------------------------------------------------------------------- /configs/sintel_submission.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | _CN = CN() 3 | 4 | _CN.name = '' 5 | _CN.suffix ='' 6 | _CN.gamma = 0.85 7 | _CN.max_flow = 400 8 | _CN.batch_size = 8 9 | _CN.sum_freq = 100 10 | _CN.val_freq = 100000000 11 | _CN.image_size = [432, 960] 12 | _CN.add_noise = False 13 | _CN.use_smoothl1 = False 14 | _CN.critical_params = [] 15 | 16 | _CN.network = 'BOFNet' 17 | 18 | _CN.restore_ckpt = None 19 | 20 | _CN.model = "VideoFlow_ckpt/BOF_sintel.pth" 21 | 22 | 23 | _CN.BOFNet = CN() 24 | _CN.BOFNet.pretrain = True 25 | _CN.BOFNet.cnet = 'twins' 26 | _CN.BOFNet.fnet = 'twins' 27 | _CN.BOFNet.gma = 'GMA-SK2' 28 | _CN.BOFNet.corr_fn = "default" 29 | _CN.BOFNet.corr_levels = 4 30 | _CN.BOFNet.mixed_precision = True 31 | 32 | _CN.BOFNet.decoder_depth = 32 33 | _CN.BOFNet.critical_params = ["cnet", "fnet", "pretrain"] 34 | 35 | 36 | ### TRAINER 37 | _CN.trainer = CN() 38 | _CN.trainer.scheduler = 'OneCycleLR' 39 | _CN.trainer.optimizer = 'adamw' 40 | _CN.trainer.canonical_lr = 12.5e-5 41 | _CN.trainer.adamw_decay = 1e-4 42 | _CN.trainer.clip = 1.0 43 | _CN.trainer.num_steps = 90000 44 | _CN.trainer.epsilon = 1e-8 45 | _CN.trainer.anneal_strategy = 'linear' 46 | def get_cfg(): 47 | return _CN.clone() 48 | -------------------------------------------------------------------------------- /configs/things.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | _CN = CN() 3 | 4 | _CN.name = '' 5 | _CN.suffix ='' 6 | _CN.gamma = 0.8 7 | _CN.max_flow = 400 8 | _CN.batch_size = 8 9 | _CN.sum_freq = 100 10 | _CN.val_freq = 100000000 11 | _CN.image_size = [432, 960] 12 | _CN.add_noise = False 13 | _CN.use_smoothl1 = False 14 | _CN.critical_params = [] 15 | 16 | _CN.network = 'BOFNet' 17 | _CN.mixed_precision = False 18 | _CN.filter_epe = False 19 | 20 | _CN.restore_ckpt = None 21 | 22 | _CN.BOFNet = CN() 23 | _CN.BOFNet.pretrain = True 24 | _CN.BOFNet.gma = 'GMA-SK2' 25 | _CN.BOFNet.cnet = 'twins' 26 | _CN.BOFNet.fnet = 'twins' 27 | _CN.BOFNet.corr_fn = 'default' 28 | _CN.BOFNet.corr_levels = 4 29 | _CN.BOFNet.mixed_precision = False 30 | 31 | _CN.BOFNet.decoder_depth = 12 32 | _CN.BOFNet.critical_params = ["cnet", "fnet", "pretrain", 'corr_fn', "gma", "corr_levels", "decoder_depth", "mixed_precision"] 33 | 34 | ### TRAINER 35 | _CN.trainer = CN() 36 | _CN.trainer.scheduler = 'OneCycleLR' 37 | _CN.trainer.optimizer = 'adamw' 38 | _CN.trainer.canonical_lr = 25e-5 39 | _CN.trainer.adamw_decay = 1e-4 40 | _CN.trainer.clip = 1.0 41 | _CN.trainer.num_steps = 120000 42 | _CN.trainer.epsilon = 1e-8 43 | _CN.trainer.anneal_strategy = 'linear' 44 | def get_cfg(): 45 | return _CN.clone() 46 | -------------------------------------------------------------------------------- /configs/things_multiframes.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | _CN = CN() 3 | 4 | _CN.name = '' 5 | _CN.suffix ='' 6 | _CN.gamma = 0.8 7 | _CN.max_flow = 400 8 | _CN.batch_size = 8 9 | _CN.sum_freq = 100 10 | _CN.val_freq = 100000000 11 | _CN.image_size = [432, 960] 12 | _CN.add_noise = False 13 | _CN.use_smoothl1 = False 14 | _CN.critical_params = [] 15 | 16 | _CN.network = 'MOFNetStack' 17 | _CN.mixed_precision = True 18 | _CN.input_frames = 5 19 | _CN.filter_epe = False 20 | 21 | _CN.restore_ckpt = None 22 | 23 | _CN.MOFNetStack = CN() 24 | _CN.MOFNetStack.pretrain = True 25 | _CN.MOFNetStack.Tfusion = 'stack' 26 | _CN.MOFNetStack.cnet = 'twins' 27 | _CN.MOFNetStack.fnet = 'twins' 28 | _CN.MOFNetStack.down_ratio = 8 29 | _CN.MOFNetStack.feat_dim = 256 30 | _CN.MOFNetStack.corr_fn = 'default' 31 | _CN.MOFNetStack.corr_levels = 4 32 | _CN.MOFNetStack.mixed_precision = True 33 | _CN.MOFNetStack.context_3D = False 34 | 35 | _CN.MOFNetStack.decoder_depth = 6 36 | _CN.MOFNetStack.critical_params = ["cnet", "fnet", "pretrain", "Tfusion", "decoder_depth", "mixed_precision", "down_ratio", "feat_dim"] 37 | 38 | ### TRAINER 39 | _CN.trainer = CN() 40 | _CN.trainer.scheduler = 'OneCycleLR' 41 | _CN.trainer.optimizer = 'adamw' 42 | _CN.trainer.canonical_lr = 25e-5 43 | _CN.trainer.adamw_decay = 1e-4 44 | _CN.trainer.clip = 1.0 45 | _CN.trainer.num_steps = 125000 46 | _CN.trainer.epsilon = 1e-8 47 | _CN.trainer.anneal_strategy = 'linear' 48 | def get_cfg(): 49 | return _CN.clone() 50 | -------------------------------------------------------------------------------- /core/Networks/BOFNet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiaoyuShi97/VideoFlow/51489304db6e75fbdd9ff64d4656c1d120b5a673/core/Networks/BOFNet/__init__.py -------------------------------------------------------------------------------- /core/Networks/BOFNet/corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from ...utils.utils import bilinear_sampler, coords_grid 6 | # from compute_sparse_correlation import compute_sparse_corr, compute_sparse_corr_torch, compute_sparse_corr_mink 7 | import alt_cuda_corr 8 | 9 | try: 10 | import alt_cuda_corr 11 | except: 12 | # alt_cuda_corr is not compiled 13 | print("[!!alt_cuda_corr is not compiled!!]") 14 | pass 15 | 16 | class DirectCorr(torch.autograd.Function): 17 | @staticmethod 18 | def forward(ctx, fmap1, fmap2, coords): 19 | ctx.save_for_backward(fmap1, fmap2, coords) 20 | corr, = alt_cuda_corr.forward(fmap1, fmap2, coords, 4) 21 | return corr 22 | 23 | def backward(ctx, grad_output): 24 | fmap1, fmap2, coords = ctx.saved_tensors 25 | grad_output = grad_output.contiguous() 26 | fmap1_grad, fmap2_grad, coords_grad = \ 27 | alt_cuda_corr.backward(fmap1, fmap2, coords, grad_output, 4) 28 | 29 | return fmap1_grad, fmap2_grad, coords_grad 30 | 31 | class OLCorrBlock: 32 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 33 | self.num_levels = num_levels 34 | self.radius = radius 35 | 36 | batch, dim, ht, wd = fmap1.shape 37 | self.fmap1 = fmap1.permute(0, 2, 3, 1).view(batch*ht*wd, 1, dim) 38 | 39 | self.fmap2_pyramid = [] 40 | self.fmap2_pyramid.append(fmap2) 41 | for i in range(self.num_levels - 1): 42 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2) 43 | self.fmap2_pyramid.append(fmap2) 44 | 45 | def __call__(self, coords): 46 | r = self.radius 47 | coords = coords.permute(0, 2, 3, 1) 48 | batch, h1, w1, _ = coords.shape 49 | _, _, dim = self.fmap1.shape 50 | 51 | out_pyramid = [] 52 | for i in range(self.num_levels): 53 | fmap2 = self.fmap2_pyramid[i] 54 | dx = torch.linspace(-r, r, 2 * r + 1) 55 | dy = torch.linspace(-r, r, 2 * r + 1) 56 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) 57 | 58 | centroid_lvl = coords.reshape(batch, h1 * w1, 1, 2) / 2 ** i 59 | delta_lvl = delta.view(1, 1, (2 * r + 1) ** 2, 2) 60 | coords_lvl = centroid_lvl + delta_lvl 61 | 62 | fmap2 = bilinear_sampler(fmap2, coords_lvl) # B, 256, h*w, 9*9 63 | fmap2 = fmap2.permute(0, 2, 1, 3).view(batch*h1*w1, dim, (2 * r + 1) ** 2) 64 | #print(self.fmap1.shape, fmap2.shape) 65 | corr = torch.bmm(self.fmap1, fmap2) / torch.sqrt(torch.tensor(dim).float()) 66 | 67 | corr = corr.view(batch, h1, w1, -1) 68 | out_pyramid.append(corr) 69 | 70 | out = torch.cat(out_pyramid, dim=-1) 71 | return out.permute(0, 3, 1, 2).float() 72 | 73 | 74 | class CorrBlock: 75 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 76 | self.num_levels = num_levels 77 | self.radius = radius 78 | self.corr_pyramid = [] 79 | 80 | # all pairs correlation 81 | corr = CorrBlock.corr(fmap1, fmap2) 82 | 83 | batch, h1, w1, dim, h2, w2 = corr.shape 84 | corr = corr.reshape(batch * h1 * w1, dim, h2, w2) 85 | 86 | self.corr_pyramid.append(corr) 87 | for i in range(self.num_levels - 1): 88 | corr = F.avg_pool2d(corr, 2, stride=2) 89 | self.corr_pyramid.append(corr) 90 | 91 | def __call__(self, coords): 92 | r = self.radius 93 | coords = coords.permute(0, 2, 3, 1) 94 | batch, h1, w1, _ = coords.shape 95 | 96 | out_pyramid = [] 97 | for i in range(self.num_levels): 98 | corr = self.corr_pyramid[i] 99 | dx = torch.linspace(-r, r, 2 * r + 1) 100 | dy = torch.linspace(-r, r, 2 * r + 1) 101 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) 102 | 103 | centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2 ** i 104 | delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) 105 | coords_lvl = centroid_lvl + delta_lvl 106 | 107 | corr = bilinear_sampler(corr, coords_lvl) 108 | corr = corr.view(batch, h1, w1, -1) 109 | out_pyramid.append(corr) 110 | 111 | out = torch.cat(out_pyramid, dim=-1) 112 | return out.permute(0, 3, 1, 2).contiguous().float() 113 | 114 | @staticmethod 115 | def corr(fmap1, fmap2): 116 | batch, dim, ht, wd = fmap1.shape 117 | fmap1 = fmap1.view(batch, dim, ht * wd) 118 | fmap2 = fmap2.view(batch, dim, ht * wd) 119 | 120 | corr = torch.matmul(fmap1.transpose(1, 2), fmap2) 121 | corr = corr.view(batch, ht, wd, 1, ht, wd) 122 | return corr / torch.sqrt(torch.tensor(dim).float()) 123 | 124 | 125 | class CorrBlockSingleScale(nn.Module): 126 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 127 | super().__init__() 128 | self.radius = radius 129 | 130 | # all pairs correlation 131 | corr = CorrBlock.corr(fmap1, fmap2) 132 | batch, h1, w1, dim, h2, w2 = corr.shape 133 | self.corr = corr.reshape(batch * h1 * w1, dim, h2, w2) 134 | 135 | def __call__(self, coords): 136 | r = self.radius 137 | coords = coords.permute(0, 2, 3, 1) 138 | batch, h1, w1, _ = coords.shape 139 | 140 | corr = self.corr 141 | dx = torch.linspace(-r, r, 2 * r + 1) 142 | dy = torch.linspace(-r, r, 2 * r + 1) 143 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) 144 | 145 | centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) 146 | delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) 147 | coords_lvl = centroid_lvl + delta_lvl 148 | 149 | corr = bilinear_sampler(corr, coords_lvl) 150 | out = corr.view(batch, h1, w1, -1) 151 | out = out.permute(0, 3, 1, 2).contiguous().float() 152 | return out 153 | 154 | @staticmethod 155 | def corr(fmap1, fmap2): 156 | batch, dim, ht, wd = fmap1.shape 157 | fmap1 = fmap1.view(batch, dim, ht * wd) 158 | fmap2 = fmap2.view(batch, dim, ht * wd) 159 | 160 | corr = torch.matmul(fmap1.transpose(1, 2), fmap2) 161 | corr = corr.view(batch, ht, wd, 1, ht, wd) 162 | return corr / torch.sqrt(torch.tensor(dim).float()) 163 | 164 | class AlternateCorrBlock: 165 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 166 | self.num_levels = num_levels 167 | self.radius = radius 168 | 169 | self.pyramid = [(fmap1, fmap2)] 170 | for i in range(self.num_levels): 171 | #fmap1 = F.avg_pool2d(fmap1, 2, stride=2) 172 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2) 173 | self.pyramid.append((None, fmap2)) 174 | 175 | def __call__(self, coords): 176 | coords = coords.permute(0, 2, 3, 1) 177 | B, H, W, _ = coords.shape 178 | dim = self.pyramid[0][0].shape[1] 179 | 180 | corr_list = [] 181 | for i in range(self.num_levels): 182 | r = self.radius 183 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous().float() 184 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous().float() 185 | 186 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() 187 | corr, = DirectCorr.apply(fmap1_i, fmap2_i, coords_i) 188 | corr_list.append(corr.squeeze(1)) 189 | 190 | corr = torch.stack(corr_list, dim=1) 191 | corr = corr.reshape(B, -1, H, W) 192 | return corr / torch.sqrt(torch.tensor(dim).float()) 193 | -------------------------------------------------------------------------------- /core/Networks/BOFNet/gma.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | from einops import rearrange 4 | 5 | 6 | class RelPosEmb(nn.Module): 7 | def __init__( 8 | self, 9 | max_pos_size, 10 | dim_head 11 | ): 12 | super().__init__() 13 | self.rel_height = nn.Embedding(2 * max_pos_size - 1, dim_head) 14 | self.rel_width = nn.Embedding(2 * max_pos_size - 1, dim_head) 15 | 16 | deltas = torch.arange(max_pos_size).view(1, -1) - torch.arange(max_pos_size).view(-1, 1) 17 | rel_ind = deltas + max_pos_size - 1 18 | self.register_buffer('rel_ind', rel_ind) 19 | 20 | def forward(self, q): 21 | batch, heads, h, w, c = q.shape 22 | height_emb = self.rel_height(self.rel_ind[:h, :h].reshape(-1)) 23 | width_emb = self.rel_width(self.rel_ind[:w, :w].reshape(-1)) 24 | 25 | height_emb = rearrange(height_emb, '(x u) d -> x u () d', x=h) 26 | width_emb = rearrange(width_emb, '(y v) d -> y () v d', y=w) 27 | 28 | height_score = einsum('b h x y d, x u v d -> b h x y u v', q, height_emb) 29 | width_score = einsum('b h x y d, y u v d -> b h x y u v', q, width_emb) 30 | 31 | return height_score + width_score 32 | 33 | 34 | class Attention(nn.Module): 35 | def __init__( 36 | self, 37 | *, 38 | args, 39 | dim, 40 | max_pos_size = 100, 41 | heads = 4, 42 | dim_head = 128, 43 | ): 44 | super().__init__() 45 | self.args = args 46 | self.heads = heads 47 | self.scale = dim_head ** -0.5 48 | inner_dim = heads * dim_head 49 | 50 | self.to_qk = nn.Conv2d(dim, inner_dim * 2, 1, bias=False) 51 | 52 | self.pos_emb = RelPosEmb(max_pos_size, dim_head) 53 | 54 | def forward(self, fmap): 55 | heads, b, c, h, w = self.heads, *fmap.shape 56 | 57 | q, k = self.to_qk(fmap).chunk(2, dim=1) 58 | 59 | q, k = map(lambda t: rearrange(t, 'b (h d) x y -> b h x y d', h=heads), (q, k)) 60 | q = self.scale * q 61 | 62 | # if self.args.position_only: 63 | # sim = self.pos_emb(q) 64 | 65 | # elif self.args.position_and_content: 66 | # sim_content = einsum('b h x y d, b h u v d -> b h x y u v', q, k) 67 | # sim_pos = self.pos_emb(q) 68 | # sim = sim_content + sim_pos 69 | 70 | # else: 71 | sim = einsum('b h x y d, b h u v d -> b h x y u v', q, k) 72 | 73 | sim = rearrange(sim, 'b h x y u v -> b h (x y) (u v)') 74 | attn = sim.softmax(dim=-1) 75 | 76 | return attn 77 | 78 | 79 | class Aggregate(nn.Module): 80 | def __init__( 81 | self, 82 | args, 83 | dim, 84 | heads = 4, 85 | dim_head = 128, 86 | ): 87 | super().__init__() 88 | self.args = args 89 | self.heads = heads 90 | self.scale = dim_head ** -0.5 91 | inner_dim = heads * dim_head 92 | 93 | self.to_v = nn.Conv2d(dim, inner_dim, 1, bias=False) 94 | 95 | self.gamma = nn.Parameter(torch.zeros(1)) 96 | 97 | if dim != inner_dim: 98 | self.project = nn.Conv2d(inner_dim, dim, 1, bias=False) 99 | else: 100 | self.project = None 101 | 102 | def forward(self, attn, fmap): 103 | heads, b, c, h, w = self.heads, *fmap.shape 104 | 105 | v = self.to_v(fmap) 106 | v = rearrange(v, 'b (h d) x y -> b h (x y) d', h=heads) 107 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 108 | out = rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w) 109 | 110 | if self.project is not None: 111 | out = self.project(out) 112 | 113 | out = fmap + self.gamma * out 114 | 115 | return out 116 | 117 | 118 | if __name__ == "__main__": 119 | att = Attention(dim=128, heads=1) 120 | fmap = torch.randn(2, 128, 40, 90) 121 | out = att(fmap) 122 | 123 | print(out.shape) 124 | -------------------------------------------------------------------------------- /core/Networks/BOFNet/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .update import GMAUpdateBlock 6 | from ..encoders import twins_svt_large 7 | from .cnn import BasicEncoder 8 | from .corr import CorrBlock, OLCorrBlock, AlternateCorrBlock 9 | from ...utils.utils import bilinear_sampler, coords_grid, upflow8 10 | from .gma import Attention, Aggregate 11 | from .sk import SKUpdateBlock6_Deep_nopoolres_AllDecoder 12 | from .sk2 import SKUpdateBlock6_Deep_nopoolres_AllDecoder2 13 | 14 | from torchvision.utils import save_image 15 | 16 | autocast = torch.cuda.amp.autocast 17 | 18 | class BOFNet(nn.Module): 19 | def __init__(self, cfg): 20 | super().__init__() 21 | self.cfg = cfg 22 | 23 | self.hidden_dim = hdim = 128 24 | self.context_dim = cdim = 128 25 | 26 | cfg.corr_radius = 4 27 | cfg.corr_levels = 4 28 | 29 | # feature network, context network, and update block 30 | if cfg.cnet == 'twins': 31 | print("[Using twins as context encoder]") 32 | self.cnet = twins_svt_large(pretrained=self.cfg.pretrain) 33 | elif cfg.cnet == 'basicencoder': 34 | print("[Using basicencoder as context encoder]") 35 | self.cnet = BasicEncoder(output_dim=256, norm_fn='instance') 36 | 37 | if cfg.fnet == 'twins': 38 | print("[Using twins as feature encoder]") 39 | self.fnet = twins_svt_large(pretrained=self.cfg.pretrain) 40 | elif cfg.fnet == 'basicencoder': 41 | print("[Using basicencoder as feature encoder]") 42 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance') 43 | 44 | if self.cfg.gma == "GMA": 45 | print("[Using GMA]") 46 | self.update_block = GMAUpdateBlock(self.cfg, hidden_dim=128) 47 | elif self.cfg.gma == 'GMA-SK': 48 | print("[Using GMA-SK]") 49 | self.cfg.cost_heads_num = 1 50 | self.update_block = SKUpdateBlock6_Deep_nopoolres_AllDecoder(args=self.cfg, hidden_dim=128) 51 | elif self.cfg.gma == 'GMA-SK2': 52 | print("[Using GMA-SK2]") 53 | self.cfg.cost_heads_num = 1 54 | self.update_block = SKUpdateBlock6_Deep_nopoolres_AllDecoder2(args=self.cfg, hidden_dim=128) 55 | 56 | print("[Using corr_fn {}]".format(self.cfg.corr_fn)) 57 | 58 | self.att = Attention(args=self.cfg, dim=128, heads=1, max_pos_size=160, dim_head=128) 59 | 60 | def initialize_flow(self, img): 61 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 62 | N, C, H, W = img.shape 63 | coords0 = coords_grid(N, H // 8, W // 8).to(img.device) 64 | coords1 = coords_grid(N, H // 8, W // 8).to(img.device) 65 | 66 | # optical flow computed as difference: flow = coords1 - coords0 67 | return coords0, coords1 68 | 69 | def upsample_flow(self, flow, mask): 70 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 71 | N, _, H, W = flow.shape 72 | mask = mask.view(N, 1, 9, 8, 8, H, W) 73 | mask = torch.softmax(mask, dim=2) 74 | 75 | up_flow = F.unfold(8 * flow, [3, 3], padding=1) 76 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 77 | 78 | up_flow = torch.sum(mask * up_flow, dim=2) 79 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 80 | return up_flow.reshape(N, 2, 8 * H, 8 * W) 81 | 82 | def forward(self, images, data={}, flow_init=None): 83 | 84 | B, N, _, H, W = images.shape 85 | 86 | images = 2 * (images / 255.0) - 1.0 87 | 88 | hdim = self.hidden_dim 89 | cdim = self.context_dim 90 | 91 | with autocast(enabled=self.cfg.mixed_precision): 92 | fmaps = self.fnet(images.reshape(B*N, 3, H, W)).reshape(B, N, -1, H//8, W//8) 93 | fmaps = fmaps.float() 94 | fmap1 = fmaps[:, 0, ...] 95 | fmap2 = fmaps[:, 1, ...] 96 | fmap3 = fmaps[:, 2, ...] 97 | 98 | if self.cfg.corr_fn == "efficient": 99 | corr_fn_21 = AlternateCorrBlock(fmap2, fmap1, radius=self.cfg.corr_radius) 100 | corr_fn_23 = AlternateCorrBlock(fmap2, fmap3, radius=self.cfg.corr_radius) 101 | else: 102 | corr_fn_21 = CorrBlock(fmap2, fmap1, num_levels=self.cfg.corr_levels, radius=self.cfg.corr_radius) 103 | corr_fn_23 = CorrBlock(fmap2, fmap3, num_levels=self.cfg.corr_levels, radius=self.cfg.corr_radius) 104 | 105 | with autocast(enabled=self.cfg.mixed_precision): 106 | cnet = self.cnet(images[:, 1, ...]) 107 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 108 | net = torch.tanh(net) 109 | inp = torch.relu(inp) 110 | attention = self.att(inp) 111 | 112 | coords0_21, coords1_21 = self.initialize_flow(images[:, 0, ...]) 113 | coords0_23, coords1_23 = self.initialize_flow(images[:, 0, ...]) 114 | 115 | flow_predictions = [] 116 | for itr in range(self.cfg.decoder_depth): 117 | coords1_21 = coords1_21.detach() 118 | coords1_23 = coords1_23.detach() 119 | 120 | corr21 = corr_fn_21(coords1_21) 121 | corr23 = corr_fn_23(coords1_23) 122 | corr = torch.cat([corr23, corr21], dim=1) 123 | 124 | flow21 = coords1_21 - coords0_21 125 | flow23 = coords1_23 - coords0_23 126 | flow = torch.cat([flow23, flow21], dim=1) 127 | 128 | with autocast(enabled=self.cfg.mixed_precision): 129 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow, attention) 130 | 131 | up_mask_21, up_mask_23 = torch.split(up_mask, [64*9, 64*9], dim=1) 132 | 133 | coords1_23 = coords1_23 + delta_flow[:, 0:2, ...] 134 | coords1_21 = coords1_21 + delta_flow[:, 2:4, ...] 135 | 136 | # upsample predictions 137 | flow_up_23 = self.upsample_flow(coords1_23 - coords0_23, up_mask_23) 138 | flow_up_21 = self.upsample_flow(coords1_21 - coords0_21, up_mask_21) 139 | 140 | flow_predictions.append(torch.stack([flow_up_23, flow_up_21], dim=1)) 141 | 142 | if self.training: 143 | return flow_predictions 144 | else: 145 | return flow_predictions[-1], torch.stack([coords1_23-coords0_23, coords1_21-coords0_21], dim=1) 146 | -------------------------------------------------------------------------------- /core/Networks/BOFNet/sk.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .gma import Aggregate 5 | 6 | class PCBlock4_Deep_nopool_res(nn.Module): 7 | def __init__(self, C_in, C_out, k_conv): 8 | super().__init__() 9 | self.conv_list = nn.ModuleList([ 10 | nn.Conv2d(C_in, C_in, kernel, stride=1, padding=kernel//2, groups=C_in) for kernel in k_conv]) 11 | 12 | self.ffn1 = nn.Sequential( 13 | nn.Conv2d(C_in, int(1.5*C_in), 1, padding=0), 14 | nn.GELU(), 15 | nn.Conv2d(int(1.5*C_in), C_in, 1, padding=0), 16 | ) 17 | self.pw = nn.Conv2d(C_in, C_in, 1, padding=0) 18 | self.ffn2 = nn.Sequential( 19 | nn.Conv2d(C_in, int(1.5*C_in), 1, padding=0), 20 | nn.GELU(), 21 | nn.Conv2d(int(1.5*C_in), C_out, 1, padding=0), 22 | ) 23 | 24 | def forward(self, x): 25 | x = F.gelu(x + self.ffn1(x)) 26 | for conv in self.conv_list: 27 | x = F.gelu(x + conv(x)) 28 | x = F.gelu(x + self.pw(x)) 29 | x = self.ffn2(x) 30 | return x 31 | 32 | 33 | class SKMotionEncoder6_Deep_nopool_res(nn.Module): 34 | def __init__(self, args): 35 | super().__init__() 36 | cor_planes = 81*4*args.cost_heads_num*2 37 | self.convc1 = PCBlock4_Deep_nopool_res(cor_planes, 256, k_conv=args.k_conv) 38 | self.convc2 = PCBlock4_Deep_nopool_res(256, 192, k_conv=args.k_conv) 39 | 40 | self.convf1_ = nn.Conv2d(4, 128, 1, 1, 0) 41 | self.convf2 = PCBlock4_Deep_nopool_res(128, 64, k_conv=args.k_conv) 42 | 43 | self.conv = PCBlock4_Deep_nopool_res(64+192, 128-4, k_conv=args.k_conv) 44 | 45 | 46 | def forward(self, flow, corr): 47 | cor = F.gelu(self.convc1(corr)) 48 | 49 | cor = self.convc2(cor) 50 | 51 | flo = self.convf1_(flow) 52 | flo = self.convf2(flo) 53 | 54 | cor_flo = torch.cat([cor, flo], dim=1) 55 | out = self.conv(cor_flo) 56 | 57 | return torch.cat([out, flow], dim=1) 58 | 59 | 60 | class SKUpdateBlock6_Deep_nopoolres_AllDecoder(nn.Module): 61 | def __init__(self, args, hidden_dim): 62 | super().__init__() 63 | self.args = args 64 | 65 | args.k_conv = [1, 15] 66 | args.PCUpdater_conv = [1, 7] 67 | 68 | self.encoder = SKMotionEncoder6_Deep_nopool_res(args) 69 | self.gru = PCBlock4_Deep_nopool_res(128+hidden_dim+hidden_dim+128, 128, k_conv=args.PCUpdater_conv) 70 | self.flow_head = PCBlock4_Deep_nopool_res(128, 4, k_conv=args.k_conv) 71 | 72 | self.mask = nn.Sequential( 73 | nn.Conv2d(128, 256, 3, padding=1), 74 | nn.ReLU(inplace=True), 75 | nn.Conv2d(256, 64*9*2, 1, padding=0)) 76 | 77 | self.aggregator = Aggregate(args=self.args, dim=128, dim_head=128, heads=1) 78 | 79 | def forward(self, net, inp, corr, flow, attention): 80 | motion_features = self.encoder(flow, corr) 81 | motion_features_global = self.aggregator(attention, motion_features) 82 | inp_cat = torch.cat([inp, motion_features, motion_features_global], dim=1) 83 | 84 | # Attentional update 85 | net = self.gru(torch.cat([net, inp_cat], dim=1)) 86 | 87 | delta_flow = self.flow_head(net) 88 | 89 | # scale mask to balence gradients 90 | mask = .25 * self.mask(net) 91 | return net, mask, delta_flow 92 | -------------------------------------------------------------------------------- /core/Networks/BOFNet/sk2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .gma import Aggregate 5 | 6 | class PCBlock4_Deep_nopool_res(nn.Module): 7 | def __init__(self, C_in, C_out, k_conv): 8 | super().__init__() 9 | self.conv_list = nn.ModuleList([ 10 | nn.Conv2d(C_in, C_in, kernel, stride=1, padding=kernel//2, groups=C_in) for kernel in k_conv]) 11 | 12 | self.ffn1 = nn.Sequential( 13 | nn.Conv2d(C_in, int(1.5*C_in), 1, padding=0), 14 | nn.GELU(), 15 | nn.Conv2d(int(1.5*C_in), C_in, 1, padding=0), 16 | ) 17 | self.pw = nn.Conv2d(C_in, C_in, 1, padding=0) 18 | self.ffn2 = nn.Sequential( 19 | nn.Conv2d(C_in, int(1.5*C_in), 1, padding=0), 20 | nn.GELU(), 21 | nn.Conv2d(int(1.5*C_in), C_out, 1, padding=0), 22 | ) 23 | 24 | def forward(self, x): 25 | x = F.gelu(x + self.ffn1(x)) 26 | for conv in self.conv_list: 27 | x = F.gelu(x + conv(x)) 28 | x = F.gelu(x + self.pw(x)) 29 | x = self.ffn2(x) 30 | return x 31 | 32 | 33 | class SKMotionEncoder6_Deep_nopool_res(nn.Module): 34 | def __init__(self, args): 35 | super().__init__() 36 | self.cor_planes = cor_planes = (args.corr_radius*2+1)**2*args.cost_heads_num*args.corr_levels 37 | self.convc1 = PCBlock4_Deep_nopool_res(cor_planes, 128, k_conv=args.k_conv) 38 | self.convc2 = PCBlock4_Deep_nopool_res(256, 192, k_conv=args.k_conv) 39 | 40 | self.convf1_ = nn.Conv2d(4, 128, 1, 1, 0) 41 | self.convf2 = PCBlock4_Deep_nopool_res(128, 64, k_conv=args.k_conv) 42 | 43 | self.conv = PCBlock4_Deep_nopool_res(64+192, 128-4, k_conv=args.k_conv) 44 | 45 | 46 | def forward(self, flow, corr): 47 | corr1, corr2 = torch.split(corr, [self.cor_planes, self.cor_planes], dim=1) 48 | cor = F.gelu(torch.cat([self.convc1(corr1), self.convc1(corr2)], dim=1)) 49 | 50 | cor = self.convc2(cor) 51 | 52 | flo = self.convf1_(flow) 53 | flo = self.convf2(flo) 54 | 55 | cor_flo = torch.cat([cor, flo], dim=1) 56 | out = self.conv(cor_flo) 57 | 58 | return torch.cat([out, flow], dim=1) 59 | 60 | 61 | class SKUpdateBlock6_Deep_nopoolres_AllDecoder2(nn.Module): 62 | def __init__(self, args, hidden_dim): 63 | super().__init__() 64 | self.args = args 65 | 66 | args.k_conv = [1, 15] 67 | args.PCUpdater_conv = [1, 7] 68 | 69 | self.encoder = SKMotionEncoder6_Deep_nopool_res(args) 70 | self.gru = PCBlock4_Deep_nopool_res(128+hidden_dim+hidden_dim+128, 128, k_conv=args.PCUpdater_conv) 71 | self.flow_head = PCBlock4_Deep_nopool_res(128, 4, k_conv=args.k_conv) 72 | 73 | self.mask = nn.Sequential( 74 | nn.Conv2d(128, 256, 3, padding=1), 75 | nn.ReLU(inplace=True), 76 | nn.Conv2d(256, 64*9*2, 1, padding=0)) 77 | 78 | self.aggregator = Aggregate(args=self.args, dim=128, dim_head=128, heads=1) 79 | 80 | def forward(self, net, inp, corr, flow, attention): 81 | motion_features = self.encoder(flow, corr) 82 | motion_features_global = self.aggregator(attention, motion_features) 83 | inp_cat = torch.cat([inp, motion_features, motion_features_global], dim=1) 84 | 85 | # Attentional update 86 | net = self.gru(torch.cat([net, inp_cat], dim=1)) 87 | 88 | delta_flow = self.flow_head(net) 89 | 90 | # scale mask to balence gradients 91 | mask = .25 * self.mask(net) 92 | return net, mask, delta_flow 93 | -------------------------------------------------------------------------------- /core/Networks/BOFNet/update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .gma import Aggregate 5 | 6 | 7 | class FlowHead(nn.Module): 8 | def __init__(self, input_dim=128, hidden_dim=256): 9 | super(FlowHead, self).__init__() 10 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 11 | self.conv2 = nn.Conv2d(hidden_dim, 4, 3, padding=1) 12 | self.relu = nn.ReLU(inplace=True) 13 | 14 | def forward(self, x): 15 | return self.conv2(self.relu(self.conv1(x))) 16 | 17 | 18 | class ConvGRU(nn.Module): 19 | def __init__(self, hidden_dim=128, input_dim=128+128): 20 | super(ConvGRU, self).__init__() 21 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 22 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 23 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 24 | 25 | def forward(self, h, x): 26 | hx = torch.cat([h, x], dim=1) 27 | 28 | z = torch.sigmoid(self.convz(hx)) 29 | r = torch.sigmoid(self.convr(hx)) 30 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) 31 | 32 | h = (1-z) * h + z * q 33 | return h 34 | 35 | 36 | class SepConvGRU(nn.Module): 37 | def __init__(self, hidden_dim=128, input_dim=192+128): 38 | super(SepConvGRU, self).__init__() 39 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 40 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 41 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 42 | 43 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 44 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 45 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 46 | 47 | 48 | def forward(self, h, x): 49 | # horizontal 50 | hx = torch.cat([h, x], dim=1) 51 | z = torch.sigmoid(self.convz1(hx)) 52 | r = torch.sigmoid(self.convr1(hx)) 53 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) 54 | h = (1-z) * h + z * q 55 | 56 | # vertical 57 | hx = torch.cat([h, x], dim=1) 58 | z = torch.sigmoid(self.convz2(hx)) 59 | r = torch.sigmoid(self.convr2(hx)) 60 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) 61 | h = (1-z) * h + z * q 62 | 63 | return h 64 | 65 | 66 | class BasicMotionEncoder(nn.Module): 67 | def __init__(self, args): 68 | super(BasicMotionEncoder, self).__init__() 69 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 * 2 70 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) 71 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 72 | self.convf1 = nn.Conv2d(4, 128, 7, padding=3) 73 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 74 | self.conv = nn.Conv2d(64+192, 128-4, 3, padding=1) 75 | 76 | def forward(self, flow, corr): 77 | cor = F.relu(self.convc1(corr)) 78 | cor = F.relu(self.convc2(cor)) 79 | flo = F.relu(self.convf1(flow)) 80 | flo = F.relu(self.convf2(flo)) 81 | 82 | cor_flo = torch.cat([cor, flo], dim=1) 83 | out = F.relu(self.conv(cor_flo)) 84 | return torch.cat([out, flow], dim=1) 85 | 86 | 87 | class BasicUpdateBlock(nn.Module): 88 | def __init__(self, args, hidden_dim=128, input_dim=128): 89 | super(BasicUpdateBlock, self).__init__() 90 | self.args = args 91 | self.encoder = BasicMotionEncoder(args) 92 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) 93 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 94 | 95 | self.mask = nn.Sequential( 96 | nn.Conv2d(128, 256, 3, padding=1), 97 | nn.ReLU(inplace=True), 98 | nn.Conv2d(256, 64*9, 1, padding=0)) 99 | 100 | def forward(self, net, inp, corr, flow, upsample=True): 101 | motion_features = self.encoder(flow, corr) 102 | inp = torch.cat([inp, motion_features], dim=1) 103 | 104 | net = self.gru(net, inp) 105 | delta_flow = self.flow_head(net) 106 | 107 | # scale mask to balence gradients 108 | mask = .25 * self.mask(net) 109 | return net, mask, delta_flow 110 | 111 | 112 | class GMAUpdateBlock(nn.Module): 113 | def __init__(self, args, hidden_dim=128): 114 | super().__init__() 115 | self.args = args 116 | self.encoder = BasicMotionEncoder(args) 117 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim+hidden_dim) 118 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 119 | 120 | self.mask = nn.Sequential( 121 | nn.Conv2d(128, 256, 3, padding=1), 122 | nn.ReLU(inplace=True), 123 | nn.Conv2d(256, 64*9*2, 1, padding=0)) 124 | 125 | self.aggregator = Aggregate(args=self.args, dim=128, dim_head=128, heads=1) 126 | 127 | def forward(self, net, inp, corr, flow, attention): 128 | motion_features = self.encoder(flow, corr) 129 | motion_features_global = self.aggregator(attention, motion_features) 130 | inp_cat = torch.cat([inp, motion_features, motion_features_global], dim=1) 131 | 132 | # Attentional update 133 | net = self.gru(net, inp_cat) 134 | 135 | delta_flow = self.flow_head(net) 136 | 137 | # scale mask to balence gradients 138 | mask = .25 * self.mask(net) 139 | 140 | return net, mask, delta_flow 141 | 142 | 143 | 144 | -------------------------------------------------------------------------------- /core/Networks/MOFNetStack/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiaoyuShi97/VideoFlow/51489304db6e75fbdd9ff64d4656c1d120b5a673/core/Networks/MOFNetStack/__init__.py -------------------------------------------------------------------------------- /core/Networks/MOFNetStack/corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from ...utils.utils import bilinear_sampler, coords_grid 6 | # from compute_sparse_correlation import compute_sparse_corr, compute_sparse_corr_torch, compute_sparse_corr_mink 7 | 8 | try: 9 | import alt_cuda_corr 10 | except: 11 | # alt_cuda_corr is not compiled 12 | print("[!!alt_cuda_corr is not compiled!!]") 13 | pass 14 | 15 | 16 | class OLCorrBlock: 17 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 18 | self.num_levels = num_levels 19 | self.radius = radius 20 | 21 | batch, dim, ht, wd = fmap1.shape 22 | self.fmap1 = fmap1.permute(0, 2, 3, 1).view(batch*ht*wd, 1, dim) 23 | 24 | self.fmap2_pyramid = [] 25 | self.fmap2_pyramid.append(fmap2) 26 | for i in range(self.num_levels - 1): 27 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2) 28 | self.fmap2_pyramid.append(fmap2) 29 | 30 | def __call__(self, coords): 31 | r = self.radius 32 | coords = coords.permute(0, 2, 3, 1) 33 | batch, h1, w1, _ = coords.shape 34 | _, _, dim = self.fmap1.shape 35 | 36 | out_pyramid = [] 37 | for i in range(self.num_levels): 38 | fmap2 = self.fmap2_pyramid[i] 39 | dx = torch.linspace(-r, r, 2 * r + 1) 40 | dy = torch.linspace(-r, r, 2 * r + 1) 41 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) 42 | 43 | centroid_lvl = coords.reshape(batch, h1 * w1, 1, 2) / 2 ** i 44 | delta_lvl = delta.view(1, 1, (2 * r + 1) ** 2, 2) 45 | coords_lvl = centroid_lvl + delta_lvl 46 | 47 | fmap2 = bilinear_sampler(fmap2, coords_lvl) # B, 256, h*w, 9*9 48 | fmap2 = fmap2.permute(0, 2, 1, 3).view(batch*h1*w1, dim, (2 * r + 1) ** 2) 49 | #print(self.fmap1.shape, fmap2.shape) 50 | corr = torch.bmm(self.fmap1, fmap2) / torch.sqrt(torch.tensor(dim).float()) 51 | 52 | corr = corr.view(batch, h1, w1, -1) 53 | out_pyramid.append(corr) 54 | 55 | out = torch.cat(out_pyramid, dim=-1) 56 | return out.permute(0, 3, 1, 2).float() 57 | 58 | 59 | class CorrBlock: 60 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 61 | self.num_levels = num_levels 62 | self.radius = radius 63 | self.corr_pyramid = [] 64 | 65 | # all pairs correlation 66 | corr = CorrBlock.corr(fmap1, fmap2) 67 | 68 | batch, h1, w1, dim, h2, w2 = corr.shape 69 | corr = corr.reshape(batch * h1 * w1, dim, h2, w2) 70 | 71 | self.corr_pyramid.append(corr) 72 | for i in range(self.num_levels - 1): 73 | corr = F.avg_pool2d(corr, 2, stride=2) 74 | self.corr_pyramid.append(corr) 75 | 76 | def __call__(self, coords): 77 | r = self.radius 78 | coords = coords.permute(0, 2, 3, 1) 79 | batch, h1, w1, _ = coords.shape 80 | 81 | out_pyramid = [] 82 | for i in range(self.num_levels): 83 | corr = self.corr_pyramid[i] 84 | dx = torch.linspace(-r, r, 2 * r + 1) 85 | dy = torch.linspace(-r, r, 2 * r + 1) 86 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) 87 | 88 | centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2 ** i 89 | delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) 90 | coords_lvl = centroid_lvl + delta_lvl 91 | 92 | corr = bilinear_sampler(corr, coords_lvl) 93 | corr = corr.view(batch, h1, w1, -1) 94 | out_pyramid.append(corr) 95 | 96 | out = torch.cat(out_pyramid, dim=-1) 97 | return out.permute(0, 3, 1, 2).contiguous().float() 98 | 99 | @staticmethod 100 | def corr(fmap1, fmap2): 101 | batch, dim, ht, wd = fmap1.shape 102 | fmap1 = fmap1.view(batch, dim, ht * wd) 103 | fmap2 = fmap2.view(batch, dim, ht * wd) 104 | 105 | corr = torch.matmul(fmap1.transpose(1, 2), fmap2) 106 | corr = corr.view(batch, ht, wd, 1, ht, wd) 107 | return corr / torch.sqrt(torch.tensor(dim).float()) 108 | 109 | 110 | class CorrBlockSingleScale(nn.Module): 111 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 112 | super().__init__() 113 | self.radius = radius 114 | 115 | # all pairs correlation 116 | corr = CorrBlock.corr(fmap1, fmap2) 117 | batch, h1, w1, dim, h2, w2 = corr.shape 118 | self.corr = corr.reshape(batch * h1 * w1, dim, h2, w2) 119 | 120 | def __call__(self, coords): 121 | r = self.radius 122 | coords = coords.permute(0, 2, 3, 1) 123 | batch, h1, w1, _ = coords.shape 124 | 125 | corr = self.corr 126 | dx = torch.linspace(-r, r, 2 * r + 1) 127 | dy = torch.linspace(-r, r, 2 * r + 1) 128 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) 129 | 130 | centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) 131 | delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) 132 | coords_lvl = centroid_lvl + delta_lvl 133 | 134 | corr = bilinear_sampler(corr, coords_lvl) 135 | out = corr.view(batch, h1, w1, -1) 136 | out = out.permute(0, 3, 1, 2).contiguous().float() 137 | return out 138 | 139 | @staticmethod 140 | def corr(fmap1, fmap2): 141 | batch, dim, ht, wd = fmap1.shape 142 | fmap1 = fmap1.view(batch, dim, ht * wd) 143 | fmap2 = fmap2.view(batch, dim, ht * wd) 144 | 145 | corr = torch.matmul(fmap1.transpose(1, 2), fmap2) 146 | corr = corr.view(batch, ht, wd, 1, ht, wd) 147 | return corr / torch.sqrt(torch.tensor(dim).float()) 148 | 149 | class AlternateCorrBlock: 150 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 151 | self.num_levels = num_levels 152 | self.radius = radius 153 | 154 | self.pyramid = [(fmap1, fmap2)] 155 | for i in range(self.num_levels): 156 | #fmap1 = F.avg_pool2d(fmap1, 2, stride=2) 157 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2) 158 | self.pyramid.append((None, fmap2)) 159 | 160 | def __call__(self, coords): 161 | coords = coords.permute(0, 2, 3, 1) 162 | B, H, W, _ = coords.shape 163 | dim = self.pyramid[0][0].shape[1] 164 | 165 | corr_list = [] 166 | for i in range(self.num_levels): 167 | r = self.radius 168 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() 169 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() 170 | 171 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() 172 | corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) 173 | corr_list.append(corr.squeeze(1)) 174 | 175 | corr = torch.stack(corr_list, dim=1) 176 | corr = corr.reshape(B, -1, H, W) 177 | return corr / torch.sqrt(torch.tensor(dim).float()) 178 | #return corr.mul_(1.0/torch.sqrt(torch.tensor(dim).float())) 179 | -------------------------------------------------------------------------------- /core/Networks/MOFNetStack/gma.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | from einops import rearrange 4 | 5 | 6 | class RelPosEmb(nn.Module): 7 | def __init__( 8 | self, 9 | max_pos_size, 10 | dim_head 11 | ): 12 | super().__init__() 13 | self.rel_height = nn.Embedding(2 * max_pos_size - 1, dim_head) 14 | self.rel_width = nn.Embedding(2 * max_pos_size - 1, dim_head) 15 | 16 | deltas = torch.arange(max_pos_size).view(1, -1) - torch.arange(max_pos_size).view(-1, 1) 17 | rel_ind = deltas + max_pos_size - 1 18 | self.register_buffer('rel_ind', rel_ind) 19 | 20 | def forward(self, q): 21 | batch, heads, h, w, c = q.shape 22 | height_emb = self.rel_height(self.rel_ind[:h, :h].reshape(-1)) 23 | width_emb = self.rel_width(self.rel_ind[:w, :w].reshape(-1)) 24 | 25 | height_emb = rearrange(height_emb, '(x u) d -> x u () d', x=h) 26 | width_emb = rearrange(width_emb, '(y v) d -> y () v d', y=w) 27 | 28 | height_score = einsum('b h x y d, x u v d -> b h x y u v', q, height_emb) 29 | width_score = einsum('b h x y d, y u v d -> b h x y u v', q, width_emb) 30 | 31 | return height_score + width_score 32 | 33 | 34 | class Attention(nn.Module): 35 | def __init__( 36 | self, 37 | *, 38 | args, 39 | dim, 40 | max_pos_size = 100, 41 | heads = 4, 42 | dim_head = 128, 43 | ): 44 | super().__init__() 45 | self.args = args 46 | self.heads = heads 47 | self.scale = dim_head ** -0.5 48 | inner_dim = heads * dim_head 49 | 50 | self.to_qk = nn.Conv2d(dim, inner_dim * 2, 1, bias=False) 51 | 52 | self.pos_emb = RelPosEmb(max_pos_size, dim_head) 53 | 54 | def forward(self, fmap): 55 | heads, b, c, h, w = self.heads, *fmap.shape 56 | 57 | q, k = self.to_qk(fmap).chunk(2, dim=1) 58 | 59 | q, k = map(lambda t: rearrange(t, 'b (h d) x y -> b h x y d', h=heads), (q, k)) 60 | q = self.scale * q 61 | 62 | # if self.args.position_only: 63 | # sim = self.pos_emb(q) 64 | 65 | # elif self.args.position_and_content: 66 | # sim_content = einsum('b h x y d, b h u v d -> b h x y u v', q, k) 67 | # sim_pos = self.pos_emb(q) 68 | # sim = sim_content + sim_pos 69 | 70 | # else: 71 | sim = einsum('b h x y d, b h u v d -> b h x y u v', q, k) 72 | 73 | sim = rearrange(sim, 'b h x y u v -> b h (x y) (u v)') 74 | attn = sim.softmax(dim=-1) 75 | 76 | return attn 77 | 78 | 79 | class Aggregate(nn.Module): 80 | def __init__( 81 | self, 82 | args, 83 | dim, 84 | heads = 4, 85 | dim_head = 128, 86 | ): 87 | super().__init__() 88 | self.args = args 89 | self.heads = heads 90 | self.scale = dim_head ** -0.5 91 | inner_dim = heads * dim_head 92 | 93 | self.to_v = nn.Conv2d(dim, inner_dim, 1, bias=False) 94 | 95 | self.gamma = nn.Parameter(torch.zeros(1)) 96 | 97 | if dim != inner_dim: 98 | self.project = nn.Conv2d(inner_dim, dim, 1, bias=False) 99 | else: 100 | self.project = None 101 | 102 | def forward(self, attn, fmap): 103 | heads, b, c, h, w = self.heads, *fmap.shape 104 | 105 | v = self.to_v(fmap) 106 | v = rearrange(v, 'b (h d) x y -> b h (x y) d', h=heads) 107 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 108 | out = rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w) 109 | 110 | if self.project is not None: 111 | out = self.project(out) 112 | 113 | out = fmap + self.gamma * out 114 | 115 | return out 116 | 117 | 118 | if __name__ == "__main__": 119 | att = Attention(dim=128, heads=1) 120 | fmap = torch.randn(2, 128, 40, 90) 121 | out = att(fmap) 122 | 123 | print(out.shape) 124 | -------------------------------------------------------------------------------- /core/Networks/MOFNetStack/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .update import GMAUpdateBlock 6 | from ..encoders import twins_svt_large, convnext_Xlarge_4x, convnext_base_2x 7 | from .corr import CorrBlock, OLCorrBlock, AlternateCorrBlock 8 | from ...utils.utils import bilinear_sampler, coords_grid, upflow8 9 | from .gma import Attention, Aggregate 10 | 11 | from torchvision.utils import save_image 12 | 13 | autocast = torch.cuda.amp.autocast 14 | 15 | class MOFNet(nn.Module): 16 | def __init__(self, cfg): 17 | super().__init__() 18 | self.cfg = cfg 19 | 20 | self.hidden_dim = hdim = self.cfg.feat_dim // 2 21 | self.context_dim = cdim = self.cfg.feat_dim // 2 22 | 23 | cfg.corr_radius = 4 24 | 25 | # feature network, context network, and update block 26 | if cfg.cnet == 'twins': 27 | print("[Using twins as context encoder]") 28 | self.cnet = twins_svt_large(pretrained=self.cfg.pretrain) 29 | elif cfg.cnet == 'basicencoder': 30 | print("[Using basicencoder as context encoder]") 31 | self.cnet = BasicEncoder(output_dim=256, norm_fn='instance') 32 | elif cfg.cnet == 'convnext_Xlarge_4x': 33 | print("[Using convnext_Xlarge_4x as context encoder]") 34 | self.cnet = convnext_Xlarge_4x(pretrained=self.cfg.pretrain) 35 | elif cfg.cnet == 'convnext_base_2x': 36 | print("[Using convnext_base_2x as context encoder]") 37 | self.cnet = convnext_base_2x(pretrained=self.cfg.pretrain) 38 | 39 | if cfg.fnet == 'twins': 40 | print("[Using twins as feature encoder]") 41 | self.fnet = twins_svt_large(pretrained=self.cfg.pretrain) 42 | elif cfg.fnet == 'basicencoder': 43 | print("[Using basicencoder as feature encoder]") 44 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance') 45 | elif cfg.fnet == 'convnext_Xlarge_4x': 46 | print("[Using convnext_Xlarge_4x as feature encoder]") 47 | self.fnet = convnext_Xlarge_4x(pretrained=self.cfg.pretrain) 48 | elif cfg.fnet == 'convnext_base_2x': 49 | print("[Using convnext_base_2x as feature encoder]") 50 | self.fnet = convnext_base_2x(pretrained=self.cfg.pretrain) 51 | 52 | hidden_dim_ratio = 256 // cfg.feat_dim 53 | 54 | if self.cfg.Tfusion == 'stack': 55 | print("[Using stack.]") 56 | self.cfg.cost_heads_num = 1 57 | from .stack import SKUpdateBlock6_Deep_nopoolres_AllDecoder2 58 | self.update_block = SKUpdateBlock6_Deep_nopoolres_AllDecoder2(args=self.cfg, hidden_dim=128//hidden_dim_ratio) 59 | # elif self.cfg.Tfusion == 'resstack': 60 | # print("[Using resstack.]") 61 | # self.cfg.cost_heads_num = 1 62 | # from .resstack import SKUpdateBlock6_Deep_nopoolres_AllDecoder2 63 | # self.update_block = SKUpdateBlock6_Deep_nopoolres_AllDecoder2(args=self.cfg, hidden_dim=128) 64 | # elif self.cfg.Tfusion == 'stackcat': 65 | # print("[Using stackcat.]") 66 | # self.cfg.cost_heads_num = 1 67 | # from .stackcat import SKUpdateBlock6_Deep_nopoolres_AllDecoder2 68 | # self.update_block = SKUpdateBlock6_Deep_nopoolres_AllDecoder2(args=self.cfg, hidden_dim=128) 69 | 70 | 71 | print("[Using corr_fn {}]".format(self.cfg.corr_fn)) 72 | 73 | gma_down_ratio = 256 // cfg.feat_dim 74 | 75 | self.att = Attention(args=self.cfg, dim=128//hidden_dim_ratio, heads=1, max_pos_size=160, dim_head=128//hidden_dim_ratio) 76 | 77 | if self.cfg.context_3D: 78 | print("[Using 3D Conv on context feature.]") 79 | self.context_3D = nn.Sequential( 80 | nn.Conv3d(256, 256, 3, stride=1, padding=1), 81 | nn.GELU(), 82 | nn.Conv3d(256, 256, 3, stride=1, padding=1), 83 | nn.GELU(), 84 | nn.Conv3d(256, 256, 3, stride=1, padding=1), 85 | nn.GELU(), 86 | ) 87 | 88 | def initialize_flow(self, img, bs, down_ratio): 89 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 90 | N, C, H, W = img.shape 91 | coords0 = coords_grid(bs, H // down_ratio, W // down_ratio).to(img.device) 92 | coords1 = coords_grid(bs, H // down_ratio, W // down_ratio).to(img.device) 93 | 94 | # optical flow computed as difference: flow = coords1 - coords0 95 | return coords0, coords1 96 | 97 | def upsample_flow(self, flow, mask): 98 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 99 | N, _, H, W = flow.shape 100 | mask = mask.view(N, 1, 9, 8, 8, H, W) 101 | mask = torch.softmax(mask, dim=2) 102 | 103 | up_flow = F.unfold(8 * flow, [3, 3], padding=1) 104 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 105 | 106 | up_flow = torch.sum(mask * up_flow, dim=2) 107 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 108 | return up_flow.reshape(N, 2, 8 * H, 8 * W) 109 | 110 | def upsample_flow_4x(self, flow, mask): 111 | 112 | N, _, H, W = flow.shape 113 | mask = mask.view(N, 1, 9, 4, 4, H, W) 114 | mask = torch.softmax(mask, dim=2) 115 | 116 | up_flow = F.unfold(4 * flow, [3, 3], padding=1) 117 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 118 | 119 | up_flow = torch.sum(mask * up_flow, dim=2) 120 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 121 | return up_flow.reshape(N, 2, 4 * H, 4 * W) 122 | 123 | def upsample_flow_2x(self, flow, mask): 124 | 125 | N, _, H, W = flow.shape 126 | mask = mask.view(N, 1, 9, 2, 2, H, W) 127 | mask = torch.softmax(mask, dim=2) 128 | 129 | up_flow = F.unfold(2 * flow, [3, 3], padding=1) 130 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 131 | 132 | up_flow = torch.sum(mask * up_flow, dim=2) 133 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 134 | return up_flow.reshape(N, 2, 2 * H, 2 * W) 135 | 136 | 137 | 138 | def forward(self, images, data={}, flow_init=None): 139 | 140 | down_ratio = self.cfg.down_ratio 141 | 142 | B, N, _, H, W = images.shape 143 | 144 | images = 2 * (images / 255.0) - 1.0 145 | 146 | hdim = self.hidden_dim 147 | cdim = self.context_dim 148 | 149 | with autocast(enabled=self.cfg.mixed_precision): 150 | fmaps = self.fnet(images.reshape(B*N, 3, H, W)).reshape(B, N, -1, H//down_ratio, W//down_ratio) 151 | fmaps = fmaps.float() 152 | 153 | if self.cfg.corr_fn == "default": 154 | corr_fn = CorrBlock 155 | elif self.cfg.corr_fn == "efficient": 156 | corr_fn = AlternateCorrBlock 157 | forward_corr_fn = corr_fn(fmaps[:, 1:N-1, ...].reshape(B*(N-2), -1, H//down_ratio, W//down_ratio), fmaps[:, 2:N, ...].reshape(B*(N-2), -1, H//down_ratio, W//down_ratio), num_levels=self.cfg.corr_levels, radius=self.cfg.corr_radius) 158 | backward_corr_fn = corr_fn(fmaps[:, 1:N-1, ...].reshape(B*(N-2), -1, H//down_ratio, W//down_ratio), fmaps[:, 0:N-2, ...].reshape(B*(N-2), -1, H//down_ratio, W//down_ratio), num_levels=self.cfg.corr_levels, radius=self.cfg.corr_radius) 159 | 160 | with autocast(enabled=self.cfg.mixed_precision): 161 | cnet = self.cnet(images[:, 1:N-1, ...].reshape(B*(N-2), 3, H, W)) 162 | if self.cfg.context_3D: 163 | #print("!@!@@#!@#!@") 164 | cnet = cnet.reshape(B, N-2, -1, H//2, W//2).permute(0, 2, 1, 3, 4) 165 | cnet = self.context_3D(cnet) + cnet 166 | #print(cnet.shape) 167 | cnet = cnet.permute(0, 2, 1, 3, 4).reshape(B*(N-2), -1, H//down_ratio, W//down_ratio) 168 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 169 | net = torch.tanh(net) 170 | inp = torch.relu(inp) 171 | attention = self.att(inp) 172 | 173 | forward_coords1, forward_coords0 = self.initialize_flow(images[:, 0, ...], bs=B*(N-2), down_ratio=down_ratio) 174 | backward_coords1, backward_coords0 = self.initialize_flow(images[:, 0, ...], bs=B*(N-2), down_ratio=down_ratio) 175 | 176 | flow_predictions = [] # forward flows followed by backward flows 177 | 178 | motion_hidden_state = None 179 | 180 | for itr in range(self.cfg.decoder_depth): 181 | 182 | forward_coords1 = forward_coords1.detach() 183 | backward_coords1 = backward_coords1.detach() 184 | 185 | forward_corr = forward_corr_fn(forward_coords1) 186 | backward_corr = backward_corr_fn(backward_coords1) 187 | 188 | forward_flow = forward_coords1 - forward_coords0 189 | backward_flow = backward_coords1 - backward_coords0 190 | 191 | with autocast(enabled=self.cfg.mixed_precision): 192 | net, motion_hidden_state, up_mask, delta_flow = self.update_block(net, motion_hidden_state, inp, forward_corr, backward_corr, forward_flow, backward_flow, forward_coords0, attention, bs=B) 193 | 194 | forward_up_mask, backward_up_mask = torch.split(up_mask, [down_ratio**2*9, down_ratio**2*9], dim=1) 195 | 196 | forward_coords1 = forward_coords1 + delta_flow[:, 0:2, ...] 197 | backward_coords1 = backward_coords1 + delta_flow[:, 2:4, ...] 198 | 199 | # upsample predictions 200 | if down_ratio == 4: 201 | forward_flow_up = self.upsample_flow_4x(forward_coords1-forward_coords0, forward_up_mask) 202 | backward_flow_up = self.upsample_flow_4x(backward_coords1-backward_coords0, backward_up_mask) 203 | elif down_ratio == 2: 204 | forward_flow_up = self.upsample_flow_2x(forward_coords1-forward_coords0, forward_up_mask) 205 | backward_flow_up = self.upsample_flow_2x(backward_coords1-backward_coords0, backward_up_mask) 206 | elif down_ratio == 8: 207 | forward_flow_up = self.upsample_flow(forward_coords1-forward_coords0, forward_up_mask) 208 | backward_flow_up = self.upsample_flow(backward_coords1-backward_coords0, backward_up_mask) 209 | 210 | flow_predictions.append(torch.cat([forward_flow_up.reshape(B, N-2, 2, H, W), backward_flow_up.reshape(B, N-2, 2, H, W)], dim=1)) 211 | 212 | if self.training: 213 | return flow_predictions 214 | else: 215 | return flow_predictions[-1], flow_predictions[-1] 216 | -------------------------------------------------------------------------------- /core/Networks/MOFNetStack/resstack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .gma import Aggregate 5 | 6 | from ...utils.utils import bilinear_sampler 7 | 8 | class PCBlock4_Deep_nopool_res(nn.Module): 9 | def __init__(self, C_in, C_out, k_conv): 10 | super().__init__() 11 | self.conv_list = nn.ModuleList([ 12 | nn.Conv2d(C_in, C_in, kernel, stride=1, padding=kernel//2, groups=C_in) for kernel in k_conv]) 13 | 14 | self.ffn1 = nn.Sequential( 15 | nn.Conv2d(C_in, int(1.3*C_in), 1, padding=0), 16 | nn.GELU(), 17 | nn.Conv2d(int(1.3*C_in), C_in, 1, padding=0), 18 | ) 19 | self.pw = nn.Conv2d(C_in, C_in, 1, padding=0) 20 | self.ffn2 = nn.Sequential( 21 | nn.Conv2d(C_in, int(1.3*C_in), 1, padding=0), 22 | nn.GELU(), 23 | nn.Conv2d(int(1.3*C_in), C_out, 1, padding=0), 24 | ) 25 | 26 | def forward(self, x): 27 | x = F.gelu(x + self.ffn1(x)) 28 | for conv in self.conv_list: 29 | x = F.gelu(x + conv(x)) 30 | x = F.gelu(x + self.pw(x)) 31 | x = self.ffn2(x) 32 | return x 33 | 34 | class velocity_update_block(nn.Module): 35 | def __init__(self, C_in=43+128+43, C_out=43, C_hidden=64): 36 | super().__init__() 37 | self.mlp = nn.Sequential( 38 | nn.Conv2d(C_in, C_hidden, 3, padding=1), 39 | nn.GELU(), 40 | nn.Conv2d(C_hidden, C_hidden, 3, padding=1), 41 | nn.GELU(), 42 | nn.Conv2d(C_hidden, C_out, 3, padding=1), 43 | ) 44 | def forward(self, x): 45 | return self.mlp(x) 46 | 47 | 48 | class SKMotionEncoder6_Deep_nopool_res(nn.Module): 49 | def __init__(self, args): 50 | super().__init__() 51 | self.cor_planes = cor_planes = (args.corr_radius*2+1)**2*args.cost_heads_num*args.corr_levels 52 | self.convc1 = PCBlock4_Deep_nopool_res(cor_planes, 128, k_conv=args.k_conv) 53 | self.convc2 = PCBlock4_Deep_nopool_res(256, 192, k_conv=args.k_conv) 54 | 55 | self.init_hidden_state = nn.Parameter(torch.randn(1, 1, 48, 1, 1)) 56 | 57 | self.convf1_ = nn.Conv2d(4, 96, 1, 1, 0) 58 | self.convf2 = PCBlock4_Deep_nopool_res(96, 64, k_conv=args.k_conv) 59 | 60 | self.conv = PCBlock4_Deep_nopool_res(64+192+48*3, 128-4+48, k_conv=args.k_conv) 61 | 62 | self.velocity_update_block = velocity_update_block() 63 | 64 | def sample_flo_feat(self, flow, feat): 65 | 66 | sampled_feat = bilinear_sampler(feat.float(), flow.permute(0, 2, 3, 1)) 67 | return sampled_feat 68 | 69 | def forward(self, motion_hidden_state, forward_flow, backward_flow, coords0, forward_corr, backward_corr, bs): 70 | 71 | BN, _, H, W = forward_flow.shape 72 | N = BN // bs 73 | 74 | if motion_hidden_state is None: 75 | #print("initialized as None") 76 | motion_hidden_state = self.init_hidden_state.repeat(bs, N, 1, H, W) 77 | else: 78 | #print("later iterations") 79 | motion_hidden_state = motion_hidden_state.reshape(bs, N, -1, H, W) 80 | 81 | motion_hidden_state_sc = motion_hidden_state.clone() 82 | 83 | forward_loc = forward_flow+coords0 84 | backward_loc = backward_flow+coords0 85 | 86 | forward_motion_hidden_state = torch.cat([motion_hidden_state[:, 1:, ...], torch.zeros(bs, 1, 48, H, W).to(motion_hidden_state.device)], dim=1).reshape(BN, -1, H, W) 87 | forward_motion_hidden_state = self.sample_flo_feat(forward_loc, forward_motion_hidden_state) 88 | backward_motion_hidden_state = torch.cat([torch.zeros(bs, 1, 48, H, W).to(motion_hidden_state.device), motion_hidden_state[:, :N-1, ...]], dim=1).reshape(BN, -1, H, W) 89 | backward_motion_hidden_state = self.sample_flo_feat(backward_loc, backward_motion_hidden_state) 90 | 91 | forward_cor = self.convc1(forward_corr) 92 | backward_cor = self.convc1(backward_corr) 93 | cor = F.gelu(torch.cat([forward_cor, backward_cor], dim=1)) 94 | cor = self.convc2(cor) 95 | 96 | flow = torch.cat([forward_flow, backward_flow], dim=1) 97 | flo = self.convf1_(flow) 98 | flo = self.convf2(flo) 99 | 100 | cor_flo = torch.cat([cor, flo, forward_motion_hidden_state, backward_motion_hidden_state, motion_hidden_state.reshape(BN, -1, H, W)], dim=1) 101 | out = self.conv(cor_flo) 102 | 103 | out, motion_hidden_state = torch.split(out, [124, 48], dim=1) 104 | 105 | motion_hidden_state = motion_hidden_state + motion_hidden_state_sc 106 | 107 | return torch.cat([out, flow], dim=1), motion_hidden_state 108 | 109 | 110 | class SKUpdateBlock6_Deep_nopoolres_AllDecoder2(nn.Module): 111 | def __init__(self, args, hidden_dim): 112 | super().__init__() 113 | self.args = args 114 | 115 | args.k_conv = [1, 15] 116 | args.PCUpdater_conv = [1, 7] 117 | 118 | self.encoder = SKMotionEncoder6_Deep_nopool_res(args) 119 | self.gru = PCBlock4_Deep_nopool_res(128+hidden_dim+hidden_dim+128, 128, k_conv=args.PCUpdater_conv) 120 | self.flow_head = PCBlock4_Deep_nopool_res(128, 4, k_conv=args.k_conv) 121 | 122 | self.mask = nn.Sequential( 123 | nn.Conv2d(128, 256, 3, padding=1), 124 | nn.ReLU(inplace=True), 125 | nn.Conv2d(256, 64*9*2, 1, padding=0)) 126 | 127 | self.aggregator = Aggregate(args=self.args, dim=128, dim_head=128, heads=1) 128 | 129 | def forward(self, net, motion_hidden_state, inp, forward_corr, backward_corr, forward_flow, backward_flow, coords0, attention, bs): 130 | 131 | motion_features, motion_hidden_state = self.encoder(motion_hidden_state, forward_flow, backward_flow, coords0, forward_corr, backward_corr, bs=bs) 132 | motion_features_global = self.aggregator(attention, motion_features) 133 | inp_cat = torch.cat([inp, motion_features, motion_features_global], dim=1) 134 | 135 | # Attentional update 136 | net = self.gru(torch.cat([net, inp_cat], dim=1)) 137 | 138 | delta_flow = self.flow_head(net) 139 | 140 | # scale mask to balence gradients 141 | mask = 100.0 * self.mask(net) 142 | return net, motion_hidden_state, mask, delta_flow 143 | -------------------------------------------------------------------------------- /core/Networks/MOFNetStack/sk.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .gma import Aggregate 5 | 6 | class PCBlock4_Deep_nopool_res(nn.Module): 7 | def __init__(self, C_in, C_out, k_conv): 8 | super().__init__() 9 | self.conv_list = nn.ModuleList([ 10 | nn.Conv2d(C_in, C_in, kernel, stride=1, padding=kernel//2, groups=C_in) for kernel in k_conv]) 11 | 12 | self.ffn1 = nn.Sequential( 13 | nn.Conv2d(C_in, int(1.5*C_in), 1, padding=0), 14 | nn.GELU(), 15 | nn.Conv2d(int(1.5*C_in), C_in, 1, padding=0), 16 | ) 17 | self.pw = nn.Conv2d(C_in, C_in, 1, padding=0) 18 | self.ffn2 = nn.Sequential( 19 | nn.Conv2d(C_in, int(1.5*C_in), 1, padding=0), 20 | nn.GELU(), 21 | nn.Conv2d(int(1.5*C_in), C_out, 1, padding=0), 22 | ) 23 | 24 | def forward(self, x): 25 | x = F.gelu(x + self.ffn1(x)) 26 | for conv in self.conv_list: 27 | x = F.gelu(x + conv(x)) 28 | x = F.gelu(x + self.pw(x)) 29 | x = self.ffn2(x) 30 | return x 31 | 32 | 33 | class SKMotionEncoder6_Deep_nopool_res(nn.Module): 34 | def __init__(self, args): 35 | super().__init__() 36 | cor_planes = 81*4*args.cost_heads_num*2 37 | self.convc1 = PCBlock4_Deep_nopool_res(cor_planes, 256, k_conv=args.k_conv) 38 | self.convc2 = PCBlock4_Deep_nopool_res(256, 192, k_conv=args.k_conv) 39 | 40 | self.convf1_ = nn.Conv2d(4, 128, 1, 1, 0) 41 | self.convf2 = PCBlock4_Deep_nopool_res(128, 64, k_conv=args.k_conv) 42 | 43 | self.conv = PCBlock4_Deep_nopool_res(64+192, 128-4, k_conv=args.k_conv) 44 | 45 | 46 | def forward(self, flow, corr): 47 | cor = F.gelu(self.convc1(corr)) 48 | 49 | cor = self.convc2(cor) 50 | 51 | flo = self.convf1_(flow) 52 | flo = self.convf2(flo) 53 | 54 | cor_flo = torch.cat([cor, flo], dim=1) 55 | out = self.conv(cor_flo) 56 | 57 | return torch.cat([out, flow], dim=1) 58 | 59 | 60 | class SKUpdateBlock6_Deep_nopoolres_AllDecoder(nn.Module): 61 | def __init__(self, args, hidden_dim): 62 | super().__init__() 63 | self.args = args 64 | 65 | args.k_conv = [1, 15] 66 | args.PCUpdater_conv = [1, 7] 67 | 68 | self.encoder = SKMotionEncoder6_Deep_nopool_res(args) 69 | self.gru = PCBlock4_Deep_nopool_res(128+hidden_dim+hidden_dim+128, 128, k_conv=args.PCUpdater_conv) 70 | self.flow_head = PCBlock4_Deep_nopool_res(128, 4, k_conv=args.k_conv) 71 | 72 | self.mask = nn.Sequential( 73 | nn.Conv2d(128, 256, 3, padding=1), 74 | nn.ReLU(inplace=True), 75 | nn.Conv2d(256, 64*9*2, 1, padding=0)) 76 | 77 | self.aggregator = Aggregate(args=self.args, dim=128, dim_head=128, heads=1) 78 | 79 | def forward(self, net, inp, corr, flow, attention): 80 | motion_features = self.encoder(flow, corr) 81 | motion_features_global = self.aggregator(attention, motion_features) 82 | inp_cat = torch.cat([inp, motion_features, motion_features_global], dim=1) 83 | 84 | # Attentional update 85 | net = self.gru(torch.cat([net, inp_cat], dim=1)) 86 | 87 | delta_flow = self.flow_head(net) 88 | 89 | # scale mask to balence gradients 90 | mask = .25 * self.mask(net) 91 | return net, mask, delta_flow 92 | -------------------------------------------------------------------------------- /core/Networks/MOFNetStack/sk2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .gma import Aggregate 5 | 6 | class PCBlock4_Deep_nopool_res(nn.Module): 7 | def __init__(self, C_in, C_out, k_conv): 8 | super().__init__() 9 | self.conv_list = nn.ModuleList([ 10 | nn.Conv2d(C_in, C_in, kernel, stride=1, padding=kernel//2, groups=C_in) for kernel in k_conv]) 11 | 12 | self.ffn1 = nn.Sequential( 13 | nn.Conv2d(C_in, int(1.5*C_in), 1, padding=0), 14 | nn.GELU(), 15 | nn.Conv2d(int(1.5*C_in), C_in, 1, padding=0), 16 | ) 17 | self.pw = nn.Conv2d(C_in, C_in, 1, padding=0) 18 | self.ffn2 = nn.Sequential( 19 | nn.Conv2d(C_in, int(1.5*C_in), 1, padding=0), 20 | nn.GELU(), 21 | nn.Conv2d(int(1.5*C_in), C_out, 1, padding=0), 22 | ) 23 | 24 | def forward(self, x): 25 | x = F.gelu(x + self.ffn1(x)) 26 | for conv in self.conv_list: 27 | x = F.gelu(x + conv(x)) 28 | x = F.gelu(x + self.pw(x)) 29 | x = self.ffn2(x) 30 | return x 31 | 32 | 33 | class SKMotionEncoder6_Deep_nopool_res(nn.Module): 34 | def __init__(self, args): 35 | super().__init__() 36 | self.cor_planes = cor_planes = (args.corr_radius*2+1)**2*args.cost_heads_num*args.corr_levels 37 | self.convc1 = PCBlock4_Deep_nopool_res(cor_planes, 128, k_conv=args.k_conv) 38 | self.convc2 = PCBlock4_Deep_nopool_res(256, 192, k_conv=args.k_conv) 39 | 40 | self.convf1_ = nn.Conv2d(4, 128, 1, 1, 0) 41 | self.convf2 = PCBlock4_Deep_nopool_res(128, 64, k_conv=args.k_conv) 42 | 43 | self.conv = PCBlock4_Deep_nopool_res(64+192, 128-4, k_conv=args.k_conv) 44 | 45 | 46 | def forward(self, forward_flow, backward_flow, forward_corr, backward_corr): 47 | cor = F.gelu(torch.cat([self.convc1(forward_corr), self.convc1(backward_corr)], dim=1)) 48 | 49 | cor = self.convc2(cor) 50 | 51 | flow = torch.cat([forward_flow, backward_flow], dim=1) 52 | flo = self.convf1_(flow) 53 | flo = self.convf2(flo) 54 | 55 | cor_flo = torch.cat([cor, flo], dim=1) 56 | out = self.conv(cor_flo) 57 | 58 | return torch.cat([out, flow], dim=1) 59 | 60 | 61 | class SKUpdateBlock6_Deep_nopoolres_AllDecoder2(nn.Module): 62 | def __init__(self, args, hidden_dim): 63 | super().__init__() 64 | self.args = args 65 | 66 | args.k_conv = [1, 15] 67 | args.PCUpdater_conv = [1, 7] 68 | 69 | self.encoder = SKMotionEncoder6_Deep_nopool_res(args) 70 | self.gru = PCBlock4_Deep_nopool_res(128+hidden_dim+hidden_dim+128, 128, k_conv=args.PCUpdater_conv) 71 | self.flow_head = PCBlock4_Deep_nopool_res(128, 4, k_conv=args.k_conv) 72 | 73 | self.mask = nn.Sequential( 74 | nn.Conv2d(128, 256, 3, padding=1), 75 | nn.ReLU(inplace=True), 76 | nn.Conv2d(256, 64*9*2, 1, padding=0)) 77 | 78 | self.aggregator = Aggregate(args=self.args, dim=128, dim_head=128, heads=1) 79 | 80 | def forward(self, net, inp, forward_corr, backward_corr, forward_flow, backward_flow, coords0, attention, bs): 81 | 82 | motion_features = self.encoder(forward_flow, backward_flow, forward_corr, backward_corr) 83 | motion_features_global = self.aggregator(attention, motion_features) 84 | inp_cat = torch.cat([inp, motion_features, motion_features_global], dim=1) 85 | 86 | # Attentional update 87 | net = self.gru(torch.cat([net, inp_cat], dim=1)) 88 | 89 | delta_flow = self.flow_head(net) 90 | 91 | # scale mask to balence gradients 92 | mask = .25 * self.mask(net) 93 | return net, mask, delta_flow 94 | -------------------------------------------------------------------------------- /core/Networks/MOFNetStack/stack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .gma import Aggregate 5 | 6 | from ...utils.utils import bilinear_sampler 7 | 8 | class PCBlock4_Deep_nopool_res(nn.Module): 9 | def __init__(self, C_in, C_out, k_conv): 10 | super().__init__() 11 | self.conv_list = nn.ModuleList([ 12 | nn.Conv2d(C_in, C_in, kernel, stride=1, padding=kernel//2, groups=C_in) for kernel in k_conv]) 13 | 14 | self.ffn1 = nn.Sequential( 15 | nn.Conv2d(C_in, int(1.5*C_in), 1, padding=0), 16 | nn.GELU(), 17 | nn.Conv2d(int(1.5*C_in), C_in, 1, padding=0), 18 | ) 19 | self.pw = nn.Conv2d(C_in, C_in, 1, padding=0) 20 | self.ffn2 = nn.Sequential( 21 | nn.Conv2d(C_in, int(1.5*C_in), 1, padding=0), 22 | nn.GELU(), 23 | nn.Conv2d(int(1.5*C_in), C_out, 1, padding=0), 24 | ) 25 | 26 | def forward(self, x): 27 | x = F.gelu(x + self.ffn1(x)) 28 | for conv in self.conv_list: 29 | x = F.gelu(x + conv(x)) 30 | x = F.gelu(x + self.pw(x)) 31 | x = self.ffn2(x) 32 | return x 33 | 34 | class velocity_update_block(nn.Module): 35 | def __init__(self, C_in=43+128+43, C_out=43, C_hidden=64): 36 | super().__init__() 37 | self.mlp = nn.Sequential( 38 | nn.Conv2d(C_in, C_hidden, 3, padding=1), 39 | nn.GELU(), 40 | nn.Conv2d(C_hidden, C_hidden, 3, padding=1), 41 | nn.GELU(), 42 | nn.Conv2d(C_hidden, C_out, 3, padding=1), 43 | ) 44 | def forward(self, x): 45 | return self.mlp(x) 46 | 47 | 48 | class SKMotionEncoder6_Deep_nopool_res(nn.Module): 49 | def __init__(self, args): 50 | super().__init__() 51 | self.cor_planes = cor_planes = (args.corr_radius*2+1)**2*args.cost_heads_num*args.corr_levels 52 | self.convc1 = PCBlock4_Deep_nopool_res(cor_planes, 128, k_conv=args.k_conv) 53 | self.convc2 = PCBlock4_Deep_nopool_res(256, 192, k_conv=args.k_conv) 54 | 55 | self.init_hidden_state = nn.Parameter(torch.randn(1, 1, 48, 1, 1)) 56 | 57 | self.convf1_ = nn.Conv2d(4, 128, 1, 1, 0) 58 | self.convf2 = PCBlock4_Deep_nopool_res(128, 64, k_conv=args.k_conv) 59 | 60 | self.conv = PCBlock4_Deep_nopool_res(64+192+48*3, 128-4+48, k_conv=args.k_conv) 61 | 62 | self.velocity_update_block = velocity_update_block() 63 | 64 | def sample_flo_feat(self, flow, feat): 65 | 66 | sampled_feat = bilinear_sampler(feat.float(), flow.permute(0, 2, 3, 1)) 67 | return sampled_feat 68 | 69 | def forward(self, motion_hidden_state, forward_flow, backward_flow, coords0, forward_corr, backward_corr, bs): 70 | 71 | BN, _, H, W = forward_flow.shape 72 | N = BN // bs 73 | 74 | if motion_hidden_state is None: 75 | #print("initialized as None") 76 | motion_hidden_state = self.init_hidden_state.repeat(bs, N, 1, H, W) 77 | else: 78 | #print("later iterations") 79 | motion_hidden_state = motion_hidden_state.reshape(bs, N, -1, H, W) 80 | 81 | forward_loc = forward_flow+coords0 82 | backward_loc = backward_flow+coords0 83 | 84 | forward_motion_hidden_state = torch.cat([motion_hidden_state[:, 1:, ...], torch.zeros(bs, 1, 48, H, W).to(motion_hidden_state.device)], dim=1).reshape(BN, -1, H, W) 85 | forward_motion_hidden_state = self.sample_flo_feat(forward_loc, forward_motion_hidden_state) 86 | backward_motion_hidden_state = torch.cat([torch.zeros(bs, 1, 48, H, W).to(motion_hidden_state.device), motion_hidden_state[:, :N-1, ...]], dim=1).reshape(BN, -1, H, W) 87 | backward_motion_hidden_state = self.sample_flo_feat(backward_loc, backward_motion_hidden_state) 88 | 89 | forward_cor = self.convc1(forward_corr) 90 | backward_cor = self.convc1(backward_corr) 91 | cor = F.gelu(torch.cat([forward_cor, backward_cor], dim=1)) 92 | cor = self.convc2(cor) 93 | 94 | flow = torch.cat([forward_flow, backward_flow], dim=1) 95 | flo = self.convf1_(flow) 96 | flo = self.convf2(flo) 97 | 98 | cor_flo = torch.cat([cor, flo, forward_motion_hidden_state, backward_motion_hidden_state, motion_hidden_state.reshape(BN, -1, H, W)], dim=1) 99 | out = self.conv(cor_flo) 100 | 101 | out, motion_hidden_state = torch.split(out, [124, 48], dim=1) 102 | 103 | return torch.cat([out, flow], dim=1), motion_hidden_state 104 | 105 | 106 | class SKUpdateBlock6_Deep_nopoolres_AllDecoder2(nn.Module): 107 | def __init__(self, args, hidden_dim): 108 | super().__init__() 109 | self.args = args 110 | 111 | args.k_conv = [1, 15] 112 | args.PCUpdater_conv = [1, 7] 113 | 114 | hidden_dim_ratio = 256 // args.feat_dim 115 | 116 | self.encoder = SKMotionEncoder6_Deep_nopool_res(args) 117 | self.gru = PCBlock4_Deep_nopool_res(128+hidden_dim+hidden_dim+128, 128//hidden_dim_ratio, k_conv=args.PCUpdater_conv) 118 | self.flow_head = PCBlock4_Deep_nopool_res(128//hidden_dim_ratio, 4, k_conv=args.k_conv) 119 | 120 | self.mask = nn.Sequential( 121 | nn.Conv2d(128//hidden_dim_ratio, 256//hidden_dim_ratio, 3, padding=1), 122 | nn.ReLU(inplace=True), 123 | nn.Conv2d(256//hidden_dim_ratio, args.down_ratio**2*9*2, 1, padding=0)) 124 | 125 | self.aggregator = Aggregate(args=self.args, dim=128, dim_head=128, heads=1) 126 | 127 | def forward(self, net, motion_hidden_state, inp, forward_corr, backward_corr, forward_flow, backward_flow, coords0, attention, bs): 128 | 129 | motion_features, motion_hidden_state = self.encoder(motion_hidden_state, forward_flow, backward_flow, coords0, forward_corr, backward_corr, bs=bs) 130 | motion_features_global = self.aggregator(attention, motion_features) 131 | inp_cat = torch.cat([inp, motion_features, motion_features_global], dim=1) 132 | 133 | # Attentional update 134 | net = self.gru(torch.cat([net, inp_cat], dim=1)) 135 | 136 | delta_flow = self.flow_head(net) 137 | 138 | # scale mask to balence gradients 139 | mask = 100.0 * self.mask(net) 140 | return net, motion_hidden_state, mask, delta_flow 141 | -------------------------------------------------------------------------------- /core/Networks/MOFNetStack/stackcat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .gma import Aggregate 5 | 6 | from ...utils.utils import bilinear_sampler 7 | 8 | class PCBlock4_Deep_nopool_res(nn.Module): 9 | def __init__(self, C_in, C_out, k_conv): 10 | super().__init__() 11 | self.conv_list = nn.ModuleList([ 12 | nn.Conv2d(C_in, C_in, kernel, stride=1, padding=kernel//2, groups=C_in) for kernel in k_conv]) 13 | 14 | self.ffn1 = nn.Sequential( 15 | nn.Conv2d(C_in, int(1.5*C_in), 1, padding=0), 16 | nn.GELU(), 17 | nn.Conv2d(int(1.5*C_in), C_in, 1, padding=0), 18 | ) 19 | self.pw = nn.Conv2d(C_in, C_in, 1, padding=0) 20 | self.ffn2 = nn.Sequential( 21 | nn.Conv2d(C_in, int(1.5*C_in), 1, padding=0), 22 | nn.GELU(), 23 | nn.Conv2d(int(1.5*C_in), C_out, 1, padding=0), 24 | ) 25 | 26 | def forward(self, x): 27 | x = F.gelu(x + self.ffn1(x)) 28 | for conv in self.conv_list: 29 | x = F.gelu(x + conv(x)) 30 | x = F.gelu(x + self.pw(x)) 31 | x = self.ffn2(x) 32 | return x 33 | 34 | class velocity_update_block(nn.Module): 35 | def __init__(self, C_in=43+128+43, C_out=43, C_hidden=64): 36 | super().__init__() 37 | self.mlp = nn.Sequential( 38 | nn.Conv2d(C_in, C_hidden, 3, padding=1), 39 | nn.GELU(), 40 | nn.Conv2d(C_hidden, C_hidden, 3, padding=1), 41 | nn.GELU(), 42 | nn.Conv2d(C_hidden, C_out, 3, padding=1), 43 | ) 44 | def forward(self, x): 45 | return self.mlp(x) 46 | 47 | 48 | class SKMotionEncoder6_Deep_nopool_res(nn.Module): 49 | def __init__(self, args): 50 | super().__init__() 51 | self.cor_planes = cor_planes = (args.corr_radius*2+1)**2*args.cost_heads_num*args.corr_levels 52 | self.convc1 = PCBlock4_Deep_nopool_res(cor_planes, 128, k_conv=args.k_conv) 53 | self.convc2 = PCBlock4_Deep_nopool_res(256, 192, k_conv=args.k_conv) 54 | 55 | self.init_hidden_state = nn.Parameter(torch.randn(1, 1, 48, 1, 1)) 56 | 57 | self.convf1_ = nn.Conv2d(4, 128, 1, 1, 0) 58 | self.convf2 = PCBlock4_Deep_nopool_res(128, 64, k_conv=args.k_conv) 59 | 60 | self.conv = PCBlock4_Deep_nopool_res(64+192+48*3, 128-4+48, k_conv=args.k_conv) 61 | 62 | self.velocity_update_block = velocity_update_block() 63 | 64 | def sample_flo_feat(self, flow, feat): 65 | 66 | sampled_feat = bilinear_sampler(feat.float(), flow.permute(0, 2, 3, 1)) 67 | return sampled_feat 68 | 69 | def forward(self, motion_hidden_state, forward_flow, backward_flow, coords0, forward_corr, backward_corr, bs): 70 | 71 | BN, _, H, W = forward_flow.shape 72 | N = BN // bs 73 | 74 | if motion_hidden_state is None: 75 | #print("initialized as None") 76 | motion_hidden_state = self.init_hidden_state.repeat(bs, N, 1, H, W) 77 | else: 78 | #print("later iterations") 79 | motion_hidden_state = motion_hidden_state.reshape(bs, N, -1, H, W) 80 | 81 | forward_loc = forward_flow+coords0 82 | backward_loc = backward_flow+coords0 83 | 84 | forward_motion_hidden_state = torch.cat([motion_hidden_state[:, 1:, ...], torch.zeros(bs, 1, 48, H, W).to(motion_hidden_state.device)], dim=1).reshape(BN, -1, H, W) 85 | backward_motion_hidden_state = torch.cat([torch.zeros(bs, 1, 48, H, W).to(motion_hidden_state.device), motion_hidden_state[:, :N-1, ...]], dim=1).reshape(BN, -1, H, W) 86 | 87 | forward_cor = self.convc1(forward_corr) 88 | backward_cor = self.convc1(backward_corr) 89 | cor = F.gelu(torch.cat([forward_cor, backward_cor], dim=1)) 90 | cor = self.convc2(cor) 91 | 92 | flow = torch.cat([forward_flow, backward_flow], dim=1) 93 | flo = self.convf1_(flow) 94 | flo = self.convf2(flo) 95 | 96 | cor_flo = torch.cat([cor, flo, forward_motion_hidden_state, backward_motion_hidden_state, motion_hidden_state.reshape(BN, -1, H, W)], dim=1) 97 | out = self.conv(cor_flo) 98 | 99 | out, motion_hidden_state = torch.split(out, [124, 48], dim=1) 100 | 101 | return torch.cat([out, flow], dim=1), motion_hidden_state 102 | 103 | 104 | class SKUpdateBlock6_Deep_nopoolres_AllDecoder2(nn.Module): 105 | def __init__(self, args, hidden_dim): 106 | super().__init__() 107 | self.args = args 108 | 109 | args.k_conv = [1, 15] 110 | args.PCUpdater_conv = [1, 7] 111 | 112 | self.encoder = SKMotionEncoder6_Deep_nopool_res(args) 113 | self.gru = PCBlock4_Deep_nopool_res(128+hidden_dim+hidden_dim+128, 128, k_conv=args.PCUpdater_conv) 114 | self.flow_head = PCBlock4_Deep_nopool_res(128, 4, k_conv=args.k_conv) 115 | 116 | self.mask = nn.Sequential( 117 | nn.Conv2d(128, 256, 3, padding=1), 118 | nn.ReLU(inplace=True), 119 | nn.Conv2d(256, 64*9*2, 1, padding=0)) 120 | 121 | self.aggregator = Aggregate(args=self.args, dim=128, dim_head=128, heads=1) 122 | 123 | def forward(self, net, motion_hidden_state, inp, forward_corr, backward_corr, forward_flow, backward_flow, coords0, attention, bs): 124 | 125 | motion_features, motion_hidden_state = self.encoder(motion_hidden_state, forward_flow, backward_flow, coords0, forward_corr, backward_corr, bs=bs) 126 | motion_features_global = self.aggregator(attention, motion_features) 127 | inp_cat = torch.cat([inp, motion_features, motion_features_global], dim=1) 128 | 129 | # Attentional update 130 | net = self.gru(torch.cat([net, inp_cat], dim=1)) 131 | 132 | delta_flow = self.flow_head(net) 133 | 134 | # scale mask to balence gradients 135 | mask = 100.0 * self.mask(net) 136 | return net, motion_hidden_state, mask, delta_flow 137 | -------------------------------------------------------------------------------- /core/Networks/MOFNetStack/update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .gma import Aggregate 5 | 6 | 7 | class FlowHead(nn.Module): 8 | def __init__(self, input_dim=128, hidden_dim=256): 9 | super(FlowHead, self).__init__() 10 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 11 | self.conv2 = nn.Conv2d(hidden_dim, 4, 3, padding=1) 12 | self.relu = nn.ReLU(inplace=True) 13 | 14 | def forward(self, x): 15 | return self.conv2(self.relu(self.conv1(x))) 16 | 17 | 18 | class ConvGRU(nn.Module): 19 | def __init__(self, hidden_dim=128, input_dim=128+128): 20 | super(ConvGRU, self).__init__() 21 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 22 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 23 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 24 | 25 | def forward(self, h, x): 26 | hx = torch.cat([h, x], dim=1) 27 | 28 | z = torch.sigmoid(self.convz(hx)) 29 | r = torch.sigmoid(self.convr(hx)) 30 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) 31 | 32 | h = (1-z) * h + z * q 33 | return h 34 | 35 | 36 | class SepConvGRU(nn.Module): 37 | def __init__(self, hidden_dim=128, input_dim=192+128): 38 | super(SepConvGRU, self).__init__() 39 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 40 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 41 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 42 | 43 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 44 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 45 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 46 | 47 | 48 | def forward(self, h, x): 49 | # horizontal 50 | hx = torch.cat([h, x], dim=1) 51 | z = torch.sigmoid(self.convz1(hx)) 52 | r = torch.sigmoid(self.convr1(hx)) 53 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) 54 | h = (1-z) * h + z * q 55 | 56 | # vertical 57 | hx = torch.cat([h, x], dim=1) 58 | z = torch.sigmoid(self.convz2(hx)) 59 | r = torch.sigmoid(self.convr2(hx)) 60 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) 61 | h = (1-z) * h + z * q 62 | 63 | return h 64 | 65 | 66 | class BasicMotionEncoder(nn.Module): 67 | def __init__(self, args): 68 | super(BasicMotionEncoder, self).__init__() 69 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 * 2 70 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) 71 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 72 | self.convf1 = nn.Conv2d(4, 128, 7, padding=3) 73 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 74 | self.conv = nn.Conv2d(64+192, 128-4, 3, padding=1) 75 | 76 | def forward(self, flow, corr): 77 | cor = F.relu(self.convc1(corr)) 78 | cor = F.relu(self.convc2(cor)) 79 | flo = F.relu(self.convf1(flow)) 80 | flo = F.relu(self.convf2(flo)) 81 | 82 | cor_flo = torch.cat([cor, flo], dim=1) 83 | out = F.relu(self.conv(cor_flo)) 84 | return torch.cat([out, flow], dim=1) 85 | 86 | 87 | class BasicUpdateBlock(nn.Module): 88 | def __init__(self, args, hidden_dim=128, input_dim=128): 89 | super(BasicUpdateBlock, self).__init__() 90 | self.args = args 91 | self.encoder = BasicMotionEncoder(args) 92 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) 93 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 94 | 95 | self.mask = nn.Sequential( 96 | nn.Conv2d(128, 256, 3, padding=1), 97 | nn.ReLU(inplace=True), 98 | nn.Conv2d(256, 64*9, 1, padding=0)) 99 | 100 | def forward(self, net, inp, corr, flow, upsample=True): 101 | motion_features = self.encoder(flow, corr) 102 | inp = torch.cat([inp, motion_features], dim=1) 103 | 104 | net = self.gru(net, inp) 105 | delta_flow = self.flow_head(net) 106 | 107 | # scale mask to balence gradients 108 | mask = .25 * self.mask(net) 109 | return net, mask, delta_flow 110 | 111 | 112 | class GMAUpdateBlock(nn.Module): 113 | def __init__(self, args, hidden_dim=128): 114 | super().__init__() 115 | self.args = args 116 | self.encoder = BasicMotionEncoder(args) 117 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim+hidden_dim) 118 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 119 | 120 | self.mask = nn.Sequential( 121 | nn.Conv2d(128, 256, 3, padding=1), 122 | nn.ReLU(inplace=True), 123 | nn.Conv2d(256, 64*9*2, 1, padding=0)) 124 | 125 | self.aggregator = Aggregate(args=self.args, dim=128, dim_head=128, heads=1) 126 | 127 | def forward(self, net, inp, corr, flow, attention): 128 | motion_features = self.encoder(flow, corr) 129 | motion_features_global = self.aggregator(attention, motion_features) 130 | inp_cat = torch.cat([inp, motion_features, motion_features_global], dim=1) 131 | 132 | # Attentional update 133 | net = self.gru(net, inp_cat) 134 | 135 | delta_flow = self.flow_head(net) 136 | 137 | # scale mask to balence gradients 138 | mask = .25 * self.mask(net) 139 | 140 | return net, mask, delta_flow 141 | 142 | 143 | 144 | -------------------------------------------------------------------------------- /core/Networks/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | def build_network(cfg): 3 | name = cfg.network 4 | if name == 'MOFNetStack': 5 | from .MOFNetStack.network import MOFNet as network 6 | elif name == 'BOFNet': 7 | from .BOFNet.network import BOFNet as network 8 | else: 9 | raise ValueError(f"Network = {name} is not a valid optimizer!") 10 | 11 | return network(cfg[name]) 12 | -------------------------------------------------------------------------------- /core/Networks/encoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import timm 4 | import numpy as np 5 | #from .twins_ft import _twins_svt_large_jihao 6 | 7 | class twins_svt_large(nn.Module): 8 | def __init__(self, pretrained=True, del_layers=True): 9 | super().__init__() 10 | self.svt = timm.create_model('twins_svt_large', pretrained=pretrained) 11 | 12 | if del_layers: 13 | del self.svt.head 14 | del self.svt.patch_embeds[2] 15 | del self.svt.patch_embeds[2] 16 | del self.svt.blocks[2] 17 | del self.svt.blocks[2] 18 | del self.svt.pos_block[2] 19 | del self.svt.pos_block[2] 20 | 21 | def forward(self, x, data=None, layer=2): 22 | B = x.shape[0] 23 | for i, (embed, drop, blocks, pos_blk) in enumerate( 24 | zip(self.svt.patch_embeds, self.svt.pos_drops, self.svt.blocks, self.svt.pos_block)): 25 | 26 | x, size = embed(x) 27 | x = drop(x) 28 | for j, blk in enumerate(blocks): 29 | x = blk(x, size) 30 | if j==0: 31 | x = pos_blk(x, size) 32 | if i < len(self.svt.depths) - 1: 33 | x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() 34 | 35 | if i == 0: 36 | x_16 = x.clone() 37 | if i == layer-1: 38 | break 39 | 40 | return x 41 | 42 | def extract_ml_features(self, x, data=None, layer=2): 43 | res = [] 44 | B = x.shape[0] 45 | for i, (embed, drop, blocks, pos_blk) in enumerate( 46 | zip(self.svt.patch_embeds, self.svt.pos_drops, self.svt.blocks, self.svt.pos_block)): 47 | x, size = embed(x) 48 | if i == layer-1: 49 | x1 = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() 50 | x = drop(x) 51 | for j, blk in enumerate(blocks): 52 | x = blk(x, size) 53 | if j==0: 54 | x = pos_blk(x, size) 55 | if i < len(self.svt.depths) - 1: 56 | x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() 57 | 58 | if i == layer-1: 59 | break 60 | 61 | return x1, x 62 | 63 | def compute_params(self): 64 | num = 0 65 | 66 | for i, (embed, drop, blocks, pos_blk) in enumerate( 67 | zip(self.svt.patch_embeds, self.svt.pos_drops, self.svt.blocks, self.svt.pos_block)): 68 | 69 | for param in embed.parameters(): 70 | num += np.prod(param.size()) 71 | for param in blocks.parameters(): 72 | num += np.prod(param.size()) 73 | for param in pos_blk.parameters(): 74 | num += np.prod(param.size()) 75 | for param in drop.parameters(): 76 | num += np.prod(param.size()) 77 | if i == 1: 78 | break 79 | return num 80 | 81 | class convnext_large(nn.Module): 82 | def __init__(self, pretrained=True): 83 | super().__init__() 84 | self.convnext = timm.create_model('convnext_large', pretrained=pretrained) 85 | 86 | def forward(self, x, data=None, layer=2): 87 | 88 | x = self.convnext.stem(x) 89 | x = self.convnext.stages[0](x) 90 | x = self.convnext.stages[1](x) 91 | return x 92 | 93 | def compute_params(self): 94 | num = 0 95 | 96 | for param in self.convnext.stem.parameters(): 97 | num += np.prod(param.size()) 98 | for param in self.convnext.stages[0].parameters(): 99 | num += np.prod(param.size()) 100 | for param in self.convnext.stages[1].parameters(): 101 | num += np.prod(param.size()) 102 | 103 | return num 104 | 105 | class convnext_Xlarge_4x(nn.Module): 106 | def __init__(self, pretrained=True, del_layers=True): 107 | super().__init__() 108 | self.convnext = timm.create_model('convnext_xlarge_in22k', pretrained=pretrained) 109 | 110 | # self.convnext.stem[0].stride = (2, 2) 111 | # self.convnext.stem[0].padding = (1, 1) 112 | 113 | if del_layers: 114 | del self.convnext.head 115 | del self.convnext.stages[1] 116 | del self.convnext.stages[1] 117 | del self.convnext.stages[1] 118 | 119 | # print(self.convnext) 120 | 121 | 122 | def forward(self, x, data=None, layer=2): 123 | 124 | x = self.convnext.stem(x) 125 | x = self.convnext.stages[0](x) 126 | return x 127 | 128 | class convnext_base_2x(nn.Module): 129 | def __init__(self, pretrained=True, del_layers=True): 130 | super().__init__() 131 | self.convnext = timm.create_model('convnext_base_in22k', pretrained=pretrained) 132 | 133 | self.convnext.stem[0].stride = (2, 2) 134 | self.convnext.stem[0].padding = (1, 1) 135 | 136 | if del_layers: 137 | del self.convnext.head 138 | del self.convnext.stages[1] 139 | del self.convnext.stages[1] 140 | del self.convnext.stages[1] 141 | 142 | # print(self.convnext) 143 | 144 | 145 | def forward(self, x, data=None, layer=2): 146 | 147 | x = self.convnext.stem(x) 148 | x = self.convnext.stages[0](x) 149 | return x 150 | 151 | 152 | if __name__ == "__main__": 153 | m = convnext_Xlarge_2x() 154 | input = torch.randn(2, 3, 64, 64) 155 | out = m(input) 156 | print(out.shape) 157 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiaoyuShi97/VideoFlow/51489304db6e75fbdd9ff64d4656c1d120b5a673/core/__init__.py -------------------------------------------------------------------------------- /core/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | MAX_FLOW = 400 4 | 5 | def sequence_loss_twoframes(flow_preds, flow_gt, valid, cfg): 6 | """ Loss function defined over sequence of flow predictions """ 7 | 8 | gamma = cfg.gamma 9 | max_flow = cfg.max_flow 10 | n_predictions = len(flow_preds) 11 | flow_loss = 0.0 12 | flow_gt_thresholds = [5, 10, 20] 13 | 14 | # exlude invalid pixels and extremely large diplacements 15 | mag = torch.sum(flow_gt**2, dim=1).sqrt() 16 | valid = (valid >= 0.5) & (mag < max_flow) 17 | 18 | for i in range(n_predictions): 19 | i_weight = gamma**(n_predictions - i - 1) 20 | i_loss = (flow_preds[i] - flow_gt).abs() 21 | flow_loss += i_weight * (valid[:, None] * i_loss).mean() 22 | 23 | epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=1).sqrt() 24 | 25 | 26 | epe = epe.view(-1)[valid.view(-1)] 27 | 28 | metrics = { 29 | 'epe': epe.mean().item(), 30 | '1px': (epe < 1).float().mean().item(), 31 | '3px': (epe < 3).float().mean().item(), 32 | '5px': (epe < 5).float().mean().item(), 33 | } 34 | 35 | return flow_loss, metrics 36 | 37 | def sequence_loss(flow_preds, flow_gt, valid, cfg): 38 | """ Loss function defined over sequence of flow predictions """ 39 | 40 | #print(flow_gt.shape, valid.shape, flow_preds[0].shape) 41 | #exit() 42 | 43 | gamma = cfg.gamma 44 | max_flow = cfg.max_flow 45 | n_predictions = len(flow_preds) 46 | flow_loss = 0.0 47 | 48 | B, N, _, H, W = flow_gt.shape 49 | 50 | NAN_flag = False 51 | 52 | # exlude invalid pixels and extremely large diplacements 53 | mag = torch.sum(flow_gt**2, dim=2).sqrt() 54 | valid = (valid >= 0.5) & (mag < max_flow) 55 | 56 | for i in range(n_predictions): 57 | i_weight = gamma**(n_predictions - i - 1) 58 | 59 | flow_pre = flow_preds[i] 60 | i_loss = (flow_pre - flow_gt).abs() 61 | 62 | if torch.isnan(i_loss).any(): 63 | NAN_flag = True 64 | 65 | _valid = valid[:, :, None] 66 | if cfg.filter_epe: 67 | loss_mag = torch.sum(i_loss**2, dim=2).sqrt() 68 | mask = loss_mag > 1000 69 | #print(mask.shape, _valid.shape) 70 | if torch.any(mask): 71 | print("[Found extrem epe. Filtered out. Max is {}. Ratio is {}]".format(torch.max(loss_mag), torch.mean(mask.float()))) 72 | _valid = _valid & (~mask[:, :, None]) 73 | 74 | flow_loss += i_weight * (_valid * i_loss).mean() 75 | 76 | epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=2).sqrt() 77 | epe = epe.view(-1)[valid.view(-1)] 78 | 79 | metrics = { 80 | 'epe': epe.mean().item(), 81 | '1px': (epe < 1).float().mean().item(), 82 | '3px': (epe < 3).float().mean().item(), 83 | '5px': (epe < 5).float().mean().item(), 84 | } 85 | 86 | return flow_loss, metrics, NAN_flag 87 | 88 | # def smooth_l1_loss(diff): 89 | # cond = diff.abs() < 1 90 | # loss = torch.where(cond, 0.5*diff**2, diff.abs()-0.5) 91 | # return loss 92 | 93 | # def sequence_loss_smooth(flow_preds, flow_gt, valid, cfg): 94 | # """ Loss function defined over sequence of flow predictions """ 95 | 96 | # gamma = cfg.gamma 97 | # max_flow = cfg.max_flow 98 | # n_predictions = len(flow_preds) 99 | # flow_loss = 0.0 100 | # flow_gt_thresholds = [5, 10, 20] 101 | 102 | # # exlude invalid pixels and extremely large diplacements 103 | # mag = torch.sum(flow_gt**2, dim=1).sqrt() 104 | # valid = (valid >= 0.5) & (mag < max_flow) 105 | 106 | # for i in range(n_predictions): 107 | # i_weight = gamma**(n_predictions - i - 1) 108 | # i_loss = smooth_l1_loss((flow_preds[i] - flow_gt)) 109 | # flow_loss += i_weight * (valid[:, None] * i_loss).mean() 110 | 111 | # epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=1).sqrt() 112 | # epe = epe.view(-1)[valid.view(-1)] 113 | 114 | # metrics = { 115 | # 'epe': epe.mean().item(), 116 | # '1px': (epe < 1).float().mean().item(), 117 | # '3px': (epe < 3).float().mean().item(), 118 | # '5px': (epe < 5).float().mean().item(), 119 | # } 120 | 121 | # flow_gt_length = torch.sum(flow_gt**2, dim=1).sqrt() 122 | # flow_gt_length = flow_gt_length.view(-1)[valid.view(-1)] 123 | # for t in flow_gt_thresholds: 124 | # e = epe[flow_gt_length < t] 125 | # metrics.update({ 126 | # f"{t}-th-5px": (e < 5).float().mean().item() 127 | # }) 128 | 129 | 130 | # return flow_loss, metrics 131 | 132 | -------------------------------------------------------------------------------- /core/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR, ExponentialLR, OneCycleLR 3 | 4 | def fetch_optimizer(model, cfg): 5 | """ Create the optimizer and learning rate scheduler """ 6 | # optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon) 7 | 8 | # scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps+100, 9 | # pct_start=0.05, cycle_momentum=False, anneal_strategy='linear') 10 | optimizer = build_optimizer(model, cfg) 11 | scheduler = build_scheduler(cfg, optimizer) 12 | 13 | return optimizer, scheduler 14 | 15 | def build_optimizer(model, config): 16 | name = config.optimizer 17 | lr = config.canonical_lr 18 | 19 | if name == "adam": 20 | return torch.optim.Adam(model.parameters(), lr=lr, weight_decay=config.adam_decay, eps=config.epsilon) 21 | elif name == "adamw": 22 | if hasattr(config, 'twins_lr_factor'): 23 | factor = config.twins_lr_factor 24 | print("[Decrease lr of pre-trained model by factor {}]".format(factor)) 25 | param_dicts = [ 26 | {"params": [p for n, p in model.named_parameters() if "feat_encoder" not in n and 'context_encoder' not in n and p.requires_grad]}, 27 | { 28 | "params": [p for n, p in model.named_parameters() if ("feat_encoder" in n or 'context_encoder' in n) and p.requires_grad], 29 | "lr": lr*factor, 30 | }, 31 | ] 32 | full = [n for n, _ in model.named_parameters()] 33 | return torch.optim.AdamW(param_dicts, lr=lr, weight_decay=config.adamw_decay, eps=config.epsilon) 34 | else: 35 | return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=config.adamw_decay, eps=config.epsilon) 36 | else: 37 | raise ValueError(f"TRAINER.OPTIMIZER = {name} is not a valid optimizer!") 38 | 39 | 40 | def build_scheduler(config, optimizer): 41 | """ 42 | Returns: 43 | scheduler (dict):{ 44 | 'scheduler': lr_scheduler, 45 | 'interval': 'step', # or 'epoch' 46 | } 47 | """ 48 | # scheduler = {'interval': config.TRAINER.SCHEDULER_INTERVAL} 49 | name = config.scheduler 50 | lr = config.canonical_lr 51 | 52 | if name == 'OneCycleLR': 53 | # scheduler = OneCycleLR(optimizer, ) 54 | if hasattr(config, 'twins_lr_factor'): 55 | factor = config.twins_lr_factor 56 | scheduler = OneCycleLR(optimizer, [lr, lr*factor], config.num_steps+100, 57 | pct_start=0.05, cycle_momentum=False, anneal_strategy=config.anneal_strategy) 58 | else: 59 | scheduler = OneCycleLR(optimizer, lr, config.num_steps+100, 60 | pct_start=0.05, cycle_momentum=False, anneal_strategy=config.anneal_strategy) 61 | # elif name == 'MultiStepLR': 62 | # scheduler.update( 63 | # {'scheduler': MultiStepLR(optimizer, config.TRAINER.MSLR_MILESTONES, gamma=config.TRAINER.MSLR_GAMMA)}) 64 | #elif name == 'CosineAnnealing': 65 | # scheduler = CosineAnnealingLR(optimizer, config.num_steps+100) 66 | # scheduler.update( 67 | # {'scheduler': CosineAnnealingLR(optimizer, config.TRAINER.COSA_TMAX)}) 68 | # elif name == 'ExponentialLR': 69 | # scheduler.update( 70 | # {'scheduler': ExponentialLR(optimizer, config.TRAINER.ELR_GAMMA)}) 71 | else: 72 | raise NotImplementedError() 73 | 74 | return scheduler 75 | -------------------------------------------------------------------------------- /core/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiaoyuShi97/VideoFlow/51489304db6e75fbdd9ff64d4656c1d120b5a673/core/utils/__init__.py -------------------------------------------------------------------------------- /core/utils/augmentor_multiframes.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import math 4 | from PIL import Image 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | import torch 11 | from torchvision.transforms import ColorJitter 12 | import torch.nn.functional as F 13 | from . import flow_transforms 14 | 15 | class FlowAugmentor: 16 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True): 17 | 18 | # spatial augmentation params 19 | self.crop_size = crop_size 20 | self.min_scale = min_scale 21 | self.max_scale = max_scale 22 | self.spatial_aug_prob = 0.8 23 | self.stretch_prob = 0.8 24 | self.max_stretch = 0.2 25 | 26 | # flip augmentation params 27 | self.do_flip = do_flip 28 | self.h_flip_prob = 0.5 29 | self.v_flip_prob = 0.1 30 | 31 | # photometric augmentation params 32 | self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14) 33 | self.asymmetric_color_aug_prob = 0.2 34 | self.eraser_aug_prob = 0.5 35 | 36 | def color_transform(self, imgs): 37 | """ Photometric augmentation """ 38 | 39 | # asymmetric 40 | if np.random.rand() < self.asymmetric_color_aug_prob: 41 | imgs = [np.array(self.photo_aug(Image.fromarray(img)), dtype=np.uint8) for img in imgs] 42 | 43 | # symmetric 44 | else: 45 | img_num = len(imgs) 46 | image_stack = np.concatenate(imgs, axis=0) 47 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 48 | imgs = np.split(image_stack, img_num, axis=0) 49 | 50 | return imgs 51 | 52 | def eraser_transform(self, imgs, bounds=[50, 100]): 53 | print("[erasing]") 54 | ht, wd = imgs[0].shape[:2] 55 | if np.random.rand() < self.eraser_aug_prob: 56 | for idx in range(len(imgs)): 57 | mean_color = np.mean(imgs[idx].reshape(-1, 3), axis=0) 58 | for _ in range(np.random.randint(1, 3)): 59 | #print("!@#!@#!@#!@#!@#!@#!") 60 | x0 = np.random.randint(0, wd) 61 | y0 = np.random.randint(0, ht) 62 | dx = np.random.randint(bounds[0], bounds[1]) 63 | dy = np.random.randint(bounds[0], bounds[1]) 64 | imgs[idx][y0:y0+dy, x0:x0+dx, :] = mean_color 65 | return imgs 66 | 67 | def spatial_transform(self, imgs, flows): 68 | # randomly sample scale 69 | ht, wd = imgs[0].shape[:2] 70 | min_scale = np.maximum( 71 | (self.crop_size[0] + 8) / float(ht), 72 | (self.crop_size[1] + 8) / float(wd)) 73 | 74 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 75 | scale_x = scale 76 | scale_y = scale 77 | if np.random.rand() < self.stretch_prob: 78 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 79 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 80 | scale_x = np.clip(scale_x, min_scale, None) 81 | scale_y = np.clip(scale_y, min_scale, None) 82 | 83 | if np.random.rand() < self.spatial_aug_prob: 84 | # rescale the images 85 | imgs = [cv2.resize(img, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) for img in imgs] 86 | 87 | flows = [cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) for flow in flows] 88 | flows = [flow * [scale_x, scale_y] for flow in flows] 89 | 90 | if self.do_flip: 91 | if np.random.rand() < self.h_flip_prob: # h-flip 92 | imgs = [img[:, ::-1] for img in imgs] 93 | flows = [flow[:, ::-1] * [-1.0, 1.0] for flow in flows] 94 | 95 | if np.random.rand() < self.v_flip_prob: # v-flip 96 | imgs = [img[::-1, :] for img in imgs] 97 | flows = [flow[::-1, :] * [1.0, -1.0] for flow in flows] 98 | 99 | if imgs[0].shape[0] == self.crop_size[0]: 100 | y0 = 0 101 | else: 102 | y0 = np.random.randint(0, imgs[0].shape[0] - self.crop_size[0]) 103 | if imgs[0].shape[1] == self.crop_size[1]: 104 | x0 = 0 105 | else: 106 | x0 = np.random.randint(0, imgs[0].shape[1] - self.crop_size[1]) 107 | 108 | imgs = [img[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] for img in imgs] 109 | flows = [flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] for flow in flows] 110 | 111 | return imgs, flows 112 | 113 | def __call__(self, imgs, flows): 114 | imgs = self.color_transform(imgs) 115 | #imgs = self.eraser_transform(imgs) 116 | imgs, flows = self.spatial_transform(imgs, flows) 117 | 118 | imgs = [np.ascontiguousarray(img) for img in imgs] 119 | flows = [np.ascontiguousarray(flow) for flow in flows] 120 | 121 | return imgs, flows 122 | 123 | class SparseFlowAugmentor: 124 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False): 125 | # spatial augmentation params 126 | self.crop_size = crop_size 127 | self.min_scale = min_scale 128 | self.max_scale = max_scale 129 | self.spatial_aug_prob = 0.8 130 | self.stretch_prob = 0.8 131 | self.max_stretch = 0.2 132 | 133 | # flip augmentation params 134 | self.do_flip = do_flip 135 | self.h_flip_prob = 0.5 136 | self.v_flip_prob = 0.1 137 | 138 | # photometric augmentation params 139 | self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14) 140 | self.asymmetric_color_aug_prob = 0.2 141 | self.eraser_aug_prob = 0.5 142 | 143 | def color_transform(self, imgs): 144 | 145 | img_num = len(imgs) 146 | image_stack = np.concatenate(imgs, axis=0) 147 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 148 | imgs = np.split(image_stack, img_num, axis=0) 149 | 150 | return imgs 151 | 152 | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): 153 | ht, wd = flow.shape[:2] 154 | coords = np.meshgrid(np.arange(wd), np.arange(ht)) 155 | coords = np.stack(coords, axis=-1) 156 | 157 | coords = coords.reshape(-1, 2).astype(np.float32) 158 | flow = flow.reshape(-1, 2).astype(np.float32) 159 | valid = valid.reshape(-1).astype(np.float32) 160 | 161 | coords0 = coords[valid>=1] 162 | flow0 = flow[valid>=1] 163 | 164 | ht1 = int(round(ht * fy)) 165 | wd1 = int(round(wd * fx)) 166 | 167 | coords1 = coords0 * [fx, fy] 168 | flow1 = flow0 * [fx, fy] 169 | 170 | xx = np.round(coords1[:,0]).astype(np.int32) 171 | yy = np.round(coords1[:,1]).astype(np.int32) 172 | 173 | v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) 174 | xx = xx[v] 175 | yy = yy[v] 176 | flow1 = flow1[v] 177 | 178 | flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) 179 | valid_img = np.zeros([ht1, wd1], dtype=np.int32) 180 | 181 | flow_img[yy, xx] = flow1 182 | valid_img[yy, xx] = 1 183 | 184 | return flow_img, valid_img 185 | 186 | def spatial_transform(self, imgs, flows, valids): 187 | pad_t = 0 188 | pad_b = 0 189 | pad_l = 0 190 | pad_r = 0 191 | if self.crop_size[0] > imgs[0].shape[0]: 192 | #pad_t = self.crop_size[0] - img1.shape[0] 193 | pad_b = self.crop_size[0] - imgs[0].shape[0] 194 | if self.crop_size[1] > imgs[0].shape[1]: 195 | print("[In kitti data, padding along width axis now!]") 196 | pad_r = self.crop_size[1] - imgs[0].shape[1] 197 | if pad_b != 0 or pad_r != 0 or pad_t != 0: 198 | imgs = [np.pad(img, ((pad_t, pad_b), (pad_l, pad_r), (0, 0)), 'constant', constant_values=((0, 0), (0, 0), (0, 0))) for img in imgs] 199 | flows = [np.pad(flow, ((pad_t, pad_b), (pad_l, pad_r), (0, 0)), 'constant', constant_values=((0, 0), (0, 0), (0, 0))) for flow in flows] 200 | valids = [np.pad(valid, ((pad_t, pad_b), (pad_l, pad_r)), 'constant', constant_values=((0, 0), (0, 0))) for valid in valids] 201 | # randomly sample scale 202 | 203 | ht, wd = imgs[0].shape[:2] 204 | min_scale = np.maximum( 205 | (self.crop_size[0] + 1) / float(ht), 206 | (self.crop_size[1] + 1) / float(wd)) 207 | 208 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 209 | scale_x = np.clip(scale, min_scale, None) 210 | scale_y = np.clip(scale, min_scale, None) 211 | 212 | if np.random.rand() < self.spatial_aug_prob: 213 | # rescale the images 214 | imgs = [cv2.resize(img, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) for img in imgs] 215 | for idx in range(len(flows)): 216 | flows[idx], valids[idx] = self.resize_sparse_flow_map(flows[idx], valids[idx], fx=scale_x, fy=scale_y) 217 | 218 | if self.do_flip: 219 | if np.random.rand() < 0.5: # h-flip 220 | imgs = [img[:, ::-1] for img in imgs] 221 | flows = [flow[:, ::-1] * [-1.0, 1.0] for flow in flows] 222 | valids = [valid[:, ::-1] for valid in valids] 223 | 224 | margin_y = 20 225 | margin_x = 50 226 | 227 | y0 = np.random.randint(0, imgs[0].shape[0] - self.crop_size[0] + margin_y) 228 | x0 = np.random.randint(-margin_x, imgs[0].shape[1] - self.crop_size[1] + margin_x) 229 | 230 | y0 = np.clip(y0, 0, imgs[0].shape[0] - self.crop_size[0]) 231 | x0 = np.clip(x0, 0, imgs[0].shape[1] - self.crop_size[1]) 232 | 233 | imgs = [img[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] for img in imgs] 234 | 235 | flows = [flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] for flow in flows] 236 | 237 | valids = [valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] for valid in valids] 238 | 239 | return imgs, flows, valids 240 | 241 | 242 | def __call__(self, imgs, flows, valids): 243 | imgs = self.color_transform(imgs) 244 | imgs, flows, valids = self.spatial_transform(imgs, flows, valids) 245 | 246 | imgs = [np.ascontiguousarray(img) for img in imgs] 247 | flows = [np.ascontiguousarray(flow) for flow in flows] 248 | valids = [np.ascontiguousarray(valid) for valid in valids] 249 | 250 | return imgs, flows, valids 251 | 252 | -------------------------------------------------------------------------------- /core/utils/flow_viz.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 2 | 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Tom Runia 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to conditions. 14 | # 15 | # Author: Tom Runia 16 | # Date Created: 2018-08-03 17 | 18 | import numpy as np 19 | 20 | def make_colorwheel(): 21 | """ 22 | Generates a color wheel for optical flow visualization as presented in: 23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 25 | 26 | Code follows the original C++ source code of Daniel Scharstein. 27 | Code follows the the Matlab source code of Deqing Sun. 28 | 29 | Returns: 30 | np.ndarray: Color wheel 31 | """ 32 | 33 | RY = 15 34 | YG = 6 35 | GC = 4 36 | CB = 11 37 | BM = 13 38 | MR = 6 39 | 40 | ncols = RY + YG + GC + CB + BM + MR 41 | colorwheel = np.zeros((ncols, 3)) 42 | col = 0 43 | 44 | # RY 45 | colorwheel[0:RY, 0] = 255 46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) 47 | col = col+RY 48 | # YG 49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) 50 | colorwheel[col:col+YG, 1] = 255 51 | col = col+YG 52 | # GC 53 | colorwheel[col:col+GC, 1] = 255 54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) 55 | col = col+GC 56 | # CB 57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) 58 | colorwheel[col:col+CB, 2] = 255 59 | col = col+CB 60 | # BM 61 | colorwheel[col:col+BM, 2] = 255 62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) 63 | col = col+BM 64 | # MR 65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) 66 | colorwheel[col:col+MR, 0] = 255 67 | return colorwheel 68 | 69 | 70 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 71 | """ 72 | Applies the flow color wheel to (possibly clipped) flow components u and v. 73 | 74 | According to the C++ source code of Daniel Scharstein 75 | According to the Matlab source code of Deqing Sun 76 | 77 | Args: 78 | u (np.ndarray): Input horizontal flow of shape [H,W] 79 | v (np.ndarray): Input vertical flow of shape [H,W] 80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 81 | 82 | Returns: 83 | np.ndarray: Flow visualization image of shape [H,W,3] 84 | """ 85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 86 | colorwheel = make_colorwheel() # shape [55x3] 87 | ncols = colorwheel.shape[0] 88 | rad = np.sqrt(np.square(u) + np.square(v)) 89 | a = np.arctan2(-v, -u)/np.pi 90 | fk = (a+1) / 2*(ncols-1) 91 | k0 = np.floor(fk).astype(np.int32) 92 | k1 = k0 + 1 93 | k1[k1 == ncols] = 0 94 | f = fk - k0 95 | for i in range(colorwheel.shape[1]): 96 | tmp = colorwheel[:,i] 97 | col0 = tmp[k0] / 255.0 98 | col1 = tmp[k1] / 255.0 99 | col = (1-f)*col0 + f*col1 100 | idx = (rad <= 1) 101 | col[idx] = 1 - rad[idx] * (1-col[idx]) 102 | col[~idx] = col[~idx] * 0.75 # out of range 103 | # Note the 2-i => BGR instead of RGB 104 | ch_idx = 2-i if convert_to_bgr else i 105 | flow_image[:,:,ch_idx] = np.floor(255 * col) 106 | return flow_image 107 | 108 | 109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 110 | """ 111 | Expects a two dimensional flow image of shape. 112 | 113 | Args: 114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 117 | 118 | Returns: 119 | np.ndarray: Flow visualization image of shape [H,W,3] 120 | """ 121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 123 | if clip_flow is not None: 124 | flow_uv = np.clip(flow_uv, 0, clip_flow) 125 | u = flow_uv[:,:,0] 126 | v = flow_uv[:,:,1] 127 | rad = np.sqrt(np.square(u) + np.square(v)) 128 | rad_max = np.max(rad) 129 | epsilon = 1e-5 130 | u = u / (rad_max + epsilon) 131 | v = v / (rad_max + epsilon) 132 | return flow_uv_to_colors(u, v, convert_to_bgr) 133 | -------------------------------------------------------------------------------- /core/utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from os.path import * 4 | import re 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | TAG_CHAR = np.array([202021.25], np.float32) 11 | 12 | def readFlow(fn): 13 | """ Read .flo file in Middlebury format""" 14 | # Code adapted from: 15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 16 | 17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 18 | # print 'fn = %s'%(fn) 19 | with open(fn, 'rb') as f: 20 | magic = np.fromfile(f, np.float32, count=1) 21 | if 202021.25 != magic: 22 | print('Magic number incorrect. Invalid .flo file') 23 | return None 24 | else: 25 | w = np.fromfile(f, np.int32, count=1) 26 | h = np.fromfile(f, np.int32, count=1) 27 | # print 'Reading %d x %d flo file\n' % (w, h) 28 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) 29 | # Reshape data into 3D array (columns, rows, bands) 30 | # The reshape here is for visualization, the original code is (w,h,2) 31 | return np.resize(data, (int(h), int(w), 2)) 32 | 33 | def readPFM(file): 34 | file = open(file, 'rb') 35 | 36 | color = None 37 | width = None 38 | height = None 39 | scale = None 40 | endian = None 41 | 42 | header = file.readline().rstrip() 43 | if header == b'PF': 44 | color = True 45 | elif header == b'Pf': 46 | color = False 47 | else: 48 | raise Exception('Not a PFM file.') 49 | 50 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) 51 | if dim_match: 52 | width, height = map(int, dim_match.groups()) 53 | else: 54 | raise Exception('Malformed PFM header.') 55 | 56 | scale = float(file.readline().rstrip()) 57 | if scale < 0: # little-endian 58 | endian = '<' 59 | scale = -scale 60 | else: 61 | endian = '>' # big-endian 62 | 63 | data = np.fromfile(file, endian + 'f') 64 | shape = (height, width, 3) if color else (height, width) 65 | 66 | data = np.reshape(data, shape) 67 | data = np.flipud(data) 68 | return data 69 | 70 | def writeFlow(filename,uv,v=None): 71 | """ Write optical flow to file. 72 | 73 | If v is None, uv is assumed to contain both u and v channels, 74 | stacked in depth. 75 | Original code by Deqing Sun, adapted from Daniel Scharstein. 76 | """ 77 | nBands = 2 78 | 79 | if v is None: 80 | assert(uv.ndim == 3) 81 | assert(uv.shape[2] == 2) 82 | u = uv[:,:,0] 83 | v = uv[:,:,1] 84 | else: 85 | u = uv 86 | 87 | assert(u.shape == v.shape) 88 | height,width = u.shape 89 | f = open(filename,'wb') 90 | # write the header 91 | f.write(TAG_CHAR) 92 | np.array(width).astype(np.int32).tofile(f) 93 | np.array(height).astype(np.int32).tofile(f) 94 | # arrange into matrix form 95 | tmp = np.zeros((height, width*nBands)) 96 | tmp[:,np.arange(width)*2] = u 97 | tmp[:,np.arange(width)*2 + 1] = v 98 | tmp.astype(np.float32).tofile(f) 99 | f.close() 100 | 101 | 102 | def readFlowKITTI(filename): 103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) 104 | flow = flow[:,:,::-1].astype(np.float32) 105 | flow, valid = flow[:, :, :2], flow[:, :, 2] 106 | flow = (flow - 2**15) / 64.0 107 | return flow, valid 108 | 109 | def readDispKITTI(filename): 110 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 111 | valid = disp > 0.0 112 | flow = np.stack([-disp, np.zeros_like(disp)], -1) 113 | return flow, valid 114 | 115 | 116 | def writeFlowKITTI(filename, uv): 117 | uv = 64.0 * uv + 2**15 118 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 119 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 120 | cv2.imwrite(filename, uv[..., ::-1]) 121 | 122 | 123 | def read_gen(file_name, pil=False): 124 | ext = splitext(file_name)[-1] 125 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 126 | return Image.open(file_name) 127 | elif ext == '.bin' or ext == '.raw': 128 | return np.load(file_name) 129 | elif ext == '.flo': 130 | return readFlow(file_name).astype(np.float32) 131 | elif ext == '.pfm': 132 | flow = readPFM(file_name).astype(np.float32) 133 | if len(flow.shape) == 2: 134 | return flow 135 | else: 136 | return flow[:, :, :-1] 137 | return [] -------------------------------------------------------------------------------- /core/utils/logger.py: -------------------------------------------------------------------------------- 1 | from torch.utils.tensorboard import SummaryWriter 2 | from loguru import logger as loguru_logger 3 | 4 | class Logger: 5 | def __init__(self, model, scheduler, cfg): 6 | self.model = model 7 | self.scheduler = scheduler 8 | self.total_steps = 0 9 | self.running_loss = {} 10 | self.writer = None 11 | self.cfg = cfg 12 | 13 | def _print_training_status(self): 14 | metrics_data = [self.running_loss[k]/self.cfg.sum_freq for k in sorted(self.running_loss.keys())] 15 | training_str = "[{:6d}, {}] ".format(self.total_steps+1, self.scheduler.get_last_lr()) 16 | metrics_str = ("{:10.4f}, "*len(metrics_data)).format(*metrics_data) 17 | 18 | # print the training status 19 | loguru_logger.info(training_str + metrics_str) 20 | 21 | if self.writer is None: 22 | if self.cfg.log_dir is None: 23 | self.writer = SummaryWriter() 24 | else: 25 | self.writer = SummaryWriter(self.cfg.log_dir) 26 | 27 | for k in self.running_loss: 28 | self.writer.add_scalar(k, self.running_loss[k]/self.cfg.sum_freq, self.total_steps) 29 | self.running_loss[k] = 0.0 30 | 31 | def push(self, metrics): 32 | self.total_steps += 1 33 | 34 | for key in metrics: 35 | if key not in self.running_loss: 36 | self.running_loss[key] = 0.0 37 | 38 | self.running_loss[key] += metrics[key] 39 | 40 | if self.total_steps % self.cfg.sum_freq == self.cfg.sum_freq-1: 41 | self._print_training_status() 42 | self.running_loss = {} 43 | 44 | def write_dict(self, results): 45 | if self.writer is None: 46 | self.writer = SummaryWriter() 47 | 48 | for key in results: 49 | self.writer.add_scalar(key, results[key], self.total_steps) 50 | 51 | def close(self): 52 | self.writer.close() 53 | 54 | -------------------------------------------------------------------------------- /core/utils/misc.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import shutil 4 | 5 | def process_transformer_cfg(cfg): 6 | log_dir = '' 7 | if 'critical_params' in cfg: 8 | critical_params = [cfg[key] for key in cfg.critical_params] 9 | for name, param in zip(cfg["critical_params"], critical_params): 10 | log_dir += "{:s}[{:s}]".format(name, str(param)) 11 | 12 | return log_dir 13 | 14 | def process_cfg(cfg): 15 | log_dir = 'logs/' + cfg.name + '/' + cfg.network + '/' 16 | critical_params = [cfg.trainer[key] for key in cfg.critical_params] 17 | for name, param in zip(cfg["critical_params"], critical_params): 18 | log_dir += "{:s}[{:s}]".format(name, str(param)) 19 | 20 | log_dir += process_transformer_cfg(cfg[cfg.network]) 21 | 22 | now = time.localtime() 23 | now_time = '{:02d}_{:02d}_{:02d}_{:02d}'.format(now.tm_mon, now.tm_mday, now.tm_hour, now.tm_min) 24 | log_dir += cfg.suffix + '(' + now_time + ')' 25 | cfg.log_dir = log_dir 26 | os.makedirs(log_dir) 27 | 28 | shutil.copytree('configs', f'{log_dir}/configs') 29 | shutil.copytree('core', f'{log_dir}/core') -------------------------------------------------------------------------------- /core/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy import interpolate 5 | 6 | 7 | class InputPadder: 8 | """ Pads images such that dimensions are divisible by 8 """ 9 | def __init__(self, dims, mode='sintel'): 10 | self.ht, self.wd = dims[-2:] 11 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 12 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 13 | self.mode = mode 14 | if mode == 'sintel': 15 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2, 0, 0] 16 | elif mode == "downzero": 17 | self._pad = [0, pad_wd, 0, pad_ht, 0, 0] 18 | else: 19 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht, 0, 0] 20 | 21 | def pad(self, input): 22 | if self.mode == "downzero": 23 | return F.pad(input, self._pad) 24 | else: 25 | return F.pad(input, self._pad, mode='replicate') 26 | 27 | def unpad(self,x): 28 | ht, wd = x.shape[-2:] 29 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 30 | return x[..., c[0]:c[1], c[2]:c[3]] 31 | 32 | def forward_interpolate(flow): 33 | flow = flow.detach().cpu().numpy() 34 | dx, dy = flow[0], flow[1] 35 | 36 | ht, wd = dx.shape 37 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 38 | 39 | x1 = x0 + dx 40 | y1 = y0 + dy 41 | 42 | x1 = x1.reshape(-1) 43 | y1 = y1.reshape(-1) 44 | dx = dx.reshape(-1) 45 | dy = dy.reshape(-1) 46 | 47 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 48 | x1 = x1[valid] 49 | y1 = y1[valid] 50 | dx = dx[valid] 51 | dy = dy[valid] 52 | 53 | flow_x = interpolate.griddata( 54 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 55 | 56 | flow_y = interpolate.griddata( 57 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) 58 | 59 | flow = np.stack([flow_x, flow_y], axis=0) 60 | return torch.from_numpy(flow).float() 61 | 62 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 63 | """ Wrapper for grid_sample, uses pixel coordinates """ 64 | H, W = img.shape[-2:] 65 | xgrid, ygrid = coords.split([1,1], dim=-1) 66 | xgrid = 2*xgrid/(W-1) - 1 67 | ygrid = 2*ygrid/(H-1) - 1 68 | 69 | grid = torch.cat([xgrid, ygrid], dim=-1) 70 | img = F.grid_sample(img, grid, align_corners=True) 71 | 72 | if mask: 73 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 74 | return img, mask.float() 75 | 76 | return img 77 | 78 | def indexing(img, coords, mask=False): 79 | """ Wrapper for grid_sample, uses pixel coordinates """ 80 | """ 81 | TODO: directly indexing features instead of sampling 82 | """ 83 | H, W = img.shape[-2:] 84 | xgrid, ygrid = coords.split([1,1], dim=-1) 85 | xgrid = 2*xgrid/(W-1) - 1 86 | ygrid = 2*ygrid/(H-1) - 1 87 | 88 | grid = torch.cat([xgrid, ygrid], dim=-1) 89 | img = F.grid_sample(img, grid, align_corners=True, mode='nearest') 90 | 91 | if mask: 92 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 93 | return img, mask.float() 94 | 95 | return img 96 | 97 | def coords_grid(batch, ht, wd): 98 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) 99 | coords = torch.stack(coords[::-1], dim=0).float() 100 | return coords[None].repeat(batch, 1, 1, 1) 101 | 102 | 103 | def upflow8(flow, mode='bilinear'): 104 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 105 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 106 | -------------------------------------------------------------------------------- /demo_input_images/frame_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiaoyuShi97/VideoFlow/51489304db6e75fbdd9ff64d4656c1d120b5a673/demo_input_images/frame_0001.png -------------------------------------------------------------------------------- /demo_input_images/frame_0002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiaoyuShi97/VideoFlow/51489304db6e75fbdd9ff64d4656c1d120b5a673/demo_input_images/frame_0002.png -------------------------------------------------------------------------------- /demo_input_images/frame_0003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiaoyuShi97/VideoFlow/51489304db6e75fbdd9ff64d4656c1d120b5a673/demo_input_images/frame_0003.png -------------------------------------------------------------------------------- /demo_input_images/frame_0004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiaoyuShi97/VideoFlow/51489304db6e75fbdd9ff64d4656c1d120b5a673/demo_input_images/frame_0004.png -------------------------------------------------------------------------------- /demo_input_images/frame_0005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiaoyuShi97/VideoFlow/51489304db6e75fbdd9ff64d4656c1d120b5a673/demo_input_images/frame_0005.png -------------------------------------------------------------------------------- /demo_input_images/frame_0006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiaoyuShi97/VideoFlow/51489304db6e75fbdd9ff64d4656c1d120b5a673/demo_input_images/frame_0006.png -------------------------------------------------------------------------------- /demo_input_images/frame_0007.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiaoyuShi97/VideoFlow/51489304db6e75fbdd9ff64d4656c1d120b5a673/demo_input_images/frame_0007.png -------------------------------------------------------------------------------- /demo_input_images/frame_0008.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiaoyuShi97/VideoFlow/51489304db6e75fbdd9ff64d4656c1d120b5a673/demo_input_images/frame_0008.png -------------------------------------------------------------------------------- /demo_input_images/frame_0009.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiaoyuShi97/VideoFlow/51489304db6e75fbdd9ff64d4656c1d120b5a673/demo_input_images/frame_0009.png -------------------------------------------------------------------------------- /demo_input_images/frame_0010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiaoyuShi97/VideoFlow/51489304db6e75fbdd9ff64d4656c1d120b5a673/demo_input_images/frame_0010.png -------------------------------------------------------------------------------- /evaluate_BOFNet.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('core') 3 | 4 | from PIL import Image 5 | import argparse 6 | import os 7 | import time 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | import matplotlib.pyplot as plt 12 | from configs.sintel_submission import get_cfg 13 | from core.utils.misc import process_cfg 14 | from utils import flow_viz 15 | import core.datasets_3frames as datasets 16 | from core import datasets_multiframes 17 | 18 | from core.Networks import build_network 19 | 20 | from utils import frame_utils 21 | from utils.utils import InputPadder, forward_interpolate 22 | import itertools 23 | 24 | 25 | @torch.no_grad() 26 | def create_sintel_submission(model, output_path='output'): 27 | """ Create submission for the Sintel leaderboard """ 28 | print("no warm start") 29 | results = {} 30 | model.eval() 31 | 32 | for dstype in ['final', 'clean']: 33 | test_dataset = datasets.MpiSintel_submission(split='test', aug_params=None, dstype=dstype, root="Sintel-test", reverse_rate=-1) 34 | 35 | for test_id in range(len(test_dataset)): 36 | if (test_id+1) % 100 == 0: 37 | print(f"{test_id} / {len(test_dataset)}") 38 | 39 | images, (sequence, frame) = test_dataset[test_id] 40 | images = images[None].cuda() 41 | 42 | padder = InputPadder(images.shape) 43 | images = padder.pad(images) 44 | 45 | flow_pre, _ = model(images, {}) 46 | 47 | flow = padder.unpad(flow_pre[0][0]).permute(1, 2, 0).cpu().numpy() 48 | 49 | # flow_img = flow_viz.flow_to_image(flow) 50 | # image = Image.fromarray(flow_img) 51 | # if not os.path.exists(f'vis_sintel_3frames_f'): 52 | # os.makedirs(f'vis_sintel_3frames_f/flow') 53 | 54 | # image.save(f'vis_sintel_3frames_f/flow/{sequence}_{frame}_forward.png') 55 | 56 | output_dir = os.path.join(output_path, dstype, sequence) 57 | output_file = os.path.join(output_dir, 'frame%04d.flo' % (frame+1)) 58 | 59 | if not os.path.exists(output_dir): 60 | os.makedirs(output_dir) 61 | 62 | frame_utils.writeFlow(output_file, flow) 63 | 64 | return results 65 | 66 | @torch.no_grad() 67 | def validate_sintel(model): 68 | """ Peform validation using the Sintel (train) split """ 69 | 70 | model.eval() 71 | results = {} 72 | 73 | records = [] 74 | 75 | boundary_index = [0, 19, 68, 117, 166, 215, 264, 313, 352, 401, 421, 470, 519, 568, 617, 666, 715, 764, 813, 862, 911, 943, 992] 76 | 77 | for dstype in ['final', "clean"]: 78 | val_dataset = datasets.MpiSintel(split='training', dstype=dstype, reverse_rate=-1) 79 | 80 | epe_list = [] 81 | epe_list_no_boundary = [] 82 | 83 | for val_id in range(len(val_dataset)): 84 | if val_id % 50 == 0: 85 | print(val_id) 86 | 87 | images, flows, valids = val_dataset[val_id] 88 | 89 | images = images[None].cuda() 90 | 91 | padder = InputPadder(images.shape) 92 | images = padder.pad(images) 93 | 94 | flow_pre, _ = model(images, {}) 95 | 96 | flow = padder.unpad(flow_pre[0][0]).cpu() 97 | flow_gt = flows[0] 98 | 99 | epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt() 100 | 101 | records.append("{}\n".format(torch.mean(epe))) 102 | 103 | epe_list.append(epe.view(-1).numpy()) 104 | 105 | if val_id not in boundary_index: 106 | epe_list_no_boundary.append(epe.view(-1).numpy()) 107 | else: 108 | print("skip~", val_id) 109 | 110 | epe_all = np.concatenate(epe_list) 111 | epe = np.mean(epe_all) 112 | px1 = np.mean(epe_all<1) 113 | px3 = np.mean(epe_all<3) 114 | px5 = np.mean(epe_all<5) 115 | 116 | print("Validation (%s) EPE: %f, 1px: %f, 3px: %f, 5px: %f" % (dstype, epe, px1, px3, px5)) 117 | results[dstype] = np.mean(epe_list) 118 | 119 | epe_all = np.concatenate(epe_list_no_boundary) 120 | epe = np.mean(epe_all) 121 | px1 = np.mean(epe_all<1) 122 | px3 = np.mean(epe_all<3) 123 | px5 = np.mean(epe_all<5) 124 | 125 | print("Validation (%s) no boundary EPE: %f, 1px: %f, 3px: %f, 5px: %f" % (dstype, epe, px1, px3, px5)) 126 | 127 | return results 128 | 129 | @torch.no_grad() 130 | def validate_things(model): 131 | """ Peform validation using the Sintel (train) split """ 132 | 133 | model.eval() 134 | results = {} 135 | 136 | for dstype in ['frames_cleanpass', "frames_finalpass"]: 137 | val_dataset = datasets_multiframes.ThingsTEST(dstype=dstype, input_frames=3, return_gt=True) 138 | 139 | epe_list = [] 140 | epe_list_no_boundary = [] 141 | 142 | records = [] 143 | import pickle 144 | 145 | for val_id in range(len(val_dataset)): 146 | if val_id % 50 == 0: 147 | print(val_id) 148 | 149 | images, flows, extra_info = val_dataset[val_id] 150 | 151 | images = images[None].cuda() 152 | # images = torch.flip(images, dims=[1]) 153 | 154 | padder = InputPadder(images.shape) 155 | images = padder.pad(images) 156 | 157 | flow_pre, _ = model(images, {}) 158 | 159 | flow_pre = padder.unpad(flow_pre[0]).cpu() 160 | 161 | flow_pre = flow_pre[:flow_pre.shape[0]//2, ...][-flows.shape[0]:, ...] 162 | # flow_pre = flow_pre[1:, ...] 163 | 164 | epe = torch.sum((flow_pre - flows)**2, dim=1).sqrt() 165 | valid = torch.sum(flows**2, dim=1).sqrt() < 400 166 | this_error = epe.view(-1)[valid.view(-1)].mean().item() 167 | #records.append(this_error) 168 | epe_list.append(epe.view(-1)[valid.view(-1)].numpy()) 169 | 170 | records.append(extra_info) 171 | 172 | flow_pre = flow_pre[0].permute(1, 2, 0).numpy() 173 | flow_gt = flows[0].permute(1, 2, 0).numpy() 174 | 175 | epe_all = np.concatenate(epe_list) 176 | epe = np.mean(epe_all) 177 | px1 = np.mean(epe_all<1) 178 | px3 = np.mean(epe_all<3) 179 | px5 = np.mean(epe_all<5) 180 | 181 | print("Validation (%s) EPE: %f, 1px: %f, 3px: %f, 5px: %f" % (dstype, epe, px1, px3, px5)) 182 | results[dstype] = epe 183 | return results 184 | 185 | @torch.no_grad() 186 | def validate_kitti(model): 187 | """ Peform validation using the Sintel (train) split """ 188 | 189 | model.eval() 190 | results = {} 191 | 192 | val_dataset = datasets_multiframes.KITTITest(input_frames=3, aug_params=None, reverse_rate=0) 193 | 194 | epe_list = [] 195 | out_list = [] 196 | 197 | for val_id in range(len(val_dataset)): 198 | if val_id % 50 == 0: 199 | print(val_id) 200 | 201 | images, flows, valids = val_dataset[val_id] 202 | 203 | images = images[None].cuda() 204 | 205 | padder = InputPadder(images.shape) 206 | images = padder.pad(images) 207 | 208 | flow_pre, _ = model(images, {}) 209 | 210 | flow_pre = padder.unpad(flow_pre[0]).cpu() 211 | 212 | flow_pre = flow_pre[0] 213 | valids = valids[0] 214 | flows = flows[0] 215 | 216 | epe = torch.sum((flow_pre - flows)**2, dim=0).sqrt() 217 | mag = torch.sum(flows**2, dim=0).sqrt() 218 | 219 | epe = epe.view(-1) 220 | mag = mag.view(-1) 221 | val = valids.view(-1) >= 0.5 222 | 223 | out = ((epe > 3.0) & ((epe/mag) > 0.05)).float() 224 | epe_list.append(epe[val].mean().item()) 225 | out_list.append(out[val].cpu().numpy()) 226 | 227 | epe_list = np.array(epe_list) 228 | out_list = np.concatenate(out_list) 229 | 230 | epe = np.mean(epe_list) 231 | f1 = 100 * np.mean(out_list) 232 | 233 | print("Validation KITTI: %f, %f" % (epe, f1)) 234 | return 235 | 236 | @torch.no_grad() 237 | def create_kitti_submission(model, output_path): 238 | """ Peform validation using the Sintel (train) split """ 239 | 240 | model.eval() 241 | results = {} 242 | 243 | val_dataset = datasets_multiframes.KITTISubmission(input_frames=3, return_gt=False) 244 | 245 | epe_list = [] 246 | out_list = [] 247 | 248 | if not os.path.exists(output_path): 249 | os.makedirs(output_path) 250 | 251 | for val_id in range(len(val_dataset)): 252 | 253 | images, frame_id = val_dataset[val_id] 254 | 255 | print(frame_id, images.shape) 256 | 257 | images = images[None].cuda() 258 | 259 | padder = InputPadder(images.shape) 260 | images = padder.pad(images) 261 | 262 | flow_pre, _ = model(images, {}) 263 | 264 | flow_pre = padder.unpad(flow_pre[0]).cpu() 265 | 266 | flow_pre = flow_pre[0].permute(1, 2, 0).numpy() 267 | 268 | output_filename = os.path.join(output_path, frame_id) 269 | frame_utils.writeFlowKITTI(output_filename, flow_pre) 270 | 271 | flow_img = flow_viz.flow_to_image(flow_pre) 272 | image = Image.fromarray(flow_img) 273 | 274 | if not os.path.exists(f'vis_kitti'): 275 | os.makedirs(f'vis_kitti/flow') 276 | 277 | image.save(f'vis_kitti/flow/{frame_id}.png') 278 | 279 | return 280 | 281 | def count_parameters(model): 282 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 283 | 284 | if __name__ == '__main__': 285 | parser = argparse.ArgumentParser() 286 | #parser.add_argument('--model', help="restore checkpoint") 287 | parser.add_argument('--dataset', help="dataset for evaluation") 288 | parser.add_argument('--small', action='store_true', help='use small model') 289 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') 290 | parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation') 291 | args = parser.parse_args() 292 | cfg = get_cfg() 293 | cfg.update(vars(args)) 294 | 295 | model = torch.nn.DataParallel(build_network(cfg)) 296 | 297 | if cfg.model is not None: 298 | model.load_state_dict(torch.load(cfg.model)) 299 | else: 300 | print("[Not loading pretrained checkpoint]") 301 | 302 | model.cuda() 303 | model.eval() 304 | 305 | print(cfg.model) 306 | print("Parameter Count: %d" % count_parameters(model)) 307 | print(args.dataset) 308 | with torch.no_grad(): 309 | if args.dataset == 'sintel': 310 | validate_sintel(model.module) 311 | elif args.dataset == 'things': 312 | validate_things(model.module) 313 | elif args.dataset == 'kitti': 314 | validate_kitti(model.module) 315 | elif args.dataset == 'kitti_submission': 316 | create_kitti_submission(model.module, output_path="flow") 317 | elif args.dataset == 'sintel_submission': 318 | create_sintel_submission(model.module, output_path="output") 319 | 320 | 321 | print(cfg.model) 322 | 323 | 324 | -------------------------------------------------------------------------------- /flow_dataset_mf/convert_HD1K.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os.path as osp 3 | from glob import glob 4 | import os 5 | import pickle 6 | 7 | root = "/mnt/lustre/share/cp/caodongliang/HD1K/" 8 | 9 | image_list = [] 10 | flow_list = [] 11 | 12 | seq_ix = 0 13 | 14 | while 1: 15 | flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix))) 16 | images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix))) 17 | 18 | if len(flows) == 0: 19 | break 20 | 21 | print(seq_ix, len(flows), images[0], images[-1], "!!!!!!!!!!!!!!") 22 | 23 | for idx in range(len(images)): 24 | images[idx] = images[idx].replace("/mnt/lustre/share/cp/caodongliang/HD1K", "HD1K") 25 | for idx in range(len(flows)): 26 | flows[idx] = flows[idx].replace("/mnt/lustre/share/cp/caodongliang/HD1K", "HD1K") 27 | 28 | seq_ix += 1 29 | 30 | image_list.append(images) 31 | flow_list.append(flows) 32 | 33 | with open("hd1k_png.pkl", 'wb') as f: 34 | pickle.dump(image_list, f) 35 | with open("hd1k_flo.pkl", 'wb') as f: 36 | pickle.dump(flow_list, f) -------------------------------------------------------------------------------- /flow_dataset_mf/convert_sintel.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os.path as osp 3 | from glob import glob 4 | import os 5 | 6 | import pickle 7 | 8 | root = "/mnt/lustre/share/cp/caodongliang/MPI-Sintel/" 9 | 10 | for split in ['training']: 11 | for dstype in ['clean', 'final']: 12 | image_list = [] 13 | flow_list = [] 14 | extra_info_list = [] 15 | 16 | flow_root = osp.join(root, split, 'flow') 17 | image_root = osp.join(root, split, dstype) 18 | 19 | for scene in os.listdir(image_root): 20 | images = sorted(glob(osp.join(image_root, scene, '*.png'))) 21 | flows = sorted(glob(osp.join(flow_root, scene, '*.flo'))) 22 | 23 | for idx in range(len(images)): 24 | images[idx] = images[idx].replace("/mnt/lustre/share/cp/caodongliang/MPI-Sintel", "Sintel") + "\n" 25 | for idx in range(len(flows)): 26 | flows[idx] = flows[idx].replace("/mnt/lustre/share/cp/caodongliang/MPI-Sintel", "Sintel") + "\n" 27 | 28 | image_list.append(images) 29 | flow_list.append(flows) 30 | extra_info_list.append(scene) 31 | 32 | with open("sintel_training_"+dstype+"_png.pkl", 'wb') as f: 33 | pickle.dump(image_list, f) 34 | with open("sintel_training_"+dstype+"_flo.pkl", 'wb') as f: 35 | pickle.dump(flow_list, f) 36 | with open("sintel_training_scene.pkl", 'wb') as f: 37 | pickle.dump(extra_info_list, f) 38 | 39 | 40 | -------------------------------------------------------------------------------- /flow_dataset_mf/convert_things.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os.path as osp 3 | from glob import glob 4 | import pickle 5 | 6 | root = "/mnt/lustre/share/cp/caodongliang/FlyingThings3D/" 7 | 8 | for dstype in ['frames_cleanpass', 'frames_finalpass']: 9 | image_list = [] 10 | fflow_list = [] 11 | pflow_list = [] 12 | 13 | 14 | for cam in ['left']: 15 | image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*'))) 16 | image_dirs = sorted([osp.join(f, cam) for f in image_dirs]) 17 | 18 | flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*'))) 19 | flow_future_dirs = sorted([osp.join(f, 'into_future', cam) for f in flow_dirs]) 20 | flow_past_dirs = sorted([osp.join(f, 'into_past', cam) for f in flow_dirs]) 21 | 22 | for idir, fdir, pdir in zip(image_dirs, flow_future_dirs, flow_past_dirs): 23 | images = sorted(glob(osp.join(idir, '*.png')) ) 24 | future_flows = sorted(glob(osp.join(fdir, '*.pfm')) ) 25 | past_flows = sorted(glob(osp.join(pdir, '*.pfm')) ) 26 | 27 | for idx in range(len(images)): 28 | images[idx] = images[idx].replace("/mnt/lustre/share/cp/caodongliang/FlyingThings3D", "flow_data") + "\n" 29 | for idx in range(len(future_flows)): 30 | future_flows[idx] = future_flows[idx].replace("/mnt/lustre/share/cp/caodongliang/FlyingThings3D", "flow_data") + "\n" 31 | for idx in range(len(past_flows)): 32 | past_flows[idx] = past_flows[idx].replace("/mnt/lustre/share/cp/caodongliang/FlyingThings3D", "flow_data") + "\n" 33 | 34 | image_list.append(images) 35 | fflow_list.append(future_flows) 36 | pflow_list.append(past_flows) 37 | 38 | with open("flyingthings_"+dstype+"_png.pkl", 'wb') as f: 39 | pickle.dump(image_list, f) 40 | with open("flyingthings_"+dstype+"_future_pfm.pkl", 'wb') as f: 41 | pickle.dump(fflow_list, f) 42 | with open("flyingthings_"+dstype+"_past_pfm.pkl", 'wb') as f: 43 | pickle.dump(pflow_list, f) 44 | -------------------------------------------------------------------------------- /flow_dataset_mf/flyingthings_thres5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiaoyuShi97/VideoFlow/51489304db6e75fbdd9ff64d4656c1d120b5a673/flow_dataset_mf/flyingthings_thres5.pkl -------------------------------------------------------------------------------- /flow_dataset_mf/sintel_training_scene.pkl: -------------------------------------------------------------------------------- 1 | (lp0 2 | S'ambush_6' 3 | p1 4 | aS'bandage_1' 5 | p2 6 | aS'market_2' 7 | p3 8 | aS'cave_2' 9 | p4 10 | aS'temple_2' 11 | p5 12 | aS'shaman_3' 13 | p6 14 | aS'sleeping_2' 15 | p7 16 | aS'market_6' 17 | p8 18 | aS'cave_4' 19 | p9 20 | aS'ambush_2' 21 | p10 22 | aS'market_5' 23 | p11 24 | aS'alley_2' 25 | p12 26 | aS'ambush_5' 27 | p13 28 | aS'ambush_7' 29 | p14 30 | aS'temple_3' 31 | p15 32 | aS'bandage_2' 33 | p16 34 | aS'mountain_1' 35 | p17 36 | aS'alley_1' 37 | p18 38 | aS'shaman_2' 39 | p19 40 | aS'bamboo_1' 41 | p20 42 | aS'ambush_4' 43 | p21 44 | aS'bamboo_2' 45 | p22 46 | aS'sleeping_1' 47 | p23 48 | a. -------------------------------------------------------------------------------- /flow_datasets/KITTI/KITTI_testing_extra_info.txt: -------------------------------------------------------------------------------- 1 | 000000_10.png 2 | 000001_10.png 3 | 000002_10.png 4 | 000003_10.png 5 | 000004_10.png 6 | 000005_10.png 7 | 000006_10.png 8 | 000007_10.png 9 | 000008_10.png 10 | 000009_10.png 11 | 000010_10.png 12 | 000011_10.png 13 | 000012_10.png 14 | 000013_10.png 15 | 000014_10.png 16 | 000015_10.png 17 | 000016_10.png 18 | 000017_10.png 19 | 000018_10.png 20 | 000019_10.png 21 | 000020_10.png 22 | 000021_10.png 23 | 000022_10.png 24 | 000023_10.png 25 | 000024_10.png 26 | 000025_10.png 27 | 000026_10.png 28 | 000027_10.png 29 | 000028_10.png 30 | 000029_10.png 31 | 000030_10.png 32 | 000031_10.png 33 | 000032_10.png 34 | 000033_10.png 35 | 000034_10.png 36 | 000035_10.png 37 | 000036_10.png 38 | 000037_10.png 39 | 000038_10.png 40 | 000039_10.png 41 | 000040_10.png 42 | 000041_10.png 43 | 000042_10.png 44 | 000043_10.png 45 | 000044_10.png 46 | 000045_10.png 47 | 000046_10.png 48 | 000047_10.png 49 | 000048_10.png 50 | 000049_10.png 51 | 000050_10.png 52 | 000051_10.png 53 | 000052_10.png 54 | 000053_10.png 55 | 000054_10.png 56 | 000055_10.png 57 | 000056_10.png 58 | 000057_10.png 59 | 000058_10.png 60 | 000059_10.png 61 | 000060_10.png 62 | 000061_10.png 63 | 000062_10.png 64 | 000063_10.png 65 | 000064_10.png 66 | 000065_10.png 67 | 000066_10.png 68 | 000067_10.png 69 | 000068_10.png 70 | 000069_10.png 71 | 000070_10.png 72 | 000071_10.png 73 | 000072_10.png 74 | 000073_10.png 75 | 000074_10.png 76 | 000075_10.png 77 | 000076_10.png 78 | 000077_10.png 79 | 000078_10.png 80 | 000079_10.png 81 | 000080_10.png 82 | 000081_10.png 83 | 000082_10.png 84 | 000083_10.png 85 | 000084_10.png 86 | 000085_10.png 87 | 000086_10.png 88 | 000087_10.png 89 | 000088_10.png 90 | 000089_10.png 91 | 000090_10.png 92 | 000091_10.png 93 | 000092_10.png 94 | 000093_10.png 95 | 000094_10.png 96 | 000095_10.png 97 | 000096_10.png 98 | 000097_10.png 99 | 000098_10.png 100 | 000099_10.png 101 | 000100_10.png 102 | 000101_10.png 103 | 000102_10.png 104 | 000103_10.png 105 | 000104_10.png 106 | 000105_10.png 107 | 000106_10.png 108 | 000107_10.png 109 | 000108_10.png 110 | 000109_10.png 111 | 000110_10.png 112 | 000111_10.png 113 | 000112_10.png 114 | 000113_10.png 115 | 000114_10.png 116 | 000115_10.png 117 | 000116_10.png 118 | 000117_10.png 119 | 000118_10.png 120 | 000119_10.png 121 | 000120_10.png 122 | 000121_10.png 123 | 000122_10.png 124 | 000123_10.png 125 | 000124_10.png 126 | 000125_10.png 127 | 000126_10.png 128 | 000127_10.png 129 | 000128_10.png 130 | 000129_10.png 131 | 000130_10.png 132 | 000131_10.png 133 | 000132_10.png 134 | 000133_10.png 135 | 000134_10.png 136 | 000135_10.png 137 | 000136_10.png 138 | 000137_10.png 139 | 000138_10.png 140 | 000139_10.png 141 | 000140_10.png 142 | 000141_10.png 143 | 000142_10.png 144 | 000143_10.png 145 | 000144_10.png 146 | 000145_10.png 147 | 000146_10.png 148 | 000147_10.png 149 | 000148_10.png 150 | 000149_10.png 151 | 000150_10.png 152 | 000151_10.png 153 | 000152_10.png 154 | 000153_10.png 155 | 000154_10.png 156 | 000155_10.png 157 | 000156_10.png 158 | 000157_10.png 159 | 000158_10.png 160 | 000159_10.png 161 | 000160_10.png 162 | 000161_10.png 163 | 000162_10.png 164 | 000163_10.png 165 | 000164_10.png 166 | 000165_10.png 167 | 000166_10.png 168 | 000167_10.png 169 | 000168_10.png 170 | 000169_10.png 171 | 000170_10.png 172 | 000171_10.png 173 | 000172_10.png 174 | 000173_10.png 175 | 000174_10.png 176 | 000175_10.png 177 | 000176_10.png 178 | 000177_10.png 179 | 000178_10.png 180 | 000179_10.png 181 | 000180_10.png 182 | 000181_10.png 183 | 000182_10.png 184 | 000183_10.png 185 | 000184_10.png 186 | 000185_10.png 187 | 000186_10.png 188 | 000187_10.png 189 | 000188_10.png 190 | 000189_10.png 191 | 000190_10.png 192 | 000191_10.png 193 | 000192_10.png 194 | 000193_10.png 195 | 000194_10.png 196 | 000195_10.png 197 | 000196_10.png 198 | 000197_10.png 199 | 000198_10.png 200 | 000199_10.png 201 | -------------------------------------------------------------------------------- /flow_datasets/KITTI/KITTI_training_extra_info.txt: -------------------------------------------------------------------------------- 1 | 000000_10.png 2 | 000001_10.png 3 | 000002_10.png 4 | 000003_10.png 5 | 000004_10.png 6 | 000005_10.png 7 | 000006_10.png 8 | 000007_10.png 9 | 000008_10.png 10 | 000009_10.png 11 | 000010_10.png 12 | 000011_10.png 13 | 000012_10.png 14 | 000013_10.png 15 | 000014_10.png 16 | 000015_10.png 17 | 000016_10.png 18 | 000017_10.png 19 | 000018_10.png 20 | 000019_10.png 21 | 000020_10.png 22 | 000021_10.png 23 | 000022_10.png 24 | 000023_10.png 25 | 000024_10.png 26 | 000025_10.png 27 | 000026_10.png 28 | 000027_10.png 29 | 000028_10.png 30 | 000029_10.png 31 | 000030_10.png 32 | 000031_10.png 33 | 000032_10.png 34 | 000033_10.png 35 | 000034_10.png 36 | 000035_10.png 37 | 000036_10.png 38 | 000037_10.png 39 | 000038_10.png 40 | 000039_10.png 41 | 000040_10.png 42 | 000041_10.png 43 | 000042_10.png 44 | 000043_10.png 45 | 000044_10.png 46 | 000045_10.png 47 | 000046_10.png 48 | 000047_10.png 49 | 000048_10.png 50 | 000049_10.png 51 | 000050_10.png 52 | 000051_10.png 53 | 000052_10.png 54 | 000053_10.png 55 | 000054_10.png 56 | 000055_10.png 57 | 000056_10.png 58 | 000057_10.png 59 | 000058_10.png 60 | 000059_10.png 61 | 000060_10.png 62 | 000061_10.png 63 | 000062_10.png 64 | 000063_10.png 65 | 000064_10.png 66 | 000065_10.png 67 | 000066_10.png 68 | 000067_10.png 69 | 000068_10.png 70 | 000069_10.png 71 | 000070_10.png 72 | 000071_10.png 73 | 000072_10.png 74 | 000073_10.png 75 | 000074_10.png 76 | 000075_10.png 77 | 000076_10.png 78 | 000077_10.png 79 | 000078_10.png 80 | 000079_10.png 81 | 000080_10.png 82 | 000081_10.png 83 | 000082_10.png 84 | 000083_10.png 85 | 000084_10.png 86 | 000085_10.png 87 | 000086_10.png 88 | 000087_10.png 89 | 000088_10.png 90 | 000089_10.png 91 | 000090_10.png 92 | 000091_10.png 93 | 000092_10.png 94 | 000093_10.png 95 | 000094_10.png 96 | 000095_10.png 97 | 000096_10.png 98 | 000097_10.png 99 | 000098_10.png 100 | 000099_10.png 101 | 000100_10.png 102 | 000101_10.png 103 | 000102_10.png 104 | 000103_10.png 105 | 000104_10.png 106 | 000105_10.png 107 | 000106_10.png 108 | 000107_10.png 109 | 000108_10.png 110 | 000109_10.png 111 | 000110_10.png 112 | 000111_10.png 113 | 000112_10.png 114 | 000113_10.png 115 | 000114_10.png 116 | 000115_10.png 117 | 000116_10.png 118 | 000117_10.png 119 | 000118_10.png 120 | 000119_10.png 121 | 000120_10.png 122 | 000121_10.png 123 | 000122_10.png 124 | 000123_10.png 125 | 000124_10.png 126 | 000125_10.png 127 | 000126_10.png 128 | 000127_10.png 129 | 000128_10.png 130 | 000129_10.png 131 | 000130_10.png 132 | 000131_10.png 133 | 000132_10.png 134 | 000133_10.png 135 | 000134_10.png 136 | 000135_10.png 137 | 000136_10.png 138 | 000137_10.png 139 | 000138_10.png 140 | 000139_10.png 141 | 000140_10.png 142 | 000141_10.png 143 | 000142_10.png 144 | 000143_10.png 145 | 000144_10.png 146 | 000145_10.png 147 | 000146_10.png 148 | 000147_10.png 149 | 000148_10.png 150 | 000149_10.png 151 | 000150_10.png 152 | 000151_10.png 153 | 000152_10.png 154 | 000153_10.png 155 | 000154_10.png 156 | 000155_10.png 157 | 000156_10.png 158 | 000157_10.png 159 | 000158_10.png 160 | 000159_10.png 161 | 000160_10.png 162 | 000161_10.png 163 | 000162_10.png 164 | 000163_10.png 165 | 000164_10.png 166 | 000165_10.png 167 | 000166_10.png 168 | 000167_10.png 169 | 000168_10.png 170 | 000169_10.png 171 | 000170_10.png 172 | 000171_10.png 173 | 000172_10.png 174 | 000173_10.png 175 | 000174_10.png 176 | 000175_10.png 177 | 000176_10.png 178 | 000177_10.png 179 | 000178_10.png 180 | 000179_10.png 181 | 000180_10.png 182 | 000181_10.png 183 | 000182_10.png 184 | 000183_10.png 185 | 000184_10.png 186 | 000185_10.png 187 | 000186_10.png 188 | 000187_10.png 189 | 000188_10.png 190 | 000189_10.png 191 | 000190_10.png 192 | 000191_10.png 193 | 000192_10.png 194 | 000193_10.png 195 | 000194_10.png 196 | 000195_10.png 197 | 000196_10.png 198 | 000197_10.png 199 | 000198_10.png 200 | 000199_10.png 201 | -------------------------------------------------------------------------------- /flow_datasets/KITTI/KITTI_training_flow.txt: -------------------------------------------------------------------------------- 1 | KITTI/training/flow_occ/000000_10.png 2 | KITTI/training/flow_occ/000001_10.png 3 | KITTI/training/flow_occ/000002_10.png 4 | KITTI/training/flow_occ/000003_10.png 5 | KITTI/training/flow_occ/000004_10.png 6 | KITTI/training/flow_occ/000005_10.png 7 | KITTI/training/flow_occ/000006_10.png 8 | KITTI/training/flow_occ/000007_10.png 9 | KITTI/training/flow_occ/000008_10.png 10 | KITTI/training/flow_occ/000009_10.png 11 | KITTI/training/flow_occ/000010_10.png 12 | KITTI/training/flow_occ/000011_10.png 13 | KITTI/training/flow_occ/000012_10.png 14 | KITTI/training/flow_occ/000013_10.png 15 | KITTI/training/flow_occ/000014_10.png 16 | KITTI/training/flow_occ/000015_10.png 17 | KITTI/training/flow_occ/000016_10.png 18 | KITTI/training/flow_occ/000017_10.png 19 | KITTI/training/flow_occ/000018_10.png 20 | KITTI/training/flow_occ/000019_10.png 21 | KITTI/training/flow_occ/000020_10.png 22 | KITTI/training/flow_occ/000021_10.png 23 | KITTI/training/flow_occ/000022_10.png 24 | KITTI/training/flow_occ/000023_10.png 25 | KITTI/training/flow_occ/000024_10.png 26 | KITTI/training/flow_occ/000025_10.png 27 | KITTI/training/flow_occ/000026_10.png 28 | KITTI/training/flow_occ/000027_10.png 29 | KITTI/training/flow_occ/000028_10.png 30 | KITTI/training/flow_occ/000029_10.png 31 | KITTI/training/flow_occ/000030_10.png 32 | KITTI/training/flow_occ/000031_10.png 33 | KITTI/training/flow_occ/000032_10.png 34 | KITTI/training/flow_occ/000033_10.png 35 | KITTI/training/flow_occ/000034_10.png 36 | KITTI/training/flow_occ/000035_10.png 37 | KITTI/training/flow_occ/000036_10.png 38 | KITTI/training/flow_occ/000037_10.png 39 | KITTI/training/flow_occ/000038_10.png 40 | KITTI/training/flow_occ/000039_10.png 41 | KITTI/training/flow_occ/000040_10.png 42 | KITTI/training/flow_occ/000041_10.png 43 | KITTI/training/flow_occ/000042_10.png 44 | KITTI/training/flow_occ/000043_10.png 45 | KITTI/training/flow_occ/000044_10.png 46 | KITTI/training/flow_occ/000045_10.png 47 | KITTI/training/flow_occ/000046_10.png 48 | KITTI/training/flow_occ/000047_10.png 49 | KITTI/training/flow_occ/000048_10.png 50 | KITTI/training/flow_occ/000049_10.png 51 | KITTI/training/flow_occ/000050_10.png 52 | KITTI/training/flow_occ/000051_10.png 53 | KITTI/training/flow_occ/000052_10.png 54 | KITTI/training/flow_occ/000053_10.png 55 | KITTI/training/flow_occ/000054_10.png 56 | KITTI/training/flow_occ/000055_10.png 57 | KITTI/training/flow_occ/000056_10.png 58 | KITTI/training/flow_occ/000057_10.png 59 | KITTI/training/flow_occ/000058_10.png 60 | KITTI/training/flow_occ/000059_10.png 61 | KITTI/training/flow_occ/000060_10.png 62 | KITTI/training/flow_occ/000061_10.png 63 | KITTI/training/flow_occ/000062_10.png 64 | KITTI/training/flow_occ/000063_10.png 65 | KITTI/training/flow_occ/000064_10.png 66 | KITTI/training/flow_occ/000065_10.png 67 | KITTI/training/flow_occ/000066_10.png 68 | KITTI/training/flow_occ/000067_10.png 69 | KITTI/training/flow_occ/000068_10.png 70 | KITTI/training/flow_occ/000069_10.png 71 | KITTI/training/flow_occ/000070_10.png 72 | KITTI/training/flow_occ/000071_10.png 73 | KITTI/training/flow_occ/000072_10.png 74 | KITTI/training/flow_occ/000073_10.png 75 | KITTI/training/flow_occ/000074_10.png 76 | KITTI/training/flow_occ/000075_10.png 77 | KITTI/training/flow_occ/000076_10.png 78 | KITTI/training/flow_occ/000077_10.png 79 | KITTI/training/flow_occ/000078_10.png 80 | KITTI/training/flow_occ/000079_10.png 81 | KITTI/training/flow_occ/000080_10.png 82 | KITTI/training/flow_occ/000081_10.png 83 | KITTI/training/flow_occ/000082_10.png 84 | KITTI/training/flow_occ/000083_10.png 85 | KITTI/training/flow_occ/000084_10.png 86 | KITTI/training/flow_occ/000085_10.png 87 | KITTI/training/flow_occ/000086_10.png 88 | KITTI/training/flow_occ/000087_10.png 89 | KITTI/training/flow_occ/000088_10.png 90 | KITTI/training/flow_occ/000089_10.png 91 | KITTI/training/flow_occ/000090_10.png 92 | KITTI/training/flow_occ/000091_10.png 93 | KITTI/training/flow_occ/000092_10.png 94 | KITTI/training/flow_occ/000093_10.png 95 | KITTI/training/flow_occ/000094_10.png 96 | KITTI/training/flow_occ/000095_10.png 97 | KITTI/training/flow_occ/000096_10.png 98 | KITTI/training/flow_occ/000097_10.png 99 | KITTI/training/flow_occ/000098_10.png 100 | KITTI/training/flow_occ/000099_10.png 101 | KITTI/training/flow_occ/000100_10.png 102 | KITTI/training/flow_occ/000101_10.png 103 | KITTI/training/flow_occ/000102_10.png 104 | KITTI/training/flow_occ/000103_10.png 105 | KITTI/training/flow_occ/000104_10.png 106 | KITTI/training/flow_occ/000105_10.png 107 | KITTI/training/flow_occ/000106_10.png 108 | KITTI/training/flow_occ/000107_10.png 109 | KITTI/training/flow_occ/000108_10.png 110 | KITTI/training/flow_occ/000109_10.png 111 | KITTI/training/flow_occ/000110_10.png 112 | KITTI/training/flow_occ/000111_10.png 113 | KITTI/training/flow_occ/000112_10.png 114 | KITTI/training/flow_occ/000113_10.png 115 | KITTI/training/flow_occ/000114_10.png 116 | KITTI/training/flow_occ/000115_10.png 117 | KITTI/training/flow_occ/000116_10.png 118 | KITTI/training/flow_occ/000117_10.png 119 | KITTI/training/flow_occ/000118_10.png 120 | KITTI/training/flow_occ/000119_10.png 121 | KITTI/training/flow_occ/000120_10.png 122 | KITTI/training/flow_occ/000121_10.png 123 | KITTI/training/flow_occ/000122_10.png 124 | KITTI/training/flow_occ/000123_10.png 125 | KITTI/training/flow_occ/000124_10.png 126 | KITTI/training/flow_occ/000125_10.png 127 | KITTI/training/flow_occ/000126_10.png 128 | KITTI/training/flow_occ/000127_10.png 129 | KITTI/training/flow_occ/000128_10.png 130 | KITTI/training/flow_occ/000129_10.png 131 | KITTI/training/flow_occ/000130_10.png 132 | KITTI/training/flow_occ/000131_10.png 133 | KITTI/training/flow_occ/000132_10.png 134 | KITTI/training/flow_occ/000133_10.png 135 | KITTI/training/flow_occ/000134_10.png 136 | KITTI/training/flow_occ/000135_10.png 137 | KITTI/training/flow_occ/000136_10.png 138 | KITTI/training/flow_occ/000137_10.png 139 | KITTI/training/flow_occ/000138_10.png 140 | KITTI/training/flow_occ/000139_10.png 141 | KITTI/training/flow_occ/000140_10.png 142 | KITTI/training/flow_occ/000141_10.png 143 | KITTI/training/flow_occ/000142_10.png 144 | KITTI/training/flow_occ/000143_10.png 145 | KITTI/training/flow_occ/000144_10.png 146 | KITTI/training/flow_occ/000145_10.png 147 | KITTI/training/flow_occ/000146_10.png 148 | KITTI/training/flow_occ/000147_10.png 149 | KITTI/training/flow_occ/000148_10.png 150 | KITTI/training/flow_occ/000149_10.png 151 | KITTI/training/flow_occ/000150_10.png 152 | KITTI/training/flow_occ/000151_10.png 153 | KITTI/training/flow_occ/000152_10.png 154 | KITTI/training/flow_occ/000153_10.png 155 | KITTI/training/flow_occ/000154_10.png 156 | KITTI/training/flow_occ/000155_10.png 157 | KITTI/training/flow_occ/000156_10.png 158 | KITTI/training/flow_occ/000157_10.png 159 | KITTI/training/flow_occ/000158_10.png 160 | KITTI/training/flow_occ/000159_10.png 161 | KITTI/training/flow_occ/000160_10.png 162 | KITTI/training/flow_occ/000161_10.png 163 | KITTI/training/flow_occ/000162_10.png 164 | KITTI/training/flow_occ/000163_10.png 165 | KITTI/training/flow_occ/000164_10.png 166 | KITTI/training/flow_occ/000165_10.png 167 | KITTI/training/flow_occ/000166_10.png 168 | KITTI/training/flow_occ/000167_10.png 169 | KITTI/training/flow_occ/000168_10.png 170 | KITTI/training/flow_occ/000169_10.png 171 | KITTI/training/flow_occ/000170_10.png 172 | KITTI/training/flow_occ/000171_10.png 173 | KITTI/training/flow_occ/000172_10.png 174 | KITTI/training/flow_occ/000173_10.png 175 | KITTI/training/flow_occ/000174_10.png 176 | KITTI/training/flow_occ/000175_10.png 177 | KITTI/training/flow_occ/000176_10.png 178 | KITTI/training/flow_occ/000177_10.png 179 | KITTI/training/flow_occ/000178_10.png 180 | KITTI/training/flow_occ/000179_10.png 181 | KITTI/training/flow_occ/000180_10.png 182 | KITTI/training/flow_occ/000181_10.png 183 | KITTI/training/flow_occ/000182_10.png 184 | KITTI/training/flow_occ/000183_10.png 185 | KITTI/training/flow_occ/000184_10.png 186 | KITTI/training/flow_occ/000185_10.png 187 | KITTI/training/flow_occ/000186_10.png 188 | KITTI/training/flow_occ/000187_10.png 189 | KITTI/training/flow_occ/000188_10.png 190 | KITTI/training/flow_occ/000189_10.png 191 | KITTI/training/flow_occ/000190_10.png 192 | KITTI/training/flow_occ/000191_10.png 193 | KITTI/training/flow_occ/000192_10.png 194 | KITTI/training/flow_occ/000193_10.png 195 | KITTI/training/flow_occ/000194_10.png 196 | KITTI/training/flow_occ/000195_10.png 197 | KITTI/training/flow_occ/000196_10.png 198 | KITTI/training/flow_occ/000197_10.png 199 | KITTI/training/flow_occ/000198_10.png 200 | KITTI/training/flow_occ/000199_10.png 201 | -------------------------------------------------------------------------------- /flow_datasets/KITTI/generate_KITTI_list.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import random 4 | from glob import glob 5 | import os.path as osp 6 | 7 | split = "testing" 8 | root = "KITTI" 9 | 10 | 11 | root = osp.join(root, split) 12 | images1 = sorted(glob(osp.join(root, 'image_2/*_10.png'))) 13 | images2 = sorted(glob(osp.join(root, 'image_2/*_11.png'))) 14 | 15 | extra_info = [] 16 | flow_list = [] 17 | image_list = [] 18 | 19 | for img1, img2 in zip(images1, images2): 20 | frame_id = img1.split('/')[-1] 21 | extra_info += [ frame_id+"\n" ] 22 | image_list += [ img1+"\n", img2+"\n" ] 23 | 24 | if split == 'training': 25 | _flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png'))) 26 | flow_list = [s+"\n" for s in _flow_list] 27 | 28 | print(len(image_list), len(flow_list), len(extra_info)) 29 | 30 | with open('KITTI_{}_image.txt'.format(split), 'w') as f: 31 | f.writelines(image_list) 32 | 33 | with open('KITTI_{}_flow.txt'.format(split), 'w') as f: 34 | f.writelines(flow_list) 35 | 36 | with open('KITTI_{}_extra_info.txt'.format(split), 'w') as f: 37 | f.writelines(extra_info) 38 | -------------------------------------------------------------------------------- /flow_datasets/flying_things_three_frames/convert_things.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os.path as osp 3 | from glob import glob 4 | 5 | root = "/mnt/lustre/share/cp/caodongliang/FlyingThings3D/" 6 | 7 | for dstype in ['frames_cleanpass', 'frames_finalpass']: 8 | image_list = [] 9 | flow_list = [] 10 | for cam in ['left']: 11 | image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*'))) 12 | image_dirs = sorted([osp.join(f, cam) for f in image_dirs]) 13 | 14 | flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*'))) 15 | flow_future_dirs = sorted([osp.join(f, 'into_future', cam) for f in flow_dirs]) 16 | flow_past_dirs = sorted([osp.join(f, 'into_past', cam) for f in flow_dirs]) 17 | 18 | for idir, fdir, pdir in zip(image_dirs, flow_future_dirs, flow_past_dirs): 19 | images = sorted(glob(osp.join(idir, '*.png')) ) 20 | future_flows = sorted(glob(osp.join(fdir, '*.pfm')) ) 21 | past_flows = sorted(glob(osp.join(pdir, '*.pfm')) ) 22 | 23 | for i in range(1, len(images)-1): 24 | image_list.append(images[i-1]) 25 | image_list.append(images[i]) 26 | image_list.append(images[i+1]) 27 | 28 | flow_list.append(future_flows[i]) 29 | flow_list.append(past_flows[i]) 30 | 31 | for idx in range(len(image_list)): 32 | image_list[idx] = image_list[idx].replace("/mnt/lustre/share/cp/caodongliang/FlyingThings3D", "flow_data") + "\n" 33 | for idx in range(len(flow_list)): 34 | flow_list[idx] = flow_list[idx].replace("/mnt/lustre/share/cp/caodongliang/FlyingThings3D", "flow_data") + "\n" 35 | 36 | 37 | with open(osp.join("flying_things_three_frames", "flyingthings_"+dstype+"_png.txt"), 'w') as f: 38 | f.writelines(image_list) 39 | print(len(image_list)) 40 | with open(osp.join("flying_things_three_frames", "flyingthings_"+dstype+"_pfm.txt"), 'w') as f: 41 | f.writelines(flow_list) 42 | print(len(flow_list)) 43 | 44 | 45 | 46 | -------------------------------------------------------------------------------- /flow_datasets/hd1k_three_frames/convert_HD1K.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os.path as osp 3 | from glob import glob 4 | import os 5 | 6 | root = "/mnt/lustre/share/cp/caodongliang/HD1K/" 7 | 8 | image_list = [] 9 | flow_list = [] 10 | 11 | seq_ix = 0 12 | 13 | while 1: 14 | flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix))) 15 | images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix))) 16 | 17 | if len(flows) == 0: 18 | break 19 | 20 | print(seq_ix, len(flows), images[0], images[-1], "!!!!!!!!!!!!!!") 21 | 22 | for i in range(len(images)-1): 23 | if i==0: 24 | image_list.append(images[0]) 25 | else: 26 | image_list.append(images[i-1]) 27 | 28 | image_list.append(images[i]) 29 | image_list.append(images[i+1]) 30 | 31 | flow_list.append(flows[i]) 32 | 33 | seq_ix += 1 34 | 35 | for idx in range(len(image_list)): 36 | image_list[idx] = image_list[idx].replace("/mnt/lustre/share/cp/caodongliang/HD1K", "HD1K") + "\n" 37 | for idx in range(len(flow_list)): 38 | flow_list[idx] = flow_list[idx].replace("/mnt/lustre/share/cp/caodongliang/HD1K", "HD1K") + "\n" 39 | 40 | with open(osp.join("hd1k_three_frames", "hd1k"+"_image.txt"), 'w') as f: 41 | f.writelines(image_list) 42 | print(len(image_list)) 43 | with open(osp.join("hd1k_three_frames", "hd1k"+"_flo.txt"), 'w') as f: 44 | f.writelines(flow_list) 45 | print(len(flow_list)) 46 | -------------------------------------------------------------------------------- /flow_datasets/sintel_three_frames/convert_sintel.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os.path as osp 3 | from glob import glob 4 | import os 5 | 6 | root = "/mnt/lustre/share/cp/caodongliang/MPI-Sintel/" 7 | 8 | for split in ['training']: 9 | for dstype in ['clean', 'final']: 10 | image_list = [] 11 | flow_list = [] 12 | extra_info_list = [] 13 | 14 | flow_root = osp.join(root, split, 'flow') 15 | image_root = osp.join(root, split, dstype) 16 | 17 | for scene in os.listdir(image_root): 18 | images = sorted(glob(osp.join(image_root, scene, '*.png'))) 19 | flows = sorted(glob(osp.join(flow_root, scene, '*.flo'))) 20 | 21 | for i in range(len(images)-1): 22 | if i==0: 23 | image_list.append(images[0]) 24 | else: 25 | image_list.append(images[i-1]) 26 | 27 | image_list.append(images[i]) 28 | image_list.append(images[i+1]) 29 | 30 | flow_list.append(flows[i]) 31 | extra_info_list.append(scene) 32 | extra_info_list.append(str(i)) 33 | 34 | for idx in range(len(image_list)): 35 | image_list[idx] = image_list[idx].replace("/mnt/lustre/share/cp/caodongliang/MPI-Sintel", "Sintel") + "\n" 36 | for idx in range(len(flow_list)): 37 | flow_list[idx] = flow_list[idx].replace("/mnt/lustre/share/cp/caodongliang/MPI-Sintel", "Sintel") + "\n" 38 | for idx in range(len(extra_info_list)): 39 | extra_info_list[idx] = extra_info_list[idx] + "\n" 40 | 41 | with open(osp.join("sintel_three_frames", "Sintel_"+dstype+"_png.txt"), 'w') as f: 42 | f.writelines(image_list) 43 | print(len(image_list)) 44 | with open(osp.join("sintel_three_frames", "Sintel_"+dstype+"_flo.txt"), 'w') as f: 45 | f.writelines(flow_list) 46 | print(len(flow_list)) 47 | with open(osp.join("sintel_three_frames", "Sintel_"+dstype+"_extra_info.txt"), 'w') as f: 48 | f.writelines(extra_info_list) 49 | print(len(extra_info_list)) -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('core') 3 | 4 | from PIL import Image 5 | import argparse 6 | import os 7 | import time 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | import matplotlib.pyplot as plt 12 | from core.utils.misc import process_cfg 13 | from utils import flow_viz 14 | 15 | from core.Networks import build_network 16 | 17 | from utils import frame_utils 18 | from utils.utils import InputPadder, forward_interpolate 19 | import itertools 20 | import imageio 21 | 22 | def prepare_image(seq_dir): 23 | print(f"preparing image...") 24 | print(f"Input image sequence dir = {seq_dir}") 25 | 26 | images = [] 27 | 28 | image_list = sorted(os.listdir(seq_dir)) 29 | 30 | for fn in image_list: 31 | img = Image.open(os.path.join(seq_dir, fn)) 32 | img = np.array(img).astype(np.uint8)[..., :3] 33 | img = torch.from_numpy(img).permute(2, 0, 1).float() 34 | images.append(img) 35 | 36 | return torch.stack(images) 37 | 38 | def vis_pre(flow_pre, vis_dir): 39 | 40 | if not os.path.exists(vis_dir): 41 | os.makedirs(vis_dir) 42 | 43 | N = flow_pre.shape[0] 44 | 45 | for idx in range(N//2): 46 | flow_img = flow_viz.flow_to_image(flow_pre[idx].permute(1, 2, 0).numpy()) 47 | image = Image.fromarray(flow_img) 48 | image.save('{}/flow_{:04}_to_{:04}.png'.format(vis_dir, idx+2, idx+3)) 49 | 50 | for idx in range(N//2, N): 51 | flow_img = flow_viz.flow_to_image(flow_pre[idx].permute(1, 2, 0).numpy()) 52 | image = Image.fromarray(flow_img) 53 | image.save('{}/flow_{:04}_to_{:04}.png'.format(vis_dir, idx-N//2+2, idx-N//2+1)) 54 | 55 | @torch.no_grad() 56 | def MOF_inference(model, cfg): 57 | 58 | model.eval() 59 | 60 | input_images = prepare_image(cfg.seq_dir) 61 | input_images = input_images[None].cuda() 62 | padder = InputPadder(input_images.shape) 63 | input_images = padder.pad(input_images) 64 | flow_pre, _ = model(input_images, {}) 65 | flow_pre = padder.unpad(flow_pre[0]).cpu() 66 | 67 | return flow_pre 68 | 69 | @torch.no_grad() 70 | def BOF_inference(model, cfg): 71 | 72 | model.eval() 73 | 74 | input_images = prepare_image(cfg.seq_dir) 75 | input_images = input_images[None].cuda() 76 | padder = InputPadder(input_images.shape) 77 | input_images = padder.pad(input_images) 78 | flow_pre, _ = model(input_images, {}) 79 | flow_pre = padder.unpad(flow_pre[0]).cpu() 80 | 81 | return flow_pre 82 | 83 | def count_parameters(model): 84 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 85 | 86 | if __name__ == '__main__': 87 | parser = argparse.ArgumentParser() 88 | parser.add_argument('--mode', default='MOF') 89 | parser.add_argument('--seq_dir', default='default') 90 | parser.add_argument('--vis_dir', default='default') 91 | 92 | args = parser.parse_args() 93 | 94 | if args.mode == 'MOF': 95 | from configs.multiframes_sintel_submission import get_cfg 96 | elif args.mode == 'BOF': 97 | from configs.sintel_submission import get_cfg 98 | 99 | cfg = get_cfg() 100 | cfg.update(vars(args)) 101 | 102 | model = torch.nn.DataParallel(build_network(cfg)) 103 | model.load_state_dict(torch.load(cfg.model)) 104 | 105 | model.cuda() 106 | model.eval() 107 | 108 | print(cfg.model) 109 | print("Parameter Count: %d" % count_parameters(model)) 110 | 111 | with torch.no_grad(): 112 | if args.mode == 'MOF': 113 | from configs.multiframes_sintel_submission import get_cfg 114 | flow_pre = MOF_inference(model.module, cfg) 115 | elif args.mode == 'BOF': 116 | from configs.sintel_submission import get_cfg 117 | flow_pre = BOF_inference(model.module, cfg) 118 | 119 | vis_pre(flow_pre, cfg.vis_dir) 120 | 121 | 122 | 123 | -------------------------------------------------------------------------------- /train_BOFNet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import sys 3 | # sys.path.append('core') 4 | 5 | import argparse 6 | import os 7 | import cv2 8 | import time 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | from pathlib import Path 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.optim as optim 16 | import torch.nn.functional as F 17 | 18 | from torch.utils.data import DataLoader 19 | from core import optimizer 20 | import core.datasets_3frames as datasets 21 | from core.loss import sequence_loss 22 | from core.optimizer import fetch_optimizer 23 | from core.utils.misc import process_cfg 24 | from loguru import logger as loguru_logger 25 | from core.utils.logger import Logger 26 | 27 | from core.Networks import build_network 28 | 29 | try: 30 | from torch.cuda.amp import GradScaler 31 | except: 32 | # dummy GradScaler for PyTorch < 1.6 33 | class GradScaler: 34 | def __init__(self): 35 | pass 36 | def scale(self, loss): 37 | return loss 38 | def unscale_(self, optimizer): 39 | pass 40 | def step(self, optimizer): 41 | optimizer.step() 42 | def update(self): 43 | pass 44 | 45 | def count_parameters(model): 46 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 47 | 48 | def train(cfg): 49 | 50 | loss_func = sequence_loss 51 | if cfg.use_smoothl1: 52 | print("[Using smooth L1 loss]") 53 | loss_func = sequence_loss_smooth 54 | 55 | model = nn.DataParallel(build_network(cfg)) 56 | loguru_logger.info("Parameter Count: %d" % count_parameters(model)) 57 | 58 | if cfg.restore_ckpt is not None: 59 | print("[Loading ckpt from {}]".format(cfg.restore_ckpt)) 60 | model.load_state_dict(torch.load(cfg.restore_ckpt), strict=True) 61 | 62 | model.cuda() 63 | model.train() 64 | 65 | train_loader = datasets.fetch_dataloader(cfg) 66 | optimizer, scheduler = fetch_optimizer(model, cfg.trainer) 67 | 68 | total_steps = 0 69 | scaler = GradScaler(enabled=cfg.mixed_precision) 70 | logger = Logger(model, scheduler, cfg) 71 | 72 | should_keep_training = True 73 | while should_keep_training: 74 | 75 | for i_batch, data_blob in enumerate(train_loader): 76 | optimizer.zero_grad() 77 | images, flows, valids = [x.cuda() for x in data_blob] 78 | if cfg.add_noise: 79 | stdv = np.random.uniform(0.0, 5.0) 80 | images = (images + stdv * torch.randn(*images.shape).cuda()).clamp(0.0, 255.0) 81 | 82 | output = {} 83 | flow_predictions = model(images, output) 84 | loss, metrics, _ = loss_func(flow_predictions, flows, valids, cfg) 85 | scaler.scale(loss).backward() 86 | scaler.unscale_(optimizer) 87 | torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.trainer.clip) 88 | 89 | scaler.step(optimizer) 90 | scheduler.step() 91 | scaler.update() 92 | 93 | metrics.update(output) 94 | logger.push(metrics) 95 | 96 | if total_steps % cfg.val_freq == cfg.val_freq - 1: 97 | PATH = '%s/%d_%s.pth' % (cfg.log_dir, total_steps+1, cfg.name) 98 | # torch.save(model.state_dict(), PATH) 99 | 100 | results = {} 101 | for val_dataset in cfg.validation: 102 | if val_dataset == 'sintel_train': 103 | results.update(evaluate_tile.validate_sintel(model.module)) 104 | 105 | logger.write_dict(results) 106 | 107 | model.train() 108 | 109 | total_steps += 1 110 | 111 | if total_steps > cfg.trainer.num_steps: 112 | should_keep_training = False 113 | break 114 | 115 | logger.close() 116 | PATH = cfg.log_dir + '/final' 117 | torch.save(model.state_dict(), PATH) 118 | 119 | return PATH 120 | 121 | if __name__ == '__main__': 122 | parser = argparse.ArgumentParser() 123 | parser.add_argument('--name', default='flowformer', help="name your experiment") 124 | parser.add_argument('--stage', help="determines which dataset to use for training") 125 | parser.add_argument('--validation', type=str, nargs='+') 126 | 127 | args = parser.parse_args() 128 | 129 | if args.stage == 'things': 130 | from configs.things import get_cfg 131 | elif args.stage == 'sintel': 132 | from configs.sintel import get_cfg 133 | elif args.stage == 'kitti': 134 | from configs.kitti import get_cfg 135 | 136 | cfg = get_cfg() 137 | cfg.update(vars(args)) 138 | process_cfg(cfg) 139 | loguru_logger.add(str(Path(cfg.log_dir) / 'log.txt'), encoding="utf8") 140 | loguru_logger.info(cfg) 141 | 142 | torch.manual_seed(1234) 143 | np.random.seed(1234) 144 | 145 | train(cfg) 146 | -------------------------------------------------------------------------------- /train_MOFNet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import sys 3 | # sys.path.append('core') 4 | 5 | import argparse 6 | import os 7 | import cv2 8 | import time 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | from pathlib import Path 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.optim as optim 16 | import torch.nn.functional as F 17 | 18 | from torch.utils.data import DataLoader 19 | from core import optimizer 20 | import core.datasets_multiframes as datasets 21 | from core.loss import sequence_loss 22 | from core.optimizer import fetch_optimizer 23 | from core.utils.misc import process_cfg 24 | from loguru import logger as loguru_logger 25 | from core.utils.logger import Logger 26 | 27 | from core.Networks import build_network 28 | 29 | try: 30 | from torch.cuda.amp import GradScaler 31 | except: 32 | # dummy GradScaler for PyTorch < 1.6 33 | class GradScaler: 34 | def __init__(self): 35 | pass 36 | def scale(self, loss): 37 | return loss 38 | def unscale_(self, optimizer): 39 | pass 40 | def step(self, optimizer): 41 | optimizer.step() 42 | def update(self): 43 | pass 44 | 45 | from torchvision.utils import save_image 46 | 47 | def count_parameters(model): 48 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 49 | 50 | def train(cfg): 51 | 52 | loss_func = sequence_loss 53 | if cfg.use_smoothl1: 54 | print("[Using smooth L1 loss]") 55 | loss_func = sequence_loss_smooth 56 | 57 | model = nn.DataParallel(build_network(cfg)) 58 | loguru_logger.info("Parameter Count: %d" % count_parameters(model)) 59 | 60 | if cfg.restore_ckpt is not None: 61 | print("[Loading ckpt from {}]".format(cfg.restore_ckpt)) 62 | model.load_state_dict(torch.load(cfg.restore_ckpt), strict=True) 63 | 64 | model.cuda() 65 | model.train() 66 | 67 | train_loader = datasets.fetch_dataloader(cfg) 68 | optimizer, scheduler = fetch_optimizer(model, cfg.trainer) 69 | 70 | total_steps = 0 71 | scaler = GradScaler(enabled=cfg.mixed_precision) 72 | logger = Logger(model, scheduler, cfg) 73 | 74 | should_keep_training = True 75 | while should_keep_training: 76 | 77 | for i_batch, data_blob in enumerate(train_loader): 78 | 79 | optimizer.zero_grad() 80 | images, flows, valids = [x.cuda() for x in data_blob] 81 | 82 | if cfg.add_noise: 83 | stdv = np.random.uniform(0.0, 5.0) 84 | images = (images + stdv * torch.randn(*images.shape).cuda()).clamp(0.0, 255.0) 85 | 86 | output = {} 87 | flow_predictions = model(images, output) 88 | loss, metrics, NAN_flag = loss_func(flow_predictions, flows, valids, cfg) 89 | 90 | scaler.scale(loss).backward() 91 | scaler.unscale_(optimizer) 92 | torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.trainer.clip) 93 | 94 | scaler.step(optimizer) 95 | scheduler.step() 96 | scaler.update() 97 | 98 | metrics.update(output) 99 | logger.push(metrics) 100 | 101 | total_steps += 1 102 | 103 | if total_steps > cfg.trainer.num_steps: 104 | should_keep_training = False 105 | break 106 | 107 | logger.close() 108 | PATH = cfg.log_dir + '/final' 109 | torch.save(model.state_dict(), PATH) 110 | 111 | return PATH 112 | 113 | if __name__ == '__main__': 114 | parser = argparse.ArgumentParser() 115 | parser.add_argument('--name', default='mofnet', help="name your experiment") 116 | parser.add_argument('--stage', help="determines which dataset to use for training") 117 | parser.add_argument('--validation', type=str, nargs='+') 118 | 119 | args = parser.parse_args() 120 | 121 | if args.stage == 'things': 122 | from configs.things_multiframes import get_cfg 123 | elif args.stage == "sintel": 124 | from configs.sintel_multiframes import get_cfg 125 | 126 | cfg = get_cfg() 127 | cfg.update(vars(args)) 128 | process_cfg(cfg) 129 | loguru_logger.add(str(Path(cfg.log_dir) / 'log.txt'), encoding="utf8") 130 | loguru_logger.info(cfg) 131 | 132 | torch.manual_seed(1234) 133 | np.random.seed(1234) 134 | 135 | train(cfg) 136 | --------------------------------------------------------------------------------