├── .gitignore ├── README.md ├── correlation_package ├── __init__.py ├── correlation.py ├── correlation_cuda.cc ├── correlation_cuda_kernel.cu ├── correlation_cuda_kernel.cuh ├── nvcc setting.md ├── pyproject.toml └── setup.py ├── figure ├── ERF_dataset.png ├── Quantitative_eval_ERF_x170FPS.png ├── driving.gif ├── flower.gif └── popcorn.gif ├── install_correlation.sh ├── models ├── final_models │ ├── ours.py │ ├── ours_large.py │ └── submodules.py ├── loss_handler.py └── model_manager.py ├── pretrained_model └── README.md ├── run_samples.py ├── sample_data ├── 00000.png ├── 00000_0t.npz ├── 00000_t0.npz ├── 00000_t1.npz └── 00001.png ├── test_bsergb.py ├── tools ├── event_utils.py └── preprocess_events.py ├── train.py └── utils ├── dataloader_bsergb.py ├── flow_utils.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | output 3 | experiments 4 | pretrained_model 5 | tools/unused 6 | logs -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CBMNet(CVPR 2023, highlight) 2 | **Official repository for the CVPR 2023 paper, "Event-based Video Frame Interpolation with Cross-Modal Asymmetric Bidirectional Motion Fields"** 3 | 4 | \[[Paper](https://openaccess.thecvf.com/content/CVPR2023/papers/Kim_Event-Based_Video_Frame_Interpolation_With_Cross-Modal_Asymmetric_Bidirectional_Motion_Fields_CVPR_2023_paper.pdf)\] 5 | \[[Supp](https://openaccess.thecvf.com/content/CVPR2023/supplemental/Kim_Event-Based_Video_Frame_CVPR_2023_supplemental.zip)\] 6 | 7 | 8 | 9 | ## Qualitative video demos on ERF-X170FPS dataset 10 | ### Falling pop-corn 11 | 12 | 15 | 16 | ### Flowers 17 | 18 | 21 | 22 | ### Driving scene 23 | 24 | 27 | 28 | ## ERF-X170FPS dataset 29 | #### Dataset of high-resolution (1440x975), high-fps (170fps) video frames plus high resolution events with extremely large motion using the beam-splitter acquisition system: 30 | ![image info](./figure/ERF_dataset.png) 31 | 32 | ### Quantitative results on the ERF-X170FPS datasets 33 | 34 | 35 | 36 | ### Downloading ERF-X170FPS datasets 37 | You can download the raw-data(collected frame and events) from this links 38 | 39 | * [[RAW-Train](https://drive.google.com/file/d/1Bsf9qreziPcVEuf0_v3kjdPUh27zsFXK/view?usp=drive_link)] 40 | * [[RAW-Test](https://drive.google.com/file/d/1Dk7jVQD29HqRVV11e8vxg5bDOh6KxrzL/view?usp=drive_link)] 41 | 42 | ** Cautions: the x,y coordinates of the raw event file are multiplied by 128. 43 | 44 | ## Requirements 45 | * PyTorch 1.8.0 46 | * CUDA 11.2 47 | * python 3.8 48 | 49 | ## Quick Usage 50 | 51 | Download repository: 52 | 53 | ```bash 54 | $ git clone https://github.com/intelpro/CBMNet 55 | ``` 56 | 57 | Install correlation package: 58 | 59 | ```bash 60 | $ sh install_correlation.sh 61 | ``` 62 | 63 | Download network weights(trained on ERF-X170FPS datasets) and place downloaded model in ./pretrained_model/ 64 | 65 | * [[Ours](https://drive.google.com/file/d/1VJKyuoRSMOJkl8fQlJIkc7S4Fmd2_X8K/view?usp=sharing)] 66 | * [[Ours-Large](https://drive.google.com/file/d/1jI6_RwhXeM-pW5CnHf0exB5RP2zp2SbY/view?usp=sharing)] 67 | 68 | 69 | Generate an intermediate video frame using ours model: 70 | 71 | ```bash 72 | $ python run_samples.py --model_name ours --ckpt_path pretrained_model/ours_weight.pth --save_output_dir ./output --image_number 0 73 | ``` 74 | 75 | Also, you can generate intermediate video frame using ours-large model: 76 | 77 | ```bash 78 | $ python run_samples.py --model_name ours_large --ckpt_path pretrained_model/ours_large_weight.pth --save_output_dir ./output --image_number 0 79 | 80 | ``` 81 | 82 | 83 | ## 🚀 Quick Test on BSERGB Dataset 84 | 85 | This section describes how to test the model on the **BSERGB dataset** using the pre-trained weights. 86 | 87 | ### 1. Download BSERGB Dataset 88 | 89 | You can download the [BSERGB dataset](https://github.com/uzh-rpg/timelens-pp) from the official TimeLens++ GitHub repository. 90 | 91 | ### 2. Preprocess Event Voxel Data 92 | 93 | After downloading, the BSERGB dataset should have the following directory structure: 94 | 95 | 96 | ``` 97 | ├── BSERGB/ 98 | │ ├── 1_TEST/ 99 | │ │ ├── scene_001/ 100 | │ │ │ ├── images/ 101 | │ │ │ │ ├── 000000.png 102 | │ │ │ │ ├── ... 103 | │ │ │ ├── events/ 104 | │ │ │ │ ├── 000000.npz 105 | │ │ │ │ ├── ... 106 | │ │ ├── scene_002/ 107 | │ │ ├── scene_003/ 108 | │ │ ├── ... 109 | │ ├── 2_VALIDATION/ 110 | │ │ ├── scene_001/ 111 | │ │ ├── scene_002/ 112 | │ │ ├── ... 113 | │ ├── 3_TRAINING/ 114 | │ │ ├── scene_001/ 115 | │ │ ├── scene_002/ 116 | │ │ ├── ... 117 | ``` 118 | 119 | Now, convert the raw event data into event voxel grids using the following command: 120 | 121 | 122 | ```bash 123 | $ python tools/preprocess_events.py --dataset_dir BSERGB_DATASET_DIR --mode 1_TEST 124 | 125 | ``` 126 | - ``--dataset_dir BSERGB_DATASET_DIR``: Specifies the BSERGB dataset directory. 127 | - ``--mode 1_TEST``: Select the mode to convert raw events into event voxels. Choose 1_TEST if you only want to perform testing. 128 | 129 | ### 🛠️ Event Voxel Preprocessing Output 130 | 131 | After preprocessing, event voxel files will be generated and saved into the **target folder**. 132 | For each sample, three types of voxel grids will be saved: 133 | 134 | - `0t`: Events from the start frame to the interpolated frame 135 | - `t0`: Reversed version of `0t` (used for backward flow) 136 | - `t1`: Events from the interpolated frame to the end frame 137 | 138 | Each event voxel is stored in the following naming format: 139 | 140 | --_{suffix}.npz 141 | 142 | Each index is zero-padded using `zfill(6)`. The `{suffix}` represents one of the three types: `0t`, `t0`, or `t1`. 143 | 144 | #### 📁 Example 145 | 146 | ```text 147 | 000000-000002-000004_0t.npz # event voxel from 000000 to 000002 148 | 000000-000002-000004_t0.npz # reversed event voxel from 000002 to 000000 149 | 000000-000002-000004_t1.npz # event voxel from 000002 to 000004 150 | ``` 151 | 152 | Once the voxel preprocessing is complete and the files are generated in the proper format, you can proceed to download the pretrained model and run the test script. 153 | 154 | ### 3. Download Pretrained Weights 155 | 156 | Download the our-large weights (trained on the BSERGB dataset) and place the downloaded model inside the ./pretrained_model directory. 157 | 158 | 🔗 **[Ours-Large(BSERGB)](https://drive.google.com/file/d/1T5ycqQK4KVZQ4pAnNkr2Ff8XIWvhmj_f/view?usp=sharing)** 159 | 160 | Then move the file to the `./pretrained_model` directory: 161 | 162 | ``` bash 163 | # Ensure the directory exists 164 | mkdir -p pretrained_model 165 | 166 | # Move the downloaded model to the correct location 167 | mv /path/to/downloaded/Ours_Large_BSERGB.pth ./pretrained_model/ 168 | ``` 169 | 170 | Make sure the final path is: 171 | 172 | 173 | ``` bash 174 | ./pretrained_model/Ours_Large_BSERGB.pth 175 | ``` 176 | 177 | ### 4. Run test scripts 178 | 179 | Once preprocessing and downloading the pretrained model are complete, you can test the model on the BSERGB dataset: 180 | 181 | ``` bash 182 | $ python test_bsergb.py --dataset_dir BSERGB_DATASET_DIR 183 | ``` 184 | 185 | After running this script, gt and result images will be generated inside the ./output directory. 186 | 187 | By evaluating the output images, you can reproduce the same quantitative results reported in the paper. 188 | 189 | ## 🚀 Train model on BSERGB Dataset 190 | 191 | Training instructions and documentation will be available in the near future. (Work in progress) 192 | 193 | In the meantime, if you need to proceed quickly, please refer to the `train.py` file for rough guide. 194 | 195 | 196 | ## Reference 197 | > Taewoo Kim, Yujeong Chae, Hyun-kyurl Jang, and Kuk-Jin Yoon" Event-based Video Frame Interpolation with Cross-modal Asymmetric Bidirectional Motion Fields", In _CVPR_, 2023. 198 | ```bibtex 199 | @InProceedings{Kim_2023_CVPR, 200 | author = {Kim, Taewoo and Chae, Yujeong and Jang, Hyun-Kurl and Yoon, Kuk-Jin}, 201 | title = {Event-Based Video Frame Interpolation With Cross-Modal Asymmetric Bidirectional Motion Fields}, 202 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 203 | month = {June}, 204 | year = {2023}, 205 | pages = {18032-18042} 206 | } 207 | ``` 208 | ## Contact 209 | If you have any question, please send an email to taewoo(intelpro@kaist.ac.kr) 210 | 211 | ## License 212 | The project codes and datasets can be used for research and education only. 213 | -------------------------------------------------------------------------------- /correlation_package/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /correlation_package/correlation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.modules.module import Module 3 | from torch.autograd import Function 4 | import correlation_cuda 5 | 6 | class CorrelationFunction(Function): 7 | 8 | # def __init__(self, pad_size=3, kernel_size=3, max_displacement=20, stride1=1, stride2=2, corr_multiply=1): 9 | # super(CorrelationFunction, self).__init__() 10 | # self.pad_size = pad_size 11 | # self.kernel_size = kernel_size 12 | # self.max_displacement = max_displacement 13 | # self.stride1 = stride1 14 | # self.stride2 = stride2 15 | # self.corr_multiply = corr_multiply 16 | # # self.out_channel = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1) 17 | 18 | @staticmethod 19 | def forward(ctx, input1, input2, pad_size, kernel_size, max_displacement,stride1, stride2, corr_multiply): 20 | ctx.save_for_backward(input1, input2) 21 | ctx.pad_size = pad_size 22 | ctx.kernel_size = kernel_size 23 | ctx.max_displacement = max_displacement 24 | ctx.stride1 = stride1 25 | ctx.stride2 = stride2 26 | ctx.corr_multiply = corr_multiply 27 | 28 | with torch.cuda.device_of(input1): 29 | rbot1 = input1.new() 30 | rbot2 = input2.new() 31 | output = input1.new() 32 | 33 | correlation_cuda.forward(input1, input2, rbot1, rbot2, output, 34 | pad_size, kernel_size, max_displacement, stride1, stride2, corr_multiply) 35 | 36 | return output 37 | 38 | @staticmethod 39 | def backward(ctx, grad_output): 40 | input1, input2 = ctx.saved_tensors 41 | 42 | with torch.cuda.device_of(input1): 43 | rbot1 = input1.new() 44 | rbot2 = input2.new() 45 | 46 | grad_input1 = input1.new() 47 | grad_input2 = input2.new() 48 | 49 | correlation_cuda.backward(input1, input2, rbot1, rbot2, grad_output, grad_input1, grad_input2, 50 | ctx.pad_size, ctx.kernel_size, ctx.max_displacement, ctx.stride1, ctx.stride2, ctx.corr_multiply) 51 | 52 | return grad_input1, grad_input2, None, None, None, None, None, None 53 | 54 | 55 | class Correlation(Module): 56 | def __init__(self, pad_size=0, kernel_size=0, max_displacement=0, stride1=1, stride2=2, corr_multiply=1): 57 | super(Correlation, self).__init__() 58 | self.pad_size = pad_size 59 | self.kernel_size = kernel_size 60 | self.max_displacement = max_displacement 61 | self.stride1 = stride1 62 | self.stride2 = stride2 63 | self.corr_multiply = corr_multiply 64 | 65 | # @staticmethod 66 | def forward(self, input1, input2): 67 | 68 | input1 = input1.contiguous() 69 | input2 = input2.contiguous() 70 | # result = CorrelationFunction(self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)(input1, input2) 71 | result = CorrelationFunction.apply(input1, input2, self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply) 72 | 73 | return result 74 | 75 | -------------------------------------------------------------------------------- /correlation_package/correlation_cuda.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #include "correlation_cuda_kernel.cuh" 8 | 9 | int correlation_forward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& output, 10 | int pad_size, 11 | int kernel_size, 12 | int max_displacement, 13 | int stride1, 14 | int stride2, 15 | int corr_type_multiply) 16 | { 17 | 18 | int batchSize = input1.size(0); 19 | 20 | int nInputChannels = input1.size(1); 21 | int inputHeight = input1.size(2); 22 | int inputWidth = input1.size(3); 23 | 24 | int kernel_radius = (kernel_size - 1) / 2; 25 | int border_radius = kernel_radius + max_displacement; 26 | 27 | int paddedInputHeight = inputHeight + 2 * pad_size; 28 | int paddedInputWidth = inputWidth + 2 * pad_size; 29 | 30 | int nOutputChannels = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1); 31 | 32 | int outputHeight = ceil(static_cast(paddedInputHeight - 2 * border_radius) / static_cast(stride1)); 33 | int outputwidth = ceil(static_cast(paddedInputWidth - 2 * border_radius) / static_cast(stride1)); 34 | 35 | rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 36 | rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 37 | output.resize_({batchSize, nOutputChannels, outputHeight, outputwidth}); 38 | 39 | rInput1.fill_(0); 40 | rInput2.fill_(0); 41 | output.fill_(0); 42 | 43 | int success = correlation_forward_cuda_kernel( 44 | output, 45 | output.size(0), 46 | output.size(1), 47 | output.size(2), 48 | output.size(3), 49 | output.stride(0), 50 | output.stride(1), 51 | output.stride(2), 52 | output.stride(3), 53 | input1, 54 | input1.size(1), 55 | input1.size(2), 56 | input1.size(3), 57 | input1.stride(0), 58 | input1.stride(1), 59 | input1.stride(2), 60 | input1.stride(3), 61 | input2, 62 | input2.size(1), 63 | input2.stride(0), 64 | input2.stride(1), 65 | input2.stride(2), 66 | input2.stride(3), 67 | rInput1, 68 | rInput2, 69 | pad_size, 70 | kernel_size, 71 | max_displacement, 72 | stride1, 73 | stride2, 74 | corr_type_multiply, 75 | at::cuda::getCurrentCUDAStream() 76 | ); 77 | 78 | //check for errors 79 | if (!success) { 80 | AT_ERROR("CUDA call failed"); 81 | } 82 | 83 | return 1; 84 | 85 | } 86 | 87 | int correlation_backward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& gradOutput, 88 | at::Tensor& gradInput1, at::Tensor& gradInput2, 89 | int pad_size, 90 | int kernel_size, 91 | int max_displacement, 92 | int stride1, 93 | int stride2, 94 | int corr_type_multiply) 95 | { 96 | 97 | int batchSize = input1.size(0); 98 | int nInputChannels = input1.size(1); 99 | int paddedInputHeight = input1.size(2)+ 2 * pad_size; 100 | int paddedInputWidth = input1.size(3)+ 2 * pad_size; 101 | 102 | int height = input1.size(2); 103 | int width = input1.size(3); 104 | 105 | rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 106 | rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 107 | gradInput1.resize_({batchSize, nInputChannels, height, width}); 108 | gradInput2.resize_({batchSize, nInputChannels, height, width}); 109 | 110 | rInput1.fill_(0); 111 | rInput2.fill_(0); 112 | gradInput1.fill_(0); 113 | gradInput2.fill_(0); 114 | 115 | int success = correlation_backward_cuda_kernel(gradOutput, 116 | gradOutput.size(0), 117 | gradOutput.size(1), 118 | gradOutput.size(2), 119 | gradOutput.size(3), 120 | gradOutput.stride(0), 121 | gradOutput.stride(1), 122 | gradOutput.stride(2), 123 | gradOutput.stride(3), 124 | input1, 125 | input1.size(1), 126 | input1.size(2), 127 | input1.size(3), 128 | input1.stride(0), 129 | input1.stride(1), 130 | input1.stride(2), 131 | input1.stride(3), 132 | input2, 133 | input2.stride(0), 134 | input2.stride(1), 135 | input2.stride(2), 136 | input2.stride(3), 137 | gradInput1, 138 | gradInput1.stride(0), 139 | gradInput1.stride(1), 140 | gradInput1.stride(2), 141 | gradInput1.stride(3), 142 | gradInput2, 143 | gradInput2.size(1), 144 | gradInput2.stride(0), 145 | gradInput2.stride(1), 146 | gradInput2.stride(2), 147 | gradInput2.stride(3), 148 | rInput1, 149 | rInput2, 150 | pad_size, 151 | kernel_size, 152 | max_displacement, 153 | stride1, 154 | stride2, 155 | corr_type_multiply, 156 | at::cuda::getCurrentCUDAStream() 157 | ); 158 | 159 | if (!success) { 160 | AT_ERROR("CUDA call failed"); 161 | } 162 | 163 | return 1; 164 | } 165 | 166 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 167 | m.def("forward", &correlation_forward_cuda, "Correlation forward (CUDA)"); 168 | m.def("backward", &correlation_backward_cuda, "Correlation backward (CUDA)"); 169 | } 170 | 171 | -------------------------------------------------------------------------------- /correlation_package/correlation_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "correlation_cuda_kernel.cuh" 4 | 5 | #define CUDA_NUM_THREADS 1024 6 | #define THREADS_PER_BLOCK 32 7 | #define FULL_MASK 0xffffffff 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | using at::Half; 15 | 16 | template 17 | __forceinline__ __device__ scalar_t warpReduceSum(scalar_t val) { 18 | for (int offset = 16; offset > 0; offset /= 2) 19 | val += __shfl_down_sync(FULL_MASK, val, offset); 20 | return val; 21 | } 22 | 23 | template 24 | __forceinline__ __device__ scalar_t blockReduceSum(scalar_t val) { 25 | 26 | static __shared__ scalar_t shared[32]; 27 | int lane = threadIdx.x % warpSize; 28 | int wid = threadIdx.x / warpSize; 29 | 30 | val = warpReduceSum(val); 31 | 32 | if (lane == 0) 33 | shared[wid] = val; 34 | 35 | __syncthreads(); 36 | 37 | val = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0; 38 | 39 | if (wid == 0) 40 | val = warpReduceSum(val); 41 | 42 | return val; 43 | } 44 | 45 | 46 | template 47 | __global__ void channels_first(const scalar_t* __restrict__ input, scalar_t* rinput, int channels, int height, int width, int pad_size) 48 | { 49 | 50 | // n (batch size), c (num of channels), y (height), x (width) 51 | int n = blockIdx.x; 52 | int y = blockIdx.y; 53 | int x = blockIdx.z; 54 | 55 | int ch_off = threadIdx.x; 56 | scalar_t value; 57 | 58 | int dimcyx = channels * height * width; 59 | int dimyx = height * width; 60 | 61 | int p_dimx = (width + 2 * pad_size); 62 | int p_dimy = (height + 2 * pad_size); 63 | int p_dimyxc = channels * p_dimy * p_dimx; 64 | int p_dimxc = p_dimx * channels; 65 | 66 | for (int c = ch_off; c < channels; c += THREADS_PER_BLOCK) { 67 | value = input[n * dimcyx + c * dimyx + y * width + x]; 68 | rinput[n * p_dimyxc + (y + pad_size) * p_dimxc + (x + pad_size) * channels + c] = value; 69 | } 70 | } 71 | 72 | 73 | template 74 | __global__ void correlation_forward(scalar_t* __restrict__ output, const int nOutputChannels, 75 | const int outputHeight, const int outputWidth, const scalar_t* __restrict__ rInput1, 76 | const int nInputChannels, const int inputHeight, const int inputWidth, 77 | const scalar_t* __restrict__ rInput2, const int pad_size, const int kernel_size, 78 | const int max_displacement, const int stride1, const int stride2) { 79 | 80 | int32_t pInputWidth = inputWidth + 2 * pad_size; 81 | int32_t pInputHeight = inputHeight + 2 * pad_size; 82 | 83 | int32_t kernel_rad = (kernel_size - 1) / 2; 84 | 85 | int32_t displacement_rad = max_displacement / stride2; 86 | 87 | int32_t displacement_size = 2 * displacement_rad + 1; 88 | 89 | int32_t n = blockIdx.x; 90 | int32_t y1 = blockIdx.y * stride1 + max_displacement; 91 | int32_t x1 = blockIdx.z * stride1 + max_displacement; 92 | int32_t c = threadIdx.x; 93 | 94 | int32_t pdimyxc = pInputHeight * pInputWidth * nInputChannels; 95 | 96 | int32_t pdimxc = pInputWidth * nInputChannels; 97 | 98 | int32_t pdimc = nInputChannels; 99 | 100 | int32_t tdimcyx = nOutputChannels * outputHeight * outputWidth; 101 | int32_t tdimyx = outputHeight * outputWidth; 102 | int32_t tdimx = outputWidth; 103 | 104 | int32_t nelems = kernel_size * kernel_size * pdimc; 105 | 106 | // element-wise product along channel axis 107 | for (int tj = -displacement_rad; tj <= displacement_rad; ++tj) { 108 | for (int ti = -displacement_rad; ti <= displacement_rad; ++ti) { 109 | int x2 = x1 + ti * stride2; 110 | int y2 = y1 + tj * stride2; 111 | 112 | float acc0 = 0.0f; 113 | 114 | for (int j = -kernel_rad; j <= kernel_rad; ++j) { 115 | for (int i = -kernel_rad; i <= kernel_rad; ++i) { 116 | // THREADS_PER_BLOCK 117 | #pragma unroll 118 | for (int ch = c; ch < pdimc; ch += blockDim.x) { 119 | 120 | int indx1 = n * pdimyxc + (y1 + j) * pdimxc 121 | + (x1 + i) * pdimc + ch; 122 | int indx2 = n * pdimyxc + (y2 + j) * pdimxc 123 | + (x2 + i) * pdimc + ch; 124 | acc0 += static_cast(rInput1[indx1] * rInput2[indx2]); 125 | } 126 | } 127 | } 128 | 129 | if (blockDim.x == warpSize) { 130 | __syncwarp(); 131 | acc0 = warpReduceSum(acc0); 132 | } else { 133 | __syncthreads(); 134 | acc0 = blockReduceSum(acc0); 135 | } 136 | 137 | if (threadIdx.x == 0) { 138 | 139 | int tc = (tj + displacement_rad) * displacement_size 140 | + (ti + displacement_rad); 141 | const int tindx = n * tdimcyx + tc * tdimyx + blockIdx.y * tdimx 142 | + blockIdx.z; 143 | output[tindx] = static_cast(acc0 / nelems); 144 | } 145 | } 146 | } 147 | } 148 | 149 | 150 | template 151 | __global__ void correlation_backward_input1(int item, scalar_t* gradInput1, int nInputChannels, int inputHeight, int inputWidth, 152 | const scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth, 153 | const scalar_t* __restrict__ rInput2, 154 | int pad_size, 155 | int kernel_size, 156 | int max_displacement, 157 | int stride1, 158 | int stride2) 159 | { 160 | // n (batch size), c (num of channels), y (height), x (width) 161 | 162 | int n = item; 163 | int y = blockIdx.x * stride1 + pad_size; 164 | int x = blockIdx.y * stride1 + pad_size; 165 | int c = blockIdx.z; 166 | int tch_off = threadIdx.x; 167 | 168 | int kernel_rad = (kernel_size - 1) / 2; 169 | int displacement_rad = max_displacement / stride2; 170 | int displacement_size = 2 * displacement_rad + 1; 171 | 172 | int xmin = (x - kernel_rad - max_displacement) / stride1; 173 | int ymin = (y - kernel_rad - max_displacement) / stride1; 174 | 175 | int xmax = (x + kernel_rad - max_displacement) / stride1; 176 | int ymax = (y + kernel_rad - max_displacement) / stride1; 177 | 178 | if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) { 179 | // assumes gradInput1 is pre-allocated and zero filled 180 | return; 181 | } 182 | 183 | if (xmin > xmax || ymin > ymax) { 184 | // assumes gradInput1 is pre-allocated and zero filled 185 | return; 186 | } 187 | 188 | xmin = max(0,xmin); 189 | xmax = min(outputWidth-1,xmax); 190 | 191 | ymin = max(0,ymin); 192 | ymax = min(outputHeight-1,ymax); 193 | 194 | int pInputWidth = inputWidth + 2 * pad_size; 195 | int pInputHeight = inputHeight + 2 * pad_size; 196 | 197 | int pdimyxc = pInputHeight * pInputWidth * nInputChannels; 198 | int pdimxc = pInputWidth * nInputChannels; 199 | int pdimc = nInputChannels; 200 | 201 | int tdimcyx = nOutputChannels * outputHeight * outputWidth; 202 | int tdimyx = outputHeight * outputWidth; 203 | int tdimx = outputWidth; 204 | 205 | int odimcyx = nInputChannels * inputHeight* inputWidth; 206 | int odimyx = inputHeight * inputWidth; 207 | int odimx = inputWidth; 208 | 209 | scalar_t nelems = kernel_size * kernel_size * nInputChannels; 210 | 211 | __shared__ scalar_t prod_sum[THREADS_PER_BLOCK]; 212 | prod_sum[tch_off] = 0; 213 | 214 | for (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) { 215 | 216 | int i2 = (tc % displacement_size - displacement_rad) * stride2; 217 | int j2 = (tc / displacement_size - displacement_rad) * stride2; 218 | 219 | int indx2 = n * pdimyxc + (y + j2)* pdimxc + (x + i2) * pdimc + c; 220 | 221 | scalar_t val2 = rInput2[indx2]; 222 | 223 | for (int j = ymin; j <= ymax; ++j) { 224 | for (int i = xmin; i <= xmax; ++i) { 225 | int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i; 226 | prod_sum[tch_off] += gradOutput[tindx] * val2; 227 | } 228 | } 229 | } 230 | __syncthreads(); 231 | 232 | if(tch_off == 0) { 233 | scalar_t reduce_sum = 0; 234 | for(int idx = 0; idx < THREADS_PER_BLOCK; idx++) { 235 | reduce_sum += prod_sum[idx]; 236 | } 237 | const int indx1 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size); 238 | gradInput1[indx1] = reduce_sum / nelems; 239 | } 240 | 241 | } 242 | 243 | template 244 | __global__ void correlation_backward_input2(int item, scalar_t* gradInput2, int nInputChannels, int inputHeight, int inputWidth, 245 | const scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth, 246 | const scalar_t* __restrict__ rInput1, 247 | int pad_size, 248 | int kernel_size, 249 | int max_displacement, 250 | int stride1, 251 | int stride2) 252 | { 253 | // n (batch size), c (num of channels), y (height), x (width) 254 | 255 | int n = item; 256 | int y = blockIdx.x * stride1 + pad_size; 257 | int x = blockIdx.y * stride1 + pad_size; 258 | int c = blockIdx.z; 259 | 260 | int tch_off = threadIdx.x; 261 | 262 | int kernel_rad = (kernel_size - 1) / 2; 263 | int displacement_rad = max_displacement / stride2; 264 | int displacement_size = 2 * displacement_rad + 1; 265 | 266 | int pInputWidth = inputWidth + 2 * pad_size; 267 | int pInputHeight = inputHeight + 2 * pad_size; 268 | 269 | int pdimyxc = pInputHeight * pInputWidth * nInputChannels; 270 | int pdimxc = pInputWidth * nInputChannels; 271 | int pdimc = nInputChannels; 272 | 273 | int tdimcyx = nOutputChannels * outputHeight * outputWidth; 274 | int tdimyx = outputHeight * outputWidth; 275 | int tdimx = outputWidth; 276 | 277 | int odimcyx = nInputChannels * inputHeight* inputWidth; 278 | int odimyx = inputHeight * inputWidth; 279 | int odimx = inputWidth; 280 | 281 | scalar_t nelems = kernel_size * kernel_size * nInputChannels; 282 | 283 | __shared__ scalar_t prod_sum[THREADS_PER_BLOCK]; 284 | prod_sum[tch_off] = 0; 285 | 286 | for (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) { 287 | int i2 = (tc % displacement_size - displacement_rad) * stride2; 288 | int j2 = (tc / displacement_size - displacement_rad) * stride2; 289 | 290 | int xmin = (x - kernel_rad - max_displacement - i2) / stride1; 291 | int ymin = (y - kernel_rad - max_displacement - j2) / stride1; 292 | 293 | int xmax = (x + kernel_rad - max_displacement - i2) / stride1; 294 | int ymax = (y + kernel_rad - max_displacement - j2) / stride1; 295 | 296 | if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) { 297 | // assumes gradInput2 is pre-allocated and zero filled 298 | continue; 299 | } 300 | 301 | if (xmin > xmax || ymin > ymax) { 302 | // assumes gradInput2 is pre-allocated and zero filled 303 | continue; 304 | } 305 | 306 | xmin = max(0,xmin); 307 | xmax = min(outputWidth-1,xmax); 308 | 309 | ymin = max(0,ymin); 310 | ymax = min(outputHeight-1,ymax); 311 | 312 | int indx1 = n * pdimyxc + (y - j2)* pdimxc + (x - i2) * pdimc + c; 313 | scalar_t val1 = rInput1[indx1]; 314 | 315 | for (int j = ymin; j <= ymax; ++j) { 316 | for (int i = xmin; i <= xmax; ++i) { 317 | int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i; 318 | prod_sum[tch_off] += gradOutput[tindx] * val1; 319 | } 320 | } 321 | } 322 | 323 | __syncthreads(); 324 | 325 | if(tch_off == 0) { 326 | scalar_t reduce_sum = 0; 327 | for(int idx = 0; idx < THREADS_PER_BLOCK; idx++) { 328 | reduce_sum += prod_sum[idx]; 329 | } 330 | const int indx2 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size); 331 | gradInput2[indx2] = reduce_sum / nelems; 332 | } 333 | 334 | } 335 | 336 | int correlation_forward_cuda_kernel(at::Tensor& output, 337 | int ob, 338 | int oc, 339 | int oh, 340 | int ow, 341 | int osb, 342 | int osc, 343 | int osh, 344 | int osw, 345 | 346 | at::Tensor& input1, 347 | int ic, 348 | int ih, 349 | int iw, 350 | int isb, 351 | int isc, 352 | int ish, 353 | int isw, 354 | 355 | at::Tensor& input2, 356 | int gc, 357 | int gsb, 358 | int gsc, 359 | int gsh, 360 | int gsw, 361 | 362 | at::Tensor& rInput1, 363 | at::Tensor& rInput2, 364 | int pad_size, 365 | int kernel_size, 366 | int max_displacement, 367 | int stride1, 368 | int stride2, 369 | int corr_type_multiply, 370 | cudaStream_t stream) 371 | { 372 | 373 | int batchSize = ob; 374 | 375 | int nInputChannels = ic; 376 | int inputWidth = iw; 377 | int inputHeight = ih; 378 | 379 | int nOutputChannels = oc; 380 | int outputWidth = ow; 381 | int outputHeight = oh; 382 | 383 | dim3 blocks_grid(batchSize, inputHeight, inputWidth); 384 | dim3 threads_block(THREADS_PER_BLOCK); 385 | 386 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channels_first_fwd_1", ([&] { 387 | 388 | channels_first<<>>( 389 | input1.data(), rInput1.data(), nInputChannels, inputHeight, inputWidth, pad_size); 390 | 391 | })); 392 | 393 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "channels_first_fwd_2", ([&] { 394 | 395 | channels_first<<>> ( 396 | input2.data(), rInput2.data(), nInputChannels, inputHeight, inputWidth, pad_size); 397 | 398 | })); 399 | 400 | dim3 threadsPerBlock(THREADS_PER_BLOCK); 401 | dim3 totalBlocksCorr(batchSize, outputHeight, outputWidth); 402 | 403 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "correlation_forward", ([&] { 404 | 405 | correlation_forward<<>> 406 | (output.data(), nOutputChannels, outputHeight, outputWidth, 407 | rInput1.data(), nInputChannels, inputHeight, inputWidth, 408 | rInput2.data(), 409 | pad_size, 410 | kernel_size, 411 | max_displacement, 412 | stride1, 413 | stride2); 414 | 415 | })); 416 | 417 | cudaError_t err = cudaGetLastError(); 418 | 419 | 420 | // check for errors 421 | if (err != cudaSuccess) { 422 | printf("error in correlation_forward_cuda_kernel: %s\n", cudaGetErrorString(err)); 423 | return 0; 424 | } 425 | 426 | return 1; 427 | } 428 | 429 | 430 | int correlation_backward_cuda_kernel( 431 | at::Tensor& gradOutput, 432 | int gob, 433 | int goc, 434 | int goh, 435 | int gow, 436 | int gosb, 437 | int gosc, 438 | int gosh, 439 | int gosw, 440 | 441 | at::Tensor& input1, 442 | int ic, 443 | int ih, 444 | int iw, 445 | int isb, 446 | int isc, 447 | int ish, 448 | int isw, 449 | 450 | at::Tensor& input2, 451 | int gsb, 452 | int gsc, 453 | int gsh, 454 | int gsw, 455 | 456 | at::Tensor& gradInput1, 457 | int gisb, 458 | int gisc, 459 | int gish, 460 | int gisw, 461 | 462 | at::Tensor& gradInput2, 463 | int ggc, 464 | int ggsb, 465 | int ggsc, 466 | int ggsh, 467 | int ggsw, 468 | 469 | at::Tensor& rInput1, 470 | at::Tensor& rInput2, 471 | int pad_size, 472 | int kernel_size, 473 | int max_displacement, 474 | int stride1, 475 | int stride2, 476 | int corr_type_multiply, 477 | cudaStream_t stream) 478 | { 479 | 480 | int batchSize = gob; 481 | int num = batchSize; 482 | 483 | int nInputChannels = ic; 484 | int inputWidth = iw; 485 | int inputHeight = ih; 486 | 487 | int nOutputChannels = goc; 488 | int outputWidth = gow; 489 | int outputHeight = goh; 490 | 491 | dim3 blocks_grid(batchSize, inputHeight, inputWidth); 492 | dim3 threads_block(THREADS_PER_BLOCK); 493 | 494 | 495 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "lltm_forward_cuda", ([&] { 496 | 497 | channels_first<<>>( 498 | input1.data(), 499 | rInput1.data(), 500 | nInputChannels, 501 | inputHeight, 502 | inputWidth, 503 | pad_size 504 | ); 505 | })); 506 | 507 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "lltm_forward_cuda", ([&] { 508 | 509 | channels_first<<>>( 510 | input2.data(), 511 | rInput2.data(), 512 | nInputChannels, 513 | inputHeight, 514 | inputWidth, 515 | pad_size 516 | ); 517 | })); 518 | 519 | dim3 threadsPerBlock(THREADS_PER_BLOCK); 520 | dim3 totalBlocksCorr(inputHeight, inputWidth, nInputChannels); 521 | 522 | for (int n = 0; n < num; ++n) { 523 | 524 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "lltm_forward_cuda", ([&] { 525 | 526 | 527 | correlation_backward_input1<<>> ( 528 | n, gradInput1.data(), nInputChannels, inputHeight, inputWidth, 529 | gradOutput.data(), nOutputChannels, outputHeight, outputWidth, 530 | rInput2.data(), 531 | pad_size, 532 | kernel_size, 533 | max_displacement, 534 | stride1, 535 | stride2); 536 | })); 537 | } 538 | 539 | for(int n = 0; n < batchSize; n++) { 540 | 541 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(rInput1.type(), "lltm_forward_cuda", ([&] { 542 | 543 | correlation_backward_input2<<>>( 544 | n, gradInput2.data(), nInputChannels, inputHeight, inputWidth, 545 | gradOutput.data(), nOutputChannels, outputHeight, outputWidth, 546 | rInput1.data(), 547 | pad_size, 548 | kernel_size, 549 | max_displacement, 550 | stride1, 551 | stride2); 552 | 553 | })); 554 | } 555 | 556 | // check for errors 557 | cudaError_t err = cudaGetLastError(); 558 | if (err != cudaSuccess) { 559 | printf("error in correlation_backward_cuda_kernel: %s\n", cudaGetErrorString(err)); 560 | return 0; 561 | } 562 | 563 | return 1; 564 | } 565 | -------------------------------------------------------------------------------- /correlation_package/correlation_cuda_kernel.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | int correlation_forward_cuda_kernel(at::Tensor& output, 8 | int ob, 9 | int oc, 10 | int oh, 11 | int ow, 12 | int osb, 13 | int osc, 14 | int osh, 15 | int osw, 16 | 17 | at::Tensor& input1, 18 | int ic, 19 | int ih, 20 | int iw, 21 | int isb, 22 | int isc, 23 | int ish, 24 | int isw, 25 | 26 | at::Tensor& input2, 27 | int gc, 28 | int gsb, 29 | int gsc, 30 | int gsh, 31 | int gsw, 32 | 33 | at::Tensor& rInput1, 34 | at::Tensor& rInput2, 35 | int pad_size, 36 | int kernel_size, 37 | int max_displacement, 38 | int stride1, 39 | int stride2, 40 | int corr_type_multiply, 41 | cudaStream_t stream); 42 | 43 | 44 | int correlation_backward_cuda_kernel( 45 | at::Tensor& gradOutput, 46 | int gob, 47 | int goc, 48 | int goh, 49 | int gow, 50 | int gosb, 51 | int gosc, 52 | int gosh, 53 | int gosw, 54 | 55 | at::Tensor& input1, 56 | int ic, 57 | int ih, 58 | int iw, 59 | int isb, 60 | int isc, 61 | int ish, 62 | int isw, 63 | 64 | at::Tensor& input2, 65 | int gsb, 66 | int gsc, 67 | int gsh, 68 | int gsw, 69 | 70 | at::Tensor& gradInput1, 71 | int gisb, 72 | int gisc, 73 | int gish, 74 | int gisw, 75 | 76 | at::Tensor& gradInput2, 77 | int ggc, 78 | int ggsb, 79 | int ggsc, 80 | int ggsh, 81 | int ggsw, 82 | 83 | at::Tensor& rInput1, 84 | at::Tensor& rInput2, 85 | int pad_size, 86 | int kernel_size, 87 | int max_displacement, 88 | int stride1, 89 | int stride2, 90 | int corr_type_multiply, 91 | cudaStream_t stream); 92 | -------------------------------------------------------------------------------- /correlation_package/nvcc setting.md: -------------------------------------------------------------------------------- 1 | # How to set the nvcc version (Default:CUDA-11.0) 2 | 3 | ## Pre-requisite 4 | 5 | Check installed `cuda-11.0` path: 6 | ```bash 7 | $ cd /usr/local 8 | $ find . -maxdepth 1 -name 'cuda-11.0' 9 | ``` 10 | - If there is no `cuda-11.0` folder in the directory, install `cuda-11.0` first. 11 | 12 | ## Change environments of the terminal (temporal) 13 | Change your terminal `PATH` and `LD_LIBRARY_PATH`: 14 | ```bash 15 | $ export PATH=/usr/local/cuda-11.0/bin${PATH:+:${PATH}} 16 | $ export LD_LIBRARY_PATH=/usr/local/cuda-11.0/lib64 17 | $ nvcc --version 18 | ``` 19 | ## Change default environments of the terminal 20 | For change default `nvcc` version of your terminal, you should add below two lines in your `~/.bashrc`. 21 | ```bash 22 | $ gedit ~/.bashrc 23 | export PATH=/usr/local/cuda-11.0/bin${PATH:+:${PATH}} 24 | export LD_LIBRARY_PATH=/usr/local/cuda-11.0/lib64 25 | ``` 26 | Save and then open new terminal: 27 | ```bash 28 | $ nvcc --version 29 | ``` 30 | -------------------------------------------------------------------------------- /correlation_package/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | # Minimum requirements for the build system to execute. 3 | requires = ["setuptools", "wheel", "numpy", "torch"] # PEP 508 specifications. 4 | -------------------------------------------------------------------------------- /correlation_package/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import torch 4 | 5 | from setuptools import setup, find_packages 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | 8 | cxx_args = ['-std=c++14'] 9 | 10 | nvcc_args = [ 11 | '-gencode', 'arch=compute_50,code=sm_50', 12 | '-gencode', 'arch=compute_52,code=sm_52', 13 | '-gencode', 'arch=compute_60,code=sm_60', 14 | '-gencode', 'arch=compute_61,code=sm_61', 15 | '-gencode', 'arch=compute_70,code=sm_70', 16 | '-gencode', 'arch=compute_70,code=compute_70' 17 | ] 18 | 19 | setup( 20 | name='correlation_cuda', 21 | ext_modules=[ 22 | CUDAExtension('correlation_cuda', [ 23 | 'correlation_cuda.cc', 24 | 'correlation_cuda_kernel.cu' 25 | ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}) 26 | ], 27 | cmdclass={ 28 | 'build_ext': BuildExtension 29 | }) 30 | -------------------------------------------------------------------------------- /figure/ERF_dataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelpro/CBMNet/248405c313ce29f87ad42873df46792c76f55d1a/figure/ERF_dataset.png -------------------------------------------------------------------------------- /figure/Quantitative_eval_ERF_x170FPS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelpro/CBMNet/248405c313ce29f87ad42873df46792c76f55d1a/figure/Quantitative_eval_ERF_x170FPS.png -------------------------------------------------------------------------------- /figure/driving.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelpro/CBMNet/248405c313ce29f87ad42873df46792c76f55d1a/figure/driving.gif -------------------------------------------------------------------------------- /figure/flower.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelpro/CBMNet/248405c313ce29f87ad42873df46792c76f55d1a/figure/flower.gif -------------------------------------------------------------------------------- /figure/popcorn.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelpro/CBMNet/248405c313ce29f87ad42873df46792c76f55d1a/figure/popcorn.gif -------------------------------------------------------------------------------- /install_correlation.sh: -------------------------------------------------------------------------------- 1 | cd correlation_package 2 | python3 setup.py install 3 | cd .. -------------------------------------------------------------------------------- /models/final_models/ours.py: -------------------------------------------------------------------------------- 1 | from models.final_models.submodules import * 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.modules import conv 6 | from correlation_package.correlation import Correlation 7 | from einops import rearrange, repeat 8 | from einops.layers.torch import Rearrange 9 | from timm.models.layers import DropPath, trunc_normal_, to_2tuple 10 | from functools import reduce, lru_cache 11 | import torch.nn.functional as tf 12 | from torch.autograd import Variable 13 | from einops import rearrange 14 | import math 15 | import numbers 16 | import collections 17 | 18 | 19 | def conv(in_planes, out_planes, kernel_size=3, stride=1, dilation=1, isReLU=True): 20 | if isReLU: 21 | return nn.Sequential( 22 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation, 23 | padding=((kernel_size - 1) * dilation) // 2, bias=True), 24 | nn.LeakyReLU(0.1, inplace=True) 25 | ) 26 | else: 27 | return nn.Sequential( 28 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation, 29 | padding=((kernel_size - 1) * dilation) // 2, bias=True) 30 | ) 31 | 32 | def predict_flow(in_planes): 33 | return nn.Conv2d(in_planes,2,kernel_size=3,stride=1,padding=1,bias=True) 34 | 35 | def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): 36 | return nn.ConvTranspose2d(in_planes, out_planes, kernel_size, stride, padding, bias=True) 37 | 38 | 39 | class encoder_event_flow(nn.Module): 40 | def __init__(self, num_chs): 41 | super(encoder_event_flow, self).__init__() 42 | self.conv1 = conv_resblock_one(num_chs[0], num_chs[1], stride=1) 43 | self.conv2 = conv_resblock_one(num_chs[1], num_chs[2], stride=1) 44 | self.conv3 = conv_resblock_one(num_chs[2], num_chs[3], stride=2) 45 | self.conv4 = conv_resblock_one(num_chs[3], num_chs[4], stride=2) 46 | 47 | def forward(self, im): 48 | x = self.conv1(im) 49 | c11 = self.conv2(x) 50 | c12 = self.conv3(c11) 51 | c13 = self.conv4(c12) 52 | return c11, c12, c13 53 | 54 | 55 | class encoder_event_for_image_flow(nn.Module): 56 | def __init__(self, num_chs): 57 | super(encoder_event_for_image_flow, self).__init__() 58 | self.conv1 = conv_resblock_one(num_chs[0], num_chs[1], stride=1) 59 | self.conv2 = conv_resblock_one(num_chs[1], num_chs[2], stride=2) 60 | self.conv3 = conv_resblock_one(num_chs[2], num_chs[3], stride=2) 61 | self.conv4 = conv_resblock_one(num_chs[3], num_chs[4], stride=2) 62 | 63 | def forward(self, im): 64 | x = self.conv1(im) 65 | c11 = self.conv2(x) 66 | c12 = self.conv3(c11) 67 | c13 = self.conv4(c12) 68 | return c11, c12, c13 69 | 70 | 71 | class encoder_image_for_image_flow(nn.Module): 72 | def __init__(self, num_chs): 73 | super(encoder_image_for_image_flow, self).__init__() 74 | self.conv1 = conv_resblock_one(num_chs[0], num_chs[1], stride=1) 75 | self.conv2 = conv_resblock_one(num_chs[1], num_chs[2], stride=2) 76 | self.conv3 = conv_resblock_one(num_chs[2], num_chs[3], stride=2) 77 | self.conv4 = conv_resblock_one(num_chs[3], num_chs[4], stride=2) 78 | 79 | def forward(self, image): 80 | x = self.conv1(image) 81 | f1 = self.conv2(x) 82 | f2 = self.conv3(f1) 83 | f3 = self.conv4(f2) 84 | return f1, f2, f3 85 | 86 | def upsample2d(inputs, target_as, mode="bilinear"): 87 | _, _, h, w = target_as.size() 88 | return tf.interpolate(inputs, [h, w], mode=mode, align_corners=True) 89 | 90 | def upsample2d_hw(inputs, h, w, mode="bilinear"): 91 | return tf.interpolate(inputs, [h, w], mode=mode, align_corners=True) 92 | 93 | 94 | class DenseBlock(nn.Module): 95 | def __init__(self, ch_in): 96 | super(DenseBlock, self).__init__() 97 | self.conv1 = conv(ch_in, 128) 98 | self.conv2 = conv(ch_in + 128, 128) 99 | self.conv3 = conv(ch_in + 256, 96) 100 | self.conv4 = conv(ch_in + 352, 64) 101 | self.conv5 = conv(ch_in + 416, 32) 102 | self.conv_last = conv(ch_in + 448, 2, isReLU=False) 103 | 104 | def forward(self, x): 105 | x1 = torch.cat([self.conv1(x), x], dim=1) 106 | x2 = torch.cat([self.conv2(x1), x1], dim=1) 107 | x3 = torch.cat([self.conv3(x2), x2], dim=1) 108 | x4 = torch.cat([self.conv4(x3), x3], dim=1) 109 | x5 = torch.cat([self.conv5(x4), x4], dim=1) 110 | x_out = self.conv_last(x5) 111 | return x5, x_out 112 | 113 | 114 | 115 | class FlowEstimatorDense(nn.Module): 116 | def __init__(self, ch_in=64, f_channels=(128, 128, 96, 64, 32, 32), ch_out=2): 117 | super(FlowEstimatorDense, self).__init__() 118 | N = 0 119 | ind = 0 120 | N += ch_in 121 | self.conv1 = conv(N, f_channels[ind]) 122 | N += f_channels[ind] 123 | ind += 1 124 | self.conv2 = conv(N, f_channels[ind]) 125 | N += f_channels[ind] 126 | ind += 1 127 | self.conv3 = conv(N, f_channels[ind]) 128 | N += f_channels[ind] 129 | ind += 1 130 | self.conv4 = conv(N, f_channels[ind]) 131 | N += f_channels[ind] 132 | ind += 1 133 | self.conv5 = conv(N, f_channels[ind]) 134 | N += f_channels[ind] 135 | self.num_feature_channel = N 136 | ind += 1 137 | self.conv_last = conv(N, ch_out, isReLU=False) 138 | 139 | def forward(self, x): 140 | x1 = torch.cat([self.conv1(x), x], axis=1) 141 | x2 = torch.cat([self.conv2(x1), x1], axis=1) 142 | x3 = torch.cat([self.conv3(x2), x2], axis=1) 143 | x4 = torch.cat([self.conv4(x3), x3], axis=1) 144 | x5 = torch.cat([self.conv5(x4), x4], axis=1) 145 | x_out = self.conv_last(x5) 146 | return x5, x_out 147 | 148 | class Tfeat_RefineBlock(nn.Module): 149 | def __init__(self, ch_in_frame, ch_in_event, ch_in_frame_prev, prev_scale=False): 150 | super(Tfeat_RefineBlock, self).__init__() 151 | if prev_scale: 152 | nf = int((ch_in_frame*2+ch_in_event+ch_in_frame_prev)/4) 153 | else: 154 | nf = int((ch_in_frame*2+ch_in_event)/4) 155 | self.conv_refine = nn.Sequential(conv1x1(4*nf, nf), nn.ReLU(), conv3x3(nf, 2*nf), nn.ReLU(), conv_resblock_one(2*nf, ch_in_frame)) 156 | 157 | def forward(self, x): 158 | x1 = self.conv_refine(x) 159 | return x1 160 | 161 | def rescale_flow(flow, width_im, height_im): 162 | u_scale = float(width_im / flow.size(3)) 163 | v_scale = float(height_im / flow.size(2)) 164 | u, v = flow.chunk(2, dim=1) 165 | u = u_scale*u 166 | v = v_scale*v 167 | return torch.cat([u, v], dim=1) 168 | 169 | 170 | class FlowNet(nn.Module): 171 | def __init__(self, md=4, tb_debug=False): 172 | super(FlowNet, self).__init__() 173 | ## argument 174 | self.tb_debug = tb_debug 175 | # flow scale 176 | self.flow_scale = 20 177 | num_chs_frame = [3, 16, 32, 64, 96] 178 | num_chs_event = [16, 16, 32, 64, 128] 179 | num_chs_event_image = [16, 16, 16, 32, 64] 180 | ## for event-level flow 181 | self.encoder_event = encoder_event_flow(num_chs_event) 182 | ## for image-level flow 183 | self.encoder_image_flow = encoder_image_for_image_flow(num_chs_frame) 184 | self.encoder_image_flow_event = encoder_event_for_image_flow(num_chs_event_image) 185 | ## leaky relu 186 | self.leakyRELU = nn.LeakyReLU(0.1) 187 | ## correlation channel value 188 | self.corr = Correlation(pad_size=md, kernel_size=1, max_displacement=md, stride1=1, stride2=1, corr_multiply=1) 189 | ## correlation channel value 190 | nd = (2*md+1)**2 191 | self.corr_refinement = nn.ModuleList([DenseBlock(nd+num_chs_frame[-1]+2), 192 | DenseBlock(nd+num_chs_frame[-2]+2), 193 | DenseBlock(nd+num_chs_frame[-3]+2), 194 | DenseBlock(nd+num_chs_frame[-4]+2) 195 | ]) 196 | self.decoder_event = nn.ModuleList([conv_resblock_one(num_chs_event[-1], num_chs_event[-1]), 197 | conv_resblock_one(num_chs_event[-2]+ num_chs_event[-1]+2, num_chs_event[-2]), 198 | conv_resblock_one(num_chs_event[-3]+ num_chs_event[-2]+2, num_chs_event[-3])]) 199 | self.predict_flow = nn.ModuleList([conv3x3_leaky_relu(num_chs_event[-1], 2), 200 | conv3x3_leaky_relu(num_chs_event[-2], 2), 201 | conv3x3_leaky_relu(num_chs_event[-3], 2)]) 202 | self.conv_frame = nn.ModuleList([conv3x3_leaky_relu(num_chs_frame[-2], 32), 203 | conv3x3_leaky_relu(num_chs_frame[-3], 32)]) 204 | self.conv_frame_t = nn.ModuleList([conv3x3_leaky_relu(num_chs_frame[-2], 32), 205 | conv3x3_leaky_relu(num_chs_frame[-3], 32)]) 206 | self.flow_fusion_block = FlowEstimatorDense(32*3+4, (32, 32, 32, 16, 8), 1) 207 | self.feat_t_refinement = nn.ModuleList([Tfeat_RefineBlock(num_chs_frame[-1], num_chs_event_image[-1]*2, None, prev_scale=False), 208 | Tfeat_RefineBlock(num_chs_frame[-2], num_chs_event_image[-2]*2, num_chs_frame[-1], prev_scale=True), 209 | Tfeat_RefineBlock(num_chs_frame[-3], num_chs_event_image[-3]*2, num_chs_frame[-2], prev_scale=True), 210 | ]) 211 | 212 | 213 | def warp(self, x, flo): 214 | """ 215 | warp an image/tensor (im2) back to im1, according to the optical flow 216 | x: [B, C, H, W] (im2) 217 | flo: [B, 2, H, W] flow 218 | """ 219 | B, C, H, W = x.size() 220 | # mesh grid 221 | xx = torch.arange(0, W).view(1,-1).repeat(H,1) 222 | yy = torch.arange(0, H).view(-1,1).repeat(1,W) 223 | xx = xx.view(1,1,H,W).repeat(B,1,1,1) 224 | yy = yy.view(1,1,H,W).repeat(B,1,1,1) 225 | grid = torch.cat((xx,yy),1).float() 226 | 227 | if x.is_cuda: 228 | grid = grid.cuda() 229 | vgrid = Variable(grid) + flo 230 | 231 | # scale grid to [-1,1] 232 | vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:].clone() / max(W-1,1)-1.0 233 | vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:].clone() / max(H-1,1)-1.0 234 | 235 | vgrid = vgrid.permute(0,2,3,1) 236 | output = nn.functional.grid_sample(x, vgrid) 237 | mask = torch.autograd.Variable(torch.ones(x.size())).cuda() 238 | mask = nn.functional.grid_sample(mask, vgrid) 239 | mask[mask<0.9999] = 0 240 | mask[mask>0] = 1 241 | return output*mask 242 | 243 | def normalize_features(self, feature_list, normalize, center, moments_across_channels=True, moments_across_images=True): 244 | # Compute feature statistics. 245 | statistics = collections.defaultdict(list) 246 | axes = [1, 2, 3] if moments_across_channels else [2, 3] # [b, c, h, w] 247 | for feature_image in feature_list: 248 | mean = torch.mean(feature_image, axis=axes, keepdims=True) # [b,1,1,1] or [b,c,1,1] 249 | variance = torch.var(feature_image, axis=axes, keepdims=True) # [b,1,1,1] or [b,c,1,1] 250 | statistics['mean'].append(mean) 251 | statistics['var'].append(variance) 252 | 253 | if moments_across_images: 254 | statistics['mean'] = ([torch.mean(F.stack(statistics['mean'], axis=0), axis=(0, ))] * len(feature_list)) 255 | statistics['var'] = ([torch.var(F.stack(statistics['var'], axis=0), axis=(0, ))] * len(feature_list)) 256 | 257 | statistics['std'] = [torch.sqrt(v + 1e-16) for v in statistics['var']] 258 | # Center and normalize features. 259 | if center: 260 | feature_list = [f - mean for f, mean in zip(feature_list, statistics['mean'])] 261 | if normalize: 262 | feature_list = [f / std for f, std in zip(feature_list, statistics['std'])] 263 | return feature_list 264 | 265 | def forward(self, batch): 266 | # F for frame feature 267 | # E for event feature 268 | ### encoding 269 | ## image feature 270 | # feature pyramid 271 | F0_pyramid = self.encoder_image_flow(batch['image_input0'])[::-1] 272 | F1_pyramid = self.encoder_image_flow(batch['image_input1'])[::-1] 273 | E_0t_pyramid = self.encoder_image_flow_event(batch['event_input_0t'])[::-1] 274 | E_t1_pyramid = self.encoder_image_flow_event(batch['event_input_t1'])[::-1] 275 | # encoder event 276 | E_t0_pyramid_flow = self.encoder_event(batch['event_input_t0'])[::-1] 277 | E_t1_pyramid_flow = self.encoder_event(batch['event_input_t1'])[::-1] 278 | ### decoding optical flow 279 | ## level 0 280 | flow_t0_out_dict, flow_t1_out_dict, flow_t0_dict, flow_t1_dict = [], [], [], [] 281 | ## event flow and image flow 282 | if self.tb_debug: 283 | event_flow_dict, image_flow_dict, fusion_flow_dict, mask_dict = [], [], [], [] 284 | for level, (E_t0_flow, E_t1_flow, E_0t, E_t1, F0, F1) in enumerate(zip(E_t0_pyramid_flow, E_t1_pyramid_flow, E_0t_pyramid, E_t1_pyramid, F0_pyramid, F1_pyramid)): 285 | if level==0: 286 | ## event flow generation 287 | feat_t0_ev = self.decoder_event[level](E_t0_flow) 288 | feat_t1_ev = self.decoder_event[level](E_t1_flow) 289 | flow_event_t0 = self.predict_flow[level](feat_t0_ev) 290 | flow_event_t1 = self.predict_flow[level](feat_t1_ev) 291 | ## fusion flow(scale == 0) 292 | flow_fusion_t0 = flow_event_t0 293 | flow_fusion_t1 = flow_event_t1 294 | ## t feature 295 | feat_t_in = torch.cat((F0, F1, E_0t, E_t1), dim=1) 296 | feat_t = self.feat_t_refinement[level](feat_t_in) 297 | else: 298 | ## feat t 299 | upfeat0_t = upsample2d(feat_t, F0) 300 | feat_t_in = torch.cat((upfeat0_t, F0, F1, E_0t, E_t1), dim=1) 301 | feat_t = self.feat_t_refinement[level](feat_t_in) 302 | #### event-based optical flow 303 | ## event flow generation 304 | upflow_t0 = rescale_flow(upsample2d(flow_t0_out_dict[level-1], E_t0_flow), E_t0_flow.size(3), E_t0_flow.size(2)) 305 | upflow_t1 = rescale_flow(upsample2d(flow_t1_out_dict[level-1], E_t1_flow), E_t1_flow.size(3), E_t1_flow.size(2)) 306 | # upsample feat_t0 307 | feat_t0_ev_up = upsample2d(feat_t0_ev, E_t0_flow) 308 | feat_t1_ev_up = upsample2d(feat_t1_ev, E_t1_flow) 309 | # decoder event 310 | flow_t0_ev_up = rescale_flow(upsample2d(flow_event_t0, E_t0_flow), E_t0_flow.size(3), E_t0_flow.size(2)) 311 | flow_t1_ev_up = rescale_flow(upsample2d(flow_event_t1, E_t1_flow), E_t1_flow.size(3), E_t1_flow.size(2)) 312 | feat_t0_ev = self.decoder_event[level](torch.cat((E_t0_flow, feat_t0_ev_up, flow_t0_ev_up ), dim=1)) 313 | feat_t1_ev = self.decoder_event[level](torch.cat((E_t1_flow, feat_t1_ev_up, flow_t1_ev_up), dim=1)) 314 | ## project flow 315 | flow_event_t0_ = self.predict_flow[level](feat_t0_ev) 316 | flow_event_t1_ = self.predict_flow[level](feat_t1_ev) 317 | ## fusion flow 318 | flow_event_t0 = flow_t0_ev_up + flow_event_t0_ 319 | flow_event_t1 = flow_t1_ev_up + flow_event_t1_ 320 | # flow rescale 321 | down_evflow_t0 = rescale_flow(upsample2d(flow_event_t0, F0), F0.size(3), F0.size(2)) 322 | down_evflow_t1 = rescale_flow(upsample2d(flow_event_t1, F1), F1.size(3), F1.size(2)) 323 | down_upflow_t0 = rescale_flow(upsample2d(flow_t0_out_dict[level-1], F0), F0.size(3), F0.size(2)) 324 | down_upflow_t1 = rescale_flow(upsample2d(flow_t1_out_dict[level-1], F1), F1.size(3), F1.size(2)) 325 | ## warping with event flow and fusion flow 326 | F0_re = self.conv_frame[level-1](F0) 327 | F0_up_warp_ev = self.warp(F0_re, self.flow_scale*down_evflow_t0) 328 | F0_up_warp_frame = self.warp(F0_re, self.flow_scale*down_upflow_t0) 329 | F1_re = self.conv_frame[level-1](F1) 330 | F1_up_warp_ev = self.warp(F1_re, self.flow_scale*down_evflow_t1) 331 | F1_up_warp_frame = self.warp(F1_re, self.flow_scale*down_upflow_t1) 332 | Ft_up = self.conv_frame_t[level-1](feat_t) 333 | ## flow fusion 334 | _, out_fusion_t0 = self.flow_fusion_block(torch.cat((F0_up_warp_ev, F0_up_warp_frame, Ft_up, down_evflow_t0, down_upflow_t0), dim=1)) 335 | _, out_fusion_t1 = self.flow_fusion_block(torch.cat((F1_up_warp_ev, F1_up_warp_frame, Ft_up, down_evflow_t1, down_upflow_t1), dim=1)) 336 | mask_t0 = upsample2d(torch.sigmoid(out_fusion_t0[:, -1, : ,:])[:, None, :, :], E_t0_flow) 337 | mask_t1 = upsample2d(torch.sigmoid(out_fusion_t1[:, -1, :, :])[:, None, :, :], E_t1_flow) 338 | flow_fusion_t0 = (1-mask_t0)*upflow_t0 + mask_t0*flow_event_t0 339 | flow_fusion_t1 = (1-mask_t1)*upflow_t1 + mask_t1*flow_event_t1 340 | ## intermediate output 341 | if self.tb_debug: 342 | event_flow_dict.append(flow_event_t0) 343 | fusion_flow_dict.append(flow_fusion_t0) 344 | image_flow_dict.append(upflow_t0) 345 | mask_dict.append(mask_t0) 346 | # flow rescale 347 | down_flow_fusion_t0 = rescale_flow(upsample2d(flow_fusion_t0, F0), F0.size(3), F0.size(2)) 348 | down_flow_fusion_t1 = rescale_flow(upsample2d(flow_fusion_t1, F1), F1.size(3), F1.size(2)) 349 | # warping with optical flow 350 | feat10 = self.warp(F0, self.flow_scale*down_flow_fusion_t0) 351 | feat11 = self.warp(F1, self.flow_scale*down_flow_fusion_t1) 352 | # feature normalization 353 | feat_t_norm, feat10_norm, feat11_norm = self.normalize_features([feat_t, feat10, feat11], normalize=True, center=True, moments_across_channels=False, moments_across_images=False) 354 | # correlation 355 | corr_t0 = self.leakyRELU(self.corr(feat_t_norm, feat10_norm)) 356 | corr_t1 = self.leakyRELU(self.corr(feat_t_norm, feat11_norm)) 357 | # correlation refienement 358 | _, res_flow_t0 = self.corr_refinement[level](torch.cat((corr_t0, feat_t, down_flow_fusion_t0), dim=1)) 359 | _, res_flow_t1 = self.corr_refinement[level](torch.cat((corr_t1, feat_t, down_flow_fusion_t1), dim=1)) 360 | # frame-based optical flow generation 361 | flow_t0_frame = down_flow_fusion_t0 + res_flow_t0 362 | flow_t1_frame = down_flow_fusion_t1 + res_flow_t1 363 | ## upsampling frame-based optical flow 364 | upflow_t0_frame = rescale_flow(upsample2d(flow_t0_frame, flow_fusion_t0), flow_fusion_t0.size(3), flow_fusion_t0.size(2)) 365 | upflow_t1_frame = rescale_flow(upsample2d(flow_t1_frame, flow_fusion_t1), flow_fusion_t1.size(3), flow_fusion_t1.size(2)) 366 | ### output 367 | flow_t0_out_dict.append(upflow_t0_frame) 368 | flow_t1_out_dict.append(upflow_t1_frame) 369 | flow_t0_dict.append(self.flow_scale*upflow_t0_frame) 370 | flow_t1_dict.append(self.flow_scale*upflow_t1_frame) 371 | flow_t0_dict = flow_t0_dict[::-1] 372 | flow_t1_dict = flow_t1_dict[::-1] 373 | ## final output return 374 | flow_output_dict = {} 375 | flow_output_dict['flow_t0_dict'] = flow_t0_dict 376 | flow_output_dict['flow_t1_dict'] = flow_t1_dict 377 | if self.tb_debug: 378 | flow_output_dict['event_flow_dict'] = event_flow_dict 379 | flow_output_dict['fusion_flow_dict'] = fusion_flow_dict 380 | flow_output_dict['image_flow_dict'] = image_flow_dict 381 | flow_output_dict['mask_dict'] = mask_dict 382 | return flow_output_dict 383 | 384 | 385 | class frame_encoder(nn.Module): 386 | def __init__(self, in_dims, nf): 387 | super(frame_encoder, self).__init__() 388 | self.conv0 = conv3x3_leaky_relu(in_dims, nf) 389 | self.conv1 = conv_resblock_two(nf, nf) 390 | self.conv2 = conv_resblock_two(nf, 2*nf, stride=2) 391 | self.conv3 = conv_resblock_two(2*nf, 4*nf, stride=2) 392 | 393 | def forward(self, x): 394 | x_ = self.conv0(x) 395 | f1 = self.conv1(x_) 396 | f2 = self.conv2(f1) 397 | f3 = self.conv3(f2) 398 | return [f1, f2, f3] 399 | 400 | class event_encoder(nn.Module): 401 | def __init__(self, in_dims, nf): 402 | super(event_encoder, self).__init__() 403 | self.conv0 = conv3x3_leaky_relu(in_dims, nf) 404 | self.conv1 = conv_resblock_two(nf, nf) 405 | self.conv2 = conv_resblock_two(nf, 2*nf, stride=2) 406 | self.conv3 = conv_resblock_two(2*nf, 4*nf, stride=2) 407 | 408 | def forward(self, x): 409 | x_ = self.conv0(x) 410 | f1 = self.conv1(x_) 411 | f2 = self.conv2(f1) 412 | f3 = self.conv3(f2) 413 | return [f1, f2, f3] 414 | 415 | ################################################## 416 | ################# Restormer ##################### 417 | 418 | ########################################################################## 419 | ## Layer Norm 420 | def to_3d(x): 421 | return rearrange(x, 'b c h w -> b (h w) c') 422 | 423 | def to_4d(x,h,w): 424 | return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w) 425 | 426 | class BiasFree_LayerNorm(nn.Module): 427 | def __init__(self, normalized_shape): 428 | super(BiasFree_LayerNorm, self).__init__() 429 | if isinstance(normalized_shape, numbers.Integral): 430 | normalized_shape = (normalized_shape,) 431 | normalized_shape = torch.Size(normalized_shape) 432 | 433 | assert len(normalized_shape) == 1 434 | 435 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 436 | self.normalized_shape = normalized_shape 437 | 438 | def forward(self, x): 439 | sigma = x.var(-1, keepdim=True, unbiased=False) 440 | return x / torch.sqrt(sigma+1e-5) * self.weight 441 | 442 | 443 | class WithBias_LayerNorm(nn.Module): 444 | def __init__(self, normalized_shape): 445 | super(WithBias_LayerNorm, self).__init__() 446 | if isinstance(normalized_shape, numbers.Integral): 447 | normalized_shape = (normalized_shape,) 448 | normalized_shape = torch.Size(normalized_shape) 449 | 450 | assert len(normalized_shape) == 1 451 | 452 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 453 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 454 | self.normalized_shape = normalized_shape 455 | 456 | def forward(self, x): 457 | mu = x.mean(-1, keepdim=True) 458 | sigma = x.var(-1, keepdim=True, unbiased=False) 459 | return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias 460 | 461 | 462 | class LayerNorm(nn.Module): 463 | def __init__(self, dim, LayerNorm_type): 464 | super(LayerNorm, self).__init__() 465 | if LayerNorm_type =='BiasFree': 466 | self.body = BiasFree_LayerNorm(dim) 467 | else: 468 | self.body = WithBias_LayerNorm(dim) 469 | 470 | def forward(self, x): 471 | h, w = x.shape[-2:] 472 | return to_4d(self.body(to_3d(x)), h, w) 473 | 474 | ########################################################################## 475 | ## Gated-Dconv Feed-Forward Network (GDFN) 476 | class FeedForward(nn.Module): 477 | def __init__(self, dim, ffn_expansion_factor, bias): 478 | super(FeedForward, self).__init__() 479 | hidden_features = int(dim*ffn_expansion_factor) 480 | self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias) 481 | self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias) 482 | self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) 483 | 484 | def forward(self, x): 485 | x = self.project_in(x) 486 | x1, x2 = self.dwconv(x).chunk(2, dim=1) 487 | x = F.gelu(x1) * x2 488 | x = self.project_out(x) 489 | return x 490 | 491 | class CrossAttention(nn.Module): 492 | def __init__(self, dim, num_heads, bias): 493 | super(CrossAttention, self).__init__() 494 | self.num_heads = num_heads 495 | self.temperature1 = nn.Parameter(torch.ones(num_heads, 1, 1)) 496 | self.temperature2 = nn.Parameter(torch.ones(num_heads, 1, 1)) 497 | 498 | self.q = nn.Conv2d(dim, dim*2, kernel_size=1, bias=bias) 499 | self.kv1 = nn.Conv2d(dim, dim*2, kernel_size=1, bias=bias) 500 | self.kv2 = nn.Conv2d(dim, dim*2, kernel_size=1, bias=bias) 501 | self.q_dwconv = nn.Conv2d(dim*2, dim*2, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias) 502 | self.kv1_dwconv = nn.Conv2d(dim*2, dim*2, kernel_size=3, stride=1, padding=1, groups=dim*2, bias=bias) 503 | self.kv2_dwconv = nn.Conv2d(dim*2, dim*2, kernel_size=3, stride=1, padding=1, groups=dim*2, bias=bias) 504 | self.project_out = nn.Conv2d(dim*2, dim, kernel_size=1, bias=bias) 505 | 506 | def forward(self, x, attn_kv1, attn_kv2): 507 | b,c,h,w = x.shape 508 | 509 | q_ = self.q_dwconv(self.q(x)) 510 | kv1 = self.kv1_dwconv(self.kv1(attn_kv1)) 511 | kv2 = self.kv2_dwconv(self.kv2(attn_kv2)) 512 | q1,q2 = q_.chunk(2, dim=1) 513 | k1,v1 = kv1.chunk(2, dim=1) 514 | k2,v2 = kv2.chunk(2, dim=1) 515 | 516 | q1 = rearrange(q1, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 517 | q2 = rearrange(q2, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 518 | k1 = rearrange(k1, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 519 | v1 = rearrange(v1, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 520 | k2 = rearrange(k2, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 521 | v2 = rearrange(v2, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 522 | 523 | q1 = torch.nn.functional.normalize(q1, dim=-1) 524 | q2 = torch.nn.functional.normalize(q2, dim=-1) 525 | k1 = torch.nn.functional.normalize(k1, dim=-1) 526 | k2 = torch.nn.functional.normalize(k2, dim=-1) 527 | 528 | attn = (q1 @ k1.transpose(-2, -1)) * self.temperature1 529 | attn = attn.softmax(dim=-1) 530 | out1 = (attn @ v1) 531 | 532 | attn = (q2 @ k2.transpose(-2, -1)) * self.temperature2 533 | attn = attn.softmax(dim=-1) 534 | out2 = (attn @ v2) 535 | 536 | out1 = rearrange(out1, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) 537 | out2 = rearrange(out2, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) 538 | out = torch.cat((out1, out2), dim=1) 539 | out = self.project_out(out) 540 | return out 541 | 542 | 543 | ########################################################################## 544 | class CrossTransformerBlock(nn.Module): 545 | def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type): 546 | super(CrossTransformerBlock, self).__init__() 547 | self.norm1 = LayerNorm(dim, LayerNorm_type) 548 | self.norm_kv1 = LayerNorm(dim, LayerNorm_type) 549 | self.norm_kv2 = LayerNorm(dim, LayerNorm_type) 550 | self.attn = CrossAttention(dim, num_heads, bias) 551 | self.norm2 = LayerNorm(dim, LayerNorm_type) 552 | self.ffn = FeedForward(dim, ffn_expansion_factor, bias) 553 | 554 | def forward(self, x, attn_kv1, attn_kv2): 555 | x = x + self.attn(self.norm1(x), self.norm_kv1(attn_kv1), self.norm_kv2(attn_kv2)) 556 | x = x + self.ffn(self.norm2(x)) 557 | return x 558 | 559 | class CrossTransformerLayer(nn.Module): 560 | def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type, num_blocks): 561 | super(CrossTransformerLayer, self).__init__() 562 | self.blocks = nn.ModuleList([CrossTransformerBlock(dim=dim, num_heads=num_heads, 563 | ffn_expansion_factor=ffn_expansion_factor, bias=bias, 564 | LayerNorm_type=LayerNorm_type) for i in range(num_blocks)]) 565 | 566 | def forward(self, x, attn_kv=None, attn_kv2=None): 567 | for blk in self.blocks: 568 | x = blk(x, attn_kv, attn_kv2) 569 | return x 570 | 571 | 572 | ########################################################################## 573 | ## Multi-DConv Head Transposed Self-Attention (MDTA) 574 | class Attention(nn.Module): 575 | def __init__(self, dim, num_heads, bias): 576 | super(Attention, self).__init__() 577 | self.num_heads = num_heads 578 | self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) 579 | 580 | self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias) 581 | self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias) 582 | self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) 583 | 584 | 585 | 586 | def forward(self, x): 587 | b,c,h,w = x.shape 588 | 589 | qkv = self.qkv_dwconv(self.qkv(x)) 590 | q,k,v = qkv.chunk(3, dim=1) 591 | 592 | q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 593 | k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 594 | v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 595 | 596 | q = torch.nn.functional.normalize(q, dim=-1) 597 | k = torch.nn.functional.normalize(k, dim=-1) 598 | 599 | attn = (q @ k.transpose(-2, -1)) * self.temperature 600 | attn = attn.softmax(dim=-1) 601 | 602 | out = (attn @ v) 603 | 604 | out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) 605 | 606 | out = self.project_out(out) 607 | return out 608 | 609 | 610 | class Self_attention(nn.Module): 611 | def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type): 612 | super(Self_attention, self).__init__() 613 | self.norm1 = LayerNorm(dim, LayerNorm_type) 614 | self.attn = Attention(dim, num_heads, bias) 615 | self.norm2 = LayerNorm(dim, LayerNorm_type) 616 | self.ffn = FeedForward(dim, ffn_expansion_factor, bias) 617 | 618 | def forward(self, x): 619 | x = x + self.attn(self.norm1(x)) 620 | x = x + self.ffn(self.norm2(x)) 621 | return x 622 | 623 | 624 | class SelfAttentionLayer(nn.Module): 625 | def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type, num_blocks): 626 | super(SelfAttentionLayer, self).__init__() 627 | self.blocks = nn.ModuleList([Self_attention(dim=dim, num_heads=num_heads, ffn_expansion_factor=ffn_expansion_factor, bias=bias, 628 | LayerNorm_type=LayerNorm_type) for i in range(num_blocks)]) 629 | 630 | def forward(self, x): 631 | for blk in self.blocks: 632 | x = blk(x) 633 | return x 634 | 635 | 636 | class Upsample(nn.Module): 637 | def __init__(self, in_channel, out_channel): 638 | super(Upsample, self).__init__() 639 | self.deconv = nn.ConvTranspose2d(in_channel, out_channel, kernel_size=2, stride=2) 640 | 641 | def forward(self, x): 642 | out = self.deconv(x) 643 | return out 644 | 645 | def flops(self, H, W): 646 | flops = 0 647 | # conv 648 | flops += H*2*W*2*self.in_channel*self.out_channel*2*2 649 | print("Upsample:{%.2f}"%(flops/1e9)) 650 | return flops 651 | 652 | 653 | class Transformer(nn.Module): 654 | def __init__(self, unit_dim): 655 | super(Transformer, self).__init__() 656 | ## init qurey networks 657 | self.init_qurey_net(unit_dim) 658 | self.init_decoder(unit_dim) 659 | ## last conv 660 | self.last_conv0 = conv3x3(unit_dim*4, 3) 661 | self.last_conv1 = conv3x3(unit_dim*2, 3) 662 | self.last_conv2 = conv3x3(unit_dim, 3) 663 | 664 | def init_decoder(self, unit_dim): 665 | ### decoder 666 | ### attention k,v building (synthesis) 667 | self.build_kv0_syn = conv3x3_leaky_relu(unit_dim*3, unit_dim*4) 668 | self.build_kv1_syn = conv3x3_leaky_relu(int(unit_dim*1.5), unit_dim*2) 669 | self.build_kv2_syn = conv3x3_leaky_relu(int(unit_dim*0.75), unit_dim) 670 | ### attention k, v building (warping) 671 | self.build_kv0_warp = conv3x3_leaky_relu(unit_dim*3+6, unit_dim*4) 672 | self.build_kv1_warp = conv3x3_leaky_relu(int(unit_dim*1.5)+6, unit_dim*2) 673 | self.build_kv2_warp = conv3x3_leaky_relu(int(unit_dim*0.75)+6, unit_dim) 674 | ## level 1 675 | self.decoder1_1 = CrossTransformerLayer(dim=unit_dim*4, num_heads=4, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias', num_blocks=2) 676 | self.decoder1_2 = SelfAttentionLayer(dim=unit_dim*4, num_heads=4, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias', num_blocks=2) 677 | ## level 2 678 | self.decoder2_1 = CrossTransformerLayer(dim=unit_dim*2, num_heads=4, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias', num_blocks=2) 679 | self.decoder2_2 = SelfAttentionLayer(dim=unit_dim*2, num_heads=2, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias', num_blocks=2) 680 | ## level 3 681 | self.decoder3_1 = CrossTransformerLayer(dim=unit_dim, num_heads=4, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias', num_blocks=2) 682 | self.decoder3_2 = SelfAttentionLayer(dim=unit_dim, num_heads=1, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias', num_blocks=2) 683 | ## upsample 684 | self.upsample0 = Upsample(unit_dim*4, unit_dim*2) 685 | self.upsample1 = Upsample(unit_dim*2, unit_dim) 686 | ## conv after body 687 | self.conv_after_body0 = conv_resblock_one(4*unit_dim, 2*unit_dim) 688 | self.conv_after_body1 = conv_resblock_one(2*unit_dim, unit_dim) 689 | 690 | ### qurey network 691 | def init_qurey_net(self, unit_dim): 692 | ### building query 693 | ## stage 1 694 | self.enc_conv0 = conv3x3_leaky_relu(unit_dim+6, unit_dim) 695 | ## stage 2 696 | self.enc_conv1 = conv3x3_leaky_relu(unit_dim, 2*unit_dim, stride=2) 697 | ## stage 3 698 | self.enc_conv2 = conv3x3_leaky_relu(2*unit_dim, 4*unit_dim, stride=2) 699 | 700 | ## query buiding !! 701 | def build_qurey(self, event_feature, frame_feature, warped_feature): 702 | cat_in0 = torch.cat((event_feature[0], frame_feature[0], warped_feature[0]), dim=1) 703 | Q0 = self.enc_conv0(cat_in0) 704 | Q1 = self.enc_conv1(Q0) 705 | Q2 = self.enc_conv2(Q1) 706 | return [Q0, Q1, Q2] 707 | 708 | def forward_decoder(self, Q_list, warped_feature, frame_feature, event_feature): 709 | ## syntheis kv building 710 | cat_in0_syn = torch.cat((frame_feature[2], event_feature[2]), dim=1) 711 | attn_kv0_syn = self.build_kv0_syn(cat_in0_syn) 712 | cat_in1_syn = torch.cat((frame_feature[1], event_feature[1]), dim=1) 713 | attn_kv1_syn = self.build_kv1_syn(cat_in1_syn) 714 | cat_in2_syn = torch.cat((frame_feature[0], event_feature[0]), dim=1) 715 | attn_kv2_syn = self.build_kv2_syn(cat_in2_syn) 716 | ## warping kv building 717 | cat_in0_warp = torch.cat((warped_feature[2], event_feature[2]), dim=1) 718 | attn_kv0_warp = self.build_kv0_warp(cat_in0_warp) 719 | cat_in1_warp = torch.cat((warped_feature[1], event_feature[1]), dim=1) 720 | attn_kv1_warp = self.build_kv1_warp(cat_in1_warp) 721 | cat_in2_warp = torch.cat((warped_feature[0], event_feature[0]), dim=1) 722 | attn_kv2_warp = self.build_kv2_warp(cat_in2_warp) 723 | ## out 0 724 | _Q0 = Q_list[2] 725 | out0 = self.decoder1_1(_Q0, attn_kv0_syn, attn_kv0_warp) 726 | out0 = self.decoder1_2(out0) 727 | up_out0 = self.upsample0(out0) 728 | ## out 1 729 | _Q1 = Q_list[1] 730 | _Q1 = self.conv_after_body0(torch.cat((_Q1, up_out0), dim=1)) 731 | out1 = self.decoder2_1(_Q1, attn_kv1_syn, attn_kv1_warp) 732 | out1 = self.decoder2_2(out1) 733 | up_out1 = self.upsample1(out1) 734 | ## out2 735 | _Q2 = Q_list[0] 736 | _Q2 = self.conv_after_body1(torch.cat((_Q2, up_out1), dim=1)) 737 | out2 = self.decoder3_1(_Q2, attn_kv2_syn, attn_kv2_warp) 738 | out2 = self.decoder3_2(out2) 739 | return [out0, out1, out2] 740 | 741 | def forward(self, event_feature, frame_feature, warped_feature): 742 | ### forward encoder 743 | Q_list = self.build_qurey(event_feature, frame_feature, warped_feature) 744 | ### forward decoder 745 | out_decoder = self.forward_decoder(Q_list, warped_feature, frame_feature, event_feature) 746 | ### synthesis frame 747 | img0 = self.last_conv0(out_decoder[0]) 748 | img1 = self.last_conv1(out_decoder[1]) 749 | img2 = self.last_conv2(out_decoder[2]) 750 | return [img2, img1, img0] 751 | 752 | 753 | 754 | 755 | class EventInterpNet(nn.Module): 756 | def __init__(self, num_bins=16, flow_debug=False): 757 | super(EventInterpNet, self).__init__() 758 | unit_dim = 32 759 | # scale 760 | self.scale = 3 761 | # flownet 762 | self.flownet = FlowNet(md=4, tb_debug=flow_debug) 763 | self.flow_debug = flow_debug 764 | # encoder 765 | self.encoder_f = frame_encoder(3, unit_dim//4) 766 | self.encoder_e = event_encoder(16, unit_dim//2) 767 | # decoder 768 | self.transformer = Transformer(unit_dim*2) 769 | # channel scaling convolution 770 | self.conv_list = nn.ModuleList([conv1x1(unit_dim, unit_dim), conv1x1(unit_dim, unit_dim), conv1x1(unit_dim, unit_dim)]) 771 | # mode information 772 | self.mode = 'flow' 773 | 774 | def set_mode(self, mode): 775 | self.mode = mode 776 | 777 | def bwarp(self, x, flo): 778 | ''' 779 | x shape : [B,C,T,H,W] 780 | t_value shape : [B,1] ############### 781 | ''' 782 | B, C, H, W = x.size() 783 | # mesh grid 784 | xx = torch.arange(0, W).view(1, 1, 1, W).expand(B, 1, H, W) 785 | yy = torch.arange(0, H).view(1, 1, H, 1).expand(B, 1, H, W) 786 | grid = torch.cat((xx, yy), 1).float() 787 | 788 | if x.is_cuda: 789 | grid = grid.cuda() 790 | vgrid = torch.autograd.Variable(grid) + flo 791 | 792 | # scale grid to [-1,1] 793 | vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0 794 | vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0 795 | 796 | vgrid = vgrid.permute(0, 2, 3, 1) # [B,H,W,2] 797 | output = nn.functional.grid_sample(x, vgrid, align_corners=True) 798 | mask = torch.autograd.Variable(torch.ones(x.size())).cuda() 799 | mask = nn.functional.grid_sample(mask, vgrid, align_corners=True) 800 | # mask[mask<0.9999] = 0 801 | # mask[mask>0] = 1 802 | mask = mask.masked_fill_(mask < 0.999, 0) 803 | mask = mask.masked_fill_(mask > 0, 1) 804 | return output * mask 805 | 806 | def Flow_pyramid(self, flow): 807 | flow_pyr = [] 808 | flow_pyr.append(flow) 809 | for i in range(1, 3): 810 | flow_pyr.append(F.interpolate(flow, scale_factor=0.5 ** i, mode='bilinear') * (0.5 ** i)) 811 | return flow_pyr 812 | 813 | def Img_pyramid(self, Img): 814 | img_pyr = [] 815 | img_pyr.append(Img) 816 | for i in range(1, 3): 817 | img_pyr.append(F.interpolate(Img, scale_factor=0.5 ** i, mode='bilinear')) 818 | return img_pyr 819 | 820 | def synthesis(self, batch, OF_t0, OF_t1): 821 | ## frame encoding 822 | f_frame0 = self.encoder_f(batch['image_input0']) 823 | f_frame1 = self.encoder_f(batch['image_input1']) 824 | ## OF pyramid 825 | OF_t0_pyramid = self.Flow_pyramid(OF_t0[0]) 826 | OF_t1_pyramid = self.Flow_pyramid(OF_t1[0]) 827 | ## image pyramid 828 | I0_pyramid = self.Img_pyramid(batch['image_input0']) 829 | I1_pyramid = self.Img_pyramid(batch['image_input1']) 830 | # frame0_warped, frame1_warped = [], [] 831 | warped_feature, frame_feature = [], [] 832 | for idx in range(self.scale): 833 | frame0_warped = self.bwarp(torch.cat((f_frame0[idx], I0_pyramid[idx]),dim=1), OF_t0_pyramid[idx]) 834 | frame1_warped = self.bwarp(torch.cat((f_frame1[idx], I1_pyramid[idx]),dim=1), OF_t1_pyramid[idx]) 835 | warped_feature.append(torch.cat((frame0_warped, frame1_warped), dim=1)) 836 | frame_feature.append(torch.cat((f_frame0[idx], f_frame1[idx]), dim=1)) 837 | event_feature = [] 838 | # event encoding for frame interpolation 839 | f_event_0t = self.encoder_e(batch['event_input_0t']) 840 | f_event_t1 = self.encoder_e(batch['event_input_t1']) 841 | for idx in range(self.scale): 842 | event_feature.append(torch.cat((f_event_0t[idx], f_event_t1[idx]), dim=1)) 843 | img_out = self.transformer(event_feature, frame_feature, warped_feature) 844 | interp_out = [] 845 | for i in range(self.scale): 846 | interp_out.append(torch.clamp(img_out[i], 0, 1)) 847 | return interp_out 848 | 849 | def forward(self, batch): 850 | output_dict = {} 851 | # --- Flow-only mode --- 852 | if self.mode == 'flow': 853 | output_dict['flow_out'] = self.flownet(batch) 854 | # --- Joint mode: --- 855 | elif self.mode == 'joint': 856 | flow_out = self.flownet(batch) 857 | interp_out = self.synthesis(batch, flow_out['flow_t0_dict'], flow_out['flow_t1_dict']) 858 | output_dict.update({'flow_out': flow_out, 'interp_out': interp_out}) 859 | else: 860 | raise ValueError(f"Unsupported mode: {self.mode}") 861 | return output_dict -------------------------------------------------------------------------------- /models/final_models/ours_large.py: -------------------------------------------------------------------------------- 1 | from models.final_models.submodules import * 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.modules import conv 6 | from correlation_package.correlation import Correlation 7 | from einops import rearrange, repeat 8 | from einops.layers.torch import Rearrange 9 | from timm.models.layers import DropPath, trunc_normal_, to_2tuple 10 | from functools import reduce, lru_cache 11 | import torch.nn.functional as tf 12 | from torch.autograd import Variable 13 | from einops import rearrange 14 | import math 15 | import numbers 16 | import collections 17 | 18 | 19 | def conv(in_planes, out_planes, kernel_size=3, stride=1, dilation=1, isReLU=True): 20 | if isReLU: 21 | return nn.Sequential( 22 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation, 23 | padding=((kernel_size - 1) * dilation) // 2, bias=True), 24 | nn.LeakyReLU(0.1, inplace=True) 25 | ) 26 | else: 27 | return nn.Sequential( 28 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation, 29 | padding=((kernel_size - 1) * dilation) // 2, bias=True) 30 | ) 31 | 32 | def predict_flow(in_planes): 33 | return nn.Conv2d(in_planes,2,kernel_size=3,stride=1,padding=1,bias=True) 34 | 35 | def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): 36 | return nn.ConvTranspose2d(in_planes, out_planes, kernel_size, stride, padding, bias=True) 37 | 38 | 39 | class encoder_event_flow(nn.Module): 40 | def __init__(self, num_chs): 41 | super(encoder_event_flow, self).__init__() 42 | self.conv1 = conv_resblock_one(num_chs[0], num_chs[1], stride=1) 43 | self.conv2 = conv_resblock_one(num_chs[1], num_chs[2], stride=1) 44 | self.conv3 = conv_resblock_one(num_chs[2], num_chs[3], stride=2) 45 | self.conv4 = conv_resblock_one(num_chs[3], num_chs[4], stride=2) 46 | 47 | def forward(self, im): 48 | x = self.conv1(im) 49 | c11 = self.conv2(x) 50 | c12 = self.conv3(c11) 51 | c13 = self.conv4(c12) 52 | return c11, c12, c13 53 | 54 | 55 | class encoder_event_for_image_flow(nn.Module): 56 | def __init__(self, num_chs): 57 | super(encoder_event_for_image_flow, self).__init__() 58 | self.conv1 = conv_resblock_one(num_chs[0], num_chs[1], stride=1) 59 | self.conv2 = conv_resblock_one(num_chs[1], num_chs[2], stride=2) 60 | self.conv3 = conv_resblock_one(num_chs[2], num_chs[3], stride=2) 61 | self.conv4 = conv_resblock_one(num_chs[3], num_chs[4], stride=2) 62 | 63 | def forward(self, im): 64 | x = self.conv1(im) 65 | c11 = self.conv2(x) 66 | c12 = self.conv3(c11) 67 | c13 = self.conv4(c12) 68 | return c11, c12, c13 69 | 70 | 71 | class encoder_image_for_image_flow(nn.Module): 72 | def __init__(self, num_chs): 73 | super(encoder_image_for_image_flow, self).__init__() 74 | self.conv1 = conv_resblock_one(num_chs[0], num_chs[1], stride=1) 75 | self.conv2 = conv_resblock_one(num_chs[1], num_chs[2], stride=2) 76 | self.conv3 = conv_resblock_one(num_chs[2], num_chs[3], stride=2) 77 | self.conv4 = conv_resblock_one(num_chs[3], num_chs[4], stride=2) 78 | 79 | def forward(self, image): 80 | x = self.conv1(image) 81 | f1 = self.conv2(x) 82 | f2 = self.conv3(f1) 83 | f3 = self.conv4(f2) 84 | return f1, f2, f3 85 | 86 | def upsample2d(inputs, target_as, mode="bilinear"): 87 | _, _, h, w = target_as.size() 88 | return tf.interpolate(inputs, [h, w], mode=mode, align_corners=True) 89 | 90 | def upsample2d_hw(inputs, h, w, mode="bilinear"): 91 | return tf.interpolate(inputs, [h, w], mode=mode, align_corners=True) 92 | 93 | 94 | class DenseBlock(nn.Module): 95 | def __init__(self, ch_in): 96 | super(DenseBlock, self).__init__() 97 | self.conv1 = conv(ch_in, 128) 98 | self.conv2 = conv(ch_in + 128, 128) 99 | self.conv3 = conv(ch_in + 256, 96) 100 | self.conv4 = conv(ch_in + 352, 64) 101 | self.conv5 = conv(ch_in + 416, 32) 102 | self.conv_last = conv(ch_in + 448, 2, isReLU=False) 103 | 104 | def forward(self, x): 105 | x1 = torch.cat([self.conv1(x), x], dim=1) 106 | x2 = torch.cat([self.conv2(x1), x1], dim=1) 107 | x3 = torch.cat([self.conv3(x2), x2], dim=1) 108 | x4 = torch.cat([self.conv4(x3), x3], dim=1) 109 | x5 = torch.cat([self.conv5(x4), x4], dim=1) 110 | x_out = self.conv_last(x5) 111 | return x5, x_out 112 | 113 | 114 | 115 | class FlowEstimatorDense(nn.Module): 116 | def __init__(self, ch_in=64, f_channels=(128, 128, 96, 64, 32, 32), ch_out=2): 117 | super(FlowEstimatorDense, self).__init__() 118 | N = 0 119 | ind = 0 120 | N += ch_in 121 | self.conv1 = conv(N, f_channels[ind]) 122 | N += f_channels[ind] 123 | ind += 1 124 | self.conv2 = conv(N, f_channels[ind]) 125 | N += f_channels[ind] 126 | ind += 1 127 | self.conv3 = conv(N, f_channels[ind]) 128 | N += f_channels[ind] 129 | ind += 1 130 | self.conv4 = conv(N, f_channels[ind]) 131 | N += f_channels[ind] 132 | ind += 1 133 | self.conv5 = conv(N, f_channels[ind]) 134 | N += f_channels[ind] 135 | self.num_feature_channel = N 136 | ind += 1 137 | self.conv_last = conv(N, ch_out, isReLU=False) 138 | 139 | def forward(self, x): 140 | x1 = torch.cat([self.conv1(x), x], axis=1) 141 | x2 = torch.cat([self.conv2(x1), x1], axis=1) 142 | x3 = torch.cat([self.conv3(x2), x2], axis=1) 143 | x4 = torch.cat([self.conv4(x3), x3], axis=1) 144 | x5 = torch.cat([self.conv5(x4), x4], axis=1) 145 | x_out = self.conv_last(x5) 146 | return x5, x_out 147 | 148 | class Tfeat_RefineBlock(nn.Module): 149 | def __init__(self, ch_in_frame, ch_in_event, ch_in_frame_prev, prev_scale=False): 150 | super(Tfeat_RefineBlock, self).__init__() 151 | if prev_scale: 152 | nf = int((ch_in_frame*2+ch_in_event+ch_in_frame_prev)/4) 153 | else: 154 | nf = int((ch_in_frame*2+ch_in_event)/4) 155 | self.conv_refine = nn.Sequential(conv1x1(4*nf, nf), nn.ReLU(), conv3x3(nf, 2*nf), nn.ReLU(), conv_resblock_one(2*nf, ch_in_frame)) 156 | 157 | def forward(self, x): 158 | x1 = self.conv_refine(x) 159 | return x1 160 | 161 | def rescale_flow(flow, width_im, height_im): 162 | u_scale = float(width_im / flow.size(3)) 163 | v_scale = float(height_im / flow.size(2)) 164 | u, v = flow.chunk(2, dim=1) 165 | u = u_scale*u 166 | v = v_scale*v 167 | return torch.cat([u, v], dim=1) 168 | 169 | 170 | class FlowNet(nn.Module): 171 | def __init__(self, md=4, tb_debug=False): 172 | super(FlowNet, self).__init__() 173 | ## argument 174 | self.tb_debug = tb_debug 175 | # flow scale 176 | self.flow_scale = 20 177 | num_chs_frame = [3, 16, 32, 64, 96] 178 | num_chs_event = [16, 16, 32, 64, 128] 179 | num_chs_event_image = [16, 16, 16, 32, 64] 180 | ## for event-level flow 181 | self.encoder_event = encoder_event_flow(num_chs_event) 182 | ## for image-level flow 183 | self.encoder_image_flow = encoder_image_for_image_flow(num_chs_frame) 184 | self.encoder_image_flow_event = encoder_event_for_image_flow(num_chs_event_image) 185 | ## leaky relu 186 | self.leakyRELU = nn.LeakyReLU(0.1) 187 | ## correlation channel value 188 | self.corr = Correlation(pad_size=md, kernel_size=1, max_displacement=md, stride1=1, stride2=1, corr_multiply=1) 189 | ## correlation channel value 190 | nd = (2*md+1)**2 191 | self.corr_refinement = nn.ModuleList([DenseBlock(nd+num_chs_frame[-1]+2), 192 | DenseBlock(nd+num_chs_frame[-2]+2), 193 | DenseBlock(nd+num_chs_frame[-3]+2), 194 | DenseBlock(nd+num_chs_frame[-4]+2) 195 | ]) 196 | self.decoder_event = nn.ModuleList([conv_resblock_one(num_chs_event[-1], num_chs_event[-1]), 197 | conv_resblock_one(num_chs_event[-2]+ num_chs_event[-1]+2, num_chs_event[-2]), 198 | conv_resblock_one(num_chs_event[-3]+ num_chs_event[-2]+2, num_chs_event[-3])]) 199 | self.predict_flow = nn.ModuleList([conv3x3_leaky_relu(num_chs_event[-1], 2), 200 | conv3x3_leaky_relu(num_chs_event[-2], 2), 201 | conv3x3_leaky_relu(num_chs_event[-3], 2)]) 202 | self.conv_frame = nn.ModuleList([conv3x3_leaky_relu(num_chs_frame[-2], 32), 203 | conv3x3_leaky_relu(num_chs_frame[-3], 32)]) 204 | self.conv_frame_t = nn.ModuleList([conv3x3_leaky_relu(num_chs_frame[-2], 32), 205 | conv3x3_leaky_relu(num_chs_frame[-3], 32)]) 206 | self.flow_fusion_block = FlowEstimatorDense(32*3+4, (32, 32, 32, 16, 8), 1) 207 | self.feat_t_refinement = nn.ModuleList([Tfeat_RefineBlock(num_chs_frame[-1], num_chs_event_image[-1]*2, None, prev_scale=False), 208 | Tfeat_RefineBlock(num_chs_frame[-2], num_chs_event_image[-2]*2, num_chs_frame[-1], prev_scale=True), 209 | Tfeat_RefineBlock(num_chs_frame[-3], num_chs_event_image[-3]*2, num_chs_frame[-2], prev_scale=True), 210 | ]) 211 | 212 | 213 | def warp(self, x, flo): 214 | """ 215 | warp an image/tensor (im2) back to im1, according to the optical flow 216 | x: [B, C, H, W] (im2) 217 | flo: [B, 2, H, W] flow 218 | """ 219 | B, C, H, W = x.size() 220 | # mesh grid 221 | xx = torch.arange(0, W).view(1,-1).repeat(H,1) 222 | yy = torch.arange(0, H).view(-1,1).repeat(1,W) 223 | xx = xx.view(1,1,H,W).repeat(B,1,1,1) 224 | yy = yy.view(1,1,H,W).repeat(B,1,1,1) 225 | grid = torch.cat((xx,yy),1).float() 226 | 227 | if x.is_cuda: 228 | grid = grid.cuda() 229 | vgrid = Variable(grid) + flo 230 | 231 | # scale grid to [-1,1] 232 | vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:].clone() / max(W-1,1)-1.0 233 | vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:].clone() / max(H-1,1)-1.0 234 | 235 | vgrid = vgrid.permute(0,2,3,1) 236 | output = nn.functional.grid_sample(x, vgrid) 237 | mask = torch.autograd.Variable(torch.ones(x.size())).cuda() 238 | mask = nn.functional.grid_sample(mask, vgrid) 239 | mask[mask<0.9999] = 0 240 | mask[mask>0] = 1 241 | return output*mask 242 | 243 | def normalize_features(self, feature_list, normalize, center, moments_across_channels=True, moments_across_images=True): 244 | # Compute feature statistics. 245 | statistics = collections.defaultdict(list) 246 | axes = [1, 2, 3] if moments_across_channels else [2, 3] # [b, c, h, w] 247 | for feature_image in feature_list: 248 | mean = torch.mean(feature_image, axis=axes, keepdims=True) # [b,1,1,1] or [b,c,1,1] 249 | variance = torch.var(feature_image, axis=axes, keepdims=True) # [b,1,1,1] or [b,c,1,1] 250 | statistics['mean'].append(mean) 251 | statistics['var'].append(variance) 252 | 253 | if moments_across_images: 254 | statistics['mean'] = ([torch.mean(F.stack(statistics['mean'], axis=0), axis=(0, ))] * len(feature_list)) 255 | statistics['var'] = ([torch.var(F.stack(statistics['var'], axis=0), axis=(0, ))] * len(feature_list)) 256 | 257 | statistics['std'] = [torch.sqrt(v + 1e-16) for v in statistics['var']] 258 | 259 | # Center and normalize features. 260 | if center: 261 | feature_list = [f - mean for f, mean in zip(feature_list, statistics['mean'])] 262 | if normalize: 263 | feature_list = [f / std for f, std in zip(feature_list, statistics['std'])] 264 | return feature_list 265 | 266 | def forward(self, batch): 267 | # F for frame feature 268 | # E for event feature 269 | ### encoding 270 | ## image feature 271 | # feature pyramid 272 | F0_pyramid = self.encoder_image_flow(batch['image_input0'])[::-1] 273 | F1_pyramid = self.encoder_image_flow(batch['image_input1'])[::-1] 274 | E_0t_pyramid = self.encoder_image_flow_event(batch['event_input_0t'])[::-1] 275 | E_t1_pyramid = self.encoder_image_flow_event(batch['event_input_t1'])[::-1] 276 | # encoder event 277 | E_t0_pyramid_flow = self.encoder_event(batch['event_input_t0'])[::-1] 278 | E_t1_pyramid_flow = self.encoder_event(batch['event_input_t1'])[::-1] 279 | ### decoding optical flow 280 | ## level 0 281 | flow_t0_out_dict, flow_t1_out_dict, flow_t0_dict, flow_t1_dict = [], [], [], [] 282 | ## event flow and image flow 283 | if self.tb_debug: 284 | event_flow_dict, image_flow_dict, fusion_flow_dict, mask_dict = [], [], [], [] 285 | for level, (E_t0_flow, E_t1_flow, E_0t, E_t1, F0, F1) in enumerate(zip(E_t0_pyramid_flow, E_t1_pyramid_flow, E_0t_pyramid, E_t1_pyramid, F0_pyramid, F1_pyramid)): 286 | if level==0: 287 | ## event flow generation 288 | feat_t0_ev = self.decoder_event[level](E_t0_flow) 289 | feat_t1_ev = self.decoder_event[level](E_t1_flow) 290 | flow_event_t0 = self.predict_flow[level](feat_t0_ev) 291 | flow_event_t1 = self.predict_flow[level](feat_t1_ev) 292 | ## fusion flow(scale == 0) 293 | flow_fusion_t0 = flow_event_t0 294 | flow_fusion_t1 = flow_event_t1 295 | ## t feature 296 | feat_t_in = torch.cat((F0, F1, E_0t, E_t1), dim=1) 297 | feat_t = self.feat_t_refinement[level](feat_t_in) 298 | else: 299 | ## feat t 300 | upfeat0_t = upsample2d(feat_t, F0) 301 | feat_t_in = torch.cat((upfeat0_t, F0, F1, E_0t, E_t1), dim=1) 302 | feat_t = self.feat_t_refinement[level](feat_t_in) 303 | #### event-based optical flow 304 | ## event flow generation 305 | upflow_t0 = rescale_flow(upsample2d(flow_t0_out_dict[level-1], E_t0_flow), E_t0_flow.size(3), E_t0_flow.size(2)) 306 | upflow_t1 = rescale_flow(upsample2d(flow_t1_out_dict[level-1], E_t1_flow), E_t1_flow.size(3), E_t1_flow.size(2)) 307 | # upsample feat_t0 308 | feat_t0_ev_up = upsample2d(feat_t0_ev, E_t0_flow) 309 | feat_t1_ev_up = upsample2d(feat_t1_ev, E_t1_flow) 310 | # decoder event 311 | flow_t0_ev_up = rescale_flow(upsample2d(flow_event_t0, E_t0_flow), E_t0_flow.size(3), E_t0_flow.size(2)) 312 | flow_t1_ev_up = rescale_flow(upsample2d(flow_event_t1, E_t1_flow), E_t1_flow.size(3), E_t1_flow.size(2)) 313 | feat_t0_ev = self.decoder_event[level](torch.cat((E_t0_flow, feat_t0_ev_up, flow_t0_ev_up ), dim=1)) 314 | feat_t1_ev = self.decoder_event[level](torch.cat((E_t1_flow, feat_t1_ev_up, flow_t1_ev_up), dim=1)) 315 | ## project flow 316 | flow_event_t0_ = self.predict_flow[level](feat_t0_ev) 317 | flow_event_t1_ = self.predict_flow[level](feat_t1_ev) 318 | ## fusion flow 319 | flow_event_t0 = flow_t0_ev_up + flow_event_t0_ 320 | flow_event_t1 = flow_t1_ev_up + flow_event_t1_ 321 | # flow rescale 322 | down_evflow_t0 = rescale_flow(upsample2d(flow_event_t0, F0), F0.size(3), F0.size(2)) 323 | down_evflow_t1 = rescale_flow(upsample2d(flow_event_t1, F1), F1.size(3), F1.size(2)) 324 | down_upflow_t0 = rescale_flow(upsample2d(flow_t0_out_dict[level-1], F0), F0.size(3), F0.size(2)) 325 | down_upflow_t1 = rescale_flow(upsample2d(flow_t1_out_dict[level-1], F1), F1.size(3), F1.size(2)) 326 | ## warping with event flow and fusion flow 327 | F0_re = self.conv_frame[level-1](F0) 328 | F0_up_warp_ev = self.warp(F0_re, self.flow_scale*down_evflow_t0) 329 | F0_up_warp_frame = self.warp(F0_re, self.flow_scale*down_upflow_t0) 330 | F1_re = self.conv_frame[level-1](F1) 331 | F1_up_warp_ev = self.warp(F1_re, self.flow_scale*down_evflow_t1) 332 | F1_up_warp_frame = self.warp(F1_re, self.flow_scale*down_upflow_t1) 333 | Ft_up = self.conv_frame_t[level-1](feat_t) 334 | ## flow fusion 335 | _, out_fusion_t0 = self.flow_fusion_block(torch.cat((F0_up_warp_ev, F0_up_warp_frame, Ft_up, down_evflow_t0, down_upflow_t0), dim=1)) 336 | _, out_fusion_t1 = self.flow_fusion_block(torch.cat((F1_up_warp_ev, F1_up_warp_frame, Ft_up, down_evflow_t1, down_upflow_t1), dim=1)) 337 | mask_t0 = upsample2d(torch.sigmoid(out_fusion_t0[:, -1, : ,:])[:, None, :, :], E_t0_flow) 338 | mask_t1 = upsample2d(torch.sigmoid(out_fusion_t1[:, -1, :, :])[:, None, :, :], E_t1_flow) 339 | flow_fusion_t0 = (1-mask_t0)*upflow_t0 + mask_t0*flow_event_t0 340 | flow_fusion_t1 = (1-mask_t1)*upflow_t1 + mask_t1*flow_event_t1 341 | ## intermediate output 342 | if self.tb_debug: 343 | event_flow_dict.append(flow_event_t0) 344 | fusion_flow_dict.append(flow_fusion_t0) 345 | image_flow_dict.append(upflow_t0) 346 | mask_dict.append(mask_t0) 347 | # flow rescale 348 | down_flow_fusion_t0 = rescale_flow(upsample2d(flow_fusion_t0, F0), F0.size(3), F0.size(2)) 349 | down_flow_fusion_t1 = rescale_flow(upsample2d(flow_fusion_t1, F1), F1.size(3), F1.size(2)) 350 | # warping with optical flow 351 | feat10 = self.warp(F0, self.flow_scale*down_flow_fusion_t0) 352 | feat11 = self.warp(F1, self.flow_scale*down_flow_fusion_t1) 353 | # feature normalization 354 | feat_t_norm, feat10_norm, feat11_norm = self.normalize_features([feat_t, feat10, feat11], normalize=True, center=True, moments_across_channels=False, moments_across_images=False) 355 | # correlation 356 | corr_t0 = self.leakyRELU(self.corr(feat_t_norm, feat10_norm)) 357 | corr_t1 = self.leakyRELU(self.corr(feat_t_norm, feat11_norm)) 358 | # correlation refienement 359 | _, res_flow_t0 = self.corr_refinement[level](torch.cat((corr_t0, feat_t, down_flow_fusion_t0), dim=1)) 360 | _, res_flow_t1 = self.corr_refinement[level](torch.cat((corr_t1, feat_t, down_flow_fusion_t1), dim=1)) 361 | # frame-based optical flow generation 362 | flow_t0_frame = down_flow_fusion_t0 + res_flow_t0 363 | flow_t1_frame = down_flow_fusion_t1 + res_flow_t1 364 | ## upsampling frame-based optical flow 365 | upflow_t0_frame = rescale_flow(upsample2d(flow_t0_frame, flow_fusion_t0), flow_fusion_t0.size(3), flow_fusion_t0.size(2)) 366 | upflow_t1_frame = rescale_flow(upsample2d(flow_t1_frame, flow_fusion_t1), flow_fusion_t1.size(3), flow_fusion_t1.size(2)) 367 | ### output 368 | flow_t0_out_dict.append(upflow_t0_frame) 369 | flow_t1_out_dict.append(upflow_t1_frame) 370 | flow_t0_dict.append(self.flow_scale*upflow_t0_frame) 371 | flow_t1_dict.append(self.flow_scale*upflow_t1_frame) 372 | flow_t0_dict = flow_t0_dict[::-1] 373 | flow_t1_dict = flow_t1_dict[::-1] 374 | ## final output return 375 | flow_output_dict = {} 376 | flow_output_dict['flow_t0_dict'] = flow_t0_dict 377 | flow_output_dict['flow_t1_dict'] = flow_t1_dict 378 | if self.tb_debug: 379 | flow_output_dict['event_flow_dict'] = event_flow_dict 380 | flow_output_dict['fusion_flow_dict'] = fusion_flow_dict 381 | flow_output_dict['image_flow_dict'] = image_flow_dict 382 | flow_output_dict['mask_dict'] = mask_dict 383 | return flow_output_dict 384 | 385 | 386 | class frame_encoder(nn.Module): 387 | def __init__(self, in_dims, nf): 388 | super(frame_encoder, self).__init__() 389 | self.conv0 = conv3x3_leaky_relu(in_dims, nf) 390 | self.conv1 = conv_resblock_two(nf, nf) 391 | self.conv2 = conv_resblock_two(nf, 2*nf, stride=2) 392 | self.conv3 = conv_resblock_two(2*nf, 4*nf, stride=2) 393 | 394 | def forward(self, x): 395 | x_ = self.conv0(x) 396 | f1 = self.conv1(x_) 397 | f2 = self.conv2(f1) 398 | f3 = self.conv3(f2) 399 | return [f1, f2, f3] 400 | 401 | class event_encoder(nn.Module): 402 | def __init__(self, in_dims, nf): 403 | super(event_encoder, self).__init__() 404 | self.conv0 = conv3x3_leaky_relu(in_dims, nf) 405 | self.conv1 = conv_resblock_two(nf, nf) 406 | self.conv2 = conv_resblock_two(nf, 2*nf, stride=2) 407 | self.conv3 = conv_resblock_two(2*nf, 4*nf, stride=2) 408 | 409 | def forward(self, x): 410 | x_ = self.conv0(x) 411 | f1 = self.conv1(x_) 412 | f2 = self.conv2(f1) 413 | f3 = self.conv3(f2) 414 | return [f1, f2, f3] 415 | 416 | ################################################## 417 | ################# Restormer ##################### 418 | 419 | ########################################################################## 420 | ## Layer Norm 421 | def to_3d(x): 422 | return rearrange(x, 'b c h w -> b (h w) c') 423 | 424 | def to_4d(x,h,w): 425 | return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w) 426 | 427 | class BiasFree_LayerNorm(nn.Module): 428 | def __init__(self, normalized_shape): 429 | super(BiasFree_LayerNorm, self).__init__() 430 | if isinstance(normalized_shape, numbers.Integral): 431 | normalized_shape = (normalized_shape,) 432 | normalized_shape = torch.Size(normalized_shape) 433 | 434 | assert len(normalized_shape) == 1 435 | 436 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 437 | self.normalized_shape = normalized_shape 438 | 439 | def forward(self, x): 440 | sigma = x.var(-1, keepdim=True, unbiased=False) 441 | return x / torch.sqrt(sigma+1e-5) * self.weight 442 | 443 | 444 | class WithBias_LayerNorm(nn.Module): 445 | def __init__(self, normalized_shape): 446 | super(WithBias_LayerNorm, self).__init__() 447 | if isinstance(normalized_shape, numbers.Integral): 448 | normalized_shape = (normalized_shape,) 449 | normalized_shape = torch.Size(normalized_shape) 450 | 451 | assert len(normalized_shape) == 1 452 | 453 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 454 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 455 | self.normalized_shape = normalized_shape 456 | 457 | def forward(self, x): 458 | mu = x.mean(-1, keepdim=True) 459 | sigma = x.var(-1, keepdim=True, unbiased=False) 460 | return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias 461 | 462 | 463 | class LayerNorm(nn.Module): 464 | def __init__(self, dim, LayerNorm_type): 465 | super(LayerNorm, self).__init__() 466 | if LayerNorm_type =='BiasFree': 467 | self.body = BiasFree_LayerNorm(dim) 468 | else: 469 | self.body = WithBias_LayerNorm(dim) 470 | 471 | def forward(self, x): 472 | h, w = x.shape[-2:] 473 | return to_4d(self.body(to_3d(x)), h, w) 474 | 475 | ########################################################################## 476 | ## Gated-Dconv Feed-Forward Network (GDFN) 477 | class FeedForward(nn.Module): 478 | def __init__(self, dim, ffn_expansion_factor, bias): 479 | super(FeedForward, self).__init__() 480 | hidden_features = int(dim*ffn_expansion_factor) 481 | self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias) 482 | self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias) 483 | self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) 484 | 485 | def forward(self, x): 486 | x = self.project_in(x) 487 | x1, x2 = self.dwconv(x).chunk(2, dim=1) 488 | x = F.gelu(x1) * x2 489 | x = self.project_out(x) 490 | return x 491 | 492 | class CrossAttention(nn.Module): 493 | def __init__(self, dim, num_heads, bias): 494 | super(CrossAttention, self).__init__() 495 | self.num_heads = num_heads 496 | self.temperature1 = nn.Parameter(torch.ones(num_heads, 1, 1)) 497 | self.temperature2 = nn.Parameter(torch.ones(num_heads, 1, 1)) 498 | 499 | self.q = nn.Conv2d(dim, dim*2, kernel_size=1, bias=bias) 500 | self.kv1 = nn.Conv2d(dim, dim*2, kernel_size=1, bias=bias) 501 | self.kv2 = nn.Conv2d(dim, dim*2, kernel_size=1, bias=bias) 502 | self.q_dwconv = nn.Conv2d(dim*2, dim*2, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias) 503 | self.kv1_dwconv = nn.Conv2d(dim*2, dim*2, kernel_size=3, stride=1, padding=1, groups=dim*2, bias=bias) 504 | self.kv2_dwconv = nn.Conv2d(dim*2, dim*2, kernel_size=3, stride=1, padding=1, groups=dim*2, bias=bias) 505 | self.project_out = nn.Conv2d(dim*2, dim, kernel_size=1, bias=bias) 506 | 507 | def forward(self, x, attn_kv1, attn_kv2): 508 | b,c,h,w = x.shape 509 | 510 | q_ = self.q_dwconv(self.q(x)) 511 | kv1 = self.kv1_dwconv(self.kv1(attn_kv1)) 512 | kv2 = self.kv2_dwconv(self.kv2(attn_kv2)) 513 | q1,q2 = q_.chunk(2, dim=1) 514 | k1,v1 = kv1.chunk(2, dim=1) 515 | k2,v2 = kv2.chunk(2, dim=1) 516 | 517 | q1 = rearrange(q1, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 518 | q2 = rearrange(q2, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 519 | k1 = rearrange(k1, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 520 | v1 = rearrange(v1, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 521 | k2 = rearrange(k2, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 522 | v2 = rearrange(v2, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 523 | 524 | q1 = torch.nn.functional.normalize(q1, dim=-1) 525 | q2 = torch.nn.functional.normalize(q2, dim=-1) 526 | k1 = torch.nn.functional.normalize(k1, dim=-1) 527 | k2 = torch.nn.functional.normalize(k2, dim=-1) 528 | 529 | attn = (q1 @ k1.transpose(-2, -1)) * self.temperature1 530 | attn = attn.softmax(dim=-1) 531 | out1 = (attn @ v1) 532 | 533 | attn = (q2 @ k2.transpose(-2, -1)) * self.temperature2 534 | attn = attn.softmax(dim=-1) 535 | out2 = (attn @ v2) 536 | 537 | out1 = rearrange(out1, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) 538 | out2 = rearrange(out2, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) 539 | out = torch.cat((out1, out2), dim=1) 540 | out = self.project_out(out) 541 | return out 542 | 543 | 544 | ########################################################################## 545 | class CrossTransformerBlock(nn.Module): 546 | def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type): 547 | super(CrossTransformerBlock, self).__init__() 548 | self.norm1 = LayerNorm(dim, LayerNorm_type) 549 | self.norm_kv1 = LayerNorm(dim, LayerNorm_type) 550 | self.norm_kv2 = LayerNorm(dim, LayerNorm_type) 551 | self.attn = CrossAttention(dim, num_heads, bias) 552 | self.norm2 = LayerNorm(dim, LayerNorm_type) 553 | self.ffn = FeedForward(dim, ffn_expansion_factor, bias) 554 | 555 | def forward(self, x, attn_kv1, attn_kv2): 556 | x = x + self.attn(self.norm1(x), self.norm_kv1(attn_kv1), self.norm_kv2(attn_kv2)) 557 | x = x + self.ffn(self.norm2(x)) 558 | return x 559 | 560 | class CrossTransformerLayer(nn.Module): 561 | def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type, num_blocks): 562 | super(CrossTransformerLayer, self).__init__() 563 | self.blocks = nn.ModuleList([CrossTransformerBlock(dim=dim, num_heads=num_heads, 564 | ffn_expansion_factor=ffn_expansion_factor, bias=bias, 565 | LayerNorm_type=LayerNorm_type) for i in range(num_blocks)]) 566 | 567 | def forward(self, x, attn_kv=None, attn_kv2=None): 568 | for blk in self.blocks: 569 | x = blk(x, attn_kv, attn_kv2) 570 | return x 571 | 572 | 573 | ########################################################################## 574 | ## Multi-DConv Head Transposed Self-Attention (MDTA) 575 | class Attention(nn.Module): 576 | def __init__(self, dim, num_heads, bias): 577 | super(Attention, self).__init__() 578 | self.num_heads = num_heads 579 | self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) 580 | 581 | self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias) 582 | self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias) 583 | self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) 584 | 585 | 586 | 587 | def forward(self, x): 588 | b,c,h,w = x.shape 589 | 590 | qkv = self.qkv_dwconv(self.qkv(x)) 591 | q,k,v = qkv.chunk(3, dim=1) 592 | 593 | q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 594 | k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 595 | v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 596 | 597 | q = torch.nn.functional.normalize(q, dim=-1) 598 | k = torch.nn.functional.normalize(k, dim=-1) 599 | 600 | attn = (q @ k.transpose(-2, -1)) * self.temperature 601 | attn = attn.softmax(dim=-1) 602 | 603 | out = (attn @ v) 604 | 605 | out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) 606 | 607 | out = self.project_out(out) 608 | return out 609 | 610 | 611 | class Self_attention(nn.Module): 612 | def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type): 613 | super(Self_attention, self).__init__() 614 | self.norm1 = LayerNorm(dim, LayerNorm_type) 615 | self.attn = Attention(dim, num_heads, bias) 616 | self.norm2 = LayerNorm(dim, LayerNorm_type) 617 | self.ffn = FeedForward(dim, ffn_expansion_factor, bias) 618 | 619 | def forward(self, x): 620 | x = x + self.attn(self.norm1(x)) 621 | x = x + self.ffn(self.norm2(x)) 622 | return x 623 | 624 | 625 | class SelfAttentionLayer(nn.Module): 626 | def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type, num_blocks): 627 | super(SelfAttentionLayer, self).__init__() 628 | self.blocks = nn.ModuleList([Self_attention(dim=dim, num_heads=num_heads, ffn_expansion_factor=ffn_expansion_factor, bias=bias, 629 | LayerNorm_type=LayerNorm_type) for i in range(num_blocks)]) 630 | 631 | def forward(self, x): 632 | for blk in self.blocks: 633 | x = blk(x) 634 | return x 635 | 636 | 637 | class Upsample(nn.Module): 638 | def __init__(self, in_channel, out_channel): 639 | super(Upsample, self).__init__() 640 | self.deconv = nn.ConvTranspose2d(in_channel, out_channel, kernel_size=2, stride=2) 641 | 642 | def forward(self, x): 643 | out = self.deconv(x) 644 | return out 645 | 646 | def flops(self, H, W): 647 | flops = 0 648 | # conv 649 | flops += H*2*W*2*self.in_channel*self.out_channel*2*2 650 | print("Upsample:{%.2f}"%(flops/1e9)) 651 | return flops 652 | 653 | 654 | class Transformer(nn.Module): 655 | def __init__(self, unit_dim): 656 | super(Transformer, self).__init__() 657 | ## init qurey networks 658 | self.init_qurey_net(unit_dim) 659 | self.init_decoder(unit_dim) 660 | ## last conv 661 | self.last_conv0 = conv3x3(unit_dim*4, 3) 662 | self.last_conv1 = conv3x3(unit_dim*2, 3) 663 | self.last_conv2 = conv3x3(unit_dim, 3) 664 | 665 | def init_decoder(self, unit_dim): 666 | ### decoder 667 | ### attention k,v building (synthesis) 668 | self.build_kv0_syn = conv3x3_leaky_relu(unit_dim*3, unit_dim*4) 669 | self.build_kv1_syn = conv3x3_leaky_relu(int(unit_dim*1.5), unit_dim*2) 670 | self.build_kv2_syn = conv3x3_leaky_relu(int(unit_dim*0.75), unit_dim) 671 | ### attention k, v building (warping) 672 | self.build_kv0_warp = conv3x3_leaky_relu(unit_dim*3+6, unit_dim*4) 673 | self.build_kv1_warp = conv3x3_leaky_relu(int(unit_dim*1.5)+6, unit_dim*2) 674 | self.build_kv2_warp = conv3x3_leaky_relu(int(unit_dim*0.75)+6, unit_dim) 675 | ## level 1 676 | self.decoder1_1 = CrossTransformerLayer(dim=unit_dim*4, num_heads=4, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias', num_blocks=2) 677 | self.decoder1_2 = SelfAttentionLayer(dim=unit_dim*4, num_heads=4, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias', num_blocks=2) 678 | ## level 2 679 | self.decoder2_1 = CrossTransformerLayer(dim=unit_dim*2, num_heads=4, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias', num_blocks=2) 680 | self.decoder2_2 = SelfAttentionLayer(dim=unit_dim*2, num_heads=2, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias', num_blocks=2) 681 | ## level 3 682 | self.decoder3_1 = CrossTransformerLayer(dim=unit_dim, num_heads=4, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias', num_blocks=2) 683 | self.decoder3_2 = SelfAttentionLayer(dim=unit_dim, num_heads=1, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias', num_blocks=2) 684 | ## upsample 685 | self.upsample0 = Upsample(unit_dim*4, unit_dim*2) 686 | self.upsample1 = Upsample(unit_dim*2, unit_dim) 687 | ## conv after body 688 | self.conv_after_body0 = conv_resblock_one(4*unit_dim, 2*unit_dim) 689 | self.conv_after_body1 = conv_resblock_one(2*unit_dim, unit_dim) 690 | 691 | ### qurey network 692 | def init_qurey_net(self, unit_dim): 693 | ### building query 694 | ## stage 1 695 | self.enc_conv0 = conv3x3_leaky_relu(unit_dim+6, unit_dim) 696 | ## stage 2 697 | self.enc_conv1 = conv3x3_leaky_relu(unit_dim, 2*unit_dim, stride=2) 698 | ## stage 3 699 | self.enc_conv2 = conv3x3_leaky_relu(2*unit_dim, 4*unit_dim, stride=2) 700 | 701 | ## query buiding !! 702 | def build_qurey(self, event_feature, frame_feature, warped_feature): 703 | cat_in0 = torch.cat((event_feature[0], frame_feature[0], warped_feature[0]), dim=1) 704 | Q0 = self.enc_conv0(cat_in0) 705 | Q1 = self.enc_conv1(Q0) 706 | Q2 = self.enc_conv2(Q1) 707 | return [Q0, Q1, Q2] 708 | 709 | def forward_decoder(self, Q_list, warped_feature, frame_feature, event_feature): 710 | ## syntheis kv building 711 | cat_in0_syn = torch.cat((frame_feature[2], event_feature[2]), dim=1) 712 | attn_kv0_syn = self.build_kv0_syn(cat_in0_syn) 713 | cat_in1_syn = torch.cat((frame_feature[1], event_feature[1]), dim=1) 714 | attn_kv1_syn = self.build_kv1_syn(cat_in1_syn) 715 | cat_in2_syn = torch.cat((frame_feature[0], event_feature[0]), dim=1) 716 | attn_kv2_syn = self.build_kv2_syn(cat_in2_syn) 717 | ## warping kv building 718 | cat_in0_warp = torch.cat((warped_feature[2], event_feature[2]), dim=1) 719 | attn_kv0_warp = self.build_kv0_warp(cat_in0_warp) 720 | cat_in1_warp = torch.cat((warped_feature[1], event_feature[1]), dim=1) 721 | attn_kv1_warp = self.build_kv1_warp(cat_in1_warp) 722 | cat_in2_warp = torch.cat((warped_feature[0], event_feature[0]), dim=1) 723 | attn_kv2_warp = self.build_kv2_warp(cat_in2_warp) 724 | ## out 0 725 | _Q0 = Q_list[2] 726 | out0 = self.decoder1_1(_Q0, attn_kv0_syn, attn_kv0_warp) 727 | out0 = self.decoder1_2(out0) 728 | up_out0 = self.upsample0(out0) 729 | ## out 1 730 | _Q1 = Q_list[1] 731 | _Q1 = self.conv_after_body0(torch.cat((_Q1, up_out0), dim=1)) 732 | out1 = self.decoder2_1(_Q1, attn_kv1_syn, attn_kv1_warp) 733 | out1 = self.decoder2_2(out1) 734 | up_out1 = self.upsample1(out1) 735 | ## out2 736 | _Q2 = Q_list[0] 737 | _Q2 = self.conv_after_body1(torch.cat((_Q2, up_out1), dim=1)) 738 | out2 = self.decoder3_1(_Q2, attn_kv2_syn, attn_kv2_warp) 739 | out2 = self.decoder3_2(out2) 740 | return [out0, out1, out2] 741 | 742 | def forward(self, event_feature, frame_feature, warped_feature): 743 | ### forward encoder 744 | Q_list = self.build_qurey(event_feature, frame_feature, warped_feature) 745 | ### forward decoder 746 | out_decoder = self.forward_decoder(Q_list, warped_feature, frame_feature, event_feature) 747 | ### synthesis frame 748 | img0 = self.last_conv0(out_decoder[0]) 749 | img1 = self.last_conv1(out_decoder[1]) 750 | img2 = self.last_conv2(out_decoder[2]) 751 | return [img2, img1, img0] 752 | 753 | 754 | 755 | 756 | class EventInterpNet(nn.Module): 757 | def __init__(self, num_bins=16, flow_debug=False): 758 | super(EventInterpNet, self).__init__() 759 | unit_dim = 44 760 | # scale 761 | self.scale = 3 762 | # flownet 763 | self.flownet = FlowNet(md=4, tb_debug=flow_debug) 764 | self.flow_debug = flow_debug 765 | # encoder 766 | self.encoder_f = frame_encoder(3, unit_dim//4) 767 | self.encoder_e = event_encoder(16, unit_dim//2) 768 | # decoder 769 | self.transformer = Transformer(unit_dim*2) 770 | # channel scaling convolution 771 | self.conv_list = nn.ModuleList([conv1x1(unit_dim, unit_dim), conv1x1(unit_dim, unit_dim), conv1x1(unit_dim, unit_dim)]) 772 | 773 | def set_mode(self, mode): 774 | self.mode = mode 775 | 776 | def bwarp(self, x, flo): 777 | ''' 778 | x shape : [B,C,T,H,W] 779 | t_value shape : [B,1] ############### 780 | ''' 781 | B, C, H, W = x.size() 782 | # mesh grid 783 | xx = torch.arange(0, W).view(1, 1, 1, W).expand(B, 1, H, W) 784 | yy = torch.arange(0, H).view(1, 1, H, 1).expand(B, 1, H, W) 785 | grid = torch.cat((xx, yy), 1).float() 786 | 787 | if x.is_cuda: 788 | grid = grid.cuda() 789 | vgrid = torch.autograd.Variable(grid) + flo 790 | 791 | # scale grid to [-1,1] 792 | vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0 793 | vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0 794 | 795 | vgrid = vgrid.permute(0, 2, 3, 1) # [B,H,W,2] 796 | output = nn.functional.grid_sample(x, vgrid, align_corners=True) 797 | mask = torch.autograd.Variable(torch.ones(x.size())).cuda() 798 | mask = nn.functional.grid_sample(mask, vgrid, align_corners=True) 799 | # mask[mask<0.9999] = 0 800 | # mask[mask>0] = 1 801 | mask = mask.masked_fill_(mask < 0.999, 0) 802 | mask = mask.masked_fill_(mask > 0, 1) 803 | return output * mask 804 | 805 | def Flow_pyramid(self, flow): 806 | flow_pyr = [] 807 | flow_pyr.append(flow) 808 | for i in range(1, 3): 809 | flow_pyr.append(F.interpolate(flow, scale_factor=0.5 ** i, mode='bilinear') * (0.5 ** i)) 810 | return flow_pyr 811 | 812 | def Img_pyramid(self, Img): 813 | img_pyr = [] 814 | img_pyr.append(Img) 815 | for i in range(1, 3): 816 | img_pyr.append(F.interpolate(Img, scale_factor=0.5 ** i, mode='bilinear')) 817 | return img_pyr 818 | 819 | def synthesis(self, batch, OF_t0, OF_t1): 820 | ## frame encoding 821 | f_frame0 = self.encoder_f(batch['image_input0']) 822 | f_frame1 = self.encoder_f(batch['image_input1']) 823 | ## OF pyramid 824 | OF_t0_pyramid = self.Flow_pyramid(OF_t0[0]) 825 | OF_t1_pyramid = self.Flow_pyramid(OF_t1[0]) 826 | ## image pyramid 827 | I0_pyramid = self.Img_pyramid(batch['image_input0']) 828 | I1_pyramid = self.Img_pyramid(batch['image_input1']) 829 | # frame0_warped, frame1_warped = [], [] 830 | warped_feature, frame_feature = [], [] 831 | for idx in range(self.scale): 832 | frame0_warped = self.bwarp(torch.cat((f_frame0[idx], I0_pyramid[idx]),dim=1), OF_t0_pyramid[idx]) 833 | frame1_warped = self.bwarp(torch.cat((f_frame1[idx], I1_pyramid[idx]),dim=1), OF_t1_pyramid[idx]) 834 | warped_feature.append(torch.cat((frame0_warped, frame1_warped), dim=1)) 835 | frame_feature.append(torch.cat((f_frame0[idx], f_frame1[idx]), dim=1)) 836 | # after_tmp_feature = self.conv_list[idx](tmp_feature) 837 | event_feature = [] 838 | # event encoding for frame interpolation 839 | f_event_0t = self.encoder_e(batch['event_input_0t']) 840 | f_event_t1 = self.encoder_e(batch['event_input_t1']) 841 | for idx in range(self.scale): 842 | event_feature.append(torch.cat((f_event_0t[idx], f_event_t1[idx]), dim=1)) 843 | img_out = self.transformer(event_feature, frame_feature, warped_feature) 844 | output_clean = [] 845 | for i in range(self.scale): 846 | output_clean.append(torch.clamp(img_out[i], 0, 1)) 847 | return output_clean 848 | 849 | def forward(self, batch): 850 | output_dict = {} 851 | # --- Flow-only mode --- 852 | if self.mode == 'flow': 853 | output_dict['flow_out'] = self.flownet(batch) 854 | # --- Joint mode: --- 855 | elif self.mode == 'joint': 856 | flow_out = self.flownet(batch) 857 | interp_out = self.synthesis(batch, flow_out['flow_t0_dict'], flow_out['flow_t1_dict']) 858 | output_dict.update({'flow_out': flow_out, 'interp_out': interp_out}) 859 | else: 860 | raise ValueError(f"Unsupported mode: {self.mode}") 861 | return output_dict -------------------------------------------------------------------------------- /models/final_models/submodules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | 5 | def actFunc(act, *args, **kwargs): 6 | act = act.lower() 7 | if act == 'relu': 8 | return nn.ReLU() 9 | elif act == 'relu6': 10 | return nn.ReLU6() 11 | elif act == 'leakyrelu': 12 | return nn.LeakyReLU(0.1) 13 | elif act == 'prelu': 14 | return nn.PReLU() 15 | elif act == 'rrelu': 16 | return nn.RReLU(0.1, 0.3) 17 | elif act == 'selu': 18 | return nn.SELU() 19 | elif act == 'celu': 20 | return nn.CELU() 21 | elif act == 'elu': 22 | return nn.ELU() 23 | elif act == 'gelu': 24 | return nn.GELU() 25 | elif act == 'tanh': 26 | return nn.Tanh() 27 | else: 28 | raise NotImplementedError 29 | 30 | class ResBlock(nn.Module): 31 | """ 32 | Residual block 33 | """ 34 | def __init__(self, in_chs, activation='relu', batch_norm=False): 35 | super(ResBlock, self).__init__() 36 | op = [] 37 | for i in range(2): 38 | op.append(conv3x3(in_chs, in_chs)) 39 | if batch_norm: 40 | op.append(nn.BatchNorm2d(in_chs)) 41 | if i == 0: 42 | op.append(actFunc(activation)) 43 | self.main_branch = nn.Sequential(*op) 44 | 45 | def forward(self, x): 46 | out = self.main_branch(x) 47 | out += x 48 | return out 49 | 50 | # conv blocks 51 | def conv1x1(in_channels, out_channels, stride=1): 52 | return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=True) 53 | 54 | def conv3x3(in_channels, out_channels, stride=1): 55 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=True) 56 | 57 | def conv5x5(in_channels, out_channels, stride=1): 58 | return nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=stride, padding=2, bias=True) 59 | 60 | def deconv4x4(in_channels, out_channels, stride=2): 61 | return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=stride, padding=1) 62 | 63 | def deconv5x5(in_channels, out_channels, stride=2): 64 | return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=5, stride=stride, padding=2, output_padding=1) 65 | 66 | # conv resblock 67 | def conv_resblock_three(in_channels, out_channels, stride=1): 68 | return nn.Sequential(conv3x3(in_channels, out_channels, stride), nn.ReLU(), ResBlock(out_channels), ResBlock(out_channels), ResBlock(out_channels)) 69 | 70 | def conv_resblock_two(in_channels, out_channels, stride=1): 71 | return nn.Sequential(conv3x3(in_channels, out_channels, stride), nn.ReLU(), ResBlock(out_channels), ResBlock(out_channels)) 72 | 73 | def conv_resblock_one(in_channels, out_channels, stride=1): 74 | return nn.Sequential(conv3x3(in_channels, out_channels, stride), nn.ReLU(), ResBlock(out_channels)) 75 | 76 | def conv_1x1_resblock_one(in_channels, out_channels, stride=1): 77 | return nn.Sequential(conv1x1(in_channels, out_channels, stride), nn.ReLU(), ResBlock(out_channels)) 78 | 79 | def conv_resblock_two_DS(in_channels, out_channels, stride=2): 80 | return nn.Sequential(conv3x3(in_channels, out_channels, stride), nn.ReLU(), ResBlock(out_channels), ResBlock(out_channels)) 81 | 82 | def conv3x3_leaky_relu(in_channels, out_channels, stride=1): 83 | return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=True), nn.LeakyReLU(0.1)) 84 | 85 | def conv1x1_leaky_relu(in_channels, out_channels, stride=1): 86 | return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=True), nn.LeakyReLU(0.1)) -------------------------------------------------------------------------------- /models/loss_handler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import importlib 4 | from models.final_models.submodules import * 5 | from math import ceil 6 | from utils.flow_utils import cal_grad2_error 7 | from utils.utils import AverageMeter 8 | 9 | 10 | class LossHandler: 11 | def __init__(self, smoothness_weight, scale, loss_weight): 12 | self.smoothness_weight = smoothness_weight 13 | self.scale = scale 14 | self.loss_weight = loss_weight 15 | self.loss_total_meter = AverageMeter() 16 | self.loss_image_meter = AverageMeter() 17 | self.loss_warp_meter = AverageMeter() 18 | self.loss_flow_meter = AverageMeter() 19 | self.loss_smoothness_meter = AverageMeter() 20 | 21 | self.reset_cache() 22 | 23 | def reset_cache(self): 24 | self.loss = 0 25 | self.loss_image = 0 26 | self.loss_flow = 0 27 | self.loss_warping = 0 28 | self.loss_smoothness = 0 29 | 30 | def compute_multiscale_loss(self, gt_list, pred_list): 31 | self.loss_image = 0 32 | for i in range(self.scale): 33 | self.loss_image += self.loss_weight[i] * self._l1_loss(gt_list[i], pred_list[i]) 34 | self.loss = self.loss_image 35 | return self.loss 36 | 37 | def compute_flow_loss(self, outputs, batch): 38 | self.loss_warping = 0 39 | self.loss_smoothness = 0 40 | 41 | imaget_est0_list, imaget_est1_list = [], [] 42 | 43 | for idx in range(len(outputs['flow_out']['flow_t0_dict'])): 44 | est0, _ = self._warp(batch['image_pyramid_0'][idx], outputs['flow_out']['flow_t0_dict'][idx]) 45 | est1, _ = self._warp(batch['image_pyramid_1'][idx], outputs['flow_out']['flow_t1_dict'][idx]) 46 | 47 | imaget_est0_list.append(est0) 48 | imaget_est1_list.append(est1) 49 | 50 | gt = batch['clean_gt_MS_images'][idx] 51 | loss0 = self._l1_loss(gt, est0) 52 | loss1 = self._l1_loss(gt, est1) 53 | smooth0 = cal_grad2_error(outputs['flow_out']['flow_t0_dict'][idx]/20, gt, 1.0) 54 | smooth1 = cal_grad2_error(outputs['flow_out']['flow_t1_dict'][idx]/20, gt, 1.0) 55 | 56 | self.loss_warping += loss0 + loss1 57 | self.loss_smoothness += smooth0 + smooth1 58 | 59 | self.loss_flow = self.loss_warping + self.smoothness_weight * self.loss_smoothness 60 | self.loss = self.loss_flow 61 | return self.loss, imaget_est0_list, imaget_est1_list 62 | 63 | def _l1_loss(self, x, y): 64 | return torch.sqrt((x - y) ** 2 + 1e-6).mean() 65 | 66 | def _warp(self, x, flo): 67 | B, C, H, W = x.size() 68 | xx = torch.arange(0, W).view(1, 1, 1, W).expand(B, 1, H, W) 69 | yy = torch.arange(0, H).view(1, 1, H, 1).expand(B, 1, H, W) 70 | grid = torch.cat((xx, yy), 1).float() 71 | if x.is_cuda: 72 | grid = grid.cuda() 73 | vgrid = torch.autograd.Variable(grid) + flo 74 | 75 | vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0 76 | vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0 77 | vgrid = vgrid.permute(0, 2, 3, 1) 78 | 79 | output = nn.functional.grid_sample(x, vgrid, align_corners=True) 80 | mask = torch.ones_like(x).cuda() if x.is_cuda else torch.ones_like(x) 81 | mask = nn.functional.grid_sample(mask, vgrid, align_corners=True) 82 | mask = mask.masked_fill(mask < 0.999, 0).masked_fill(mask > 0, 1) 83 | return output * mask, mask 84 | 85 | def update_meters(self, mode): 86 | self.loss_flow_meter.update(self.loss_flow) 87 | self.loss_warp_meter.update(self.loss_warping) 88 | self.loss_smoothness_meter.update(self.loss_smoothness) 89 | if mode == 'joint': 90 | self.loss_total_meter.update(self.loss) 91 | self.loss_image_meter.update(self.loss_image) 92 | 93 | def reset_meters(self): 94 | self.loss_total_meter.reset() 95 | self.loss_image_meter.reset() 96 | self.loss_flow_meter.reset() 97 | self.loss_warp_meter.reset() 98 | self.loss_smoothness_meter.reset() -------------------------------------------------------------------------------- /models/model_manager.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import importlib 4 | from math import ceil 5 | from models.final_models.submodules import * 6 | from utils.utils import AverageMeter, batch2device 7 | from .loss_handler import LossHandler 8 | 9 | 10 | class OurModel: 11 | def __init__(self, args): 12 | self.voxel_num_bins = args.voxel_num_bins 13 | self.flow_debug = bool(args.flow_tb_debug) 14 | self.scale = 3 15 | self.loss_weight = [1, 0.1, 0.1] 16 | self.batch = {} 17 | self.outputs = {} 18 | self.test_outputs = {} 19 | self.downsample = nn.AvgPool2d(2, stride=2) 20 | 21 | self.loss_handler = LossHandler( 22 | smoothness_weight=args.smoothness_weight, 23 | scale=self.scale, 24 | loss_weight=self.loss_weight 25 | ) 26 | 27 | def initialize(self, model_folder, model_name): 28 | mod = importlib.import_module(f'models.{model_folder}.{model_name}') 29 | self.net = mod.EventInterpNet(self.voxel_num_bins, self.flow_debug) 30 | 31 | def cuda(self): 32 | self.net.cuda() 33 | 34 | def train(self): 35 | self.net.train() 36 | 37 | def eval(self): 38 | self.net.eval() 39 | 40 | def use_multi_gpu(self): 41 | self.net = nn.DataParallel(self.net) 42 | 43 | def fix_flownet(self): 44 | net = self.net.module if isinstance(self.net, nn.DataParallel) else self.net 45 | for param in net.flownet.parameters(): 46 | param.requires_grad = False 47 | 48 | def get_optimizer_params(self): 49 | return self.net.parameters() 50 | 51 | def set_mode(self, mode): 52 | net = self.net.module if isinstance(self.net, nn.DataParallel) else self.net 53 | net.set_mode(mode) 54 | 55 | def set_train_input(self, sample): 56 | self._set_common_input(sample) 57 | self._generate_multi_scale_inputs(sample) 58 | 59 | def set_input(self, sample): 60 | self._set_common_input(sample) 61 | 62 | def _set_common_input(self, sample): 63 | self.batch['image_input0'] = sample['clean_image_first'].float() 64 | self.batch['image_input1'] = sample['clean_image_last'].float() 65 | self.batch['imaget_input'] = sample['clean_middle'].float() 66 | self.batch['event_input_t0'] = sample['voxel_grid_t0'].float() 67 | self.batch['event_input_0t'] = sample['voxel_grid_0t'].float() 68 | self.batch['event_input_t1'] = sample['voxel_grid_t1'].float() 69 | self.batch['clean_gt_images'] = sample['clean_middle'] 70 | 71 | def _generate_multi_scale_inputs(self, sample): 72 | labels = sample['clean_middle'] 73 | image_0 = sample['clean_image_first'] 74 | image_1 = sample['clean_image_last'] 75 | self.batch['clean_gt_MS_images'] = [labels] 76 | self.batch['image_pyramid_0'] = [image_0] 77 | self.batch['image_pyramid_1'] = [image_1] 78 | 79 | for _ in range(self.scale - 1): 80 | labels = self.downsample(labels.clone()) 81 | image_0 = self.downsample(image_0.clone()) 82 | image_1 = self.downsample(image_1.clone()) 83 | self.batch['clean_gt_MS_images'].append(labels) 84 | self.batch['image_pyramid_0'].append(image_0) 85 | self.batch['image_pyramid_1'].append(image_1) 86 | 87 | def forward_nets(self): 88 | self.outputs = self.net(self.batch) 89 | 90 | def forward_joint_test(self): 91 | self.test_outputs = self.net(self.batch) 92 | self.test_outputs['flow_out']['flow_t0_dict'] = self.test_outputs['flow_out']['flow_t0_dict'][0][..., 0:self.H_org,0:self.W_org] 93 | self.test_outputs['flow_out']['flow_t1_dict'] = self.test_outputs['flow_out']['flow_t1_dict'][0][..., 0:self.H_org,0:self.W_org] 94 | self.test_outputs['interp_out'] = self.test_outputs['interp_out'][0][..., 0:self.H_org,0:self.W_org] 95 | 96 | def get_multi_scale_loss(self): 97 | return self.loss_handler.compute_multiscale_loss( 98 | self.batch['clean_gt_MS_images'], 99 | self.outputs['interp_out'] 100 | ) 101 | 102 | def get_flow_loss(self): 103 | loss_flow, imaget_est0_list, imaget_est1_list = self.loss_handler.compute_flow_loss(self.outputs, self.batch) 104 | self.batch['imaget_est0_warp'] = imaget_est0_list 105 | self.batch['imaget_est1_warp'] = imaget_est1_list 106 | return loss_flow 107 | 108 | def update_loss_meters(self, mode): 109 | self.loss_handler.update_meters(mode) 110 | 111 | def reset_loss_meters(self): 112 | self.loss_handler.reset_meters() 113 | 114 | def load_model(self, state_dict): 115 | if list(state_dict.keys())[0].startswith('module.'): 116 | new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} 117 | self.net.load_state_dict(new_state_dict) 118 | else: 119 | self.net.load_state_dict(state_dict) 120 | 121 | def set_test_input(self, sample): 122 | B, _, H, W = sample['clean_image_first'].shape 123 | H_ = ceil(H / 64) * 64 124 | W_ = ceil(W / 64) * 64 125 | 126 | C1 = torch.zeros((B, 3, H_, W_)).cuda() 127 | C2 = torch.zeros((B, 3, H_, W_)).cuda() 128 | 129 | self.batch['image_input0_org'] = sample['clean_image_first'] 130 | self.batch['image_input1_org'] = sample['clean_image_last'] 131 | 132 | C1[:, :, :H, :W] = sample['clean_image_first'] 133 | C2[:, :, :H, :W] = sample['clean_image_last'] 134 | 135 | self.batch['image_input0'] = C1 136 | self.batch['image_input1'] = C2 137 | 138 | Vt0 = torch.zeros((B, self.voxel_num_bins, H_, W_)).cuda() 139 | Vt1 = torch.zeros((B, self.voxel_num_bins, H_, W_)).cuda() 140 | V0t = torch.zeros((B, self.voxel_num_bins, H_, W_)).cuda() 141 | 142 | Vt0[:, :, :H, :W] = sample['voxel_grid_t0'] 143 | Vt1[:, :, :H, :W] = sample['voxel_grid_t1'] 144 | V0t[:, :, :H, :W] = sample['voxel_grid_0t'] 145 | 146 | self.batch['event_input_t0'] = Vt0 147 | self.batch['event_input_0t'] = V0t 148 | self.batch['event_input_t1'] = Vt1 149 | 150 | self.H_org = H 151 | self.W_org = W 152 | 153 | -------------------------------------------------------------------------------- /pretrained_model/README.md: -------------------------------------------------------------------------------- 1 | ### put the downloaded model here -------------------------------------------------------------------------------- /run_samples.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import numpy as np 4 | import os 5 | from models.model_manager import OurModel 6 | from skimage.io import imread 7 | import cv2 8 | from utils.utils import * 9 | 10 | 11 | torch.backends.cudnn.enabled = True 12 | torch.backends.cudnn.benchmark = True 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--voxel_num_bins', type = int, default=16) 16 | 17 | parser.add_argument('--sample_folder_path', type = str, default='./sample_data') 18 | parser.add_argument('--save_output_dir', type = str, default='./output') 19 | parser.add_argument('--image_number', type = int, default=0) 20 | 21 | parser.add_argument('--model_folder', type = str, default='final_models') 22 | parser.add_argument('--model_name', type = str, default='ours') 23 | parser.add_argument('--flow_tb_debug', type = str2bool, default='False') 24 | parser.add_argument('--smoothness_weight', type = float, default=10.0) 25 | parser.add_argument('--ckpt_path', type=str, default='pretrained_model/ours_weight.pth') 26 | args = parser.parse_args() 27 | 28 | 29 | first_image_name = os.path.join(args.sample_folder_path, str(args.image_number).zfill(5) + '.png') 30 | second_image_name = os.path.join(args.sample_folder_path, str(args.image_number+1).zfill(5) + '.png') 31 | first_image_np = imread(first_image_name) 32 | second_image_np = imread(second_image_name) 33 | frame1 = torch.from_numpy(first_image_np).permute(2,0,1).float().unsqueeze(0) / 255.0 34 | frame3 = torch.from_numpy(second_image_np).permute(2,0,1).float().unsqueeze(0) / 255.0 35 | 36 | voxel_0t_name = os.path.join(args.sample_folder_path, str(args.image_number).zfill(5) + '_0t.npz') 37 | voxel_t0_name = os.path.join(args.sample_folder_path, str(args.image_number).zfill(5) + '_t0.npz') 38 | voxel_t1_name = os.path.join(args.sample_folder_path, str(args.image_number).zfill(5) + '_t1.npz') 39 | voxel_0t = torch.from_numpy(np.load(voxel_0t_name)["data"])[None, ...] 40 | voxel_t1 = torch.from_numpy(np.load(voxel_t1_name)["data"])[None, ...] 41 | voxel_t0 = torch.from_numpy(np.load(voxel_t0_name)["data"])[None, ...] 42 | 43 | model = OurModel(args) 44 | model.initialize(args.model_folder, args.model_name) 45 | 46 | ckpt = torch.load(args.ckpt_path, map_location='cpu') 47 | model.load_model(ckpt) 48 | 49 | model.cuda() 50 | model.set_mode('joint') 51 | with torch.no_grad(): 52 | # patch-wise evaluation 53 | iter_idx = 0 54 | h_size_patch_testing = 640 55 | h_overlap_size = 305 56 | w_size_patch_testing = 896 57 | w_overlap_size = 352 58 | sample = {} 59 | sample['clean_image_first'] = frame1.cuda() 60 | sample['clean_image_last'] = frame3.cuda() 61 | sample['voxel_grid_0t'] = voxel_0t.cuda() 62 | sample['voxel_grid_t1'] = voxel_t1.cuda() 63 | sample['voxel_grid_t0'] = voxel_t0.cuda() 64 | 65 | B, C, H, W = frame1.shape 66 | 67 | h_stride = h_size_patch_testing - h_overlap_size 68 | w_stride = w_size_patch_testing - w_overlap_size 69 | h_idx_list = list(range(0, H-h_size_patch_testing, h_stride)) + [max(0, H-h_size_patch_testing)] 70 | w_idx_list = list(range(0, W-w_size_patch_testing, w_stride)) + [max(0, W-w_size_patch_testing)] 71 | # output 72 | E = torch.zeros(B, C, H, W).cuda() 73 | W_ = torch.zeros_like(E).cuda() 74 | input_keys = ['clean_image_first', 'clean_image_last', 'voxel_grid_0t', 'voxel_grid_t1', 'voxel_grid_t0'] 75 | not_overlap_border = True 76 | for h_idx in h_idx_list: 77 | for w_idx in w_idx_list: 78 | _sample = {} 79 | for input_key in input_keys: 80 | _sample[input_key] = sample[input_key][..., h_idx:h_idx+h_size_patch_testing, w_idx:w_idx+w_size_patch_testing] 81 | model.set_test_input(_sample) 82 | model.forward_joint_test() 83 | out_patch = model.test_outputs['interp_out'] 84 | out_patch_mask = torch.ones_like(out_patch) 85 | if not_overlap_border: 86 | if h_idx < h_idx_list[-1]: 87 | out_patch[..., -h_overlap_size//2:, :] *= 0 88 | out_patch_mask[..., -h_overlap_size//2:, :] *= 0 89 | if w_idx < w_idx_list[-1]: 90 | out_patch[..., -w_overlap_size//2:] *= 0 91 | out_patch_mask[..., -w_overlap_size//2:] *= 0 92 | if h_idx > h_idx_list[0]: 93 | out_patch[..., :h_overlap_size//2, :] *= 0 94 | out_patch_mask[..., :h_overlap_size//2, :] *= 0 95 | if w_idx > w_idx_list[0]: 96 | out_patch[..., :w_overlap_size//2] *= 0 97 | out_patch_mask[..., :w_overlap_size//2] *= 0 98 | E[:, :, h_idx:(h_idx+h_size_patch_testing), w_idx:(w_idx+w_size_patch_testing)].add_(out_patch) 99 | W_[:, :, h_idx:(h_idx+h_size_patch_testing), w_idx:(w_idx+w_size_patch_testing)].add_(out_patch_mask) 100 | output = E.div_(W_) 101 | clean_middle_np = tensor2numpy(output) 102 | ## save output 103 | os.makedirs(args.save_output_dir, exist_ok=True) 104 | ## _0,_2 is output 105 | cv2.imwrite(os.path.join(args.save_output_dir, str(args.image_number).zfill(5) + '_0.png'), cv2.cvtColor(first_image_np, cv2.COLOR_RGB2BGR)) 106 | cv2.imwrite(os.path.join(args.save_output_dir, str(args.image_number).zfill(5) + '_2.png'), cv2.cvtColor(second_image_np, cv2.COLOR_RGB2BGR)) 107 | ## _1 is output 108 | cv2.imwrite(os.path.join(args.save_output_dir, str(args.image_number).zfill(5) + '_1.png'), cv2.cvtColor(clean_middle_np, cv2.COLOR_RGB2BGR)) -------------------------------------------------------------------------------- /sample_data/00000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelpro/CBMNet/248405c313ce29f87ad42873df46792c76f55d1a/sample_data/00000.png -------------------------------------------------------------------------------- /sample_data/00000_0t.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelpro/CBMNet/248405c313ce29f87ad42873df46792c76f55d1a/sample_data/00000_0t.npz -------------------------------------------------------------------------------- /sample_data/00000_t0.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelpro/CBMNet/248405c313ce29f87ad42873df46792c76f55d1a/sample_data/00000_t0.npz -------------------------------------------------------------------------------- /sample_data/00000_t1.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelpro/CBMNet/248405c313ce29f87ad42873df46792c76f55d1a/sample_data/00000_t1.npz -------------------------------------------------------------------------------- /sample_data/00001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intelpro/CBMNet/248405c313ce29f87ad42873df46792c76f55d1a/sample_data/00001.png -------------------------------------------------------------------------------- /test_bsergb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.utils import * 3 | from tqdm import tqdm 4 | import argparse 5 | import os 6 | from utils.dataloader_bsergb import * 7 | from models.model_manager import OurModel 8 | from utils.flow_utils import * 9 | import torchvision.utils as vutils 10 | 11 | 12 | def get_argument(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--val_batch_size', type=int, default=1) 15 | parser.add_argument('--voxel_num_bins', type=int, default=16) 16 | parser.add_argument('--flow_tb_debug', type=str2bool, default='False') 17 | parser.add_argument('--flow_tb_viz', type=str2bool, default='True') 18 | parser.add_argument('--warp_tb_debug', type=str2bool, default='True') 19 | parser.add_argument('--val_mode', type=str2bool, default='False') 20 | parser.add_argument('--val_skip_num_list', default=[1, 3]) 21 | parser.add_argument('--model_folder', type=str, default='final_models') 22 | parser.add_argument('--model_name', type=str, default='ours_large') 23 | parser.add_argument('--use_smoothness_loss', type=str2bool, default='True') 24 | parser.add_argument('--smoothness_weight', type=float, default=10.0) 25 | parser.add_argument('--num_threads', type=int, default=12) 26 | parser.add_argument('--experiment_name', type=str, default='test_bsergb_dataset') 27 | parser.add_argument('--tb_update_thresh', type=int, default=1) 28 | parser.add_argument('--data_dir', type=str, default='/home/user/dataset/bsergb_interpolation_v2/') 29 | parser.add_argument('--ckpt_path', type=str, default='pretrained_model/Ours_Large_BSERGB.pth') 30 | parser.add_argument('--use_multigpu', type=str2bool, default='True') 31 | parser.add_argument('--train_skip_num_list', default=[1, 3]) 32 | args = parser.parse_args() 33 | return args 34 | 35 | 36 | class Trainer(object): 37 | def __init__(self, args): 38 | self.args = args 39 | self._init_model() 40 | self._init_metrics() 41 | self._init_dataloader() 42 | 43 | def _init_dataloader(self): 44 | val_set_dict = get_BSERGB_val_dataset(self.args.data_dir, self.args.val_skip_num_list, mode='1_TEST') 45 | self.val_loader_dict = {} 46 | for skip_num, val_dataset in val_set_dict.items(): 47 | self.val_loader_dict[skip_num] = torch.utils.data.DataLoader( 48 | val_dataset, 49 | batch_size=self.args.val_batch_size, 50 | shuffle=False, 51 | num_workers=self.args.num_threads, 52 | pin_memory=True 53 | ) 54 | 55 | def _init_model(self): 56 | self.model = OurModel(self.args) 57 | self.model.initialize(self.args.model_folder, self.args.model_name) 58 | ckpt = torch.load(self.args.ckpt_path, map_location='cpu') 59 | self.model.load_model(ckpt['model_state_dict']) 60 | 61 | if torch.cuda.is_available(): 62 | self.model.cuda() 63 | 64 | if self.args.use_multigpu: 65 | self.model.use_multi_gpu() 66 | 67 | def _init_metrics(self): 68 | self.PSNR_calculator = PSNR() 69 | 70 | def test_joint(self, epoch=0): 71 | psnr_total = AverageMeter() 72 | psnr_interval = AverageMeter() 73 | 74 | self.model.eval() 75 | self.model.set_mode('joint') 76 | 77 | os.makedirs('./outputs', exist_ok=True) 78 | os.makedirs(f'./logs/{self.args.experiment_name}', exist_ok=True) 79 | 80 | with torch.no_grad(): 81 | for skip_num, val_loader in self.val_loader_dict.items(): 82 | output_save_path = f'./outputs/net_out/{skip_num}skip' 83 | gt_save_path = f'./outputs/gt/{skip_num}skip' 84 | os.makedirs(output_save_path, exist_ok=True) 85 | os.makedirs(gt_save_path, exist_ok=True) 86 | for i, sample in enumerate(tqdm(val_loader, desc=f'val skip {skip_num}')): 87 | sample = batch2device(sample) 88 | self.model.set_test_input(sample) 89 | self.model.forward_joint_test() 90 | 91 | gt = sample['clean_middle'] 92 | pred = self.model.test_outputs['interp_out'] 93 | 94 | psnr = self.PSNR_calculator(gt, pred).mean().item() 95 | 96 | psnr_interval.update(psnr) 97 | psnr_total.update(psnr) 98 | 99 | output_name = os.path.join(output_save_path, str(i).zfill(5) + '.png') 100 | vutils.save_image(pred[0], output_name) 101 | gt_name = os.path.join(gt_save_path, str(i).zfill(5) + '.png') 102 | vutils.save_image(gt[0], gt_name) 103 | 104 | print(f"[Skip {skip_num}] PSNR: {psnr_interval.avg:.2f}") 105 | psnr_interval.reset() 106 | 107 | avg_psnr = psnr_total.avg 108 | 109 | print(f"\n[Test Summary] Avg PSNR: {avg_psnr:.2f}") 110 | 111 | log_path = os.path.join('./logs', self.args.experiment_name, f'test_result_epoch{epoch}.txt') 112 | with open(log_path, 'w') as f: 113 | f.write(f"Experiment: {self.args.experiment_name}\n") 114 | f.write(f"Epoch: {epoch}\n") 115 | f.write(f"Average PSNR: {avg_psnr:.2f}\n") 116 | 117 | torch.cuda.empty_cache() 118 | self.model.test_outputs = {} 119 | return avg_psnr 120 | 121 | 122 | if __name__ == '__main__': 123 | args = get_argument() 124 | trainer = Trainer(args) 125 | trainer.test_joint(epoch=0) -------------------------------------------------------------------------------- /tools/event_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def extract_events(event): 5 | x = event['x'] 6 | y = event['y'] 7 | t = event['timestamp'] 8 | p = event['polarity'] 9 | return np.stack([t, x, y, p], axis=1) # shape: (N, 4) 10 | 11 | def event_reverse(events): 12 | end_time = events[:, 0].max() 13 | events[:, 0] = end_time - events[:, 0] 14 | events[:, 3][events[:, 3] == 0] = -1 15 | events[:, 3] = -events[:, 3] 16 | events = np.copy(np.flipud(events)) 17 | return events 18 | 19 | def events_to_voxel_grid(events, num_bins, width, height): 20 | assert events.shape[1] == 4 21 | voxel_grid = np.zeros((num_bins, height, width), np.float32).ravel() 22 | 23 | last_stamp, first_stamp = events[-1, 0], events[0, 0] 24 | deltaT = max(last_stamp - first_stamp, 1.0) 25 | 26 | events[:, 0] = (num_bins - 1) * (events[:, 0] - first_stamp) / deltaT 27 | ts, xs, ys, pols = events[:, 0], events[:, 1].astype(int), events[:, 2].astype(int), events[:, 3] 28 | pols[pols == 0] = -1 29 | 30 | # 인덱스 범위 제한 31 | xs = np.clip(xs, 0, width - 1) 32 | ys = np.clip(ys, 0, height - 1) 33 | tis = np.clip(ts.astype(int), 0, num_bins - 1) 34 | 35 | dts = ts - tis 36 | vals_left = pols * (1.0 - dts) 37 | vals_right = pols * dts 38 | 39 | valid_indices = (tis >= 0) & (tis < num_bins) 40 | np.add.at(voxel_grid, xs[valid_indices] + ys[valid_indices] * width + tis[valid_indices] * width * height, vals_left[valid_indices]) 41 | 42 | valid_indices = (tis >= 0) & ((tis + 1) < num_bins) 43 | np.add.at(voxel_grid, xs[valid_indices] + ys[valid_indices] * width + (tis[valid_indices] + 1) * width * height, vals_right[valid_indices]) 44 | 45 | return voxel_grid.reshape((num_bins, height, width)) 46 | -------------------------------------------------------------------------------- /tools/preprocess_events.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from tqdm import tqdm 4 | from event_utils import * 5 | 6 | def parse_args(): 7 | parser = argparse.ArgumentParser(description="Event voxel grid generation arguments") 8 | 9 | parser.add_argument('--skip_nums', type=int, nargs='+', default=[1, 3], 10 | help='List of skip numbers (e.g., --skip_nums 1 3)') 11 | parser.add_argument('--dataset_dir', type=str, default='/home/user/dataset/bs_ergb', 12 | help='Path to input dataset directory') 13 | parser.add_argument('--mode', type=str, default='1_TEST', 14 | help='Dataset mode (e.g., 1_TEST)') 15 | parser.add_argument('--voxel_prefix', type=str, default='event_voxel_grid_bin16', 16 | help='Prefix for voxel grid folder') 17 | 18 | return parser.parse_args() 19 | 20 | 21 | if __name__ == '__main__': 22 | args = parse_args() 23 | skip_num_list = args.skip_nums 24 | dataset_dir = args.dataset_dir 25 | mode = args.mode 26 | event_voxel_dir_prefix = args.voxel_prefix 27 | 28 | width, height, num_bins = 970, 625, 16 29 | dataset_with_mode = os.path.join(dataset_dir, mode) 30 | scene_list = sorted(os.listdir(dataset_with_mode)) 31 | 32 | for scene_name in scene_list: 33 | image_dir = os.path.join(dataset_with_mode, scene_name, 'images') 34 | index_list = sorted([f.split('.png')[0] for f in os.listdir(image_dir) if f.endswith('.png')]) 35 | event_dir = os.path.join(dataset_with_mode, scene_name, 'events') 36 | 37 | for skip_num in skip_num_list: 38 | save_dir = os.path.join(dataset_dir, mode, scene_name, event_voxel_dir_prefix, f"{skip_num}skip") 39 | os.makedirs(save_dir, exist_ok=True) 40 | 41 | num_triplets = (len(index_list) - 1) // (skip_num + 1) 42 | triplets = [] 43 | ## gathering triplets 44 | for i in range(num_triplets): 45 | start = i * (skip_num+1) 46 | end = start + (skip_num+1) 47 | for i in range(1, skip_num+1): 48 | middle = start + i 49 | triplets.append((start, middle, end)) 50 | for start_idx, middle_idx, end_idx in tqdm(triplets, desc=f"[{mode}] {scene_name} - skip{skip_num}"): 51 | event_0t = np.concatenate([extract_events(np.load(os.path.join(event_dir, f"{idx:06d}.npz"))) 52 | for idx in range(start_idx, middle_idx)], axis=0) 53 | event_t1 = np.concatenate([extract_events(np.load(os.path.join(event_dir, f"{idx:06d}.npz"))) 54 | for idx in range(middle_idx, end_idx)], axis=0) 55 | 56 | # event_0t 57 | if event_0t.shape[0] > 0: 58 | mask_0t = (event_0t[:, 1] / 32 < width) & (event_0t[:, 2] / 32 < height) 59 | _0t = np.column_stack((event_0t[mask_0t][:, 0], 60 | event_0t[mask_0t][:, 1] / 32, 61 | event_0t[mask_0t][:, 2] / 32, 62 | event_0t[mask_0t][:, 3])) 63 | event_0t_vox = events_to_voxel_grid(_0t, num_bins, width, height) 64 | event_t0_vox = events_to_voxel_grid(event_reverse(_0t.copy()), num_bins, width, height) 65 | else: 66 | event_0t_vox = event_t0_vox = np.zeros((num_bins, height, width)) 67 | 68 | # event_t1 69 | if event_t1.shape[0] > 0: 70 | mask_t1 = (event_t1[:, 1] / 32 < width) & (event_t1[:, 2] / 32 < height) 71 | _t1 = np.column_stack((event_t1[mask_t1][:, 0], 72 | event_t1[mask_t1][:, 1] / 32, 73 | event_t1[mask_t1][:, 2] / 32, 74 | event_t1[mask_t1][:, 3])) 75 | event_t1_vox = events_to_voxel_grid(_t1, num_bins, width, height) 76 | else: 77 | event_t1_vox = np.zeros((num_bins, height, width)) 78 | 79 | # Save 80 | base_name = f"{start_idx:06d}-{middle_idx:06d}-{end_idx:06d}" 81 | np.savez_compressed(os.path.join(save_dir, base_name + '_0t.npz'), data=event_0t_vox) 82 | np.savez_compressed(os.path.join(save_dir, base_name + '_t0.npz'), data=event_t0_vox) 83 | np.savez_compressed(os.path.join(save_dir, base_name + '_t1.npz'), data=event_t1_vox) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.utils import * 3 | from torch.optim import AdamW 4 | from tensorboardX import SummaryWriter 5 | from tqdm import tqdm, trange 6 | import datetime 7 | import argparse 8 | import os 9 | from tools.unused.dataloader_bsergb import * 10 | from models.model_manager import OurModel 11 | import torch.optim as optim 12 | from utils.flow_utils import * 13 | 14 | 15 | 16 | def get_argument(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--total_epochs', type = int, default=301) 19 | parser.add_argument('--end_epochs_flow', type = int, default=100) 20 | parser.add_argument('--batch_size', type = int, default=1) 21 | parser.add_argument('--val_batch_size', type = int, default=1) 22 | # training params 23 | parser.add_argument('--voxel_num_bins', type = int, default=16) 24 | parser.add_argument('--crop_size', type = int, default=256) 25 | parser.add_argument('--learning_rate', type = float, default=1e-4) 26 | parser.add_argument('--mode', type = str, default='flow') 27 | parser.add_argument('--flow_tb_debug', type = str2bool, default='True') 28 | parser.add_argument('--flow_tb_viz', type = str2bool, default='True') 29 | parser.add_argument('--warp_tb_debug', type = str2bool, default='True') 30 | ## val folder 31 | parser.add_argument('--val_mode', type = str2bool, default='False') 32 | parser.add_argument('--val_skip_num_list', default=[1, 3]) 33 | # model discription 34 | parser.add_argument('--model_folder', type=str, default='final_models') 35 | parser.add_argument('--model_name', type=str, default='ours') 36 | parser.add_argument('--use_smoothness_loss', type=str2bool, default='True') 37 | parser.add_argument('--smoothness_weight', type = float, default=10.0) 38 | # data loading params 39 | parser.add_argument('--num_threads', type = int, default=12) 40 | parser.add_argument('--experiment_name', type = str, default='train_bsergb_networks') 41 | parser.add_argument('--tb_update_thresh', type = int, default=1) 42 | parser.add_argument('--data_dir', type = str, default = '/media/mnt2/bs_ergb') 43 | parser.add_argument('--use_multigpu', type=str2bool, default='True') 44 | parser.add_argument('--train_skip_num_list', default=[1, 3]) 45 | # loading module 46 | args = parser.parse_args() 47 | return args 48 | 49 | 50 | class Trainer(object): 51 | def __init__(self, args): 52 | self.args = args 53 | self._init_counters() 54 | self._init_tensorboard() 55 | self._init_dataloader() 56 | self._init_model() 57 | self._init_optimizer() 58 | self._init_scheduler() 59 | self._init_metrics() 60 | 61 | def _init_counters(self): 62 | self.tb_iter_cnt = 0 63 | self.tb_iter_cnt_val = 0 64 | self.tb_iter_cnt2 = 0 65 | self.tb_iter_cnt2_val = 0 66 | self.tb_iter_thresh = self.args.tb_update_thresh 67 | self.batchsize = self.args.batch_size 68 | self.start_epoch = 0 69 | self.end_epoch = self.args.total_epochs 70 | self.best_psnr = 0.0 71 | self.start_epoch_flow = 0 72 | self.end_epoch_flow = self.args.end_epochs_flow 73 | self.start_epoch_joint = self.args.end_epochs_flow + 1 74 | 75 | def _init_tensorboard(self): 76 | timestamp = datetime.datetime.now().strftime('%y%m%d-%H%M') 77 | tb_path = os.path.join('./experiments', f"{timestamp}-{self.args.experiment_name}") 78 | self.tb = SummaryWriter(tb_path, flush_secs=1) 79 | 80 | def _init_dataloader(self): 81 | ## train set 82 | train_set = get_BSERGB_train_dataset(self.args.data_dir, self.args.train_skip_num_list, mode='3_TRAINING') 83 | self.train_loader = torch.utils.data.DataLoader(train_set, batch_size=self.args.batch_size, shuffle=True, num_workers=self.args.num_threads, pin_memory=True, drop_last=True) 84 | ## val set 85 | val_set_dict = get_BSERGB_val_dataset(self.args.data_dir, self.args.val_skip_num_list, mode='1_TEST') 86 | # make loader per skip 87 | self.val_loader_dict = {} 88 | for skip_num, val_dataset in val_set_dict.items(): 89 | self.val_loader_dict[skip_num] = torch.utils.data.DataLoader( 90 | val_dataset, 91 | batch_size=self.args.val_batch_size, 92 | shuffle=False, 93 | num_workers=self.args.num_threads, 94 | pin_memory=True 95 | ) 96 | 97 | def _init_model(self): 98 | self.model = OurModel(self.args) 99 | self.model.initialize(self.args.model_folder, self.args.model_name) 100 | 101 | if torch.cuda.is_available(): 102 | self.model.cuda() 103 | 104 | if self.args.use_multigpu: 105 | self.model.use_multi_gpu() 106 | 107 | def _init_optimizer(self): 108 | params = self.model.get_optimizer_params() 109 | self.optimizer = AdamW(params, lr=self.args.learning_rate) 110 | 111 | def _init_scheduler(self): 112 | if self.args.mode == 'joint': 113 | milestones = [200, 270] 114 | elif self.args.mode == 'flow': 115 | milestones = [60] 116 | else: 117 | milestones = [] 118 | 119 | if milestones: 120 | self.scheduler = optim.lr_scheduler.MultiStepLR( 121 | self.optimizer, 122 | milestones=milestones, 123 | gamma=0.5 124 | ) 125 | 126 | def _init_metrics(self): 127 | self.PSNR_calculator = PSNR() 128 | self.SSIM_calculator = SSIM() 129 | 130 | def mode_classify(self): 131 | # Mode override by argument 132 | if self.args.mode == 'joint': 133 | mode = 'joint' 134 | elif self.epoch <= self.end_epoch_flow: 135 | mode = 'flow' 136 | elif self.start_epoch_joint <= self.epoch <= self.end_epoch: 137 | mode = 'joint' 138 | else: 139 | raise ValueError(f"Invalid epoch {self.epoch} for mode classification.") 140 | self.model.set_mode(mode) 141 | # Automatically freeze flownet if in joint mode 142 | if mode == 'joint': 143 | self.model.fix_flownet() 144 | return mode 145 | 146 | def train(self): 147 | for self.epoch in trange(self.start_epoch, self.end_epoch, desc='epoch progress'): 148 | self.model.train() 149 | mode_now = self.mode_classify() 150 | 151 | for _, sample in enumerate(tqdm(self.train_loader, desc='train progress')): 152 | self.train_step(sample, mode=mode_now) 153 | 154 | if self.epoch % 10 == 0 and mode_now == 'joint': 155 | psnr_val, _ = self.val_joint(self.epoch) 156 | if psnr_val > self.best_psnr: 157 | self.best_psnr = psnr_val 158 | self.save_model(self.epoch, best=True) 159 | print(f"[Best Model Updated] Epoch {self.epoch} - PSNR: {psnr_val:.2f}") 160 | 161 | self.scheduler.step() 162 | 163 | 164 | def train_step(self, sample, mode): 165 | # --- Move batch to device and zero optimizer --- 166 | sample = batch2device(sample) 167 | self.optimizer.zero_grad() 168 | 169 | # --- Set input for model --- 170 | self.model.set_train_input(sample) 171 | 172 | # --- Forward pass and compute loss --- 173 | self.model.forward_nets() 174 | if mode == 'flow': 175 | loss = self.model.get_flow_loss() 176 | elif mode == 'joint': 177 | loss = self.model.get_multi_scale_loss() 178 | else: 179 | raise ValueError(f"Unsupported mode: {mode}") 180 | 181 | # --- Backpropagation and optimization --- 182 | loss.backward() 183 | self.optimizer.step() 184 | 185 | # --- Update training status --- 186 | self.model.update_loss_meters(mode) 187 | self.tb_iter_cnt += 1 188 | 189 | if self.batchsize * self.tb_iter_cnt > self.tb_iter_thresh: 190 | self.log_train_tb(mode) 191 | 192 | # --- Clean up --- 193 | del sample 194 | 195 | 196 | def log_train_tb(self, mode): 197 | def add_scalar(tag, value): 198 | self.tb.add_scalar(tag, value, self.tb_iter_cnt2) 199 | 200 | def add_image(tag, image): 201 | self.tb.add_image(tag, image, self.tb_iter_cnt2) 202 | 203 | def add_flow_image(tag, flow_tensor): 204 | flow_img = flow_to_image(flow_tensor.detach().cpu().permute(1, 2, 0).numpy()).transpose(2, 0, 1) 205 | add_image(tag, flow_img) 206 | 207 | # --- Log loss values --- 208 | add_scalar('train_progress/loss_total', self.model.loss_handler.loss_total_meter.avg) 209 | add_scalar('train_progress/loss_flow', self.model.loss_handler.loss_flow_meter.avg) 210 | add_scalar('train_progress/loss_warp', self.model.loss_handler.loss_warp_meter.avg) 211 | add_scalar('train_progress/loss_smoothness', self.model.loss_handler.loss_smoothness_meter.avg) 212 | 213 | # --- Log interpolation input images --- 214 | add_image('train_image/clean_image_first', self.model.batch['image_input0'][0]) 215 | add_image('train_image/clean_image_last', self.model.batch['image_input1'][0]) 216 | add_image('train_image/interp_gt', self.model.batch['clean_gt_images'][0]) 217 | 218 | # --- Log predicted optical flow (estimated) --- 219 | if self.args.flow_tb_viz: 220 | add_flow_image('train_flow/flow_t0_est', self.model.outputs['flow_out']['flow_t0_dict'][0][0]) 221 | add_flow_image('train_flow/flow_t1_est', self.model.outputs['flow_out']['flow_t1_dict'][0][0]) 222 | 223 | # --- Debug intermediate flow results --- 224 | if self.args.flow_tb_debug: 225 | add_flow_image('train_flow_debug_0/flow_event', self.model.outputs['flow_out']['event_flow_dict'][0][0]) 226 | add_flow_image('train_flow_debug_0/flow_image', self.model.outputs['flow_out']['image_flow_dict'][0][0]) 227 | add_flow_image('train_flow_debug_0/flow_fusion', self.model.outputs['flow_out']['fusion_flow_dict'][0][0]) 228 | add_image('train_flow_debug_0/event_flow_mask', self.model.outputs['flow_out']['mask_dict'][0][0]) 229 | 230 | # --- Joint training-specific logging --- 231 | if mode == 'joint': 232 | add_scalar('train_progress/loss_image', self.model.loss_handler.loss_image_meter.avg) 233 | add_image('train_image/interp_out', self.model.outputs['interp_out'][0][0]) 234 | elif mode == 'flow': 235 | # --- Warp output visualization --- 236 | if self.args.warp_tb_debug: 237 | add_image('train_warp_output/warp_image_0t', self.model.batch['imaget_est0_warp'][0][0]) 238 | add_image('train_warp_output/warp_image_t1', self.model.batch['imaget_est1_warp'][0][0]) 239 | add_image('train_warp_output/warp_image_gt', self.model.batch['clean_gt_images'][0]) 240 | 241 | # --- Update counters and reset meters --- 242 | self.tb_iter_cnt2 += 1 243 | self.tb_iter_cnt = 0 244 | self.model.loss_handler.reset_meters() 245 | 246 | def val_joint(self, epoch): 247 | # Total and per-interval metric meters 248 | psnr_total = AverageMeter() 249 | ssim_total = AverageMeter() 250 | psnr_interval = AverageMeter() 251 | ssim_interval = AverageMeter() 252 | 253 | # Set model to evaluation mode 254 | self.model.eval() 255 | # set model mode 256 | self.model.set_mode('joint') 257 | 258 | with torch.no_grad(): 259 | for skip_num, val_loader in self.val_loader_dict.items(): 260 | for _, sample in enumerate(tqdm(val_loader, desc=f'val skip {skip_num}')): 261 | sample = batch2device(sample) 262 | self.model.set_test_input(sample) 263 | self.model.forward_joint_test() 264 | 265 | gt = sample['clean_middle'] 266 | pred = self.model.test_outputs['interp_out'] 267 | 268 | psnr = self.PSNR_calculator(gt, pred).mean().item() 269 | ssim = self.SSIM_calculator(gt, pred).mean().item() 270 | 271 | psnr_interval.update(psnr) 272 | ssim_interval.update(ssim) 273 | 274 | psnr_total.update(psnr) 275 | ssim_total.update(ssim) 276 | 277 | # Log per interval result 278 | self.tb.add_scalar(f'val_progress/BSERGB/{skip_num}skip/avg_psnr_interp', psnr_interval.avg, epoch) 279 | self.tb.add_scalar(f'val_progress/BSERGB/{skip_num}skip/avg_ssim_interp', ssim_interval.avg, epoch) 280 | 281 | psnr_interval.reset() 282 | ssim_interval.reset() 283 | 284 | # Log total result 285 | self.tb.add_scalar('val_progress/BSERGB/average/avg_psnr_interp', psnr_total.avg, epoch) 286 | self.tb.add_scalar('val_progress/BSERGB/average/avg_ssim_interp', ssim_total.avg, epoch) 287 | 288 | torch.cuda.empty_cache() 289 | self.model.test_outputs = {} 290 | return psnr_total.avg, ssim_total.avg 291 | 292 | def save_model(self, epoch): 293 | combined_state_dict = { 294 | 'epoch': self.epoch, 295 | 'model_state_dict': self.model.net.state_dict(), 296 | 'Optimizer_state_dict' : self.optimizer.state_dict(), 297 | 'Scheduler_state_dict' : self.scheduler.state_dict()} 298 | torch.save(combined_state_dict, os.path.join(self.model.save_path, 'best_model_' + str(epoch) + '_ep.pth')) 299 | 300 | 301 | if __name__=='__main__': 302 | args = get_argument() 303 | trainer = Trainer(args) 304 | if args.val_mode == True: 305 | trainer.val_joint(0) 306 | else: 307 | trainer.train() 308 | -------------------------------------------------------------------------------- /utils/dataloader_bsergb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.utils.data as data 3 | from torchvision import transforms 4 | import numpy as np 5 | import random 6 | from PIL import Image 7 | from torch.utils.data import ConcatDataset 8 | import re 9 | 10 | 11 | 12 | 13 | class BSERGB_val_dataset(data.Dataset): 14 | def __init__(self, data_path, skip_num): 15 | super(BSERGB_val_dataset, self).__init__() 16 | # Image and event prefix 17 | self.event_vox_prefix = 'event_voxel_grid_bin16' 18 | self.sharp_image_prefix = 'images' 19 | # skip list 20 | self.skip_num = skip_num 21 | # Transform 22 | self.transform = transforms.ToTensor() 23 | # Initialize file taxonomy 24 | self.get_filetaxnomy(data_path) 25 | ## crop size 26 | self.crop_height = 256 27 | self.crop_width = 256 28 | 29 | def get_filetaxnomy(self, data_dir): 30 | event_voxel_dir = os.path.join(data_dir, self.event_vox_prefix, str(self.skip_num) + 'skip') 31 | clean_image_dir = os.path.join(data_dir, self.sharp_image_prefix) 32 | index_list = [f.split('.png')[0] for f in os.listdir(clean_image_dir) if f.endswith(".png")] 33 | index_list.sort() 34 | self.input_name_dict = {} 35 | self.input_name_dict['event_voxels_0t'] = [] 36 | self.input_name_dict['event_voxels_t0'] = [] 37 | self.input_name_dict['event_voxels_t1'] = [] 38 | self.input_name_dict['clean_image_first'] = [] 39 | self.input_name_dict['clean_image_last'] = [] 40 | self.input_name_dict['gt_image'] = [] 41 | num_triplets = (len(index_list)-1) // (self.skip_num+1) 42 | triplets = [] 43 | ## gathering triplets 44 | for i in range(num_triplets): 45 | start = i * (self.skip_num+1) 46 | end = start + (self.skip_num+1) 47 | for i in range(1, self.skip_num+1): 48 | middle = start + i 49 | triplets.append((start, middle, end)) 50 | for triplet in triplets: 51 | first_idx = triplet[0] 52 | middle_idx = triplet[1] 53 | end_idx = triplet[2] 54 | self.input_name_dict['clean_image_first'].append(os.path.join(clean_image_dir, str(first_idx).zfill(6) + '.png')) 55 | self.input_name_dict['clean_image_last'].append(os.path.join(clean_image_dir, str(end_idx).zfill(6) + '.png')) 56 | self.input_name_dict['gt_image'].append(os.path.join(clean_image_dir, str(middle_idx).zfill(6) + '.png')) 57 | self.input_name_dict['event_voxels_0t'].append(os.path.join(event_voxel_dir, f"{first_idx:06d}-{middle_idx:06d}-{end_idx:06d}_0t.npz")) 58 | self.input_name_dict['event_voxels_t0'].append(os.path.join(event_voxel_dir, f"{first_idx:06d}-{middle_idx:06d}-{end_idx:06d}_t0.npz")) 59 | self.input_name_dict['event_voxels_t1'].append(os.path.join(event_voxel_dir, f"{first_idx:06d}-{middle_idx:06d}-{end_idx:06d}_t1.npz")) 60 | 61 | def __getitem__(self, index): 62 | # first image 63 | first_image_path = self.input_name_dict['clean_image_first'][index] 64 | first_image = Image.open(first_image_path) 65 | first_image_tensor = self.transform(first_image) 66 | # second image 67 | second_image_path = self.input_name_dict['clean_image_last'][index] 68 | second_image = Image.open(second_image_path) 69 | second_image_tensor = self.transform(second_image) 70 | # gt image 71 | gt_image_path = self.input_name_dict['gt_image'][index] 72 | gt_image = Image.open(gt_image_path) 73 | gt_image_tensor = self.transform(gt_image) 74 | ## event voxel 75 | # 0t voxel 76 | event_vox_0t_path = self.input_name_dict['event_voxels_0t'][index] 77 | event_vox_0t = np.load(event_vox_0t_path)["data"] 78 | # t1 voxel 79 | event_vox_t1_path = self.input_name_dict['event_voxels_t1'][index] 80 | event_vox_t1 = np.load(event_vox_t1_path)["data"] 81 | # 1t voxel 82 | event_vox_t0_path = self.input_name_dict['event_voxels_t0'][index] 83 | event_vox_t0 = np.load(event_vox_t0_path)["data"] 84 | ## return sample!! 85 | sample = dict() 86 | sample['clean_middle'] = gt_image_tensor 87 | sample['clean_image_first'] = first_image_tensor 88 | sample['clean_image_last'] = second_image_tensor 89 | sample['voxel_grid_0t'] = event_vox_0t 90 | sample['voxel_grid_t1'] = event_vox_t1 91 | sample['voxel_grid_t0'] = event_vox_t0 92 | return sample 93 | 94 | def __len__(self): 95 | return len(self.input_name_dict['gt_image']) 96 | 97 | def get_BSERGB_val_dataset(data_dir, skip_list, mode='1_TEST'): 98 | dataset_path_sub = os.path.join(data_dir, mode) 99 | scene_list = sorted(os.listdir(dataset_path_sub)) 100 | dataset_dict = {} 101 | for skip_num in skip_list: 102 | dataset_list = [] 103 | for scene in scene_list: 104 | dataset_path_full = os.path.join(dataset_path_sub, scene) 105 | dset = BSERGB_val_dataset(dataset_path_full, skip_num) 106 | dataset_list.append(dset) 107 | dataset_concat = ConcatDataset(dataset_list) 108 | dataset_dict[skip_num] = dataset_concat 109 | return dataset_dict 110 | 111 | 112 | -------------------------------------------------------------------------------- /utils/flow_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import inspect 6 | UNKNOWN_FLOW_THRESH = 1e7 7 | 8 | 9 | def compute_color(u, v): 10 | """ 11 | compute optical flow color map 12 | :param u: optical flow horizontal map 13 | :param v: optical flow vertical map 14 | :return: optical flow in color code 15 | """ 16 | [h, w] = u.shape 17 | img = np.zeros([h, w, 3]) 18 | nanIdx = np.isnan(u) | np.isnan(v) 19 | u[nanIdx] = 0 20 | v[nanIdx] = 0 21 | 22 | colorwheel = make_color_wheel() 23 | ncols = np.size(colorwheel, 0) 24 | 25 | rad = np.sqrt(u**2+v**2) 26 | 27 | a = np.arctan2(-v, -u) / np.pi 28 | 29 | fk = (a+1) / 2 * (ncols - 1) + 1 30 | 31 | k0 = np.floor(fk).astype(int) 32 | 33 | k1 = k0 + 1 34 | k1[k1 == ncols+1] = 1 35 | f = fk - k0 36 | 37 | for i in range(0, np.size(colorwheel,1)): 38 | tmp = colorwheel[:, i] 39 | col0 = tmp[k0-1] / 255 40 | col1 = tmp[k1-1] / 255 41 | col = (1-f) * col0 + f * col1 42 | 43 | idx = rad <= 1 44 | col[idx] = 1-rad[idx]*(1-col[idx]) 45 | notidx = np.logical_not(idx) 46 | 47 | col[notidx] *= 0.75 48 | img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx))) 49 | return img 50 | 51 | 52 | def mesh_grid(B, H, W): 53 | # mesh grid 54 | x_base = torch.arange(0, W).repeat(B, H, 1) # BHW 55 | y_base = torch.arange(0, H).repeat(B, W, 1).transpose(1, 2) # BHW 56 | base_grid = torch.stack([x_base, y_base], 1) # B2HW 57 | return base_grid 58 | 59 | 60 | def norm_grid(v_grid): 61 | _, _, H, W = v_grid.size() 62 | 63 | # scale grid to [-1,1] 64 | v_grid_norm = torch.zeros_like(v_grid) 65 | v_grid_norm[:, 0, :, :] = 2.0 * v_grid[:, 0, :, :] / (W - 1) - 1.0 66 | v_grid_norm[:, 1, :, :] = 2.0 * v_grid[:, 1, :, :] / (H - 1) - 1.0 67 | return v_grid_norm.permute(0, 2, 3, 1) # BHW2 68 | 69 | def cal_grad2_error(flo, image, beta): 70 | """ 71 | Calculate the image-edge-aware second-order smoothness loss for flo 72 | """ 73 | def gradient(pred): 74 | D_dy = pred[:, :, 1:, :] - pred[:, :, :-1, :] 75 | D_dx = pred[:, :, :, 1:] - pred[:, :, :, :-1] 76 | return D_dx, D_dy 77 | 78 | img_grad_x, img_grad_y = gradient(image) 79 | weights_x = torch.exp(-10.0 * torch.mean(torch.abs(img_grad_x), 1, keepdim=True)) 80 | weights_y = torch.exp(-10.0 * torch.mean(torch.abs(img_grad_y), 1, keepdim=True)) 81 | 82 | dx, dy = gradient(flo) 83 | dx2, dxdy = gradient(dx) 84 | dydx, dy2 = gradient(dy) 85 | return (torch.mean(beta*weights_x[:,:, :, 1:]*torch.abs(dx2)) + torch.mean(beta*weights_y[:, :, 1:, :]*torch.abs(dy2))) / 2.0 86 | 87 | def make_colorwheel(): 88 | """ 89 | Generates a color wheel for optical flow visualization as presented in: 90 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 91 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 92 | Code follows the original C++ source code of Daniel Scharstein. 93 | Code follows the the Matlab source code of Deqing Sun. 94 | Returns: 95 | np.ndarray: Color wheel 96 | """ 97 | 98 | RY = 15 99 | YG = 6 100 | GC = 4 101 | CB = 11 102 | BM = 13 103 | MR = 6 104 | 105 | ncols = RY + YG + GC + CB + BM + MR 106 | colorwheel = np.zeros((ncols, 3)) 107 | col = 0 108 | 109 | # RY 110 | colorwheel[0:RY, 0] = 255 111 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) 112 | col = col+RY 113 | # YG 114 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) 115 | colorwheel[col:col+YG, 1] = 255 116 | col = col+YG 117 | # GC 118 | colorwheel[col:col+GC, 1] = 255 119 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) 120 | col = col+GC 121 | # CB 122 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) 123 | colorwheel[col:col+CB, 2] = 255 124 | col = col+CB 125 | # BM 126 | colorwheel[col:col+BM, 2] = 255 127 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) 128 | col = col+BM 129 | # MR 130 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) 131 | colorwheel[col:col+MR, 0] = 255 132 | return colorwheel 133 | 134 | 135 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 136 | """ 137 | Applies the flow color wheel to (possibly clipped) flow components u and v. 138 | According to the C++ source code of Daniel Scharstein 139 | According to the Matlab source code of Deqing Sun 140 | Args: 141 | u (np.ndarray): Input horizontal flow of shape [H,W] 142 | v (np.ndarray): Input vertical flow of shape [H,W] 143 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 144 | Returns: 145 | np.ndarray: Flow visualization image of shape [H,W,3] 146 | """ 147 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 148 | colorwheel = make_colorwheel() # shape [55x3] 149 | ncols = colorwheel.shape[0] 150 | rad = np.sqrt(np.square(u) + np.square(v)) 151 | a = np.arctan2(-v, -u)/np.pi 152 | fk = (a+1) / 2*(ncols-1) 153 | k0 = np.floor(fk).astype(np.int32) 154 | k1 = k0 + 1 155 | k1[k1 == ncols] = 0 156 | f = fk - k0 157 | for i in range(colorwheel.shape[1]): 158 | tmp = colorwheel[:,i] 159 | col0 = tmp[k0] / 255.0 160 | col1 = tmp[k1] / 255.0 161 | col = (1-f)*col0 + f*col1 162 | idx = (rad <= 1) 163 | col[idx] = 1 - rad[idx] * (1-col[idx]) 164 | col[~idx] = col[~idx] * 0.75 # out of range 165 | # Note the 2-i => BGR instead of RGB 166 | ch_idx = 2-i if convert_to_bgr else i 167 | flow_image[:,:,ch_idx] = np.floor(255*col) 168 | return flow_image 169 | 170 | def flow_warp(x, flow, pad='border', mode='bilinear'): 171 | B, _, H, W = x.size() 172 | 173 | base_grid = mesh_grid(B, H, W).type_as(x) # B2HW 174 | v_grid = norm_grid(base_grid + flow) # BHW2 175 | if 'align_corners' in inspect.getfullargspec(torch.nn.functional.grid_sample).args: 176 | im1_recons = nn.functional.grid_sample(x, v_grid, mode=mode, padding_mode=pad, align_corners=True) 177 | else: 178 | im1_recons = nn.functional.grid_sample(x, v_grid, mode=mode, padding_mode=pad) 179 | return im1_recons 180 | 181 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 182 | """ 183 | Expects a two dimensional flow image of shape. 184 | Args: 185 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 186 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 187 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 188 | Returns: 189 | np.ndarray: Flow visualization image of shape [H,W,3] 190 | """ 191 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 192 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 193 | if clip_flow is not None: 194 | flow_uv = np.clip(flow_uv, 0, clip_flow) 195 | u = flow_uv[:,:,0] 196 | v = flow_uv[:,:,1] 197 | rad = np.sqrt(np.square(u) + np.square(v)) 198 | rad_max = np.max(rad) 199 | epsilon = 1e-5 200 | u = u / (rad_max + epsilon) 201 | v = v / (rad_max + epsilon) 202 | return flow_uv_to_colors(u, v, convert_to_bgr) 203 | 204 | def flow2rgb(flow_map, max_value): 205 | flow_map_np = flow_map.detach().cpu().numpy() 206 | _, h, w = flow_map_np.shape 207 | flow_map_np[:,(flow_map_np[0] == 0) & (flow_map_np[1] == 0)] = float('nan') 208 | rgb_map = np.ones((3,h,w)).astype(np.float32) 209 | if max_value is not None: 210 | normalized_flow_map = flow_map_np / max_value 211 | else: 212 | normalized_flow_map = flow_map_np / (np.abs(flow_map_np).max()) 213 | rgb_map[0] += normalized_flow_map[0] 214 | rgb_map[1] -= 0.5*(normalized_flow_map[0] + normalized_flow_map[1]) 215 | rgb_map[2] += normalized_flow_map[1] 216 | return rgb_map.clip(0,1) 217 | 218 | def warp(x, flo, return_mask=False): 219 | B, C, H, W = x.size() 220 | # mesh grid 221 | xx = torch.arange(0, W).view(1, 1, 1, W).expand(B, 1, H, W) 222 | yy = torch.arange(0, H).view(1, 1, H, 1).expand(B, 1, H, W) 223 | 224 | grid = torch.cat((xx, yy), 1).float() 225 | 226 | if x.is_cuda: 227 | grid = grid.cuda() 228 | 229 | vgrid = torch.autograd.Variable(grid) + flo 230 | 231 | # scale grid to [-1,1] 232 | vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :] / max(W - 1, 1) - 1.0 233 | vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :] / max(H - 1, 1) - 1.0 234 | 235 | vgrid = vgrid.permute(0, 2, 3, 1) 236 | output = nn.functional.grid_sample(x, vgrid) 237 | mask = torch.autograd.Variable(torch.ones(x.size())).cuda() 238 | mask = nn.functional.grid_sample(mask, vgrid) 239 | 240 | mask = mask.masked_fill_(mask < 0.999, 0) 241 | mask = mask.masked_fill_(mask > 0, 1) 242 | 243 | if return_mask: 244 | return output, mask 245 | else: 246 | return output 247 | 248 | def EPE(input_flow, target_flow, sparse=False, mean=True): 249 | EPE_map = torch.norm(target_flow-input_flow,2,1) 250 | batch_size = EPE_map.size(0) 251 | if sparse: 252 | # invalid flow is defined with both flow coordinates to be exactly 0 253 | mask = (target_flow[:,0] == 0) & (target_flow[:,1] == 0) 254 | 255 | EPE_map = EPE_map[~mask] 256 | if mean: 257 | return EPE_map.mean() 258 | else: 259 | return EPE_map.sum()/batch_size 260 | 261 | def realEPE(output, target, sparse=False): 262 | b, _, h, w = target.size() 263 | upsampled_output = F.interpolate(output, (h,w), mode='bilinear', align_corners=False) 264 | return EPE(upsampled_output, target, sparse, mean=True) -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from torch.autograd import Variable 5 | import torch.nn.functional as F 6 | from math import exp 7 | 8 | def str2bool(v): 9 | return v.lower() in ('true') 10 | 11 | def tensor2numpy(tensor, rgb_range=1.): 12 | rgb_coefficient = 255 / rgb_range 13 | img = tensor.mul(rgb_coefficient).clamp(0, 255).round() 14 | img = img[0].data 15 | img = np.transpose(img.cpu().numpy(), (1, 2, 0)).astype(np.uint8) 16 | return img 17 | 18 | def batch2device(dictionary_of_tensors): 19 | if isinstance(dictionary_of_tensors, dict): 20 | return {key: batch2device(value) for key, value in dictionary_of_tensors.items()} 21 | return dictionary_of_tensors.cuda() 22 | 23 | def str2bool(v): 24 | return v.lower() in ('true') 25 | 26 | def randomCrop(tensor, x, y, height, width): 27 | tensor = tensor[..., y:y+height, x:x+width] 28 | return tensor 29 | 30 | def tensor2numpy(tensor, rgb_range=1.): 31 | rgb_coefficient = 255 / rgb_range 32 | img = tensor.mul(rgb_coefficient).clamp(0, 255).round() 33 | img = img[0].data 34 | img = np.transpose(img.cpu().numpy(), (1, 2, 0)).astype(np.uint8) 35 | return img 36 | 37 | def tensor2numpy_batch_idxs(tensor, batch_idx, rgb_range=1.): 38 | rgb_coefficient = 255 / rgb_range 39 | img = tensor.mul(rgb_coefficient).clamp(0, 255).round() 40 | img = img[batch_idx].data 41 | img = np.transpose(img.cpu().numpy(), (1, 2, 0)).astype(np.uint8) 42 | return img 43 | 44 | 45 | class AverageMeter(object): 46 | """Computes and stores the average and current value""" 47 | def __init__(self): 48 | self.reset() 49 | 50 | def reset(self): 51 | self.val = 0 52 | self.avg = 0 53 | self.sum = 0 54 | self.count = 0 55 | 56 | def update(self, val, n=1): 57 | self.val = val 58 | self.sum += val * n 59 | self.count += n 60 | self.avg = self.sum / self.count 61 | 62 | def gaussian(window_size, sigma): 63 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 64 | return gauss/gauss.sum() 65 | 66 | def create_window(window_size, channel): 67 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 68 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 69 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 70 | return window 71 | 72 | class PSNR: 73 | def __init__(self): 74 | self.name = "PSNR" 75 | 76 | @staticmethod 77 | def __call__(img1, img2): 78 | img1 = img1.reshape(img1.shape[0], -1) 79 | img2 = img2.reshape(img2.shape[0], -1) 80 | mse = torch.mean((img1 - img2) ** 2, dim=1) 81 | return 10* torch.log10(1 / mse) 82 | 83 | 84 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 85 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 86 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 87 | 88 | mu1_sq = mu1.pow(2) 89 | mu2_sq = mu2.pow(2) 90 | mu1_mu2 = mu1*mu2 91 | 92 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 93 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 94 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 95 | 96 | C1 = 0.01**2 97 | C2 = 0.03**2 98 | 99 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 100 | 101 | if size_average: 102 | return ssim_map.mean() 103 | else: 104 | return ssim_map.mean(1).mean(1).mean(1) 105 | 106 | 107 | class SSIM(torch.nn.Module): 108 | def __init__(self, window_size = 11, size_average = False): 109 | super(SSIM, self).__init__() 110 | self.window_size = window_size 111 | self.size_average = size_average 112 | self.channel = 1 113 | self.window = create_window(window_size, self.channel) 114 | 115 | def forward(self, img1, img2): 116 | (_, channel, _, _) = img1.size() 117 | 118 | if channel == self.channel and self.window.data.type() == img1.data.type(): 119 | window = self.window 120 | else: 121 | window = create_window(self.window_size, channel) 122 | 123 | if img1.is_cuda: 124 | window = window.cuda(img1.get_device()) 125 | window = window.type_as(img1) 126 | 127 | self.window = window 128 | self.channel = channel 129 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 130 | --------------------------------------------------------------------------------