├── .gitignore
├── README.md
├── correlation_package
├── __init__.py
├── correlation.py
├── correlation_cuda.cc
├── correlation_cuda_kernel.cu
├── correlation_cuda_kernel.cuh
├── nvcc setting.md
├── pyproject.toml
└── setup.py
├── figure
├── ERF_dataset.png
├── Quantitative_eval_ERF_x170FPS.png
├── driving.gif
├── flower.gif
└── popcorn.gif
├── install_correlation.sh
├── models
├── final_models
│ ├── ours.py
│ ├── ours_large.py
│ └── submodules.py
├── loss_handler.py
└── model_manager.py
├── pretrained_model
└── README.md
├── run_samples.py
├── sample_data
├── 00000.png
├── 00000_0t.npz
├── 00000_t0.npz
├── 00000_t1.npz
└── 00001.png
├── test_bsergb.py
├── tools
├── event_utils.py
└── preprocess_events.py
├── train.py
└── utils
├── dataloader_bsergb.py
├── flow_utils.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | output
3 | experiments
4 | pretrained_model
5 | tools/unused
6 | logs
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CBMNet(CVPR 2023, highlight)
2 | **Official repository for the CVPR 2023 paper, "Event-based Video Frame Interpolation with Cross-Modal Asymmetric Bidirectional Motion Fields"**
3 |
4 | \[[Paper](https://openaccess.thecvf.com/content/CVPR2023/papers/Kim_Event-Based_Video_Frame_Interpolation_With_Cross-Modal_Asymmetric_Bidirectional_Motion_Fields_CVPR_2023_paper.pdf)\]
5 | \[[Supp](https://openaccess.thecvf.com/content/CVPR2023/supplemental/Kim_Event-Based_Video_Frame_CVPR_2023_supplemental.zip)\]
6 |
7 |
8 |
9 | ## Qualitative video demos on ERF-X170FPS dataset
10 | ### Falling pop-corn
11 |
12 |
15 |
16 | ### Flowers
17 |
18 |
21 |
22 | ### Driving scene
23 |
24 |
27 |
28 | ## ERF-X170FPS dataset
29 | #### Dataset of high-resolution (1440x975), high-fps (170fps) video frames plus high resolution events with extremely large motion using the beam-splitter acquisition system:
30 | 
31 |
32 | ### Quantitative results on the ERF-X170FPS datasets
33 |
34 |
35 |
36 | ### Downloading ERF-X170FPS datasets
37 | You can download the raw-data(collected frame and events) from this links
38 |
39 | * [[RAW-Train](https://drive.google.com/file/d/1Bsf9qreziPcVEuf0_v3kjdPUh27zsFXK/view?usp=drive_link)]
40 | * [[RAW-Test](https://drive.google.com/file/d/1Dk7jVQD29HqRVV11e8vxg5bDOh6KxrzL/view?usp=drive_link)]
41 |
42 | ** Cautions: the x,y coordinates of the raw event file are multiplied by 128.
43 |
44 | ## Requirements
45 | * PyTorch 1.8.0
46 | * CUDA 11.2
47 | * python 3.8
48 |
49 | ## Quick Usage
50 |
51 | Download repository:
52 |
53 | ```bash
54 | $ git clone https://github.com/intelpro/CBMNet
55 | ```
56 |
57 | Install correlation package:
58 |
59 | ```bash
60 | $ sh install_correlation.sh
61 | ```
62 |
63 | Download network weights(trained on ERF-X170FPS datasets) and place downloaded model in ./pretrained_model/
64 |
65 | * [[Ours](https://drive.google.com/file/d/1VJKyuoRSMOJkl8fQlJIkc7S4Fmd2_X8K/view?usp=sharing)]
66 | * [[Ours-Large](https://drive.google.com/file/d/1jI6_RwhXeM-pW5CnHf0exB5RP2zp2SbY/view?usp=sharing)]
67 |
68 |
69 | Generate an intermediate video frame using ours model:
70 |
71 | ```bash
72 | $ python run_samples.py --model_name ours --ckpt_path pretrained_model/ours_weight.pth --save_output_dir ./output --image_number 0
73 | ```
74 |
75 | Also, you can generate intermediate video frame using ours-large model:
76 |
77 | ```bash
78 | $ python run_samples.py --model_name ours_large --ckpt_path pretrained_model/ours_large_weight.pth --save_output_dir ./output --image_number 0
79 |
80 | ```
81 |
82 |
83 | ## 🚀 Quick Test on BSERGB Dataset
84 |
85 | This section describes how to test the model on the **BSERGB dataset** using the pre-trained weights.
86 |
87 | ### 1. Download BSERGB Dataset
88 |
89 | You can download the [BSERGB dataset](https://github.com/uzh-rpg/timelens-pp) from the official TimeLens++ GitHub repository.
90 |
91 | ### 2. Preprocess Event Voxel Data
92 |
93 | After downloading, the BSERGB dataset should have the following directory structure:
94 |
95 |
96 | ```
97 | ├── BSERGB/
98 | │ ├── 1_TEST/
99 | │ │ ├── scene_001/
100 | │ │ │ ├── images/
101 | │ │ │ │ ├── 000000.png
102 | │ │ │ │ ├── ...
103 | │ │ │ ├── events/
104 | │ │ │ │ ├── 000000.npz
105 | │ │ │ │ ├── ...
106 | │ │ ├── scene_002/
107 | │ │ ├── scene_003/
108 | │ │ ├── ...
109 | │ ├── 2_VALIDATION/
110 | │ │ ├── scene_001/
111 | │ │ ├── scene_002/
112 | │ │ ├── ...
113 | │ ├── 3_TRAINING/
114 | │ │ ├── scene_001/
115 | │ │ ├── scene_002/
116 | │ │ ├── ...
117 | ```
118 |
119 | Now, convert the raw event data into event voxel grids using the following command:
120 |
121 |
122 | ```bash
123 | $ python tools/preprocess_events.py --dataset_dir BSERGB_DATASET_DIR --mode 1_TEST
124 |
125 | ```
126 | - ``--dataset_dir BSERGB_DATASET_DIR``: Specifies the BSERGB dataset directory.
127 | - ``--mode 1_TEST``: Select the mode to convert raw events into event voxels. Choose 1_TEST if you only want to perform testing.
128 |
129 | ### 🛠️ Event Voxel Preprocessing Output
130 |
131 | After preprocessing, event voxel files will be generated and saved into the **target folder**.
132 | For each sample, three types of voxel grids will be saved:
133 |
134 | - `0t`: Events from the start frame to the interpolated frame
135 | - `t0`: Reversed version of `0t` (used for backward flow)
136 | - `t1`: Events from the interpolated frame to the end frame
137 |
138 | Each event voxel is stored in the following naming format:
139 |
140 | --_{suffix}.npz
141 |
142 | Each index is zero-padded using `zfill(6)`. The `{suffix}` represents one of the three types: `0t`, `t0`, or `t1`.
143 |
144 | #### 📁 Example
145 |
146 | ```text
147 | 000000-000002-000004_0t.npz # event voxel from 000000 to 000002
148 | 000000-000002-000004_t0.npz # reversed event voxel from 000002 to 000000
149 | 000000-000002-000004_t1.npz # event voxel from 000002 to 000004
150 | ```
151 |
152 | Once the voxel preprocessing is complete and the files are generated in the proper format, you can proceed to download the pretrained model and run the test script.
153 |
154 | ### 3. Download Pretrained Weights
155 |
156 | Download the our-large weights (trained on the BSERGB dataset) and place the downloaded model inside the ./pretrained_model directory.
157 |
158 | 🔗 **[Ours-Large(BSERGB)](https://drive.google.com/file/d/1T5ycqQK4KVZQ4pAnNkr2Ff8XIWvhmj_f/view?usp=sharing)**
159 |
160 | Then move the file to the `./pretrained_model` directory:
161 |
162 | ``` bash
163 | # Ensure the directory exists
164 | mkdir -p pretrained_model
165 |
166 | # Move the downloaded model to the correct location
167 | mv /path/to/downloaded/Ours_Large_BSERGB.pth ./pretrained_model/
168 | ```
169 |
170 | Make sure the final path is:
171 |
172 |
173 | ``` bash
174 | ./pretrained_model/Ours_Large_BSERGB.pth
175 | ```
176 |
177 | ### 4. Run test scripts
178 |
179 | Once preprocessing and downloading the pretrained model are complete, you can test the model on the BSERGB dataset:
180 |
181 | ``` bash
182 | $ python test_bsergb.py --dataset_dir BSERGB_DATASET_DIR
183 | ```
184 |
185 | After running this script, gt and result images will be generated inside the ./output directory.
186 |
187 | By evaluating the output images, you can reproduce the same quantitative results reported in the paper.
188 |
189 | ## 🚀 Train model on BSERGB Dataset
190 |
191 | Training instructions and documentation will be available in the near future. (Work in progress)
192 |
193 | In the meantime, if you need to proceed quickly, please refer to the `train.py` file for rough guide.
194 |
195 |
196 | ## Reference
197 | > Taewoo Kim, Yujeong Chae, Hyun-kyurl Jang, and Kuk-Jin Yoon" Event-based Video Frame Interpolation with Cross-modal Asymmetric Bidirectional Motion Fields", In _CVPR_, 2023.
198 | ```bibtex
199 | @InProceedings{Kim_2023_CVPR,
200 | author = {Kim, Taewoo and Chae, Yujeong and Jang, Hyun-Kurl and Yoon, Kuk-Jin},
201 | title = {Event-Based Video Frame Interpolation With Cross-Modal Asymmetric Bidirectional Motion Fields},
202 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
203 | month = {June},
204 | year = {2023},
205 | pages = {18032-18042}
206 | }
207 | ```
208 | ## Contact
209 | If you have any question, please send an email to taewoo(intelpro@kaist.ac.kr)
210 |
211 | ## License
212 | The project codes and datasets can be used for research and education only.
213 |
--------------------------------------------------------------------------------
/correlation_package/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/correlation_package/correlation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn.modules.module import Module
3 | from torch.autograd import Function
4 | import correlation_cuda
5 |
6 | class CorrelationFunction(Function):
7 |
8 | # def __init__(self, pad_size=3, kernel_size=3, max_displacement=20, stride1=1, stride2=2, corr_multiply=1):
9 | # super(CorrelationFunction, self).__init__()
10 | # self.pad_size = pad_size
11 | # self.kernel_size = kernel_size
12 | # self.max_displacement = max_displacement
13 | # self.stride1 = stride1
14 | # self.stride2 = stride2
15 | # self.corr_multiply = corr_multiply
16 | # # self.out_channel = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1)
17 |
18 | @staticmethod
19 | def forward(ctx, input1, input2, pad_size, kernel_size, max_displacement,stride1, stride2, corr_multiply):
20 | ctx.save_for_backward(input1, input2)
21 | ctx.pad_size = pad_size
22 | ctx.kernel_size = kernel_size
23 | ctx.max_displacement = max_displacement
24 | ctx.stride1 = stride1
25 | ctx.stride2 = stride2
26 | ctx.corr_multiply = corr_multiply
27 |
28 | with torch.cuda.device_of(input1):
29 | rbot1 = input1.new()
30 | rbot2 = input2.new()
31 | output = input1.new()
32 |
33 | correlation_cuda.forward(input1, input2, rbot1, rbot2, output,
34 | pad_size, kernel_size, max_displacement, stride1, stride2, corr_multiply)
35 |
36 | return output
37 |
38 | @staticmethod
39 | def backward(ctx, grad_output):
40 | input1, input2 = ctx.saved_tensors
41 |
42 | with torch.cuda.device_of(input1):
43 | rbot1 = input1.new()
44 | rbot2 = input2.new()
45 |
46 | grad_input1 = input1.new()
47 | grad_input2 = input2.new()
48 |
49 | correlation_cuda.backward(input1, input2, rbot1, rbot2, grad_output, grad_input1, grad_input2,
50 | ctx.pad_size, ctx.kernel_size, ctx.max_displacement, ctx.stride1, ctx.stride2, ctx.corr_multiply)
51 |
52 | return grad_input1, grad_input2, None, None, None, None, None, None
53 |
54 |
55 | class Correlation(Module):
56 | def __init__(self, pad_size=0, kernel_size=0, max_displacement=0, stride1=1, stride2=2, corr_multiply=1):
57 | super(Correlation, self).__init__()
58 | self.pad_size = pad_size
59 | self.kernel_size = kernel_size
60 | self.max_displacement = max_displacement
61 | self.stride1 = stride1
62 | self.stride2 = stride2
63 | self.corr_multiply = corr_multiply
64 |
65 | # @staticmethod
66 | def forward(self, input1, input2):
67 |
68 | input1 = input1.contiguous()
69 | input2 = input2.contiguous()
70 | # result = CorrelationFunction(self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)(input1, input2)
71 | result = CorrelationFunction.apply(input1, input2, self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)
72 |
73 | return result
74 |
75 |
--------------------------------------------------------------------------------
/correlation_package/correlation_cuda.cc:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 |
7 | #include "correlation_cuda_kernel.cuh"
8 |
9 | int correlation_forward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& output,
10 | int pad_size,
11 | int kernel_size,
12 | int max_displacement,
13 | int stride1,
14 | int stride2,
15 | int corr_type_multiply)
16 | {
17 |
18 | int batchSize = input1.size(0);
19 |
20 | int nInputChannels = input1.size(1);
21 | int inputHeight = input1.size(2);
22 | int inputWidth = input1.size(3);
23 |
24 | int kernel_radius = (kernel_size - 1) / 2;
25 | int border_radius = kernel_radius + max_displacement;
26 |
27 | int paddedInputHeight = inputHeight + 2 * pad_size;
28 | int paddedInputWidth = inputWidth + 2 * pad_size;
29 |
30 | int nOutputChannels = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1);
31 |
32 | int outputHeight = ceil(static_cast(paddedInputHeight - 2 * border_radius) / static_cast(stride1));
33 | int outputwidth = ceil(static_cast(paddedInputWidth - 2 * border_radius) / static_cast(stride1));
34 |
35 | rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
36 | rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
37 | output.resize_({batchSize, nOutputChannels, outputHeight, outputwidth});
38 |
39 | rInput1.fill_(0);
40 | rInput2.fill_(0);
41 | output.fill_(0);
42 |
43 | int success = correlation_forward_cuda_kernel(
44 | output,
45 | output.size(0),
46 | output.size(1),
47 | output.size(2),
48 | output.size(3),
49 | output.stride(0),
50 | output.stride(1),
51 | output.stride(2),
52 | output.stride(3),
53 | input1,
54 | input1.size(1),
55 | input1.size(2),
56 | input1.size(3),
57 | input1.stride(0),
58 | input1.stride(1),
59 | input1.stride(2),
60 | input1.stride(3),
61 | input2,
62 | input2.size(1),
63 | input2.stride(0),
64 | input2.stride(1),
65 | input2.stride(2),
66 | input2.stride(3),
67 | rInput1,
68 | rInput2,
69 | pad_size,
70 | kernel_size,
71 | max_displacement,
72 | stride1,
73 | stride2,
74 | corr_type_multiply,
75 | at::cuda::getCurrentCUDAStream()
76 | );
77 |
78 | //check for errors
79 | if (!success) {
80 | AT_ERROR("CUDA call failed");
81 | }
82 |
83 | return 1;
84 |
85 | }
86 |
87 | int correlation_backward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& gradOutput,
88 | at::Tensor& gradInput1, at::Tensor& gradInput2,
89 | int pad_size,
90 | int kernel_size,
91 | int max_displacement,
92 | int stride1,
93 | int stride2,
94 | int corr_type_multiply)
95 | {
96 |
97 | int batchSize = input1.size(0);
98 | int nInputChannels = input1.size(1);
99 | int paddedInputHeight = input1.size(2)+ 2 * pad_size;
100 | int paddedInputWidth = input1.size(3)+ 2 * pad_size;
101 |
102 | int height = input1.size(2);
103 | int width = input1.size(3);
104 |
105 | rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
106 | rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
107 | gradInput1.resize_({batchSize, nInputChannels, height, width});
108 | gradInput2.resize_({batchSize, nInputChannels, height, width});
109 |
110 | rInput1.fill_(0);
111 | rInput2.fill_(0);
112 | gradInput1.fill_(0);
113 | gradInput2.fill_(0);
114 |
115 | int success = correlation_backward_cuda_kernel(gradOutput,
116 | gradOutput.size(0),
117 | gradOutput.size(1),
118 | gradOutput.size(2),
119 | gradOutput.size(3),
120 | gradOutput.stride(0),
121 | gradOutput.stride(1),
122 | gradOutput.stride(2),
123 | gradOutput.stride(3),
124 | input1,
125 | input1.size(1),
126 | input1.size(2),
127 | input1.size(3),
128 | input1.stride(0),
129 | input1.stride(1),
130 | input1.stride(2),
131 | input1.stride(3),
132 | input2,
133 | input2.stride(0),
134 | input2.stride(1),
135 | input2.stride(2),
136 | input2.stride(3),
137 | gradInput1,
138 | gradInput1.stride(0),
139 | gradInput1.stride(1),
140 | gradInput1.stride(2),
141 | gradInput1.stride(3),
142 | gradInput2,
143 | gradInput2.size(1),
144 | gradInput2.stride(0),
145 | gradInput2.stride(1),
146 | gradInput2.stride(2),
147 | gradInput2.stride(3),
148 | rInput1,
149 | rInput2,
150 | pad_size,
151 | kernel_size,
152 | max_displacement,
153 | stride1,
154 | stride2,
155 | corr_type_multiply,
156 | at::cuda::getCurrentCUDAStream()
157 | );
158 |
159 | if (!success) {
160 | AT_ERROR("CUDA call failed");
161 | }
162 |
163 | return 1;
164 | }
165 |
166 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
167 | m.def("forward", &correlation_forward_cuda, "Correlation forward (CUDA)");
168 | m.def("backward", &correlation_backward_cuda, "Correlation backward (CUDA)");
169 | }
170 |
171 |
--------------------------------------------------------------------------------
/correlation_package/correlation_cuda_kernel.cu:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "correlation_cuda_kernel.cuh"
4 |
5 | #define CUDA_NUM_THREADS 1024
6 | #define THREADS_PER_BLOCK 32
7 | #define FULL_MASK 0xffffffff
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | using at::Half;
15 |
16 | template
17 | __forceinline__ __device__ scalar_t warpReduceSum(scalar_t val) {
18 | for (int offset = 16; offset > 0; offset /= 2)
19 | val += __shfl_down_sync(FULL_MASK, val, offset);
20 | return val;
21 | }
22 |
23 | template
24 | __forceinline__ __device__ scalar_t blockReduceSum(scalar_t val) {
25 |
26 | static __shared__ scalar_t shared[32];
27 | int lane = threadIdx.x % warpSize;
28 | int wid = threadIdx.x / warpSize;
29 |
30 | val = warpReduceSum(val);
31 |
32 | if (lane == 0)
33 | shared[wid] = val;
34 |
35 | __syncthreads();
36 |
37 | val = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0;
38 |
39 | if (wid == 0)
40 | val = warpReduceSum(val);
41 |
42 | return val;
43 | }
44 |
45 |
46 | template
47 | __global__ void channels_first(const scalar_t* __restrict__ input, scalar_t* rinput, int channels, int height, int width, int pad_size)
48 | {
49 |
50 | // n (batch size), c (num of channels), y (height), x (width)
51 | int n = blockIdx.x;
52 | int y = blockIdx.y;
53 | int x = blockIdx.z;
54 |
55 | int ch_off = threadIdx.x;
56 | scalar_t value;
57 |
58 | int dimcyx = channels * height * width;
59 | int dimyx = height * width;
60 |
61 | int p_dimx = (width + 2 * pad_size);
62 | int p_dimy = (height + 2 * pad_size);
63 | int p_dimyxc = channels * p_dimy * p_dimx;
64 | int p_dimxc = p_dimx * channels;
65 |
66 | for (int c = ch_off; c < channels; c += THREADS_PER_BLOCK) {
67 | value = input[n * dimcyx + c * dimyx + y * width + x];
68 | rinput[n * p_dimyxc + (y + pad_size) * p_dimxc + (x + pad_size) * channels + c] = value;
69 | }
70 | }
71 |
72 |
73 | template
74 | __global__ void correlation_forward(scalar_t* __restrict__ output, const int nOutputChannels,
75 | const int outputHeight, const int outputWidth, const scalar_t* __restrict__ rInput1,
76 | const int nInputChannels, const int inputHeight, const int inputWidth,
77 | const scalar_t* __restrict__ rInput2, const int pad_size, const int kernel_size,
78 | const int max_displacement, const int stride1, const int stride2) {
79 |
80 | int32_t pInputWidth = inputWidth + 2 * pad_size;
81 | int32_t pInputHeight = inputHeight + 2 * pad_size;
82 |
83 | int32_t kernel_rad = (kernel_size - 1) / 2;
84 |
85 | int32_t displacement_rad = max_displacement / stride2;
86 |
87 | int32_t displacement_size = 2 * displacement_rad + 1;
88 |
89 | int32_t n = blockIdx.x;
90 | int32_t y1 = blockIdx.y * stride1 + max_displacement;
91 | int32_t x1 = blockIdx.z * stride1 + max_displacement;
92 | int32_t c = threadIdx.x;
93 |
94 | int32_t pdimyxc = pInputHeight * pInputWidth * nInputChannels;
95 |
96 | int32_t pdimxc = pInputWidth * nInputChannels;
97 |
98 | int32_t pdimc = nInputChannels;
99 |
100 | int32_t tdimcyx = nOutputChannels * outputHeight * outputWidth;
101 | int32_t tdimyx = outputHeight * outputWidth;
102 | int32_t tdimx = outputWidth;
103 |
104 | int32_t nelems = kernel_size * kernel_size * pdimc;
105 |
106 | // element-wise product along channel axis
107 | for (int tj = -displacement_rad; tj <= displacement_rad; ++tj) {
108 | for (int ti = -displacement_rad; ti <= displacement_rad; ++ti) {
109 | int x2 = x1 + ti * stride2;
110 | int y2 = y1 + tj * stride2;
111 |
112 | float acc0 = 0.0f;
113 |
114 | for (int j = -kernel_rad; j <= kernel_rad; ++j) {
115 | for (int i = -kernel_rad; i <= kernel_rad; ++i) {
116 | // THREADS_PER_BLOCK
117 | #pragma unroll
118 | for (int ch = c; ch < pdimc; ch += blockDim.x) {
119 |
120 | int indx1 = n * pdimyxc + (y1 + j) * pdimxc
121 | + (x1 + i) * pdimc + ch;
122 | int indx2 = n * pdimyxc + (y2 + j) * pdimxc
123 | + (x2 + i) * pdimc + ch;
124 | acc0 += static_cast(rInput1[indx1] * rInput2[indx2]);
125 | }
126 | }
127 | }
128 |
129 | if (blockDim.x == warpSize) {
130 | __syncwarp();
131 | acc0 = warpReduceSum(acc0);
132 | } else {
133 | __syncthreads();
134 | acc0 = blockReduceSum(acc0);
135 | }
136 |
137 | if (threadIdx.x == 0) {
138 |
139 | int tc = (tj + displacement_rad) * displacement_size
140 | + (ti + displacement_rad);
141 | const int tindx = n * tdimcyx + tc * tdimyx + blockIdx.y * tdimx
142 | + blockIdx.z;
143 | output[tindx] = static_cast(acc0 / nelems);
144 | }
145 | }
146 | }
147 | }
148 |
149 |
150 | template
151 | __global__ void correlation_backward_input1(int item, scalar_t* gradInput1, int nInputChannels, int inputHeight, int inputWidth,
152 | const scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth,
153 | const scalar_t* __restrict__ rInput2,
154 | int pad_size,
155 | int kernel_size,
156 | int max_displacement,
157 | int stride1,
158 | int stride2)
159 | {
160 | // n (batch size), c (num of channels), y (height), x (width)
161 |
162 | int n = item;
163 | int y = blockIdx.x * stride1 + pad_size;
164 | int x = blockIdx.y * stride1 + pad_size;
165 | int c = blockIdx.z;
166 | int tch_off = threadIdx.x;
167 |
168 | int kernel_rad = (kernel_size - 1) / 2;
169 | int displacement_rad = max_displacement / stride2;
170 | int displacement_size = 2 * displacement_rad + 1;
171 |
172 | int xmin = (x - kernel_rad - max_displacement) / stride1;
173 | int ymin = (y - kernel_rad - max_displacement) / stride1;
174 |
175 | int xmax = (x + kernel_rad - max_displacement) / stride1;
176 | int ymax = (y + kernel_rad - max_displacement) / stride1;
177 |
178 | if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) {
179 | // assumes gradInput1 is pre-allocated and zero filled
180 | return;
181 | }
182 |
183 | if (xmin > xmax || ymin > ymax) {
184 | // assumes gradInput1 is pre-allocated and zero filled
185 | return;
186 | }
187 |
188 | xmin = max(0,xmin);
189 | xmax = min(outputWidth-1,xmax);
190 |
191 | ymin = max(0,ymin);
192 | ymax = min(outputHeight-1,ymax);
193 |
194 | int pInputWidth = inputWidth + 2 * pad_size;
195 | int pInputHeight = inputHeight + 2 * pad_size;
196 |
197 | int pdimyxc = pInputHeight * pInputWidth * nInputChannels;
198 | int pdimxc = pInputWidth * nInputChannels;
199 | int pdimc = nInputChannels;
200 |
201 | int tdimcyx = nOutputChannels * outputHeight * outputWidth;
202 | int tdimyx = outputHeight * outputWidth;
203 | int tdimx = outputWidth;
204 |
205 | int odimcyx = nInputChannels * inputHeight* inputWidth;
206 | int odimyx = inputHeight * inputWidth;
207 | int odimx = inputWidth;
208 |
209 | scalar_t nelems = kernel_size * kernel_size * nInputChannels;
210 |
211 | __shared__ scalar_t prod_sum[THREADS_PER_BLOCK];
212 | prod_sum[tch_off] = 0;
213 |
214 | for (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) {
215 |
216 | int i2 = (tc % displacement_size - displacement_rad) * stride2;
217 | int j2 = (tc / displacement_size - displacement_rad) * stride2;
218 |
219 | int indx2 = n * pdimyxc + (y + j2)* pdimxc + (x + i2) * pdimc + c;
220 |
221 | scalar_t val2 = rInput2[indx2];
222 |
223 | for (int j = ymin; j <= ymax; ++j) {
224 | for (int i = xmin; i <= xmax; ++i) {
225 | int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i;
226 | prod_sum[tch_off] += gradOutput[tindx] * val2;
227 | }
228 | }
229 | }
230 | __syncthreads();
231 |
232 | if(tch_off == 0) {
233 | scalar_t reduce_sum = 0;
234 | for(int idx = 0; idx < THREADS_PER_BLOCK; idx++) {
235 | reduce_sum += prod_sum[idx];
236 | }
237 | const int indx1 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size);
238 | gradInput1[indx1] = reduce_sum / nelems;
239 | }
240 |
241 | }
242 |
243 | template
244 | __global__ void correlation_backward_input2(int item, scalar_t* gradInput2, int nInputChannels, int inputHeight, int inputWidth,
245 | const scalar_t* __restrict__ gradOutput, int nOutputChannels, int outputHeight, int outputWidth,
246 | const scalar_t* __restrict__ rInput1,
247 | int pad_size,
248 | int kernel_size,
249 | int max_displacement,
250 | int stride1,
251 | int stride2)
252 | {
253 | // n (batch size), c (num of channels), y (height), x (width)
254 |
255 | int n = item;
256 | int y = blockIdx.x * stride1 + pad_size;
257 | int x = blockIdx.y * stride1 + pad_size;
258 | int c = blockIdx.z;
259 |
260 | int tch_off = threadIdx.x;
261 |
262 | int kernel_rad = (kernel_size - 1) / 2;
263 | int displacement_rad = max_displacement / stride2;
264 | int displacement_size = 2 * displacement_rad + 1;
265 |
266 | int pInputWidth = inputWidth + 2 * pad_size;
267 | int pInputHeight = inputHeight + 2 * pad_size;
268 |
269 | int pdimyxc = pInputHeight * pInputWidth * nInputChannels;
270 | int pdimxc = pInputWidth * nInputChannels;
271 | int pdimc = nInputChannels;
272 |
273 | int tdimcyx = nOutputChannels * outputHeight * outputWidth;
274 | int tdimyx = outputHeight * outputWidth;
275 | int tdimx = outputWidth;
276 |
277 | int odimcyx = nInputChannels * inputHeight* inputWidth;
278 | int odimyx = inputHeight * inputWidth;
279 | int odimx = inputWidth;
280 |
281 | scalar_t nelems = kernel_size * kernel_size * nInputChannels;
282 |
283 | __shared__ scalar_t prod_sum[THREADS_PER_BLOCK];
284 | prod_sum[tch_off] = 0;
285 |
286 | for (int tc = tch_off; tc < nOutputChannels; tc += THREADS_PER_BLOCK) {
287 | int i2 = (tc % displacement_size - displacement_rad) * stride2;
288 | int j2 = (tc / displacement_size - displacement_rad) * stride2;
289 |
290 | int xmin = (x - kernel_rad - max_displacement - i2) / stride1;
291 | int ymin = (y - kernel_rad - max_displacement - j2) / stride1;
292 |
293 | int xmax = (x + kernel_rad - max_displacement - i2) / stride1;
294 | int ymax = (y + kernel_rad - max_displacement - j2) / stride1;
295 |
296 | if (xmax < 0 || ymax < 0 || xmin >= outputWidth || ymin >= outputHeight) {
297 | // assumes gradInput2 is pre-allocated and zero filled
298 | continue;
299 | }
300 |
301 | if (xmin > xmax || ymin > ymax) {
302 | // assumes gradInput2 is pre-allocated and zero filled
303 | continue;
304 | }
305 |
306 | xmin = max(0,xmin);
307 | xmax = min(outputWidth-1,xmax);
308 |
309 | ymin = max(0,ymin);
310 | ymax = min(outputHeight-1,ymax);
311 |
312 | int indx1 = n * pdimyxc + (y - j2)* pdimxc + (x - i2) * pdimc + c;
313 | scalar_t val1 = rInput1[indx1];
314 |
315 | for (int j = ymin; j <= ymax; ++j) {
316 | for (int i = xmin; i <= xmax; ++i) {
317 | int tindx = n * tdimcyx + tc * tdimyx + j * tdimx + i;
318 | prod_sum[tch_off] += gradOutput[tindx] * val1;
319 | }
320 | }
321 | }
322 |
323 | __syncthreads();
324 |
325 | if(tch_off == 0) {
326 | scalar_t reduce_sum = 0;
327 | for(int idx = 0; idx < THREADS_PER_BLOCK; idx++) {
328 | reduce_sum += prod_sum[idx];
329 | }
330 | const int indx2 = n * odimcyx + c * odimyx + (y - pad_size) * odimx + (x - pad_size);
331 | gradInput2[indx2] = reduce_sum / nelems;
332 | }
333 |
334 | }
335 |
336 | int correlation_forward_cuda_kernel(at::Tensor& output,
337 | int ob,
338 | int oc,
339 | int oh,
340 | int ow,
341 | int osb,
342 | int osc,
343 | int osh,
344 | int osw,
345 |
346 | at::Tensor& input1,
347 | int ic,
348 | int ih,
349 | int iw,
350 | int isb,
351 | int isc,
352 | int ish,
353 | int isw,
354 |
355 | at::Tensor& input2,
356 | int gc,
357 | int gsb,
358 | int gsc,
359 | int gsh,
360 | int gsw,
361 |
362 | at::Tensor& rInput1,
363 | at::Tensor& rInput2,
364 | int pad_size,
365 | int kernel_size,
366 | int max_displacement,
367 | int stride1,
368 | int stride2,
369 | int corr_type_multiply,
370 | cudaStream_t stream)
371 | {
372 |
373 | int batchSize = ob;
374 |
375 | int nInputChannels = ic;
376 | int inputWidth = iw;
377 | int inputHeight = ih;
378 |
379 | int nOutputChannels = oc;
380 | int outputWidth = ow;
381 | int outputHeight = oh;
382 |
383 | dim3 blocks_grid(batchSize, inputHeight, inputWidth);
384 | dim3 threads_block(THREADS_PER_BLOCK);
385 |
386 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channels_first_fwd_1", ([&] {
387 |
388 | channels_first<<>>(
389 | input1.data(), rInput1.data(), nInputChannels, inputHeight, inputWidth, pad_size);
390 |
391 | }));
392 |
393 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "channels_first_fwd_2", ([&] {
394 |
395 | channels_first<<>> (
396 | input2.data(), rInput2.data(), nInputChannels, inputHeight, inputWidth, pad_size);
397 |
398 | }));
399 |
400 | dim3 threadsPerBlock(THREADS_PER_BLOCK);
401 | dim3 totalBlocksCorr(batchSize, outputHeight, outputWidth);
402 |
403 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "correlation_forward", ([&] {
404 |
405 | correlation_forward<<>>
406 | (output.data(), nOutputChannels, outputHeight, outputWidth,
407 | rInput1.data(), nInputChannels, inputHeight, inputWidth,
408 | rInput2.data(),
409 | pad_size,
410 | kernel_size,
411 | max_displacement,
412 | stride1,
413 | stride2);
414 |
415 | }));
416 |
417 | cudaError_t err = cudaGetLastError();
418 |
419 |
420 | // check for errors
421 | if (err != cudaSuccess) {
422 | printf("error in correlation_forward_cuda_kernel: %s\n", cudaGetErrorString(err));
423 | return 0;
424 | }
425 |
426 | return 1;
427 | }
428 |
429 |
430 | int correlation_backward_cuda_kernel(
431 | at::Tensor& gradOutput,
432 | int gob,
433 | int goc,
434 | int goh,
435 | int gow,
436 | int gosb,
437 | int gosc,
438 | int gosh,
439 | int gosw,
440 |
441 | at::Tensor& input1,
442 | int ic,
443 | int ih,
444 | int iw,
445 | int isb,
446 | int isc,
447 | int ish,
448 | int isw,
449 |
450 | at::Tensor& input2,
451 | int gsb,
452 | int gsc,
453 | int gsh,
454 | int gsw,
455 |
456 | at::Tensor& gradInput1,
457 | int gisb,
458 | int gisc,
459 | int gish,
460 | int gisw,
461 |
462 | at::Tensor& gradInput2,
463 | int ggc,
464 | int ggsb,
465 | int ggsc,
466 | int ggsh,
467 | int ggsw,
468 |
469 | at::Tensor& rInput1,
470 | at::Tensor& rInput2,
471 | int pad_size,
472 | int kernel_size,
473 | int max_displacement,
474 | int stride1,
475 | int stride2,
476 | int corr_type_multiply,
477 | cudaStream_t stream)
478 | {
479 |
480 | int batchSize = gob;
481 | int num = batchSize;
482 |
483 | int nInputChannels = ic;
484 | int inputWidth = iw;
485 | int inputHeight = ih;
486 |
487 | int nOutputChannels = goc;
488 | int outputWidth = gow;
489 | int outputHeight = goh;
490 |
491 | dim3 blocks_grid(batchSize, inputHeight, inputWidth);
492 | dim3 threads_block(THREADS_PER_BLOCK);
493 |
494 |
495 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "lltm_forward_cuda", ([&] {
496 |
497 | channels_first<<>>(
498 | input1.data(),
499 | rInput1.data(),
500 | nInputChannels,
501 | inputHeight,
502 | inputWidth,
503 | pad_size
504 | );
505 | }));
506 |
507 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "lltm_forward_cuda", ([&] {
508 |
509 | channels_first<<>>(
510 | input2.data(),
511 | rInput2.data(),
512 | nInputChannels,
513 | inputHeight,
514 | inputWidth,
515 | pad_size
516 | );
517 | }));
518 |
519 | dim3 threadsPerBlock(THREADS_PER_BLOCK);
520 | dim3 totalBlocksCorr(inputHeight, inputWidth, nInputChannels);
521 |
522 | for (int n = 0; n < num; ++n) {
523 |
524 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input2.type(), "lltm_forward_cuda", ([&] {
525 |
526 |
527 | correlation_backward_input1<<>> (
528 | n, gradInput1.data(), nInputChannels, inputHeight, inputWidth,
529 | gradOutput.data(), nOutputChannels, outputHeight, outputWidth,
530 | rInput2.data(),
531 | pad_size,
532 | kernel_size,
533 | max_displacement,
534 | stride1,
535 | stride2);
536 | }));
537 | }
538 |
539 | for(int n = 0; n < batchSize; n++) {
540 |
541 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(rInput1.type(), "lltm_forward_cuda", ([&] {
542 |
543 | correlation_backward_input2<<>>(
544 | n, gradInput2.data(), nInputChannels, inputHeight, inputWidth,
545 | gradOutput.data(), nOutputChannels, outputHeight, outputWidth,
546 | rInput1.data(),
547 | pad_size,
548 | kernel_size,
549 | max_displacement,
550 | stride1,
551 | stride2);
552 |
553 | }));
554 | }
555 |
556 | // check for errors
557 | cudaError_t err = cudaGetLastError();
558 | if (err != cudaSuccess) {
559 | printf("error in correlation_backward_cuda_kernel: %s\n", cudaGetErrorString(err));
560 | return 0;
561 | }
562 |
563 | return 1;
564 | }
565 |
--------------------------------------------------------------------------------
/correlation_package/correlation_cuda_kernel.cuh:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 | #include
5 | #include
6 |
7 | int correlation_forward_cuda_kernel(at::Tensor& output,
8 | int ob,
9 | int oc,
10 | int oh,
11 | int ow,
12 | int osb,
13 | int osc,
14 | int osh,
15 | int osw,
16 |
17 | at::Tensor& input1,
18 | int ic,
19 | int ih,
20 | int iw,
21 | int isb,
22 | int isc,
23 | int ish,
24 | int isw,
25 |
26 | at::Tensor& input2,
27 | int gc,
28 | int gsb,
29 | int gsc,
30 | int gsh,
31 | int gsw,
32 |
33 | at::Tensor& rInput1,
34 | at::Tensor& rInput2,
35 | int pad_size,
36 | int kernel_size,
37 | int max_displacement,
38 | int stride1,
39 | int stride2,
40 | int corr_type_multiply,
41 | cudaStream_t stream);
42 |
43 |
44 | int correlation_backward_cuda_kernel(
45 | at::Tensor& gradOutput,
46 | int gob,
47 | int goc,
48 | int goh,
49 | int gow,
50 | int gosb,
51 | int gosc,
52 | int gosh,
53 | int gosw,
54 |
55 | at::Tensor& input1,
56 | int ic,
57 | int ih,
58 | int iw,
59 | int isb,
60 | int isc,
61 | int ish,
62 | int isw,
63 |
64 | at::Tensor& input2,
65 | int gsb,
66 | int gsc,
67 | int gsh,
68 | int gsw,
69 |
70 | at::Tensor& gradInput1,
71 | int gisb,
72 | int gisc,
73 | int gish,
74 | int gisw,
75 |
76 | at::Tensor& gradInput2,
77 | int ggc,
78 | int ggsb,
79 | int ggsc,
80 | int ggsh,
81 | int ggsw,
82 |
83 | at::Tensor& rInput1,
84 | at::Tensor& rInput2,
85 | int pad_size,
86 | int kernel_size,
87 | int max_displacement,
88 | int stride1,
89 | int stride2,
90 | int corr_type_multiply,
91 | cudaStream_t stream);
92 |
--------------------------------------------------------------------------------
/correlation_package/nvcc setting.md:
--------------------------------------------------------------------------------
1 | # How to set the nvcc version (Default:CUDA-11.0)
2 |
3 | ## Pre-requisite
4 |
5 | Check installed `cuda-11.0` path:
6 | ```bash
7 | $ cd /usr/local
8 | $ find . -maxdepth 1 -name 'cuda-11.0'
9 | ```
10 | - If there is no `cuda-11.0` folder in the directory, install `cuda-11.0` first.
11 |
12 | ## Change environments of the terminal (temporal)
13 | Change your terminal `PATH` and `LD_LIBRARY_PATH`:
14 | ```bash
15 | $ export PATH=/usr/local/cuda-11.0/bin${PATH:+:${PATH}}
16 | $ export LD_LIBRARY_PATH=/usr/local/cuda-11.0/lib64
17 | $ nvcc --version
18 | ```
19 | ## Change default environments of the terminal
20 | For change default `nvcc` version of your terminal, you should add below two lines in your `~/.bashrc`.
21 | ```bash
22 | $ gedit ~/.bashrc
23 | export PATH=/usr/local/cuda-11.0/bin${PATH:+:${PATH}}
24 | export LD_LIBRARY_PATH=/usr/local/cuda-11.0/lib64
25 | ```
26 | Save and then open new terminal:
27 | ```bash
28 | $ nvcc --version
29 | ```
30 |
--------------------------------------------------------------------------------
/correlation_package/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | # Minimum requirements for the build system to execute.
3 | requires = ["setuptools", "wheel", "numpy", "torch"] # PEP 508 specifications.
4 |
--------------------------------------------------------------------------------
/correlation_package/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import os
3 | import torch
4 |
5 | from setuptools import setup, find_packages
6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
7 |
8 | cxx_args = ['-std=c++14']
9 |
10 | nvcc_args = [
11 | '-gencode', 'arch=compute_50,code=sm_50',
12 | '-gencode', 'arch=compute_52,code=sm_52',
13 | '-gencode', 'arch=compute_60,code=sm_60',
14 | '-gencode', 'arch=compute_61,code=sm_61',
15 | '-gencode', 'arch=compute_70,code=sm_70',
16 | '-gencode', 'arch=compute_70,code=compute_70'
17 | ]
18 |
19 | setup(
20 | name='correlation_cuda',
21 | ext_modules=[
22 | CUDAExtension('correlation_cuda', [
23 | 'correlation_cuda.cc',
24 | 'correlation_cuda_kernel.cu'
25 | ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args})
26 | ],
27 | cmdclass={
28 | 'build_ext': BuildExtension
29 | })
30 |
--------------------------------------------------------------------------------
/figure/ERF_dataset.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/intelpro/CBMNet/248405c313ce29f87ad42873df46792c76f55d1a/figure/ERF_dataset.png
--------------------------------------------------------------------------------
/figure/Quantitative_eval_ERF_x170FPS.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/intelpro/CBMNet/248405c313ce29f87ad42873df46792c76f55d1a/figure/Quantitative_eval_ERF_x170FPS.png
--------------------------------------------------------------------------------
/figure/driving.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/intelpro/CBMNet/248405c313ce29f87ad42873df46792c76f55d1a/figure/driving.gif
--------------------------------------------------------------------------------
/figure/flower.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/intelpro/CBMNet/248405c313ce29f87ad42873df46792c76f55d1a/figure/flower.gif
--------------------------------------------------------------------------------
/figure/popcorn.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/intelpro/CBMNet/248405c313ce29f87ad42873df46792c76f55d1a/figure/popcorn.gif
--------------------------------------------------------------------------------
/install_correlation.sh:
--------------------------------------------------------------------------------
1 | cd correlation_package
2 | python3 setup.py install
3 | cd ..
--------------------------------------------------------------------------------
/models/final_models/ours.py:
--------------------------------------------------------------------------------
1 | from models.final_models.submodules import *
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn.modules import conv
6 | from correlation_package.correlation import Correlation
7 | from einops import rearrange, repeat
8 | from einops.layers.torch import Rearrange
9 | from timm.models.layers import DropPath, trunc_normal_, to_2tuple
10 | from functools import reduce, lru_cache
11 | import torch.nn.functional as tf
12 | from torch.autograd import Variable
13 | from einops import rearrange
14 | import math
15 | import numbers
16 | import collections
17 |
18 |
19 | def conv(in_planes, out_planes, kernel_size=3, stride=1, dilation=1, isReLU=True):
20 | if isReLU:
21 | return nn.Sequential(
22 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation,
23 | padding=((kernel_size - 1) * dilation) // 2, bias=True),
24 | nn.LeakyReLU(0.1, inplace=True)
25 | )
26 | else:
27 | return nn.Sequential(
28 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation,
29 | padding=((kernel_size - 1) * dilation) // 2, bias=True)
30 | )
31 |
32 | def predict_flow(in_planes):
33 | return nn.Conv2d(in_planes,2,kernel_size=3,stride=1,padding=1,bias=True)
34 |
35 | def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
36 | return nn.ConvTranspose2d(in_planes, out_planes, kernel_size, stride, padding, bias=True)
37 |
38 |
39 | class encoder_event_flow(nn.Module):
40 | def __init__(self, num_chs):
41 | super(encoder_event_flow, self).__init__()
42 | self.conv1 = conv_resblock_one(num_chs[0], num_chs[1], stride=1)
43 | self.conv2 = conv_resblock_one(num_chs[1], num_chs[2], stride=1)
44 | self.conv3 = conv_resblock_one(num_chs[2], num_chs[3], stride=2)
45 | self.conv4 = conv_resblock_one(num_chs[3], num_chs[4], stride=2)
46 |
47 | def forward(self, im):
48 | x = self.conv1(im)
49 | c11 = self.conv2(x)
50 | c12 = self.conv3(c11)
51 | c13 = self.conv4(c12)
52 | return c11, c12, c13
53 |
54 |
55 | class encoder_event_for_image_flow(nn.Module):
56 | def __init__(self, num_chs):
57 | super(encoder_event_for_image_flow, self).__init__()
58 | self.conv1 = conv_resblock_one(num_chs[0], num_chs[1], stride=1)
59 | self.conv2 = conv_resblock_one(num_chs[1], num_chs[2], stride=2)
60 | self.conv3 = conv_resblock_one(num_chs[2], num_chs[3], stride=2)
61 | self.conv4 = conv_resblock_one(num_chs[3], num_chs[4], stride=2)
62 |
63 | def forward(self, im):
64 | x = self.conv1(im)
65 | c11 = self.conv2(x)
66 | c12 = self.conv3(c11)
67 | c13 = self.conv4(c12)
68 | return c11, c12, c13
69 |
70 |
71 | class encoder_image_for_image_flow(nn.Module):
72 | def __init__(self, num_chs):
73 | super(encoder_image_for_image_flow, self).__init__()
74 | self.conv1 = conv_resblock_one(num_chs[0], num_chs[1], stride=1)
75 | self.conv2 = conv_resblock_one(num_chs[1], num_chs[2], stride=2)
76 | self.conv3 = conv_resblock_one(num_chs[2], num_chs[3], stride=2)
77 | self.conv4 = conv_resblock_one(num_chs[3], num_chs[4], stride=2)
78 |
79 | def forward(self, image):
80 | x = self.conv1(image)
81 | f1 = self.conv2(x)
82 | f2 = self.conv3(f1)
83 | f3 = self.conv4(f2)
84 | return f1, f2, f3
85 |
86 | def upsample2d(inputs, target_as, mode="bilinear"):
87 | _, _, h, w = target_as.size()
88 | return tf.interpolate(inputs, [h, w], mode=mode, align_corners=True)
89 |
90 | def upsample2d_hw(inputs, h, w, mode="bilinear"):
91 | return tf.interpolate(inputs, [h, w], mode=mode, align_corners=True)
92 |
93 |
94 | class DenseBlock(nn.Module):
95 | def __init__(self, ch_in):
96 | super(DenseBlock, self).__init__()
97 | self.conv1 = conv(ch_in, 128)
98 | self.conv2 = conv(ch_in + 128, 128)
99 | self.conv3 = conv(ch_in + 256, 96)
100 | self.conv4 = conv(ch_in + 352, 64)
101 | self.conv5 = conv(ch_in + 416, 32)
102 | self.conv_last = conv(ch_in + 448, 2, isReLU=False)
103 |
104 | def forward(self, x):
105 | x1 = torch.cat([self.conv1(x), x], dim=1)
106 | x2 = torch.cat([self.conv2(x1), x1], dim=1)
107 | x3 = torch.cat([self.conv3(x2), x2], dim=1)
108 | x4 = torch.cat([self.conv4(x3), x3], dim=1)
109 | x5 = torch.cat([self.conv5(x4), x4], dim=1)
110 | x_out = self.conv_last(x5)
111 | return x5, x_out
112 |
113 |
114 |
115 | class FlowEstimatorDense(nn.Module):
116 | def __init__(self, ch_in=64, f_channels=(128, 128, 96, 64, 32, 32), ch_out=2):
117 | super(FlowEstimatorDense, self).__init__()
118 | N = 0
119 | ind = 0
120 | N += ch_in
121 | self.conv1 = conv(N, f_channels[ind])
122 | N += f_channels[ind]
123 | ind += 1
124 | self.conv2 = conv(N, f_channels[ind])
125 | N += f_channels[ind]
126 | ind += 1
127 | self.conv3 = conv(N, f_channels[ind])
128 | N += f_channels[ind]
129 | ind += 1
130 | self.conv4 = conv(N, f_channels[ind])
131 | N += f_channels[ind]
132 | ind += 1
133 | self.conv5 = conv(N, f_channels[ind])
134 | N += f_channels[ind]
135 | self.num_feature_channel = N
136 | ind += 1
137 | self.conv_last = conv(N, ch_out, isReLU=False)
138 |
139 | def forward(self, x):
140 | x1 = torch.cat([self.conv1(x), x], axis=1)
141 | x2 = torch.cat([self.conv2(x1), x1], axis=1)
142 | x3 = torch.cat([self.conv3(x2), x2], axis=1)
143 | x4 = torch.cat([self.conv4(x3), x3], axis=1)
144 | x5 = torch.cat([self.conv5(x4), x4], axis=1)
145 | x_out = self.conv_last(x5)
146 | return x5, x_out
147 |
148 | class Tfeat_RefineBlock(nn.Module):
149 | def __init__(self, ch_in_frame, ch_in_event, ch_in_frame_prev, prev_scale=False):
150 | super(Tfeat_RefineBlock, self).__init__()
151 | if prev_scale:
152 | nf = int((ch_in_frame*2+ch_in_event+ch_in_frame_prev)/4)
153 | else:
154 | nf = int((ch_in_frame*2+ch_in_event)/4)
155 | self.conv_refine = nn.Sequential(conv1x1(4*nf, nf), nn.ReLU(), conv3x3(nf, 2*nf), nn.ReLU(), conv_resblock_one(2*nf, ch_in_frame))
156 |
157 | def forward(self, x):
158 | x1 = self.conv_refine(x)
159 | return x1
160 |
161 | def rescale_flow(flow, width_im, height_im):
162 | u_scale = float(width_im / flow.size(3))
163 | v_scale = float(height_im / flow.size(2))
164 | u, v = flow.chunk(2, dim=1)
165 | u = u_scale*u
166 | v = v_scale*v
167 | return torch.cat([u, v], dim=1)
168 |
169 |
170 | class FlowNet(nn.Module):
171 | def __init__(self, md=4, tb_debug=False):
172 | super(FlowNet, self).__init__()
173 | ## argument
174 | self.tb_debug = tb_debug
175 | # flow scale
176 | self.flow_scale = 20
177 | num_chs_frame = [3, 16, 32, 64, 96]
178 | num_chs_event = [16, 16, 32, 64, 128]
179 | num_chs_event_image = [16, 16, 16, 32, 64]
180 | ## for event-level flow
181 | self.encoder_event = encoder_event_flow(num_chs_event)
182 | ## for image-level flow
183 | self.encoder_image_flow = encoder_image_for_image_flow(num_chs_frame)
184 | self.encoder_image_flow_event = encoder_event_for_image_flow(num_chs_event_image)
185 | ## leaky relu
186 | self.leakyRELU = nn.LeakyReLU(0.1)
187 | ## correlation channel value
188 | self.corr = Correlation(pad_size=md, kernel_size=1, max_displacement=md, stride1=1, stride2=1, corr_multiply=1)
189 | ## correlation channel value
190 | nd = (2*md+1)**2
191 | self.corr_refinement = nn.ModuleList([DenseBlock(nd+num_chs_frame[-1]+2),
192 | DenseBlock(nd+num_chs_frame[-2]+2),
193 | DenseBlock(nd+num_chs_frame[-3]+2),
194 | DenseBlock(nd+num_chs_frame[-4]+2)
195 | ])
196 | self.decoder_event = nn.ModuleList([conv_resblock_one(num_chs_event[-1], num_chs_event[-1]),
197 | conv_resblock_one(num_chs_event[-2]+ num_chs_event[-1]+2, num_chs_event[-2]),
198 | conv_resblock_one(num_chs_event[-3]+ num_chs_event[-2]+2, num_chs_event[-3])])
199 | self.predict_flow = nn.ModuleList([conv3x3_leaky_relu(num_chs_event[-1], 2),
200 | conv3x3_leaky_relu(num_chs_event[-2], 2),
201 | conv3x3_leaky_relu(num_chs_event[-3], 2)])
202 | self.conv_frame = nn.ModuleList([conv3x3_leaky_relu(num_chs_frame[-2], 32),
203 | conv3x3_leaky_relu(num_chs_frame[-3], 32)])
204 | self.conv_frame_t = nn.ModuleList([conv3x3_leaky_relu(num_chs_frame[-2], 32),
205 | conv3x3_leaky_relu(num_chs_frame[-3], 32)])
206 | self.flow_fusion_block = FlowEstimatorDense(32*3+4, (32, 32, 32, 16, 8), 1)
207 | self.feat_t_refinement = nn.ModuleList([Tfeat_RefineBlock(num_chs_frame[-1], num_chs_event_image[-1]*2, None, prev_scale=False),
208 | Tfeat_RefineBlock(num_chs_frame[-2], num_chs_event_image[-2]*2, num_chs_frame[-1], prev_scale=True),
209 | Tfeat_RefineBlock(num_chs_frame[-3], num_chs_event_image[-3]*2, num_chs_frame[-2], prev_scale=True),
210 | ])
211 |
212 |
213 | def warp(self, x, flo):
214 | """
215 | warp an image/tensor (im2) back to im1, according to the optical flow
216 | x: [B, C, H, W] (im2)
217 | flo: [B, 2, H, W] flow
218 | """
219 | B, C, H, W = x.size()
220 | # mesh grid
221 | xx = torch.arange(0, W).view(1,-1).repeat(H,1)
222 | yy = torch.arange(0, H).view(-1,1).repeat(1,W)
223 | xx = xx.view(1,1,H,W).repeat(B,1,1,1)
224 | yy = yy.view(1,1,H,W).repeat(B,1,1,1)
225 | grid = torch.cat((xx,yy),1).float()
226 |
227 | if x.is_cuda:
228 | grid = grid.cuda()
229 | vgrid = Variable(grid) + flo
230 |
231 | # scale grid to [-1,1]
232 | vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:].clone() / max(W-1,1)-1.0
233 | vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:].clone() / max(H-1,1)-1.0
234 |
235 | vgrid = vgrid.permute(0,2,3,1)
236 | output = nn.functional.grid_sample(x, vgrid)
237 | mask = torch.autograd.Variable(torch.ones(x.size())).cuda()
238 | mask = nn.functional.grid_sample(mask, vgrid)
239 | mask[mask<0.9999] = 0
240 | mask[mask>0] = 1
241 | return output*mask
242 |
243 | def normalize_features(self, feature_list, normalize, center, moments_across_channels=True, moments_across_images=True):
244 | # Compute feature statistics.
245 | statistics = collections.defaultdict(list)
246 | axes = [1, 2, 3] if moments_across_channels else [2, 3] # [b, c, h, w]
247 | for feature_image in feature_list:
248 | mean = torch.mean(feature_image, axis=axes, keepdims=True) # [b,1,1,1] or [b,c,1,1]
249 | variance = torch.var(feature_image, axis=axes, keepdims=True) # [b,1,1,1] or [b,c,1,1]
250 | statistics['mean'].append(mean)
251 | statistics['var'].append(variance)
252 |
253 | if moments_across_images:
254 | statistics['mean'] = ([torch.mean(F.stack(statistics['mean'], axis=0), axis=(0, ))] * len(feature_list))
255 | statistics['var'] = ([torch.var(F.stack(statistics['var'], axis=0), axis=(0, ))] * len(feature_list))
256 |
257 | statistics['std'] = [torch.sqrt(v + 1e-16) for v in statistics['var']]
258 | # Center and normalize features.
259 | if center:
260 | feature_list = [f - mean for f, mean in zip(feature_list, statistics['mean'])]
261 | if normalize:
262 | feature_list = [f / std for f, std in zip(feature_list, statistics['std'])]
263 | return feature_list
264 |
265 | def forward(self, batch):
266 | # F for frame feature
267 | # E for event feature
268 | ### encoding
269 | ## image feature
270 | # feature pyramid
271 | F0_pyramid = self.encoder_image_flow(batch['image_input0'])[::-1]
272 | F1_pyramid = self.encoder_image_flow(batch['image_input1'])[::-1]
273 | E_0t_pyramid = self.encoder_image_flow_event(batch['event_input_0t'])[::-1]
274 | E_t1_pyramid = self.encoder_image_flow_event(batch['event_input_t1'])[::-1]
275 | # encoder event
276 | E_t0_pyramid_flow = self.encoder_event(batch['event_input_t0'])[::-1]
277 | E_t1_pyramid_flow = self.encoder_event(batch['event_input_t1'])[::-1]
278 | ### decoding optical flow
279 | ## level 0
280 | flow_t0_out_dict, flow_t1_out_dict, flow_t0_dict, flow_t1_dict = [], [], [], []
281 | ## event flow and image flow
282 | if self.tb_debug:
283 | event_flow_dict, image_flow_dict, fusion_flow_dict, mask_dict = [], [], [], []
284 | for level, (E_t0_flow, E_t1_flow, E_0t, E_t1, F0, F1) in enumerate(zip(E_t0_pyramid_flow, E_t1_pyramid_flow, E_0t_pyramid, E_t1_pyramid, F0_pyramid, F1_pyramid)):
285 | if level==0:
286 | ## event flow generation
287 | feat_t0_ev = self.decoder_event[level](E_t0_flow)
288 | feat_t1_ev = self.decoder_event[level](E_t1_flow)
289 | flow_event_t0 = self.predict_flow[level](feat_t0_ev)
290 | flow_event_t1 = self.predict_flow[level](feat_t1_ev)
291 | ## fusion flow(scale == 0)
292 | flow_fusion_t0 = flow_event_t0
293 | flow_fusion_t1 = flow_event_t1
294 | ## t feature
295 | feat_t_in = torch.cat((F0, F1, E_0t, E_t1), dim=1)
296 | feat_t = self.feat_t_refinement[level](feat_t_in)
297 | else:
298 | ## feat t
299 | upfeat0_t = upsample2d(feat_t, F0)
300 | feat_t_in = torch.cat((upfeat0_t, F0, F1, E_0t, E_t1), dim=1)
301 | feat_t = self.feat_t_refinement[level](feat_t_in)
302 | #### event-based optical flow
303 | ## event flow generation
304 | upflow_t0 = rescale_flow(upsample2d(flow_t0_out_dict[level-1], E_t0_flow), E_t0_flow.size(3), E_t0_flow.size(2))
305 | upflow_t1 = rescale_flow(upsample2d(flow_t1_out_dict[level-1], E_t1_flow), E_t1_flow.size(3), E_t1_flow.size(2))
306 | # upsample feat_t0
307 | feat_t0_ev_up = upsample2d(feat_t0_ev, E_t0_flow)
308 | feat_t1_ev_up = upsample2d(feat_t1_ev, E_t1_flow)
309 | # decoder event
310 | flow_t0_ev_up = rescale_flow(upsample2d(flow_event_t0, E_t0_flow), E_t0_flow.size(3), E_t0_flow.size(2))
311 | flow_t1_ev_up = rescale_flow(upsample2d(flow_event_t1, E_t1_flow), E_t1_flow.size(3), E_t1_flow.size(2))
312 | feat_t0_ev = self.decoder_event[level](torch.cat((E_t0_flow, feat_t0_ev_up, flow_t0_ev_up ), dim=1))
313 | feat_t1_ev = self.decoder_event[level](torch.cat((E_t1_flow, feat_t1_ev_up, flow_t1_ev_up), dim=1))
314 | ## project flow
315 | flow_event_t0_ = self.predict_flow[level](feat_t0_ev)
316 | flow_event_t1_ = self.predict_flow[level](feat_t1_ev)
317 | ## fusion flow
318 | flow_event_t0 = flow_t0_ev_up + flow_event_t0_
319 | flow_event_t1 = flow_t1_ev_up + flow_event_t1_
320 | # flow rescale
321 | down_evflow_t0 = rescale_flow(upsample2d(flow_event_t0, F0), F0.size(3), F0.size(2))
322 | down_evflow_t1 = rescale_flow(upsample2d(flow_event_t1, F1), F1.size(3), F1.size(2))
323 | down_upflow_t0 = rescale_flow(upsample2d(flow_t0_out_dict[level-1], F0), F0.size(3), F0.size(2))
324 | down_upflow_t1 = rescale_flow(upsample2d(flow_t1_out_dict[level-1], F1), F1.size(3), F1.size(2))
325 | ## warping with event flow and fusion flow
326 | F0_re = self.conv_frame[level-1](F0)
327 | F0_up_warp_ev = self.warp(F0_re, self.flow_scale*down_evflow_t0)
328 | F0_up_warp_frame = self.warp(F0_re, self.flow_scale*down_upflow_t0)
329 | F1_re = self.conv_frame[level-1](F1)
330 | F1_up_warp_ev = self.warp(F1_re, self.flow_scale*down_evflow_t1)
331 | F1_up_warp_frame = self.warp(F1_re, self.flow_scale*down_upflow_t1)
332 | Ft_up = self.conv_frame_t[level-1](feat_t)
333 | ## flow fusion
334 | _, out_fusion_t0 = self.flow_fusion_block(torch.cat((F0_up_warp_ev, F0_up_warp_frame, Ft_up, down_evflow_t0, down_upflow_t0), dim=1))
335 | _, out_fusion_t1 = self.flow_fusion_block(torch.cat((F1_up_warp_ev, F1_up_warp_frame, Ft_up, down_evflow_t1, down_upflow_t1), dim=1))
336 | mask_t0 = upsample2d(torch.sigmoid(out_fusion_t0[:, -1, : ,:])[:, None, :, :], E_t0_flow)
337 | mask_t1 = upsample2d(torch.sigmoid(out_fusion_t1[:, -1, :, :])[:, None, :, :], E_t1_flow)
338 | flow_fusion_t0 = (1-mask_t0)*upflow_t0 + mask_t0*flow_event_t0
339 | flow_fusion_t1 = (1-mask_t1)*upflow_t1 + mask_t1*flow_event_t1
340 | ## intermediate output
341 | if self.tb_debug:
342 | event_flow_dict.append(flow_event_t0)
343 | fusion_flow_dict.append(flow_fusion_t0)
344 | image_flow_dict.append(upflow_t0)
345 | mask_dict.append(mask_t0)
346 | # flow rescale
347 | down_flow_fusion_t0 = rescale_flow(upsample2d(flow_fusion_t0, F0), F0.size(3), F0.size(2))
348 | down_flow_fusion_t1 = rescale_flow(upsample2d(flow_fusion_t1, F1), F1.size(3), F1.size(2))
349 | # warping with optical flow
350 | feat10 = self.warp(F0, self.flow_scale*down_flow_fusion_t0)
351 | feat11 = self.warp(F1, self.flow_scale*down_flow_fusion_t1)
352 | # feature normalization
353 | feat_t_norm, feat10_norm, feat11_norm = self.normalize_features([feat_t, feat10, feat11], normalize=True, center=True, moments_across_channels=False, moments_across_images=False)
354 | # correlation
355 | corr_t0 = self.leakyRELU(self.corr(feat_t_norm, feat10_norm))
356 | corr_t1 = self.leakyRELU(self.corr(feat_t_norm, feat11_norm))
357 | # correlation refienement
358 | _, res_flow_t0 = self.corr_refinement[level](torch.cat((corr_t0, feat_t, down_flow_fusion_t0), dim=1))
359 | _, res_flow_t1 = self.corr_refinement[level](torch.cat((corr_t1, feat_t, down_flow_fusion_t1), dim=1))
360 | # frame-based optical flow generation
361 | flow_t0_frame = down_flow_fusion_t0 + res_flow_t0
362 | flow_t1_frame = down_flow_fusion_t1 + res_flow_t1
363 | ## upsampling frame-based optical flow
364 | upflow_t0_frame = rescale_flow(upsample2d(flow_t0_frame, flow_fusion_t0), flow_fusion_t0.size(3), flow_fusion_t0.size(2))
365 | upflow_t1_frame = rescale_flow(upsample2d(flow_t1_frame, flow_fusion_t1), flow_fusion_t1.size(3), flow_fusion_t1.size(2))
366 | ### output
367 | flow_t0_out_dict.append(upflow_t0_frame)
368 | flow_t1_out_dict.append(upflow_t1_frame)
369 | flow_t0_dict.append(self.flow_scale*upflow_t0_frame)
370 | flow_t1_dict.append(self.flow_scale*upflow_t1_frame)
371 | flow_t0_dict = flow_t0_dict[::-1]
372 | flow_t1_dict = flow_t1_dict[::-1]
373 | ## final output return
374 | flow_output_dict = {}
375 | flow_output_dict['flow_t0_dict'] = flow_t0_dict
376 | flow_output_dict['flow_t1_dict'] = flow_t1_dict
377 | if self.tb_debug:
378 | flow_output_dict['event_flow_dict'] = event_flow_dict
379 | flow_output_dict['fusion_flow_dict'] = fusion_flow_dict
380 | flow_output_dict['image_flow_dict'] = image_flow_dict
381 | flow_output_dict['mask_dict'] = mask_dict
382 | return flow_output_dict
383 |
384 |
385 | class frame_encoder(nn.Module):
386 | def __init__(self, in_dims, nf):
387 | super(frame_encoder, self).__init__()
388 | self.conv0 = conv3x3_leaky_relu(in_dims, nf)
389 | self.conv1 = conv_resblock_two(nf, nf)
390 | self.conv2 = conv_resblock_two(nf, 2*nf, stride=2)
391 | self.conv3 = conv_resblock_two(2*nf, 4*nf, stride=2)
392 |
393 | def forward(self, x):
394 | x_ = self.conv0(x)
395 | f1 = self.conv1(x_)
396 | f2 = self.conv2(f1)
397 | f3 = self.conv3(f2)
398 | return [f1, f2, f3]
399 |
400 | class event_encoder(nn.Module):
401 | def __init__(self, in_dims, nf):
402 | super(event_encoder, self).__init__()
403 | self.conv0 = conv3x3_leaky_relu(in_dims, nf)
404 | self.conv1 = conv_resblock_two(nf, nf)
405 | self.conv2 = conv_resblock_two(nf, 2*nf, stride=2)
406 | self.conv3 = conv_resblock_two(2*nf, 4*nf, stride=2)
407 |
408 | def forward(self, x):
409 | x_ = self.conv0(x)
410 | f1 = self.conv1(x_)
411 | f2 = self.conv2(f1)
412 | f3 = self.conv3(f2)
413 | return [f1, f2, f3]
414 |
415 | ##################################################
416 | ################# Restormer #####################
417 |
418 | ##########################################################################
419 | ## Layer Norm
420 | def to_3d(x):
421 | return rearrange(x, 'b c h w -> b (h w) c')
422 |
423 | def to_4d(x,h,w):
424 | return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)
425 |
426 | class BiasFree_LayerNorm(nn.Module):
427 | def __init__(self, normalized_shape):
428 | super(BiasFree_LayerNorm, self).__init__()
429 | if isinstance(normalized_shape, numbers.Integral):
430 | normalized_shape = (normalized_shape,)
431 | normalized_shape = torch.Size(normalized_shape)
432 |
433 | assert len(normalized_shape) == 1
434 |
435 | self.weight = nn.Parameter(torch.ones(normalized_shape))
436 | self.normalized_shape = normalized_shape
437 |
438 | def forward(self, x):
439 | sigma = x.var(-1, keepdim=True, unbiased=False)
440 | return x / torch.sqrt(sigma+1e-5) * self.weight
441 |
442 |
443 | class WithBias_LayerNorm(nn.Module):
444 | def __init__(self, normalized_shape):
445 | super(WithBias_LayerNorm, self).__init__()
446 | if isinstance(normalized_shape, numbers.Integral):
447 | normalized_shape = (normalized_shape,)
448 | normalized_shape = torch.Size(normalized_shape)
449 |
450 | assert len(normalized_shape) == 1
451 |
452 | self.weight = nn.Parameter(torch.ones(normalized_shape))
453 | self.bias = nn.Parameter(torch.zeros(normalized_shape))
454 | self.normalized_shape = normalized_shape
455 |
456 | def forward(self, x):
457 | mu = x.mean(-1, keepdim=True)
458 | sigma = x.var(-1, keepdim=True, unbiased=False)
459 | return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias
460 |
461 |
462 | class LayerNorm(nn.Module):
463 | def __init__(self, dim, LayerNorm_type):
464 | super(LayerNorm, self).__init__()
465 | if LayerNorm_type =='BiasFree':
466 | self.body = BiasFree_LayerNorm(dim)
467 | else:
468 | self.body = WithBias_LayerNorm(dim)
469 |
470 | def forward(self, x):
471 | h, w = x.shape[-2:]
472 | return to_4d(self.body(to_3d(x)), h, w)
473 |
474 | ##########################################################################
475 | ## Gated-Dconv Feed-Forward Network (GDFN)
476 | class FeedForward(nn.Module):
477 | def __init__(self, dim, ffn_expansion_factor, bias):
478 | super(FeedForward, self).__init__()
479 | hidden_features = int(dim*ffn_expansion_factor)
480 | self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
481 | self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)
482 | self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
483 |
484 | def forward(self, x):
485 | x = self.project_in(x)
486 | x1, x2 = self.dwconv(x).chunk(2, dim=1)
487 | x = F.gelu(x1) * x2
488 | x = self.project_out(x)
489 | return x
490 |
491 | class CrossAttention(nn.Module):
492 | def __init__(self, dim, num_heads, bias):
493 | super(CrossAttention, self).__init__()
494 | self.num_heads = num_heads
495 | self.temperature1 = nn.Parameter(torch.ones(num_heads, 1, 1))
496 | self.temperature2 = nn.Parameter(torch.ones(num_heads, 1, 1))
497 |
498 | self.q = nn.Conv2d(dim, dim*2, kernel_size=1, bias=bias)
499 | self.kv1 = nn.Conv2d(dim, dim*2, kernel_size=1, bias=bias)
500 | self.kv2 = nn.Conv2d(dim, dim*2, kernel_size=1, bias=bias)
501 | self.q_dwconv = nn.Conv2d(dim*2, dim*2, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias)
502 | self.kv1_dwconv = nn.Conv2d(dim*2, dim*2, kernel_size=3, stride=1, padding=1, groups=dim*2, bias=bias)
503 | self.kv2_dwconv = nn.Conv2d(dim*2, dim*2, kernel_size=3, stride=1, padding=1, groups=dim*2, bias=bias)
504 | self.project_out = nn.Conv2d(dim*2, dim, kernel_size=1, bias=bias)
505 |
506 | def forward(self, x, attn_kv1, attn_kv2):
507 | b,c,h,w = x.shape
508 |
509 | q_ = self.q_dwconv(self.q(x))
510 | kv1 = self.kv1_dwconv(self.kv1(attn_kv1))
511 | kv2 = self.kv2_dwconv(self.kv2(attn_kv2))
512 | q1,q2 = q_.chunk(2, dim=1)
513 | k1,v1 = kv1.chunk(2, dim=1)
514 | k2,v2 = kv2.chunk(2, dim=1)
515 |
516 | q1 = rearrange(q1, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
517 | q2 = rearrange(q2, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
518 | k1 = rearrange(k1, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
519 | v1 = rearrange(v1, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
520 | k2 = rearrange(k2, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
521 | v2 = rearrange(v2, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
522 |
523 | q1 = torch.nn.functional.normalize(q1, dim=-1)
524 | q2 = torch.nn.functional.normalize(q2, dim=-1)
525 | k1 = torch.nn.functional.normalize(k1, dim=-1)
526 | k2 = torch.nn.functional.normalize(k2, dim=-1)
527 |
528 | attn = (q1 @ k1.transpose(-2, -1)) * self.temperature1
529 | attn = attn.softmax(dim=-1)
530 | out1 = (attn @ v1)
531 |
532 | attn = (q2 @ k2.transpose(-2, -1)) * self.temperature2
533 | attn = attn.softmax(dim=-1)
534 | out2 = (attn @ v2)
535 |
536 | out1 = rearrange(out1, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
537 | out2 = rearrange(out2, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
538 | out = torch.cat((out1, out2), dim=1)
539 | out = self.project_out(out)
540 | return out
541 |
542 |
543 | ##########################################################################
544 | class CrossTransformerBlock(nn.Module):
545 | def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
546 | super(CrossTransformerBlock, self).__init__()
547 | self.norm1 = LayerNorm(dim, LayerNorm_type)
548 | self.norm_kv1 = LayerNorm(dim, LayerNorm_type)
549 | self.norm_kv2 = LayerNorm(dim, LayerNorm_type)
550 | self.attn = CrossAttention(dim, num_heads, bias)
551 | self.norm2 = LayerNorm(dim, LayerNorm_type)
552 | self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
553 |
554 | def forward(self, x, attn_kv1, attn_kv2):
555 | x = x + self.attn(self.norm1(x), self.norm_kv1(attn_kv1), self.norm_kv2(attn_kv2))
556 | x = x + self.ffn(self.norm2(x))
557 | return x
558 |
559 | class CrossTransformerLayer(nn.Module):
560 | def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type, num_blocks):
561 | super(CrossTransformerLayer, self).__init__()
562 | self.blocks = nn.ModuleList([CrossTransformerBlock(dim=dim, num_heads=num_heads,
563 | ffn_expansion_factor=ffn_expansion_factor, bias=bias,
564 | LayerNorm_type=LayerNorm_type) for i in range(num_blocks)])
565 |
566 | def forward(self, x, attn_kv=None, attn_kv2=None):
567 | for blk in self.blocks:
568 | x = blk(x, attn_kv, attn_kv2)
569 | return x
570 |
571 |
572 | ##########################################################################
573 | ## Multi-DConv Head Transposed Self-Attention (MDTA)
574 | class Attention(nn.Module):
575 | def __init__(self, dim, num_heads, bias):
576 | super(Attention, self).__init__()
577 | self.num_heads = num_heads
578 | self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
579 |
580 | self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
581 | self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
582 | self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
583 |
584 |
585 |
586 | def forward(self, x):
587 | b,c,h,w = x.shape
588 |
589 | qkv = self.qkv_dwconv(self.qkv(x))
590 | q,k,v = qkv.chunk(3, dim=1)
591 |
592 | q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
593 | k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
594 | v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
595 |
596 | q = torch.nn.functional.normalize(q, dim=-1)
597 | k = torch.nn.functional.normalize(k, dim=-1)
598 |
599 | attn = (q @ k.transpose(-2, -1)) * self.temperature
600 | attn = attn.softmax(dim=-1)
601 |
602 | out = (attn @ v)
603 |
604 | out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
605 |
606 | out = self.project_out(out)
607 | return out
608 |
609 |
610 | class Self_attention(nn.Module):
611 | def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
612 | super(Self_attention, self).__init__()
613 | self.norm1 = LayerNorm(dim, LayerNorm_type)
614 | self.attn = Attention(dim, num_heads, bias)
615 | self.norm2 = LayerNorm(dim, LayerNorm_type)
616 | self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
617 |
618 | def forward(self, x):
619 | x = x + self.attn(self.norm1(x))
620 | x = x + self.ffn(self.norm2(x))
621 | return x
622 |
623 |
624 | class SelfAttentionLayer(nn.Module):
625 | def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type, num_blocks):
626 | super(SelfAttentionLayer, self).__init__()
627 | self.blocks = nn.ModuleList([Self_attention(dim=dim, num_heads=num_heads, ffn_expansion_factor=ffn_expansion_factor, bias=bias,
628 | LayerNorm_type=LayerNorm_type) for i in range(num_blocks)])
629 |
630 | def forward(self, x):
631 | for blk in self.blocks:
632 | x = blk(x)
633 | return x
634 |
635 |
636 | class Upsample(nn.Module):
637 | def __init__(self, in_channel, out_channel):
638 | super(Upsample, self).__init__()
639 | self.deconv = nn.ConvTranspose2d(in_channel, out_channel, kernel_size=2, stride=2)
640 |
641 | def forward(self, x):
642 | out = self.deconv(x)
643 | return out
644 |
645 | def flops(self, H, W):
646 | flops = 0
647 | # conv
648 | flops += H*2*W*2*self.in_channel*self.out_channel*2*2
649 | print("Upsample:{%.2f}"%(flops/1e9))
650 | return flops
651 |
652 |
653 | class Transformer(nn.Module):
654 | def __init__(self, unit_dim):
655 | super(Transformer, self).__init__()
656 | ## init qurey networks
657 | self.init_qurey_net(unit_dim)
658 | self.init_decoder(unit_dim)
659 | ## last conv
660 | self.last_conv0 = conv3x3(unit_dim*4, 3)
661 | self.last_conv1 = conv3x3(unit_dim*2, 3)
662 | self.last_conv2 = conv3x3(unit_dim, 3)
663 |
664 | def init_decoder(self, unit_dim):
665 | ### decoder
666 | ### attention k,v building (synthesis)
667 | self.build_kv0_syn = conv3x3_leaky_relu(unit_dim*3, unit_dim*4)
668 | self.build_kv1_syn = conv3x3_leaky_relu(int(unit_dim*1.5), unit_dim*2)
669 | self.build_kv2_syn = conv3x3_leaky_relu(int(unit_dim*0.75), unit_dim)
670 | ### attention k, v building (warping)
671 | self.build_kv0_warp = conv3x3_leaky_relu(unit_dim*3+6, unit_dim*4)
672 | self.build_kv1_warp = conv3x3_leaky_relu(int(unit_dim*1.5)+6, unit_dim*2)
673 | self.build_kv2_warp = conv3x3_leaky_relu(int(unit_dim*0.75)+6, unit_dim)
674 | ## level 1
675 | self.decoder1_1 = CrossTransformerLayer(dim=unit_dim*4, num_heads=4, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias', num_blocks=2)
676 | self.decoder1_2 = SelfAttentionLayer(dim=unit_dim*4, num_heads=4, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias', num_blocks=2)
677 | ## level 2
678 | self.decoder2_1 = CrossTransformerLayer(dim=unit_dim*2, num_heads=4, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias', num_blocks=2)
679 | self.decoder2_2 = SelfAttentionLayer(dim=unit_dim*2, num_heads=2, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias', num_blocks=2)
680 | ## level 3
681 | self.decoder3_1 = CrossTransformerLayer(dim=unit_dim, num_heads=4, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias', num_blocks=2)
682 | self.decoder3_2 = SelfAttentionLayer(dim=unit_dim, num_heads=1, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias', num_blocks=2)
683 | ## upsample
684 | self.upsample0 = Upsample(unit_dim*4, unit_dim*2)
685 | self.upsample1 = Upsample(unit_dim*2, unit_dim)
686 | ## conv after body
687 | self.conv_after_body0 = conv_resblock_one(4*unit_dim, 2*unit_dim)
688 | self.conv_after_body1 = conv_resblock_one(2*unit_dim, unit_dim)
689 |
690 | ### qurey network
691 | def init_qurey_net(self, unit_dim):
692 | ### building query
693 | ## stage 1
694 | self.enc_conv0 = conv3x3_leaky_relu(unit_dim+6, unit_dim)
695 | ## stage 2
696 | self.enc_conv1 = conv3x3_leaky_relu(unit_dim, 2*unit_dim, stride=2)
697 | ## stage 3
698 | self.enc_conv2 = conv3x3_leaky_relu(2*unit_dim, 4*unit_dim, stride=2)
699 |
700 | ## query buiding !!
701 | def build_qurey(self, event_feature, frame_feature, warped_feature):
702 | cat_in0 = torch.cat((event_feature[0], frame_feature[0], warped_feature[0]), dim=1)
703 | Q0 = self.enc_conv0(cat_in0)
704 | Q1 = self.enc_conv1(Q0)
705 | Q2 = self.enc_conv2(Q1)
706 | return [Q0, Q1, Q2]
707 |
708 | def forward_decoder(self, Q_list, warped_feature, frame_feature, event_feature):
709 | ## syntheis kv building
710 | cat_in0_syn = torch.cat((frame_feature[2], event_feature[2]), dim=1)
711 | attn_kv0_syn = self.build_kv0_syn(cat_in0_syn)
712 | cat_in1_syn = torch.cat((frame_feature[1], event_feature[1]), dim=1)
713 | attn_kv1_syn = self.build_kv1_syn(cat_in1_syn)
714 | cat_in2_syn = torch.cat((frame_feature[0], event_feature[0]), dim=1)
715 | attn_kv2_syn = self.build_kv2_syn(cat_in2_syn)
716 | ## warping kv building
717 | cat_in0_warp = torch.cat((warped_feature[2], event_feature[2]), dim=1)
718 | attn_kv0_warp = self.build_kv0_warp(cat_in0_warp)
719 | cat_in1_warp = torch.cat((warped_feature[1], event_feature[1]), dim=1)
720 | attn_kv1_warp = self.build_kv1_warp(cat_in1_warp)
721 | cat_in2_warp = torch.cat((warped_feature[0], event_feature[0]), dim=1)
722 | attn_kv2_warp = self.build_kv2_warp(cat_in2_warp)
723 | ## out 0
724 | _Q0 = Q_list[2]
725 | out0 = self.decoder1_1(_Q0, attn_kv0_syn, attn_kv0_warp)
726 | out0 = self.decoder1_2(out0)
727 | up_out0 = self.upsample0(out0)
728 | ## out 1
729 | _Q1 = Q_list[1]
730 | _Q1 = self.conv_after_body0(torch.cat((_Q1, up_out0), dim=1))
731 | out1 = self.decoder2_1(_Q1, attn_kv1_syn, attn_kv1_warp)
732 | out1 = self.decoder2_2(out1)
733 | up_out1 = self.upsample1(out1)
734 | ## out2
735 | _Q2 = Q_list[0]
736 | _Q2 = self.conv_after_body1(torch.cat((_Q2, up_out1), dim=1))
737 | out2 = self.decoder3_1(_Q2, attn_kv2_syn, attn_kv2_warp)
738 | out2 = self.decoder3_2(out2)
739 | return [out0, out1, out2]
740 |
741 | def forward(self, event_feature, frame_feature, warped_feature):
742 | ### forward encoder
743 | Q_list = self.build_qurey(event_feature, frame_feature, warped_feature)
744 | ### forward decoder
745 | out_decoder = self.forward_decoder(Q_list, warped_feature, frame_feature, event_feature)
746 | ### synthesis frame
747 | img0 = self.last_conv0(out_decoder[0])
748 | img1 = self.last_conv1(out_decoder[1])
749 | img2 = self.last_conv2(out_decoder[2])
750 | return [img2, img1, img0]
751 |
752 |
753 |
754 |
755 | class EventInterpNet(nn.Module):
756 | def __init__(self, num_bins=16, flow_debug=False):
757 | super(EventInterpNet, self).__init__()
758 | unit_dim = 32
759 | # scale
760 | self.scale = 3
761 | # flownet
762 | self.flownet = FlowNet(md=4, tb_debug=flow_debug)
763 | self.flow_debug = flow_debug
764 | # encoder
765 | self.encoder_f = frame_encoder(3, unit_dim//4)
766 | self.encoder_e = event_encoder(16, unit_dim//2)
767 | # decoder
768 | self.transformer = Transformer(unit_dim*2)
769 | # channel scaling convolution
770 | self.conv_list = nn.ModuleList([conv1x1(unit_dim, unit_dim), conv1x1(unit_dim, unit_dim), conv1x1(unit_dim, unit_dim)])
771 | # mode information
772 | self.mode = 'flow'
773 |
774 | def set_mode(self, mode):
775 | self.mode = mode
776 |
777 | def bwarp(self, x, flo):
778 | '''
779 | x shape : [B,C,T,H,W]
780 | t_value shape : [B,1] ###############
781 | '''
782 | B, C, H, W = x.size()
783 | # mesh grid
784 | xx = torch.arange(0, W).view(1, 1, 1, W).expand(B, 1, H, W)
785 | yy = torch.arange(0, H).view(1, 1, H, 1).expand(B, 1, H, W)
786 | grid = torch.cat((xx, yy), 1).float()
787 |
788 | if x.is_cuda:
789 | grid = grid.cuda()
790 | vgrid = torch.autograd.Variable(grid) + flo
791 |
792 | # scale grid to [-1,1]
793 | vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0
794 | vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0
795 |
796 | vgrid = vgrid.permute(0, 2, 3, 1) # [B,H,W,2]
797 | output = nn.functional.grid_sample(x, vgrid, align_corners=True)
798 | mask = torch.autograd.Variable(torch.ones(x.size())).cuda()
799 | mask = nn.functional.grid_sample(mask, vgrid, align_corners=True)
800 | # mask[mask<0.9999] = 0
801 | # mask[mask>0] = 1
802 | mask = mask.masked_fill_(mask < 0.999, 0)
803 | mask = mask.masked_fill_(mask > 0, 1)
804 | return output * mask
805 |
806 | def Flow_pyramid(self, flow):
807 | flow_pyr = []
808 | flow_pyr.append(flow)
809 | for i in range(1, 3):
810 | flow_pyr.append(F.interpolate(flow, scale_factor=0.5 ** i, mode='bilinear') * (0.5 ** i))
811 | return flow_pyr
812 |
813 | def Img_pyramid(self, Img):
814 | img_pyr = []
815 | img_pyr.append(Img)
816 | for i in range(1, 3):
817 | img_pyr.append(F.interpolate(Img, scale_factor=0.5 ** i, mode='bilinear'))
818 | return img_pyr
819 |
820 | def synthesis(self, batch, OF_t0, OF_t1):
821 | ## frame encoding
822 | f_frame0 = self.encoder_f(batch['image_input0'])
823 | f_frame1 = self.encoder_f(batch['image_input1'])
824 | ## OF pyramid
825 | OF_t0_pyramid = self.Flow_pyramid(OF_t0[0])
826 | OF_t1_pyramid = self.Flow_pyramid(OF_t1[0])
827 | ## image pyramid
828 | I0_pyramid = self.Img_pyramid(batch['image_input0'])
829 | I1_pyramid = self.Img_pyramid(batch['image_input1'])
830 | # frame0_warped, frame1_warped = [], []
831 | warped_feature, frame_feature = [], []
832 | for idx in range(self.scale):
833 | frame0_warped = self.bwarp(torch.cat((f_frame0[idx], I0_pyramid[idx]),dim=1), OF_t0_pyramid[idx])
834 | frame1_warped = self.bwarp(torch.cat((f_frame1[idx], I1_pyramid[idx]),dim=1), OF_t1_pyramid[idx])
835 | warped_feature.append(torch.cat((frame0_warped, frame1_warped), dim=1))
836 | frame_feature.append(torch.cat((f_frame0[idx], f_frame1[idx]), dim=1))
837 | event_feature = []
838 | # event encoding for frame interpolation
839 | f_event_0t = self.encoder_e(batch['event_input_0t'])
840 | f_event_t1 = self.encoder_e(batch['event_input_t1'])
841 | for idx in range(self.scale):
842 | event_feature.append(torch.cat((f_event_0t[idx], f_event_t1[idx]), dim=1))
843 | img_out = self.transformer(event_feature, frame_feature, warped_feature)
844 | interp_out = []
845 | for i in range(self.scale):
846 | interp_out.append(torch.clamp(img_out[i], 0, 1))
847 | return interp_out
848 |
849 | def forward(self, batch):
850 | output_dict = {}
851 | # --- Flow-only mode ---
852 | if self.mode == 'flow':
853 | output_dict['flow_out'] = self.flownet(batch)
854 | # --- Joint mode: ---
855 | elif self.mode == 'joint':
856 | flow_out = self.flownet(batch)
857 | interp_out = self.synthesis(batch, flow_out['flow_t0_dict'], flow_out['flow_t1_dict'])
858 | output_dict.update({'flow_out': flow_out, 'interp_out': interp_out})
859 | else:
860 | raise ValueError(f"Unsupported mode: {self.mode}")
861 | return output_dict
--------------------------------------------------------------------------------
/models/final_models/ours_large.py:
--------------------------------------------------------------------------------
1 | from models.final_models.submodules import *
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn.modules import conv
6 | from correlation_package.correlation import Correlation
7 | from einops import rearrange, repeat
8 | from einops.layers.torch import Rearrange
9 | from timm.models.layers import DropPath, trunc_normal_, to_2tuple
10 | from functools import reduce, lru_cache
11 | import torch.nn.functional as tf
12 | from torch.autograd import Variable
13 | from einops import rearrange
14 | import math
15 | import numbers
16 | import collections
17 |
18 |
19 | def conv(in_planes, out_planes, kernel_size=3, stride=1, dilation=1, isReLU=True):
20 | if isReLU:
21 | return nn.Sequential(
22 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation,
23 | padding=((kernel_size - 1) * dilation) // 2, bias=True),
24 | nn.LeakyReLU(0.1, inplace=True)
25 | )
26 | else:
27 | return nn.Sequential(
28 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, dilation=dilation,
29 | padding=((kernel_size - 1) * dilation) // 2, bias=True)
30 | )
31 |
32 | def predict_flow(in_planes):
33 | return nn.Conv2d(in_planes,2,kernel_size=3,stride=1,padding=1,bias=True)
34 |
35 | def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
36 | return nn.ConvTranspose2d(in_planes, out_planes, kernel_size, stride, padding, bias=True)
37 |
38 |
39 | class encoder_event_flow(nn.Module):
40 | def __init__(self, num_chs):
41 | super(encoder_event_flow, self).__init__()
42 | self.conv1 = conv_resblock_one(num_chs[0], num_chs[1], stride=1)
43 | self.conv2 = conv_resblock_one(num_chs[1], num_chs[2], stride=1)
44 | self.conv3 = conv_resblock_one(num_chs[2], num_chs[3], stride=2)
45 | self.conv4 = conv_resblock_one(num_chs[3], num_chs[4], stride=2)
46 |
47 | def forward(self, im):
48 | x = self.conv1(im)
49 | c11 = self.conv2(x)
50 | c12 = self.conv3(c11)
51 | c13 = self.conv4(c12)
52 | return c11, c12, c13
53 |
54 |
55 | class encoder_event_for_image_flow(nn.Module):
56 | def __init__(self, num_chs):
57 | super(encoder_event_for_image_flow, self).__init__()
58 | self.conv1 = conv_resblock_one(num_chs[0], num_chs[1], stride=1)
59 | self.conv2 = conv_resblock_one(num_chs[1], num_chs[2], stride=2)
60 | self.conv3 = conv_resblock_one(num_chs[2], num_chs[3], stride=2)
61 | self.conv4 = conv_resblock_one(num_chs[3], num_chs[4], stride=2)
62 |
63 | def forward(self, im):
64 | x = self.conv1(im)
65 | c11 = self.conv2(x)
66 | c12 = self.conv3(c11)
67 | c13 = self.conv4(c12)
68 | return c11, c12, c13
69 |
70 |
71 | class encoder_image_for_image_flow(nn.Module):
72 | def __init__(self, num_chs):
73 | super(encoder_image_for_image_flow, self).__init__()
74 | self.conv1 = conv_resblock_one(num_chs[0], num_chs[1], stride=1)
75 | self.conv2 = conv_resblock_one(num_chs[1], num_chs[2], stride=2)
76 | self.conv3 = conv_resblock_one(num_chs[2], num_chs[3], stride=2)
77 | self.conv4 = conv_resblock_one(num_chs[3], num_chs[4], stride=2)
78 |
79 | def forward(self, image):
80 | x = self.conv1(image)
81 | f1 = self.conv2(x)
82 | f2 = self.conv3(f1)
83 | f3 = self.conv4(f2)
84 | return f1, f2, f3
85 |
86 | def upsample2d(inputs, target_as, mode="bilinear"):
87 | _, _, h, w = target_as.size()
88 | return tf.interpolate(inputs, [h, w], mode=mode, align_corners=True)
89 |
90 | def upsample2d_hw(inputs, h, w, mode="bilinear"):
91 | return tf.interpolate(inputs, [h, w], mode=mode, align_corners=True)
92 |
93 |
94 | class DenseBlock(nn.Module):
95 | def __init__(self, ch_in):
96 | super(DenseBlock, self).__init__()
97 | self.conv1 = conv(ch_in, 128)
98 | self.conv2 = conv(ch_in + 128, 128)
99 | self.conv3 = conv(ch_in + 256, 96)
100 | self.conv4 = conv(ch_in + 352, 64)
101 | self.conv5 = conv(ch_in + 416, 32)
102 | self.conv_last = conv(ch_in + 448, 2, isReLU=False)
103 |
104 | def forward(self, x):
105 | x1 = torch.cat([self.conv1(x), x], dim=1)
106 | x2 = torch.cat([self.conv2(x1), x1], dim=1)
107 | x3 = torch.cat([self.conv3(x2), x2], dim=1)
108 | x4 = torch.cat([self.conv4(x3), x3], dim=1)
109 | x5 = torch.cat([self.conv5(x4), x4], dim=1)
110 | x_out = self.conv_last(x5)
111 | return x5, x_out
112 |
113 |
114 |
115 | class FlowEstimatorDense(nn.Module):
116 | def __init__(self, ch_in=64, f_channels=(128, 128, 96, 64, 32, 32), ch_out=2):
117 | super(FlowEstimatorDense, self).__init__()
118 | N = 0
119 | ind = 0
120 | N += ch_in
121 | self.conv1 = conv(N, f_channels[ind])
122 | N += f_channels[ind]
123 | ind += 1
124 | self.conv2 = conv(N, f_channels[ind])
125 | N += f_channels[ind]
126 | ind += 1
127 | self.conv3 = conv(N, f_channels[ind])
128 | N += f_channels[ind]
129 | ind += 1
130 | self.conv4 = conv(N, f_channels[ind])
131 | N += f_channels[ind]
132 | ind += 1
133 | self.conv5 = conv(N, f_channels[ind])
134 | N += f_channels[ind]
135 | self.num_feature_channel = N
136 | ind += 1
137 | self.conv_last = conv(N, ch_out, isReLU=False)
138 |
139 | def forward(self, x):
140 | x1 = torch.cat([self.conv1(x), x], axis=1)
141 | x2 = torch.cat([self.conv2(x1), x1], axis=1)
142 | x3 = torch.cat([self.conv3(x2), x2], axis=1)
143 | x4 = torch.cat([self.conv4(x3), x3], axis=1)
144 | x5 = torch.cat([self.conv5(x4), x4], axis=1)
145 | x_out = self.conv_last(x5)
146 | return x5, x_out
147 |
148 | class Tfeat_RefineBlock(nn.Module):
149 | def __init__(self, ch_in_frame, ch_in_event, ch_in_frame_prev, prev_scale=False):
150 | super(Tfeat_RefineBlock, self).__init__()
151 | if prev_scale:
152 | nf = int((ch_in_frame*2+ch_in_event+ch_in_frame_prev)/4)
153 | else:
154 | nf = int((ch_in_frame*2+ch_in_event)/4)
155 | self.conv_refine = nn.Sequential(conv1x1(4*nf, nf), nn.ReLU(), conv3x3(nf, 2*nf), nn.ReLU(), conv_resblock_one(2*nf, ch_in_frame))
156 |
157 | def forward(self, x):
158 | x1 = self.conv_refine(x)
159 | return x1
160 |
161 | def rescale_flow(flow, width_im, height_im):
162 | u_scale = float(width_im / flow.size(3))
163 | v_scale = float(height_im / flow.size(2))
164 | u, v = flow.chunk(2, dim=1)
165 | u = u_scale*u
166 | v = v_scale*v
167 | return torch.cat([u, v], dim=1)
168 |
169 |
170 | class FlowNet(nn.Module):
171 | def __init__(self, md=4, tb_debug=False):
172 | super(FlowNet, self).__init__()
173 | ## argument
174 | self.tb_debug = tb_debug
175 | # flow scale
176 | self.flow_scale = 20
177 | num_chs_frame = [3, 16, 32, 64, 96]
178 | num_chs_event = [16, 16, 32, 64, 128]
179 | num_chs_event_image = [16, 16, 16, 32, 64]
180 | ## for event-level flow
181 | self.encoder_event = encoder_event_flow(num_chs_event)
182 | ## for image-level flow
183 | self.encoder_image_flow = encoder_image_for_image_flow(num_chs_frame)
184 | self.encoder_image_flow_event = encoder_event_for_image_flow(num_chs_event_image)
185 | ## leaky relu
186 | self.leakyRELU = nn.LeakyReLU(0.1)
187 | ## correlation channel value
188 | self.corr = Correlation(pad_size=md, kernel_size=1, max_displacement=md, stride1=1, stride2=1, corr_multiply=1)
189 | ## correlation channel value
190 | nd = (2*md+1)**2
191 | self.corr_refinement = nn.ModuleList([DenseBlock(nd+num_chs_frame[-1]+2),
192 | DenseBlock(nd+num_chs_frame[-2]+2),
193 | DenseBlock(nd+num_chs_frame[-3]+2),
194 | DenseBlock(nd+num_chs_frame[-4]+2)
195 | ])
196 | self.decoder_event = nn.ModuleList([conv_resblock_one(num_chs_event[-1], num_chs_event[-1]),
197 | conv_resblock_one(num_chs_event[-2]+ num_chs_event[-1]+2, num_chs_event[-2]),
198 | conv_resblock_one(num_chs_event[-3]+ num_chs_event[-2]+2, num_chs_event[-3])])
199 | self.predict_flow = nn.ModuleList([conv3x3_leaky_relu(num_chs_event[-1], 2),
200 | conv3x3_leaky_relu(num_chs_event[-2], 2),
201 | conv3x3_leaky_relu(num_chs_event[-3], 2)])
202 | self.conv_frame = nn.ModuleList([conv3x3_leaky_relu(num_chs_frame[-2], 32),
203 | conv3x3_leaky_relu(num_chs_frame[-3], 32)])
204 | self.conv_frame_t = nn.ModuleList([conv3x3_leaky_relu(num_chs_frame[-2], 32),
205 | conv3x3_leaky_relu(num_chs_frame[-3], 32)])
206 | self.flow_fusion_block = FlowEstimatorDense(32*3+4, (32, 32, 32, 16, 8), 1)
207 | self.feat_t_refinement = nn.ModuleList([Tfeat_RefineBlock(num_chs_frame[-1], num_chs_event_image[-1]*2, None, prev_scale=False),
208 | Tfeat_RefineBlock(num_chs_frame[-2], num_chs_event_image[-2]*2, num_chs_frame[-1], prev_scale=True),
209 | Tfeat_RefineBlock(num_chs_frame[-3], num_chs_event_image[-3]*2, num_chs_frame[-2], prev_scale=True),
210 | ])
211 |
212 |
213 | def warp(self, x, flo):
214 | """
215 | warp an image/tensor (im2) back to im1, according to the optical flow
216 | x: [B, C, H, W] (im2)
217 | flo: [B, 2, H, W] flow
218 | """
219 | B, C, H, W = x.size()
220 | # mesh grid
221 | xx = torch.arange(0, W).view(1,-1).repeat(H,1)
222 | yy = torch.arange(0, H).view(-1,1).repeat(1,W)
223 | xx = xx.view(1,1,H,W).repeat(B,1,1,1)
224 | yy = yy.view(1,1,H,W).repeat(B,1,1,1)
225 | grid = torch.cat((xx,yy),1).float()
226 |
227 | if x.is_cuda:
228 | grid = grid.cuda()
229 | vgrid = Variable(grid) + flo
230 |
231 | # scale grid to [-1,1]
232 | vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:].clone() / max(W-1,1)-1.0
233 | vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:].clone() / max(H-1,1)-1.0
234 |
235 | vgrid = vgrid.permute(0,2,3,1)
236 | output = nn.functional.grid_sample(x, vgrid)
237 | mask = torch.autograd.Variable(torch.ones(x.size())).cuda()
238 | mask = nn.functional.grid_sample(mask, vgrid)
239 | mask[mask<0.9999] = 0
240 | mask[mask>0] = 1
241 | return output*mask
242 |
243 | def normalize_features(self, feature_list, normalize, center, moments_across_channels=True, moments_across_images=True):
244 | # Compute feature statistics.
245 | statistics = collections.defaultdict(list)
246 | axes = [1, 2, 3] if moments_across_channels else [2, 3] # [b, c, h, w]
247 | for feature_image in feature_list:
248 | mean = torch.mean(feature_image, axis=axes, keepdims=True) # [b,1,1,1] or [b,c,1,1]
249 | variance = torch.var(feature_image, axis=axes, keepdims=True) # [b,1,1,1] or [b,c,1,1]
250 | statistics['mean'].append(mean)
251 | statistics['var'].append(variance)
252 |
253 | if moments_across_images:
254 | statistics['mean'] = ([torch.mean(F.stack(statistics['mean'], axis=0), axis=(0, ))] * len(feature_list))
255 | statistics['var'] = ([torch.var(F.stack(statistics['var'], axis=0), axis=(0, ))] * len(feature_list))
256 |
257 | statistics['std'] = [torch.sqrt(v + 1e-16) for v in statistics['var']]
258 |
259 | # Center and normalize features.
260 | if center:
261 | feature_list = [f - mean for f, mean in zip(feature_list, statistics['mean'])]
262 | if normalize:
263 | feature_list = [f / std for f, std in zip(feature_list, statistics['std'])]
264 | return feature_list
265 |
266 | def forward(self, batch):
267 | # F for frame feature
268 | # E for event feature
269 | ### encoding
270 | ## image feature
271 | # feature pyramid
272 | F0_pyramid = self.encoder_image_flow(batch['image_input0'])[::-1]
273 | F1_pyramid = self.encoder_image_flow(batch['image_input1'])[::-1]
274 | E_0t_pyramid = self.encoder_image_flow_event(batch['event_input_0t'])[::-1]
275 | E_t1_pyramid = self.encoder_image_flow_event(batch['event_input_t1'])[::-1]
276 | # encoder event
277 | E_t0_pyramid_flow = self.encoder_event(batch['event_input_t0'])[::-1]
278 | E_t1_pyramid_flow = self.encoder_event(batch['event_input_t1'])[::-1]
279 | ### decoding optical flow
280 | ## level 0
281 | flow_t0_out_dict, flow_t1_out_dict, flow_t0_dict, flow_t1_dict = [], [], [], []
282 | ## event flow and image flow
283 | if self.tb_debug:
284 | event_flow_dict, image_flow_dict, fusion_flow_dict, mask_dict = [], [], [], []
285 | for level, (E_t0_flow, E_t1_flow, E_0t, E_t1, F0, F1) in enumerate(zip(E_t0_pyramid_flow, E_t1_pyramid_flow, E_0t_pyramid, E_t1_pyramid, F0_pyramid, F1_pyramid)):
286 | if level==0:
287 | ## event flow generation
288 | feat_t0_ev = self.decoder_event[level](E_t0_flow)
289 | feat_t1_ev = self.decoder_event[level](E_t1_flow)
290 | flow_event_t0 = self.predict_flow[level](feat_t0_ev)
291 | flow_event_t1 = self.predict_flow[level](feat_t1_ev)
292 | ## fusion flow(scale == 0)
293 | flow_fusion_t0 = flow_event_t0
294 | flow_fusion_t1 = flow_event_t1
295 | ## t feature
296 | feat_t_in = torch.cat((F0, F1, E_0t, E_t1), dim=1)
297 | feat_t = self.feat_t_refinement[level](feat_t_in)
298 | else:
299 | ## feat t
300 | upfeat0_t = upsample2d(feat_t, F0)
301 | feat_t_in = torch.cat((upfeat0_t, F0, F1, E_0t, E_t1), dim=1)
302 | feat_t = self.feat_t_refinement[level](feat_t_in)
303 | #### event-based optical flow
304 | ## event flow generation
305 | upflow_t0 = rescale_flow(upsample2d(flow_t0_out_dict[level-1], E_t0_flow), E_t0_flow.size(3), E_t0_flow.size(2))
306 | upflow_t1 = rescale_flow(upsample2d(flow_t1_out_dict[level-1], E_t1_flow), E_t1_flow.size(3), E_t1_flow.size(2))
307 | # upsample feat_t0
308 | feat_t0_ev_up = upsample2d(feat_t0_ev, E_t0_flow)
309 | feat_t1_ev_up = upsample2d(feat_t1_ev, E_t1_flow)
310 | # decoder event
311 | flow_t0_ev_up = rescale_flow(upsample2d(flow_event_t0, E_t0_flow), E_t0_flow.size(3), E_t0_flow.size(2))
312 | flow_t1_ev_up = rescale_flow(upsample2d(flow_event_t1, E_t1_flow), E_t1_flow.size(3), E_t1_flow.size(2))
313 | feat_t0_ev = self.decoder_event[level](torch.cat((E_t0_flow, feat_t0_ev_up, flow_t0_ev_up ), dim=1))
314 | feat_t1_ev = self.decoder_event[level](torch.cat((E_t1_flow, feat_t1_ev_up, flow_t1_ev_up), dim=1))
315 | ## project flow
316 | flow_event_t0_ = self.predict_flow[level](feat_t0_ev)
317 | flow_event_t1_ = self.predict_flow[level](feat_t1_ev)
318 | ## fusion flow
319 | flow_event_t0 = flow_t0_ev_up + flow_event_t0_
320 | flow_event_t1 = flow_t1_ev_up + flow_event_t1_
321 | # flow rescale
322 | down_evflow_t0 = rescale_flow(upsample2d(flow_event_t0, F0), F0.size(3), F0.size(2))
323 | down_evflow_t1 = rescale_flow(upsample2d(flow_event_t1, F1), F1.size(3), F1.size(2))
324 | down_upflow_t0 = rescale_flow(upsample2d(flow_t0_out_dict[level-1], F0), F0.size(3), F0.size(2))
325 | down_upflow_t1 = rescale_flow(upsample2d(flow_t1_out_dict[level-1], F1), F1.size(3), F1.size(2))
326 | ## warping with event flow and fusion flow
327 | F0_re = self.conv_frame[level-1](F0)
328 | F0_up_warp_ev = self.warp(F0_re, self.flow_scale*down_evflow_t0)
329 | F0_up_warp_frame = self.warp(F0_re, self.flow_scale*down_upflow_t0)
330 | F1_re = self.conv_frame[level-1](F1)
331 | F1_up_warp_ev = self.warp(F1_re, self.flow_scale*down_evflow_t1)
332 | F1_up_warp_frame = self.warp(F1_re, self.flow_scale*down_upflow_t1)
333 | Ft_up = self.conv_frame_t[level-1](feat_t)
334 | ## flow fusion
335 | _, out_fusion_t0 = self.flow_fusion_block(torch.cat((F0_up_warp_ev, F0_up_warp_frame, Ft_up, down_evflow_t0, down_upflow_t0), dim=1))
336 | _, out_fusion_t1 = self.flow_fusion_block(torch.cat((F1_up_warp_ev, F1_up_warp_frame, Ft_up, down_evflow_t1, down_upflow_t1), dim=1))
337 | mask_t0 = upsample2d(torch.sigmoid(out_fusion_t0[:, -1, : ,:])[:, None, :, :], E_t0_flow)
338 | mask_t1 = upsample2d(torch.sigmoid(out_fusion_t1[:, -1, :, :])[:, None, :, :], E_t1_flow)
339 | flow_fusion_t0 = (1-mask_t0)*upflow_t0 + mask_t0*flow_event_t0
340 | flow_fusion_t1 = (1-mask_t1)*upflow_t1 + mask_t1*flow_event_t1
341 | ## intermediate output
342 | if self.tb_debug:
343 | event_flow_dict.append(flow_event_t0)
344 | fusion_flow_dict.append(flow_fusion_t0)
345 | image_flow_dict.append(upflow_t0)
346 | mask_dict.append(mask_t0)
347 | # flow rescale
348 | down_flow_fusion_t0 = rescale_flow(upsample2d(flow_fusion_t0, F0), F0.size(3), F0.size(2))
349 | down_flow_fusion_t1 = rescale_flow(upsample2d(flow_fusion_t1, F1), F1.size(3), F1.size(2))
350 | # warping with optical flow
351 | feat10 = self.warp(F0, self.flow_scale*down_flow_fusion_t0)
352 | feat11 = self.warp(F1, self.flow_scale*down_flow_fusion_t1)
353 | # feature normalization
354 | feat_t_norm, feat10_norm, feat11_norm = self.normalize_features([feat_t, feat10, feat11], normalize=True, center=True, moments_across_channels=False, moments_across_images=False)
355 | # correlation
356 | corr_t0 = self.leakyRELU(self.corr(feat_t_norm, feat10_norm))
357 | corr_t1 = self.leakyRELU(self.corr(feat_t_norm, feat11_norm))
358 | # correlation refienement
359 | _, res_flow_t0 = self.corr_refinement[level](torch.cat((corr_t0, feat_t, down_flow_fusion_t0), dim=1))
360 | _, res_flow_t1 = self.corr_refinement[level](torch.cat((corr_t1, feat_t, down_flow_fusion_t1), dim=1))
361 | # frame-based optical flow generation
362 | flow_t0_frame = down_flow_fusion_t0 + res_flow_t0
363 | flow_t1_frame = down_flow_fusion_t1 + res_flow_t1
364 | ## upsampling frame-based optical flow
365 | upflow_t0_frame = rescale_flow(upsample2d(flow_t0_frame, flow_fusion_t0), flow_fusion_t0.size(3), flow_fusion_t0.size(2))
366 | upflow_t1_frame = rescale_flow(upsample2d(flow_t1_frame, flow_fusion_t1), flow_fusion_t1.size(3), flow_fusion_t1.size(2))
367 | ### output
368 | flow_t0_out_dict.append(upflow_t0_frame)
369 | flow_t1_out_dict.append(upflow_t1_frame)
370 | flow_t0_dict.append(self.flow_scale*upflow_t0_frame)
371 | flow_t1_dict.append(self.flow_scale*upflow_t1_frame)
372 | flow_t0_dict = flow_t0_dict[::-1]
373 | flow_t1_dict = flow_t1_dict[::-1]
374 | ## final output return
375 | flow_output_dict = {}
376 | flow_output_dict['flow_t0_dict'] = flow_t0_dict
377 | flow_output_dict['flow_t1_dict'] = flow_t1_dict
378 | if self.tb_debug:
379 | flow_output_dict['event_flow_dict'] = event_flow_dict
380 | flow_output_dict['fusion_flow_dict'] = fusion_flow_dict
381 | flow_output_dict['image_flow_dict'] = image_flow_dict
382 | flow_output_dict['mask_dict'] = mask_dict
383 | return flow_output_dict
384 |
385 |
386 | class frame_encoder(nn.Module):
387 | def __init__(self, in_dims, nf):
388 | super(frame_encoder, self).__init__()
389 | self.conv0 = conv3x3_leaky_relu(in_dims, nf)
390 | self.conv1 = conv_resblock_two(nf, nf)
391 | self.conv2 = conv_resblock_two(nf, 2*nf, stride=2)
392 | self.conv3 = conv_resblock_two(2*nf, 4*nf, stride=2)
393 |
394 | def forward(self, x):
395 | x_ = self.conv0(x)
396 | f1 = self.conv1(x_)
397 | f2 = self.conv2(f1)
398 | f3 = self.conv3(f2)
399 | return [f1, f2, f3]
400 |
401 | class event_encoder(nn.Module):
402 | def __init__(self, in_dims, nf):
403 | super(event_encoder, self).__init__()
404 | self.conv0 = conv3x3_leaky_relu(in_dims, nf)
405 | self.conv1 = conv_resblock_two(nf, nf)
406 | self.conv2 = conv_resblock_two(nf, 2*nf, stride=2)
407 | self.conv3 = conv_resblock_two(2*nf, 4*nf, stride=2)
408 |
409 | def forward(self, x):
410 | x_ = self.conv0(x)
411 | f1 = self.conv1(x_)
412 | f2 = self.conv2(f1)
413 | f3 = self.conv3(f2)
414 | return [f1, f2, f3]
415 |
416 | ##################################################
417 | ################# Restormer #####################
418 |
419 | ##########################################################################
420 | ## Layer Norm
421 | def to_3d(x):
422 | return rearrange(x, 'b c h w -> b (h w) c')
423 |
424 | def to_4d(x,h,w):
425 | return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)
426 |
427 | class BiasFree_LayerNorm(nn.Module):
428 | def __init__(self, normalized_shape):
429 | super(BiasFree_LayerNorm, self).__init__()
430 | if isinstance(normalized_shape, numbers.Integral):
431 | normalized_shape = (normalized_shape,)
432 | normalized_shape = torch.Size(normalized_shape)
433 |
434 | assert len(normalized_shape) == 1
435 |
436 | self.weight = nn.Parameter(torch.ones(normalized_shape))
437 | self.normalized_shape = normalized_shape
438 |
439 | def forward(self, x):
440 | sigma = x.var(-1, keepdim=True, unbiased=False)
441 | return x / torch.sqrt(sigma+1e-5) * self.weight
442 |
443 |
444 | class WithBias_LayerNorm(nn.Module):
445 | def __init__(self, normalized_shape):
446 | super(WithBias_LayerNorm, self).__init__()
447 | if isinstance(normalized_shape, numbers.Integral):
448 | normalized_shape = (normalized_shape,)
449 | normalized_shape = torch.Size(normalized_shape)
450 |
451 | assert len(normalized_shape) == 1
452 |
453 | self.weight = nn.Parameter(torch.ones(normalized_shape))
454 | self.bias = nn.Parameter(torch.zeros(normalized_shape))
455 | self.normalized_shape = normalized_shape
456 |
457 | def forward(self, x):
458 | mu = x.mean(-1, keepdim=True)
459 | sigma = x.var(-1, keepdim=True, unbiased=False)
460 | return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias
461 |
462 |
463 | class LayerNorm(nn.Module):
464 | def __init__(self, dim, LayerNorm_type):
465 | super(LayerNorm, self).__init__()
466 | if LayerNorm_type =='BiasFree':
467 | self.body = BiasFree_LayerNorm(dim)
468 | else:
469 | self.body = WithBias_LayerNorm(dim)
470 |
471 | def forward(self, x):
472 | h, w = x.shape[-2:]
473 | return to_4d(self.body(to_3d(x)), h, w)
474 |
475 | ##########################################################################
476 | ## Gated-Dconv Feed-Forward Network (GDFN)
477 | class FeedForward(nn.Module):
478 | def __init__(self, dim, ffn_expansion_factor, bias):
479 | super(FeedForward, self).__init__()
480 | hidden_features = int(dim*ffn_expansion_factor)
481 | self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
482 | self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)
483 | self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
484 |
485 | def forward(self, x):
486 | x = self.project_in(x)
487 | x1, x2 = self.dwconv(x).chunk(2, dim=1)
488 | x = F.gelu(x1) * x2
489 | x = self.project_out(x)
490 | return x
491 |
492 | class CrossAttention(nn.Module):
493 | def __init__(self, dim, num_heads, bias):
494 | super(CrossAttention, self).__init__()
495 | self.num_heads = num_heads
496 | self.temperature1 = nn.Parameter(torch.ones(num_heads, 1, 1))
497 | self.temperature2 = nn.Parameter(torch.ones(num_heads, 1, 1))
498 |
499 | self.q = nn.Conv2d(dim, dim*2, kernel_size=1, bias=bias)
500 | self.kv1 = nn.Conv2d(dim, dim*2, kernel_size=1, bias=bias)
501 | self.kv2 = nn.Conv2d(dim, dim*2, kernel_size=1, bias=bias)
502 | self.q_dwconv = nn.Conv2d(dim*2, dim*2, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias)
503 | self.kv1_dwconv = nn.Conv2d(dim*2, dim*2, kernel_size=3, stride=1, padding=1, groups=dim*2, bias=bias)
504 | self.kv2_dwconv = nn.Conv2d(dim*2, dim*2, kernel_size=3, stride=1, padding=1, groups=dim*2, bias=bias)
505 | self.project_out = nn.Conv2d(dim*2, dim, kernel_size=1, bias=bias)
506 |
507 | def forward(self, x, attn_kv1, attn_kv2):
508 | b,c,h,w = x.shape
509 |
510 | q_ = self.q_dwconv(self.q(x))
511 | kv1 = self.kv1_dwconv(self.kv1(attn_kv1))
512 | kv2 = self.kv2_dwconv(self.kv2(attn_kv2))
513 | q1,q2 = q_.chunk(2, dim=1)
514 | k1,v1 = kv1.chunk(2, dim=1)
515 | k2,v2 = kv2.chunk(2, dim=1)
516 |
517 | q1 = rearrange(q1, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
518 | q2 = rearrange(q2, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
519 | k1 = rearrange(k1, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
520 | v1 = rearrange(v1, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
521 | k2 = rearrange(k2, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
522 | v2 = rearrange(v2, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
523 |
524 | q1 = torch.nn.functional.normalize(q1, dim=-1)
525 | q2 = torch.nn.functional.normalize(q2, dim=-1)
526 | k1 = torch.nn.functional.normalize(k1, dim=-1)
527 | k2 = torch.nn.functional.normalize(k2, dim=-1)
528 |
529 | attn = (q1 @ k1.transpose(-2, -1)) * self.temperature1
530 | attn = attn.softmax(dim=-1)
531 | out1 = (attn @ v1)
532 |
533 | attn = (q2 @ k2.transpose(-2, -1)) * self.temperature2
534 | attn = attn.softmax(dim=-1)
535 | out2 = (attn @ v2)
536 |
537 | out1 = rearrange(out1, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
538 | out2 = rearrange(out2, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
539 | out = torch.cat((out1, out2), dim=1)
540 | out = self.project_out(out)
541 | return out
542 |
543 |
544 | ##########################################################################
545 | class CrossTransformerBlock(nn.Module):
546 | def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
547 | super(CrossTransformerBlock, self).__init__()
548 | self.norm1 = LayerNorm(dim, LayerNorm_type)
549 | self.norm_kv1 = LayerNorm(dim, LayerNorm_type)
550 | self.norm_kv2 = LayerNorm(dim, LayerNorm_type)
551 | self.attn = CrossAttention(dim, num_heads, bias)
552 | self.norm2 = LayerNorm(dim, LayerNorm_type)
553 | self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
554 |
555 | def forward(self, x, attn_kv1, attn_kv2):
556 | x = x + self.attn(self.norm1(x), self.norm_kv1(attn_kv1), self.norm_kv2(attn_kv2))
557 | x = x + self.ffn(self.norm2(x))
558 | return x
559 |
560 | class CrossTransformerLayer(nn.Module):
561 | def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type, num_blocks):
562 | super(CrossTransformerLayer, self).__init__()
563 | self.blocks = nn.ModuleList([CrossTransformerBlock(dim=dim, num_heads=num_heads,
564 | ffn_expansion_factor=ffn_expansion_factor, bias=bias,
565 | LayerNorm_type=LayerNorm_type) for i in range(num_blocks)])
566 |
567 | def forward(self, x, attn_kv=None, attn_kv2=None):
568 | for blk in self.blocks:
569 | x = blk(x, attn_kv, attn_kv2)
570 | return x
571 |
572 |
573 | ##########################################################################
574 | ## Multi-DConv Head Transposed Self-Attention (MDTA)
575 | class Attention(nn.Module):
576 | def __init__(self, dim, num_heads, bias):
577 | super(Attention, self).__init__()
578 | self.num_heads = num_heads
579 | self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
580 |
581 | self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
582 | self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
583 | self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
584 |
585 |
586 |
587 | def forward(self, x):
588 | b,c,h,w = x.shape
589 |
590 | qkv = self.qkv_dwconv(self.qkv(x))
591 | q,k,v = qkv.chunk(3, dim=1)
592 |
593 | q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
594 | k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
595 | v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
596 |
597 | q = torch.nn.functional.normalize(q, dim=-1)
598 | k = torch.nn.functional.normalize(k, dim=-1)
599 |
600 | attn = (q @ k.transpose(-2, -1)) * self.temperature
601 | attn = attn.softmax(dim=-1)
602 |
603 | out = (attn @ v)
604 |
605 | out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
606 |
607 | out = self.project_out(out)
608 | return out
609 |
610 |
611 | class Self_attention(nn.Module):
612 | def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
613 | super(Self_attention, self).__init__()
614 | self.norm1 = LayerNorm(dim, LayerNorm_type)
615 | self.attn = Attention(dim, num_heads, bias)
616 | self.norm2 = LayerNorm(dim, LayerNorm_type)
617 | self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
618 |
619 | def forward(self, x):
620 | x = x + self.attn(self.norm1(x))
621 | x = x + self.ffn(self.norm2(x))
622 | return x
623 |
624 |
625 | class SelfAttentionLayer(nn.Module):
626 | def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type, num_blocks):
627 | super(SelfAttentionLayer, self).__init__()
628 | self.blocks = nn.ModuleList([Self_attention(dim=dim, num_heads=num_heads, ffn_expansion_factor=ffn_expansion_factor, bias=bias,
629 | LayerNorm_type=LayerNorm_type) for i in range(num_blocks)])
630 |
631 | def forward(self, x):
632 | for blk in self.blocks:
633 | x = blk(x)
634 | return x
635 |
636 |
637 | class Upsample(nn.Module):
638 | def __init__(self, in_channel, out_channel):
639 | super(Upsample, self).__init__()
640 | self.deconv = nn.ConvTranspose2d(in_channel, out_channel, kernel_size=2, stride=2)
641 |
642 | def forward(self, x):
643 | out = self.deconv(x)
644 | return out
645 |
646 | def flops(self, H, W):
647 | flops = 0
648 | # conv
649 | flops += H*2*W*2*self.in_channel*self.out_channel*2*2
650 | print("Upsample:{%.2f}"%(flops/1e9))
651 | return flops
652 |
653 |
654 | class Transformer(nn.Module):
655 | def __init__(self, unit_dim):
656 | super(Transformer, self).__init__()
657 | ## init qurey networks
658 | self.init_qurey_net(unit_dim)
659 | self.init_decoder(unit_dim)
660 | ## last conv
661 | self.last_conv0 = conv3x3(unit_dim*4, 3)
662 | self.last_conv1 = conv3x3(unit_dim*2, 3)
663 | self.last_conv2 = conv3x3(unit_dim, 3)
664 |
665 | def init_decoder(self, unit_dim):
666 | ### decoder
667 | ### attention k,v building (synthesis)
668 | self.build_kv0_syn = conv3x3_leaky_relu(unit_dim*3, unit_dim*4)
669 | self.build_kv1_syn = conv3x3_leaky_relu(int(unit_dim*1.5), unit_dim*2)
670 | self.build_kv2_syn = conv3x3_leaky_relu(int(unit_dim*0.75), unit_dim)
671 | ### attention k, v building (warping)
672 | self.build_kv0_warp = conv3x3_leaky_relu(unit_dim*3+6, unit_dim*4)
673 | self.build_kv1_warp = conv3x3_leaky_relu(int(unit_dim*1.5)+6, unit_dim*2)
674 | self.build_kv2_warp = conv3x3_leaky_relu(int(unit_dim*0.75)+6, unit_dim)
675 | ## level 1
676 | self.decoder1_1 = CrossTransformerLayer(dim=unit_dim*4, num_heads=4, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias', num_blocks=2)
677 | self.decoder1_2 = SelfAttentionLayer(dim=unit_dim*4, num_heads=4, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias', num_blocks=2)
678 | ## level 2
679 | self.decoder2_1 = CrossTransformerLayer(dim=unit_dim*2, num_heads=4, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias', num_blocks=2)
680 | self.decoder2_2 = SelfAttentionLayer(dim=unit_dim*2, num_heads=2, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias', num_blocks=2)
681 | ## level 3
682 | self.decoder3_1 = CrossTransformerLayer(dim=unit_dim, num_heads=4, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias', num_blocks=2)
683 | self.decoder3_2 = SelfAttentionLayer(dim=unit_dim, num_heads=1, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias', num_blocks=2)
684 | ## upsample
685 | self.upsample0 = Upsample(unit_dim*4, unit_dim*2)
686 | self.upsample1 = Upsample(unit_dim*2, unit_dim)
687 | ## conv after body
688 | self.conv_after_body0 = conv_resblock_one(4*unit_dim, 2*unit_dim)
689 | self.conv_after_body1 = conv_resblock_one(2*unit_dim, unit_dim)
690 |
691 | ### qurey network
692 | def init_qurey_net(self, unit_dim):
693 | ### building query
694 | ## stage 1
695 | self.enc_conv0 = conv3x3_leaky_relu(unit_dim+6, unit_dim)
696 | ## stage 2
697 | self.enc_conv1 = conv3x3_leaky_relu(unit_dim, 2*unit_dim, stride=2)
698 | ## stage 3
699 | self.enc_conv2 = conv3x3_leaky_relu(2*unit_dim, 4*unit_dim, stride=2)
700 |
701 | ## query buiding !!
702 | def build_qurey(self, event_feature, frame_feature, warped_feature):
703 | cat_in0 = torch.cat((event_feature[0], frame_feature[0], warped_feature[0]), dim=1)
704 | Q0 = self.enc_conv0(cat_in0)
705 | Q1 = self.enc_conv1(Q0)
706 | Q2 = self.enc_conv2(Q1)
707 | return [Q0, Q1, Q2]
708 |
709 | def forward_decoder(self, Q_list, warped_feature, frame_feature, event_feature):
710 | ## syntheis kv building
711 | cat_in0_syn = torch.cat((frame_feature[2], event_feature[2]), dim=1)
712 | attn_kv0_syn = self.build_kv0_syn(cat_in0_syn)
713 | cat_in1_syn = torch.cat((frame_feature[1], event_feature[1]), dim=1)
714 | attn_kv1_syn = self.build_kv1_syn(cat_in1_syn)
715 | cat_in2_syn = torch.cat((frame_feature[0], event_feature[0]), dim=1)
716 | attn_kv2_syn = self.build_kv2_syn(cat_in2_syn)
717 | ## warping kv building
718 | cat_in0_warp = torch.cat((warped_feature[2], event_feature[2]), dim=1)
719 | attn_kv0_warp = self.build_kv0_warp(cat_in0_warp)
720 | cat_in1_warp = torch.cat((warped_feature[1], event_feature[1]), dim=1)
721 | attn_kv1_warp = self.build_kv1_warp(cat_in1_warp)
722 | cat_in2_warp = torch.cat((warped_feature[0], event_feature[0]), dim=1)
723 | attn_kv2_warp = self.build_kv2_warp(cat_in2_warp)
724 | ## out 0
725 | _Q0 = Q_list[2]
726 | out0 = self.decoder1_1(_Q0, attn_kv0_syn, attn_kv0_warp)
727 | out0 = self.decoder1_2(out0)
728 | up_out0 = self.upsample0(out0)
729 | ## out 1
730 | _Q1 = Q_list[1]
731 | _Q1 = self.conv_after_body0(torch.cat((_Q1, up_out0), dim=1))
732 | out1 = self.decoder2_1(_Q1, attn_kv1_syn, attn_kv1_warp)
733 | out1 = self.decoder2_2(out1)
734 | up_out1 = self.upsample1(out1)
735 | ## out2
736 | _Q2 = Q_list[0]
737 | _Q2 = self.conv_after_body1(torch.cat((_Q2, up_out1), dim=1))
738 | out2 = self.decoder3_1(_Q2, attn_kv2_syn, attn_kv2_warp)
739 | out2 = self.decoder3_2(out2)
740 | return [out0, out1, out2]
741 |
742 | def forward(self, event_feature, frame_feature, warped_feature):
743 | ### forward encoder
744 | Q_list = self.build_qurey(event_feature, frame_feature, warped_feature)
745 | ### forward decoder
746 | out_decoder = self.forward_decoder(Q_list, warped_feature, frame_feature, event_feature)
747 | ### synthesis frame
748 | img0 = self.last_conv0(out_decoder[0])
749 | img1 = self.last_conv1(out_decoder[1])
750 | img2 = self.last_conv2(out_decoder[2])
751 | return [img2, img1, img0]
752 |
753 |
754 |
755 |
756 | class EventInterpNet(nn.Module):
757 | def __init__(self, num_bins=16, flow_debug=False):
758 | super(EventInterpNet, self).__init__()
759 | unit_dim = 44
760 | # scale
761 | self.scale = 3
762 | # flownet
763 | self.flownet = FlowNet(md=4, tb_debug=flow_debug)
764 | self.flow_debug = flow_debug
765 | # encoder
766 | self.encoder_f = frame_encoder(3, unit_dim//4)
767 | self.encoder_e = event_encoder(16, unit_dim//2)
768 | # decoder
769 | self.transformer = Transformer(unit_dim*2)
770 | # channel scaling convolution
771 | self.conv_list = nn.ModuleList([conv1x1(unit_dim, unit_dim), conv1x1(unit_dim, unit_dim), conv1x1(unit_dim, unit_dim)])
772 |
773 | def set_mode(self, mode):
774 | self.mode = mode
775 |
776 | def bwarp(self, x, flo):
777 | '''
778 | x shape : [B,C,T,H,W]
779 | t_value shape : [B,1] ###############
780 | '''
781 | B, C, H, W = x.size()
782 | # mesh grid
783 | xx = torch.arange(0, W).view(1, 1, 1, W).expand(B, 1, H, W)
784 | yy = torch.arange(0, H).view(1, 1, H, 1).expand(B, 1, H, W)
785 | grid = torch.cat((xx, yy), 1).float()
786 |
787 | if x.is_cuda:
788 | grid = grid.cuda()
789 | vgrid = torch.autograd.Variable(grid) + flo
790 |
791 | # scale grid to [-1,1]
792 | vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0
793 | vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0
794 |
795 | vgrid = vgrid.permute(0, 2, 3, 1) # [B,H,W,2]
796 | output = nn.functional.grid_sample(x, vgrid, align_corners=True)
797 | mask = torch.autograd.Variable(torch.ones(x.size())).cuda()
798 | mask = nn.functional.grid_sample(mask, vgrid, align_corners=True)
799 | # mask[mask<0.9999] = 0
800 | # mask[mask>0] = 1
801 | mask = mask.masked_fill_(mask < 0.999, 0)
802 | mask = mask.masked_fill_(mask > 0, 1)
803 | return output * mask
804 |
805 | def Flow_pyramid(self, flow):
806 | flow_pyr = []
807 | flow_pyr.append(flow)
808 | for i in range(1, 3):
809 | flow_pyr.append(F.interpolate(flow, scale_factor=0.5 ** i, mode='bilinear') * (0.5 ** i))
810 | return flow_pyr
811 |
812 | def Img_pyramid(self, Img):
813 | img_pyr = []
814 | img_pyr.append(Img)
815 | for i in range(1, 3):
816 | img_pyr.append(F.interpolate(Img, scale_factor=0.5 ** i, mode='bilinear'))
817 | return img_pyr
818 |
819 | def synthesis(self, batch, OF_t0, OF_t1):
820 | ## frame encoding
821 | f_frame0 = self.encoder_f(batch['image_input0'])
822 | f_frame1 = self.encoder_f(batch['image_input1'])
823 | ## OF pyramid
824 | OF_t0_pyramid = self.Flow_pyramid(OF_t0[0])
825 | OF_t1_pyramid = self.Flow_pyramid(OF_t1[0])
826 | ## image pyramid
827 | I0_pyramid = self.Img_pyramid(batch['image_input0'])
828 | I1_pyramid = self.Img_pyramid(batch['image_input1'])
829 | # frame0_warped, frame1_warped = [], []
830 | warped_feature, frame_feature = [], []
831 | for idx in range(self.scale):
832 | frame0_warped = self.bwarp(torch.cat((f_frame0[idx], I0_pyramid[idx]),dim=1), OF_t0_pyramid[idx])
833 | frame1_warped = self.bwarp(torch.cat((f_frame1[idx], I1_pyramid[idx]),dim=1), OF_t1_pyramid[idx])
834 | warped_feature.append(torch.cat((frame0_warped, frame1_warped), dim=1))
835 | frame_feature.append(torch.cat((f_frame0[idx], f_frame1[idx]), dim=1))
836 | # after_tmp_feature = self.conv_list[idx](tmp_feature)
837 | event_feature = []
838 | # event encoding for frame interpolation
839 | f_event_0t = self.encoder_e(batch['event_input_0t'])
840 | f_event_t1 = self.encoder_e(batch['event_input_t1'])
841 | for idx in range(self.scale):
842 | event_feature.append(torch.cat((f_event_0t[idx], f_event_t1[idx]), dim=1))
843 | img_out = self.transformer(event_feature, frame_feature, warped_feature)
844 | output_clean = []
845 | for i in range(self.scale):
846 | output_clean.append(torch.clamp(img_out[i], 0, 1))
847 | return output_clean
848 |
849 | def forward(self, batch):
850 | output_dict = {}
851 | # --- Flow-only mode ---
852 | if self.mode == 'flow':
853 | output_dict['flow_out'] = self.flownet(batch)
854 | # --- Joint mode: ---
855 | elif self.mode == 'joint':
856 | flow_out = self.flownet(batch)
857 | interp_out = self.synthesis(batch, flow_out['flow_t0_dict'], flow_out['flow_t1_dict'])
858 | output_dict.update({'flow_out': flow_out, 'interp_out': interp_out})
859 | else:
860 | raise ValueError(f"Unsupported mode: {self.mode}")
861 | return output_dict
--------------------------------------------------------------------------------
/models/final_models/submodules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 |
5 | def actFunc(act, *args, **kwargs):
6 | act = act.lower()
7 | if act == 'relu':
8 | return nn.ReLU()
9 | elif act == 'relu6':
10 | return nn.ReLU6()
11 | elif act == 'leakyrelu':
12 | return nn.LeakyReLU(0.1)
13 | elif act == 'prelu':
14 | return nn.PReLU()
15 | elif act == 'rrelu':
16 | return nn.RReLU(0.1, 0.3)
17 | elif act == 'selu':
18 | return nn.SELU()
19 | elif act == 'celu':
20 | return nn.CELU()
21 | elif act == 'elu':
22 | return nn.ELU()
23 | elif act == 'gelu':
24 | return nn.GELU()
25 | elif act == 'tanh':
26 | return nn.Tanh()
27 | else:
28 | raise NotImplementedError
29 |
30 | class ResBlock(nn.Module):
31 | """
32 | Residual block
33 | """
34 | def __init__(self, in_chs, activation='relu', batch_norm=False):
35 | super(ResBlock, self).__init__()
36 | op = []
37 | for i in range(2):
38 | op.append(conv3x3(in_chs, in_chs))
39 | if batch_norm:
40 | op.append(nn.BatchNorm2d(in_chs))
41 | if i == 0:
42 | op.append(actFunc(activation))
43 | self.main_branch = nn.Sequential(*op)
44 |
45 | def forward(self, x):
46 | out = self.main_branch(x)
47 | out += x
48 | return out
49 |
50 | # conv blocks
51 | def conv1x1(in_channels, out_channels, stride=1):
52 | return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=True)
53 |
54 | def conv3x3(in_channels, out_channels, stride=1):
55 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=True)
56 |
57 | def conv5x5(in_channels, out_channels, stride=1):
58 | return nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=stride, padding=2, bias=True)
59 |
60 | def deconv4x4(in_channels, out_channels, stride=2):
61 | return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=stride, padding=1)
62 |
63 | def deconv5x5(in_channels, out_channels, stride=2):
64 | return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=5, stride=stride, padding=2, output_padding=1)
65 |
66 | # conv resblock
67 | def conv_resblock_three(in_channels, out_channels, stride=1):
68 | return nn.Sequential(conv3x3(in_channels, out_channels, stride), nn.ReLU(), ResBlock(out_channels), ResBlock(out_channels), ResBlock(out_channels))
69 |
70 | def conv_resblock_two(in_channels, out_channels, stride=1):
71 | return nn.Sequential(conv3x3(in_channels, out_channels, stride), nn.ReLU(), ResBlock(out_channels), ResBlock(out_channels))
72 |
73 | def conv_resblock_one(in_channels, out_channels, stride=1):
74 | return nn.Sequential(conv3x3(in_channels, out_channels, stride), nn.ReLU(), ResBlock(out_channels))
75 |
76 | def conv_1x1_resblock_one(in_channels, out_channels, stride=1):
77 | return nn.Sequential(conv1x1(in_channels, out_channels, stride), nn.ReLU(), ResBlock(out_channels))
78 |
79 | def conv_resblock_two_DS(in_channels, out_channels, stride=2):
80 | return nn.Sequential(conv3x3(in_channels, out_channels, stride), nn.ReLU(), ResBlock(out_channels), ResBlock(out_channels))
81 |
82 | def conv3x3_leaky_relu(in_channels, out_channels, stride=1):
83 | return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=True), nn.LeakyReLU(0.1))
84 |
85 | def conv1x1_leaky_relu(in_channels, out_channels, stride=1):
86 | return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=True), nn.LeakyReLU(0.1))
--------------------------------------------------------------------------------
/models/loss_handler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import importlib
4 | from models.final_models.submodules import *
5 | from math import ceil
6 | from utils.flow_utils import cal_grad2_error
7 | from utils.utils import AverageMeter
8 |
9 |
10 | class LossHandler:
11 | def __init__(self, smoothness_weight, scale, loss_weight):
12 | self.smoothness_weight = smoothness_weight
13 | self.scale = scale
14 | self.loss_weight = loss_weight
15 | self.loss_total_meter = AverageMeter()
16 | self.loss_image_meter = AverageMeter()
17 | self.loss_warp_meter = AverageMeter()
18 | self.loss_flow_meter = AverageMeter()
19 | self.loss_smoothness_meter = AverageMeter()
20 |
21 | self.reset_cache()
22 |
23 | def reset_cache(self):
24 | self.loss = 0
25 | self.loss_image = 0
26 | self.loss_flow = 0
27 | self.loss_warping = 0
28 | self.loss_smoothness = 0
29 |
30 | def compute_multiscale_loss(self, gt_list, pred_list):
31 | self.loss_image = 0
32 | for i in range(self.scale):
33 | self.loss_image += self.loss_weight[i] * self._l1_loss(gt_list[i], pred_list[i])
34 | self.loss = self.loss_image
35 | return self.loss
36 |
37 | def compute_flow_loss(self, outputs, batch):
38 | self.loss_warping = 0
39 | self.loss_smoothness = 0
40 |
41 | imaget_est0_list, imaget_est1_list = [], []
42 |
43 | for idx in range(len(outputs['flow_out']['flow_t0_dict'])):
44 | est0, _ = self._warp(batch['image_pyramid_0'][idx], outputs['flow_out']['flow_t0_dict'][idx])
45 | est1, _ = self._warp(batch['image_pyramid_1'][idx], outputs['flow_out']['flow_t1_dict'][idx])
46 |
47 | imaget_est0_list.append(est0)
48 | imaget_est1_list.append(est1)
49 |
50 | gt = batch['clean_gt_MS_images'][idx]
51 | loss0 = self._l1_loss(gt, est0)
52 | loss1 = self._l1_loss(gt, est1)
53 | smooth0 = cal_grad2_error(outputs['flow_out']['flow_t0_dict'][idx]/20, gt, 1.0)
54 | smooth1 = cal_grad2_error(outputs['flow_out']['flow_t1_dict'][idx]/20, gt, 1.0)
55 |
56 | self.loss_warping += loss0 + loss1
57 | self.loss_smoothness += smooth0 + smooth1
58 |
59 | self.loss_flow = self.loss_warping + self.smoothness_weight * self.loss_smoothness
60 | self.loss = self.loss_flow
61 | return self.loss, imaget_est0_list, imaget_est1_list
62 |
63 | def _l1_loss(self, x, y):
64 | return torch.sqrt((x - y) ** 2 + 1e-6).mean()
65 |
66 | def _warp(self, x, flo):
67 | B, C, H, W = x.size()
68 | xx = torch.arange(0, W).view(1, 1, 1, W).expand(B, 1, H, W)
69 | yy = torch.arange(0, H).view(1, 1, H, 1).expand(B, 1, H, W)
70 | grid = torch.cat((xx, yy), 1).float()
71 | if x.is_cuda:
72 | grid = grid.cuda()
73 | vgrid = torch.autograd.Variable(grid) + flo
74 |
75 | vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0
76 | vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0
77 | vgrid = vgrid.permute(0, 2, 3, 1)
78 |
79 | output = nn.functional.grid_sample(x, vgrid, align_corners=True)
80 | mask = torch.ones_like(x).cuda() if x.is_cuda else torch.ones_like(x)
81 | mask = nn.functional.grid_sample(mask, vgrid, align_corners=True)
82 | mask = mask.masked_fill(mask < 0.999, 0).masked_fill(mask > 0, 1)
83 | return output * mask, mask
84 |
85 | def update_meters(self, mode):
86 | self.loss_flow_meter.update(self.loss_flow)
87 | self.loss_warp_meter.update(self.loss_warping)
88 | self.loss_smoothness_meter.update(self.loss_smoothness)
89 | if mode == 'joint':
90 | self.loss_total_meter.update(self.loss)
91 | self.loss_image_meter.update(self.loss_image)
92 |
93 | def reset_meters(self):
94 | self.loss_total_meter.reset()
95 | self.loss_image_meter.reset()
96 | self.loss_flow_meter.reset()
97 | self.loss_warp_meter.reset()
98 | self.loss_smoothness_meter.reset()
--------------------------------------------------------------------------------
/models/model_manager.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import importlib
4 | from math import ceil
5 | from models.final_models.submodules import *
6 | from utils.utils import AverageMeter, batch2device
7 | from .loss_handler import LossHandler
8 |
9 |
10 | class OurModel:
11 | def __init__(self, args):
12 | self.voxel_num_bins = args.voxel_num_bins
13 | self.flow_debug = bool(args.flow_tb_debug)
14 | self.scale = 3
15 | self.loss_weight = [1, 0.1, 0.1]
16 | self.batch = {}
17 | self.outputs = {}
18 | self.test_outputs = {}
19 | self.downsample = nn.AvgPool2d(2, stride=2)
20 |
21 | self.loss_handler = LossHandler(
22 | smoothness_weight=args.smoothness_weight,
23 | scale=self.scale,
24 | loss_weight=self.loss_weight
25 | )
26 |
27 | def initialize(self, model_folder, model_name):
28 | mod = importlib.import_module(f'models.{model_folder}.{model_name}')
29 | self.net = mod.EventInterpNet(self.voxel_num_bins, self.flow_debug)
30 |
31 | def cuda(self):
32 | self.net.cuda()
33 |
34 | def train(self):
35 | self.net.train()
36 |
37 | def eval(self):
38 | self.net.eval()
39 |
40 | def use_multi_gpu(self):
41 | self.net = nn.DataParallel(self.net)
42 |
43 | def fix_flownet(self):
44 | net = self.net.module if isinstance(self.net, nn.DataParallel) else self.net
45 | for param in net.flownet.parameters():
46 | param.requires_grad = False
47 |
48 | def get_optimizer_params(self):
49 | return self.net.parameters()
50 |
51 | def set_mode(self, mode):
52 | net = self.net.module if isinstance(self.net, nn.DataParallel) else self.net
53 | net.set_mode(mode)
54 |
55 | def set_train_input(self, sample):
56 | self._set_common_input(sample)
57 | self._generate_multi_scale_inputs(sample)
58 |
59 | def set_input(self, sample):
60 | self._set_common_input(sample)
61 |
62 | def _set_common_input(self, sample):
63 | self.batch['image_input0'] = sample['clean_image_first'].float()
64 | self.batch['image_input1'] = sample['clean_image_last'].float()
65 | self.batch['imaget_input'] = sample['clean_middle'].float()
66 | self.batch['event_input_t0'] = sample['voxel_grid_t0'].float()
67 | self.batch['event_input_0t'] = sample['voxel_grid_0t'].float()
68 | self.batch['event_input_t1'] = sample['voxel_grid_t1'].float()
69 | self.batch['clean_gt_images'] = sample['clean_middle']
70 |
71 | def _generate_multi_scale_inputs(self, sample):
72 | labels = sample['clean_middle']
73 | image_0 = sample['clean_image_first']
74 | image_1 = sample['clean_image_last']
75 | self.batch['clean_gt_MS_images'] = [labels]
76 | self.batch['image_pyramid_0'] = [image_0]
77 | self.batch['image_pyramid_1'] = [image_1]
78 |
79 | for _ in range(self.scale - 1):
80 | labels = self.downsample(labels.clone())
81 | image_0 = self.downsample(image_0.clone())
82 | image_1 = self.downsample(image_1.clone())
83 | self.batch['clean_gt_MS_images'].append(labels)
84 | self.batch['image_pyramid_0'].append(image_0)
85 | self.batch['image_pyramid_1'].append(image_1)
86 |
87 | def forward_nets(self):
88 | self.outputs = self.net(self.batch)
89 |
90 | def forward_joint_test(self):
91 | self.test_outputs = self.net(self.batch)
92 | self.test_outputs['flow_out']['flow_t0_dict'] = self.test_outputs['flow_out']['flow_t0_dict'][0][..., 0:self.H_org,0:self.W_org]
93 | self.test_outputs['flow_out']['flow_t1_dict'] = self.test_outputs['flow_out']['flow_t1_dict'][0][..., 0:self.H_org,0:self.W_org]
94 | self.test_outputs['interp_out'] = self.test_outputs['interp_out'][0][..., 0:self.H_org,0:self.W_org]
95 |
96 | def get_multi_scale_loss(self):
97 | return self.loss_handler.compute_multiscale_loss(
98 | self.batch['clean_gt_MS_images'],
99 | self.outputs['interp_out']
100 | )
101 |
102 | def get_flow_loss(self):
103 | loss_flow, imaget_est0_list, imaget_est1_list = self.loss_handler.compute_flow_loss(self.outputs, self.batch)
104 | self.batch['imaget_est0_warp'] = imaget_est0_list
105 | self.batch['imaget_est1_warp'] = imaget_est1_list
106 | return loss_flow
107 |
108 | def update_loss_meters(self, mode):
109 | self.loss_handler.update_meters(mode)
110 |
111 | def reset_loss_meters(self):
112 | self.loss_handler.reset_meters()
113 |
114 | def load_model(self, state_dict):
115 | if list(state_dict.keys())[0].startswith('module.'):
116 | new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
117 | self.net.load_state_dict(new_state_dict)
118 | else:
119 | self.net.load_state_dict(state_dict)
120 |
121 | def set_test_input(self, sample):
122 | B, _, H, W = sample['clean_image_first'].shape
123 | H_ = ceil(H / 64) * 64
124 | W_ = ceil(W / 64) * 64
125 |
126 | C1 = torch.zeros((B, 3, H_, W_)).cuda()
127 | C2 = torch.zeros((B, 3, H_, W_)).cuda()
128 |
129 | self.batch['image_input0_org'] = sample['clean_image_first']
130 | self.batch['image_input1_org'] = sample['clean_image_last']
131 |
132 | C1[:, :, :H, :W] = sample['clean_image_first']
133 | C2[:, :, :H, :W] = sample['clean_image_last']
134 |
135 | self.batch['image_input0'] = C1
136 | self.batch['image_input1'] = C2
137 |
138 | Vt0 = torch.zeros((B, self.voxel_num_bins, H_, W_)).cuda()
139 | Vt1 = torch.zeros((B, self.voxel_num_bins, H_, W_)).cuda()
140 | V0t = torch.zeros((B, self.voxel_num_bins, H_, W_)).cuda()
141 |
142 | Vt0[:, :, :H, :W] = sample['voxel_grid_t0']
143 | Vt1[:, :, :H, :W] = sample['voxel_grid_t1']
144 | V0t[:, :, :H, :W] = sample['voxel_grid_0t']
145 |
146 | self.batch['event_input_t0'] = Vt0
147 | self.batch['event_input_0t'] = V0t
148 | self.batch['event_input_t1'] = Vt1
149 |
150 | self.H_org = H
151 | self.W_org = W
152 |
153 |
--------------------------------------------------------------------------------
/pretrained_model/README.md:
--------------------------------------------------------------------------------
1 | ### put the downloaded model here
--------------------------------------------------------------------------------
/run_samples.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | import numpy as np
4 | import os
5 | from models.model_manager import OurModel
6 | from skimage.io import imread
7 | import cv2
8 | from utils.utils import *
9 |
10 |
11 | torch.backends.cudnn.enabled = True
12 | torch.backends.cudnn.benchmark = True
13 |
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('--voxel_num_bins', type = int, default=16)
16 |
17 | parser.add_argument('--sample_folder_path', type = str, default='./sample_data')
18 | parser.add_argument('--save_output_dir', type = str, default='./output')
19 | parser.add_argument('--image_number', type = int, default=0)
20 |
21 | parser.add_argument('--model_folder', type = str, default='final_models')
22 | parser.add_argument('--model_name', type = str, default='ours')
23 | parser.add_argument('--flow_tb_debug', type = str2bool, default='False')
24 | parser.add_argument('--smoothness_weight', type = float, default=10.0)
25 | parser.add_argument('--ckpt_path', type=str, default='pretrained_model/ours_weight.pth')
26 | args = parser.parse_args()
27 |
28 |
29 | first_image_name = os.path.join(args.sample_folder_path, str(args.image_number).zfill(5) + '.png')
30 | second_image_name = os.path.join(args.sample_folder_path, str(args.image_number+1).zfill(5) + '.png')
31 | first_image_np = imread(first_image_name)
32 | second_image_np = imread(second_image_name)
33 | frame1 = torch.from_numpy(first_image_np).permute(2,0,1).float().unsqueeze(0) / 255.0
34 | frame3 = torch.from_numpy(second_image_np).permute(2,0,1).float().unsqueeze(0) / 255.0
35 |
36 | voxel_0t_name = os.path.join(args.sample_folder_path, str(args.image_number).zfill(5) + '_0t.npz')
37 | voxel_t0_name = os.path.join(args.sample_folder_path, str(args.image_number).zfill(5) + '_t0.npz')
38 | voxel_t1_name = os.path.join(args.sample_folder_path, str(args.image_number).zfill(5) + '_t1.npz')
39 | voxel_0t = torch.from_numpy(np.load(voxel_0t_name)["data"])[None, ...]
40 | voxel_t1 = torch.from_numpy(np.load(voxel_t1_name)["data"])[None, ...]
41 | voxel_t0 = torch.from_numpy(np.load(voxel_t0_name)["data"])[None, ...]
42 |
43 | model = OurModel(args)
44 | model.initialize(args.model_folder, args.model_name)
45 |
46 | ckpt = torch.load(args.ckpt_path, map_location='cpu')
47 | model.load_model(ckpt)
48 |
49 | model.cuda()
50 | model.set_mode('joint')
51 | with torch.no_grad():
52 | # patch-wise evaluation
53 | iter_idx = 0
54 | h_size_patch_testing = 640
55 | h_overlap_size = 305
56 | w_size_patch_testing = 896
57 | w_overlap_size = 352
58 | sample = {}
59 | sample['clean_image_first'] = frame1.cuda()
60 | sample['clean_image_last'] = frame3.cuda()
61 | sample['voxel_grid_0t'] = voxel_0t.cuda()
62 | sample['voxel_grid_t1'] = voxel_t1.cuda()
63 | sample['voxel_grid_t0'] = voxel_t0.cuda()
64 |
65 | B, C, H, W = frame1.shape
66 |
67 | h_stride = h_size_patch_testing - h_overlap_size
68 | w_stride = w_size_patch_testing - w_overlap_size
69 | h_idx_list = list(range(0, H-h_size_patch_testing, h_stride)) + [max(0, H-h_size_patch_testing)]
70 | w_idx_list = list(range(0, W-w_size_patch_testing, w_stride)) + [max(0, W-w_size_patch_testing)]
71 | # output
72 | E = torch.zeros(B, C, H, W).cuda()
73 | W_ = torch.zeros_like(E).cuda()
74 | input_keys = ['clean_image_first', 'clean_image_last', 'voxel_grid_0t', 'voxel_grid_t1', 'voxel_grid_t0']
75 | not_overlap_border = True
76 | for h_idx in h_idx_list:
77 | for w_idx in w_idx_list:
78 | _sample = {}
79 | for input_key in input_keys:
80 | _sample[input_key] = sample[input_key][..., h_idx:h_idx+h_size_patch_testing, w_idx:w_idx+w_size_patch_testing]
81 | model.set_test_input(_sample)
82 | model.forward_joint_test()
83 | out_patch = model.test_outputs['interp_out']
84 | out_patch_mask = torch.ones_like(out_patch)
85 | if not_overlap_border:
86 | if h_idx < h_idx_list[-1]:
87 | out_patch[..., -h_overlap_size//2:, :] *= 0
88 | out_patch_mask[..., -h_overlap_size//2:, :] *= 0
89 | if w_idx < w_idx_list[-1]:
90 | out_patch[..., -w_overlap_size//2:] *= 0
91 | out_patch_mask[..., -w_overlap_size//2:] *= 0
92 | if h_idx > h_idx_list[0]:
93 | out_patch[..., :h_overlap_size//2, :] *= 0
94 | out_patch_mask[..., :h_overlap_size//2, :] *= 0
95 | if w_idx > w_idx_list[0]:
96 | out_patch[..., :w_overlap_size//2] *= 0
97 | out_patch_mask[..., :w_overlap_size//2] *= 0
98 | E[:, :, h_idx:(h_idx+h_size_patch_testing), w_idx:(w_idx+w_size_patch_testing)].add_(out_patch)
99 | W_[:, :, h_idx:(h_idx+h_size_patch_testing), w_idx:(w_idx+w_size_patch_testing)].add_(out_patch_mask)
100 | output = E.div_(W_)
101 | clean_middle_np = tensor2numpy(output)
102 | ## save output
103 | os.makedirs(args.save_output_dir, exist_ok=True)
104 | ## _0,_2 is output
105 | cv2.imwrite(os.path.join(args.save_output_dir, str(args.image_number).zfill(5) + '_0.png'), cv2.cvtColor(first_image_np, cv2.COLOR_RGB2BGR))
106 | cv2.imwrite(os.path.join(args.save_output_dir, str(args.image_number).zfill(5) + '_2.png'), cv2.cvtColor(second_image_np, cv2.COLOR_RGB2BGR))
107 | ## _1 is output
108 | cv2.imwrite(os.path.join(args.save_output_dir, str(args.image_number).zfill(5) + '_1.png'), cv2.cvtColor(clean_middle_np, cv2.COLOR_RGB2BGR))
--------------------------------------------------------------------------------
/sample_data/00000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/intelpro/CBMNet/248405c313ce29f87ad42873df46792c76f55d1a/sample_data/00000.png
--------------------------------------------------------------------------------
/sample_data/00000_0t.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/intelpro/CBMNet/248405c313ce29f87ad42873df46792c76f55d1a/sample_data/00000_0t.npz
--------------------------------------------------------------------------------
/sample_data/00000_t0.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/intelpro/CBMNet/248405c313ce29f87ad42873df46792c76f55d1a/sample_data/00000_t0.npz
--------------------------------------------------------------------------------
/sample_data/00000_t1.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/intelpro/CBMNet/248405c313ce29f87ad42873df46792c76f55d1a/sample_data/00000_t1.npz
--------------------------------------------------------------------------------
/sample_data/00001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/intelpro/CBMNet/248405c313ce29f87ad42873df46792c76f55d1a/sample_data/00001.png
--------------------------------------------------------------------------------
/test_bsergb.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from utils.utils import *
3 | from tqdm import tqdm
4 | import argparse
5 | import os
6 | from utils.dataloader_bsergb import *
7 | from models.model_manager import OurModel
8 | from utils.flow_utils import *
9 | import torchvision.utils as vutils
10 |
11 |
12 | def get_argument():
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('--val_batch_size', type=int, default=1)
15 | parser.add_argument('--voxel_num_bins', type=int, default=16)
16 | parser.add_argument('--flow_tb_debug', type=str2bool, default='False')
17 | parser.add_argument('--flow_tb_viz', type=str2bool, default='True')
18 | parser.add_argument('--warp_tb_debug', type=str2bool, default='True')
19 | parser.add_argument('--val_mode', type=str2bool, default='False')
20 | parser.add_argument('--val_skip_num_list', default=[1, 3])
21 | parser.add_argument('--model_folder', type=str, default='final_models')
22 | parser.add_argument('--model_name', type=str, default='ours_large')
23 | parser.add_argument('--use_smoothness_loss', type=str2bool, default='True')
24 | parser.add_argument('--smoothness_weight', type=float, default=10.0)
25 | parser.add_argument('--num_threads', type=int, default=12)
26 | parser.add_argument('--experiment_name', type=str, default='test_bsergb_dataset')
27 | parser.add_argument('--tb_update_thresh', type=int, default=1)
28 | parser.add_argument('--data_dir', type=str, default='/home/user/dataset/bsergb_interpolation_v2/')
29 | parser.add_argument('--ckpt_path', type=str, default='pretrained_model/Ours_Large_BSERGB.pth')
30 | parser.add_argument('--use_multigpu', type=str2bool, default='True')
31 | parser.add_argument('--train_skip_num_list', default=[1, 3])
32 | args = parser.parse_args()
33 | return args
34 |
35 |
36 | class Trainer(object):
37 | def __init__(self, args):
38 | self.args = args
39 | self._init_model()
40 | self._init_metrics()
41 | self._init_dataloader()
42 |
43 | def _init_dataloader(self):
44 | val_set_dict = get_BSERGB_val_dataset(self.args.data_dir, self.args.val_skip_num_list, mode='1_TEST')
45 | self.val_loader_dict = {}
46 | for skip_num, val_dataset in val_set_dict.items():
47 | self.val_loader_dict[skip_num] = torch.utils.data.DataLoader(
48 | val_dataset,
49 | batch_size=self.args.val_batch_size,
50 | shuffle=False,
51 | num_workers=self.args.num_threads,
52 | pin_memory=True
53 | )
54 |
55 | def _init_model(self):
56 | self.model = OurModel(self.args)
57 | self.model.initialize(self.args.model_folder, self.args.model_name)
58 | ckpt = torch.load(self.args.ckpt_path, map_location='cpu')
59 | self.model.load_model(ckpt['model_state_dict'])
60 |
61 | if torch.cuda.is_available():
62 | self.model.cuda()
63 |
64 | if self.args.use_multigpu:
65 | self.model.use_multi_gpu()
66 |
67 | def _init_metrics(self):
68 | self.PSNR_calculator = PSNR()
69 |
70 | def test_joint(self, epoch=0):
71 | psnr_total = AverageMeter()
72 | psnr_interval = AverageMeter()
73 |
74 | self.model.eval()
75 | self.model.set_mode('joint')
76 |
77 | os.makedirs('./outputs', exist_ok=True)
78 | os.makedirs(f'./logs/{self.args.experiment_name}', exist_ok=True)
79 |
80 | with torch.no_grad():
81 | for skip_num, val_loader in self.val_loader_dict.items():
82 | output_save_path = f'./outputs/net_out/{skip_num}skip'
83 | gt_save_path = f'./outputs/gt/{skip_num}skip'
84 | os.makedirs(output_save_path, exist_ok=True)
85 | os.makedirs(gt_save_path, exist_ok=True)
86 | for i, sample in enumerate(tqdm(val_loader, desc=f'val skip {skip_num}')):
87 | sample = batch2device(sample)
88 | self.model.set_test_input(sample)
89 | self.model.forward_joint_test()
90 |
91 | gt = sample['clean_middle']
92 | pred = self.model.test_outputs['interp_out']
93 |
94 | psnr = self.PSNR_calculator(gt, pred).mean().item()
95 |
96 | psnr_interval.update(psnr)
97 | psnr_total.update(psnr)
98 |
99 | output_name = os.path.join(output_save_path, str(i).zfill(5) + '.png')
100 | vutils.save_image(pred[0], output_name)
101 | gt_name = os.path.join(gt_save_path, str(i).zfill(5) + '.png')
102 | vutils.save_image(gt[0], gt_name)
103 |
104 | print(f"[Skip {skip_num}] PSNR: {psnr_interval.avg:.2f}")
105 | psnr_interval.reset()
106 |
107 | avg_psnr = psnr_total.avg
108 |
109 | print(f"\n[Test Summary] Avg PSNR: {avg_psnr:.2f}")
110 |
111 | log_path = os.path.join('./logs', self.args.experiment_name, f'test_result_epoch{epoch}.txt')
112 | with open(log_path, 'w') as f:
113 | f.write(f"Experiment: {self.args.experiment_name}\n")
114 | f.write(f"Epoch: {epoch}\n")
115 | f.write(f"Average PSNR: {avg_psnr:.2f}\n")
116 |
117 | torch.cuda.empty_cache()
118 | self.model.test_outputs = {}
119 | return avg_psnr
120 |
121 |
122 | if __name__ == '__main__':
123 | args = get_argument()
124 | trainer = Trainer(args)
125 | trainer.test_joint(epoch=0)
--------------------------------------------------------------------------------
/tools/event_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def extract_events(event):
5 | x = event['x']
6 | y = event['y']
7 | t = event['timestamp']
8 | p = event['polarity']
9 | return np.stack([t, x, y, p], axis=1) # shape: (N, 4)
10 |
11 | def event_reverse(events):
12 | end_time = events[:, 0].max()
13 | events[:, 0] = end_time - events[:, 0]
14 | events[:, 3][events[:, 3] == 0] = -1
15 | events[:, 3] = -events[:, 3]
16 | events = np.copy(np.flipud(events))
17 | return events
18 |
19 | def events_to_voxel_grid(events, num_bins, width, height):
20 | assert events.shape[1] == 4
21 | voxel_grid = np.zeros((num_bins, height, width), np.float32).ravel()
22 |
23 | last_stamp, first_stamp = events[-1, 0], events[0, 0]
24 | deltaT = max(last_stamp - first_stamp, 1.0)
25 |
26 | events[:, 0] = (num_bins - 1) * (events[:, 0] - first_stamp) / deltaT
27 | ts, xs, ys, pols = events[:, 0], events[:, 1].astype(int), events[:, 2].astype(int), events[:, 3]
28 | pols[pols == 0] = -1
29 |
30 | # 인덱스 범위 제한
31 | xs = np.clip(xs, 0, width - 1)
32 | ys = np.clip(ys, 0, height - 1)
33 | tis = np.clip(ts.astype(int), 0, num_bins - 1)
34 |
35 | dts = ts - tis
36 | vals_left = pols * (1.0 - dts)
37 | vals_right = pols * dts
38 |
39 | valid_indices = (tis >= 0) & (tis < num_bins)
40 | np.add.at(voxel_grid, xs[valid_indices] + ys[valid_indices] * width + tis[valid_indices] * width * height, vals_left[valid_indices])
41 |
42 | valid_indices = (tis >= 0) & ((tis + 1) < num_bins)
43 | np.add.at(voxel_grid, xs[valid_indices] + ys[valid_indices] * width + (tis[valid_indices] + 1) * width * height, vals_right[valid_indices])
44 |
45 | return voxel_grid.reshape((num_bins, height, width))
46 |
--------------------------------------------------------------------------------
/tools/preprocess_events.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from tqdm import tqdm
4 | from event_utils import *
5 |
6 | def parse_args():
7 | parser = argparse.ArgumentParser(description="Event voxel grid generation arguments")
8 |
9 | parser.add_argument('--skip_nums', type=int, nargs='+', default=[1, 3],
10 | help='List of skip numbers (e.g., --skip_nums 1 3)')
11 | parser.add_argument('--dataset_dir', type=str, default='/home/user/dataset/bs_ergb',
12 | help='Path to input dataset directory')
13 | parser.add_argument('--mode', type=str, default='1_TEST',
14 | help='Dataset mode (e.g., 1_TEST)')
15 | parser.add_argument('--voxel_prefix', type=str, default='event_voxel_grid_bin16',
16 | help='Prefix for voxel grid folder')
17 |
18 | return parser.parse_args()
19 |
20 |
21 | if __name__ == '__main__':
22 | args = parse_args()
23 | skip_num_list = args.skip_nums
24 | dataset_dir = args.dataset_dir
25 | mode = args.mode
26 | event_voxel_dir_prefix = args.voxel_prefix
27 |
28 | width, height, num_bins = 970, 625, 16
29 | dataset_with_mode = os.path.join(dataset_dir, mode)
30 | scene_list = sorted(os.listdir(dataset_with_mode))
31 |
32 | for scene_name in scene_list:
33 | image_dir = os.path.join(dataset_with_mode, scene_name, 'images')
34 | index_list = sorted([f.split('.png')[0] for f in os.listdir(image_dir) if f.endswith('.png')])
35 | event_dir = os.path.join(dataset_with_mode, scene_name, 'events')
36 |
37 | for skip_num in skip_num_list:
38 | save_dir = os.path.join(dataset_dir, mode, scene_name, event_voxel_dir_prefix, f"{skip_num}skip")
39 | os.makedirs(save_dir, exist_ok=True)
40 |
41 | num_triplets = (len(index_list) - 1) // (skip_num + 1)
42 | triplets = []
43 | ## gathering triplets
44 | for i in range(num_triplets):
45 | start = i * (skip_num+1)
46 | end = start + (skip_num+1)
47 | for i in range(1, skip_num+1):
48 | middle = start + i
49 | triplets.append((start, middle, end))
50 | for start_idx, middle_idx, end_idx in tqdm(triplets, desc=f"[{mode}] {scene_name} - skip{skip_num}"):
51 | event_0t = np.concatenate([extract_events(np.load(os.path.join(event_dir, f"{idx:06d}.npz")))
52 | for idx in range(start_idx, middle_idx)], axis=0)
53 | event_t1 = np.concatenate([extract_events(np.load(os.path.join(event_dir, f"{idx:06d}.npz")))
54 | for idx in range(middle_idx, end_idx)], axis=0)
55 |
56 | # event_0t
57 | if event_0t.shape[0] > 0:
58 | mask_0t = (event_0t[:, 1] / 32 < width) & (event_0t[:, 2] / 32 < height)
59 | _0t = np.column_stack((event_0t[mask_0t][:, 0],
60 | event_0t[mask_0t][:, 1] / 32,
61 | event_0t[mask_0t][:, 2] / 32,
62 | event_0t[mask_0t][:, 3]))
63 | event_0t_vox = events_to_voxel_grid(_0t, num_bins, width, height)
64 | event_t0_vox = events_to_voxel_grid(event_reverse(_0t.copy()), num_bins, width, height)
65 | else:
66 | event_0t_vox = event_t0_vox = np.zeros((num_bins, height, width))
67 |
68 | # event_t1
69 | if event_t1.shape[0] > 0:
70 | mask_t1 = (event_t1[:, 1] / 32 < width) & (event_t1[:, 2] / 32 < height)
71 | _t1 = np.column_stack((event_t1[mask_t1][:, 0],
72 | event_t1[mask_t1][:, 1] / 32,
73 | event_t1[mask_t1][:, 2] / 32,
74 | event_t1[mask_t1][:, 3]))
75 | event_t1_vox = events_to_voxel_grid(_t1, num_bins, width, height)
76 | else:
77 | event_t1_vox = np.zeros((num_bins, height, width))
78 |
79 | # Save
80 | base_name = f"{start_idx:06d}-{middle_idx:06d}-{end_idx:06d}"
81 | np.savez_compressed(os.path.join(save_dir, base_name + '_0t.npz'), data=event_0t_vox)
82 | np.savez_compressed(os.path.join(save_dir, base_name + '_t0.npz'), data=event_t0_vox)
83 | np.savez_compressed(os.path.join(save_dir, base_name + '_t1.npz'), data=event_t1_vox)
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from utils.utils import *
3 | from torch.optim import AdamW
4 | from tensorboardX import SummaryWriter
5 | from tqdm import tqdm, trange
6 | import datetime
7 | import argparse
8 | import os
9 | from tools.unused.dataloader_bsergb import *
10 | from models.model_manager import OurModel
11 | import torch.optim as optim
12 | from utils.flow_utils import *
13 |
14 |
15 |
16 | def get_argument():
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument('--total_epochs', type = int, default=301)
19 | parser.add_argument('--end_epochs_flow', type = int, default=100)
20 | parser.add_argument('--batch_size', type = int, default=1)
21 | parser.add_argument('--val_batch_size', type = int, default=1)
22 | # training params
23 | parser.add_argument('--voxel_num_bins', type = int, default=16)
24 | parser.add_argument('--crop_size', type = int, default=256)
25 | parser.add_argument('--learning_rate', type = float, default=1e-4)
26 | parser.add_argument('--mode', type = str, default='flow')
27 | parser.add_argument('--flow_tb_debug', type = str2bool, default='True')
28 | parser.add_argument('--flow_tb_viz', type = str2bool, default='True')
29 | parser.add_argument('--warp_tb_debug', type = str2bool, default='True')
30 | ## val folder
31 | parser.add_argument('--val_mode', type = str2bool, default='False')
32 | parser.add_argument('--val_skip_num_list', default=[1, 3])
33 | # model discription
34 | parser.add_argument('--model_folder', type=str, default='final_models')
35 | parser.add_argument('--model_name', type=str, default='ours')
36 | parser.add_argument('--use_smoothness_loss', type=str2bool, default='True')
37 | parser.add_argument('--smoothness_weight', type = float, default=10.0)
38 | # data loading params
39 | parser.add_argument('--num_threads', type = int, default=12)
40 | parser.add_argument('--experiment_name', type = str, default='train_bsergb_networks')
41 | parser.add_argument('--tb_update_thresh', type = int, default=1)
42 | parser.add_argument('--data_dir', type = str, default = '/media/mnt2/bs_ergb')
43 | parser.add_argument('--use_multigpu', type=str2bool, default='True')
44 | parser.add_argument('--train_skip_num_list', default=[1, 3])
45 | # loading module
46 | args = parser.parse_args()
47 | return args
48 |
49 |
50 | class Trainer(object):
51 | def __init__(self, args):
52 | self.args = args
53 | self._init_counters()
54 | self._init_tensorboard()
55 | self._init_dataloader()
56 | self._init_model()
57 | self._init_optimizer()
58 | self._init_scheduler()
59 | self._init_metrics()
60 |
61 | def _init_counters(self):
62 | self.tb_iter_cnt = 0
63 | self.tb_iter_cnt_val = 0
64 | self.tb_iter_cnt2 = 0
65 | self.tb_iter_cnt2_val = 0
66 | self.tb_iter_thresh = self.args.tb_update_thresh
67 | self.batchsize = self.args.batch_size
68 | self.start_epoch = 0
69 | self.end_epoch = self.args.total_epochs
70 | self.best_psnr = 0.0
71 | self.start_epoch_flow = 0
72 | self.end_epoch_flow = self.args.end_epochs_flow
73 | self.start_epoch_joint = self.args.end_epochs_flow + 1
74 |
75 | def _init_tensorboard(self):
76 | timestamp = datetime.datetime.now().strftime('%y%m%d-%H%M')
77 | tb_path = os.path.join('./experiments', f"{timestamp}-{self.args.experiment_name}")
78 | self.tb = SummaryWriter(tb_path, flush_secs=1)
79 |
80 | def _init_dataloader(self):
81 | ## train set
82 | train_set = get_BSERGB_train_dataset(self.args.data_dir, self.args.train_skip_num_list, mode='3_TRAINING')
83 | self.train_loader = torch.utils.data.DataLoader(train_set, batch_size=self.args.batch_size, shuffle=True, num_workers=self.args.num_threads, pin_memory=True, drop_last=True)
84 | ## val set
85 | val_set_dict = get_BSERGB_val_dataset(self.args.data_dir, self.args.val_skip_num_list, mode='1_TEST')
86 | # make loader per skip
87 | self.val_loader_dict = {}
88 | for skip_num, val_dataset in val_set_dict.items():
89 | self.val_loader_dict[skip_num] = torch.utils.data.DataLoader(
90 | val_dataset,
91 | batch_size=self.args.val_batch_size,
92 | shuffle=False,
93 | num_workers=self.args.num_threads,
94 | pin_memory=True
95 | )
96 |
97 | def _init_model(self):
98 | self.model = OurModel(self.args)
99 | self.model.initialize(self.args.model_folder, self.args.model_name)
100 |
101 | if torch.cuda.is_available():
102 | self.model.cuda()
103 |
104 | if self.args.use_multigpu:
105 | self.model.use_multi_gpu()
106 |
107 | def _init_optimizer(self):
108 | params = self.model.get_optimizer_params()
109 | self.optimizer = AdamW(params, lr=self.args.learning_rate)
110 |
111 | def _init_scheduler(self):
112 | if self.args.mode == 'joint':
113 | milestones = [200, 270]
114 | elif self.args.mode == 'flow':
115 | milestones = [60]
116 | else:
117 | milestones = []
118 |
119 | if milestones:
120 | self.scheduler = optim.lr_scheduler.MultiStepLR(
121 | self.optimizer,
122 | milestones=milestones,
123 | gamma=0.5
124 | )
125 |
126 | def _init_metrics(self):
127 | self.PSNR_calculator = PSNR()
128 | self.SSIM_calculator = SSIM()
129 |
130 | def mode_classify(self):
131 | # Mode override by argument
132 | if self.args.mode == 'joint':
133 | mode = 'joint'
134 | elif self.epoch <= self.end_epoch_flow:
135 | mode = 'flow'
136 | elif self.start_epoch_joint <= self.epoch <= self.end_epoch:
137 | mode = 'joint'
138 | else:
139 | raise ValueError(f"Invalid epoch {self.epoch} for mode classification.")
140 | self.model.set_mode(mode)
141 | # Automatically freeze flownet if in joint mode
142 | if mode == 'joint':
143 | self.model.fix_flownet()
144 | return mode
145 |
146 | def train(self):
147 | for self.epoch in trange(self.start_epoch, self.end_epoch, desc='epoch progress'):
148 | self.model.train()
149 | mode_now = self.mode_classify()
150 |
151 | for _, sample in enumerate(tqdm(self.train_loader, desc='train progress')):
152 | self.train_step(sample, mode=mode_now)
153 |
154 | if self.epoch % 10 == 0 and mode_now == 'joint':
155 | psnr_val, _ = self.val_joint(self.epoch)
156 | if psnr_val > self.best_psnr:
157 | self.best_psnr = psnr_val
158 | self.save_model(self.epoch, best=True)
159 | print(f"[Best Model Updated] Epoch {self.epoch} - PSNR: {psnr_val:.2f}")
160 |
161 | self.scheduler.step()
162 |
163 |
164 | def train_step(self, sample, mode):
165 | # --- Move batch to device and zero optimizer ---
166 | sample = batch2device(sample)
167 | self.optimizer.zero_grad()
168 |
169 | # --- Set input for model ---
170 | self.model.set_train_input(sample)
171 |
172 | # --- Forward pass and compute loss ---
173 | self.model.forward_nets()
174 | if mode == 'flow':
175 | loss = self.model.get_flow_loss()
176 | elif mode == 'joint':
177 | loss = self.model.get_multi_scale_loss()
178 | else:
179 | raise ValueError(f"Unsupported mode: {mode}")
180 |
181 | # --- Backpropagation and optimization ---
182 | loss.backward()
183 | self.optimizer.step()
184 |
185 | # --- Update training status ---
186 | self.model.update_loss_meters(mode)
187 | self.tb_iter_cnt += 1
188 |
189 | if self.batchsize * self.tb_iter_cnt > self.tb_iter_thresh:
190 | self.log_train_tb(mode)
191 |
192 | # --- Clean up ---
193 | del sample
194 |
195 |
196 | def log_train_tb(self, mode):
197 | def add_scalar(tag, value):
198 | self.tb.add_scalar(tag, value, self.tb_iter_cnt2)
199 |
200 | def add_image(tag, image):
201 | self.tb.add_image(tag, image, self.tb_iter_cnt2)
202 |
203 | def add_flow_image(tag, flow_tensor):
204 | flow_img = flow_to_image(flow_tensor.detach().cpu().permute(1, 2, 0).numpy()).transpose(2, 0, 1)
205 | add_image(tag, flow_img)
206 |
207 | # --- Log loss values ---
208 | add_scalar('train_progress/loss_total', self.model.loss_handler.loss_total_meter.avg)
209 | add_scalar('train_progress/loss_flow', self.model.loss_handler.loss_flow_meter.avg)
210 | add_scalar('train_progress/loss_warp', self.model.loss_handler.loss_warp_meter.avg)
211 | add_scalar('train_progress/loss_smoothness', self.model.loss_handler.loss_smoothness_meter.avg)
212 |
213 | # --- Log interpolation input images ---
214 | add_image('train_image/clean_image_first', self.model.batch['image_input0'][0])
215 | add_image('train_image/clean_image_last', self.model.batch['image_input1'][0])
216 | add_image('train_image/interp_gt', self.model.batch['clean_gt_images'][0])
217 |
218 | # --- Log predicted optical flow (estimated) ---
219 | if self.args.flow_tb_viz:
220 | add_flow_image('train_flow/flow_t0_est', self.model.outputs['flow_out']['flow_t0_dict'][0][0])
221 | add_flow_image('train_flow/flow_t1_est', self.model.outputs['flow_out']['flow_t1_dict'][0][0])
222 |
223 | # --- Debug intermediate flow results ---
224 | if self.args.flow_tb_debug:
225 | add_flow_image('train_flow_debug_0/flow_event', self.model.outputs['flow_out']['event_flow_dict'][0][0])
226 | add_flow_image('train_flow_debug_0/flow_image', self.model.outputs['flow_out']['image_flow_dict'][0][0])
227 | add_flow_image('train_flow_debug_0/flow_fusion', self.model.outputs['flow_out']['fusion_flow_dict'][0][0])
228 | add_image('train_flow_debug_0/event_flow_mask', self.model.outputs['flow_out']['mask_dict'][0][0])
229 |
230 | # --- Joint training-specific logging ---
231 | if mode == 'joint':
232 | add_scalar('train_progress/loss_image', self.model.loss_handler.loss_image_meter.avg)
233 | add_image('train_image/interp_out', self.model.outputs['interp_out'][0][0])
234 | elif mode == 'flow':
235 | # --- Warp output visualization ---
236 | if self.args.warp_tb_debug:
237 | add_image('train_warp_output/warp_image_0t', self.model.batch['imaget_est0_warp'][0][0])
238 | add_image('train_warp_output/warp_image_t1', self.model.batch['imaget_est1_warp'][0][0])
239 | add_image('train_warp_output/warp_image_gt', self.model.batch['clean_gt_images'][0])
240 |
241 | # --- Update counters and reset meters ---
242 | self.tb_iter_cnt2 += 1
243 | self.tb_iter_cnt = 0
244 | self.model.loss_handler.reset_meters()
245 |
246 | def val_joint(self, epoch):
247 | # Total and per-interval metric meters
248 | psnr_total = AverageMeter()
249 | ssim_total = AverageMeter()
250 | psnr_interval = AverageMeter()
251 | ssim_interval = AverageMeter()
252 |
253 | # Set model to evaluation mode
254 | self.model.eval()
255 | # set model mode
256 | self.model.set_mode('joint')
257 |
258 | with torch.no_grad():
259 | for skip_num, val_loader in self.val_loader_dict.items():
260 | for _, sample in enumerate(tqdm(val_loader, desc=f'val skip {skip_num}')):
261 | sample = batch2device(sample)
262 | self.model.set_test_input(sample)
263 | self.model.forward_joint_test()
264 |
265 | gt = sample['clean_middle']
266 | pred = self.model.test_outputs['interp_out']
267 |
268 | psnr = self.PSNR_calculator(gt, pred).mean().item()
269 | ssim = self.SSIM_calculator(gt, pred).mean().item()
270 |
271 | psnr_interval.update(psnr)
272 | ssim_interval.update(ssim)
273 |
274 | psnr_total.update(psnr)
275 | ssim_total.update(ssim)
276 |
277 | # Log per interval result
278 | self.tb.add_scalar(f'val_progress/BSERGB/{skip_num}skip/avg_psnr_interp', psnr_interval.avg, epoch)
279 | self.tb.add_scalar(f'val_progress/BSERGB/{skip_num}skip/avg_ssim_interp', ssim_interval.avg, epoch)
280 |
281 | psnr_interval.reset()
282 | ssim_interval.reset()
283 |
284 | # Log total result
285 | self.tb.add_scalar('val_progress/BSERGB/average/avg_psnr_interp', psnr_total.avg, epoch)
286 | self.tb.add_scalar('val_progress/BSERGB/average/avg_ssim_interp', ssim_total.avg, epoch)
287 |
288 | torch.cuda.empty_cache()
289 | self.model.test_outputs = {}
290 | return psnr_total.avg, ssim_total.avg
291 |
292 | def save_model(self, epoch):
293 | combined_state_dict = {
294 | 'epoch': self.epoch,
295 | 'model_state_dict': self.model.net.state_dict(),
296 | 'Optimizer_state_dict' : self.optimizer.state_dict(),
297 | 'Scheduler_state_dict' : self.scheduler.state_dict()}
298 | torch.save(combined_state_dict, os.path.join(self.model.save_path, 'best_model_' + str(epoch) + '_ep.pth'))
299 |
300 |
301 | if __name__=='__main__':
302 | args = get_argument()
303 | trainer = Trainer(args)
304 | if args.val_mode == True:
305 | trainer.val_joint(0)
306 | else:
307 | trainer.train()
308 |
--------------------------------------------------------------------------------
/utils/dataloader_bsergb.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch.utils.data as data
3 | from torchvision import transforms
4 | import numpy as np
5 | import random
6 | from PIL import Image
7 | from torch.utils.data import ConcatDataset
8 | import re
9 |
10 |
11 |
12 |
13 | class BSERGB_val_dataset(data.Dataset):
14 | def __init__(self, data_path, skip_num):
15 | super(BSERGB_val_dataset, self).__init__()
16 | # Image and event prefix
17 | self.event_vox_prefix = 'event_voxel_grid_bin16'
18 | self.sharp_image_prefix = 'images'
19 | # skip list
20 | self.skip_num = skip_num
21 | # Transform
22 | self.transform = transforms.ToTensor()
23 | # Initialize file taxonomy
24 | self.get_filetaxnomy(data_path)
25 | ## crop size
26 | self.crop_height = 256
27 | self.crop_width = 256
28 |
29 | def get_filetaxnomy(self, data_dir):
30 | event_voxel_dir = os.path.join(data_dir, self.event_vox_prefix, str(self.skip_num) + 'skip')
31 | clean_image_dir = os.path.join(data_dir, self.sharp_image_prefix)
32 | index_list = [f.split('.png')[0] for f in os.listdir(clean_image_dir) if f.endswith(".png")]
33 | index_list.sort()
34 | self.input_name_dict = {}
35 | self.input_name_dict['event_voxels_0t'] = []
36 | self.input_name_dict['event_voxels_t0'] = []
37 | self.input_name_dict['event_voxels_t1'] = []
38 | self.input_name_dict['clean_image_first'] = []
39 | self.input_name_dict['clean_image_last'] = []
40 | self.input_name_dict['gt_image'] = []
41 | num_triplets = (len(index_list)-1) // (self.skip_num+1)
42 | triplets = []
43 | ## gathering triplets
44 | for i in range(num_triplets):
45 | start = i * (self.skip_num+1)
46 | end = start + (self.skip_num+1)
47 | for i in range(1, self.skip_num+1):
48 | middle = start + i
49 | triplets.append((start, middle, end))
50 | for triplet in triplets:
51 | first_idx = triplet[0]
52 | middle_idx = triplet[1]
53 | end_idx = triplet[2]
54 | self.input_name_dict['clean_image_first'].append(os.path.join(clean_image_dir, str(first_idx).zfill(6) + '.png'))
55 | self.input_name_dict['clean_image_last'].append(os.path.join(clean_image_dir, str(end_idx).zfill(6) + '.png'))
56 | self.input_name_dict['gt_image'].append(os.path.join(clean_image_dir, str(middle_idx).zfill(6) + '.png'))
57 | self.input_name_dict['event_voxels_0t'].append(os.path.join(event_voxel_dir, f"{first_idx:06d}-{middle_idx:06d}-{end_idx:06d}_0t.npz"))
58 | self.input_name_dict['event_voxels_t0'].append(os.path.join(event_voxel_dir, f"{first_idx:06d}-{middle_idx:06d}-{end_idx:06d}_t0.npz"))
59 | self.input_name_dict['event_voxels_t1'].append(os.path.join(event_voxel_dir, f"{first_idx:06d}-{middle_idx:06d}-{end_idx:06d}_t1.npz"))
60 |
61 | def __getitem__(self, index):
62 | # first image
63 | first_image_path = self.input_name_dict['clean_image_first'][index]
64 | first_image = Image.open(first_image_path)
65 | first_image_tensor = self.transform(first_image)
66 | # second image
67 | second_image_path = self.input_name_dict['clean_image_last'][index]
68 | second_image = Image.open(second_image_path)
69 | second_image_tensor = self.transform(second_image)
70 | # gt image
71 | gt_image_path = self.input_name_dict['gt_image'][index]
72 | gt_image = Image.open(gt_image_path)
73 | gt_image_tensor = self.transform(gt_image)
74 | ## event voxel
75 | # 0t voxel
76 | event_vox_0t_path = self.input_name_dict['event_voxels_0t'][index]
77 | event_vox_0t = np.load(event_vox_0t_path)["data"]
78 | # t1 voxel
79 | event_vox_t1_path = self.input_name_dict['event_voxels_t1'][index]
80 | event_vox_t1 = np.load(event_vox_t1_path)["data"]
81 | # 1t voxel
82 | event_vox_t0_path = self.input_name_dict['event_voxels_t0'][index]
83 | event_vox_t0 = np.load(event_vox_t0_path)["data"]
84 | ## return sample!!
85 | sample = dict()
86 | sample['clean_middle'] = gt_image_tensor
87 | sample['clean_image_first'] = first_image_tensor
88 | sample['clean_image_last'] = second_image_tensor
89 | sample['voxel_grid_0t'] = event_vox_0t
90 | sample['voxel_grid_t1'] = event_vox_t1
91 | sample['voxel_grid_t0'] = event_vox_t0
92 | return sample
93 |
94 | def __len__(self):
95 | return len(self.input_name_dict['gt_image'])
96 |
97 | def get_BSERGB_val_dataset(data_dir, skip_list, mode='1_TEST'):
98 | dataset_path_sub = os.path.join(data_dir, mode)
99 | scene_list = sorted(os.listdir(dataset_path_sub))
100 | dataset_dict = {}
101 | for skip_num in skip_list:
102 | dataset_list = []
103 | for scene in scene_list:
104 | dataset_path_full = os.path.join(dataset_path_sub, scene)
105 | dset = BSERGB_val_dataset(dataset_path_full, skip_num)
106 | dataset_list.append(dset)
107 | dataset_concat = ConcatDataset(dataset_list)
108 | dataset_dict[skip_num] = dataset_concat
109 | return dataset_dict
110 |
111 |
112 |
--------------------------------------------------------------------------------
/utils/flow_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import inspect
6 | UNKNOWN_FLOW_THRESH = 1e7
7 |
8 |
9 | def compute_color(u, v):
10 | """
11 | compute optical flow color map
12 | :param u: optical flow horizontal map
13 | :param v: optical flow vertical map
14 | :return: optical flow in color code
15 | """
16 | [h, w] = u.shape
17 | img = np.zeros([h, w, 3])
18 | nanIdx = np.isnan(u) | np.isnan(v)
19 | u[nanIdx] = 0
20 | v[nanIdx] = 0
21 |
22 | colorwheel = make_color_wheel()
23 | ncols = np.size(colorwheel, 0)
24 |
25 | rad = np.sqrt(u**2+v**2)
26 |
27 | a = np.arctan2(-v, -u) / np.pi
28 |
29 | fk = (a+1) / 2 * (ncols - 1) + 1
30 |
31 | k0 = np.floor(fk).astype(int)
32 |
33 | k1 = k0 + 1
34 | k1[k1 == ncols+1] = 1
35 | f = fk - k0
36 |
37 | for i in range(0, np.size(colorwheel,1)):
38 | tmp = colorwheel[:, i]
39 | col0 = tmp[k0-1] / 255
40 | col1 = tmp[k1-1] / 255
41 | col = (1-f) * col0 + f * col1
42 |
43 | idx = rad <= 1
44 | col[idx] = 1-rad[idx]*(1-col[idx])
45 | notidx = np.logical_not(idx)
46 |
47 | col[notidx] *= 0.75
48 | img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx)))
49 | return img
50 |
51 |
52 | def mesh_grid(B, H, W):
53 | # mesh grid
54 | x_base = torch.arange(0, W).repeat(B, H, 1) # BHW
55 | y_base = torch.arange(0, H).repeat(B, W, 1).transpose(1, 2) # BHW
56 | base_grid = torch.stack([x_base, y_base], 1) # B2HW
57 | return base_grid
58 |
59 |
60 | def norm_grid(v_grid):
61 | _, _, H, W = v_grid.size()
62 |
63 | # scale grid to [-1,1]
64 | v_grid_norm = torch.zeros_like(v_grid)
65 | v_grid_norm[:, 0, :, :] = 2.0 * v_grid[:, 0, :, :] / (W - 1) - 1.0
66 | v_grid_norm[:, 1, :, :] = 2.0 * v_grid[:, 1, :, :] / (H - 1) - 1.0
67 | return v_grid_norm.permute(0, 2, 3, 1) # BHW2
68 |
69 | def cal_grad2_error(flo, image, beta):
70 | """
71 | Calculate the image-edge-aware second-order smoothness loss for flo
72 | """
73 | def gradient(pred):
74 | D_dy = pred[:, :, 1:, :] - pred[:, :, :-1, :]
75 | D_dx = pred[:, :, :, 1:] - pred[:, :, :, :-1]
76 | return D_dx, D_dy
77 |
78 | img_grad_x, img_grad_y = gradient(image)
79 | weights_x = torch.exp(-10.0 * torch.mean(torch.abs(img_grad_x), 1, keepdim=True))
80 | weights_y = torch.exp(-10.0 * torch.mean(torch.abs(img_grad_y), 1, keepdim=True))
81 |
82 | dx, dy = gradient(flo)
83 | dx2, dxdy = gradient(dx)
84 | dydx, dy2 = gradient(dy)
85 | return (torch.mean(beta*weights_x[:,:, :, 1:]*torch.abs(dx2)) + torch.mean(beta*weights_y[:, :, 1:, :]*torch.abs(dy2))) / 2.0
86 |
87 | def make_colorwheel():
88 | """
89 | Generates a color wheel for optical flow visualization as presented in:
90 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
91 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
92 | Code follows the original C++ source code of Daniel Scharstein.
93 | Code follows the the Matlab source code of Deqing Sun.
94 | Returns:
95 | np.ndarray: Color wheel
96 | """
97 |
98 | RY = 15
99 | YG = 6
100 | GC = 4
101 | CB = 11
102 | BM = 13
103 | MR = 6
104 |
105 | ncols = RY + YG + GC + CB + BM + MR
106 | colorwheel = np.zeros((ncols, 3))
107 | col = 0
108 |
109 | # RY
110 | colorwheel[0:RY, 0] = 255
111 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
112 | col = col+RY
113 | # YG
114 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
115 | colorwheel[col:col+YG, 1] = 255
116 | col = col+YG
117 | # GC
118 | colorwheel[col:col+GC, 1] = 255
119 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
120 | col = col+GC
121 | # CB
122 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
123 | colorwheel[col:col+CB, 2] = 255
124 | col = col+CB
125 | # BM
126 | colorwheel[col:col+BM, 2] = 255
127 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
128 | col = col+BM
129 | # MR
130 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
131 | colorwheel[col:col+MR, 0] = 255
132 | return colorwheel
133 |
134 |
135 | def flow_uv_to_colors(u, v, convert_to_bgr=False):
136 | """
137 | Applies the flow color wheel to (possibly clipped) flow components u and v.
138 | According to the C++ source code of Daniel Scharstein
139 | According to the Matlab source code of Deqing Sun
140 | Args:
141 | u (np.ndarray): Input horizontal flow of shape [H,W]
142 | v (np.ndarray): Input vertical flow of shape [H,W]
143 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
144 | Returns:
145 | np.ndarray: Flow visualization image of shape [H,W,3]
146 | """
147 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
148 | colorwheel = make_colorwheel() # shape [55x3]
149 | ncols = colorwheel.shape[0]
150 | rad = np.sqrt(np.square(u) + np.square(v))
151 | a = np.arctan2(-v, -u)/np.pi
152 | fk = (a+1) / 2*(ncols-1)
153 | k0 = np.floor(fk).astype(np.int32)
154 | k1 = k0 + 1
155 | k1[k1 == ncols] = 0
156 | f = fk - k0
157 | for i in range(colorwheel.shape[1]):
158 | tmp = colorwheel[:,i]
159 | col0 = tmp[k0] / 255.0
160 | col1 = tmp[k1] / 255.0
161 | col = (1-f)*col0 + f*col1
162 | idx = (rad <= 1)
163 | col[idx] = 1 - rad[idx] * (1-col[idx])
164 | col[~idx] = col[~idx] * 0.75 # out of range
165 | # Note the 2-i => BGR instead of RGB
166 | ch_idx = 2-i if convert_to_bgr else i
167 | flow_image[:,:,ch_idx] = np.floor(255*col)
168 | return flow_image
169 |
170 | def flow_warp(x, flow, pad='border', mode='bilinear'):
171 | B, _, H, W = x.size()
172 |
173 | base_grid = mesh_grid(B, H, W).type_as(x) # B2HW
174 | v_grid = norm_grid(base_grid + flow) # BHW2
175 | if 'align_corners' in inspect.getfullargspec(torch.nn.functional.grid_sample).args:
176 | im1_recons = nn.functional.grid_sample(x, v_grid, mode=mode, padding_mode=pad, align_corners=True)
177 | else:
178 | im1_recons = nn.functional.grid_sample(x, v_grid, mode=mode, padding_mode=pad)
179 | return im1_recons
180 |
181 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
182 | """
183 | Expects a two dimensional flow image of shape.
184 | Args:
185 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
186 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
187 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
188 | Returns:
189 | np.ndarray: Flow visualization image of shape [H,W,3]
190 | """
191 | assert flow_uv.ndim == 3, 'input flow must have three dimensions'
192 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
193 | if clip_flow is not None:
194 | flow_uv = np.clip(flow_uv, 0, clip_flow)
195 | u = flow_uv[:,:,0]
196 | v = flow_uv[:,:,1]
197 | rad = np.sqrt(np.square(u) + np.square(v))
198 | rad_max = np.max(rad)
199 | epsilon = 1e-5
200 | u = u / (rad_max + epsilon)
201 | v = v / (rad_max + epsilon)
202 | return flow_uv_to_colors(u, v, convert_to_bgr)
203 |
204 | def flow2rgb(flow_map, max_value):
205 | flow_map_np = flow_map.detach().cpu().numpy()
206 | _, h, w = flow_map_np.shape
207 | flow_map_np[:,(flow_map_np[0] == 0) & (flow_map_np[1] == 0)] = float('nan')
208 | rgb_map = np.ones((3,h,w)).astype(np.float32)
209 | if max_value is not None:
210 | normalized_flow_map = flow_map_np / max_value
211 | else:
212 | normalized_flow_map = flow_map_np / (np.abs(flow_map_np).max())
213 | rgb_map[0] += normalized_flow_map[0]
214 | rgb_map[1] -= 0.5*(normalized_flow_map[0] + normalized_flow_map[1])
215 | rgb_map[2] += normalized_flow_map[1]
216 | return rgb_map.clip(0,1)
217 |
218 | def warp(x, flo, return_mask=False):
219 | B, C, H, W = x.size()
220 | # mesh grid
221 | xx = torch.arange(0, W).view(1, 1, 1, W).expand(B, 1, H, W)
222 | yy = torch.arange(0, H).view(1, 1, H, 1).expand(B, 1, H, W)
223 |
224 | grid = torch.cat((xx, yy), 1).float()
225 |
226 | if x.is_cuda:
227 | grid = grid.cuda()
228 |
229 | vgrid = torch.autograd.Variable(grid) + flo
230 |
231 | # scale grid to [-1,1]
232 | vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :] / max(W - 1, 1) - 1.0
233 | vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :] / max(H - 1, 1) - 1.0
234 |
235 | vgrid = vgrid.permute(0, 2, 3, 1)
236 | output = nn.functional.grid_sample(x, vgrid)
237 | mask = torch.autograd.Variable(torch.ones(x.size())).cuda()
238 | mask = nn.functional.grid_sample(mask, vgrid)
239 |
240 | mask = mask.masked_fill_(mask < 0.999, 0)
241 | mask = mask.masked_fill_(mask > 0, 1)
242 |
243 | if return_mask:
244 | return output, mask
245 | else:
246 | return output
247 |
248 | def EPE(input_flow, target_flow, sparse=False, mean=True):
249 | EPE_map = torch.norm(target_flow-input_flow,2,1)
250 | batch_size = EPE_map.size(0)
251 | if sparse:
252 | # invalid flow is defined with both flow coordinates to be exactly 0
253 | mask = (target_flow[:,0] == 0) & (target_flow[:,1] == 0)
254 |
255 | EPE_map = EPE_map[~mask]
256 | if mean:
257 | return EPE_map.mean()
258 | else:
259 | return EPE_map.sum()/batch_size
260 |
261 | def realEPE(output, target, sparse=False):
262 | b, _, h, w = target.size()
263 | upsampled_output = F.interpolate(output, (h,w), mode='bilinear', align_corners=False)
264 | return EPE(upsampled_output, target, sparse, mean=True)
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | from torch.autograd import Variable
5 | import torch.nn.functional as F
6 | from math import exp
7 |
8 | def str2bool(v):
9 | return v.lower() in ('true')
10 |
11 | def tensor2numpy(tensor, rgb_range=1.):
12 | rgb_coefficient = 255 / rgb_range
13 | img = tensor.mul(rgb_coefficient).clamp(0, 255).round()
14 | img = img[0].data
15 | img = np.transpose(img.cpu().numpy(), (1, 2, 0)).astype(np.uint8)
16 | return img
17 |
18 | def batch2device(dictionary_of_tensors):
19 | if isinstance(dictionary_of_tensors, dict):
20 | return {key: batch2device(value) for key, value in dictionary_of_tensors.items()}
21 | return dictionary_of_tensors.cuda()
22 |
23 | def str2bool(v):
24 | return v.lower() in ('true')
25 |
26 | def randomCrop(tensor, x, y, height, width):
27 | tensor = tensor[..., y:y+height, x:x+width]
28 | return tensor
29 |
30 | def tensor2numpy(tensor, rgb_range=1.):
31 | rgb_coefficient = 255 / rgb_range
32 | img = tensor.mul(rgb_coefficient).clamp(0, 255).round()
33 | img = img[0].data
34 | img = np.transpose(img.cpu().numpy(), (1, 2, 0)).astype(np.uint8)
35 | return img
36 |
37 | def tensor2numpy_batch_idxs(tensor, batch_idx, rgb_range=1.):
38 | rgb_coefficient = 255 / rgb_range
39 | img = tensor.mul(rgb_coefficient).clamp(0, 255).round()
40 | img = img[batch_idx].data
41 | img = np.transpose(img.cpu().numpy(), (1, 2, 0)).astype(np.uint8)
42 | return img
43 |
44 |
45 | class AverageMeter(object):
46 | """Computes and stores the average and current value"""
47 | def __init__(self):
48 | self.reset()
49 |
50 | def reset(self):
51 | self.val = 0
52 | self.avg = 0
53 | self.sum = 0
54 | self.count = 0
55 |
56 | def update(self, val, n=1):
57 | self.val = val
58 | self.sum += val * n
59 | self.count += n
60 | self.avg = self.sum / self.count
61 |
62 | def gaussian(window_size, sigma):
63 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
64 | return gauss/gauss.sum()
65 |
66 | def create_window(window_size, channel):
67 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
68 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
69 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
70 | return window
71 |
72 | class PSNR:
73 | def __init__(self):
74 | self.name = "PSNR"
75 |
76 | @staticmethod
77 | def __call__(img1, img2):
78 | img1 = img1.reshape(img1.shape[0], -1)
79 | img2 = img2.reshape(img2.shape[0], -1)
80 | mse = torch.mean((img1 - img2) ** 2, dim=1)
81 | return 10* torch.log10(1 / mse)
82 |
83 |
84 | def _ssim(img1, img2, window, window_size, channel, size_average = True):
85 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
86 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
87 |
88 | mu1_sq = mu1.pow(2)
89 | mu2_sq = mu2.pow(2)
90 | mu1_mu2 = mu1*mu2
91 |
92 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
93 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
94 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
95 |
96 | C1 = 0.01**2
97 | C2 = 0.03**2
98 |
99 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
100 |
101 | if size_average:
102 | return ssim_map.mean()
103 | else:
104 | return ssim_map.mean(1).mean(1).mean(1)
105 |
106 |
107 | class SSIM(torch.nn.Module):
108 | def __init__(self, window_size = 11, size_average = False):
109 | super(SSIM, self).__init__()
110 | self.window_size = window_size
111 | self.size_average = size_average
112 | self.channel = 1
113 | self.window = create_window(window_size, self.channel)
114 |
115 | def forward(self, img1, img2):
116 | (_, channel, _, _) = img1.size()
117 |
118 | if channel == self.channel and self.window.data.type() == img1.data.type():
119 | window = self.window
120 | else:
121 | window = create_window(self.window_size, channel)
122 |
123 | if img1.is_cuda:
124 | window = window.cuda(img1.get_device())
125 | window = window.type_as(img1)
126 |
127 | self.window = window
128 | self.channel = channel
129 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
130 |
--------------------------------------------------------------------------------