├── .gitignore ├── LICENSE ├── RAFT ├── .gitignore ├── LICENSE ├── RAFT.png ├── README.md ├── alt_cuda_corr │ ├── correlation.cpp │ ├── correlation_kernel.cu │ └── setup.py ├── chairs_split.txt ├── core │ ├── __init__.py │ ├── corr.py │ ├── datasets.py │ ├── extractor.py │ ├── raft.py │ ├── update.py │ └── utils │ │ ├── __init__.py │ │ ├── augmentor.py │ │ ├── flow_viz.py │ │ ├── frame_utils.py │ │ └── utils.py ├── demo.py ├── download_models.sh ├── evaluate.py ├── train.py ├── train_mixed.sh └── train_standard.sh ├── README.md ├── data └── dis_index.py ├── demo ├── DR-RIFE-vgg_0.gif ├── DR-RIFE-vgg_1.gif ├── DR-RIFE_0.gif ├── DR-RIFE_1.gif ├── I0_0.png ├── I0_1.png ├── I1_0.png ├── I1_1.png ├── T-RIFE_0.gif ├── T-RIFE_1.gif ├── cctv5_interpany-clearer.mp4 ├── manipulation.jpg ├── manipulation1.gif ├── manipulation2.gif ├── manipulation3.gif └── teaser.jpg ├── docker ├── Dockerfile ├── README.md └── entrypoint.sh ├── inference_img.py ├── inference_video.py ├── models ├── DI-AMT-and-IFRNet │ ├── LICENSE │ ├── README.md │ ├── benchmarks │ │ ├── Vimeo90K_m.py │ │ ├── Vimeo90K_sdi_m.py │ │ ├── Vimeo90K_sdi_m_recur.py │ │ ├── __init__.py │ │ ├── adobe240.py │ │ ├── gopro.py │ │ ├── snu_film.py │ │ ├── speed_parameters.py │ │ ├── ucf101.py │ │ ├── vimeo90k.py │ │ ├── vimeo90k_sdi.py │ │ ├── vimeo90k_tta.py │ │ └── xiph.py │ ├── cfgs │ │ ├── AMT-G.yaml │ │ ├── AMT-L.yaml │ │ ├── AMT-S.yaml │ │ ├── AMT-S_gopro.yaml │ │ ├── AMT-S_septuplet.yaml │ │ ├── AMT-S_septuplet_wofloloss.yaml │ │ ├── IFRNet.yaml │ │ ├── IFRNet_septuplet_wofloloss.yaml │ │ ├── M-SDI-AMT-S_septuplet_wofloloss.yaml │ │ ├── M-SDI-R-AMT-S_septuplet_wofloloss.yaml │ │ ├── M-SDI-R-AMT-S_v1_septuplet_wofloloss.yaml │ │ ├── R-AMT-S_v1_septuplet_wofloloss.yaml │ │ ├── R-IFRNet_septuplet_wofloloss.yaml │ │ ├── SDI-AMT-S_septuplet_wofloloss.yaml │ │ ├── SDI-AMT-S_triplet_wofloloss.yaml │ │ ├── SDI-IFRNet_septuplet_wofloloss.yaml │ │ ├── SDI-IFRNet_triplet_wofloloss.yaml │ │ ├── SDI-R-AMT-S_septuplet_wofloloss.yaml │ │ ├── SDI-R-AMT-S_v1_septuplet_wofloloss.yaml │ │ ├── SDI-R-AMT-S_v2_septuplet_wofloloss.yaml │ │ └── SDI-R-IFRNet_septuplet_wofloloss.yaml │ ├── datasets │ │ ├── __init__.py │ │ ├── adobe_datasets.py │ │ ├── gopro_datasets.py │ │ ├── vimeo_datasets.py │ │ ├── vimeo_septuplet_datasets.py │ │ └── vimeo_septuplet_recur_datasets.py │ ├── environment.yaml │ ├── flow_generation │ │ ├── __init__.py │ │ ├── gen_flow.py │ │ ├── gen_multi_flow.py │ │ ├── liteflownet │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── correlation │ │ │ │ ├── README.md │ │ │ │ └── correlation.py │ │ │ └── run.py │ │ └── multiprocess_gen_multi_flow.py │ ├── inference_img_plus.py │ ├── inference_img_plus_sdi.py │ ├── inference_img_plus_sdi_recur.py │ ├── inference_video_plus.py │ ├── inference_video_plus_sdi.py │ ├── inference_video_plus_sdi_recur.py │ ├── losses │ │ ├── __init__.py │ │ └── loss.py │ ├── metrics │ │ ├── __init__.py │ │ └── psnr_ssim.py │ ├── networks │ │ ├── AMT-G.py │ │ ├── AMT-L.py │ │ ├── AMT-S.py │ │ ├── IFRNet.py │ │ ├── SDI-AMT-S.py │ │ ├── SDI-IFRNet.py │ │ ├── SDI-R-AMT-S.py │ │ ├── SDI-R-AMT-S_v1.py │ │ ├── SDI-R-AMT-S_v2.py │ │ ├── SDI-R-IFRNet.py │ │ ├── __init__.py │ │ └── blocks │ │ │ ├── __init__.py │ │ │ ├── feat_enc.py │ │ │ ├── ifrnet.py │ │ │ ├── ifrnet_recur.py │ │ │ ├── ifrnet_recur_v1.py │ │ │ ├── ifrnet_recur_v2.py │ │ │ ├── multi_flow.py │ │ │ ├── multi_flow_recur.py │ │ │ ├── multi_flow_recur_v1.py │ │ │ ├── multi_flow_recur_v2.py │ │ │ └── raft.py │ ├── scripts │ │ ├── benchmark_arbitrary.sh │ │ ├── benchmark_fixed.sh │ │ └── train.sh │ ├── train.py │ ├── trainers │ │ ├── __init__.py │ │ ├── base_trainer.py │ │ └── logger.py │ └── utils │ │ ├── __init__.py │ │ ├── build_utils.py │ │ ├── dist_utils.py │ │ ├── flow_utils.py │ │ └── utils.py ├── DI-EMA-VFI │ ├── LICENSE │ ├── README.md │ ├── Trainer.py │ ├── Trainer_recur.py │ ├── benchmark │ │ ├── HD_4X.py │ │ ├── MiddleBury.py │ │ ├── SNU_FILM.py │ │ ├── TimeTest.py │ │ ├── UCF101.py │ │ ├── Vimeo90K.py │ │ ├── Vimeo90K_m.py │ │ ├── Vimeo90K_sdi.py │ │ ├── Vimeo90K_sdi_m.py │ │ ├── Vimeo90K_sdi_m_recur.py │ │ ├── XTest_8X.py │ │ ├── Xiph.py │ │ └── utils │ │ │ ├── padder.py │ │ │ ├── pytorch_msssim.py │ │ │ └── yuv_frame_io.py │ ├── config.py │ ├── config_recur.py │ ├── dataset.py │ ├── dataset_sdi_m_mask.py │ ├── dataset_sdi_m_mask_recur.py │ ├── demo_2x.py │ ├── demo_Nx.py │ ├── inference_img_plus.py │ ├── inference_img_plus_sdi.py │ ├── inference_img_plus_sdi_recur.py │ ├── inference_video_plus.py │ ├── inference_video_plus_sdi.py │ ├── inference_video_plus_sdi_recur.py │ ├── model │ │ ├── __init__.py │ │ ├── feature_extractor.py │ │ ├── feature_recur_extractor.py │ │ ├── flow_estimation.py │ │ ├── flow_recur_estimation.py │ │ ├── loss.py │ │ ├── refine.py │ │ └── warplayer.py │ ├── train.py │ ├── train_sdi_m_mask.py │ └── train_sdi_m_mask_recur.py └── DI-RIFE │ ├── README.md │ ├── benchmark │ ├── ATD12K.py │ ├── HD.py │ ├── HD_multi_4X.py │ ├── MiddleBury_Other.py │ ├── UCF101.py │ ├── Vimeo90K.py │ ├── Vimeo90K_m.py │ ├── Vimeo90K_sdi.py │ ├── Vimeo90K_sdi_m.py │ ├── Vimeo90K_sdi_m_recur.py │ ├── Vimeo90K_sdi_unif.py │ ├── Vimeo90K_sdi_unif_m.py │ ├── Vimeo90K_sdi_unif_m_recur.py │ ├── testtime.py │ └── yuv_frame_io.py │ ├── dataset.py │ ├── dataset_m.py │ ├── dataset_sdi.py │ ├── dataset_sdi_m.py │ ├── dataset_sdi_m_mask_recur.py │ ├── inference_img_plus.py │ ├── inference_img_plus_sdi.py │ ├── inference_img_plus_sdi_recur.py │ ├── inference_video_plus.py │ ├── inference_video_plus_sdi.py │ ├── inference_video_plus_sdi_recur.py │ ├── model │ ├── IFNet.py │ ├── IFNet_2R.py │ ├── IFNet_m.py │ ├── IFNet_sdi.py │ ├── IFNet_sdi_recur.py │ ├── RIFE.py │ ├── RIFE_m.py │ ├── RIFE_sdi.py │ ├── RIFE_sdi_recur.py │ ├── laplacian.py │ ├── loss.py │ ├── oldmodel │ │ ├── IFNet_HD.py │ │ ├── IFNet_HDv2.py │ │ ├── RIFE_HD.py │ │ └── RIFE_HDv2.py │ ├── pytorch_msssim │ │ └── __init__.py │ ├── refine.py │ ├── refine_2R.py │ └── warplayer.py │ ├── requirements.txt │ ├── train.py │ ├── train_m.py │ ├── train_sdi.py │ ├── train_sdi_m.py │ ├── train_sdi_m_mask_recur.py │ └── utils.py ├── multiprocess_create_dis_index.py ├── process_create_dis_index.py ├── requirements.txt ├── test.py ├── train.py └── webapp ├── .gitignore ├── backend ├── .gitignore ├── README.md ├── app.py ├── data │ ├── embeddings │ │ └── .gitkeep │ ├── models │ │ └── .gitkeep │ └── uploads │ │ └── .gitkeep ├── dataset.py ├── embedding.py ├── interpolate_ours.py └── testServer.py └── webapp ├── .gitignore ├── CREATE_EMBEDDING.md ├── README.md ├── configs └── webpack │ ├── common.js │ ├── dev.js │ └── prod.js ├── model ├── sam_onnx_quantized_example_fast.onnx └── sam_onnx_quantized_example_full.onnx ├── package-lock.json ├── package.json ├── postcss.config.js ├── public └── images │ └── close_button.png ├── src ├── App.tsx ├── ControlApp.tsx ├── StageApp.tsx ├── assets │ ├── images │ │ └── loader.gif │ ├── index.html │ └── scss │ │ └── App.scss ├── components │ ├── DragDropFile.jsx │ ├── LineChart.tsx │ ├── Mask.tsx │ ├── MaskList.tsx │ ├── Stage.tsx │ ├── Tool.tsx │ ├── helpers │ │ ├── Interfaces.tsx │ │ ├── maskUtils.tsx │ │ ├── onnxModelAPI.tsx │ │ └── scaleHelper.tsx │ └── hooks │ │ ├── context.tsx │ │ └── createContext.tsx └── index.tsx ├── tailwind.config.js └── tsconfig.json /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Zhihang Zhong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /RAFT/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.egg-info 3 | dist 4 | datasets 5 | pytorch_env 6 | models 7 | build 8 | correlation.egg-info 9 | -------------------------------------------------------------------------------- /RAFT/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, princeton-vl 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /RAFT/RAFT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzh-tech/InterpAny-Clearer/50ece06a1fca91bfeed2bd3f2e724d696833ff63/RAFT/RAFT.png -------------------------------------------------------------------------------- /RAFT/README.md: -------------------------------------------------------------------------------- 1 | # RAFT 2 | This repository contains the source code for our paper: 3 | 4 | [RAFT: Recurrent All Pairs Field Transforms for Optical Flow](https://arxiv.org/pdf/2003.12039.pdf)
5 | ECCV 2020
6 | Zachary Teed and Jia Deng
7 | 8 | 9 | 10 | ## Requirements 11 | The code has been tested with PyTorch 1.6 and Cuda 10.1. 12 | ```Shell 13 | conda create --name raft 14 | conda activate raft 15 | conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 matplotlib tensorboard scipy opencv -c pytorch 16 | ``` 17 | 18 | ## Demos 19 | Pretrained models can be downloaded by running 20 | ```Shell 21 | ./download_models.sh 22 | ``` 23 | or downloaded from [google drive](https://drive.google.com/drive/folders/1sWDsfuZ3Up38EUQt7-JDTT1HcGHuJgvT?usp=sharing) 24 | 25 | You can demo a trained model on a sequence of frames 26 | ```Shell 27 | python demo.py --model=models/raft-things.pth --path=demo-frames 28 | ``` 29 | 30 | ## Required Data 31 | To evaluate/train RAFT, you will need to download the required datasets. 32 | * [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs) 33 | * [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) 34 | * [Sintel](http://sintel.is.tue.mpg.de/) 35 | * [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow) 36 | * [HD1K](http://hci-benchmark.iwr.uni-heidelberg.de/) (optional) 37 | 38 | 39 | By default `datasets.py` will search for the datasets in these locations. You can create symbolic links to wherever the datasets were downloaded in the `datasets` folder 40 | 41 | ```Shell 42 | ├── datasets 43 | ├── Sintel 44 | ├── test 45 | ├── training 46 | ├── KITTI 47 | ├── testing 48 | ├── training 49 | ├── devkit 50 | ├── FlyingChairs_release 51 | ├── data 52 | ├── FlyingThings3D 53 | ├── frames_cleanpass 54 | ├── frames_finalpass 55 | ├── optical_flow 56 | ``` 57 | 58 | ## Evaluation 59 | You can evaluate a trained model using `evaluate.py` 60 | ```Shell 61 | python evaluate.py --model=models/raft-things.pth --dataset=sintel --mixed_precision 62 | ``` 63 | 64 | ## Training 65 | We used the following training schedule in our paper (2 GPUs). Training logs will be written to the `runs` which can be visualized using tensorboard 66 | ```Shell 67 | ./train_standard.sh 68 | ``` 69 | 70 | If you have a RTX GPU, training can be accelerated using mixed precision. You can expect similiar results in this setting (1 GPU) 71 | ```Shell 72 | ./train_mixed.sh 73 | ``` 74 | 75 | ## (Optional) Efficent Implementation 76 | You can optionally use our alternate (efficent) implementation by compiling the provided cuda extension 77 | ```Shell 78 | cd alt_cuda_corr && python setup.py install && cd .. 79 | ``` 80 | and running `demo.py` and `evaluate.py` with the `--alternate_corr` flag Note, this implementation is somewhat slower than all-pairs, but uses significantly less GPU memory during the forward pass. 81 | -------------------------------------------------------------------------------- /RAFT/alt_cuda_corr/correlation.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | // CUDA forward declarations 5 | std::vector corr_cuda_forward( 6 | torch::Tensor fmap1, 7 | torch::Tensor fmap2, 8 | torch::Tensor coords, 9 | int radius); 10 | 11 | std::vector corr_cuda_backward( 12 | torch::Tensor fmap1, 13 | torch::Tensor fmap2, 14 | torch::Tensor coords, 15 | torch::Tensor corr_grad, 16 | int radius); 17 | 18 | // C++ interface 19 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 20 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 21 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 22 | 23 | std::vector corr_forward( 24 | torch::Tensor fmap1, 25 | torch::Tensor fmap2, 26 | torch::Tensor coords, 27 | int radius) { 28 | CHECK_INPUT(fmap1); 29 | CHECK_INPUT(fmap2); 30 | CHECK_INPUT(coords); 31 | 32 | return corr_cuda_forward(fmap1, fmap2, coords, radius); 33 | } 34 | 35 | 36 | std::vector corr_backward( 37 | torch::Tensor fmap1, 38 | torch::Tensor fmap2, 39 | torch::Tensor coords, 40 | torch::Tensor corr_grad, 41 | int radius) { 42 | CHECK_INPUT(fmap1); 43 | CHECK_INPUT(fmap2); 44 | CHECK_INPUT(coords); 45 | CHECK_INPUT(corr_grad); 46 | 47 | return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius); 48 | } 49 | 50 | 51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 52 | m.def("forward", &corr_forward, "CORR forward"); 53 | m.def("backward", &corr_backward, "CORR backward"); 54 | } -------------------------------------------------------------------------------- /RAFT/alt_cuda_corr/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | 5 | setup( 6 | name='correlation', 7 | ext_modules=[ 8 | CUDAExtension('alt_cuda_corr', 9 | sources=['correlation.cpp', 'correlation_kernel.cu'], 10 | extra_compile_args={'cxx': [], 'nvcc': ['-O3']}), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) 15 | 16 | -------------------------------------------------------------------------------- /RAFT/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzh-tech/InterpAny-Clearer/50ece06a1fca91bfeed2bd3f2e724d696833ff63/RAFT/core/__init__.py -------------------------------------------------------------------------------- /RAFT/core/corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | try: 5 | import alt_cuda_corr 6 | from utils.utils import bilinear_sampler, coords_grid 7 | except ImportError: 8 | from RAFT.core.utils.utils import bilinear_sampler, coords_grid 9 | 10 | # alt_cuda_corr is not compiled 11 | pass 12 | 13 | 14 | class CorrBlock: 15 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 16 | self.num_levels = num_levels 17 | self.radius = radius 18 | self.corr_pyramid = [] 19 | 20 | # all pairs correlation 21 | corr = CorrBlock.corr(fmap1, fmap2) 22 | 23 | batch, h1, w1, dim, h2, w2 = corr.shape 24 | corr = corr.reshape(batch * h1 * w1, dim, h2, w2) 25 | 26 | self.corr_pyramid.append(corr) 27 | for i in range(self.num_levels - 1): 28 | corr = F.avg_pool2d(corr, 2, stride=2) 29 | self.corr_pyramid.append(corr) 30 | 31 | def __call__(self, coords): 32 | r = self.radius 33 | coords = coords.permute(0, 2, 3, 1) 34 | batch, h1, w1, _ = coords.shape 35 | 36 | out_pyramid = [] 37 | for i in range(self.num_levels): 38 | corr = self.corr_pyramid[i] 39 | dx = torch.linspace(-r, r, 2 * r + 1, device=coords.device) 40 | dy = torch.linspace(-r, r, 2 * r + 1, device=coords.device) 41 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1) 42 | 43 | centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2 ** i 44 | delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) 45 | coords_lvl = centroid_lvl + delta_lvl 46 | 47 | corr = bilinear_sampler(corr, coords_lvl) 48 | corr = corr.view(batch, h1, w1, -1) 49 | out_pyramid.append(corr) 50 | 51 | out = torch.cat(out_pyramid, dim=-1) 52 | return out.permute(0, 3, 1, 2).contiguous().float() 53 | 54 | @staticmethod 55 | def corr(fmap1, fmap2): 56 | batch, dim, ht, wd = fmap1.shape 57 | fmap1 = fmap1.view(batch, dim, ht * wd) 58 | fmap2 = fmap2.view(batch, dim, ht * wd) 59 | 60 | corr = torch.matmul(fmap1.transpose(1, 2), fmap2) 61 | corr = corr.view(batch, ht, wd, 1, ht, wd) 62 | return corr / torch.sqrt(torch.tensor(dim).float()) 63 | 64 | 65 | class AlternateCorrBlock: 66 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 67 | self.num_levels = num_levels 68 | self.radius = radius 69 | 70 | self.pyramid = [(fmap1, fmap2)] 71 | for i in range(self.num_levels): 72 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2) 73 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2) 74 | self.pyramid.append((fmap1, fmap2)) 75 | 76 | def __call__(self, coords): 77 | coords = coords.permute(0, 2, 3, 1) 78 | B, H, W, _ = coords.shape 79 | dim = self.pyramid[0][0].shape[1] 80 | 81 | corr_list = [] 82 | for i in range(self.num_levels): 83 | r = self.radius 84 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() 85 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() 86 | 87 | coords_i = (coords / 2 ** i).reshape(B, 1, H, W, 2).contiguous() 88 | corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) 89 | corr_list.append(corr.squeeze(1)) 90 | 91 | corr = torch.stack(corr_list, dim=1) 92 | corr = corr.reshape(B, -1, H, W) 93 | return corr / torch.sqrt(torch.tensor(dim).float()) 94 | -------------------------------------------------------------------------------- /RAFT/core/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzh-tech/InterpAny-Clearer/50ece06a1fca91bfeed2bd3f2e724d696833ff63/RAFT/core/utils/__init__.py -------------------------------------------------------------------------------- /RAFT/core/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy import interpolate 5 | 6 | 7 | class InputPadder: 8 | """ Pads images such that dimensions are divisible by 8 """ 9 | def __init__(self, dims, mode='sintel'): 10 | self.ht, self.wd = dims[-2:] 11 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 12 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 13 | if mode == 'sintel': 14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 15 | else: 16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] 17 | 18 | def pad(self, *inputs): 19 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 20 | 21 | def unpad(self,x): 22 | ht, wd = x.shape[-2:] 23 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 24 | return x[..., c[0]:c[1], c[2]:c[3]] 25 | 26 | def forward_interpolate(flow): 27 | flow = flow.detach().cpu().numpy() 28 | dx, dy = flow[0], flow[1] 29 | 30 | ht, wd = dx.shape 31 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 32 | 33 | x1 = x0 + dx 34 | y1 = y0 + dy 35 | 36 | x1 = x1.reshape(-1) 37 | y1 = y1.reshape(-1) 38 | dx = dx.reshape(-1) 39 | dy = dy.reshape(-1) 40 | 41 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 42 | x1 = x1[valid] 43 | y1 = y1[valid] 44 | dx = dx[valid] 45 | dy = dy[valid] 46 | 47 | flow_x = interpolate.griddata( 48 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 49 | 50 | flow_y = interpolate.griddata( 51 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) 52 | 53 | flow = np.stack([flow_x, flow_y], axis=0) 54 | return torch.from_numpy(flow).float() 55 | 56 | 57 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 58 | """ Wrapper for grid_sample, uses pixel coordinates """ 59 | H, W = img.shape[-2:] 60 | xgrid, ygrid = coords.split([1,1], dim=-1) 61 | xgrid = 2*xgrid/(W-1) - 1 62 | ygrid = 2*ygrid/(H-1) - 1 63 | 64 | grid = torch.cat([xgrid, ygrid], dim=-1) 65 | img = F.grid_sample(img, grid, align_corners=True) 66 | 67 | if mask: 68 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 69 | return img, mask.float() 70 | 71 | return img 72 | 73 | 74 | def coords_grid(batch, ht, wd, device): 75 | coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device)) 76 | coords = torch.stack(coords[::-1], dim=0).float() 77 | return coords[None].repeat(batch, 1, 1, 1) 78 | 79 | 80 | def upflow8(flow, mode='bilinear'): 81 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 82 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 83 | -------------------------------------------------------------------------------- /RAFT/demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('core') 3 | 4 | import argparse 5 | import os 6 | import cv2 7 | import glob 8 | import numpy as np 9 | import torch 10 | from PIL import Image 11 | 12 | from raft import RAFT 13 | from utils import flow_viz 14 | from utils.utils import InputPadder 15 | 16 | 17 | 18 | DEVICE = 'cuda' 19 | 20 | def load_image(imfile): 21 | img = np.array(Image.open(imfile)).astype(np.uint8) 22 | img = torch.from_numpy(img).permute(2, 0, 1).float() 23 | return img[None].to(DEVICE) 24 | 25 | 26 | def viz(img, flo): 27 | img = img[0].permute(1,2,0).cpu().numpy() 28 | flo = flo[0].permute(1,2,0).cpu().numpy() 29 | 30 | # map flow to rgb image 31 | flo = flow_viz.flow_to_image(flo) 32 | img_flo = np.concatenate([img, flo], axis=0) 33 | 34 | # import matplotlib.pyplot as plt 35 | # plt.imshow(img_flo / 255.0) 36 | # plt.show() 37 | 38 | cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0) 39 | cv2.waitKey() 40 | 41 | 42 | def demo(args): 43 | model = torch.nn.DataParallel(RAFT(args)) 44 | model.load_state_dict(torch.load(args.model)) 45 | 46 | model = model.module 47 | model.to(DEVICE) 48 | model.eval() 49 | 50 | with torch.no_grad(): 51 | images = glob.glob(os.path.join(args.path, '*.png')) + \ 52 | glob.glob(os.path.join(args.path, '*.jpg')) 53 | 54 | images = sorted(images) 55 | for imfile1, imfile2 in zip(images[:-1], images[1:]): 56 | image1 = load_image(imfile1) 57 | image2 = load_image(imfile2) 58 | 59 | padder = InputPadder(image1.shape) 60 | image1, image2 = padder.pad(image1, image2) 61 | 62 | flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) 63 | viz(image1, flow_up) 64 | 65 | 66 | if __name__ == '__main__': 67 | parser = argparse.ArgumentParser() 68 | parser.add_argument('--model', help="restore checkpoint") 69 | parser.add_argument('--path', help="dataset for evaluation") 70 | parser.add_argument('--small', action='store_true', help='use small model') 71 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') 72 | parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation') 73 | args = parser.parse_args() 74 | 75 | demo(args) 76 | -------------------------------------------------------------------------------- /RAFT/download_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget https://www.dropbox.com/s/4j4z58wuv8o0mfz/models.zip 3 | unzip models.zip 4 | -------------------------------------------------------------------------------- /RAFT/train_mixed.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | mkdir -p checkpoints 3 | python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 --num_steps 120000 --batch_size 8 --lr 0.00025 --image_size 368 496 --wdecay 0.0001 --mixed_precision 4 | python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 400 720 --wdecay 0.0001 --mixed_precision 5 | python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 368 768 --wdecay 0.00001 --gamma=0.85 --mixed_precision 6 | python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 --num_steps 50000 --batch_size 5 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85 --mixed_precision 7 | -------------------------------------------------------------------------------- /RAFT/train_standard.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | mkdir -p checkpoints 3 | python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 1 --num_steps 100000 --batch_size 10 --lr 0.0004 --image_size 368 496 --wdecay 0.0001 4 | python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 400 720 --wdecay 0.0001 5 | python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 368 768 --wdecay 0.00001 --gamma=0.85 6 | python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 1 --num_steps 50000 --batch_size 6 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85 7 | -------------------------------------------------------------------------------- /demo/DR-RIFE-vgg_0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzh-tech/InterpAny-Clearer/50ece06a1fca91bfeed2bd3f2e724d696833ff63/demo/DR-RIFE-vgg_0.gif -------------------------------------------------------------------------------- /demo/DR-RIFE-vgg_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzh-tech/InterpAny-Clearer/50ece06a1fca91bfeed2bd3f2e724d696833ff63/demo/DR-RIFE-vgg_1.gif -------------------------------------------------------------------------------- /demo/DR-RIFE_0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzh-tech/InterpAny-Clearer/50ece06a1fca91bfeed2bd3f2e724d696833ff63/demo/DR-RIFE_0.gif -------------------------------------------------------------------------------- /demo/DR-RIFE_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzh-tech/InterpAny-Clearer/50ece06a1fca91bfeed2bd3f2e724d696833ff63/demo/DR-RIFE_1.gif -------------------------------------------------------------------------------- /demo/I0_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzh-tech/InterpAny-Clearer/50ece06a1fca91bfeed2bd3f2e724d696833ff63/demo/I0_0.png -------------------------------------------------------------------------------- /demo/I0_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzh-tech/InterpAny-Clearer/50ece06a1fca91bfeed2bd3f2e724d696833ff63/demo/I0_1.png -------------------------------------------------------------------------------- /demo/I1_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzh-tech/InterpAny-Clearer/50ece06a1fca91bfeed2bd3f2e724d696833ff63/demo/I1_0.png -------------------------------------------------------------------------------- /demo/I1_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzh-tech/InterpAny-Clearer/50ece06a1fca91bfeed2bd3f2e724d696833ff63/demo/I1_1.png -------------------------------------------------------------------------------- /demo/T-RIFE_0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzh-tech/InterpAny-Clearer/50ece06a1fca91bfeed2bd3f2e724d696833ff63/demo/T-RIFE_0.gif -------------------------------------------------------------------------------- /demo/T-RIFE_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzh-tech/InterpAny-Clearer/50ece06a1fca91bfeed2bd3f2e724d696833ff63/demo/T-RIFE_1.gif -------------------------------------------------------------------------------- /demo/cctv5_interpany-clearer.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzh-tech/InterpAny-Clearer/50ece06a1fca91bfeed2bd3f2e724d696833ff63/demo/cctv5_interpany-clearer.mp4 -------------------------------------------------------------------------------- /demo/manipulation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzh-tech/InterpAny-Clearer/50ece06a1fca91bfeed2bd3f2e724d696833ff63/demo/manipulation.jpg -------------------------------------------------------------------------------- /demo/manipulation1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzh-tech/InterpAny-Clearer/50ece06a1fca91bfeed2bd3f2e724d696833ff63/demo/manipulation1.gif -------------------------------------------------------------------------------- /demo/manipulation2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzh-tech/InterpAny-Clearer/50ece06a1fca91bfeed2bd3f2e724d696833ff63/demo/manipulation2.gif -------------------------------------------------------------------------------- /demo/manipulation3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzh-tech/InterpAny-Clearer/50ece06a1fca91bfeed2bd3f2e724d696833ff63/demo/manipulation3.gif -------------------------------------------------------------------------------- /demo/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzh-tech/InterpAny-Clearer/50ece06a1fca91bfeed2bd3f2e724d696833ff63/demo/teaser.jpg -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | 2 | FROM nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04 3 | 4 | 5 | # # Avoid Public GPG key error 6 | # # https://github.com/NVIDIA/nvidia-docker/issues/1631 7 | # RUN apt-key del 7fa2af80 \ 8 | # && apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub \ 9 | # && apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub 10 | 11 | ENV TZ=Asia/Shanghai 12 | RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone 13 | 14 | # install apt packages 15 | RUN apt-get update \ 16 | && apt-get install -y ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6 wget curl\ 17 | && apt-get clean \ 18 | && rm -rf /var/lib/apt/lists/* 19 | 20 | SHELL ["/bin/bash", "-c"] 21 | 22 | # Install miniconda 23 | RUN wget -q \ 24 | https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ 25 | && bash Miniconda3-latest-Linux-x86_64.sh -b -p /opt/miniconda \ 26 | && rm -f Miniconda3-latest-Linux-x86_64.sh 27 | 28 | ENV PATH="/opt/miniconda/bin:$PATH" 29 | 30 | # Update in bashrc 31 | RUN echo "source /opt/miniconda/etc/profile.d/conda.sh" >> /root/.bashrc 32 | 33 | # RUN git clone https://github.com/Wei-ucas/InterpAny-Clearer.git /InterpAny-Clearer \ 34 | COPY . /InterpAny-Clearer 35 | RUN source ~/.bashrc \ 36 | && cd /InterpAny-Clearer \ 37 | && conda create -n InterpAny python=3.8 -y\ 38 | && source activate InterpAny \ 39 | && pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116 \ 40 | && pip install -r requirements.txt \ 41 | && conda clean -y --all\ 42 | && pip cache purge 43 | 44 | 45 | WORKDIR /InterpAny-Clearer 46 | # Download pretrained models 47 | RUN wget https://pjlab-3090-sport.oss-cn-beijing.aliyuncs.com/downloads/InterpAny-Clearer/checkpoints.tar.gz\ 48 | && tar -zxvf checkpoints.tar.gz \ 49 | && rm checkpoints.tar.gz 50 | 51 | # Prepare backend environment 52 | RUN cd /InterpAny-Clearer/webapp/backend \ 53 | && cd data/models\ 54 | && wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth 55 | 56 | 57 | # Install webapp dependencies 58 | ARG VERSION="v16.14.2" 59 | ARG DISTRO="linux-x64" 60 | RUN wget -c https://nodejs.org/dist/$VERSION/node-$VERSION-$DISTRO.tar.xz -P ~/Downloads\ 61 | && mkdir -p /usr/local/lib/nodejs \ 62 | && tar -xJvf ~/Downloads/node-$VERSION-$DISTRO.tar.xz -C /usr/local/lib/nodejs \ 63 | && rm -rf ~/Downloads/node-$VERSION-$DISTRO.tar.xz \ 64 | && sed -i 's/^#.*nodejs.*$//gi' ~/.profile \ 65 | && sed -i 's/^export PATH=\/usr\/local\/lib\/nodejs\/node.*$//g' ~/.profile \ 66 | && export PATH=/usr/local/lib/nodejs/node-$VERSION-$DISTRO/bin:$PATH | tee -a ~/.profile \ 67 | && . ~/.profile \ 68 | && ln -sf /usr/local/lib/nodejs/node-$VERSION-$DISTRO/bin/node /usr/bin/node \ 69 | && ln -sf /usr/local/lib/nodejs/node-$VERSION-$DISTRO/bin/npm /usr/bin/npm \ 70 | && ln -sf /usr/local/lib/nodejs/node-$VERSION-$DISTRO/bin/npx /usr/bin/npx \ 71 | && echo "Node.js version: $(node -v)" \ 72 | && echo "NPM version: $(npm -v)" \ 73 | && npm install -g yarn \ 74 | && ln -sf /usr/local/lib/nodejs/node-$VERSION-$DISTRO/bin/yarn /usr/bin/yarn \ 75 | && echo "Yarn version: $(yarn -v)" 76 | 77 | 78 | RUN echo "conda activate InterpAny" >> ~/.bashrc 79 | 80 | EXPOSE 5001 8080 81 | 82 | COPY ./docker/entrypoint.sh /usr/local/bin/entrypoint.sh 83 | RUN chmod +x /usr/local/bin/entrypoint.sh 84 | ENTRYPOINT [ "/usr/local/bin/entrypoint.sh" ] 85 | CMD [ "serve" ] 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | -------------------------------------------------------------------------------- /docker/README.md: -------------------------------------------------------------------------------- 1 | # Run with Docker Container 2 | 3 | ## Build Docker image 4 | Make sure docker >= 19.03 and nvidia-container-toolkit >= 1.3 are installed. 5 | ```shell 6 | docker build -t interpany:v0 -f docker/Dockerfile . 7 | ``` 8 | 9 | ## Run Docker container 10 | 11 | ### RUN the container 12 | 13 | ```shell 14 | docker run -it --name=interp -p 5001:5001 -p 8080:8080 --gpus all --shm-size=8g interpany:v0 15 | ``` 16 | This command will build a container named `interp`, which serves a webapp on http://localhost:8080/ (only accessible from the local machine). 17 | 18 | ```shell 19 | docker run -it --name=interp -p 5001:5001 -p 8080:8080 --gpus all --shm-size=8g interpany:v0 /bin/bash 20 | ``` 21 | With `/bin/bash/` arguement, only the container with all the dependencies is built, and the project code is located at `/InterpAny-Clearer` in the container. 22 | -------------------------------------------------------------------------------- /docker/entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | if [[ "$1" = "serve" ]]; then 5 | shift 1 6 | source activate InterpAny 7 | mkdir /InterpAny-Clearer/webapp/backend/data/results 8 | cd /InterpAny-Clearer/webapp/backend && nohup python app.py & 9 | cd /InterpAny-Clearer/webapp/webapp && yarn && nohup yarn start & 10 | echo "Webapp is running on http://localhost:8080" 11 | else 12 | eval "$@" 13 | fi 14 | 15 | # prevent docker exit 16 | tail -f /dev/null -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/benchmarks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzh-tech/InterpAny-Clearer/50ece06a1fca91bfeed2bd3f2e724d696833ff63/models/DI-AMT-and-IFRNet/benchmarks/__init__.py -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/benchmarks/adobe240.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import tqdm 3 | import torch 4 | import argparse 5 | import numpy as np 6 | from omegaconf import OmegaConf 7 | 8 | sys.path.append('.') 9 | from utils.build_utils import build_from_cfg 10 | from datasets.adobe_datasets import Adobe240_Dataset 11 | from metrics.psnr_ssim import calculate_psnr, calculate_ssim 12 | 13 | parser = argparse.ArgumentParser( 14 | prog = 'AMT', 15 | description = 'Adobe240 evaluation', 16 | ) 17 | parser.add_argument('-c', '--config', default='cfgs/AMT-S_gopro.yaml') 18 | parser.add_argument('-p', '--ckpt', default='pretrained/gopro_amt-s.pth',) 19 | parser.add_argument('-r', '--root', default='data/Adobe240/test_frames',) 20 | args = parser.parse_args() 21 | 22 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 23 | cfg_path = args.config 24 | ckpt_path = args.ckpt 25 | root = args.root 26 | 27 | network_cfg = OmegaConf.load(cfg_path).network 28 | network_name = network_cfg.name 29 | model = build_from_cfg(network_cfg) 30 | ckpt = torch.load(ckpt_path) 31 | model.load_state_dict(ckpt['state_dict']) 32 | model = model.to(device) 33 | model.eval() 34 | 35 | dataset = Adobe240_Dataset(dataset_dir=root, augment=False) 36 | 37 | psnr_list = [] 38 | ssim_list = [] 39 | pbar = tqdm.tqdm(dataset, total=len(dataset)) 40 | for data in pbar: 41 | input_dict = {} 42 | for k, v in data.items(): 43 | input_dict[k] = v.to(device).unsqueeze(0) 44 | with torch.no_grad(): 45 | imgt_pred = model(**input_dict)['imgt_pred'] 46 | psnr = calculate_psnr(imgt_pred, input_dict['imgt']) 47 | ssim = calculate_ssim(imgt_pred, input_dict['imgt']) 48 | psnr_list.append(psnr) 49 | ssim_list.append(ssim) 50 | avg_psnr = np.mean(psnr_list) 51 | avg_ssim = np.mean(ssim_list) 52 | desc_str = f'[{network_name}/Adobe240] psnr: {avg_psnr:.02f}, ssim: {avg_ssim:.04f}' 53 | pbar.set_description_str(desc_str) 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/benchmarks/gopro.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import tqdm 3 | import torch 4 | import argparse 5 | import numpy as np 6 | from omegaconf import OmegaConf 7 | 8 | sys.path.append('.') 9 | from utils.build_utils import build_from_cfg 10 | from datasets.gopro_datasets import GoPro_Test_Dataset 11 | from metrics.psnr_ssim import calculate_psnr, calculate_ssim 12 | 13 | parser = argparse.ArgumentParser( 14 | prog = 'AMT', 15 | description = 'GOPRO evaluation', 16 | ) 17 | parser.add_argument('-c', '--config', default='cfgs/AMT-S_gopro.yaml') 18 | parser.add_argument('-p', '--ckpt', default='pretrained/gopro_amt-s.pth',) 19 | parser.add_argument('-r', '--root', default='data/GOPRO',) 20 | args = parser.parse_args() 21 | 22 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 23 | cfg_path = args.config 24 | ckpt_path = args.ckpt 25 | root = args.root 26 | 27 | network_cfg = OmegaConf.load(cfg_path).network 28 | network_name = network_cfg.name 29 | model = build_from_cfg(network_cfg) 30 | ckpt = torch.load(ckpt_path) 31 | model.load_state_dict(ckpt['state_dict']) 32 | model = model.to(device) 33 | model.eval() 34 | 35 | dataset = GoPro_Test_Dataset(dataset_dir=root) 36 | 37 | psnr_list = [] 38 | ssim_list = [] 39 | pbar = tqdm.tqdm(dataset, total=len(dataset)) 40 | for data in pbar: 41 | input_dict = {} 42 | for k, v in data.items(): 43 | input_dict[k] = v.to(device).unsqueeze(0) 44 | with torch.no_grad(): 45 | imgt_pred = model(**input_dict)['imgt_pred'] 46 | psnr = calculate_psnr(imgt_pred, input_dict['imgt']) 47 | ssim = calculate_ssim(imgt_pred, input_dict['imgt']) 48 | psnr_list.append(psnr) 49 | ssim_list.append(ssim) 50 | avg_psnr = np.mean(psnr_list) 51 | avg_ssim = np.mean(ssim_list) 52 | desc_str = f'[{network_name}/GOPRO] psnr: {avg_psnr:.02f}, ssim: {avg_ssim:.04f}' 53 | pbar.set_description_str(desc_str) 54 | 55 | 56 | -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/benchmarks/snu_film.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import tqdm 4 | import torch 5 | import argparse 6 | import numpy as np 7 | import os.path as osp 8 | from omegaconf import OmegaConf 9 | 10 | sys.path.append('.') 11 | from utils.build_utils import build_from_cfg 12 | from metrics.psnr_ssim import calculate_psnr, calculate_ssim 13 | from utils.utils import InputPadder, read, img2tensor 14 | 15 | 16 | def parse_path(path): 17 | path_list = path.split('/') 18 | new_path = osp.join(*path_list[-3:]) 19 | return new_path 20 | 21 | parser = argparse.ArgumentParser( 22 | prog = 'AMT', 23 | description = 'SNU-FILM evaluation', 24 | ) 25 | parser.add_argument('-c', '--config', default='cfgs/AMT-S.yaml') 26 | parser.add_argument('-p', '--ckpt', default='pretrained/amt-s.pth') 27 | parser.add_argument('-r', '--root', default='data/SNU_FILM') 28 | args = parser.parse_args() 29 | 30 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 31 | cfg_path = args.config 32 | ckpt_path = args.ckpt 33 | root = args.root 34 | 35 | network_cfg = OmegaConf.load(cfg_path).network 36 | network_name = network_cfg.name 37 | model = build_from_cfg(network_cfg) 38 | ckpt = torch.load(ckpt_path) 39 | model.load_state_dict(ckpt['state_dict']) 40 | model = model.to(device) 41 | model.eval() 42 | 43 | divisor = 20; scale_factor = 0.8 44 | splits = ['easy', 'medium', 'hard', 'extreme'] 45 | for split in splits: 46 | with open(os.path.join(root, f'test-{split}.txt'), "r") as fr: 47 | file_list = [l.strip().split(' ') for l in fr.readlines()] 48 | pbar = tqdm.tqdm(file_list, total=len(file_list)) 49 | 50 | psnr_list = []; ssim_list = [] 51 | for name in pbar: 52 | img0 = img2tensor(read(osp.join(root, parse_path(name[0])))).to(device) 53 | imgt = img2tensor(read(osp.join(root, parse_path(name[1])))).to(device) 54 | img1 = img2tensor(read(osp.join(root, parse_path(name[2])))).to(device) 55 | padder = InputPadder(img0.shape, divisor) 56 | img0, img1 = padder.pad(img0, img1) 57 | 58 | embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device) 59 | imgt_pred = model(img0, img1, embt, scale_factor=scale_factor, eval=True)['imgt_pred'] 60 | imgt_pred = padder.unpad(imgt_pred) 61 | 62 | psnr = calculate_psnr(imgt_pred, imgt).detach().cpu().numpy() 63 | ssim = calculate_ssim(imgt_pred, imgt).detach().cpu().numpy() 64 | 65 | psnr_list.append(psnr) 66 | ssim_list.append(ssim) 67 | avg_psnr = np.mean(psnr_list) 68 | avg_ssim = np.mean(ssim_list) 69 | desc_str = f'[{network_name}/SNU-FILM] [{split}] psnr: {avg_psnr:.02f}, ssim: {avg_ssim:.04f}' 70 | pbar.set_description_str(desc_str) 71 | -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/benchmarks/speed_parameters.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | import torch 4 | import argparse 5 | from omegaconf import OmegaConf 6 | 7 | sys.path.append('.') 8 | from utils.build_utils import build_from_cfg 9 | 10 | parser = argparse.ArgumentParser( 11 | prog = 'AMT', 12 | description = 'Speed¶meter benchmark', 13 | ) 14 | parser.add_argument('-c', '--config', default='cfgs/AMT-S.yaml') 15 | args = parser.parse_args() 16 | 17 | cfg_path = args.config 18 | network_cfg = OmegaConf.load(cfg_path).network 19 | model = build_from_cfg(network_cfg) 20 | model = model.cuda() 21 | model.eval() 22 | 23 | img0 = torch.randn(1, 3, 256, 448).cuda() 24 | img1 = torch.randn(1, 3, 256, 448).cuda() 25 | embt = torch.tensor(1/2).float().view(1, 1, 1, 1).cuda() 26 | 27 | with torch.no_grad(): 28 | for i in range(100): 29 | out = model(img0, img1, embt, eval=True) 30 | torch.cuda.synchronize() 31 | time_stamp = time.time() 32 | for i in range(1000): 33 | out = model(img0, img1, embt, eval=True) 34 | torch.cuda.synchronize() 35 | print('Time: {:.5f}s'.format((time.time() - time_stamp) / 1)) 36 | 37 | total = sum([param.nelement() for param in model.parameters()]) 38 | print('Parameters: {:.2f}M'.format(total / 1e6)) 39 | -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/benchmarks/ucf101.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import tqdm 4 | import torch 5 | import argparse 6 | import numpy as np 7 | import os.path as osp 8 | from omegaconf import OmegaConf 9 | 10 | sys.path.append('.') 11 | from utils.utils import read, img2tensor 12 | from utils.build_utils import build_from_cfg 13 | from metrics.psnr_ssim import calculate_psnr, calculate_ssim 14 | 15 | parser = argparse.ArgumentParser( 16 | prog = 'AMT', 17 | description = 'UCF101 evaluation', 18 | ) 19 | parser.add_argument('-c', '--config', default='cfgs/AMT-S.yaml') 20 | parser.add_argument('-p', '--ckpt', default='pretrained/amt-s.pth') 21 | parser.add_argument('-r', '--root', default='data/ucf101_interp_ours') 22 | args = parser.parse_args() 23 | 24 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 25 | cfg_path = args.config 26 | ckpt_path = args.ckpt 27 | root = args.root 28 | 29 | network_cfg = OmegaConf.load(cfg_path).network 30 | network_name = network_cfg.name 31 | model = build_from_cfg(network_cfg) 32 | ckpt = torch.load(ckpt_path) 33 | model.load_state_dict(ckpt['state_dict']) 34 | model = model.to(device) 35 | model.eval() 36 | 37 | dirs = sorted(os.listdir(root)) 38 | psnr_list = [] 39 | ssim_list = [] 40 | pbar = tqdm.tqdm(dirs, total=len(dirs)) 41 | for d in pbar: 42 | dir_path = osp.join(root, d) 43 | I0 = img2tensor(read(osp.join(dir_path, 'frame_00.png'))).to(device) 44 | I1 = img2tensor(read(osp.join(dir_path, 'frame_01_gt.png'))).to(device) 45 | I2 = img2tensor(read(osp.join(dir_path, 'frame_02.png'))).to(device) 46 | embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device) 47 | 48 | I1_pred = model(I0, I2, embt, eval=True)['imgt_pred'] 49 | 50 | psnr = calculate_psnr(I1_pred, I1).detach().cpu().numpy() 51 | ssim = calculate_ssim(I1_pred, I1).detach().cpu().numpy() 52 | 53 | psnr_list.append(psnr) 54 | ssim_list.append(ssim) 55 | 56 | avg_psnr = np.mean(psnr_list) 57 | avg_ssim = np.mean(ssim_list) 58 | desc_str = f'[{network_name}/UCF101] psnr: {avg_psnr:.02f}, ssim: {avg_ssim:.04f}' 59 | pbar.set_description_str(desc_str) -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/benchmarks/vimeo90k_tta.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import tqdm 3 | import torch 4 | import argparse 5 | import numpy as np 6 | import os.path as osp 7 | from omegaconf import OmegaConf 8 | 9 | sys.path.append('.') 10 | from utils.utils import read, img2tensor 11 | from utils.build_utils import build_from_cfg 12 | from metrics.psnr_ssim import calculate_psnr, calculate_ssim 13 | 14 | parser = argparse.ArgumentParser( 15 | prog = 'AMT', 16 | description = 'Vimeo90K evaluation (with Test-Time Augmentation)', 17 | ) 18 | parser.add_argument('-c', '--config', default='cfgs/AMT-S.yaml') 19 | parser.add_argument('p', '--ckpt', default='pretrained/amt-s.pth',) 20 | parser.add_argument('-r', '--root', default='data/vimeo_triplet',) 21 | args = parser.parse_args() 22 | 23 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 24 | cfg_path = args.config 25 | ckpt_path = args.ckpt 26 | root = args.root 27 | 28 | network_cfg = OmegaConf.load(cfg_path).network 29 | network_name = network_cfg.name 30 | model = build_from_cfg(network_cfg) 31 | ckpt = torch.load(ckpt_path) 32 | model.load_state_dict(ckpt['state_dict']) 33 | model = model.to(device) 34 | model.eval() 35 | 36 | with open(osp.join(root, 'tri_testlist.txt'), 'r') as fr: 37 | file_list = fr.readlines() 38 | 39 | psnr_list = [] 40 | ssim_list = [] 41 | 42 | pbar = tqdm.tqdm(file_list, total=len(file_list)) 43 | for name in pbar: 44 | name = str(name).strip() 45 | if(len(name) <= 1): 46 | continue 47 | dir_path = osp.join(root, 'sequences', name) 48 | I0 = img2tensor(read(osp.join(dir_path, 'im1.png'))).to(device) 49 | I1 = img2tensor(read(osp.join(dir_path, 'im2.png'))).to(device) 50 | I2 = img2tensor(read(osp.join(dir_path, 'im3.png'))).to(device) 51 | embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device) 52 | 53 | I1_pred1 = model(I0, I2, embt, 54 | scale_factor=1.0, eval=True)['imgt_pred'] 55 | I1_pred2 = model(torch.flip(I0, [2]), torch.flip(I2, [2]), embt, 56 | scale_factor=1.0, eval=True)['imgt_pred'] 57 | I1_pred = I1_pred1 / 2 + torch.flip(I1_pred2, [2]) / 2 58 | psnr = calculate_psnr(I1_pred, I1).detach().cpu().numpy() 59 | ssim = calculate_ssim(I1_pred, I1).detach().cpu().numpy() 60 | 61 | psnr_list.append(psnr) 62 | ssim_list.append(ssim) 63 | avg_psnr = np.mean(psnr_list) 64 | avg_ssim = np.mean(ssim_list) 65 | desc_str = f'[{network_name}/Vimeo90K] psnr: {avg_psnr:.02f}, ssim: {avg_ssim:.04f}' 66 | pbar.set_description_str(desc_str) 67 | 68 | -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/cfgs/AMT-G.yaml: -------------------------------------------------------------------------------- 1 | exp_name: floloss1e-2_300epoch_bs24_lr1p5e-4 2 | seed: 2023 3 | epochs: 300 4 | distributed: true 5 | lr: 1.5e-4 6 | lr_min: 2e-5 7 | weight_decay: 0.0 8 | resume_state: null 9 | save_dir: work_dir 10 | eval_interval: 1 11 | 12 | network: 13 | name: networks.AMT-G.Model 14 | params: 15 | corr_radius: 3 16 | corr_lvls: 4 17 | num_flows: 5 18 | data: 19 | train: 20 | name: datasets.vimeo_datasets.Vimeo90K_Train_Dataset 21 | params: 22 | dataset_dir: data/vimeo_triplet 23 | val: 24 | name: datasets.vimeo_datasets.Vimeo90K_Test_Dataset 25 | params: 26 | dataset_dir: data/vimeo_triplet 27 | train_loader: 28 | batch_size: 24 29 | num_workers: 12 30 | val_loader: 31 | batch_size: 24 32 | num_workers: 3 33 | 34 | logger: 35 | use_wandb: true 36 | resume_id: null 37 | 38 | losses: 39 | - { 40 | name: losses.loss.CharbonnierLoss, 41 | nickname: l_rec, 42 | params: { 43 | loss_weight: 1.0, 44 | keys: [imgt_pred, imgt] 45 | } 46 | } 47 | - { 48 | name: losses.loss.TernaryLoss, 49 | nickname: l_ter, 50 | params: { 51 | loss_weight: 1.0, 52 | keys: [imgt_pred, imgt] 53 | } 54 | } 55 | - { 56 | name: losses.loss.MultipleFlowLoss, 57 | nickname: l_flo, 58 | params: { 59 | loss_weight: 0.005, 60 | keys: [flow0_pred, flow1_pred, flow] 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/cfgs/AMT-L.yaml: -------------------------------------------------------------------------------- 1 | exp_name: floloss1e-2_300epoch_bs24_lr2e-4 2 | seed: 2023 3 | epochs: 300 4 | distributed: true 5 | lr: 2e-4 6 | lr_min: 2e-5 7 | weight_decay: 0.0 8 | resume_state: null 9 | save_dir: work_dir 10 | eval_interval: 1 11 | 12 | network: 13 | name: networks.AMT-L.Model 14 | params: 15 | corr_radius: 3 16 | corr_lvls: 4 17 | num_flows: 5 18 | data: 19 | train: 20 | name: datasets.vimeo_datasets.Vimeo90K_Train_Dataset 21 | params: 22 | dataset_dir: data/vimeo_triplet 23 | val: 24 | name: datasets.vimeo_datasets.Vimeo90K_Test_Dataset 25 | params: 26 | dataset_dir: data/vimeo_triplet 27 | train_loader: 28 | batch_size: 24 29 | num_workers: 12 30 | val_loader: 31 | batch_size: 24 32 | num_workers: 3 33 | 34 | logger: 35 | use_wandb: true 36 | resume_id: null 37 | 38 | losses: 39 | - { 40 | name: losses.loss.CharbonnierLoss, 41 | nickname: l_rec, 42 | params: { 43 | loss_weight: 1.0, 44 | keys: [imgt_pred, imgt] 45 | } 46 | } 47 | - { 48 | name: losses.loss.TernaryLoss, 49 | nickname: l_ter, 50 | params: { 51 | loss_weight: 1.0, 52 | keys: [imgt_pred, imgt] 53 | } 54 | } 55 | - { 56 | name: losses.loss.MultipleFlowLoss, 57 | nickname: l_flo, 58 | params: { 59 | loss_weight: 0.002, 60 | keys: [flow0_pred, flow1_pred, flow] 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/cfgs/AMT-S.yaml: -------------------------------------------------------------------------------- 1 | exp_name: floloss1e-2_300epoch_bs24_lr2e-4 2 | seed: 2023 3 | epochs: 300 4 | distributed: true 5 | lr: 2e-4 6 | lr_min: 2e-5 7 | weight_decay: 0.0 8 | resume_state: null 9 | save_dir: work_dir 10 | eval_interval: 1 11 | 12 | network: 13 | name: networks.AMT-S.Model 14 | params: 15 | corr_radius: 3 16 | corr_lvls: 4 17 | num_flows: 3 18 | 19 | data: 20 | train: 21 | name: datasets.vimeo_datasets.Vimeo90K_Train_Dataset 22 | params: 23 | dataset_dir: data/vimeo_triplet 24 | val: 25 | name: datasets.vimeo_datasets.Vimeo90K_Test_Dataset 26 | params: 27 | dataset_dir: data/vimeo_triplet 28 | train_loader: 29 | batch_size: 24 30 | num_workers: 12 31 | val_loader: 32 | batch_size: 24 33 | num_workers: 3 34 | 35 | logger: 36 | use_wandb: false 37 | resume_id: null 38 | 39 | losses: 40 | - { 41 | name: losses.loss.CharbonnierLoss, 42 | nickname: l_rec, 43 | params: { 44 | loss_weight: 1.0, 45 | keys: [imgt_pred, imgt] 46 | } 47 | } 48 | - { 49 | name: losses.loss.TernaryLoss, 50 | nickname: l_ter, 51 | params: { 52 | loss_weight: 1.0, 53 | keys: [imgt_pred, imgt] 54 | } 55 | } 56 | - { 57 | name: losses.loss.MultipleFlowLoss, 58 | nickname: l_flo, 59 | params: { 60 | loss_weight: 0.002, 61 | keys: [flow0_pred, flow1_pred, flow] 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/cfgs/AMT-S_gopro.yaml: -------------------------------------------------------------------------------- 1 | exp_name: wofloloss_400epoch_bs24_lr2e-4 2 | seed: 2023 3 | epochs: 400 4 | distributed: true 5 | lr: 2e-4 6 | lr_min: 2e-5 7 | weight_decay: 0.0 8 | resume_state: null 9 | save_dir: work_dir 10 | eval_interval: 1 11 | 12 | network: 13 | name: networks.AMT-S.Model 14 | params: 15 | corr_radius: 3 16 | corr_lvls: 4 17 | num_flows: 3 18 | 19 | data: 20 | train: 21 | name: datasets.gopro_datasets.GoPro_Train_Dataset 22 | params: 23 | dataset_dir: data/GOPRO 24 | val: 25 | name: datasets.gopro_datasets.GoPro_Test_Dataset 26 | params: 27 | dataset_dir: data/GOPRO 28 | train_loader: 29 | batch_size: 24 30 | num_workers: 12 31 | val_loader: 32 | batch_size: 24 33 | num_workers: 3 34 | 35 | logger: 36 | use_wandb: false 37 | resume_id: null 38 | 39 | losses: 40 | - { 41 | name: losses.loss.CharbonnierLoss, 42 | nickname: l_rec, 43 | params: { 44 | loss_weight: 1.0, 45 | keys: [imgt_pred, imgt] 46 | } 47 | } 48 | - { 49 | name: losses.loss.TernaryLoss, 50 | nickname: l_ter, 51 | params: { 52 | loss_weight: 1.0, 53 | keys: [imgt_pred, imgt] 54 | } 55 | } 56 | 57 | -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/cfgs/AMT-S_septuplet.yaml: -------------------------------------------------------------------------------- 1 | # CUDA_VISISIBLE_DEVICES=4,5,6,7 sh ./scripts/train.sh 4 cfgs/AMT-S_septuplet.yaml 14514 2 | # CUDA_VISISIBLE_DEVICES=4,5,6,7 screen sh ./scripts/train.sh 4 cfgs/AMT-S_septuplet.yaml 14514 3 | exp_name: 400epoch_bs24_lr2e-4 4 | seed: 2023 5 | epochs: 400 6 | distributed: true 7 | lr: 2e-4 8 | lr_min: 2e-5 9 | weight_decay: 0.0 10 | resume_state: null 11 | save_dir: experiments 12 | eval_interval: 1 13 | 14 | network: 15 | name: networks.AMT-S.Model 16 | params: 17 | corr_radius: 3 18 | corr_lvls: 4 19 | num_flows: 3 20 | 21 | data: 22 | train: 23 | name: datasets.vimeo_septuplet_datasets.Vimeo90K_Train_Dataset 24 | params: 25 | dataset_dir: /mnt/disks/ssd0/dataset/vimeo_septuplet 26 | use_flow: true 27 | val: 28 | name: datasets.vimeo_septuplet_datasets.Vimeo90K_Test_Dataset 29 | params: 30 | dataset_dir: /mnt/disks/ssd0/dataset/vimeo_septuplet 31 | use_flow: true 32 | train_loader: 33 | batch_size: 24 34 | num_workers: 12 35 | val_loader: 36 | batch_size: 24 37 | num_workers: 3 38 | 39 | logger: 40 | use_wandb: false 41 | resume_id: null 42 | 43 | losses: 44 | - { 45 | name: losses.loss.CharbonnierLoss, 46 | nickname: l_rec, 47 | params: { 48 | loss_weight: 1.0, 49 | keys: [ imgt_pred, imgt ] 50 | } 51 | } 52 | - { 53 | name: losses.loss.TernaryLoss, 54 | nickname: l_ter, 55 | params: { 56 | loss_weight: 1.0, 57 | keys: [ imgt_pred, imgt ] 58 | } 59 | } 60 | - { 61 | name: losses.loss.MultipleFlowLoss, 62 | nickname: l_flo, 63 | params: { 64 | loss_weight: 0.002, 65 | keys: [ flow0_pred, flow1_pred, flow ] 66 | } 67 | } 68 | 69 | -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/cfgs/AMT-S_septuplet_wofloloss.yaml: -------------------------------------------------------------------------------- 1 | # CUDA_VISISIBLE_DEVICES=4,5,6,7 sh ./scripts/train.sh 4 cfgs/AMT-S_septuplet_wofloloss.yaml 14514 2 | # CUDA_VISISIBLE_DEVICES=0,1,2,3,4,5,6,7 screen sh ./scripts/train.sh 8 cfgs/AMT-S_septuplet_wofloloss.yaml 14514 3 | #exp_name: 400epoch_bs24_lr2e-4 4 | exp_name: T-AMT-S 5 | seed: 2023 6 | epochs: 400 7 | distributed: true 8 | lr: 2e-4 9 | lr_min: 2e-5 10 | weight_decay: 0.0 11 | resume_state: null 12 | save_dir: experiments 13 | eval_interval: 1 14 | 15 | network: 16 | name: networks.AMT-S.Model 17 | params: 18 | corr_radius: 3 19 | corr_lvls: 4 20 | num_flows: 3 21 | 22 | data: 23 | train: 24 | name: datasets.vimeo_septuplet_datasets.Vimeo90K_Train_Dataset 25 | params: 26 | # dataset_dir: /mnt/disks/ssd0/dataset/vimeo_septuplet 27 | dataset_dir: ../../dataset/vimeo_septuplet 28 | val: 29 | name: datasets.vimeo_septuplet_datasets.Vimeo90K_Test_Dataset 30 | params: 31 | # dataset_dir: /mnt/disks/ssd0/dataset/vimeo_septuplet 32 | dataset_dir: ../../dataset/vimeo_septuplet 33 | train_loader: 34 | batch_size: 24 35 | num_workers: 12 36 | val_loader: 37 | batch_size: 24 38 | num_workers: 3 39 | 40 | logger: 41 | use_wandb: false 42 | resume_id: null 43 | 44 | losses: 45 | - { 46 | name: losses.loss.CharbonnierLoss, 47 | nickname: l_rec, 48 | params: { 49 | loss_weight: 1.0, 50 | keys: [ imgt_pred, imgt ] 51 | } 52 | } 53 | - { 54 | name: losses.loss.TernaryLoss, 55 | nickname: l_ter, 56 | params: { 57 | loss_weight: 1.0, 58 | keys: [ imgt_pred, imgt ] 59 | } 60 | 61 | } 62 | 63 | -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/cfgs/IFRNet.yaml: -------------------------------------------------------------------------------- 1 | exp_name: floloss1e-2_geoloss1e-2_300epoch_bs24_lr1e-4 2 | seed: 2023 3 | epochs: 300 4 | distributed: true 5 | lr: 1e-4 6 | lr_min: 1e-5 7 | weight_decay: 1e-6 8 | resume_state: null 9 | save_dir: work_dir 10 | eval_interval: 1 11 | 12 | network: 13 | name: networks.IFRNet.Model 14 | 15 | data: 16 | train: 17 | name: datasets.datasets.Vimeo90K_Train_Dataset 18 | params: 19 | dataset_dir: data/vimeo_triplet 20 | val: 21 | name: datasets.datasets.Vimeo90K_Test_Dataset 22 | params: 23 | dataset_dir: data/vimeo_triplet 24 | train_loader: 25 | batch_size: 24 26 | num_workers: 12 27 | val_loader: 28 | batch_size: 24 29 | num_workers: 3 30 | 31 | logger: 32 | use_wandb: true 33 | resume_id: null 34 | 35 | losses: 36 | - { 37 | name: losses.loss.CharbonnierLoss, 38 | nickname: l_rec, 39 | params: { 40 | loss_weight: 1.0, 41 | keys: [imgt_pred, imgt] 42 | } 43 | } 44 | - { 45 | name: losses.loss.TernaryLoss, 46 | nickname: l_ter, 47 | params: { 48 | loss_weight: 1.0, 49 | keys: [imgt_pred, imgt] 50 | } 51 | } 52 | - { 53 | name: losses.loss.IFRFlowLoss, 54 | nickname: l_flo, 55 | params: { 56 | loss_weight: 0.01, 57 | keys: [flow0_pred, flow1_pred, flow] 58 | } 59 | } 60 | - { 61 | name: losses.loss.GeometryLoss, 62 | nickname: l_geo, 63 | params: { 64 | loss_weight: 0.01, 65 | keys: [ft_pred, ft_gt] 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/cfgs/IFRNet_septuplet_wofloloss.yaml: -------------------------------------------------------------------------------- 1 | # screen sh ./scripts/train.sh 4 cfgs/IFRNet_septuplet_wofloloss.yaml 14514 2 | #exp_name: 300epoch_bs24_lr1e-4 3 | exp_name: T-IFRNet 4 | seed: 2023 5 | epochs: 300 6 | distributed: true 7 | lr: 1e-4 8 | lr_min: 1e-5 9 | weight_decay: 1e-6 10 | resume_state: null 11 | save_dir: experiments 12 | eval_interval: 1 13 | 14 | network: 15 | name: networks.IFRNet.Model 16 | 17 | data: 18 | train: 19 | name: datasets.vimeo_septuplet_datasets.Vimeo90K_Train_Dataset 20 | params: 21 | # dataset_dir: /mnt/disks/ssd0/dataset/vimeo_septuplet 22 | dataset_dir: ../../dataset/vimeo_septuplet 23 | val: 24 | name: datasets.vimeo_septuplet_datasets.Vimeo90K_Test_Dataset 25 | params: 26 | # dataset_dir: /mnt/disks/ssd0/dataset/vimeo_septuplet 27 | dataset_dir: ../../dataset/vimeo_septuplet 28 | train_loader: 29 | batch_size: 24 30 | num_workers: 12 31 | val_loader: 32 | batch_size: 24 33 | num_workers: 3 34 | 35 | logger: 36 | use_wandb: false 37 | resume_id: null 38 | 39 | losses: 40 | - { 41 | name: losses.loss.CharbonnierLoss, 42 | nickname: l_rec, 43 | params: { 44 | loss_weight: 1.0, 45 | keys: [ imgt_pred, imgt ] 46 | } 47 | } 48 | - { 49 | name: losses.loss.TernaryLoss, 50 | nickname: l_ter, 51 | params: { 52 | loss_weight: 1.0, 53 | keys: [ imgt_pred, imgt ] 54 | } 55 | } 56 | # - { 57 | # name: losses.loss.GeometryLoss, 58 | # nickname: l_geo, 59 | # params: { 60 | # loss_weight: 0.01, 61 | # keys: [ ft_pred, ft_gt ] 62 | # } 63 | # } 64 | -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/cfgs/M-SDI-AMT-S_septuplet_wofloloss.yaml: -------------------------------------------------------------------------------- 1 | # CUDA_VISISIBLE_DEVICES=4,5,6,7 sh ./scripts/train.sh 4 cfgs/M-SDI-AMT-S_septuplet_wofloloss.yaml 14514 2 | # CUDA_VISISIBLE_DEVICES=0,1,2,3,4,5,6,7 screen sh ./scripts/train.sh 8 cfgs/M-SDI-AMT-S_septuplet_wofloloss.yaml 14514 3 | exp_name: 400epoch_bs24_lr2e-4 4 | seed: 2023 5 | epochs: 400 6 | distributed: true 7 | lr: 2e-4 8 | lr_min: 2e-5 9 | weight_decay: 0.0 10 | resume_state: null 11 | save_dir: experiments 12 | eval_interval: 1 13 | 14 | network: 15 | name: networks.SDI-AMT-S.Model 16 | params: 17 | corr_radius: 3 18 | corr_lvls: 4 19 | num_flows: 3 20 | 21 | data: 22 | train: 23 | name: datasets.vimeo_septuplet_datasets.Vimeo90K_Train_Dataset 24 | params: 25 | dataset_dir: /mnt/disks/ssd0/dataset/vimeo_septuplet 26 | use_sdi: true 27 | use_mask: true 28 | val: 29 | name: datasets.vimeo_septuplet_datasets.Vimeo90K_Test_Dataset 30 | params: 31 | dataset_dir: /mnt/disks/ssd0/dataset/vimeo_septuplet 32 | use_sdi: true 33 | use_mask: true 34 | train_loader: 35 | batch_size: 24 36 | num_workers: 12 37 | val_loader: 38 | batch_size: 24 39 | num_workers: 3 40 | 41 | logger: 42 | use_wandb: false 43 | resume_id: null 44 | 45 | losses: 46 | - { 47 | name: losses.loss.CharbonnierLoss, 48 | nickname: l_rec, 49 | params: { 50 | loss_weight: 1.0, 51 | keys: [ imgt_pred, imgt ] 52 | } 53 | } 54 | - { 55 | name: losses.loss.TernaryLoss, 56 | nickname: l_ter, 57 | params: { 58 | loss_weight: 1.0, 59 | keys: [ imgt_pred, imgt ] 60 | } 61 | } 62 | 63 | -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/cfgs/M-SDI-R-AMT-S_septuplet_wofloloss.yaml: -------------------------------------------------------------------------------- 1 | # CUDA_VISISIBLE_DEVICES=4,5,6,7 screen sh ./scripts/train.sh 4 cfgs/M-SDI-R-AMT-S_septuplet_wofloloss.yaml 14514 2 | # CUDA_VISISIBLE_DEVICES=0,1,2,3,4,5,6,7 screen sh ./scripts/train.sh 8 cfgs/M-SDI-R-AMT-S_septuplet_wofloloss.yaml 14514 3 | exp_name: 400epoch_bs24_lr2e-4 4 | seed: 2023 5 | epochs: 400 6 | distributed: true 7 | lr: 2e-4 8 | lr_min: 2e-5 9 | weight_decay: 0.0 10 | resume_state: null 11 | save_dir: experiments 12 | eval_interval: 1 13 | 14 | network: 15 | name: networks.SDI-R-AMT-S.Model 16 | params: 17 | corr_radius: 3 18 | corr_lvls: 4 19 | num_flows: 3 20 | 21 | data: 22 | train: 23 | name: datasets.vimeo_septuplet_recur_datasets.Vimeo90K_Train_Dataset 24 | params: 25 | dataset_dir: /mnt/disks/ssd0/dataset/vimeo_septuplet 26 | use_sdi: true 27 | use_mask: true 28 | val: 29 | name: datasets.vimeo_septuplet_recur_datasets.Vimeo90K_Test_Dataset 30 | params: 31 | dataset_dir: /mnt/disks/ssd0/dataset/vimeo_septuplet 32 | use_sdi: true 33 | use_mask: true 34 | train_loader: 35 | batch_size: 24 36 | num_workers: 12 37 | val_loader: 38 | batch_size: 24 39 | num_workers: 3 40 | 41 | logger: 42 | use_wandb: false 43 | resume_id: null 44 | 45 | losses: 46 | - { 47 | name: losses.loss.CharbonnierLoss, 48 | nickname: l_rec, 49 | params: { 50 | loss_weight: 1.0, 51 | keys: [ imgt_pred, imgt ] 52 | } 53 | } 54 | - { 55 | name: losses.loss.TernaryLoss, 56 | nickname: l_ter, 57 | params: { 58 | loss_weight: 1.0, 59 | keys: [ imgt_pred, imgt ] 60 | } 61 | } 62 | 63 | -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/cfgs/M-SDI-R-AMT-S_v1_septuplet_wofloloss.yaml: -------------------------------------------------------------------------------- 1 | # screen sh ./scripts/train.sh 4 cfgs/M-SDI-R-AMT-S_v1_septuplet_wofloloss.yaml 14514 2 | exp_name: 400epoch_bs24_lr2e-4 3 | seed: 2023 4 | epochs: 400 5 | distributed: true 6 | lr: 2e-4 7 | lr_min: 2e-5 8 | weight_decay: 0.0 9 | resume_state: null 10 | save_dir: experiments 11 | eval_interval: 1 12 | 13 | network: 14 | name: networks.SDI-R-AMT-S_v1.Model 15 | params: 16 | corr_radius: 3 17 | corr_lvls: 4 18 | num_flows: 3 19 | 20 | data: 21 | train: 22 | name: datasets.vimeo_septuplet_recur_datasets.Vimeo90K_Train_Dataset 23 | params: 24 | dataset_dir: /mnt/disks/ssd0/dataset/vimeo_septuplet 25 | use_sdi: true 26 | use_mask: true 27 | val: 28 | name: datasets.vimeo_septuplet_recur_datasets.Vimeo90K_Test_Dataset 29 | params: 30 | dataset_dir: /mnt/disks/ssd0/dataset/vimeo_septuplet 31 | use_sdi: true 32 | use_mask: true 33 | train_loader: 34 | batch_size: 24 35 | num_workers: 12 36 | val_loader: 37 | batch_size: 24 38 | num_workers: 3 39 | 40 | logger: 41 | use_wandb: false 42 | resume_id: null 43 | 44 | losses: 45 | - { 46 | name: losses.loss.CharbonnierLoss, 47 | nickname: l_rec, 48 | params: { 49 | loss_weight: 1.0, 50 | keys: [ imgt_pred, imgt ] 51 | } 52 | } 53 | - { 54 | name: losses.loss.TernaryLoss, 55 | nickname: l_ter, 56 | params: { 57 | loss_weight: 1.0, 58 | keys: [ imgt_pred, imgt ] 59 | } 60 | } 61 | 62 | -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/cfgs/R-AMT-S_v1_septuplet_wofloloss.yaml: -------------------------------------------------------------------------------- 1 | # screen sh ./scripts/train.sh 4 cfgs/R-AMT-S_v1_septuplet_wofloloss.yaml 14514 2 | #exp_name: 400epoch_bs24_lr2e-4 3 | exp_name: TR-AMT-S 4 | seed: 2023 5 | epochs: 400 6 | distributed: true 7 | lr: 2e-4 8 | lr_min: 2e-5 9 | weight_decay: 0.0 10 | resume_state: null 11 | save_dir: experiments 12 | eval_interval: 1 13 | 14 | network: 15 | name: networks.SDI-R-AMT-S_v1.Model 16 | params: 17 | corr_radius: 3 18 | corr_lvls: 4 19 | num_flows: 3 20 | 21 | data: 22 | train: 23 | name: datasets.vimeo_septuplet_recur_datasets.Vimeo90K_Train_Dataset 24 | params: 25 | # dataset_dir: /mnt/disks/ssd0/dataset/vimeo_septuplet 26 | dataset_dir: ../../dataset/vimeo_septuplet 27 | use_sdi: false 28 | use_mask: false 29 | val: 30 | name: datasets.vimeo_septuplet_recur_datasets.Vimeo90K_Test_Dataset 31 | params: 32 | # dataset_dir: /mnt/disks/ssd0/dataset/vimeo_septuplet 33 | dataset_dir: ../../dataset/vimeo_septuplet 34 | use_sdi: false 35 | use_mask: false 36 | train_loader: 37 | batch_size: 24 38 | num_workers: 12 39 | val_loader: 40 | batch_size: 24 41 | num_workers: 3 42 | 43 | logger: 44 | use_wandb: false 45 | resume_id: null 46 | 47 | losses: 48 | - { 49 | name: losses.loss.CharbonnierLoss, 50 | nickname: l_rec, 51 | params: { 52 | loss_weight: 1.0, 53 | keys: [ imgt_pred, imgt ] 54 | } 55 | } 56 | - { 57 | name: losses.loss.TernaryLoss, 58 | nickname: l_ter, 59 | params: { 60 | loss_weight: 1.0, 61 | keys: [ imgt_pred, imgt ] 62 | } 63 | } 64 | 65 | -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/cfgs/R-IFRNet_septuplet_wofloloss.yaml: -------------------------------------------------------------------------------- 1 | # screen sh ./scripts/train.sh 4 cfgs/R-IFRNet_septuplet_wofloloss.yaml 14514 2 | #exp_name: 300epoch_bs24_lr1e-4 3 | exp_name: TR-IFRNet 4 | seed: 2023 5 | epochs: 300 6 | distributed: true 7 | lr: 1e-4 8 | lr_min: 1e-5 9 | weight_decay: 1e-6 10 | resume_state: null 11 | save_dir: experiments 12 | eval_interval: 1 13 | 14 | network: 15 | name: networks.SDI-R-IFRNet.Model 16 | 17 | data: 18 | train: 19 | name: datasets.vimeo_septuplet_recur_datasets.Vimeo90K_Train_Dataset 20 | params: 21 | # dataset_dir: /mnt/disks/ssd0/dataset/vimeo_septuplet 22 | dataset_dir: ../../dataset/vimeo_septuplet 23 | use_sdi: false 24 | use_mask: false 25 | val: 26 | name: datasets.vimeo_septuplet_recur_datasets.Vimeo90K_Test_Dataset 27 | params: 28 | # dataset_dir: /mnt/disks/ssd0/dataset/vimeo_septuplet 29 | dataset_dir: ../../dataset/vimeo_septuplet 30 | use_sdi: false 31 | use_mask: false 32 | train_loader: 33 | batch_size: 24 34 | num_workers: 12 35 | val_loader: 36 | batch_size: 24 37 | num_workers: 3 38 | 39 | logger: 40 | use_wandb: false 41 | resume_id: null 42 | 43 | losses: 44 | - { 45 | name: losses.loss.CharbonnierLoss, 46 | nickname: l_rec, 47 | params: { 48 | loss_weight: 1.0, 49 | keys: [ imgt_pred, imgt ] 50 | } 51 | } 52 | - { 53 | name: losses.loss.TernaryLoss, 54 | nickname: l_ter, 55 | params: { 56 | loss_weight: 1.0, 57 | keys: [ imgt_pred, imgt ] 58 | } 59 | } 60 | # - { 61 | # name: losses.loss.GeometryLoss, 62 | # nickname: l_geo, 63 | # params: { 64 | # loss_weight: 0.01, 65 | # keys: [ ft_pred, ft_gt ] 66 | # } 67 | # } 68 | -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/cfgs/SDI-AMT-S_septuplet_wofloloss.yaml: -------------------------------------------------------------------------------- 1 | # CUDA_VISISIBLE_DEVICES=4,5,6,7 sh ./scripts/train.sh 4 cfgs/SDI-AMT-S_septuplet_wofloloss.yaml 14514 2 | # CUDA_VISISIBLE_DEVICES=0,1,2,3,4,5,6,7 screen sh ./scripts/train.sh 8 cfgs/SDI-AMT-S_septuplet_wofloloss.yaml 14514 3 | #exp_name: 400epoch_bs24_lr2e-4 4 | exp_name: D-AMT-S 5 | seed: 2023 6 | epochs: 400 7 | distributed: true 8 | lr: 2e-4 9 | lr_min: 2e-5 10 | weight_decay: 0.0 11 | resume_state: null 12 | save_dir: experiments 13 | eval_interval: 1 14 | 15 | network: 16 | name: networks.SDI-AMT-S.Model 17 | params: 18 | corr_radius: 3 19 | corr_lvls: 4 20 | num_flows: 3 21 | 22 | data: 23 | train: 24 | name: datasets.vimeo_septuplet_datasets.Vimeo90K_Train_Dataset 25 | params: 26 | # dataset_dir: /mnt/disks/ssd0/dataset/vimeo_septuplet 27 | dataset_dir: ../../dataset/vimeo_septuplet 28 | use_sdi: true 29 | use_mask: false 30 | val: 31 | name: datasets.vimeo_septuplet_datasets.Vimeo90K_Test_Dataset 32 | params: 33 | # dataset_dir: /mnt/disks/ssd0/dataset/vimeo_septuplet 34 | dataset_dir: ../../dataset/vimeo_septuplet 35 | use_sdi: true 36 | use_mask: false 37 | train_loader: 38 | batch_size: 24 39 | num_workers: 12 40 | val_loader: 41 | batch_size: 24 42 | num_workers: 3 43 | 44 | logger: 45 | use_wandb: false 46 | resume_id: null 47 | 48 | losses: 49 | - { 50 | name: losses.loss.CharbonnierLoss, 51 | nickname: l_rec, 52 | params: { 53 | loss_weight: 1.0, 54 | keys: [ imgt_pred, imgt ] 55 | } 56 | } 57 | - { 58 | name: losses.loss.TernaryLoss, 59 | nickname: l_ter, 60 | params: { 61 | loss_weight: 1.0, 62 | keys: [ imgt_pred, imgt ] 63 | } 64 | } 65 | 66 | -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/cfgs/SDI-AMT-S_triplet_wofloloss.yaml: -------------------------------------------------------------------------------- 1 | # CUDA_VISISIBLE_DEVICES=0,1,2,3 screen sh ./scripts/train.sh 4 cfgs/SDI-AMT-S_triplet_wofloloss.yaml 14514 2 | exp_name: 300epoch_bs24_lr2e-4 3 | seed: 2023 4 | epochs: 300 5 | distributed: true 6 | lr: 2e-4 7 | lr_min: 2e-5 8 | weight_decay: 0.0 9 | resume_state: null 10 | save_dir: experiments 11 | eval_interval: 1 12 | 13 | network: 14 | name: networks.SDI-AMT-S.Model 15 | params: 16 | corr_radius: 3 17 | corr_lvls: 4 18 | num_flows: 3 19 | 20 | data: 21 | train: 22 | name: datasets.vimeo_septuplet_datasets.Vimeo90K_Train_Dataset 23 | params: 24 | dataset_dir: /mnt/disks/ssd0/dataset/vimeo_triplet 25 | use_sdi: true 26 | use_mask: false 27 | triplet: true 28 | val: 29 | name: datasets.vimeo_septuplet_datasets.Vimeo90K_Test_Dataset 30 | params: 31 | dataset_dir: /mnt/disks/ssd0/dataset/vimeo_triplet 32 | use_sdi: true 33 | use_mask: false 34 | triplet: true 35 | train_loader: 36 | batch_size: 24 37 | num_workers: 12 38 | val_loader: 39 | batch_size: 24 40 | num_workers: 3 41 | 42 | logger: 43 | use_wandb: false 44 | resume_id: null 45 | 46 | losses: 47 | - { 48 | name: losses.loss.CharbonnierLoss, 49 | nickname: l_rec, 50 | params: { 51 | loss_weight: 1.0, 52 | keys: [ imgt_pred, imgt ] 53 | } 54 | } 55 | - { 56 | name: losses.loss.TernaryLoss, 57 | nickname: l_ter, 58 | params: { 59 | loss_weight: 1.0, 60 | keys: [ imgt_pred, imgt ] 61 | } 62 | } 63 | 64 | -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/cfgs/SDI-IFRNet_septuplet_wofloloss.yaml: -------------------------------------------------------------------------------- 1 | # screen sh ./scripts/train.sh 4 cfgs/SDI-IFRNet_septuplet_wofloloss.yaml 14514 2 | #exp_name: 300epoch_bs24_lr1e-4 3 | exp_name: D-IFRNet 4 | seed: 2023 5 | epochs: 300 6 | distributed: true 7 | lr: 1e-4 8 | lr_min: 1e-5 9 | weight_decay: 1e-6 10 | resume_state: null 11 | save_dir: experiments 12 | eval_interval: 1 13 | 14 | network: 15 | name: networks.SDI-IFRNet.Model 16 | 17 | data: 18 | train: 19 | name: datasets.vimeo_septuplet_datasets.Vimeo90K_Train_Dataset 20 | params: 21 | # dataset_dir: /mnt/disks/ssd0/dataset/vimeo_septuplet 22 | dataset_dir: ../../dataset/vimeo_septuplet 23 | use_sdi: true 24 | use_mask: false 25 | val: 26 | name: datasets.vimeo_septuplet_datasets.Vimeo90K_Test_Dataset 27 | params: 28 | # dataset_dir: /mnt/disks/ssd0/dataset/vimeo_septuplet 29 | dataset_dir: ../../dataset/vimeo_septuplet 30 | use_sdi: true 31 | use_mask: false 32 | train_loader: 33 | batch_size: 24 34 | num_workers: 12 35 | val_loader: 36 | batch_size: 24 37 | num_workers: 3 38 | 39 | logger: 40 | use_wandb: false 41 | resume_id: null 42 | 43 | losses: 44 | - { 45 | name: losses.loss.CharbonnierLoss, 46 | nickname: l_rec, 47 | params: { 48 | loss_weight: 1.0, 49 | keys: [ imgt_pred, imgt ] 50 | } 51 | } 52 | - { 53 | name: losses.loss.TernaryLoss, 54 | nickname: l_ter, 55 | params: { 56 | loss_weight: 1.0, 57 | keys: [ imgt_pred, imgt ] 58 | } 59 | } 60 | # - { 61 | # name: losses.loss.GeometryLoss, 62 | # nickname: l_geo, 63 | # params: { 64 | # loss_weight: 0.01, 65 | # keys: [ ft_pred, ft_gt ] 66 | # } 67 | # } 68 | -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/cfgs/SDI-IFRNet_triplet_wofloloss.yaml: -------------------------------------------------------------------------------- 1 | # screen sh ./scripts/train.sh 4 cfgs/SDI-IFRNet_triplet_wofloloss.yaml 14514 2 | exp_name: 300epoch_bs24_lr1e-4 3 | seed: 2023 4 | epochs: 300 5 | distributed: true 6 | lr: 1e-4 7 | lr_min: 1e-5 8 | weight_decay: 1e-6 9 | resume_state: null 10 | save_dir: experiments 11 | eval_interval: 1 12 | 13 | network: 14 | name: networks.SDI-IFRNet.Model 15 | 16 | data: 17 | train: 18 | name: datasets.vimeo_septuplet_datasets.Vimeo90K_Train_Dataset 19 | params: 20 | dataset_dir: /mnt/disks/ssd0/dataset/vimeo_triplet 21 | use_sdi: true 22 | use_mask: false 23 | triplet: true 24 | val: 25 | name: datasets.vimeo_septuplet_datasets.Vimeo90K_Test_Dataset 26 | params: 27 | dataset_dir: /mnt/disks/ssd0/dataset/vimeo_triplet 28 | use_sdi: true 29 | use_mask: false 30 | triplet: true 31 | train_loader: 32 | batch_size: 24 33 | num_workers: 12 34 | val_loader: 35 | batch_size: 24 36 | num_workers: 3 37 | 38 | logger: 39 | use_wandb: false 40 | resume_id: null 41 | 42 | losses: 43 | - { 44 | name: losses.loss.CharbonnierLoss, 45 | nickname: l_rec, 46 | params: { 47 | loss_weight: 1.0, 48 | keys: [ imgt_pred, imgt ] 49 | } 50 | } 51 | - { 52 | name: losses.loss.TernaryLoss, 53 | nickname: l_ter, 54 | params: { 55 | loss_weight: 1.0, 56 | keys: [ imgt_pred, imgt ] 57 | } 58 | } 59 | # - { 60 | # name: losses.loss.GeometryLoss, 61 | # nickname: l_geo, 62 | # params: { 63 | # loss_weight: 0.01, 64 | # keys: [ ft_pred, ft_gt ] 65 | # } 66 | # } 67 | -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/cfgs/SDI-R-AMT-S_septuplet_wofloloss.yaml: -------------------------------------------------------------------------------- 1 | # CUDA_VISISIBLE_DEVICES=4,5,6,7 screen sh ./scripts/train.sh 4 cfgs/SDI-R-AMT-S_septuplet_wofloloss.yaml 14514 2 | # CUDA_VISISIBLE_DEVICES=0,1,2,3,4,5,6,7 screen sh ./scripts/train.sh 8 cfgs/SDI-R-AMT-S_septuplet_wofloloss.yaml 14514 3 | exp_name: 400epoch_bs24_lr2e-4 4 | seed: 2023 5 | epochs: 400 6 | distributed: true 7 | lr: 2e-4 8 | lr_min: 2e-5 9 | weight_decay: 0.0 10 | resume_state: null 11 | save_dir: experiments 12 | eval_interval: 1 13 | 14 | network: 15 | name: networks.SDI-R-AMT-S.Model 16 | params: 17 | corr_radius: 3 18 | corr_lvls: 4 19 | num_flows: 3 20 | 21 | data: 22 | train: 23 | name: datasets.vimeo_septuplet_recur_datasets.Vimeo90K_Train_Dataset 24 | params: 25 | dataset_dir: /mnt/disks/ssd0/dataset/vimeo_septuplet 26 | use_sdi: true 27 | use_mask: false 28 | val: 29 | name: datasets.vimeo_septuplet_recur_datasets.Vimeo90K_Test_Dataset 30 | params: 31 | dataset_dir: /mnt/disks/ssd0/dataset/vimeo_septuplet 32 | use_sdi: true 33 | use_mask: false 34 | train_loader: 35 | batch_size: 24 36 | num_workers: 12 37 | val_loader: 38 | batch_size: 24 39 | num_workers: 3 40 | 41 | logger: 42 | use_wandb: false 43 | resume_id: null 44 | 45 | losses: 46 | - { 47 | name: losses.loss.CharbonnierLoss, 48 | nickname: l_rec, 49 | params: { 50 | loss_weight: 1.0, 51 | keys: [ imgt_pred, imgt ] 52 | } 53 | } 54 | - { 55 | name: losses.loss.TernaryLoss, 56 | nickname: l_ter, 57 | params: { 58 | loss_weight: 1.0, 59 | keys: [ imgt_pred, imgt ] 60 | } 61 | } 62 | 63 | -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/cfgs/SDI-R-AMT-S_v1_septuplet_wofloloss.yaml: -------------------------------------------------------------------------------- 1 | # screen sh ./scripts/train.sh 4 cfgs/SDI-R-AMT-S_v1_septuplet_wofloloss.yaml 14514 2 | #exp_name: 400epoch_bs24_lr2e-4 3 | exp_name: DR-AMT-S 4 | seed: 2023 5 | epochs: 400 6 | distributed: true 7 | lr: 2e-4 8 | lr_min: 2e-5 9 | weight_decay: 0.0 10 | resume_state: null 11 | save_dir: experiments 12 | eval_interval: 1 13 | 14 | network: 15 | name: networks.SDI-R-AMT-S_v1.Model 16 | params: 17 | corr_radius: 3 18 | corr_lvls: 4 19 | num_flows: 3 20 | 21 | data: 22 | train: 23 | name: datasets.vimeo_septuplet_recur_datasets.Vimeo90K_Train_Dataset 24 | params: 25 | # dataset_dir: /mnt/disks/ssd0/dataset/vimeo_septuplet 26 | dataset_dir: ../../dataset/vimeo_septuplet 27 | use_sdi: true 28 | use_mask: false 29 | val: 30 | name: datasets.vimeo_septuplet_recur_datasets.Vimeo90K_Test_Dataset 31 | params: 32 | # dataset_dir: /mnt/disks/ssd0/dataset/vimeo_septuplet 33 | dataset_dir: ../../dataset/vimeo_septuplet 34 | use_sdi: true 35 | use_mask: false 36 | train_loader: 37 | batch_size: 24 38 | num_workers: 12 39 | val_loader: 40 | batch_size: 24 41 | num_workers: 3 42 | 43 | logger: 44 | use_wandb: false 45 | resume_id: null 46 | 47 | losses: 48 | - { 49 | name: losses.loss.CharbonnierLoss, 50 | nickname: l_rec, 51 | params: { 52 | loss_weight: 1.0, 53 | keys: [ imgt_pred, imgt ] 54 | } 55 | } 56 | - { 57 | name: losses.loss.TernaryLoss, 58 | nickname: l_ter, 59 | params: { 60 | loss_weight: 1.0, 61 | keys: [ imgt_pred, imgt ] 62 | } 63 | } 64 | 65 | -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/cfgs/SDI-R-AMT-S_v2_septuplet_wofloloss.yaml: -------------------------------------------------------------------------------- 1 | # screen sh ./scripts/train.sh 4 cfgs/SDI-R-AMT-S_v2_septuplet_wofloloss.yaml 14514 2 | exp_name: 400epoch_bs24_lr2e-4 3 | seed: 2023 4 | epochs: 400 5 | distributed: true 6 | lr: 2e-4 7 | lr_min: 2e-5 8 | weight_decay: 0.0 9 | resume_state: null 10 | save_dir: experiments 11 | eval_interval: 1 12 | 13 | network: 14 | name: networks.SDI-R-AMT-S_v2.Model 15 | params: 16 | corr_radius: 3 17 | corr_lvls: 4 18 | num_flows: 3 19 | 20 | data: 21 | train: 22 | name: datasets.vimeo_septuplet_recur_datasets.Vimeo90K_Train_Dataset 23 | params: 24 | dataset_dir: /mnt/disks/ssd0/dataset/vimeo_septuplet 25 | use_sdi: true 26 | use_mask: false 27 | val: 28 | name: datasets.vimeo_septuplet_recur_datasets.Vimeo90K_Test_Dataset 29 | params: 30 | dataset_dir: /mnt/disks/ssd0/dataset/vimeo_septuplet 31 | use_sdi: true 32 | use_mask: false 33 | train_loader: 34 | batch_size: 24 35 | num_workers: 12 36 | val_loader: 37 | batch_size: 24 38 | num_workers: 3 39 | 40 | logger: 41 | use_wandb: false 42 | resume_id: null 43 | 44 | losses: 45 | - { 46 | name: losses.loss.CharbonnierLoss, 47 | nickname: l_rec, 48 | params: { 49 | loss_weight: 1.0, 50 | keys: [ imgt_pred, imgt ] 51 | } 52 | } 53 | - { 54 | name: losses.loss.TernaryLoss, 55 | nickname: l_ter, 56 | params: { 57 | loss_weight: 1.0, 58 | keys: [ imgt_pred, imgt ] 59 | } 60 | } 61 | 62 | -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/cfgs/SDI-R-IFRNet_septuplet_wofloloss.yaml: -------------------------------------------------------------------------------- 1 | # screen sh ./scripts/train.sh 4 cfgs/SDI-R-IFRNet_septuplet_wofloloss.yaml 14514 2 | #exp_name: 300epoch_bs24_lr1e-4 3 | exp_name: DR-IFRNet 4 | seed: 2023 5 | epochs: 300 6 | distributed: true 7 | lr: 1e-4 8 | lr_min: 1e-5 9 | weight_decay: 1e-6 10 | resume_state: null 11 | save_dir: experiments 12 | eval_interval: 1 13 | 14 | network: 15 | name: networks.SDI-R-IFRNet.Model 16 | 17 | data: 18 | train: 19 | name: datasets.vimeo_septuplet_recur_datasets.Vimeo90K_Train_Dataset 20 | params: 21 | # dataset_dir: /mnt/disks/ssd0/dataset/vimeo_septuplet 22 | dataset_dir: ../../dataset/vimeo_septuplet 23 | use_sdi: true 24 | use_mask: false 25 | val: 26 | name: datasets.vimeo_septuplet_recur_datasets.Vimeo90K_Test_Dataset 27 | params: 28 | # dataset_dir: /mnt/disks/ssd0/dataset/vimeo_septuplet 29 | dataset_dir: ../../dataset/vimeo_septuplet 30 | use_sdi: true 31 | use_mask: false 32 | train_loader: 33 | batch_size: 24 34 | num_workers: 12 35 | val_loader: 36 | batch_size: 24 37 | num_workers: 3 38 | 39 | logger: 40 | use_wandb: false 41 | resume_id: null 42 | 43 | losses: 44 | - { 45 | name: losses.loss.CharbonnierLoss, 46 | nickname: l_rec, 47 | params: { 48 | loss_weight: 1.0, 49 | keys: [ imgt_pred, imgt ] 50 | } 51 | } 52 | - { 53 | name: losses.loss.TernaryLoss, 54 | nickname: l_ter, 55 | params: { 56 | loss_weight: 1.0, 57 | keys: [ imgt_pred, imgt ] 58 | } 59 | } 60 | # - { 61 | # name: losses.loss.GeometryLoss, 62 | # nickname: l_geo, 63 | # params: { 64 | # loss_weight: 0.01, 65 | # keys: [ ft_pred, ft_gt ] 66 | # } 67 | # } 68 | -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzh-tech/InterpAny-Clearer/50ece06a1fca91bfeed2bd3f2e724d696833ff63/models/DI-AMT-and-IFRNet/datasets/__init__.py -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/environment.yaml: -------------------------------------------------------------------------------- 1 | name: amt 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - python=3.8.5 8 | - pip=20.3 9 | - cudatoolkit=11.3 10 | - pytorch=1.11.0 11 | - torchvision=0.12.0 12 | - numpy=1.21.5 13 | - pip: 14 | - opencv-python==4.1.2.30 15 | - imageio==2.19.3 16 | - omegaconf==2.3.0 17 | - Pillow==9.4.0 18 | - tqdm==4.64.1 19 | - wandb==0.12.21 -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/flow_generation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzh-tech/InterpAny-Clearer/50ece06a1fca91bfeed2bd3f2e724d696833ff63/models/DI-AMT-and-IFRNet/flow_generation/__init__.py -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/flow_generation/gen_flow.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import argparse 5 | import numpy as np 6 | import os.path as osp 7 | import torch.nn.functional as F 8 | 9 | sys.path.append('.') 10 | from utils.utils import read, write 11 | from flow_generation.liteflownet.run import estimate 12 | 13 | parser = argparse.ArgumentParser( 14 | prog='AMT', 15 | description='Flow generation', 16 | ) 17 | parser.add_argument('-r', '--root', default='data/vimeo_triplet') 18 | args = parser.parse_args() 19 | 20 | vimeo90k_dir = args.root 21 | vimeo90k_sequences_dir = osp.join(vimeo90k_dir, 'sequences') 22 | vimeo90k_flow_dir = osp.join(vimeo90k_dir, 'flow') 23 | 24 | 25 | def pred_flow(img1, img2): 26 | img1 = torch.from_numpy(img1).float().permute(2, 0, 1) / 255.0 27 | img2 = torch.from_numpy(img2).float().permute(2, 0, 1) / 255.0 28 | 29 | flow = estimate(img1, img2) 30 | 31 | flow = flow.permute(1, 2, 0).cpu().numpy() 32 | return flow 33 | 34 | 35 | print('Built Flow Path') 36 | if not osp.exists(vimeo90k_flow_dir): 37 | os.makedirs(vimeo90k_flow_dir) 38 | 39 | for sequences_path in sorted(os.listdir(vimeo90k_sequences_dir)): 40 | vimeo90k_sequences_path_dir = osp.join(vimeo90k_sequences_dir, sequences_path) 41 | vimeo90k_flow_path_dir = osp.join(vimeo90k_flow_dir, sequences_path) 42 | if not osp.exists(vimeo90k_flow_path_dir): 43 | os.mkdir(vimeo90k_flow_path_dir) 44 | 45 | for sequences_id in sorted(os.listdir(vimeo90k_sequences_path_dir)): 46 | vimeo90k_flow_id_dir = osp.join(vimeo90k_flow_path_dir, sequences_id) 47 | if not osp.exists(vimeo90k_flow_id_dir): 48 | os.mkdir(vimeo90k_flow_id_dir) 49 | 50 | for sequences_path in sorted(os.listdir(vimeo90k_sequences_dir)): 51 | vimeo90k_sequences_path_dir = os.path.join(vimeo90k_sequences_dir, sequences_path) 52 | vimeo90k_flow_path_dir = os.path.join(vimeo90k_flow_dir, sequences_path) 53 | 54 | for sequences_id in sorted(os.listdir(vimeo90k_sequences_path_dir)): 55 | vimeo90k_sequences_id_dir = os.path.join(vimeo90k_sequences_path_dir, sequences_id) 56 | vimeo90k_flow_id_dir = os.path.join(vimeo90k_flow_path_dir, sequences_id) 57 | 58 | img0_path = vimeo90k_sequences_id_dir + '/im1.png' 59 | imgt_path = vimeo90k_sequences_id_dir + '/im2.png' 60 | img1_path = vimeo90k_sequences_id_dir + '/im3.png' 61 | flow_t0_path = vimeo90k_flow_id_dir + '/flow_t0.flo' 62 | flow_t1_path = vimeo90k_flow_id_dir + '/flow_t1.flo' 63 | 64 | img0 = read(img0_path) 65 | imgt = read(imgt_path) 66 | img1 = read(img1_path) 67 | 68 | flow_t0 = pred_flow(imgt, img0) 69 | flow_t1 = pred_flow(imgt, img1) 70 | 71 | write(flow_t0_path, flow_t0) 72 | write(flow_t1_path, flow_t1) 73 | 74 | print('Written Sequences {}'.format(sequences_path)) 75 | -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/flow_generation/gen_multi_flow.py: -------------------------------------------------------------------------------- 1 | """ 2 | python flow_generation/gen_multi_flow.py -r '../dataset/vimeo_septuplet' 3 | """ 4 | import os 5 | import sys 6 | import torch 7 | import argparse 8 | import numpy as np 9 | import os.path as osp 10 | import torch.nn.functional as F 11 | from itertools import combinations 12 | 13 | sys.path.append('.') 14 | from utils.utils import read, write 15 | from flow_generation.liteflownet.run import estimate 16 | 17 | parser = argparse.ArgumentParser( 18 | prog='AMT', 19 | description='Flow generation', 20 | ) 21 | parser.add_argument('-r', '--root', default='/mnt/disks/ssd0/dataset/vimeo_septuplet') 22 | parser.add_argument('--worker_id', type=int, default=0) 23 | parser.add_argument('--num_workers', type=int, default=16) 24 | args = parser.parse_args() 25 | 26 | vimeo90k_dir = args.root 27 | vimeo90k_sequences_dir = osp.join(vimeo90k_dir, 'sequences') 28 | vimeo90k_flow_dir = osp.join(vimeo90k_dir, 'flow') 29 | 30 | 31 | def pred_flow(img1, img2): 32 | img1 = torch.from_numpy(img1).float().permute(2, 0, 1) / 255.0 33 | img2 = torch.from_numpy(img2).float().permute(2, 0, 1) / 255.0 34 | 35 | flow = estimate(img1, img2) 36 | 37 | flow = flow.permute(1, 2, 0).cpu().numpy() 38 | return flow 39 | 40 | 41 | print('Built Flow Path') 42 | if not osp.exists(vimeo90k_flow_dir): 43 | os.makedirs(vimeo90k_flow_dir) 44 | 45 | sequences = sorted(os.listdir(vimeo90k_sequences_dir)) 46 | for sequences_path in sequences[args.worker_id::args.num_workers]: 47 | vimeo90k_sequences_path_dir = osp.join(vimeo90k_sequences_dir, sequences_path) 48 | vimeo90k_flow_path_dir = osp.join(vimeo90k_flow_dir, sequences_path) 49 | if not osp.exists(vimeo90k_flow_path_dir): 50 | os.mkdir(vimeo90k_flow_path_dir) 51 | 52 | for sequences_id in sorted(os.listdir(vimeo90k_sequences_path_dir)): 53 | vimeo90k_flow_id_dir = osp.join(vimeo90k_flow_path_dir, sequences_id) 54 | if not osp.exists(vimeo90k_flow_id_dir): 55 | os.mkdir(vimeo90k_flow_id_dir) 56 | 57 | for sequences_path in sequences[args.worker_id::args.num_workers]: 58 | vimeo90k_sequences_path_dir = os.path.join(vimeo90k_sequences_dir, sequences_path) 59 | vimeo90k_flow_path_dir = os.path.join(vimeo90k_flow_dir, sequences_path) 60 | 61 | for sequences_id in sorted(os.listdir(vimeo90k_sequences_path_dir)): 62 | vimeo90k_sequences_id_dir = os.path.join(vimeo90k_sequences_path_dir, sequences_id) 63 | vimeo90k_flow_id_dir = os.path.join(vimeo90k_flow_path_dir, sequences_id) 64 | 65 | img_paths = [osp.join(vimeo90k_sequences_id_dir, 'im{}.png'.format(i + 1)) for i in range(7)] 66 | combs = list(combinations(list(range(7)), r=3)) 67 | 68 | for comb in combs: 69 | img0_path = img_paths[comb[0]] 70 | imgt_path = img_paths[comb[1]] 71 | img1_path = img_paths[comb[2]] 72 | 73 | flow_t0_path = osp.join(vimeo90k_flow_id_dir + '/flow_{}_{}.flo'.format(comb[1], comb[0])) 74 | flow_t1_path = osp.join(vimeo90k_flow_id_dir + '/flow_{}_{}.flo'.format(comb[1], comb[2])) 75 | 76 | img0 = read(img0_path) 77 | imgt = read(imgt_path) 78 | img1 = read(img1_path) 79 | 80 | flow_t0 = pred_flow(imgt, img0) 81 | flow_t1 = pred_flow(imgt, img1) 82 | 83 | write(flow_t0_path, flow_t0) 84 | write(flow_t1_path, flow_t1) 85 | 86 | print('{} finished'.format(vimeo90k_sequences_id_dir)) 87 | 88 | print('Written Sequences {}'.format(sequences_path)) 89 | -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/flow_generation/liteflownet/README.md: -------------------------------------------------------------------------------- 1 | # pytorch-liteflownet 2 | This is a personal reimplementation of LiteFlowNet [1] using PyTorch. Should you be making use of this work, please cite the paper accordingly. Also, make sure to adhere to the licensing terms of the authors. Should you be making use of this particular implementation, please acknowledge it appropriately [2]. 3 | 4 | Paper 5 | 6 | For the original Caffe version of this work, please see: https://github.com/twhui/LiteFlowNet 7 |
8 | Other optical flow implementations from me: [pytorch-pwc](https://github.com/sniklaus/pytorch-pwc), [pytorch-unflow](https://github.com/sniklaus/pytorch-unflow), [pytorch-spynet](https://github.com/sniklaus/pytorch-spynet) 9 | 10 | ## setup 11 | The correlation layer is implemented in CUDA using CuPy, which is why CuPy is a required dependency. It can be installed using `pip install cupy` or alternatively using one of the provided [binary packages](https://docs.cupy.dev/en/stable/install.html#installing-cupy) as outlined in the CuPy repository. If you would like to use Docker, you can take a look at [this](https://github.com/sniklaus/pytorch-liteflownet/pull/43) pull request to get started. 12 | 13 | ## usage 14 | To run it on your own pair of images, use the following command. You can choose between three models, please make sure to see their paper / the code for more details. 15 | 16 | ``` 17 | python run.py --model default --one ./images/one.png --two ./images/two.png --out ./out.flo 18 | ``` 19 | 20 | I am afraid that I cannot guarantee that this reimplementation is correct. However, it produced results pretty much identical to the implementation of the original authors in the examples that I tried. There are some numerical deviations that stem from differences in the `DownsampleLayer` of Caffe and the `torch.nn.functional.interpolate` function of PyTorch. Please feel free to contribute to this repository by submitting issues and pull requests. 21 | 22 | ## comparison 23 |

Comparison

24 | 25 | ## license 26 | As stated in the licensing terms of the authors of the paper, their material is provided for research purposes only. Please make sure to further consult their licensing terms. 27 | 28 | ## references 29 | ``` 30 | [1] @inproceedings{Hui_CVPR_2018, 31 | author = {Tak-Wai Hui and Xiaoou Tang and Chen Change Loy}, 32 | title = {{LiteFlowNet}: A Lightweight Convolutional Neural Network for Optical Flow Estimation}, 33 | booktitle = {IEEE Conference on Computer Vision and Pattern Recognition}, 34 | year = {2018} 35 | } 36 | ``` 37 | 38 | ``` 39 | [2] @misc{pytorch-liteflownet, 40 | author = {Simon Niklaus}, 41 | title = {A Reimplementation of {LiteFlowNet} Using {PyTorch}}, 42 | year = {2019}, 43 | howpublished = {\url{https://github.com/sniklaus/pytorch-liteflownet}} 44 | } 45 | ``` -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/flow_generation/liteflownet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzh-tech/InterpAny-Clearer/50ece06a1fca91bfeed2bd3f2e724d696833ff63/models/DI-AMT-and-IFRNet/flow_generation/liteflownet/__init__.py -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/flow_generation/liteflownet/correlation/README.md: -------------------------------------------------------------------------------- 1 | This is an adaptation of the FlowNet2 implementation in order to compute cost volumes. Should you be making use of this work, please make sure to adhere to the licensing terms of the original authors. Should you be making use or modify this particular implementation, please acknowledge it appropriately. -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/flow_generation/multiprocess_gen_multi_flow.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import subprocess 3 | import time 4 | import os 5 | import os.path as osp 6 | 7 | if __name__ == '__main__': 8 | """ 9 | cmd: 10 | sudo chmod -R +w /mnt/disks/ssd0/dataset/vimeo_septuplet/flow/ 11 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 screen python flow_generation/multiprocess_gen_multi_flow.py --num_gpus 8 --workers_per_gpu 8 12 | """ 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--num_gpus', type=int, default=8) 15 | parser.add_argument('--workers_per_gpu', type=int, default=4) 16 | args = parser.parse_args() 17 | 18 | # split sample_paths 19 | num_processes = args.num_gpus * args.workers_per_gpu 20 | 21 | # launch multiprocess for masks generation 22 | pool = [] 23 | for i in range(num_processes): 24 | cmd = ['python', 'flow_generation/gen_multi_flow.py', 25 | '--root=/mnt/disks/ssd0/dataset/vimeo_septuplet', 26 | '--worker_id={}'.format(i), '--num_workers={}'.format(num_processes)] 27 | env = { 28 | **os.environ, 29 | 'CUDA_VISIBLE_DEVICES': str(i // args.workers_per_gpu) 30 | } 31 | p = subprocess.Popen(cmd, env=env) 32 | pool.append(p) 33 | exit_codes = [p.wait() for p in pool] 34 | -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzh-tech/InterpAny-Clearer/50ece06a1fca91bfeed2bd3f2e724d696833ff63/models/DI-AMT-and-IFRNet/losses/__init__.py -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzh-tech/InterpAny-Clearer/50ece06a1fca91bfeed2bd3f2e724d696833ff63/models/DI-AMT-and-IFRNet/metrics/__init__.py -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzh-tech/InterpAny-Clearer/50ece06a1fca91bfeed2bd3f2e724d696833ff63/models/DI-AMT-and-IFRNet/networks/__init__.py -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/networks/blocks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzh-tech/InterpAny-Clearer/50ece06a1fca91bfeed2bd3f2e724d696833ff63/models/DI-AMT-and-IFRNet/networks/blocks/__init__.py -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/networks/blocks/multi_flow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from utils.flow_utils import warp 4 | from networks.blocks.ifrnet import ( 5 | convrelu, resize, 6 | ResBlock, 7 | ) 8 | 9 | 10 | def multi_flow_combine(comb_block, img0, img1, flow0, flow1, 11 | mask=None, img_res=None, mean=None): 12 | ''' 13 | A parallel implementation of multiple flow field warping 14 | comb_block: An nn.Seqential object. 15 | img shape: [b, c, h, w] 16 | flow shape: [b, 2*num_flows, h, w] 17 | mask (opt): 18 | If 'mask' is None, the function conduct a simple average. 19 | img_res (opt): 20 | If 'img_res' is None, the function adds zero instead. 21 | mean (opt): 22 | If 'mean' is None, the function adds zero instead. 23 | ''' 24 | b, c, h, w = flow0.shape 25 | num_flows = c // 2 26 | flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) 27 | flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) 28 | 29 | mask = mask.reshape(b, num_flows, 1, h, w 30 | ).reshape(-1, 1, h, w) if mask is not None else None 31 | img_res = img_res.reshape(b, num_flows, 3, h, w 32 | ).reshape(-1, 3, h, w) if img_res is not None else 0 33 | img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w) 34 | img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w) 35 | mean = torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1 36 | ) if mean is not None else 0 37 | 38 | img0_warp = warp(img0, flow0) 39 | img1_warp = warp(img1, flow1) 40 | img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res 41 | img_warps = img_warps.reshape(b, num_flows, 3, h, w) 42 | imgt_pred = img_warps.mean(1) + comb_block(img_warps.view(b, -1, h, w)) 43 | return imgt_pred 44 | 45 | 46 | class MultiFlowDecoder(nn.Module): 47 | def __init__(self, in_ch, skip_ch, num_flows=3): 48 | super(MultiFlowDecoder, self).__init__() 49 | self.num_flows = num_flows 50 | self.convblock = nn.Sequential( 51 | convrelu(in_ch * 3 + 4, in_ch * 3), 52 | ResBlock(in_ch * 3, skip_ch), 53 | nn.ConvTranspose2d(in_ch * 3, 8 * num_flows, 4, 2, 1, bias=True) 54 | ) 55 | 56 | def forward(self, ft_, f0, f1, flow0, flow1): 57 | n = self.num_flows 58 | f0_warp = warp(f0, flow0) 59 | f1_warp = warp(f1, flow1) 60 | out = self.convblock(torch.cat([ft_, f0_warp, f1_warp, flow0, flow1], 1)) 61 | delta_flow0, delta_flow1, mask, img_res = torch.split(out, [2 * n, 2 * n, n, 3 * n], 1) 62 | mask = torch.sigmoid(mask) 63 | 64 | flow0 = delta_flow0 + 2.0 * resize(flow0, scale_factor=2.0 65 | ).repeat(1, self.num_flows, 1, 1) 66 | flow1 = delta_flow1 + 2.0 * resize(flow1, scale_factor=2.0 67 | ).repeat(1, self.num_flows, 1, 1) 68 | 69 | return flow0, flow1, mask, img_res 70 | -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/networks/blocks/multi_flow_recur.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from utils.flow_utils import warp 4 | from networks.blocks.ifrnet import ( 5 | convrelu, resize, 6 | ResBlock, 7 | ) 8 | 9 | 10 | def multi_flow_combine(comb_block, img0, img1, flow0, flow1, 11 | mask=None, img_res=None, mean=None): 12 | ''' 13 | A parallel implementation of multiple flow field warping 14 | comb_block: An nn.Seqential object. 15 | img shape: [b, c, h, w] 16 | flow shape: [b, 2*num_flows, h, w] 17 | mask (opt): 18 | If 'mask' is None, the function conduct a simple average. 19 | img_res (opt): 20 | If 'img_res' is None, the function adds zero instead. 21 | mean (opt): 22 | If 'mean' is None, the function adds zero instead. 23 | ''' 24 | b, c, h, w = flow0.shape 25 | num_flows = c // 2 26 | flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) 27 | flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) 28 | 29 | mask = mask.reshape(b, num_flows, 1, h, w 30 | ).reshape(-1, 1, h, w) if mask is not None else None 31 | img_res = img_res.reshape(b, num_flows, 3, h, w 32 | ).reshape(-1, 3, h, w) if img_res is not None else 0 33 | img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w) 34 | img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w) 35 | mean = torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1 36 | ) if mean is not None else 0 37 | 38 | img0_warp = warp(img0, flow0) 39 | img1_warp = warp(img1, flow1) 40 | img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res 41 | img_warps = img_warps.reshape(b, num_flows, 3, h, w) 42 | imgt_pred = img_warps.mean(1) + comb_block(img_warps.view(b, -1, h, w)) 43 | return imgt_pred 44 | 45 | 46 | class MultiFlowDecoder(nn.Module): 47 | def __init__(self, in_ch, skip_ch, num_flows=3): 48 | super(MultiFlowDecoder, self).__init__() 49 | self.num_flows = num_flows 50 | self.convblock = nn.Sequential( 51 | convrelu(in_ch * 4 + 4, in_ch * 3), 52 | ResBlock(in_ch * 3, skip_ch), 53 | nn.ConvTranspose2d(in_ch * 3, 8 * num_flows, 4, 2, 1, bias=True) 54 | ) 55 | 56 | def forward(self, ft_, f0, f1, f_ref, flow0, flow1): 57 | n = self.num_flows 58 | f0_warp = warp(f0, flow0) 59 | f1_warp = warp(f1, flow1) 60 | out = self.convblock(torch.cat([ft_, f0_warp, f1_warp, f_ref, flow0, flow1], 1)) 61 | delta_flow0, delta_flow1, mask, img_res = torch.split(out, [2 * n, 2 * n, n, 3 * n], 1) 62 | mask = torch.sigmoid(mask) 63 | 64 | flow0 = delta_flow0 + 2.0 * resize(flow0, scale_factor=2.0 65 | ).repeat(1, self.num_flows, 1, 1) 66 | flow1 = delta_flow1 + 2.0 * resize(flow1, scale_factor=2.0 67 | ).repeat(1, self.num_flows, 1, 1) 68 | 69 | return flow0, flow1, mask, img_res 70 | -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/networks/blocks/multi_flow_recur_v1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from utils.flow_utils import warp 4 | from networks.blocks.ifrnet import ( 5 | convrelu, resize, 6 | ResBlock, 7 | ) 8 | 9 | 10 | def multi_flow_combine(comb_block, img0, img1, flow0, flow1, 11 | mask=None, img_res=None, mean=None): 12 | ''' 13 | A parallel implementation of multiple flow field warping 14 | comb_block: An nn.Seqential object. 15 | img shape: [b, c, h, w] 16 | flow shape: [b, 2*num_flows, h, w] 17 | mask (opt): 18 | If 'mask' is None, the function conduct a simple average. 19 | img_res (opt): 20 | If 'img_res' is None, the function adds zero instead. 21 | mean (opt): 22 | If 'mean' is None, the function adds zero instead. 23 | ''' 24 | b, c, h, w = flow0.shape 25 | num_flows = c // 2 26 | flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) 27 | flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) 28 | 29 | mask = mask.reshape(b, num_flows, 1, h, w 30 | ).reshape(-1, 1, h, w) if mask is not None else None 31 | img_res = img_res.reshape(b, num_flows, 3, h, w 32 | ).reshape(-1, 3, h, w) if img_res is not None else 0 33 | img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w) 34 | img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w) 35 | mean = torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1 36 | ) if mean is not None else 0 37 | 38 | img0_warp = warp(img0, flow0) 39 | img1_warp = warp(img1, flow1) 40 | img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res 41 | img_warps = img_warps.reshape(b, num_flows, 3, h, w) 42 | imgt_pred = img_warps.mean(1) + comb_block(img_warps.view(b, -1, h, w)) 43 | return imgt_pred 44 | 45 | 46 | class MultiFlowDecoder(nn.Module): 47 | def __init__(self, in_ch, skip_ch, num_flows=3): 48 | super(MultiFlowDecoder, self).__init__() 49 | self.num_flows = num_flows 50 | self.convblock = nn.Sequential( 51 | convrelu(in_ch * 3 + 4, in_ch * 3), 52 | ResBlock(in_ch * 3, skip_ch), 53 | nn.ConvTranspose2d(in_ch * 3, 8 * num_flows, 4, 2, 1, bias=True) 54 | ) 55 | 56 | def forward(self, ft_, f0, f1, flow0, flow1): 57 | n = self.num_flows 58 | f0_warp = warp(f0, flow0) 59 | f1_warp = warp(f1, flow1) 60 | out = self.convblock(torch.cat([ft_, f0_warp, f1_warp, flow0, flow1], 1)) 61 | delta_flow0, delta_flow1, mask, img_res = torch.split(out, [2 * n, 2 * n, n, 3 * n], 1) 62 | mask = torch.sigmoid(mask) 63 | 64 | flow0 = delta_flow0 + 2.0 * resize(flow0, scale_factor=2.0 65 | ).repeat(1, self.num_flows, 1, 1) 66 | flow1 = delta_flow1 + 2.0 * resize(flow1, scale_factor=2.0 67 | ).repeat(1, self.num_flows, 1, 1) 68 | 69 | return flow0, flow1, mask, img_res 70 | -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/networks/blocks/multi_flow_recur_v2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from utils.flow_utils import warp 4 | from networks.blocks.ifrnet import ( 5 | convrelu, resize, 6 | ResBlock, 7 | ) 8 | 9 | 10 | def multi_flow_combine(comb_block, img0, img1, flow0, flow1, 11 | mask=None, img_res=None, mean=None): 12 | ''' 13 | A parallel implementation of multiple flow field warping 14 | comb_block: An nn.Seqential object. 15 | img shape: [b, c, h, w] 16 | flow shape: [b, 2*num_flows, h, w] 17 | mask (opt): 18 | If 'mask' is None, the function conduct a simple average. 19 | img_res (opt): 20 | If 'img_res' is None, the function adds zero instead. 21 | mean (opt): 22 | If 'mean' is None, the function adds zero instead. 23 | ''' 24 | b, c, h, w = flow0.shape 25 | num_flows = c // 2 26 | flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) 27 | flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) 28 | 29 | mask = mask.reshape(b, num_flows, 1, h, w 30 | ).reshape(-1, 1, h, w) if mask is not None else None 31 | img_res = img_res.reshape(b, num_flows, 3, h, w 32 | ).reshape(-1, 3, h, w) if img_res is not None else 0 33 | img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w) 34 | img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w) 35 | mean = torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1 36 | ) if mean is not None else 0 37 | 38 | img0_warp = warp(img0, flow0) 39 | img1_warp = warp(img1, flow1) 40 | img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res 41 | img_warps = img_warps.reshape(b, num_flows, 3, h, w) 42 | imgt_pred = img_warps.mean(1) + comb_block(img_warps.view(b, -1, h, w)) 43 | return imgt_pred 44 | 45 | 46 | class MultiFlowDecoder(nn.Module): 47 | def __init__(self, in_ch, skip_ch, num_flows=3): 48 | super(MultiFlowDecoder, self).__init__() 49 | self.num_flows = num_flows 50 | self.convblock = nn.Sequential( 51 | convrelu(in_ch * 3 + 4 + 1, in_ch * 3), 52 | ResBlock(in_ch * 3, skip_ch), 53 | nn.ConvTranspose2d(in_ch * 3, 8 * num_flows, 4, 2, 1, bias=True) 54 | ) 55 | 56 | def forward(self, ft_, f0, f1, flow0, flow1, embed): 57 | n = self.num_flows 58 | f0_warp = warp(f0, flow0) 59 | f1_warp = warp(f1, flow1) 60 | out = self.convblock(torch.cat([ft_, f0_warp, f1_warp, flow0, flow1, embed], 1)) 61 | delta_flow0, delta_flow1, mask, img_res = torch.split(out, [2 * n, 2 * n, n, 3 * n], 1) 62 | mask = torch.sigmoid(mask) 63 | 64 | flow0 = delta_flow0 + 2.0 * resize(flow0, scale_factor=2.0 65 | ).repeat(1, self.num_flows, 1, 1) 66 | flow1 = delta_flow1 + 2.0 * resize(flow1, scale_factor=2.0 67 | ).repeat(1, self.num_flows, 1, 1) 68 | 69 | return flow0, flow1, mask, img_res 70 | -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/scripts/benchmark_arbitrary.sh: -------------------------------------------------------------------------------- 1 | CFG=$1 2 | CKPT=$2 3 | 4 | python benchmarks/gopro.py -c $CFG -p $CKPT 5 | python benchmarks/adobe240.py -c $CFG -p $CKPT -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/scripts/benchmark_fixed.sh: -------------------------------------------------------------------------------- 1 | CFG=$1 2 | CKPT=$2 3 | 4 | python benchmarks/vimeo90k.py -c $CFG -p $CKPT 5 | python benchmarks/ucf101.py -c $CFG -p $CKPT 6 | python benchmarks/snu_film.py -c $CFG -p $CKPT 7 | python benchmarks/xiph.py -c $CFG -p $CKPT -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/scripts/train.sh: -------------------------------------------------------------------------------- 1 | NUM_GPU=$1 2 | CFG=$2 3 | PORT=$3 4 | CUDA_VISISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch \ 5 | --nproc_per_node $NUM_GPU \ 6 | --master_port $PORT train.py -c $CFG -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from shutil import copyfile 4 | import torch.distributed as dist 5 | import torch 6 | import importlib 7 | import datetime 8 | from utils.dist_utils import ( 9 | get_world_size, 10 | ) 11 | from omegaconf import OmegaConf 12 | from utils.utils import seed_all 13 | 14 | parser = argparse.ArgumentParser(description='VFI') 15 | parser.add_argument('-c', '--config', type=str) 16 | parser.add_argument('-p', '--port', default='23455', type=str) 17 | parser.add_argument('--local_rank', default='0') 18 | 19 | args = parser.parse_args() 20 | 21 | 22 | def main_worker(rank, config): 23 | if 'local_rank' not in config: 24 | config['local_rank'] = config['global_rank'] = rank 25 | if torch.cuda.is_available(): 26 | print(f'Rank {rank} is available') 27 | config['device'] = f"cuda:{rank}" 28 | if config['distributed']: 29 | dist.init_process_group(backend='nccl', 30 | timeout=datetime.timedelta(seconds=5400)) 31 | else: 32 | config['device'] = 'cpu' 33 | 34 | # cfg_name = os.path.basename(args.config).split('.')[0] 35 | # config['exp_name'] = cfg_name + '_' + config['exp_name'] 36 | config['save_dir'] = os.path.join(config['save_dir'], config['exp_name']) 37 | print(config['save_dir']) 38 | 39 | if (not config['distributed']) or rank == 0: 40 | os.makedirs(config['save_dir'], exist_ok=True) 41 | os.makedirs(f'{config["save_dir"]}/ckpts', exist_ok=True) 42 | config_path = os.path.join(config['save_dir'], 43 | args.config.split('/')[-1]) 44 | if not os.path.isfile(config_path): 45 | copyfile(args.config, config_path) 46 | print('[**] create folder {}'.format(config['save_dir'])) 47 | 48 | trainer_name = config.get('trainer_type', 'base_trainer') 49 | print(f'using GPU {rank} for training') 50 | if rank == 0: 51 | print(trainer_name) 52 | trainer_pack = importlib.import_module('trainers.' + trainer_name) 53 | trainer = trainer_pack.Trainer(config) 54 | 55 | trainer.train() 56 | 57 | 58 | if __name__ == "__main__": 59 | torch.backends.cudnn.benchmark = True 60 | cfg = OmegaConf.load(args.config) 61 | seed_all(cfg.seed) 62 | rank = int(args.local_rank) 63 | torch.cuda.set_device(torch.device(f'cuda:{rank}')) 64 | # setting distributed cfgurations 65 | cfg['world_size'] = get_world_size() 66 | cfg['local_rank'] = rank 67 | if rank == 0: 68 | print('world_size:', cfg['world_size']) 69 | main_worker(rank, cfg) 70 | -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/trainers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzh-tech/InterpAny-Clearer/50ece06a1fca91bfeed2bd3f2e724d696833ff63/models/DI-AMT-and-IFRNet/trainers/__init__.py -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/trainers/logger.py: -------------------------------------------------------------------------------- 1 | import time 2 | import wandb 3 | import shutil 4 | import logging 5 | import os.path as osp 6 | from torch.utils.tensorboard import SummaryWriter 7 | 8 | 9 | def mv_archived_logger(name): 10 | timestamp = time.strftime("%Y-%m-%d_%H:%M:%S_", time.localtime()) 11 | basename = 'archived_' + timestamp + osp.basename(name) 12 | archived_name = osp.join(osp.dirname(name), basename) 13 | shutil.move(name, archived_name) 14 | 15 | 16 | class CustomLogger: 17 | def __init__(self, common_cfg, tb_cfg=None, wandb_cfg=None, rank=0): 18 | global global_logger 19 | self.rank = rank 20 | 21 | if self.rank == 0: 22 | self.logger = logging.getLogger('VFI') 23 | self.logger.setLevel(logging.INFO) 24 | format_str = logging.Formatter(common_cfg['format']) 25 | 26 | console_handler = logging.StreamHandler() 27 | console_handler.setFormatter(format_str) 28 | 29 | if osp.exists(common_cfg['filename']): 30 | mv_archived_logger(common_cfg['filename']) 31 | 32 | file_handler = logging.FileHandler(common_cfg['filename'], 33 | common_cfg['filemode']) 34 | file_handler.setFormatter(format_str) 35 | 36 | self.logger.addHandler(console_handler) 37 | self.logger.addHandler(file_handler) 38 | self.tb_logger = None 39 | 40 | self.enable_wandb = False 41 | 42 | if wandb_cfg is not None: 43 | self.enable_wandb = True 44 | wandb.init(**wandb_cfg) 45 | 46 | if tb_cfg is not None: 47 | self.tb_logger = SummaryWriter(**tb_cfg) 48 | 49 | global_logger = self 50 | 51 | def __call__(self, msg=None, level=logging.INFO, tb_msg=None): 52 | if self.rank != 0: 53 | return 54 | if msg is not None: 55 | self.logger.log(level, msg) 56 | 57 | if self.tb_logger is not None and tb_msg is not None: 58 | self.tb_logger.add_scalar(*tb_msg) 59 | 60 | def close(self): 61 | if self.rank == 0 and self.enable_wandb: 62 | wandb.finish() 63 | -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzh-tech/InterpAny-Clearer/50ece06a1fca91bfeed2bd3f2e724d696833ff63/models/DI-AMT-and-IFRNet/utils/__init__.py -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/utils/build_utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | 4 | def base_build_fn(module, cls, params): 5 | return getattr(importlib.import_module( 6 | module, package=None), cls)(**params) 7 | 8 | 9 | def build_from_cfg(config): 10 | module, cls = config['name'].rsplit(".", 1) 11 | params = config.get('params', {}) 12 | return base_build_fn(module, cls, params) 13 | -------------------------------------------------------------------------------- /models/DI-AMT-and-IFRNet/utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | 5 | def get_world_size(): 6 | """Find OMPI world size without calling mpi functions 7 | :rtype: int 8 | """ 9 | if os.environ.get('PMI_SIZE') is not None: 10 | return int(os.environ.get('PMI_SIZE') or 1) 11 | elif os.environ.get('OMPI_COMM_WORLD_SIZE') is not None: 12 | return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1) 13 | else: 14 | return int(os.environ['WORLD_SIZE']) 15 | # return torch.cuda.device_count() 16 | 17 | 18 | def get_global_rank(): 19 | """Find OMPI world rank without calling mpi functions 20 | :rtype: int 21 | """ 22 | if os.environ.get('PMI_RANK') is not None: 23 | return int(os.environ.get('PMI_RANK') or 0) 24 | elif os.environ.get('OMPI_COMM_WORLD_RANK') is not None: 25 | return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0) 26 | else: 27 | return 0 28 | 29 | 30 | def get_local_rank(): 31 | """Find OMPI local rank without calling mpi functions 32 | :rtype: int 33 | """ 34 | if os.environ.get('MPI_LOCALRANKID') is not None: 35 | return int(os.environ.get('MPI_LOCALRANKID') or 0) 36 | elif os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') is not None: 37 | return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0) 38 | else: 39 | return 0 40 | 41 | 42 | def get_master_ip(): 43 | if os.environ.get('AZ_BATCH_MASTER_NODE') is not None: 44 | return os.environ.get('AZ_BATCH_MASTER_NODE').split(':')[0] 45 | elif os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') is not None: 46 | return os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') 47 | else: 48 | return "127.0.0.1" 49 | -------------------------------------------------------------------------------- /models/DI-EMA-VFI/benchmark/MiddleBury.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import torch 5 | import argparse 6 | import numpy as np 7 | import warnings 8 | warnings.filterwarnings('ignore') 9 | torch.set_grad_enabled(False) 10 | 11 | '''==========import from our code==========''' 12 | sys.path.append('.') 13 | import config as cfg 14 | from Trainer import Model 15 | from utils.padder import InputPadder 16 | from benchmark.utils.pytorch_msssim import ssim_matlab 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--model', default='ours', type=str) 20 | parser.add_argument('--path', type=str, required=True) 21 | args = parser.parse_args() 22 | assert args.model in ['ours', 'ours_small'], 'Model not exists!' 23 | 24 | 25 | '''==========Model setting==========''' 26 | TTA = True 27 | if args.model == 'ours_small': 28 | TTA = False 29 | cfg.MODEL_CONFIG['LOGNAME'] = 'ours_small' 30 | cfg.MODEL_CONFIG['MODEL_ARCH'] = cfg.init_model_config( 31 | F = 16, 32 | depth = [2, 2, 2, 2, 2] 33 | ) 34 | else: 35 | cfg.MODEL_CONFIG['LOGNAME'] = 'ours' 36 | cfg.MODEL_CONFIG['MODEL_ARCH'] = cfg.init_model_config( 37 | F = 32, 38 | depth = [2, 2, 2, 4, 4] 39 | ) 40 | model = Model(-1) 41 | model.load_model() 42 | model.eval() 43 | model.device() 44 | 45 | print(f'=========================Starting testing=========================') 46 | print(f'Dataset: MiddleBury Model: {model.name} TTA: {TTA}') 47 | path = args.path 48 | name = ['Beanbags', 'Dimetrodon', 'DogDance', 'Grove2', 'Grove3', 'Hydrangea', 'MiniCooper', 'RubberWhale', 'Urban2', 'Urban3', 'Venus', 'Walking'] 49 | IE_list = [] 50 | for i in name: 51 | i0 = cv2.imread(path + '/other-data/{}/frame10.png'.format(i)).transpose(2, 0, 1) / 255. 52 | i1 = cv2.imread(path + '/other-data/{}/frame11.png'.format(i)).transpose(2, 0, 1) / 255. 53 | gt = cv2.imread(path + '/other-gt-interp/{}/frame10i11.png'.format(i)) 54 | i0 = torch.from_numpy(i0).unsqueeze(0).float().cuda() 55 | i1 = torch.from_numpy(i1).unsqueeze(0).float().cuda() 56 | padder = InputPadder(i0.shape, divisor = 32) 57 | i0, i1 = padder.pad(i0, i1) 58 | pred1 = model.inference(i0, i1, TTA=TTA, fast_TTA=TTA)[0] 59 | pred = padder.unpad(pred1) 60 | out = pred.detach().cpu().numpy().transpose(1, 2, 0) 61 | out = np.round(out * 255.) 62 | IE_list.append(np.abs((out - gt * 1.0)).mean()) 63 | print(f"Avg IE: {np.mean(IE_list)}") -------------------------------------------------------------------------------- /models/DI-EMA-VFI/benchmark/SNU_FILM.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import math 5 | import torch 6 | import argparse 7 | import warnings 8 | import numpy as np 9 | from tqdm import tqdm 10 | warnings.filterwarnings('ignore') 11 | torch.set_grad_enabled(False) 12 | 13 | '''==========import from our code==========''' 14 | sys.path.append('.') 15 | import config as cfg 16 | from Trainer import Model 17 | from benchmark.utils.padder import InputPadder 18 | from benchmark.utils.pytorch_msssim import ssim_matlab 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--model', default='ours', type=str) 22 | parser.add_argument('--path', type=str, required=True) 23 | args = parser.parse_args() 24 | assert args.model in ['ours', 'ours_small'], 'Model not exists!' 25 | 26 | 27 | '''==========Model setting==========''' 28 | TTA = True 29 | down_scale = 0.5 30 | if args.model == 'ours_small': 31 | TTA = False 32 | cfg.MODEL_CONFIG['LOGNAME'] = 'ours_small' 33 | cfg.MODEL_CONFIG['MODEL_ARCH'] = cfg.init_model_config( 34 | F = 16, 35 | depth = [2, 2, 2, 2, 2] 36 | ) 37 | else: 38 | cfg.MODEL_CONFIG['LOGNAME'] = 'ours' 39 | cfg.MODEL_CONFIG['MODEL_ARCH'] = cfg.init_model_config( 40 | F = 32, 41 | depth = [2, 2, 2, 4, 4] 42 | ) 43 | model = Model(-1) 44 | model.load_model() 45 | model.eval() 46 | model.device() 47 | 48 | 49 | print(f'=========================Starting testing=========================') 50 | print(f'Dataset: SNU_FILM Model: {model.name} TTA: {TTA}') 51 | path = args.path 52 | level_list = ['test-easy.txt', 'test-medium.txt', 'test-hard.txt', 'test-extreme.txt'] 53 | for test_file in level_list: 54 | psnr_list, ssim_list = [], [] 55 | file_list = [] 56 | 57 | with open(os.path.join(path, test_file), "r") as f: 58 | for line in f: 59 | line = line.strip() 60 | file_list.append(line.split(' ')) 61 | 62 | for line in tqdm(file_list): 63 | I0_path = os.path.join(path, line[0]) 64 | I1_path = os.path.join(path, line[1]) 65 | I2_path = os.path.join(path, line[2]) 66 | I0 = cv2.imread(I0_path) 67 | I1_ = cv2.imread(I1_path) 68 | I2 = cv2.imread(I2_path) 69 | I0 = (torch.tensor(I0.transpose(2, 0, 1)).float() / 255.0).unsqueeze(0).cuda() 70 | I1 = (torch.tensor(I1_.transpose(2, 0, 1)).float() / 255.0).unsqueeze(0).cuda() 71 | I2 = (torch.tensor(I2.transpose(2, 0, 1)).float() / 255.0).unsqueeze(0).cuda() 72 | padder = InputPadder(I0.shape, divisor=32) 73 | I0, I2 = padder.pad(I0, I2) 74 | I1_pred = model.hr_inference(I0, I2, TTA, down_scale=down_scale, fast_TTA=TTA)[0] 75 | I1_pred = padder.unpad(I1_pred) 76 | ssim = ssim_matlab(I1, I1_pred.unsqueeze(0)).detach().cpu().numpy() 77 | 78 | I1_pred = I1_pred.detach().cpu().numpy().transpose(1, 2, 0) 79 | I1_ = I1_ / 255. 80 | psnr = -10 * math.log10(((I1_ - I1_pred) * (I1_ - I1_pred)).mean()) 81 | 82 | psnr_list.append(psnr) 83 | ssim_list.append(ssim) 84 | 85 | print('Testing level:' + test_file[:-4]) 86 | print('Avg PSNR: {} SSIM: {}'.format(np.mean(psnr_list), np.mean(ssim_list))) 87 | -------------------------------------------------------------------------------- /models/DI-EMA-VFI/benchmark/TimeTest.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | import sys 3 | import torch 4 | import argparse 5 | import os 6 | import warnings 7 | warnings.filterwarnings('ignore') 8 | torch.set_grad_enabled(False) 9 | 10 | '''==========import from our code==========''' 11 | sys.path.append('.') 12 | import config as cfg 13 | from Trainer import Model 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--model', default='ours_small', type=str) 17 | parser.add_argument('--H', default=256, type=int) 18 | parser.add_argument('--W', default=256, type=int) 19 | args = parser.parse_args() 20 | assert args.model in ['ours', 'ours_small'], 'Model not exists!' 21 | 22 | '''==========Model setting==========''' 23 | TTA = True 24 | if args.model == 'ours_small': 25 | TTA = False 26 | cfg.MODEL_CONFIG['LOGNAME'] = 'ours_small' 27 | cfg.MODEL_CONFIG['MODEL_ARCH'] = cfg.init_model_config( 28 | F = 16, 29 | depth = [2, 2, 2, 2, 2] 30 | ) 31 | else: 32 | cfg.MODEL_CONFIG['LOGNAME'] = 'ours' 33 | cfg.MODEL_CONFIG['MODEL_ARCH'] = cfg.init_model_config( 34 | F = 32, 35 | depth = [2, 2, 2, 4, 4] 36 | ) 37 | 38 | model = Model(-1) 39 | model.load_model() 40 | model.eval() 41 | model.device() 42 | 43 | if torch.cuda.is_available(): 44 | torch.backends.cudnn.enabled = True 45 | torch.backends.cudnn.benchmark = True 46 | 47 | H, W = args.H, args.W 48 | I0 = torch.rand(1, 3, H, W).cuda() 49 | I1 = torch.rand(1, 3, H, W).cuda() 50 | 51 | print(f'Test model: {model.name} TTA: {TTA}') 52 | with torch.no_grad(): 53 | for i in range(50): 54 | pred = model.inference(I0, I1) 55 | if torch.cuda.is_available(): 56 | torch.cuda.synchronize() 57 | time_stamp = time() 58 | for i in range(100): 59 | pred = model.inference(I0, I1) 60 | if torch.cuda.is_available(): 61 | torch.cuda.synchronize() 62 | print((time() - time_stamp) / 100 * 1000) -------------------------------------------------------------------------------- /models/DI-EMA-VFI/benchmark/UCF101.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import math 5 | import torch 6 | import argparse 7 | import warnings 8 | import numpy as np 9 | from tqdm import tqdm 10 | warnings.filterwarnings('ignore') 11 | torch.set_grad_enabled(False) 12 | 13 | '''==========import from our code==========''' 14 | sys.path.append('.') 15 | import config as cfg 16 | from Trainer import Model 17 | from benchmark.utils.pytorch_msssim import ssim_matlab 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--model', default='ours', type=str) 21 | parser.add_argument('--path', type=str, required=True) 22 | args = parser.parse_args() 23 | assert args.model in ['ours', 'ours_small'], 'Model not exists!' 24 | 25 | 26 | '''==========Model setting==========''' 27 | TTA = True 28 | if args.model == 'ours_small': 29 | TTA = False 30 | cfg.MODEL_CONFIG['LOGNAME'] = 'ours_small' 31 | cfg.MODEL_CONFIG['MODEL_ARCH'] = cfg.init_model_config( 32 | F = 16, 33 | depth = [2, 2, 2, 2, 2] 34 | ) 35 | else: 36 | cfg.MODEL_CONFIG['LOGNAME'] = 'ours' 37 | cfg.MODEL_CONFIG['MODEL_ARCH'] = cfg.init_model_config( 38 | F = 32, 39 | depth = [2, 2, 2, 4, 4] 40 | ) 41 | model = Model(-1) 42 | model.load_model() 43 | model.eval() 44 | model.device() 45 | 46 | 47 | print(f'=========================Starting testing=========================') 48 | print(f'Dataset: UCF101 Model: {model.name} TTA: {TTA}') 49 | path = args.path 50 | dirs = os.listdir(path) 51 | psnr_list, ssim_list = [], [] 52 | for d in tqdm(dirs): 53 | img0 = (path + '/' + d + '/frame_00.png') 54 | img1 = (path + '/' + d + '/frame_02.png') 55 | gt = (path + '/' + d + '/frame_01_gt.png') 56 | img0 = (torch.tensor(cv2.imread(img0).transpose(2, 0, 1) / 255.)).cuda().float().unsqueeze(0) 57 | img1 = (torch.tensor(cv2.imread(img1).transpose(2, 0, 1) / 255.)).cuda().float().unsqueeze(0) 58 | gt = (torch.tensor(cv2.imread(gt).transpose(2, 0, 1) / 255.)).cuda().float().unsqueeze(0) 59 | pred = model.inference(img0, img1, TTA=TTA, fast_TTA=TTA)[0] 60 | ssim = ssim_matlab(gt, torch.round(pred * 255).unsqueeze(0) / 255.).detach().cpu().numpy() 61 | out = pred.detach().cpu().numpy().transpose(1, 2, 0) 62 | out = np.round(out * 255) / 255. 63 | gt = gt[0].cpu().numpy().transpose(1, 2, 0) 64 | psnr = -10 * math.log10(((gt - out) * (gt - out)).mean()) 65 | psnr_list.append(psnr) 66 | ssim_list.append(ssim) 67 | print("Avg PSNR: {} SSIM: {}".format(np.mean(psnr_list), np.mean(ssim_list))) -------------------------------------------------------------------------------- /models/DI-EMA-VFI/benchmark/utils/padder.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn.functional as F 3 | 4 | 5 | 6 | class InputPadder: 7 | """ Pads images such that dimensions are divisible by divisor """ 8 | def __init__(self, dims, divisor = 16): 9 | self.ht, self.wd = dims[-2:] 10 | pad_ht = (((self.ht // divisor) + 1) * divisor - self.ht) % divisor 11 | pad_wd = (((self.wd // divisor) + 1) * divisor - self.wd) % divisor 12 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 13 | 14 | def pad(self, *inputs): 15 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 16 | 17 | def unpad(self,x): 18 | ht, wd = x.shape[-2:] 19 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 20 | return x[..., c[0]:c[1], c[2]:c[3]] -------------------------------------------------------------------------------- /models/DI-EMA-VFI/config.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import torch.nn as nn 3 | 4 | from model import feature_extractor 5 | from model import flow_estimation 6 | 7 | '''==========Model config==========''' 8 | 9 | 10 | def init_model_config(F=32, W=7, depth=[2, 2, 2, 4, 4]): 11 | '''This function should not be modified''' 12 | return { 13 | 'embed_dims': [F, 2 * F, 4 * F, 8 * F, 16 * F], 14 | 'motion_dims': [0, 0, 0, 8 * F // depth[-2], 16 * F // depth[-1]], 15 | 'num_heads': [8 * F // 32, 16 * F // 32], 16 | 'mlp_ratios': [4, 4], 17 | 'qkv_bias': True, 18 | 'norm_layer': partial(nn.LayerNorm, eps=1e-6), 19 | 'depths': depth, 20 | 'window_sizes': [W, W] 21 | }, { 22 | 'embed_dims': [F, 2 * F, 4 * F, 8 * F, 16 * F], 23 | 'motion_dims': [0, 0, 0, 8 * F // depth[-2], 16 * F // depth[-1]], 24 | 'depths': depth, 25 | 'num_heads': [8 * F // 32, 16 * F // 32], 26 | 'window_sizes': [W, W], 27 | 'scales': [4, 8, 16], 28 | 'hidden_dims': [4 * F, 4 * F], 29 | 'c': F 30 | } 31 | 32 | 33 | MODEL_CONFIG = { 34 | 'LOGNAME': 'ema-vfi', 35 | 'MODEL_TYPE': (feature_extractor, flow_estimation), 36 | 'MODEL_ARCH': init_model_config( 37 | F=32, 38 | W=7, 39 | depth=[2, 2, 2, 4, 4] 40 | ) 41 | } 42 | 43 | # MODEL_CONFIG = { 44 | # 'LOGNAME': 'ours_small', 45 | # 'MODEL_TYPE': (feature_extractor, flow_estimation), 46 | # 'MODEL_ARCH': init_model_config( 47 | # F = 16, 48 | # W = 7, 49 | # depth = [2, 2, 2, 2, 2] 50 | # ) 51 | # } 52 | -------------------------------------------------------------------------------- /models/DI-EMA-VFI/config_recur.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import torch.nn as nn 3 | 4 | from model import feature_recur_extractor 5 | from model import flow_recur_estimation 6 | 7 | '''==========Model config==========''' 8 | 9 | 10 | def init_model_config(F=32, W=7, depth=[2, 2, 2, 4, 4]): 11 | '''This function should not be modified''' 12 | return { 13 | 'embed_dims': [F, 2 * F, 4 * F, 8 * F, 16 * F], 14 | 'motion_dims': [0, 0, 0, 8 * F // depth[-2], 16 * F // depth[-1]], 15 | 'num_heads': [8 * F // 32, 16 * F // 32], 16 | 'mlp_ratios': [4, 4], 17 | 'qkv_bias': True, 18 | 'norm_layer': partial(nn.LayerNorm, eps=1e-6), 19 | 'depths': depth, 20 | 'window_sizes': [W, W] 21 | }, { 22 | 'embed_dims': [F, 2 * F, 4 * F, 8 * F, 16 * F], 23 | 'motion_dims': [0, 0, 0, 8 * F // depth[-2], 16 * F // depth[-1]], 24 | 'depths': depth, 25 | 'num_heads': [8 * F // 32, 16 * F // 32], 26 | 'window_sizes': [W, W], 27 | 'scales': [4, 8, 16], 28 | 'hidden_dims': [4 * F, 4 * F], 29 | 'c': F 30 | } 31 | 32 | 33 | MODEL_CONFIG = { 34 | 'LOGNAME': 'ema-vfi', 35 | 'MODEL_TYPE': (feature_recur_extractor, flow_recur_estimation), 36 | 'MODEL_ARCH': init_model_config( 37 | F=32, 38 | W=7, 39 | depth=[2, 2, 2, 4, 4] 40 | ) 41 | } 42 | 43 | # MODEL_CONFIG = { 44 | # 'LOGNAME': 'ours_small', 45 | # 'MODEL_TYPE': (feature_recur_extractor, flow_recur_estimation), 46 | # 'MODEL_ARCH': init_model_config( 47 | # F = 16, 48 | # W = 7, 49 | # depth = [2, 2, 2, 2, 2] 50 | # ) 51 | # } 52 | -------------------------------------------------------------------------------- /models/DI-EMA-VFI/demo_2x.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import math 5 | import sys 6 | import torch 7 | import numpy as np 8 | import argparse 9 | from imageio import mimsave 10 | 11 | # I0_path = '../dataset/vimeo_septuplet/sequences/00068/0261/im1.png' 12 | # I2_path = '../dataset/vimeo_septuplet/sequences/00068/0261/im7.png' 13 | # save_dir = './demo/00068/0261_EMA' 14 | 15 | I0_path = '../dataset/vimeo_septuplet/sequences/00080/0050/im1.png' 16 | I2_path = '../dataset/vimeo_septuplet/sequences/00080/0050/im7.png' 17 | save_dir = './demo/00080/0050_EMA' 18 | 19 | iters = 7 20 | 21 | os.makedirs(save_dir, exist_ok=True) 22 | 23 | '''==========import from our code==========''' 24 | sys.path.append('.') 25 | import config as cfg 26 | from Trainer import Model 27 | from benchmark.utils.padder import InputPadder 28 | 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--model', default='ours', type=str) 31 | args = parser.parse_args() 32 | assert args.model in ['ours', 'ours_small'], 'Model not exists!' 33 | 34 | '''==========Model setting==========''' 35 | TTA = True 36 | if args.model == 'ours_small': 37 | TTA = False 38 | cfg.MODEL_CONFIG['LOGNAME'] = 'ours_small' 39 | cfg.MODEL_CONFIG['MODEL_ARCH'] = cfg.init_model_config( 40 | F=16, 41 | depth=[2, 2, 2, 2, 2] 42 | ) 43 | else: 44 | cfg.MODEL_CONFIG['LOGNAME'] = 'ours' 45 | cfg.MODEL_CONFIG['MODEL_ARCH'] = cfg.init_model_config( 46 | F=32, 47 | depth=[2, 2, 2, 4, 4] 48 | ) 49 | model = Model(-1) 50 | model.load_model() 51 | model.eval() 52 | model.device() 53 | 54 | print(f'=========================Start Generating=========================') 55 | 56 | I0 = cv2.imread(I0_path) 57 | I2 = cv2.imread(I2_path) 58 | 59 | I0_ = (torch.tensor(I0.transpose(2, 0, 1)).cuda() / 255.).unsqueeze(0) 60 | I2_ = (torch.tensor(I2.transpose(2, 0, 1)).cuda() / 255.).unsqueeze(0) 61 | 62 | padder = InputPadder(I0_.shape, divisor=32) 63 | I0_, I2_ = padder.pad(I0_, I2_) 64 | 65 | imgs = [I0_, I2_] 66 | while iters != 0: 67 | imgs_temp = [I0_, ] 68 | for I_start, I_end in zip(imgs[:-1], imgs[1:]): 69 | mid = model.inference(I_start, I_end, TTA=TTA, fast_TTA=TTA) 70 | imgs_temp.append(mid) 71 | imgs_temp.append(I_end) 72 | imgs = imgs_temp 73 | iters -= 1 74 | 75 | imgs = [ 76 | (padder.unpad(img)[0].detach().cpu().numpy().transpose(1, 2, 0) * 255.0).astype(np.uint8)[:, :, ::-1] 77 | for img in imgs 78 | ] 79 | 80 | mimsave('{}/demo.gif'.format(save_dir), imgs, duration=1000 / 15.) 81 | 82 | print(f'=========================Done=========================') 83 | -------------------------------------------------------------------------------- /models/DI-EMA-VFI/demo_Nx.py: -------------------------------------------------------------------------------- 1 | """ 2 | python demo_Nx.py --n 128 3 | """ 4 | import os 5 | 6 | import cv2 7 | import math 8 | import sys 9 | import torch 10 | import numpy as np 11 | import argparse 12 | from imageio import mimsave 13 | 14 | # I0_path = '../dataset/vimeo_septuplet/sequences/00068/0261/im1.png' 15 | # I2_path = '../dataset/vimeo_septuplet/sequences/00068/0261/im7.png' 16 | # save_dir = './demo/00068/0261_EMA_t' 17 | 18 | I0_path = '../dataset/vimeo_septuplet/sequences/00080/0050/im1.png' 19 | I2_path = '../dataset/vimeo_septuplet/sequences/00080/0050/im7.png' 20 | save_dir = './demo/00080/0050_EMA_t' 21 | 22 | os.makedirs(save_dir, exist_ok=True) 23 | 24 | '''==========import from our code==========''' 25 | sys.path.append('.') 26 | import config as cfg 27 | from Trainer import Model 28 | from benchmark.utils.padder import InputPadder 29 | 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--model', default='ours_t', type=str) 32 | parser.add_argument('--n', default=16, type=int) 33 | args = parser.parse_args() 34 | assert args.model in ['ours_t', 'ours_small_t'], 'Model not exists!' 35 | 36 | '''==========Model setting==========''' 37 | TTA = True 38 | if args.model == 'ours_small_t': 39 | TTA = False 40 | cfg.MODEL_CONFIG['LOGNAME'] = 'ours_small_t' 41 | cfg.MODEL_CONFIG['MODEL_ARCH'] = cfg.init_model_config( 42 | F=16, 43 | depth=[2, 2, 2, 2, 2] 44 | ) 45 | else: 46 | cfg.MODEL_CONFIG['LOGNAME'] = 'ours_t' 47 | cfg.MODEL_CONFIG['MODEL_ARCH'] = cfg.init_model_config( 48 | F=32, 49 | depth=[2, 2, 2, 4, 4] 50 | ) 51 | model = Model(-1) 52 | model.load_model() 53 | model.eval() 54 | model.device() 55 | 56 | print(f'=========================Start Generating=========================') 57 | 58 | I0 = cv2.imread(I0_path) 59 | I2 = cv2.imread(I2_path) 60 | 61 | I0_ = (torch.tensor(I0.transpose(2, 0, 1)).cuda() / 255.).unsqueeze(0) 62 | I2_ = (torch.tensor(I2.transpose(2, 0, 1)).cuda() / 255.).unsqueeze(0) 63 | 64 | padder = InputPadder(I0_.shape, divisor=32) 65 | I0_, I2_ = padder.pad(I0_, I2_) 66 | 67 | images = [I0[:, :, ::-1]] 68 | preds = model.multi_inference(I0_, I2_, TTA=TTA, time_list=[(i + 1) * (1. / args.n) for i in range(args.n - 1)], 69 | fast_TTA=TTA) 70 | for pred in preds: 71 | images.append((padder.unpad(pred).detach().cpu().numpy().transpose(1, 2, 0) * 255.0).astype(np.uint8)[:, :, ::-1]) 72 | images.append(I2[:, :, ::-1]) 73 | mimsave('{}/demo.gif'.format(save_dir), images, duration=1000 / 15.) 74 | 75 | print(f'=========================Done=========================') 76 | -------------------------------------------------------------------------------- /models/DI-EMA-VFI/inference_img_plus.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('.') 4 | 5 | import warnings 6 | 7 | warnings.filterwarnings('ignore') 8 | 9 | import os 10 | import cv2 11 | import torch 12 | import imageio as iio 13 | import os.path as osp 14 | import numpy as np 15 | from argparse import ArgumentParser 16 | from Trainer import Model 17 | from benchmark.utils.padder import InputPadder 18 | 19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 | 21 | 22 | @torch.no_grad() 23 | def interpolate(I0, I1, num): 24 | imgs = [] 25 | I0 = (torch.tensor(I0.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0) 26 | I1 = (torch.tensor(I1.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0) 27 | padder = InputPadder(I0.shape, 32) 28 | I0, I1 = padder.pad(I0, I1) 29 | 30 | embts = [torch.zeros_like(I0[:, :1, :, :]) + j / (num + 1) for j in range(1, num + 1)] 31 | 32 | for i in range(num): 33 | mid = model.multi_inference(I0, I1, time_list=[embts[i], ], TTA=False, fast_TTA=False)[0] 34 | mid = padder.unpad(mid) 35 | mid = mid.clamp(0, 1).permute(1, 2, 0).detach().cpu().numpy() 36 | mid = (mid * 255.).astype(np.uint8) 37 | imgs.append(mid) 38 | return imgs 39 | 40 | 41 | if __name__ == '__main__': 42 | parser = ArgumentParser() 43 | parser.add_argument('--img0', type=str, default='./demo/I0_0.png', help='path of start image') 44 | parser.add_argument('--img1', type=str, default='./demo/I0_1.png', help='path of end image') 45 | parser.add_argument('--checkpoint', type=str, default='./experiments/EMA-VFI_m/train_sdi_log/', 46 | help='path of checkpoint') 47 | parser.add_argument('--save_dir', type=str, default='./demo/I0_results/', help='where to save image results') 48 | parser.add_argument('--num', type=int, nargs='+', default=[5, 5], help='number of extracted images') 49 | parser.add_argument('--gif', action='store_true', help='whether to generate the corresponding gif') 50 | args = parser.parse_args() 51 | 52 | extracted_num = 2 53 | for sub_num in args.num: 54 | extracted_num += sub_num * (extracted_num - 1) 55 | 56 | # ----------------------- Load model ----------------------- 57 | model = Model(-1) 58 | model.load_model(log_path=args.checkpoint) 59 | model.eval() 60 | model.device() 61 | # ----------------------- Load input frames ----------------------- 62 | os.makedirs(args.save_dir, exist_ok=True) 63 | I0 = cv2.imread(args.img0) 64 | I1 = cv2.imread(args.img1) 65 | gif_imgs = [I0, I1] 66 | 67 | for sub_num in args.num: 68 | gif_imgs_temp = [gif_imgs[0], ] 69 | for i, (img_start, img_end) in enumerate(zip(gif_imgs[:-1], gif_imgs[1:])): 70 | interp_imgs = interpolate(img_start, img_end, num=sub_num) 71 | gif_imgs_temp += interp_imgs 72 | gif_imgs_temp += [img_end, ] 73 | gif_imgs = gif_imgs_temp 74 | 75 | print('Interpolate 2 images to {} images'.format(extracted_num)) 76 | 77 | for i, img in enumerate(gif_imgs): 78 | save_path = osp.join(args.save_dir, '{:03d}.png'.format(i)) 79 | cv2.imwrite(save_path, img) 80 | 81 | if args.gif: 82 | gif_path = osp.join(args.save_dir, 'demo.gif') 83 | with iio.get_writer(gif_path, mode='I') as writer: 84 | for img in gif_imgs: 85 | writer.append_data(img[:, :, ::-1]) 86 | -------------------------------------------------------------------------------- /models/DI-EMA-VFI/inference_img_plus_sdi.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('.') 4 | 5 | import warnings 6 | 7 | warnings.filterwarnings('ignore') 8 | 9 | import os 10 | import cv2 11 | import torch 12 | import imageio as iio 13 | import os.path as osp 14 | import numpy as np 15 | from argparse import ArgumentParser 16 | from Trainer import Model 17 | from benchmark.utils.padder import InputPadder 18 | 19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 | 21 | 22 | @torch.no_grad() 23 | def interpolate(I0, I1, num): 24 | imgs = [] 25 | I0 = (torch.tensor(I0.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0) 26 | I1 = (torch.tensor(I1.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0) 27 | padder = InputPadder(I0.shape, 32) 28 | I0, I1 = padder.pad(I0, I1) 29 | 30 | embts = [torch.zeros_like(I0[:, :1, :, :]) + j / (num + 1) for j in range(1, num + 1)] 31 | 32 | for i in range(num): 33 | mid = model.multi_inference(I0, I1, time_list=[embts[i], ], TTA=False, fast_TTA=False)[0] 34 | mid = padder.unpad(mid) 35 | mid = mid.clamp(0, 1).permute(1, 2, 0).detach().cpu().numpy() 36 | mid = (mid * 255.).astype(np.uint8) 37 | imgs.append(mid) 38 | return imgs 39 | 40 | 41 | if __name__ == '__main__': 42 | parser = ArgumentParser() 43 | parser.add_argument('--img0', type=str, default='./demo/I0_0.png', help='path of start image') 44 | parser.add_argument('--img1', type=str, default='./demo/I0_1.png', help='path of end image') 45 | parser.add_argument('--checkpoint', type=str, default='./experiments/EMA-VFI_m/train_sdi_log/', 46 | help='path of checkpoint') 47 | parser.add_argument('--save_dir', type=str, default='./demo/I0_results/', help='where to save image results') 48 | parser.add_argument('--num', type=int, nargs='+', default=[5, 5], help='number of extracted images') 49 | parser.add_argument('--gif', action='store_true', help='whether to generate the corresponding gif') 50 | args = parser.parse_args() 51 | 52 | extracted_num = 2 53 | for sub_num in args.num: 54 | extracted_num += sub_num * (extracted_num - 1) 55 | 56 | # ----------------------- Load model ----------------------- 57 | model = Model(-1) 58 | model.load_model(log_path=args.checkpoint) 59 | model.eval() 60 | model.device() 61 | # ----------------------- Load input frames ----------------------- 62 | os.makedirs(args.save_dir, exist_ok=True) 63 | I0 = cv2.imread(args.img0) 64 | I1 = cv2.imread(args.img1) 65 | gif_imgs = [I0, I1] 66 | 67 | for sub_num in args.num: 68 | gif_imgs_temp = [gif_imgs[0], ] 69 | for i, (img_start, img_end) in enumerate(zip(gif_imgs[:-1], gif_imgs[1:])): 70 | interp_imgs = interpolate(img_start, img_end, num=sub_num) 71 | gif_imgs_temp += interp_imgs 72 | gif_imgs_temp += [img_end, ] 73 | gif_imgs = gif_imgs_temp 74 | 75 | print('Interpolate 2 images to {} images'.format(extracted_num)) 76 | 77 | for i, img in enumerate(gif_imgs): 78 | save_path = osp.join(args.save_dir, '{:03d}.png'.format(i)) 79 | cv2.imwrite(save_path, img) 80 | 81 | if args.gif: 82 | gif_path = osp.join(args.save_dir, 'demo.gif') 83 | with iio.get_writer(gif_path, mode='I') as writer: 84 | for img in gif_imgs: 85 | writer.append_data(img[:, :, ::-1]) 86 | -------------------------------------------------------------------------------- /models/DI-EMA-VFI/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .feature_extractor import feature_extractor 2 | from .feature_recur_extractor import feature_recur_extractor 3 | from .flow_estimation import MultiScaleFlow as flow_estimation 4 | from .flow_recur_estimation import MultiScaleFlow as flow_recur_estimation 5 | 6 | __all__ = ['feature_extractor', 'feature_recur_extractor', 'flow_estimation', 'flow_recur_estimation'] 7 | -------------------------------------------------------------------------------- /models/DI-EMA-VFI/model/refine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | from timm.models.layers import trunc_normal_ 5 | 6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 7 | 8 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 9 | return nn.Sequential( 10 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, 11 | padding=padding, dilation=dilation, bias=True), 12 | nn.PReLU(out_planes) 13 | ) 14 | 15 | def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): 16 | return nn.Sequential( 17 | torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True), 18 | nn.PReLU(out_planes) 19 | ) 20 | 21 | class Conv2(nn.Module): 22 | def __init__(self, in_planes, out_planes, stride=2): 23 | super(Conv2, self).__init__() 24 | self.conv1 = conv(in_planes, out_planes, 3, stride, 1) 25 | self.conv2 = conv(out_planes, out_planes, 3, 1, 1) 26 | 27 | def forward(self, x): 28 | x = self.conv1(x) 29 | x = self.conv2(x) 30 | return x 31 | 32 | class Unet(nn.Module): 33 | def __init__(self, c, out=3): 34 | super(Unet, self).__init__() 35 | self.down0 = Conv2(17+c, 2*c) 36 | self.down1 = Conv2(4*c, 4*c) 37 | self.down2 = Conv2(8*c, 8*c) 38 | self.down3 = Conv2(16*c, 16*c) 39 | self.up0 = deconv(32*c, 8*c) 40 | self.up1 = deconv(16*c, 4*c) 41 | self.up2 = deconv(8*c, 2*c) 42 | self.up3 = deconv(4*c, c) 43 | self.conv = nn.Conv2d(c, out, 3, 1, 1) 44 | self.apply(self._init_weights) 45 | 46 | def _init_weights(self, m): 47 | if isinstance(m, nn.Linear): 48 | trunc_normal_(m.weight, std=.02) 49 | if isinstance(m, nn.Linear) and m.bias is not None: 50 | nn.init.constant_(m.bias, 0) 51 | elif isinstance(m, nn.LayerNorm): 52 | nn.init.constant_(m.bias, 0) 53 | nn.init.constant_(m.weight, 1.0) 54 | elif isinstance(m, nn.Conv2d): 55 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 56 | fan_out //= m.groups 57 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 58 | if m.bias is not None: 59 | m.bias.data.zero_() 60 | 61 | def forward(self, img0, img1, warped_img0, warped_img1, mask, flow, c0, c1): 62 | s0 = self.down0(torch.cat((img0, img1, warped_img0, warped_img1, mask, flow,c0[0], c1[0]), 1)) 63 | s1 = self.down1(torch.cat((s0, c0[1], c1[1]), 1)) 64 | s2 = self.down2(torch.cat((s1, c0[2], c1[2]), 1)) 65 | s3 = self.down3(torch.cat((s2, c0[3], c1[3]), 1)) 66 | x = self.up0(torch.cat((s3, c0[4], c1[4]), 1)) 67 | x = self.up1(torch.cat((x, s2), 1)) 68 | x = self.up2(torch.cat((x, s1), 1)) 69 | x = self.up3(torch.cat((x, s0), 1)) 70 | x = self.conv(x) 71 | return torch.sigmoid(x) 72 | -------------------------------------------------------------------------------- /models/DI-EMA-VFI/model/warplayer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 4 | backwarp_tenGrid = {} 5 | 6 | def warp(tenInput, tenFlow): 7 | k = (str(tenFlow.device), str(tenFlow.size())) 8 | if k not in backwarp_tenGrid: 9 | tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view( 10 | 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) 11 | tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view( 12 | 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) 13 | backwarp_tenGrid[k] = torch.cat( 14 | [tenHorizontal, tenVertical], 1).to(device) 15 | 16 | tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), 17 | tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1) 18 | 19 | g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) 20 | return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True) 21 | -------------------------------------------------------------------------------- /models/DI-RIFE/benchmark/ATD12K.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append('.') 4 | import cv2 5 | import math 6 | import torch 7 | import argparse 8 | import numpy as np 9 | from torch.nn import functional as F 10 | from model.pytorch_msssim import ssim_matlab 11 | from model.RIFE import Model 12 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 13 | 14 | model = Model() 15 | model.load_model('train_log') 16 | model.eval() 17 | model.device() 18 | 19 | path = 'datasets/test_2k_540p/' 20 | dirs = os.listdir(path) 21 | psnr_list = [] 22 | ssim_list = [] 23 | print(len(dirs)) 24 | for d in dirs: 25 | img0 = (path + d + '/frame1.png') 26 | img1 = (path + d + '/frame3.png') 27 | gt = (path + d + '/frame2.png') 28 | img0 = (torch.tensor(cv2.imread(img0).transpose(2, 0, 1) / 255.)).to(device).float().unsqueeze(0) 29 | img1 = (torch.tensor(cv2.imread(img1).transpose(2, 0, 1) / 255.)).to(device).float().unsqueeze(0) 30 | gt = (torch.tensor(cv2.imread(gt).transpose(2, 0, 1) / 255.)).to(device).float().unsqueeze(0) 31 | pader = torch.nn.ReplicationPad2d([0, 0, 2, 2]) 32 | img0 = pader(img0) 33 | img1 = pader(img1) 34 | pred = model.inference(img0, img1)[0][:, 2:-2] 35 | ssim = ssim_matlab(gt, torch.round(pred * 255).unsqueeze(0) / 255.).detach().cpu().numpy() 36 | out = pred.detach().cpu().numpy().transpose(1, 2, 0) 37 | out = np.round(out * 255) / 255. 38 | gt = gt[0].cpu().numpy().transpose(1, 2, 0) 39 | psnr = -10 * math.log10(((gt - out) * (gt - out)).mean()) 40 | psnr_list.append(psnr) 41 | ssim_list.append(ssim) 42 | print("Avg PSNR: {} SSIM: {}".format(np.mean(psnr_list), np.mean(ssim_list))) 43 | -------------------------------------------------------------------------------- /models/DI-RIFE/benchmark/HD.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append('.') 4 | import cv2 5 | import math 6 | import torch 7 | import argparse 8 | import numpy as np 9 | from torch.nn import functional as F 10 | from model.pytorch_msssim import ssim_matlab 11 | from model.RIFE import Model 12 | from skimage.color import rgb2yuv, yuv2rgb 13 | from yuv_frame_io import YUV_Read,YUV_Write 14 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 15 | 16 | model = Model() 17 | model.load_model('train_log') 18 | model.eval() 19 | model.device() 20 | 21 | name_list = [ 22 | ('HD_dataset/HD720p_GT/parkrun_1280x720_50.yuv', 720, 1280), 23 | ('HD_dataset/HD720p_GT/shields_1280x720_60.yuv', 720, 1280), 24 | ('HD_dataset/HD720p_GT/stockholm_1280x720_60.yuv', 720, 1280), 25 | ('HD_dataset/HD1080p_GT/BlueSky.yuv', 1080, 1920), 26 | ('HD_dataset/HD1080p_GT/Kimono1_1920x1080_24.yuv', 1080, 1920), 27 | ('HD_dataset/HD1080p_GT/ParkScene_1920x1080_24.yuv', 1080, 1920), 28 | ('HD_dataset/HD1080p_GT/sunflower_1080p25.yuv', 1080, 1920), 29 | ('HD_dataset/HD544p_GT/Sintel_Alley2_1280x544.yuv', 544, 1280), 30 | ('HD_dataset/HD544p_GT/Sintel_Market5_1280x544.yuv', 544, 1280), 31 | ('HD_dataset/HD544p_GT/Sintel_Temple1_1280x544.yuv', 544, 1280), 32 | ('HD_dataset/HD544p_GT/Sintel_Temple2_1280x544.yuv', 544, 1280), 33 | ] 34 | tot = 0. 35 | for data in name_list: 36 | psnr_list = [] 37 | name = data[0] 38 | h = data[1] 39 | w = data[2] 40 | if 'yuv' in name: 41 | Reader = YUV_Read(name, h, w, toRGB=True) 42 | else: 43 | Reader = cv2.VideoCapture(name) 44 | _, lastframe = Reader.read() 45 | # fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v') 46 | # video = cv2.VideoWriter(name + '.mp4', fourcc, 30, (w, h)) 47 | for index in range(0, 100, 2): 48 | if 'yuv' in name: 49 | IMAGE1, success1 = Reader.read(index) 50 | gt, _ = Reader.read(index + 1) 51 | IMAGE2, success2 = Reader.read(index + 2) 52 | if not success2: 53 | break 54 | else: 55 | success1, gt = Reader.read() 56 | success2, frame = Reader.read() 57 | IMAGE1 = lastframe 58 | IMAGE2 = frame 59 | lastframe = frame 60 | if not success2: 61 | break 62 | I0 = torch.from_numpy(np.transpose(IMAGE1, (2,0,1)).astype("float32") / 255.).cuda().unsqueeze(0) 63 | I1 = torch.from_numpy(np.transpose(IMAGE2, (2,0,1)).astype("float32") / 255.).cuda().unsqueeze(0) 64 | 65 | if h == 720: 66 | pad = 24 67 | elif h == 1080: 68 | pad = 4 69 | else: 70 | pad = 16 71 | pader = torch.nn.ReplicationPad2d([0, 0, pad, pad]) 72 | I0 = pader(I0) 73 | I1 = pader(I1) 74 | with torch.no_grad(): 75 | pred = model.inference(I0, I1) 76 | pred = pred[:, :, pad: -pad] 77 | out = (np.round(pred[0].detach().cpu().numpy().transpose(1, 2, 0) * 255)).astype('uint8') 78 | # video.write(out) 79 | if 'yuv' in name: 80 | diff_rgb = 128.0 + rgb2yuv(gt / 255.)[:, :, 0] * 255 - rgb2yuv(out / 255.)[:, :, 0] * 255 81 | mse = np.mean((diff_rgb - 128.0) ** 2) 82 | PIXEL_MAX = 255.0 83 | psnr = 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) 84 | else: 85 | psnr = skim.compare_psnr(gt, out) 86 | psnr_list.append(psnr) 87 | print(np.mean(psnr_list)) 88 | tot += np.mean(psnr_list) 89 | print('avg psnr', tot / len(name_list)) 90 | -------------------------------------------------------------------------------- /models/DI-RIFE/benchmark/MiddleBury_Other.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append('.') 4 | import cv2 5 | import math 6 | import torch 7 | import argparse 8 | import numpy as np 9 | from torch.nn import functional as F 10 | from model.pytorch_msssim import ssim_matlab 11 | from model.RIFE import Model 12 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 13 | 14 | model = Model() 15 | model.load_model('train_log') 16 | model.eval() 17 | model.device() 18 | 19 | name = ['Beanbags', 'Dimetrodon', 'DogDance', 'Grove2', 'Grove3', 'Hydrangea', 'MiniCooper', 'RubberWhale', 'Urban2', 'Urban3', 'Venus', 'Walking'] 20 | IE_list = [] 21 | for i in name: 22 | i0 = cv2.imread('other-data/{}/frame10.png'.format(i)).transpose(2, 0, 1) / 255. 23 | i1 = cv2.imread('other-data/{}/frame11.png'.format(i)).transpose(2, 0, 1) / 255. 24 | gt = cv2.imread('other-gt-interp/{}/frame10i11.png'.format(i)) 25 | h, w = i0.shape[1], i0.shape[2] 26 | imgs = torch.zeros([1, 6, 480, 640]).to(device) 27 | ph = (480 - h) // 2 28 | pw = (640 - w) // 2 29 | imgs[:, :3, :h, :w] = torch.from_numpy(i0).unsqueeze(0).float().to(device) 30 | imgs[:, 3:, :h, :w] = torch.from_numpy(i1).unsqueeze(0).float().to(device) 31 | I0 = imgs[:, :3] 32 | I2 = imgs[:, 3:] 33 | pred = model.inference(I0, I2) 34 | out = pred[0].detach().cpu().numpy().transpose(1, 2, 0) 35 | out = np.round(out[:h, :w] * 255) 36 | IE_list.append(np.abs((out - gt * 1.0)).mean()) 37 | print(np.mean(IE_list)) 38 | -------------------------------------------------------------------------------- /models/DI-RIFE/benchmark/UCF101.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append('.') 4 | import cv2 5 | import math 6 | import torch 7 | import argparse 8 | import numpy as np 9 | from torch.nn import functional as F 10 | from model.pytorch_msssim import ssim_matlab 11 | from model.RIFE import Model 12 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 13 | 14 | model = Model() 15 | model.load_model('train_log') 16 | model.eval() 17 | model.device() 18 | 19 | path = 'UCF101/ucf101_interp_ours/' 20 | dirs = os.listdir(path) 21 | psnr_list = [] 22 | ssim_list = [] 23 | print(len(dirs)) 24 | for d in dirs: 25 | img0 = (path + d + '/frame_00.png') 26 | img1 = (path + d + '/frame_02.png') 27 | gt = (path + d + '/frame_01_gt.png') 28 | img0 = (torch.tensor(cv2.imread(img0).transpose(2, 0, 1) / 255.)).to(device).float().unsqueeze(0) 29 | img1 = (torch.tensor(cv2.imread(img1).transpose(2, 0, 1) / 255.)).to(device).float().unsqueeze(0) 30 | gt = (torch.tensor(cv2.imread(gt).transpose(2, 0, 1) / 255.)).to(device).float().unsqueeze(0) 31 | pred = model.inference(img0, img1)[0] 32 | ssim = ssim_matlab(gt, torch.round(pred * 255).unsqueeze(0) / 255.).detach().cpu().numpy() 33 | out = pred.detach().cpu().numpy().transpose(1, 2, 0) 34 | out = np.round(out * 255) / 255. 35 | gt = gt[0].cpu().numpy().transpose(1, 2, 0) 36 | psnr = -10 * math.log10(((gt - out) * (gt - out)).mean()) 37 | psnr_list.append(psnr) 38 | ssim_list.append(ssim) 39 | print("Avg PSNR: {} SSIM: {}".format(np.mean(psnr_list), np.mean(ssim_list))) 40 | -------------------------------------------------------------------------------- /models/DI-RIFE/benchmark/Vimeo90K.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append('..') 5 | import cv2 6 | import math 7 | import torch 8 | import lpips 9 | import argparse 10 | import numpy as np 11 | import os.path as osp 12 | from model.pytorch_msssim import ssim_matlab 13 | from model.RIFE import Model 14 | from utils import Logger 15 | from basicsr.metrics.niqe import calculate_niqe 16 | 17 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 18 | 19 | if __name__ == '__main__': 20 | """ 21 | CUDA_VISIBLE_DEVICES=0 python Vimeo90K.py --model_dir ../experiments/rife/train_log --testset_path /mnt/disks/ssd0/dataset/vimeo_triplet/ 22 | """ 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--model_dir', type=str, default='train_sdi_log') 25 | parser.add_argument('--testset_path', type=str, default='../dataset/vimeo_septuplet/') 26 | args = parser.parse_args() 27 | 28 | model = Model() 29 | model.load_model(args.model_dir) 30 | model.eval() 31 | model.device() 32 | 33 | logger = Logger(osp.join(args.model_dir, 'test_log.txt')) 34 | 35 | path = args.testset_path 36 | f = open(path + 'tri_testlist.txt', 'r') 37 | psnr_list = [] 38 | ssim_list = [] 39 | lpips_list = [] 40 | niqe_list = [] 41 | loss_fn_alex = lpips.LPIPS(net='alex').to(device) 42 | for i in f: 43 | name = str(i).strip() 44 | if (len(name) <= 1): 45 | continue 46 | print(path + 'sequences/' + name + '/im1.png') 47 | I0 = cv2.imread(path + 'sequences/' + name + '/im1.png') 48 | I1 = cv2.imread(path + 'sequences/' + name + '/im2.png') 49 | I2 = cv2.imread(path + 'sequences/' + name + '/im3.png') 50 | I0 = (torch.tensor(I0.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0) 51 | I2 = (torch.tensor(I2.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0) 52 | mid = model.inference(I0, I2)[0] 53 | ssim = ssim_matlab(torch.tensor(I1.transpose(2, 0, 1)).to(device).unsqueeze(0) / 255., 54 | torch.round(mid * 255).unsqueeze(0) / 255.).detach().cpu().numpy() 55 | mid = np.round((mid * 255).detach().cpu().numpy()).astype('uint8').transpose(1, 2, 0) / 255. 56 | I1 = I1 / 255. 57 | psnr = -10 * math.log10(((I1 - mid) * (I1 - mid)).mean()) 58 | 59 | psnr_list.append(psnr) 60 | ssim_list.append(ssim) 61 | 62 | # calculate niqe score 63 | niqe = calculate_niqe(mid * 255., crop_border=0) 64 | niqe_list.append(niqe) 65 | 66 | # calculate lpips score 67 | mid = mid[:, :, ::-1] # rgb image 68 | mid = torch.from_numpy(2 * mid - 1.).permute(2, 0, 1)[None].float().to( 69 | device 70 | ) # (1, 3, h, w) value range from [-1, 1] 71 | I1 = I1[:, :, ::-1] # rgb image 72 | I1 = torch.from_numpy(2 * I1 - 1.).permute(2, 0, 1)[None].float().to( 73 | device 74 | ) # (1, 3, h, w) value range from [-1, 1] 75 | lpips_value = loss_fn_alex.forward(mid, I1).detach().cpu().numpy() 76 | lpips_list.append(lpips_value) 77 | 78 | logger("Avg PSNR: {:.4f} SSIM: {:.4f} LPIPS: {:.4f} NIQE: {:.4f}".format( 79 | np.mean(psnr), np.mean(ssim), np.mean(lpips_value), np.mean(niqe) 80 | )) 81 | logger("Total Avg PSNR: {:.4f} SSIM: {:.4f} LPIPS: {:.4f} NIQE: {:.4f}".format( 82 | np.mean(psnr_list), np.mean(ssim_list), np.mean(lpips_list), np.mean(niqe_list) 83 | )) 84 | -------------------------------------------------------------------------------- /models/DI-RIFE/benchmark/testtime.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import sys 3 | sys.path.append('.') 4 | import time 5 | import torch 6 | import torch.nn as nn 7 | from model.RIFE import Model 8 | 9 | model = Model() 10 | model.eval() 11 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 12 | torch.set_grad_enabled(False) 13 | if torch.cuda.is_available(): 14 | torch.backends.cudnn.enabled = True 15 | torch.backends.cudnn.benchmark = True 16 | 17 | I0 = torch.rand(1, 3, 480, 640).to(device) 18 | I1 = torch.rand(1, 3, 480, 640).to(device) 19 | with torch.no_grad(): 20 | for i in range(100): 21 | pred = model.inference(I0, I1) 22 | if torch.cuda.is_available(): 23 | torch.cuda.synchronize() 24 | time_stamp = time.time() 25 | for i in range(100): 26 | pred = model.inference(I0, I1) 27 | if torch.cuda.is_available(): 28 | torch.cuda.synchronize() 29 | print((time.time() - time_stamp) / 100) 30 | -------------------------------------------------------------------------------- /models/DI-RIFE/inference_img_plus.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('.') 4 | 5 | import warnings 6 | 7 | warnings.filterwarnings('ignore') 8 | 9 | import shutil 10 | import math 11 | import os 12 | import cv2 13 | import torch 14 | import lpips 15 | import imageio as iio 16 | import os.path as osp 17 | import numpy as np 18 | import torch.nn.functional as F 19 | from model.RIFE_m import Model 20 | from argparse import ArgumentParser 21 | from tqdm import tqdm 22 | from model.pytorch_msssim import ssim_matlab 23 | 24 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 25 | 26 | 27 | def interpolate(I0, I1, num): 28 | imgs = [] 29 | I0 = (torch.tensor(I0.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0) 30 | I1 = (torch.tensor(I1.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0) 31 | 32 | _, _, h, w = I0.shape 33 | ph = ((h - 1) // 32 + 1) * 32 34 | pw = ((w - 1) // 32 + 1) * 32 35 | padding = (0, pw - w, 0, ph - h) 36 | I0 = F.pad(I0, padding) 37 | I1 = F.pad(I1, padding) 38 | 39 | timesteps = [j / (num + 1 + 1e-6) for j in range(1, num + 1)] 40 | 41 | for i, timestep in enumerate(timesteps): 42 | mid = model.inference(I0, I1, timestep=timestep)[0] 43 | mid = mid.clamp(0, 1).permute(1, 2, 0).detach().cpu().numpy() 44 | mid = (mid * 255.).astype(np.uint8) 45 | imgs.append(mid[:h, :w]) 46 | return imgs 47 | 48 | 49 | if __name__ == '__main__': 50 | parser = ArgumentParser() 51 | parser.add_argument('--img0', type=str, default='./demo/I0.png', help='path of start image') 52 | parser.add_argument('--img1', type=str, default='./demo/I0_1.png', help='path of end image') 53 | parser.add_argument('--checkpoint', type=str, default='./experiments/rife_m/train_m_log_official', 54 | help='path of checkpoint') 55 | parser.add_argument('--save_dir', type=str, default='./demo/I0_results/', help='where to save image results') 56 | parser.add_argument('--num', type=int, nargs='+', default=[5, 5], help='number of extracted images') 57 | parser.add_argument('--gif', action='store_true', help='whether to generate the corresponding gif') 58 | args = parser.parse_args() 59 | 60 | extracted_num = 2 61 | for sub_num in args.num: 62 | extracted_num += sub_num * (extracted_num - 1) 63 | 64 | model = Model() 65 | model.load_model(args.checkpoint) 66 | model.eval() 67 | model.device() 68 | 69 | os.makedirs(args.save_dir, exist_ok=True) 70 | I0 = cv2.imread(args.img0) 71 | I1 = cv2.imread(args.img1) 72 | gif_imgs = [I0, I1] 73 | 74 | for sub_num in args.num: 75 | gif_imgs_temp = [gif_imgs[0], ] 76 | for i, (img_start, img_end) in enumerate(zip(gif_imgs[:-1], gif_imgs[1:])): 77 | interp_imgs = interpolate(img_start, img_end, num=sub_num) 78 | gif_imgs_temp += interp_imgs 79 | gif_imgs_temp += [img_end, ] 80 | gif_imgs = gif_imgs_temp 81 | 82 | print('Interpolate 2 images to {} images'.format(extracted_num)) 83 | 84 | for i, img in enumerate(gif_imgs): 85 | save_path = osp.join(args.save_dir, '{:03d}.png'.format(i)) 86 | cv2.imwrite(save_path, img) 87 | 88 | if args.gif: 89 | gif_path = osp.join(args.save_dir, 'demo.gif') 90 | with iio.get_writer(gif_path, mode='I') as writer: 91 | for img in gif_imgs: 92 | writer.append_data(img[:, :, ::-1]) 93 | -------------------------------------------------------------------------------- /models/DI-RIFE/inference_img_plus_sdi.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('.') 4 | 5 | import warnings 6 | 7 | warnings.filterwarnings('ignore') 8 | 9 | import shutil 10 | import os 11 | import math 12 | import cv2 13 | import torch 14 | import lpips 15 | import imageio as iio 16 | import os.path as osp 17 | import numpy as np 18 | import numpy.ma as ma 19 | import torch.nn.functional as F 20 | from model.RIFE_sdi import Model 21 | from argparse import ArgumentParser 22 | from tqdm import tqdm 23 | from model.pytorch_msssim import ssim_matlab 24 | 25 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 26 | 27 | 28 | def interpolate(I0, I1, num): 29 | imgs = [] 30 | I0 = (torch.tensor(I0.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0) 31 | I1 = (torch.tensor(I1.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0) 32 | 33 | sdi_maps = [torch.zeros_like(I0[:, :1, :, :]) + j / (num + 1) for j in range(1, num + 1)] 34 | 35 | _, _, h, w = I0.shape 36 | ph = ((h - 1) // 32 + 1) * 32 37 | pw = ((w - 1) // 32 + 1) * 32 38 | padding = (0, pw - w, 0, ph - h) 39 | I0 = F.pad(I0, padding) 40 | I1 = F.pad(I1, padding) 41 | 42 | for i, sdi_map in enumerate(sdi_maps): 43 | mid = model.inference(I0, I1, sdi_map=sdi_map)[0] 44 | mid = mid.clamp(0, 1).permute(1, 2, 0).detach().cpu().numpy() 45 | mid = (mid * 255.).astype(np.uint8) 46 | imgs.append(mid[:h, :w]) 47 | return imgs 48 | 49 | 50 | if __name__ == '__main__': 51 | parser = ArgumentParser() 52 | parser.add_argument('--img0', type=str, default='./demo/I0.png', help='path of start image') 53 | parser.add_argument('--img1', type=str, default='./demo/I0_1.png', help='path of end image') 54 | parser.add_argument('--checkpoint', type=str, default='./experiments/rife_sdi_m_mask_noavg_blur/train_sdi_log', 55 | help='path of checkpoint') 56 | parser.add_argument('--save_dir', type=str, default='./demo/I0_results/', help='where to save image results') 57 | parser.add_argument('--num', type=int, nargs='+', default=[5, 5], help='number of extracted images') 58 | parser.add_argument('--gif', action='store_true', help='whether to generate the corresponding gif') 59 | parser.add_argument('--no_distill', action='store_true', help='no optical flow distillation') 60 | args = parser.parse_args() 61 | 62 | extracted_num = 2 63 | for sub_num in args.num: 64 | extracted_num += sub_num * (extracted_num - 1) 65 | 66 | model = Model(distill=not args.no_distill) 67 | model.load_model(args.checkpoint) 68 | model.eval() 69 | model.device() 70 | 71 | os.makedirs(args.save_dir, exist_ok=True) 72 | I0 = cv2.imread(args.img0) 73 | I1 = cv2.imread(args.img1) 74 | gif_imgs = [I0, I1] 75 | 76 | for sub_num in args.num: 77 | gif_imgs_temp = [gif_imgs[0], ] 78 | for i, (img_start, img_end) in enumerate(zip(gif_imgs[:-1], gif_imgs[1:])): 79 | interp_imgs = interpolate(img_start, img_end, num=sub_num) 80 | gif_imgs_temp += interp_imgs 81 | gif_imgs_temp += [img_end, ] 82 | gif_imgs = gif_imgs_temp 83 | 84 | print('Interpolate 2 images to {} images'.format(extracted_num)) 85 | 86 | for i, img in enumerate(gif_imgs): 87 | save_path = osp.join(args.save_dir, '{:03d}.png'.format(i)) 88 | cv2.imwrite(save_path, img) 89 | 90 | if args.gif: 91 | gif_path = osp.join(args.save_dir, 'demo.gif') 92 | with iio.get_writer(gif_path, mode='I') as writer: 93 | for img in gif_imgs: 94 | writer.append_data(img[:, :, ::-1]) 95 | -------------------------------------------------------------------------------- /models/DI-RIFE/model/laplacian.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 7 | 8 | import torch 9 | 10 | def gauss_kernel(size=5, channels=3): 11 | kernel = torch.tensor([[1., 4., 6., 4., 1], 12 | [4., 16., 24., 16., 4.], 13 | [6., 24., 36., 24., 6.], 14 | [4., 16., 24., 16., 4.], 15 | [1., 4., 6., 4., 1.]]) 16 | kernel /= 256. 17 | kernel = kernel.repeat(channels, 1, 1, 1) 18 | kernel = kernel.to(device) 19 | return kernel 20 | 21 | def downsample(x): 22 | return x[:, :, ::2, ::2] 23 | 24 | def upsample(x): 25 | cc = torch.cat([x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3]).to(device)], dim=3) 26 | cc = cc.view(x.shape[0], x.shape[1], x.shape[2]*2, x.shape[3]) 27 | cc = cc.permute(0,1,3,2) 28 | cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2]*2).to(device)], dim=3) 29 | cc = cc.view(x.shape[0], x.shape[1], x.shape[3]*2, x.shape[2]*2) 30 | x_up = cc.permute(0,1,3,2) 31 | return conv_gauss(x_up, 4*gauss_kernel(channels=x.shape[1])) 32 | 33 | def conv_gauss(img, kernel): 34 | img = torch.nn.functional.pad(img, (2, 2, 2, 2), mode='reflect') 35 | out = torch.nn.functional.conv2d(img, kernel, groups=img.shape[1]) 36 | return out 37 | 38 | def laplacian_pyramid(img, kernel, max_levels=3): 39 | current = img 40 | pyr = [] 41 | for level in range(max_levels): 42 | filtered = conv_gauss(current, kernel) 43 | down = downsample(filtered) 44 | up = upsample(down) 45 | diff = current-up 46 | pyr.append(diff) 47 | current = down 48 | return pyr 49 | 50 | class LapLoss(torch.nn.Module): 51 | def __init__(self, max_levels=5, channels=3): 52 | super(LapLoss, self).__init__() 53 | self.max_levels = max_levels 54 | self.gauss_kernel = gauss_kernel(channels=channels) 55 | 56 | def forward(self, input, target): 57 | pyr_input = laplacian_pyramid(img=input, kernel=self.gauss_kernel, max_levels=self.max_levels) 58 | pyr_target = laplacian_pyramid(img=target, kernel=self.gauss_kernel, max_levels=self.max_levels) 59 | return sum(torch.nn.functional.l1_loss(a, b) for a, b in zip(pyr_input, pyr_target)) 60 | -------------------------------------------------------------------------------- /models/DI-RIFE/model/refine_2R.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.optim as optim 5 | import itertools 6 | from model.warplayer import warp 7 | from torch.nn.parallel import DistributedDataParallel as DDP 8 | import torch.nn.functional as F 9 | 10 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 11 | 12 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 13 | return nn.Sequential( 14 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, 15 | padding=padding, dilation=dilation, bias=True), 16 | nn.PReLU(out_planes) 17 | ) 18 | 19 | def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): 20 | return nn.Sequential( 21 | torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True), 22 | nn.PReLU(out_planes) 23 | ) 24 | 25 | class Conv2(nn.Module): 26 | def __init__(self, in_planes, out_planes, stride=2): 27 | super(Conv2, self).__init__() 28 | self.conv1 = conv(in_planes, out_planes, 3, stride, 1) 29 | self.conv2 = conv(out_planes, out_planes, 3, 1, 1) 30 | 31 | def forward(self, x): 32 | x = self.conv1(x) 33 | x = self.conv2(x) 34 | return x 35 | 36 | c = 16 37 | class Contextnet(nn.Module): 38 | def __init__(self): 39 | super(Contextnet, self).__init__() 40 | self.conv1 = Conv2(3, c, 1) 41 | self.conv2 = Conv2(c, 2*c) 42 | self.conv3 = Conv2(2*c, 4*c) 43 | self.conv4 = Conv2(4*c, 8*c) 44 | 45 | def forward(self, x, flow): 46 | x = self.conv1(x) 47 | # flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5 48 | f1 = warp(x, flow) 49 | x = self.conv2(x) 50 | flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5 51 | f2 = warp(x, flow) 52 | x = self.conv3(x) 53 | flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5 54 | f3 = warp(x, flow) 55 | x = self.conv4(x) 56 | flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5 57 | f4 = warp(x, flow) 58 | return [f1, f2, f3, f4] 59 | 60 | class Unet(nn.Module): 61 | def __init__(self): 62 | super(Unet, self).__init__() 63 | self.down0 = Conv2(17, 2*c, 1) 64 | self.down1 = Conv2(4*c, 4*c) 65 | self.down2 = Conv2(8*c, 8*c) 66 | self.down3 = Conv2(16*c, 16*c) 67 | self.up0 = deconv(32*c, 8*c) 68 | self.up1 = deconv(16*c, 4*c) 69 | self.up2 = deconv(8*c, 2*c) 70 | self.up3 = deconv(4*c, c) 71 | self.conv = nn.Conv2d(c, 3, 3, 2, 1) 72 | 73 | def forward(self, img0, img1, warped_img0, warped_img1, mask, flow, c0, c1): 74 | s0 = self.down0(torch.cat((img0, img1, warped_img0, warped_img1, mask, flow), 1)) 75 | s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1)) 76 | s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1)) 77 | s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1)) 78 | x = self.up0(torch.cat((s3, c0[3], c1[3]), 1)) 79 | x = self.up1(torch.cat((x, s2), 1)) 80 | x = self.up2(torch.cat((x, s1), 1)) 81 | x = self.up3(torch.cat((x, s0), 1)) 82 | x = self.conv(x) 83 | return torch.sigmoid(x) 84 | -------------------------------------------------------------------------------- /models/DI-RIFE/model/warplayer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 5 | backwarp_tenGrid = {} 6 | 7 | 8 | def warp(tenInput, tenFlow): 9 | k = (str(tenFlow.device), str(tenFlow.size())) 10 | if k not in backwarp_tenGrid: 11 | tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view( 12 | 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) 13 | tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view( 14 | 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) 15 | backwarp_tenGrid[k] = torch.cat( 16 | [tenHorizontal, tenVertical], 1).to(device) 17 | 18 | tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), 19 | tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1) 20 | 21 | g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) 22 | return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True) 23 | -------------------------------------------------------------------------------- /models/DI-RIFE/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.16 2 | tqdm>=4.35.0 3 | sk-video>=1.1.10 4 | torch==1.7.1 5 | opencv-python>=4.1.2 6 | moviepy>=1.0.3 7 | torchvision==0.8.2 -------------------------------------------------------------------------------- /models/DI-RIFE/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | from datetime import datetime 4 | 5 | 6 | class Logger: 7 | """ 8 | Logger class to record training log 9 | """ 10 | 11 | def __init__(self, file_path, verbose=True): 12 | self.verbose = verbose 13 | self.create_dir(file_path) 14 | self.logger = open(file_path, 'a+') 15 | 16 | def create_dir(self, file_path): 17 | dir = osp.dirname(file_path) 18 | os.makedirs(dir, exist_ok=True) 19 | 20 | def __call__(self, *args, prefix='', timestamp=False): 21 | if timestamp: 22 | now = datetime.now() 23 | now = now.strftime("%Y/%m/%d, %H:%M:%S - ") 24 | else: 25 | now = '' 26 | if prefix == '': 27 | info = prefix + now 28 | else: 29 | info = prefix + ' ' + now 30 | for msg in args: 31 | if not isinstance(msg, str): 32 | msg = str(msg) 33 | info += msg + '\n' 34 | self.logger.write(info) 35 | if self.verbose: 36 | print(info, end='') 37 | self.logger.flush() 38 | 39 | def __del__(self): 40 | self.logger.close() 41 | -------------------------------------------------------------------------------- /multiprocess_create_dis_index.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import subprocess 3 | import time 4 | import os 5 | import os.path as osp 6 | 7 | if __name__ == '__main__': 8 | """ 9 | cmd: 10 | CUDA_VISIBLE_DEVICES=0,1,2,3 python multiprocess_create_dis_index.py --num_gpus 4 --num_workers 5 11 | CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,7 screen python multiprocess_create_dis_index.py --num_gpus 7 --num_workers 4 --path /mnt/disks/ssd0/dataset/gopro/ --sample_list_path train.txt --downsample_ratio 4. --sample_length 9 12 | CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,7 screen python multiprocess_create_dis_index.py --num_gpus 7 --num_workers 4 --path /mnt/disks/ssd0/dataset/gopro/ --sample_list_path test.txt --downsample_ratio 4. --sample_length 9 13 | CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,7 screen python multiprocess_create_dis_index.py --num_gpus 7 --num_workers 4 --path /mnt/disks/ssd0/dataset/dvd/ --sample_list_path train.txt --downsample_ratio 4. --sample_length 9 14 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 screen python multiprocess_create_dis_index.py --num_gpus 8 --num_workers 4 --path ./dataset/vimeo_triplet/ --sample_list_path tri_trainlist.txt --sample_length 3 15 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 screen python multiprocess_create_dis_index.py --num_gpus 8 --num_workers 4 --path ./dataset/vimeo_triplet/ --sample_list_path tri_testlist.txt --sample_length 3 16 | """ 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--num_gpus', type=int, default=2) 19 | parser.add_argument('--num_workers', type=int, default=8) 20 | parser.add_argument('--path', type=str, default='/mnt/disks/ssd0/dataset/vimeo_septuplet/') 21 | # argument for the path of sample list 22 | parser.add_argument('--sample_list_path', nargs='+', type=str, default=['sep_trainlist.txt', 'sep_testlist.txt']) 23 | parser.add_argument('--downsample_ratio', type=float, default=2.) 24 | parser.add_argument('--sample_length', type=int, default=7) 25 | args = parser.parse_args() 26 | 27 | path = args.path 28 | sample_paths = [] 29 | for set_name in args.sample_list_path: 30 | sep_list = osp.join(path, set_name) 31 | with open(sep_list) as f: 32 | samples = f.readlines() 33 | for sample in samples: 34 | if '/' not in sample: 35 | continue 36 | sample_path = osp.join(path, 'sequences', sample.strip()) 37 | sample_paths.append(sample_path) 38 | 39 | # split sample_paths 40 | num_processes = args.num_gpus * args.num_workers 41 | for i in range(num_processes): 42 | gpu_sample_paths = sample_paths[i::num_processes] 43 | with open(f'sample_paths_{i}.txt', 'w') as f: 44 | f.write('\n'.join(gpu_sample_paths)) 45 | 46 | # launch multiprocess for masks generation 47 | pool = [] 48 | for i in range(num_processes): 49 | cmd = ['python', 'process_create_dis_index.py', '--sample_list_path', f'sample_paths_{i}.txt', 50 | '--downsample_ratio', str(args.downsample_ratio), '--sample_length', str(args.sample_length)] 51 | print(' '.join(cmd)) 52 | env = { 53 | **os.environ, 54 | 'CUDA_VISIBLE_DEVICES': str(i // args.num_workers) 55 | } 56 | p = subprocess.Popen(cmd, env=env) 57 | pool.append(p) 58 | exit_codes = [p.wait() for p in pool] 59 | 60 | # clean split result 61 | for i in range(num_processes): 62 | os.remove(f'sample_paths_{i}.txt') 63 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.0.0 2 | addict==2.4.0 3 | antlr4-python3-runtime==4.9.3 4 | appdirs==1.4.4 5 | basicsr==1.4.2 6 | blinker==1.6.3 7 | cachetools==5.3.1 8 | certifi==2023.7.22 9 | charset-normalizer==3.3.0 10 | click==8.1.7 11 | coloredlogs==15.0.1 12 | contourpy==1.1.1 13 | cycler==0.12.1 14 | decorator==4.4.2 15 | docker-pycreds==0.4.0 16 | filelock==3.12.4 17 | Flask==3.0.0 18 | Flask-Cors==4.0.0 19 | flatbuffers==23.5.26 20 | fonttools==4.43.1 21 | fsspec==2023.9.2 22 | future==0.18.3 23 | gitdb==4.0.10 24 | GitPython==3.1.37 25 | google-auth==2.23.3 26 | google-auth-oauthlib==1.0.0 27 | grpcio==1.59.0 28 | huggingface-hub==0.18.0 29 | humanfriendly==10.0 30 | idna==3.4 31 | imageio==2.31.5 32 | imageio-ffmpeg==0.4.9 33 | importlib-metadata==6.8.0 34 | importlib-resources==6.1.0 35 | install==1.3.5 36 | itsdangerous==2.1.2 37 | Jinja2==3.1.2 38 | kiwisolver==1.4.5 39 | lazy_loader==0.3 40 | lightning-utilities==0.9.0 41 | lmdb==1.4.1 42 | lpips==0.1.4 43 | Markdown==3.5 44 | MarkupSafe==2.1.3 45 | matplotlib==3.7.3 46 | moviepy==1.0.3 47 | mpmath==1.3.0 48 | networkx==3.1 49 | numpy==1.24.4 50 | oauthlib==3.2.2 51 | omegaconf==2.3.0 52 | onnx==1.14.1 53 | onnxruntime==1.15.1 54 | opencv-python==4.8.1.78 55 | packaging==23.2 56 | pathtools==0.1.2 57 | Pillow==10.1.0 58 | platformdirs==3.11.0 59 | proglog==0.1.10 60 | protobuf==4.24.4 61 | psutil==5.9.6 62 | pyasn1==0.5.0 63 | pyasn1-modules==0.3.0 64 | pycocotools==2.0.7 65 | pyparsing==3.1.1 66 | python-dateutil==2.8.2 67 | PyWavelets==1.4.1 68 | PyYAML==6.0.1 69 | requests==2.31.0 70 | requests-oauthlib==1.3.1 71 | rsa==4.9 72 | safetensors==0.4.0 73 | scikit-image==0.21.0 74 | scipy==1.10.1 75 | segment-anything @ git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588 76 | sentry-sdk==1.32.0 77 | setproctitle==1.3.3 78 | six==1.16.0 79 | sk-video==1.1.10 80 | smmap==5.0.1 81 | sympy==1.12 82 | tb-nightly==2.14.0a20230808 83 | tensorboard==2.14.0 84 | tensorboard-data-server==0.7.1 85 | tifffile==2023.7.10 86 | timm==0.9.7 87 | tomli==2.0.1 88 | torch==1.12.1+cu116 89 | torchaudio==0.12.1+cu116 90 | torchmetrics==1.2.0 91 | torchvision==0.13.1+cu116 92 | tqdm==4.66.1 93 | typing_extensions==4.8.0 94 | urllib3==2.0.6 95 | wandb==0.15.12 96 | Werkzeug==3.0.0 97 | yapf==0.40.2 98 | zipp==3.17.0 99 | -------------------------------------------------------------------------------- /webapp/.gitignore: -------------------------------------------------------------------------------- 1 | .vscode -------------------------------------------------------------------------------- /webapp/backend/.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | images 3 | data/models/sam_vit_b_01ec64.pth 4 | __pycache__ 5 | *.npy 6 | *.pth 7 | *.onnx 8 | dances 9 | app_all.py 10 | dancimation.py 11 | filters.py -------------------------------------------------------------------------------- /webapp/backend/README.md: -------------------------------------------------------------------------------- 1 | ## Server for "Manipulated Interpolation of Anything" 2 | We need to setup a server to serve segmenting tasks 3 | 4 | ### Download checkpoint 5 | ``` 6 | cd data/models 7 | wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth 8 | cd ../.. 9 | ``` 10 | 11 | ### Start Server 12 | ``` 13 | python app.py 14 | ``` -------------------------------------------------------------------------------- /webapp/backend/data/embeddings/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzh-tech/InterpAny-Clearer/50ece06a1fca91bfeed2bd3f2e724d696833ff63/webapp/backend/data/embeddings/.gitkeep -------------------------------------------------------------------------------- /webapp/backend/data/models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzh-tech/InterpAny-Clearer/50ece06a1fca91bfeed2bd3f2e724d696833ff63/webapp/backend/data/models/.gitkeep -------------------------------------------------------------------------------- /webapp/backend/data/uploads/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzh-tech/InterpAny-Clearer/50ece06a1fca91bfeed2bd3f2e724d696833ff63/webapp/backend/data/uploads/.gitkeep -------------------------------------------------------------------------------- /webapp/backend/testServer.py: -------------------------------------------------------------------------------- 1 | import redis 2 | from pprint import pprint 3 | import time 4 | 5 | rcli = redis.Redis() 6 | 7 | pprint("Idle") 8 | rcli.set('dancetype', 0) 9 | time.sleep(10) 10 | 11 | rcli.set('dancetype', 3) 12 | rcli.set('tempo', 120) 13 | time.sleep(20) 14 | 15 | pprint("Idle") 16 | rcli.set('dancetype', 0) 17 | -------------------------------------------------------------------------------- /webapp/webapp/.gitignore: -------------------------------------------------------------------------------- 1 | node_modules 2 | yarn.lock 3 | -------------------------------------------------------------------------------- /webapp/webapp/CREATE_EMBEDDING.md: -------------------------------------------------------------------------------- 1 | ## Export the image embedding 2 | 3 | In the [ONNX Model Example notebook](https://github.com/facebookresearch/segment-anything/blob/main/notebooks/onnx_model_example.ipynb) upload the image of your choice and generate and save corresponding embedding. 4 | 5 | Initialize the predictor: 6 | 7 | ```python 8 | checkpoint = "sam_vit_h_4b8939.pth" 9 | model_type = "vit_h" 10 | sam = sam_model_registry[model_type](checkpoint=checkpoint) 11 | sam.to(device='cuda') 12 | predictor = SamPredictor(sam) 13 | ``` 14 | 15 | Set the new image and export the embedding: 16 | 17 | ``` 18 | image = cv2.imread('src/assets/dogs.jpg') 19 | predictor.set_image(image) 20 | image_embedding = predictor.get_image_embedding().cpu().numpy() 21 | np.save("dogs_embedding.npy", image_embedding) 22 | ``` 23 | 24 | Save the new image and embedding in `src/assets/data`. 25 | 26 | ## Export the ONNX model 27 | 28 | You also need to export the quantized ONNX model from the [ONNX Model Example notebook](https://github.com/facebookresearch/segment-anything/blob/main/notebooks/onnx_model_example.ipynb). 29 | 30 | Run the cell in the notebook which saves the `sam_onnx_quantized_example.onnx` file, download it and copy it to the path `/model/sam_onnx_quantized_example.onnx`. 31 | 32 | Here is a snippet of the export/quantization code: 33 | 34 | ``` 35 | onnx_model_path = "sam_onnx_example.onnx" 36 | onnx_model_quantized_path = "sam_onnx_quantized_example.onnx" 37 | quantize_dynamic( 38 | model_input=onnx_model_path, 39 | model_output=onnx_model_quantized_path, 40 | optimize_model=True, 41 | per_channel=False, 42 | reduce_range=False, 43 | weight_type=QuantType.QUInt8, 44 | ) 45 | ``` 46 | 47 | **NOTE: if you change the ONNX model by using a new checkpoint you need to also re-export the embedding.** 48 | 49 | ## Update the image, embedding, model in the app 50 | 51 | Update the following file paths at the top of`App.tsx`: 52 | 53 | ```py 54 | const IMAGE_PATH = "/assets/data/dogs.jpg"; 55 | const IMAGE_EMBEDDING = "/assets/data/dogs_embedding.npy"; 56 | const MODEL_DIR = "/model/sam_onnx_quantized_example.onnx"; 57 | ``` 58 | 59 | ## ONNX multithreading with SharedArrayBuffer 60 | 61 | To use multithreading, the appropriate headers need to be set to create a cross origin isolation state which will enable use of `SharedArrayBuffer` (see this [blog post](https://cloudblogs.microsoft.com/opensource/2021/09/02/onnx-runtime-web-running-your-machine-learning-model-in-browser/) for more details) 62 | 63 | The headers below are set in `configs/webpack/dev.js`: 64 | 65 | ```js 66 | headers: { 67 | "Cross-Origin-Opener-Policy": "same-origin", 68 | "Cross-Origin-Embedder-Policy": "credentialless", 69 | } 70 | ``` 71 | 72 | ## Structure of the app 73 | 74 | **`App.tsx`** 75 | 76 | - Initializes ONNX model 77 | - Loads image embedding and image 78 | - Runs the ONNX model based on input prompts 79 | 80 | **`Stage.tsx`** 81 | 82 | - Handles mouse move interaction to update the ONNX model prompt 83 | 84 | **`Tool.tsx`** 85 | 86 | - Renders the image and the mask prediction 87 | 88 | **`helpers/maskUtils.tsx`** 89 | 90 | - Conversion of ONNX model output from array to an HTMLImageElement 91 | 92 | **`helpers/onnxModelAPI.tsx`** 93 | 94 | - Formats the inputs for the ONNX model 95 | 96 | **`helpers/scaleHelper.tsx`** 97 | 98 | - Handles image scaling logic for SAM (longest size 1024) 99 | 100 | **`hooks/`** 101 | 102 | - Handle shared state for the app -------------------------------------------------------------------------------- /webapp/webapp/README.md: -------------------------------------------------------------------------------- 1 | ## Manipulated interpolation of anything 2 | 3 | ## Run the app 4 | 5 | 1. Install node and npm 6 | 7 | 2. Install Yarn 8 | 9 | ``` 10 | npm install --g yarn 11 | ``` 12 | > Note: yarn version: 1.22.18 or above 13 | $ yarn --version 14 | 15 | 3. Build and run: 16 | 17 | ``` 18 | yarn && yarn start 19 | ``` 20 | 21 | 4. Navigate to [`http://localhost:8080/`](http://localhost:8080/) 22 | > Note: running `yarn start` would automatically open http://localhost:8080/ 23 | 24 | Move your cursor around to see the mask prediction update in real time. -------------------------------------------------------------------------------- /webapp/webapp/configs/webpack/common.js: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | const { resolve } = require("path"); 8 | const HtmlWebpackPlugin = require("html-webpack-plugin"); 9 | const FriendlyErrorsWebpackPlugin = require("friendly-errors-webpack-plugin"); 10 | const CopyPlugin = require("copy-webpack-plugin"); 11 | const webpack = require("webpack"); 12 | 13 | module.exports = { 14 | entry: "./src/index.tsx", 15 | resolve: { 16 | extensions: [".js", ".jsx", ".ts", ".tsx"], 17 | }, 18 | output: { 19 | path: resolve(__dirname, "dist"), 20 | }, 21 | module: { 22 | rules: [ 23 | { 24 | test: /\.mjs$/, 25 | include: /node_modules/, 26 | type: "javascript/auto", 27 | resolve: { 28 | fullySpecified: false, 29 | }, 30 | }, 31 | { 32 | test: [/\.jsx?$/, /\.tsx?$/], 33 | use: ["ts-loader"], 34 | exclude: /node_modules/, 35 | }, 36 | { 37 | test: /\.css$/, 38 | use: ["style-loader", "css-loader"], 39 | }, 40 | { 41 | test: /\.(scss|sass)$/, 42 | use: ["style-loader", "css-loader", "postcss-loader"], 43 | }, 44 | { 45 | test: /\.(jpe?g|png|gif|svg)$/i, 46 | use: [ 47 | "file-loader?hash=sha512&digest=hex&name=img/[contenthash].[ext]", 48 | "image-webpack-loader?bypassOnDebug&optipng.optimizationLevel=7&gifsicle.interlaced=false", 49 | ], 50 | }, 51 | { 52 | test: /\.(woff|woff2|ttf)$/, 53 | use: { 54 | loader: "url-loader", 55 | }, 56 | }, 57 | ], 58 | }, 59 | plugins: [ 60 | new CopyPlugin({ 61 | patterns: [ 62 | { 63 | from: "node_modules/onnxruntime-web/dist/*.wasm", 64 | to: "[name][ext]", 65 | }, 66 | { 67 | from: "model", 68 | to: "model", 69 | }, 70 | { 71 | from: "src/assets", 72 | to: "assets", 73 | }, 74 | ], 75 | }), 76 | new HtmlWebpackPlugin({ 77 | template: "./src/assets/index.html", 78 | }), 79 | new FriendlyErrorsWebpackPlugin(), 80 | new webpack.ProvidePlugin({ 81 | process: "process/browser", 82 | }), 83 | ], 84 | }; 85 | -------------------------------------------------------------------------------- /webapp/webapp/configs/webpack/dev.js: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | // development config 8 | const { merge } = require("webpack-merge"); 9 | const commonConfig = require("./common"); 10 | 11 | module.exports = merge(commonConfig, { 12 | mode: "development", 13 | devServer: { 14 | hot: true, // enable HMR on the server 15 | open: true, 16 | // These headers enable the cross origin isolation state 17 | // needed to enable use of SharedArrayBuffer for ONNX 18 | // multithreading. 19 | headers: { 20 | "Cross-Origin-Opener-Policy": "same-origin", 21 | "Cross-Origin-Embedder-Policy": "credentialless", 22 | }, 23 | }, 24 | devtool: "cheap-module-source-map", 25 | }); 26 | -------------------------------------------------------------------------------- /webapp/webapp/configs/webpack/prod.js: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | // production config 8 | const { merge } = require("webpack-merge"); 9 | const { resolve } = require("path"); 10 | const Dotenv = require("dotenv-webpack"); 11 | const commonConfig = require("./common"); 12 | 13 | module.exports = merge(commonConfig, { 14 | mode: "production", 15 | output: { 16 | filename: "js/bundle.[contenthash].min.js", 17 | path: resolve(__dirname, "../../dist"), 18 | publicPath: "/", 19 | }, 20 | devtool: "source-map", 21 | plugins: [new Dotenv()], 22 | }); 23 | -------------------------------------------------------------------------------- /webapp/webapp/model/sam_onnx_quantized_example_fast.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzh-tech/InterpAny-Clearer/50ece06a1fca91bfeed2bd3f2e724d696833ff63/webapp/webapp/model/sam_onnx_quantized_example_fast.onnx -------------------------------------------------------------------------------- /webapp/webapp/model/sam_onnx_quantized_example_full.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzh-tech/InterpAny-Clearer/50ece06a1fca91bfeed2bd3f2e724d696833ff63/webapp/webapp/model/sam_onnx_quantized_example_full.onnx -------------------------------------------------------------------------------- /webapp/webapp/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "segment-anything-mini-demo", 3 | "version": "0.1.0", 4 | "license": "MIT", 5 | "scripts": { 6 | "build": "yarn run clean-dist && webpack --config=configs/webpack/prod.js && mv dist/*.wasm dist/js", 7 | "clean-dist": "rimraf dist/*", 8 | "lint": "eslint './src/**/*.{js,ts,tsx}' --quiet", 9 | "start": "yarn run start-dev", 10 | "test": "yarn run start-model-test", 11 | "start-dev": "webpack serve --config=configs/webpack/dev.js" 12 | }, 13 | "devDependencies": { 14 | "@babel/core": "^7.18.13", 15 | "@babel/preset-env": "^7.18.10", 16 | "@babel/preset-react": "^7.18.6", 17 | "@babel/preset-typescript": "^7.18.6", 18 | "@pmmmwh/react-refresh-webpack-plugin": "^0.5.7", 19 | "@testing-library/react": "^13.3.0", 20 | "@types/node": "^18.7.13", 21 | "@types/react": "18", 22 | "@types/react-dom": "18", 23 | "@types/underscore": "^1.11.4", 24 | "@typescript-eslint/eslint-plugin": "^5.35.1", 25 | "@typescript-eslint/parser": "^5.35.1", 26 | "babel-loader": "^8.2.5", 27 | "copy-webpack-plugin": "^11.0.0", 28 | "css-loader": "^6.7.1", 29 | "dotenv": "^16.0.2", 30 | "dotenv-webpack": "^8.0.1", 31 | "eslint": "^8.22.0", 32 | "eslint-plugin-react": "^7.31.0", 33 | "file-loader": "^6.2.0", 34 | "fork-ts-checker-webpack-plugin": "^7.2.13", 35 | "friendly-errors-webpack-plugin": "^1.7.0", 36 | "html-webpack-plugin": "^5.5.0", 37 | "image-webpack-loader": "^8.1.0", 38 | "postcss-loader": "^7.0.1", 39 | "postcss-preset-env": "^7.8.0", 40 | "process": "^0.11.10", 41 | "rimraf": "^3.0.2", 42 | "sass": "^1.54.5", 43 | "sass-loader": "^13.0.2", 44 | "style-loader": "^3.3.1", 45 | "tailwindcss": "^3.1.8", 46 | "ts-loader": "^9.3.1", 47 | "typescript": "^4.8.2", 48 | "webpack": "^5.74.0", 49 | "webpack-cli": "^4.10.0", 50 | "webpack-dev-server": "^4.10.0", 51 | "webpack-dotenv-plugin": "^2.1.0", 52 | "webpack-merge": "^5.8.0" 53 | }, 54 | "dependencies": { 55 | "@heroicons/react": "^2.0.18", 56 | "@material-tailwind/react": "^2.0.1", 57 | "axios": "^1.4.0", 58 | "chart.js": "2.9.3", 59 | "chartjs-plugin-dragdata": "1.1.3", 60 | "npyjs": "^0.4.0", 61 | "onnxruntime-web": "^1.14.0", 62 | "react": "^18.2.0", 63 | "react-chartjs-2": "2.9.0", 64 | "react-dom": "^18.2.0", 65 | "react-refresh": "^0.14.0", 66 | "underscore": "^1.13.6", 67 | "util": "^0.12.5", 68 | "uuidv4": "^6.2.13" 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /webapp/webapp/postcss.config.js: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | const tailwindcss = require("tailwindcss"); 8 | module.exports = { 9 | plugins: ["postcss-preset-env", 'tailwindcss/nesting', tailwindcss], 10 | }; 11 | -------------------------------------------------------------------------------- /webapp/webapp/public/images/close_button.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzh-tech/InterpAny-Clearer/50ece06a1fca91bfeed2bd3f2e724d696833ff63/webapp/webapp/public/images/close_button.png -------------------------------------------------------------------------------- /webapp/webapp/src/App.tsx: -------------------------------------------------------------------------------- 1 | import React, { useState } from 'react'; 2 | import { AppContextProvider, ControlContextProvider } from './components/hooks/context'; 3 | import StageApp from './StageApp'; 4 | import ControlApp from './ControlApp'; 5 | import { modelMaskProps, modelRawMaskProps } from './components/helpers/Interfaces'; 6 | 7 | // dog 8 | // const IMAGE_PATH1 = "/assets/data/dog.jpg"; 9 | // const IMAGE_PATH2 = "/assets/data/dog.jpg"; 10 | // const IMAGE_EMBEDDING = "/assets/data/dog_embedding.npy"; 11 | 12 | // dogs 13 | // const IMAGE_PATH1 = "/assets/data/dogs.jpg"; 14 | // const IMAGE_PATH2 = "/assets/data/dogs.jpg"; 15 | // const IMAGE_EMBEDDING = "/assets/data/dogs_embedding.npy"; 16 | 17 | // truck 18 | // const IMAGE_PATH1 = "/assets/data/truck.jpg"; 19 | // const IMAGE_PATH2 = "/assets/data/truck.jpg"; 20 | // const IMAGE_EMBEDDING = "/assets/data/truck_embedding.npy"; 21 | 22 | // groceries 23 | const IMAGE_PATH1 = "/assets/data/groceries.jpg"; 24 | const IMAGE_PATH2 = "/assets/data/truck.jpg"; 25 | const IMAGE_EMBEDDING = "/assets/data/groceries_embedding_fast.npy"; 26 | 27 | // 000 28 | // const IMAGE_PATH1 = "/assets/data/000.png"; 29 | // const IMAGE_PATH2 = "/assets/data/001.png"; 30 | // const IMAGE_EMBEDDING = "/assets/data/000_embedding.npy"; 31 | 32 | const App = () => { 33 | const [masks, setMasks] = useState>([]); 34 | const [defaultRawMask, setDefaultRawMask] = useState(null); 35 | const [blocking, setBlocking] = useState(false); 36 | const addMask = (mask: modelMaskProps) => { 37 | setMasks([...masks, mask]); 38 | } 39 | 40 | const [imagePath1, setImagePath1] = useState(""); 41 | const [imagePath2, setImagePath2] = useState(""); 42 | 43 | const handleDeleteMask = (index: number) => { 44 | const newMasks = [...masks]; 45 | newMasks.splice(index,1) 46 | setMasks(newMasks); 47 | }; 48 | 49 | const updateLoader = (loading: boolean) => { 50 | setBlocking(loading) 51 | } 52 | 53 | const updateMasks = (masks: Array) => { 54 | console.log("vatran updateMasks", "") 55 | setMasks(masks); 56 | } 57 | 58 | const flexCenterClasses = "flex items-center justify-center m-auto"; 59 | return ( 60 | <> 61 |
62 |

Manipulated Interpolation of Anything

63 |
64 | 65 | 66 | 67 | 68 | 70 | 71 |
72 |
73 | 74 |
75 |
76 |
77 |
Loading...
78 |
79 |
80 | 81 | ); 82 | } 83 | 84 | export default App; -------------------------------------------------------------------------------- /webapp/webapp/src/assets/images/loader.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zzh-tech/InterpAny-Clearer/50ece06a1fca91bfeed2bd3f2e724d696833ff63/webapp/webapp/src/assets/images/loader.gif -------------------------------------------------------------------------------- /webapp/webapp/src/assets/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 9 | Manipulated Interpolation of Anything 10 | 11 | 12 | 13 | 14 | 15 | 16 |
17 | 18 | 19 | -------------------------------------------------------------------------------- /webapp/webapp/src/components/DragDropFile.jsx: -------------------------------------------------------------------------------- 1 | import React, { useState, useContext, useEffect, useRef } from "react"; 2 | const axios = require("axios"); 3 | const SERVER_URL = "http://127.0.0.1:5001"; 4 | 5 | // drag drop file component 6 | function DragDropFile({ message, uploadedCallback }) { 7 | // drag state 8 | const [dragActive, setDragActive] = React.useState(false); 9 | // ref 10 | const inputRef = React.useRef(null); 11 | 12 | // handle drag events 13 | const handleDrag = function(e) { 14 | e.preventDefault(); 15 | e.stopPropagation(); 16 | if (e.type === "dragenter" || e.type === "dragover") { 17 | setDragActive(true); 18 | } else if (e.type === "dragleave") { 19 | setDragActive(false); 20 | } 21 | }; 22 | 23 | // triggers when file is dropped 24 | const handleDrop = function(e) { 25 | e.preventDefault(); 26 | e.stopPropagation(); 27 | setDragActive(false); 28 | if (e.dataTransfer.files && e.dataTransfer.files[0]) { 29 | // at least one file has been dropped so do something 30 | // handleFiles(e.dataTransfer.files); 31 | console.log(e.dataTransfer.files[0]); 32 | uploadFile(e.dataTransfer.files[0]); 33 | } 34 | }; 35 | 36 | // triggers when file is selected with click 37 | const handleChange = function(e) { 38 | e.preventDefault(); 39 | if (e.target.files && e.target.files[0]) { 40 | // handleFiles(e.target.files); 41 | console.log(e.target.files[0]); 42 | uploadFile(e.target.files[0]); 43 | } 44 | }; 45 | 46 | // triggers the input when the button is clicked 47 | const onButtonClick = () => { 48 | inputRef.current.click(); 49 | }; 50 | 51 | const uploadFile = (file) => { 52 | let formData = new FormData(); 53 | 54 | formData.append("file", file); 55 | 56 | axios.post(`${SERVER_URL}/upload`, formData, { 57 | headers: { 58 | "Content-Type": "multipart/form-data", 59 | }, 60 | }) 61 | .then((response) => { 62 | console.log(response); 63 | const url = new URL(response['data']['url'], SERVER_URL); 64 | const fileName = response['data']['filename']; 65 | uploadedCallback(url, fileName); 66 | }, (error) => { 67 | console.log(error); 68 | }); 69 | }; 70 | 71 | return ( 72 |
e.preventDefault()}> 73 | 74 | 80 | { dragActive &&
} 81 |
82 | ); 83 | }; 84 | 85 | export default DragDropFile; -------------------------------------------------------------------------------- /webapp/webapp/src/components/Mask.tsx: -------------------------------------------------------------------------------- 1 | import React, { useContext } from 'react'; 2 | import { MaskProps } from "./helpers/Interfaces"; 3 | import LineChart from './LineChart'; 4 | import { Button, Checkbox } from "@material-tailwind/react"; 5 | import { 6 | TrashIcon 7 | } from "@heroicons/react/24/outline"; 8 | import { ControlContext } from './hooks/createContext'; 9 | 10 | const Mask = ({ mask, index, handleDelete }: MaskProps) => { 11 | const { 12 | selectedIndices: [selectedIndices, setSelectedIndices] 13 | } = useContext(ControlContext)!; 14 | 15 | const handleChange = (event: any) => { 16 | console.log("vatran", event, index) 17 | if (event.target.checked) { 18 | addToSelectedIndices(index); 19 | } else { 20 | removeFromSelectedIndices(index); 21 | } 22 | } 23 | 24 | const addToSelectedIndices = (index: number) => { 25 | setSelectedIndices([...selectedIndices, index]) 26 | } 27 | 28 | const removeFromSelectedIndices = (index: number) => { 29 | const indices = selectedIndices.filter((x) => x != index) 30 | setSelectedIndices(indices) 31 | } 32 | 33 | return ( 34 |
35 |
36 | = 0} onResize={undefined} onResizeCapture={undefined} crossOrigin={undefined}/> 37 | 40 |
41 | Photo 42 |
43 |
44 | ); 45 | } 46 | 47 | export default Mask; -------------------------------------------------------------------------------- /webapp/webapp/src/components/MaskList.tsx: -------------------------------------------------------------------------------- 1 | import React from "react"; 2 | import { MaskListProps } from "./helpers/Interfaces"; 3 | import Mask from './Mask'; 4 | 5 | const MaskList = ({ masks, handleDelete }: MaskListProps) => { 6 | 7 | return ( 8 | <> 9 | { 10 | masks.map(mask => { 11 | const index = masks.indexOf(mask); 12 | return 13 | }) 14 | } 15 | 16 | ); 17 | } 18 | 19 | export default MaskList; -------------------------------------------------------------------------------- /webapp/webapp/src/components/Stage.tsx: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | import React, { useContext } from "react"; 8 | import * as _ from "underscore"; 9 | import Tool from "./Tool"; 10 | import { StageProps, modelInputProps } from "./helpers/Interfaces"; 11 | import { AppContext } from "./hooks/createContext"; 12 | import { v4 as uuidv4 } from 'uuid'; 13 | 14 | const Stage = ({ setDefaultRawMask, addMask }: StageProps) => { 15 | const { 16 | clicks: [, setClicks], 17 | image: [image, setImage], 18 | rawMask: [rawMask], 19 | maskImg: [maskImg, setMaskImg], 20 | } = useContext(AppContext)!; 21 | 22 | const getClick = (x: number, y: number): modelInputProps => { 23 | const clickType = 1; 24 | return { x, y, clickType }; 25 | }; 26 | 27 | // Get mouse position and scale the (x, y) coordinates back to the natural 28 | // scale of the image. Update the state of clicks with setClicks to trigger 29 | // the ONNX model to run and generate a new mask via a useEffect in App.tsx 30 | const handleMouseMove = _.throttle((e: any) => { 31 | let el = e.nativeEvent.target; 32 | const rect = el.getBoundingClientRect(); 33 | let x = e.clientX - rect.left; 34 | let y = e.clientY - rect.top; 35 | const imageScale = image ? image.width / el.offsetWidth : 1; 36 | x *= imageScale; 37 | y *= imageScale; 38 | const click = getClick(x, y); 39 | if (click) setClicks([click]); 40 | }, 15); 41 | 42 | const handleMouseClick = (e: any) => { 43 | let el = e.nativeEvent.target; 44 | const rect = el.getBoundingClientRect(); 45 | let x = e.clientX - rect.left; 46 | let y = e.clientY - rect.top; 47 | const imageScale = image ? image.width / el.offsetWidth : 1; 48 | x *= imageScale; 49 | y *= imageScale; 50 | 51 | // add a mask to the mask panel 52 | if (maskImg && rawMask) { 53 | addMask( 54 | { 55 | id: uuidv4(), 56 | name: `${length}`, 57 | maskImg: maskImg, 58 | rawMask: rawMask 59 | } 60 | ); 61 | } 62 | } 63 | 64 | const clearImage1 = (event: { target: any; }) => { 65 | setImage(null); 66 | setDefaultRawMask(null); 67 | setMaskImg(null); 68 | }; 69 | 70 | const flexCenterClasses = "flex items-center justify-center"; 71 | return ( 72 |
73 |
74 | {image && ( 75 |
76 | {image && ( 77 | 78 | )} 79 | 80 |
81 | )} 82 |
83 |
84 | ); 85 | }; 86 | 87 | export default Stage; 88 | -------------------------------------------------------------------------------- /webapp/webapp/src/components/Tool.tsx: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | import React, { useContext, useEffect, useState } from "react"; 8 | import { AppContext } from "./hooks/createContext"; 9 | import { ToolProps } from "./helpers/Interfaces"; 10 | import * as _ from "underscore"; 11 | 12 | const Tool = ({ handleMouseClick, handleMouseMove }: ToolProps) => { 13 | const { 14 | image: [image], 15 | maskImg: [maskImg, setMaskImg], 16 | } = useContext(AppContext)!; 17 | 18 | // Determine if we should shrink or grow the images to match the 19 | // width or the height of the page and setup a ResizeObserver to 20 | // monitor changes in the size of the page 21 | const [shouldFitToWidth, setShouldFitToWidth] = useState(true); 22 | const bodyEl = document.body; 23 | const fitToPage = () => { 24 | if (!image) return; 25 | const imageAspectRatio = image.width / image.height; 26 | const screenAspectRatio = window.innerWidth / window.innerHeight; 27 | setShouldFitToWidth(imageAspectRatio > screenAspectRatio); 28 | }; 29 | const resizeObserver = new ResizeObserver((entries) => { 30 | for (const entry of entries) { 31 | if (entry.target === bodyEl) { 32 | fitToPage(); 33 | } 34 | } 35 | }); 36 | useEffect(() => { 37 | fitToPage(); 38 | resizeObserver.observe(bodyEl); 39 | return () => { 40 | resizeObserver.unobserve(bodyEl); 41 | }; 42 | }, [image]); 43 | 44 | const imageClasses = ""; 45 | const maskImageClasses = `absolute opacity-40 pointer-events-none`; 46 | 47 | // Render the image and the predicted mask image on top 48 | return ( 49 | <> 50 | {image && ( 51 | _.defer(() => setMaskImg(null))} 55 | onTouchStart={handleMouseMove} 56 | src={image.src} 57 | className={`${ 58 | shouldFitToWidth ? "w-full" : "h-full" 59 | } ${imageClasses}`} 60 | > 61 | )} 62 | {maskImg && ( 63 | 69 | )} 70 | 71 | ); 72 | }; 73 | 74 | export default Tool; 75 | -------------------------------------------------------------------------------- /webapp/webapp/src/components/helpers/Interfaces.tsx: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | import { Tensor } from "onnxruntime-web"; 8 | 9 | export interface modelRawMaskProps { 10 | height: number; 11 | width: number; 12 | data: Uint8ClampedArray; 13 | } 14 | 15 | export interface modelMaskProps { 16 | id: string; 17 | name: string; 18 | maskImg: HTMLImageElement; 19 | rawMask: modelRawMaskProps; 20 | } 21 | 22 | export interface modelControlProps { 23 | points: Array 24 | } 25 | 26 | export interface modelScaleProps { 27 | samScale: number; 28 | height: number; 29 | width: number; 30 | } 31 | 32 | export interface modelInputProps { 33 | x: number; 34 | y: number; 35 | clickType: number; 36 | } 37 | 38 | export interface modeDataProps { 39 | clicks?: Array; 40 | tensor: Tensor; 41 | modelScale: modelScaleProps; 42 | } 43 | 44 | export interface StageAppProps { 45 | setDefaultRawMask: (rawMask: modelRawMaskProps | null) => void; 46 | addMask: (mask: modelMaskProps) => void; 47 | updateLoader: (loading: boolean) => void; 48 | onUploadImage1: (path: string) => void; 49 | onUploadImage2: (path: string) => void; 50 | } 51 | 52 | export interface StageProps { 53 | setDefaultRawMask: (rawMask: modelRawMaskProps | null) => void; 54 | addMask: (mask: modelMaskProps) => void; 55 | } 56 | 57 | export interface ControlProps { 58 | defaultRawMask: modelRawMaskProps | null; 59 | masks: Array; 60 | image1Path: string; 61 | image2Path: string; 62 | handleDelete: (index: number) => void; 63 | updateMasks: (masks: Array) => void; 64 | updateLoader: (loading: boolean) => void; 65 | } 66 | 67 | export interface ToolProps { 68 | handleMouseClick: (e: any) => void; 69 | handleMouseMove: (e: any) => void; 70 | } 71 | 72 | export interface MaskListProps { 73 | masks: Array<{id: string, name: string, maskImg: HTMLImageElement}> 74 | handleDelete: (index: number) => void; 75 | } 76 | 77 | export interface MaskProps { 78 | mask: {id: string, name: string, maskImg: HTMLImageElement}; 79 | index: number; 80 | handleDelete: (index: number) => void; 81 | } 82 | 83 | export interface LineChartProps { 84 | maskIndex: number 85 | } -------------------------------------------------------------------------------- /webapp/webapp/src/components/helpers/maskUtils.tsx: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | // Convert the onnx model mask prediction to ImageData 8 | export function arrayToMaskArray(input: any) { 9 | const arr = new Uint8ClampedArray(input.length).fill(0); 10 | for (let i = 0; i < input.length; i++) { 11 | 12 | // Threshold the onnx model mask prediction at 0.0 13 | // This is equivalent to thresholding the mask using predictor.model.mask_threshold 14 | // in python 15 | if (input[i] > 0.0) { 16 | arr[i] = 1; 17 | } 18 | } 19 | return arr; 20 | } 21 | 22 | export function allOnesMaskArray(width: number, height: number) { 23 | const arr = new Uint8ClampedArray(width * height).fill(1); 24 | return arr; 25 | } 26 | 27 | // Convert the onnx model mask prediction to ImageData 28 | function arrayToImageData(input: any, width: number, height: number) { 29 | const [r, g, b, a] = [0, 114, 189, 255]; // the masks's blue color 30 | const arr = new Uint8ClampedArray(4 * width * height).fill(0); 31 | for (let i = 0; i < input.length; i++) { 32 | 33 | // Threshold the onnx model mask prediction at 0.0 34 | // This is equivalent to thresholding the mask using predictor.model.mask_threshold 35 | // in python 36 | if (input[i] > 0.0) { 37 | arr[4 * i + 0] = r; 38 | arr[4 * i + 1] = g; 39 | arr[4 * i + 2] = b; 40 | arr[4 * i + 3] = a; 41 | } 42 | } 43 | return new ImageData(arr, height, width); 44 | } 45 | 46 | // Use a Canvas element to produce an image from ImageData 47 | function imageDataToImage(imageData: ImageData) { 48 | const canvas = imageDataToCanvas(imageData); 49 | const image = new Image(); 50 | image.src = canvas.toDataURL(); 51 | return image; 52 | } 53 | 54 | // Canvas elements can be created from ImageData 55 | function imageDataToCanvas(imageData: ImageData) { 56 | const canvas = document.createElement("canvas"); 57 | const ctx = canvas.getContext("2d"); 58 | canvas.width = imageData.width; 59 | canvas.height = imageData.height; 60 | ctx?.putImageData(imageData, 0, 0); 61 | return canvas; 62 | } 63 | 64 | // Convert the onnx model mask output to an HTMLImageElement 65 | export function onnxMaskToImage(input: any, width: number, height: number) { 66 | return imageDataToImage(arrayToImageData(input, width, height)); 67 | } 68 | -------------------------------------------------------------------------------- /webapp/webapp/src/components/helpers/onnxModelAPI.tsx: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | import { Tensor } from "onnxruntime-web"; 8 | import { modeDataProps } from "./Interfaces"; 9 | 10 | const modelData = ({ clicks, tensor, modelScale }: modeDataProps) => { 11 | const imageEmbedding = tensor; 12 | let pointCoords; 13 | let pointLabels; 14 | let pointCoordsTensor; 15 | let pointLabelsTensor; 16 | 17 | // Check there are input click prompts 18 | if (clicks) { 19 | let n = clicks.length; 20 | 21 | // If there is no box input, a single padding point with 22 | // label -1 and coordinates (0.0, 0.0) should be concatenated 23 | // so initialize the array to support (n + 1) points. 24 | pointCoords = new Float32Array(2 * (n + 1)); 25 | pointLabels = new Float32Array(n + 1); 26 | 27 | // Add clicks and scale to what SAM expects 28 | for (let i = 0; i < n; i++) { 29 | pointCoords[2 * i] = clicks[i].x * modelScale.samScale; 30 | pointCoords[2 * i + 1] = clicks[i].y * modelScale.samScale; 31 | pointLabels[i] = clicks[i].clickType; 32 | } 33 | 34 | // Add in the extra point/label when only clicks and no box 35 | // The extra point is at (0, 0) with label -1 36 | pointCoords[2 * n] = 0.0; 37 | pointCoords[2 * n + 1] = 0.0; 38 | pointLabels[n] = -1.0; 39 | 40 | // Create the tensor 41 | pointCoordsTensor = new Tensor("float32", pointCoords, [1, n + 1, 2]); 42 | pointLabelsTensor = new Tensor("float32", pointLabels, [1, n + 1]); 43 | } 44 | const imageSizeTensor = new Tensor("float32", [ 45 | modelScale.height, 46 | modelScale.width, 47 | ]); 48 | 49 | if (pointCoordsTensor === undefined || pointLabelsTensor === undefined) 50 | return; 51 | 52 | // There is no previous mask, so default to an empty tensor 53 | const maskInput = new Tensor( 54 | "float32", 55 | new Float32Array(256 * 256), 56 | [1, 1, 256, 256] 57 | ); 58 | // There is no previous mask, so default to 0 59 | const hasMaskInput = new Tensor("float32", [0]); 60 | 61 | return { 62 | image_embeddings: imageEmbedding, 63 | point_coords: pointCoordsTensor, 64 | point_labels: pointLabelsTensor, 65 | orig_im_size: imageSizeTensor, 66 | mask_input: maskInput, 67 | has_mask_input: hasMaskInput, 68 | }; 69 | }; 70 | 71 | export { modelData }; 72 | -------------------------------------------------------------------------------- /webapp/webapp/src/components/helpers/scaleHelper.tsx: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | 8 | // Helper function for handling image scaling needed for SAM 9 | const handleImageScale = (image: HTMLImageElement) => { 10 | // Input images to SAM must be resized so the longest side is 1024 11 | const LONG_SIDE_LENGTH = 1024; 12 | let w = image.naturalWidth; 13 | let h = image.naturalHeight; 14 | const samScale = LONG_SIDE_LENGTH / Math.max(h, w); 15 | return { height: h, width: w, samScale }; 16 | }; 17 | 18 | export { handleImageScale }; 19 | -------------------------------------------------------------------------------- /webapp/webapp/src/components/hooks/context.tsx: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | import React, { useState } from "react"; 8 | import { modelInputProps, modelRawMaskProps, modelMaskProps, modelControlProps } from "../helpers/Interfaces"; 9 | import { AppContext, ControlContext } from "./createContext"; 10 | 11 | export const AppContextProvider = (props: { 12 | children: React.ReactElement>; 13 | }) => { 14 | const [clicks, setClicks] = useState | null>(null); 15 | const [image, setImage] = useState(null); 16 | const [rawMask, setRawMask] = useState(null); 17 | const [maskImg, setMaskImg] = useState(null); 18 | 19 | return ( 20 | 28 | {props.children} 29 | 30 | ); 31 | }; 32 | 33 | export const ControlContextProvider = (props: { 34 | children: React.ReactElement>; 35 | }) => { 36 | 37 | const [controls, setControls] = useState>([]); 38 | const [selectedIndices, setSelectedIndices] = useState>([]); 39 | 40 | return ( 41 | 47 | {props.children} 48 | 49 | ); 50 | }; 51 | -------------------------------------------------------------------------------- /webapp/webapp/src/components/hooks/createContext.tsx: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | import { createContext } from "react"; 8 | import { modelInputProps, modelRawMaskProps, modelControlProps } from "../helpers/Interfaces"; 9 | 10 | interface contextProps { 11 | clicks: [ 12 | clicks: modelInputProps[] | null, 13 | setClicks: (e: modelInputProps[] | null) => void 14 | ]; 15 | image: [ 16 | image: HTMLImageElement | null, 17 | setImage: (e: HTMLImageElement | null) => void 18 | ]; 19 | rawMask: [ 20 | rawMask: modelRawMaskProps | null, 21 | setRawMask: (e: modelRawMaskProps | null) => void 22 | ], 23 | maskImg: [ 24 | maskImg: HTMLImageElement | null, 25 | setMaskImg: (e: HTMLImageElement | null) => void 26 | ]; 27 | } 28 | 29 | interface contextControlsProps { 30 | controls: [ 31 | controls: Array, 32 | setControls: (e: Array) => void 33 | ]; 34 | selectedIndices: [ 35 | selectedIndices: Array, 36 | setSelectedIndices: (e: Array) => void 37 | ]; 38 | } 39 | 40 | export const AppContext = createContext(null); 41 | export const ControlContext = createContext(null); 42 | -------------------------------------------------------------------------------- /webapp/webapp/src/index.tsx: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | import * as React from "react"; 8 | import { createRoot } from "react-dom/client"; 9 | import App from "./App"; 10 | const container = document.getElementById("root"); 11 | const root = createRoot(container!); 12 | root.render( 13 | 14 | ); 15 | -------------------------------------------------------------------------------- /webapp/webapp/tailwind.config.js: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | const withMT = require("@material-tailwind/react/utils/withMT"); 7 | 8 | /** @type {import('tailwindcss').Config} */ 9 | module.exports = withMT({ 10 | content: ["./src/**/*.{html,js,tsx}"], 11 | theme: {}, 12 | plugins: [], 13 | variants: { 14 | extend: { 15 | opacity: ['disabled'], 16 | } 17 | }, 18 | }); 19 | -------------------------------------------------------------------------------- /webapp/webapp/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "lib": ["dom", "dom.iterable", "esnext"], 4 | "allowJs": true, 5 | "skipLibCheck": true, 6 | "strict": true, 7 | "forceConsistentCasingInFileNames": true, 8 | "noEmit": false, 9 | "esModuleInterop": true, 10 | "module": "esnext", 11 | "moduleResolution": "node", 12 | "resolveJsonModule": true, 13 | "isolatedModules": true, 14 | "jsx": "react", 15 | "incremental": true, 16 | "target": "ESNext", 17 | "useDefineForClassFields": true, 18 | "allowSyntheticDefaultImports": true, 19 | "outDir": "./dist/", 20 | "sourceMap": true 21 | }, 22 | "include": ["next-env.d.ts", "**/*.ts", "**/*.tsx", "src"], 23 | "exclude": ["node_modules"] 24 | } 25 | --------------------------------------------------------------------------------