├── DVSTool ├── lib │ ├── __init__.py │ ├── pwcNet │ │ ├── __init__.py │ │ ├── correlation_pytorch1_1 │ │ │ ├── clean.sh │ │ │ ├── setup.py │ │ │ ├── correlation_cuda_kernel.cuh │ │ │ └── correlation_cuda.cc │ │ └── pwc_net.pth.tar │ ├── forwardWarpTorch │ │ ├── __init__.py │ │ ├── fWarpGaussion.py │ │ └── forwardWarp.py │ ├── distribTool.py │ ├── fileTool.py │ ├── fitTool.py │ ├── videoTool.py │ ├── metrics.py │ ├── checkTool.py │ ├── eventJiangzhe.py │ └── visualTool.py ├── splitTrainTest.py ├── resplit.py ├── mainDVSProcess_01.py └── DVSBase.py ├── stage1 ├── dataset │ └── .gitkeep ├── lib │ ├── __init__.py │ ├── pwcNet │ │ ├── __init__.py │ │ ├── correlation_pytorch1_1 │ │ │ ├── clean.sh │ │ │ ├── setup.py │ │ │ ├── correlation_cuda_kernel.cuh │ │ │ └── correlation_cuda.cc │ │ └── pwc_net.pth.tar │ ├── distribTool.py │ ├── lossTool.py │ ├── fileTool.py │ ├── videoTool.py │ ├── fitTool.py │ ├── metrics.py │ ├── visualTool.py │ ├── warp.py │ └── checkTool.py ├── output │ └── .gitkeep ├── configs │ ├── __init__.py │ └── configEVI.py ├── models │ ├── __init__.py │ ├── FrameUnet.py │ ├── EventUnet.py │ ├── FuseNet.py │ └── Generator.py ├── dataloader │ ├── __init__.py │ ├── dataloaderBase.py │ └── eventReader.py └── runBash │ └── runEvi.py ├── stage2 ├── lib │ ├── __init__.py │ ├── pwcNet │ │ ├── __init__.py │ │ ├── correlation_pytorch1_1 │ │ │ ├── clean.sh │ │ │ ├── setup.py │ │ │ ├── correlation_cuda_kernel.cuh │ │ │ └── correlation_cuda.cc │ │ └── pwc_net.pth.tar │ ├── distribTool.py │ ├── getMSQMatrix.py │ ├── lossTool.py │ ├── fileTool.py │ ├── fitTool.py │ ├── videoTool.py │ ├── metrics.py │ ├── warp.py │ └── visualTool.py ├── configs │ ├── __init__.py │ ├── configTest.py │ └── configEVI.py ├── models │ ├── __init__.py │ ├── FrameUnet.py │ ├── EventUnet.py │ ├── FuseNet.py │ ├── Generator.py │ └── subPixelAttn.py ├── dataloader │ ├── __init__.py │ └── dataloaderBase.py ├── output │ └── Demo_train_on_lowfps_S1 │ │ ├── gif │ │ └── .gitkeep │ │ ├── state │ │ └── bestEVI_epoch100.pth │ │ ├── events │ │ └── events.out.tfevents.1637397568.sensetime │ │ └── config.txt ├── matrixC.npy └── runBash │ └── runEvi.py ├── dataset ├── fastDVS_dataset │ └── .gitkeep ├── fastDVS_process │ └── .gitkeep └── aedat4 │ ├── test │ └── play_dvSave-2020_10_27_10_50_01.aedat4 │ └── train │ └── play_dvSave-2020_10_27_10_50_01.aedat4 ├── .gitignore ├── .gitattributes ├── LICENSE └── README.md /DVSTool/lib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /stage1/dataset/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /stage1/lib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /stage1/output/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /stage2/lib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /stage1/configs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /stage1/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /stage2/configs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /stage2/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /DVSTool/lib/pwcNet/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/fastDVS_dataset/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/fastDVS_process/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /stage1/dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /stage1/lib/pwcNet/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /stage2/dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /stage2/lib/pwcNet/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /DVSTool/lib/forwardWarpTorch/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /stage2/output/Demo_train_on_lowfps_S1/gif/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.png 2 | *.jpg 3 | .idea/ 4 | __pycache__/ 5 | -------------------------------------------------------------------------------- /stage2/matrixC.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hityzy1122/WEVI/HEAD/stage2/matrixC.npy -------------------------------------------------------------------------------- /DVSTool/lib/pwcNet/correlation_pytorch1_1/clean.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | rm -rf build *.egg-info dist 3 | # python setup.py install 4 | -------------------------------------------------------------------------------- /stage1/lib/pwcNet/correlation_pytorch1_1/clean.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | rm -rf build *.egg-info dist 3 | # python setup.py install 4 | -------------------------------------------------------------------------------- /stage2/lib/pwcNet/correlation_pytorch1_1/clean.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | rm -rf build *.egg-info dist 3 | # python setup.py install 4 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.aedat4 filter=lfs diff=lfs merge=lfs -text 2 | *.pth filter=lfs diff=lfs merge=lfs -text 3 | *.tar filter=lfs diff=lfs merge=lfs -text 4 | -------------------------------------------------------------------------------- /stage1/lib/pwcNet/pwc_net.pth.tar: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:862f65ab73f1074e3f1bd7bfdcdc0eddb9b3fe5ed98dd8f2acb01b1122b37acc 3 | size 37512100 4 | -------------------------------------------------------------------------------- /stage2/lib/pwcNet/pwc_net.pth.tar: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:862f65ab73f1074e3f1bd7bfdcdc0eddb9b3fe5ed98dd8f2acb01b1122b37acc 3 | size 37512100 4 | -------------------------------------------------------------------------------- /DVSTool/lib/pwcNet/pwc_net.pth.tar: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:862f65ab73f1074e3f1bd7bfdcdc0eddb9b3fe5ed98dd8f2acb01b1122b37acc 3 | size 37512100 4 | -------------------------------------------------------------------------------- /dataset/aedat4/test/play_dvSave-2020_10_27_10_50_01.aedat4: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:fa49d2b92dc203fb374290a5228003c541165116b8ab9ce60a0e71d15c96afa6 3 | size 15846310 4 | -------------------------------------------------------------------------------- /dataset/aedat4/train/play_dvSave-2020_10_27_10_50_01.aedat4: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:fa49d2b92dc203fb374290a5228003c541165116b8ab9ce60a0e71d15c96afa6 3 | size 15846310 4 | -------------------------------------------------------------------------------- /stage2/output/Demo_train_on_lowfps_S1/state/bestEVI_epoch100.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:380850ca551aa580bb7057aa5fe13f5ed3fee2bcb5f2e588006f8ac49db8d967 3 | size 185441584 4 | -------------------------------------------------------------------------------- /stage2/output/Demo_train_on_lowfps_S1/events/events.out.tfevents.1637397568.sensetime: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hityzy1122/WEVI/HEAD/stage2/output/Demo_train_on_lowfps_S1/events/events.out.tfevents.1637397568.sensetime -------------------------------------------------------------------------------- /stage1/lib/pwcNet/correlation_pytorch1_1/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import torch 4 | 5 | from setuptools import setup, find_packages 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | 8 | cxx_args = ['-std=c++11'] 9 | 10 | nvcc_args = [ 11 | '-gencode', 'arch=compute_50,code=sm_50', 12 | '-gencode', 'arch=compute_52,code=sm_52', 13 | '-gencode', 'arch=compute_60,code=sm_60', 14 | '-gencode', 'arch=compute_61,code=sm_61' 15 | # '-gencode', 'arch=compute_70,code=sm_70', 16 | # '-gencode', 'arch=compute_70,code=compute_70' 17 | ] 18 | 19 | setup( 20 | name='correlation_cuda', 21 | ext_modules=[ 22 | CUDAExtension('correlation_cuda', [ 23 | 'correlation_cuda.cc', 24 | 'correlation_cuda_kernel.cu' 25 | ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}) 26 | ], 27 | cmdclass={ 28 | 'build_ext': BuildExtension 29 | }) 30 | -------------------------------------------------------------------------------- /stage2/lib/pwcNet/correlation_pytorch1_1/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import torch 4 | 5 | from setuptools import setup, find_packages 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | 8 | cxx_args = ['-std=c++11'] 9 | 10 | nvcc_args = [ 11 | '-gencode', 'arch=compute_50,code=sm_50', 12 | '-gencode', 'arch=compute_52,code=sm_52', 13 | '-gencode', 'arch=compute_60,code=sm_60', 14 | '-gencode', 'arch=compute_61,code=sm_61' 15 | # '-gencode', 'arch=compute_70,code=sm_70', 16 | # '-gencode', 'arch=compute_70,code=compute_70' 17 | ] 18 | 19 | setup( 20 | name='correlation_cuda', 21 | ext_modules=[ 22 | CUDAExtension('correlation_cuda', [ 23 | 'correlation_cuda.cc', 24 | 'correlation_cuda_kernel.cu' 25 | ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}) 26 | ], 27 | cmdclass={ 28 | 'build_ext': BuildExtension 29 | }) 30 | -------------------------------------------------------------------------------- /DVSTool/lib/pwcNet/correlation_pytorch1_1/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import torch 4 | 5 | from setuptools import setup, find_packages 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | include_dirs = ['/usr/local/cuda/include/'] 8 | cxx_args = ['-std=c++11'] 9 | 10 | nvcc_args = [ 11 | '-gencode', 'arch=compute_50,code=sm_50', 12 | '-gencode', 'arch=compute_52,code=sm_52', 13 | '-gencode', 'arch=compute_60,code=sm_60', 14 | '-gencode', 'arch=compute_61,code=sm_61', 15 | # '-gencode', 'arch=compute_70,code=sm_70', 16 | # '-gencode', 'arch=compute_86,code=sm_86' 17 | ] 18 | 19 | setup( 20 | name='correlation_cuda', 21 | ext_modules=[ 22 | CUDAExtension('correlation_cuda', [ 23 | 'correlation_cuda.cc', 24 | 'correlation_cuda_kernel.cu' 25 | ],extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}) 26 | ], 27 | cmdclass={ 28 | 'build_ext': BuildExtension 29 | }) 30 | -------------------------------------------------------------------------------- /DVSTool/lib/distribTool.py: -------------------------------------------------------------------------------- 1 | import torch.distributed as dist 2 | 3 | 4 | def synchronize(): 5 | """ 6 | Helper function to synchronize (barrier) among all processes when 7 | using distributed training 8 | """ 9 | if not dist.is_available(): 10 | return 11 | if not dist.is_initialized(): 12 | return 13 | world_size = dist.get_world_size() 14 | if world_size == 1: 15 | return 16 | dist.barrier() 17 | 18 | 19 | def get_world_size(): 20 | if not dist.is_available(): 21 | return 1 22 | if not dist.is_initialized(): 23 | return 1 24 | return dist.get_world_size() 25 | 26 | 27 | def reduceTensorMean(cfg, tensor): 28 | rt = tensor.clone() 29 | # print(rt) 30 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 31 | rt /= cfg.envWorldSize 32 | 33 | # if 0 == cfg.envRank or 1 == cfg.envRank: print('rank={}, rtdivid={}'.format(cfg.envRank, rt)) 34 | return rt 35 | 36 | # def reduceTensorSum(tensor): 37 | # rt = tensor.clone() 38 | # dist.all_reduce(rt, op=dist.reduce_op.SUM) 39 | # return rt 40 | -------------------------------------------------------------------------------- /stage1/lib/distribTool.py: -------------------------------------------------------------------------------- 1 | import torch.distributed as dist 2 | 3 | 4 | def synchronize(): 5 | """ 6 | Helper function to synchronize (barrier) among all processes when 7 | using distributed training 8 | """ 9 | if not dist.is_available(): 10 | return 11 | if not dist.is_initialized(): 12 | return 13 | world_size = dist.get_world_size() 14 | if world_size == 1: 15 | return 16 | dist.barrier() 17 | 18 | 19 | def get_world_size(): 20 | if not dist.is_available(): 21 | return 1 22 | if not dist.is_initialized(): 23 | return 1 24 | return dist.get_world_size() 25 | 26 | 27 | def reduceTensorMean(cfg, tensor): 28 | rt = tensor.clone() 29 | # print(rt) 30 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 31 | rt /= cfg.envWorldSize 32 | 33 | # if 0 == cfg.envRank or 1 == cfg.envRank: print('rank={}, rtdivid={}'.format(cfg.envRank, rt)) 34 | return rt 35 | 36 | # def reduceTensorSum(tensor): 37 | # rt = tensor.clone() 38 | # dist.all_reduce(rt, op=dist.reduce_op.SUM) 39 | # return rt 40 | -------------------------------------------------------------------------------- /stage2/lib/distribTool.py: -------------------------------------------------------------------------------- 1 | import torch.distributed as dist 2 | 3 | 4 | def synchronize(): 5 | """ 6 | Helper function to synchronize (barrier) among all processes when 7 | using distributed training 8 | """ 9 | if not dist.is_available(): 10 | return 11 | if not dist.is_initialized(): 12 | return 13 | world_size = dist.get_world_size() 14 | if world_size == 1: 15 | return 16 | dist.barrier() 17 | 18 | 19 | def get_world_size(): 20 | if not dist.is_available(): 21 | return 1 22 | if not dist.is_initialized(): 23 | return 1 24 | return dist.get_world_size() 25 | 26 | 27 | def reduceTensorMean(cfg, tensor): 28 | rt = tensor.clone() 29 | # print(rt) 30 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 31 | rt /= cfg.envWorldSize 32 | 33 | # if 0 == cfg.envRank or 1 == cfg.envRank: print('rank={}, rtdivid={}'.format(cfg.envRank, rt)) 34 | return rt 35 | 36 | # def reduceTensorSum(tensor): 37 | # rt = tensor.clone() 38 | # dist.all_reduce(rt, op=dist.reduce_op.SUM) 39 | # return rt 40 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Ocean YU 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /stage1/models/FrameUnet.py: -------------------------------------------------------------------------------- 1 | from models.baseModule import BaseNet, conv3x3, deconv 2 | 3 | 4 | class FrameUnet(BaseNet): 5 | def __init__(self, cfg): 6 | super(FrameUnet, self).__init__(cfg.netInitType) 7 | self.netScale = 16 8 | self.conv1 = conv3x3(6, 32, cfg, ks=1) # 32, 1x 9 | self.conv2 = conv3x3(32, 64, cfg) # 64, 2x 10 | self.conv3 = conv3x3(64, 128, cfg) # 128, 4x 11 | self.conv4 = conv3x3(128, 256, cfg) # 256, 8x 12 | self.conv5 = conv3x3(256, 256, cfg) # 256, 16x 13 | 14 | self.deconv2 = deconv(256, 256, 3, cfg) # 256, 9x 15 | self.deconv3 = deconv(512, 128, 3, cfg) # 128, 4x 16 | self.deconv4 = deconv(256, 64, 3, cfg) # 64, 2x 17 | self.deconv5 = deconv(128, 32, 3, cfg) # 32, 1x 18 | self.randomInitNet() 19 | 20 | def forward(self, Xt): 21 | feat1 = self.conv1(Xt) # 32 22 | feat2 = self.conv2(feat1) # 64 23 | feat3 = self.conv3(feat2) # 128 24 | feat4 = self.conv4(feat3) # 256 25 | z16 = self.conv5(feat4) # 256 26 | 27 | z8 = self.deconv2(z16, feat4) # 512 28 | z4 = self.deconv3(z8, feat3) # 256 29 | z2 = self.deconv4(z4, feat2) # 128 30 | z1 = self.deconv5(z2, feat1) # 64 31 | 32 | return z16, z8, z4, z2, z1 -------------------------------------------------------------------------------- /stage2/models/FrameUnet.py: -------------------------------------------------------------------------------- 1 | from models.baseModule import BaseNet, conv3x3, deconv 2 | 3 | 4 | class FrameUnet(BaseNet): 5 | def __init__(self, cfg): 6 | super(FrameUnet, self).__init__(cfg.netInitType) 7 | self.netScale = 16 8 | self.conv1 = conv3x3(6, 32, cfg, ks=1) # 32, 1x 9 | self.conv2 = conv3x3(32, 64, cfg) # 64, 2x 10 | self.conv3 = conv3x3(64, 128, cfg) # 128, 4x 11 | self.conv4 = conv3x3(128, 256, cfg) # 256, 8x 12 | self.conv5 = conv3x3(256, 256, cfg) # 256, 16x 13 | 14 | self.deconv2 = deconv(256, 256, 3, cfg) # 256, 9x 15 | self.deconv3 = deconv(512, 128, 3, cfg) # 128, 4x 16 | self.deconv4 = deconv(256, 64, 3, cfg) # 64, 2x 17 | self.deconv5 = deconv(128, 32, 3, cfg) # 32, 1x 18 | self.randomInitNet() 19 | 20 | def forward(self, Xt): 21 | feat1 = self.conv1(Xt) # 32 22 | feat2 = self.conv2(feat1) # 64 23 | feat3 = self.conv3(feat2) # 128 24 | feat4 = self.conv4(feat3) # 256 25 | z16 = self.conv5(feat4) # 256 26 | 27 | z8 = self.deconv2(z16, feat4) # 512 28 | z4 = self.deconv3(z8, feat3) # 256 29 | z2 = self.deconv4(z4, feat2) # 128 30 | z1 = self.deconv5(z2, feat1) # 64 31 | 32 | return z16, z8, z4, z2, z1 -------------------------------------------------------------------------------- /stage2/lib/getMSQMatrix.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | 4 | 5 | def getGaussianW(shape: tuple): 6 | sigma = 0.5 7 | H, W = shape 8 | x = np.linspace(0, H - 1, H).astype(np.float32) 9 | y = np.linspace(0, W - 1, W).astype(np.float32) 10 | xv, yv = np.meshgrid(x, y) 11 | xv: np.ndarray = xv - H // 2 12 | yv: np.ndarray = yv - W // 2 13 | WeightMatrix: np.ndarray = np.exp(-(np.square(xv) + np.square(yv)) / (2 * np.square(sigma))) / np.square( 14 | sigma * np.sqrt(np.pi * 2)) 15 | 16 | # WeightMatrix /= WeightMatrix.sum() 17 | 18 | return np.diag(WeightMatrix.flatten()) 19 | 20 | 21 | def getArray(): 22 | P: np.ndarray = np.array([[0.5, 0.5, -1, -1, 1], 23 | [0, 0.5, 0, -1, 1], 24 | [0.5, 0.5, 1, -1, 1], 25 | [0.5, 0, -1, 0, 1], 26 | [0, 0, 0, 0, 1], 27 | [0.5, 0, 1, 0, 1], 28 | [0.5, 0.5, -1, 1, 1], 29 | [0, 0.5, 0, 1, 1], 30 | [0.5, 0.5, 1, 1, 1]]).astype(np.float32) 31 | W = getGaussianW((3, 3)) 32 | C = np.linalg.inv(P.T @ W @ P) @ P.T @ W 33 | np.save('matrixC.npy', C) 34 | pass 35 | 36 | 37 | if __name__ == '__main__': 38 | getArray() 39 | -------------------------------------------------------------------------------- /stage1/models/EventUnet.py: -------------------------------------------------------------------------------- 1 | from models.baseModule import BaseNet, conv3x3, deconv 2 | 3 | 4 | class EventUnet(BaseNet): 5 | def __init__(self, cfg): 6 | super(EventUnet, self).__init__(cfg.netInitType) 7 | self.netScale = 16 8 | self.conv1 = conv3x3(8, 32, cfg, ks=1) # 32, 1x 9 | self.conv2 = conv3x3(32, 64, cfg) # 64, 2x 10 | self.conv3 = conv3x3(64, 128, cfg) # 128, 4x 11 | self.conv4 = conv3x3(128, 256, cfg) # 256, 8x 12 | self.conv5 = conv3x3(256, 256, cfg) # 256, 16x 13 | 14 | self.deconv2 = deconv(256, 256, 3, cfg) # 256, 9x 15 | self.deconv3 = deconv(512, 128, 3, cfg) # 128, 4x 16 | self.deconv4 = deconv(256, 64, 3, cfg) # 64, 2x 17 | self.deconv5 = deconv(128, 32, 3, cfg) # 32, 1x 18 | 19 | # self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 20 | self.randomInitNet() 21 | 22 | def forward(self, Xt): 23 | feat1 = self.conv1(Xt) # 32,1x 24 | feat2 = self.conv2(feat1) # 64,2x 25 | feat3 = self.conv3(feat2) # 128,4x 26 | feat4 = self.conv4(feat3) # 256,8x 27 | z16 = self.conv5(feat4) # 512,16x 28 | 29 | z8 = self.deconv2(z16, feat4) # 512,8x 30 | z4 = self.deconv3(z8, feat3) # 256, 4x 31 | z2 = self.deconv4(z4, feat2) # 128, 2x 32 | z1 = self.deconv5(z2, feat1) # 64, 1x 33 | 34 | return z16, z8, z4, z2, z1 -------------------------------------------------------------------------------- /stage2/models/EventUnet.py: -------------------------------------------------------------------------------- 1 | from models.baseModule import BaseNet, conv3x3, deconv 2 | 3 | 4 | class EventUnet(BaseNet): 5 | def __init__(self, cfg): 6 | super(EventUnet, self).__init__(cfg.netInitType) 7 | self.netScale = 16 8 | self.conv1 = conv3x3(8, 32, cfg, ks=1) # 32, 1x 9 | self.conv2 = conv3x3(32, 64, cfg) # 64, 2x 10 | self.conv3 = conv3x3(64, 128, cfg) # 128, 4x 11 | self.conv4 = conv3x3(128, 256, cfg) # 256, 8x 12 | self.conv5 = conv3x3(256, 256, cfg) # 256, 16x 13 | 14 | self.deconv2 = deconv(256, 256, 3, cfg) # 256, 9x 15 | self.deconv3 = deconv(512, 128, 3, cfg) # 128, 4x 16 | self.deconv4 = deconv(256, 64, 3, cfg) # 64, 2x 17 | self.deconv5 = deconv(128, 32, 3, cfg) # 32, 1x 18 | 19 | # self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 20 | self.randomInitNet() 21 | 22 | def forward(self, Xt): 23 | feat1 = self.conv1(Xt) # 32,1x 24 | feat2 = self.conv2(feat1) # 64,2x 25 | feat3 = self.conv3(feat2) # 128,4x 26 | feat4 = self.conv4(feat3) # 256,8x 27 | z16 = self.conv5(feat4) # 512,16x 28 | 29 | z8 = self.deconv2(z16, feat4) # 512,8x 30 | z4 = self.deconv3(z8, feat3) # 256, 4x 31 | z2 = self.deconv4(z4, feat2) # 128, 2x 32 | z1 = self.deconv5(z2, feat1) # 64, 1x 33 | 34 | return z16, z8, z4, z2, z1 -------------------------------------------------------------------------------- /stage1/models/FuseNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.baseModule import FuseBlock, BaseNet, Interp 4 | 5 | 6 | class FuseNet(BaseNet): 7 | def __init__(self): 8 | super(FuseNet, self).__init__() 9 | self.ConvIn = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1, stride=1, bias=False) 10 | self.AADBlk1 = FuseBlock(cin=512, cout=256, c_ef=512) 11 | self.AADBlk2 = FuseBlock(cin=256, cout=128, c_ef=256) 12 | self.AADBlk3 = FuseBlock(cin=128, cout=64, c_ef=128) 13 | self.AADBlk4 = FuseBlock(cin=64, cout=32, c_ef=64) 14 | self.Up2x = Interp(scale=2) 15 | self.ItStage1 = nn.Sequential(nn.ReplicationPad2d([1, 1, 1, 1]), 16 | nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3), 17 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 18 | nn.Conv2d(in_channels=32, out_channels=3, kernel_size=1), 19 | ) 20 | 21 | self.randomInitNet() 22 | 23 | def forward(self, z_e, z_f): 24 | ST16x = self.ConvIn(torch.cat([z_e[0], z_f[0]], dim=1)) # 64 25 | 26 | ST8x = self.AADBlk1(self.Up2x(ST16x), z_e[1], z_f[1]) # 32 27 | 28 | ST4x = self.AADBlk2(self.Up2x(ST8x), z_e[2], z_f[2]) # 16 29 | 30 | ST2x = self.AADBlk3(self.Up2x(ST4x), z_e[3], z_f[3]) # 8 31 | 32 | ST1x = self.AADBlk4(self.Up2x(ST2x), z_e[4], z_f[4]) # 4 33 | 34 | ItStage1 = self.ItStage1(ST1x) 35 | 36 | return ItStage1 -------------------------------------------------------------------------------- /stage1/lib/lossTool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from lib.warp import forwadWarp 4 | import torch.nn as nn 5 | 6 | fsplat = forwadWarp(bilinear=False) 7 | fbiwarp = forwadWarp() 8 | 9 | 10 | def TPerceptualLoss(gt_lv1, gt_lv2, gt_lv3, T_lv1, T_lv2, T_lv3): 11 | loss_texture = F.mse_loss(gt_lv3.detach(), T_lv3) 12 | loss_texture += F.mse_loss(gt_lv2.detach(), T_lv2) 13 | loss_texture += F.mse_loss(gt_lv1.detach(), T_lv1) 14 | 15 | loss_texture /= 3. 16 | 17 | return loss_texture 18 | 19 | 20 | def l1Loss(source, target, reduction='mean', mask=None): 21 | if mask is None: 22 | return F.l1_loss(source, target, reduction=reduction) 23 | else: 24 | return F.l1_loss(source * mask, target * mask, reduction=reduction) 25 | 26 | 27 | def l2Loss(source, target, mask=None, reduction='none'): 28 | if mask is None: 29 | return F.mse_loss(source, target, reduction=reduction) 30 | else: 31 | return F.mse_loss(source * mask, target * mask, reduction=reduction) 32 | 33 | 34 | def CharbonnierLoss(source, target, mask=None): 35 | eps = 1e-6 36 | 37 | if mask is None: 38 | diff: torch.Tensor = source - target 39 | 40 | else: 41 | diff: torch.Tensor = (source - target) * mask 42 | loss = torch.sqrt(diff ** 2 + eps ** 2).mean() 43 | return loss 44 | 45 | 46 | def minLoss(I0t, I1t, It): 47 | loss0t = F.mse_loss(I0t, It, reduction='none') 48 | loss1t = F.mse_loss(I1t, It, reduction='none') 49 | 50 | minLoss = torch.min(loss0t, loss1t).mean() 51 | 52 | return minLoss 53 | -------------------------------------------------------------------------------- /stage2/models/FuseNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.baseModule import FuseBlock, BaseNet, Interp 4 | 5 | 6 | class FuseNet(BaseNet): 7 | def __init__(self): 8 | super(FuseNet, self).__init__() 9 | self.ConvIn = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1, stride=1, bias=False) 10 | self.AADBlk1 = FuseBlock(cin=512, cout=256, c_ef=512) 11 | self.AADBlk2 = FuseBlock(cin=256, cout=128, c_ef=256) 12 | self.AADBlk3 = FuseBlock(cin=128, cout=64, c_ef=128) 13 | self.AADBlk4 = FuseBlock(cin=64, cout=32, c_ef=64) 14 | self.Up2x = Interp(scale=2) 15 | self.ItStage1 = nn.Sequential(nn.ReplicationPad2d([1, 1, 1, 1]), 16 | nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3), 17 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 18 | nn.Conv2d(in_channels=32, out_channels=3, kernel_size=1), 19 | ) 20 | 21 | self.randomInitNet() 22 | 23 | def forward(self, z_e, z_f): 24 | ST16x = self.ConvIn(torch.cat([z_e[0], z_f[0]], dim=1)) # 64 25 | 26 | ST8x = self.AADBlk1(self.Up2x(ST16x), z_e[1], z_f[1]) # 32 27 | 28 | ST4x = self.AADBlk2(self.Up2x(ST8x), z_e[2], z_f[2]) # 16 29 | 30 | ST2x = self.AADBlk3(self.Up2x(ST4x), z_e[3], z_f[3]) # 8 31 | 32 | ST1x = self.AADBlk4(self.Up2x(ST2x), z_e[4], z_f[4]) # 4 33 | 34 | ItStage1 = self.ItStage1(ST1x) 35 | 36 | return ItStage1, ST4x, ST2x, ST1x 37 | -------------------------------------------------------------------------------- /stage2/output/Demo_train_on_lowfps_S1/config.txt: -------------------------------------------------------------------------------- 1 | -------------Config------------------- 2 | a_name = DVS_S2FullHard_gn_ 3 | dump = False 4 | envApexLevel = O0 5 | envDistributed = 0 6 | envLocalRank = 0 7 | envNodeID = 0 8 | envNumGPUs = 1 9 | envParallel = False 10 | envRank = 0 11 | envUseApex = False 12 | envWorkers = 0 13 | envWorldSize = 1 14 | envnodeName = SingleNode 15 | lrGamma = 0.999 16 | lrInit = 0.0005 17 | lrMilestones = [100, 150] 18 | lrPolicy = exp 19 | lrdecayIter = 100 20 | netActivate = leakyrelu 21 | netCheck = False 22 | netInitGain = 0.2 23 | netInitType = xavier 24 | netNorm = group 25 | optBetas = [0.9, 0.999] 26 | optDecay = 0 27 | optMomentum = 0.995 28 | optPolicy = Adam 29 | outPathS2 = ./output/Demo_train_on_lowfps_S1/Real_S2/ 30 | pathEvents = output/Demo_train_on_lowfps_S1/events 31 | pathExp = output/Demo_train_on_lowfps_S1 32 | pathGif = output/Demo_train_on_lowfps_S1/gif 33 | pathOut = ./output/ 34 | pathState = output/Demo_train_on_lowfps_S1/state 35 | pathTrainEvent = /home/sensetime/research/release_ICCV2021/dataset/fastDVS_dataset/train 36 | pathValEvent = /home/sensetime/research/release_ICCV2021/dataset/fastDVS_dataset/test 37 | pathWeight = ./output/Demo_train_on_lowfps_S1/state/bestEVI_epoch100.pth 38 | setRandSeed = 2021 39 | snapShot = 10 40 | step = 1 41 | testBatch = 8 42 | testBatchPerGPU = 8 43 | trainBatch = 4 44 | trainBatchPerGPU = 4 45 | trainEpoch = 5000 46 | trainLogger = 47 | trainMaxSave = 10 48 | trainMean = 0 49 | trainSize = (64, 64) 50 | trainStd = 1 51 | trainVisual = False 52 | trainWriter = 53 | valNumInter = 3 54 | valScale = 1 55 | --------------End--------------------- 56 | -------------------------------------------------------------------------------- /stage2/lib/lossTool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from lib.warp import forwadWarp 4 | import torch.nn as nn 5 | 6 | fsplat = forwadWarp(bilinear=False) 7 | fbiwarp = forwadWarp() 8 | 9 | 10 | def TPerceptualLoss(gt_lv1, gt_lv2, gt_lv3, T_lv1, T_lv2, T_lv3): 11 | loss_texture = F.mse_loss(gt_lv3.detach(), T_lv3) 12 | loss_texture += F.mse_loss(gt_lv2.detach(), T_lv2) 13 | loss_texture += F.mse_loss(gt_lv1.detach(), T_lv1) 14 | 15 | loss_texture /= 3. 16 | 17 | return loss_texture 18 | 19 | 20 | def l1Loss(source, target, reduction='mean', mask=None): 21 | if mask is None: 22 | return F.l1_loss(source, target, reduction=reduction) 23 | else: 24 | return F.l1_loss(source * mask, target * mask, reduction=reduction) 25 | 26 | 27 | def l2Loss(source, target, mask=None, reduction='none'): 28 | if mask is None: 29 | return F.mse_loss(source, target, reduction=reduction) 30 | else: 31 | return F.mse_loss(source * mask, target * mask, reduction=reduction) 32 | 33 | 34 | def CharbonnierLoss(source, target, mask=None): 35 | eps = 1e-6 36 | 37 | if mask is None: 38 | diff: torch.Tensor = source - target 39 | 40 | else: 41 | diff: torch.Tensor = (source - target) * mask 42 | loss = torch.sqrt(diff ** 2 + eps ** 2).mean() 43 | return loss 44 | 45 | 46 | def minLoss(I0t: torch.Tensor, I1t: torch.Tensor, It: torch.Tensor): 47 | I0t.requires_grad_(True) 48 | I1t.requires_grad_(True) 49 | 50 | loss0t = F.mse_loss(I0t, It.detach(), reduction='none') 51 | loss1t = F.mse_loss(I1t, It.detach(), reduction='none') 52 | 53 | minLoss = torch.min(loss0t, loss1t).mean() 54 | 55 | return minLoss 56 | -------------------------------------------------------------------------------- /stage1/lib/pwcNet/correlation_pytorch1_1/correlation_cuda_kernel.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | int correlation_forward_cuda_kernel(at::Tensor& output, 8 | int ob, 9 | int oc, 10 | int oh, 11 | int ow, 12 | int osb, 13 | int osc, 14 | int osh, 15 | int osw, 16 | 17 | at::Tensor& input1, 18 | int ic, 19 | int ih, 20 | int iw, 21 | int isb, 22 | int isc, 23 | int ish, 24 | int isw, 25 | 26 | at::Tensor& input2, 27 | int gc, 28 | int gsb, 29 | int gsc, 30 | int gsh, 31 | int gsw, 32 | 33 | at::Tensor& rInput1, 34 | at::Tensor& rInput2, 35 | int pad_size, 36 | int kernel_size, 37 | int max_displacement, 38 | int stride1, 39 | int stride2, 40 | int corr_type_multiply, 41 | cudaStream_t stream); 42 | 43 | 44 | int correlation_backward_cuda_kernel( 45 | at::Tensor& gradOutput, 46 | int gob, 47 | int goc, 48 | int goh, 49 | int gow, 50 | int gosb, 51 | int gosc, 52 | int gosh, 53 | int gosw, 54 | 55 | at::Tensor& input1, 56 | int ic, 57 | int ih, 58 | int iw, 59 | int isb, 60 | int isc, 61 | int ish, 62 | int isw, 63 | 64 | at::Tensor& input2, 65 | int gsb, 66 | int gsc, 67 | int gsh, 68 | int gsw, 69 | 70 | at::Tensor& gradInput1, 71 | int gisb, 72 | int gisc, 73 | int gish, 74 | int gisw, 75 | 76 | at::Tensor& gradInput2, 77 | int ggc, 78 | int ggsb, 79 | int ggsc, 80 | int ggsh, 81 | int ggsw, 82 | 83 | at::Tensor& rInput1, 84 | at::Tensor& rInput2, 85 | int pad_size, 86 | int kernel_size, 87 | int max_displacement, 88 | int stride1, 89 | int stride2, 90 | int corr_type_multiply, 91 | cudaStream_t stream); 92 | -------------------------------------------------------------------------------- /stage2/lib/pwcNet/correlation_pytorch1_1/correlation_cuda_kernel.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | int correlation_forward_cuda_kernel(at::Tensor& output, 8 | int ob, 9 | int oc, 10 | int oh, 11 | int ow, 12 | int osb, 13 | int osc, 14 | int osh, 15 | int osw, 16 | 17 | at::Tensor& input1, 18 | int ic, 19 | int ih, 20 | int iw, 21 | int isb, 22 | int isc, 23 | int ish, 24 | int isw, 25 | 26 | at::Tensor& input2, 27 | int gc, 28 | int gsb, 29 | int gsc, 30 | int gsh, 31 | int gsw, 32 | 33 | at::Tensor& rInput1, 34 | at::Tensor& rInput2, 35 | int pad_size, 36 | int kernel_size, 37 | int max_displacement, 38 | int stride1, 39 | int stride2, 40 | int corr_type_multiply, 41 | cudaStream_t stream); 42 | 43 | 44 | int correlation_backward_cuda_kernel( 45 | at::Tensor& gradOutput, 46 | int gob, 47 | int goc, 48 | int goh, 49 | int gow, 50 | int gosb, 51 | int gosc, 52 | int gosh, 53 | int gosw, 54 | 55 | at::Tensor& input1, 56 | int ic, 57 | int ih, 58 | int iw, 59 | int isb, 60 | int isc, 61 | int ish, 62 | int isw, 63 | 64 | at::Tensor& input2, 65 | int gsb, 66 | int gsc, 67 | int gsh, 68 | int gsw, 69 | 70 | at::Tensor& gradInput1, 71 | int gisb, 72 | int gisc, 73 | int gish, 74 | int gisw, 75 | 76 | at::Tensor& gradInput2, 77 | int ggc, 78 | int ggsb, 79 | int ggsc, 80 | int ggsh, 81 | int ggsw, 82 | 83 | at::Tensor& rInput1, 84 | at::Tensor& rInput2, 85 | int pad_size, 86 | int kernel_size, 87 | int max_displacement, 88 | int stride1, 89 | int stride2, 90 | int corr_type_multiply, 91 | cudaStream_t stream); 92 | -------------------------------------------------------------------------------- /DVSTool/lib/pwcNet/correlation_pytorch1_1/correlation_cuda_kernel.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | int correlation_forward_cuda_kernel(at::Tensor& output, 8 | int ob, 9 | int oc, 10 | int oh, 11 | int ow, 12 | int osb, 13 | int osc, 14 | int osh, 15 | int osw, 16 | 17 | at::Tensor& input1, 18 | int ic, 19 | int ih, 20 | int iw, 21 | int isb, 22 | int isc, 23 | int ish, 24 | int isw, 25 | 26 | at::Tensor& input2, 27 | int gc, 28 | int gsb, 29 | int gsc, 30 | int gsh, 31 | int gsw, 32 | 33 | at::Tensor& rInput1, 34 | at::Tensor& rInput2, 35 | int pad_size, 36 | int kernel_size, 37 | int max_displacement, 38 | int stride1, 39 | int stride2, 40 | int corr_type_multiply, 41 | cudaStream_t stream); 42 | 43 | 44 | int correlation_backward_cuda_kernel( 45 | at::Tensor& gradOutput, 46 | int gob, 47 | int goc, 48 | int goh, 49 | int gow, 50 | int gosb, 51 | int gosc, 52 | int gosh, 53 | int gosw, 54 | 55 | at::Tensor& input1, 56 | int ic, 57 | int ih, 58 | int iw, 59 | int isb, 60 | int isc, 61 | int ish, 62 | int isw, 63 | 64 | at::Tensor& input2, 65 | int gsb, 66 | int gsc, 67 | int gsh, 68 | int gsw, 69 | 70 | at::Tensor& gradInput1, 71 | int gisb, 72 | int gisc, 73 | int gish, 74 | int gisw, 75 | 76 | at::Tensor& gradInput2, 77 | int ggc, 78 | int ggsb, 79 | int ggsc, 80 | int ggsh, 81 | int ggsw, 82 | 83 | at::Tensor& rInput1, 84 | at::Tensor& rInput2, 85 | int pad_size, 86 | int kernel_size, 87 | int max_displacement, 88 | int stride1, 89 | int stride2, 90 | int corr_type_multiply, 91 | cudaStream_t stream); 92 | -------------------------------------------------------------------------------- /stage1/models/Generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from models.baseModule import BaseNet 4 | from models.EventUnet import EventUnet 5 | from models.FrameUnet import FrameUnet 6 | from models.FuseNet import FuseNet 7 | import math 8 | 9 | 10 | class Generator(BaseNet): 11 | def __init__(self, cfg): 12 | super(Generator, self).__init__(cfg.netInitType, cfg.netInitGain) 13 | self.cfg = cfg 14 | self.netScale = 16 15 | 16 | self.eventUnet = EventUnet(cfg) 17 | self.frameUnet = FrameUnet(cfg) 18 | self.fuseNet = FuseNet() 19 | 20 | if cfg.step in [1, 2, 3]: 21 | self.initPreweight(cfg.pathWeight) 22 | 23 | def getWeight(self, pathPreWeight: str = None): 24 | checkpoints = torch.load(pathPreWeight, map_location=torch.device('cpu')) 25 | try: 26 | weightDict = checkpoints['Generator'] 27 | except Exception as e: 28 | weightDict = checkpoints['model_state_dict'] 29 | return weightDict 30 | 31 | def adap2Net(self, tensor: torch.Tensor): 32 | Height, Width = tensor.size(2), tensor.size(3) 33 | 34 | Height_ = int(math.floor(math.ceil(Height / self.netScale) * self.netScale)) 35 | Width_ = int(math.floor(math.ceil(Width / self.netScale) * self.netScale)) 36 | 37 | if any([Height_ != Height, Width_ != Width]): 38 | tensor = F.pad(tensor, [0, Width_ - Width, 0, Height_ - Height]) 39 | 40 | return tensor 41 | 42 | def forward(self, I0t, I1t, Et): 43 | N, C, H, W = I0t.shape 44 | 45 | I0t = self.adap2Net(I0t) 46 | I1t = self.adap2Net(I1t) 47 | 48 | Et = self.adap2Net(Et) 49 | 50 | z_e = self.eventUnet(Et) 51 | z_f = self.frameUnet(torch.cat([I0t, I1t], dim=1)) 52 | 53 | fusedOut = self.fuseNet(z_e, z_f) 54 | 55 | output = fusedOut[:, :, 0:H, 0:W] 56 | 57 | return output 58 | -------------------------------------------------------------------------------- /DVSTool/lib/fileTool.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from pathlib import Path 3 | from queue import Queue 4 | 5 | 6 | def delPath(root): 7 | """ 8 | remove dir trees 9 | :param path: root dir 10 | :return: True or False 11 | """ 12 | root = Path(root) 13 | if root.is_file(): 14 | try: 15 | root.unlink() 16 | except Exception as e: 17 | print(e) 18 | elif root.is_dir(): 19 | for item in root.iterdir(): 20 | delPath(item) 21 | try: 22 | root.rmdir() 23 | # print('Files in {} is removed'.format(root)) 24 | except Exception as e: 25 | print(e) 26 | 27 | 28 | def mkPath(path): 29 | p = Path(path) 30 | try: 31 | p.mkdir(parents=True, exist_ok=False) 32 | return True 33 | except Exception as e: 34 | return False 35 | 36 | 37 | def getAllFiles(root, ext=None): 38 | p = Path(root) 39 | if ext is not None: 40 | pathnames = p.glob("**/*.{}".format(ext)) 41 | 42 | else: 43 | pathnames = p.glob("**/*") 44 | filenames = sorted([x.as_posix() for x in pathnames]) 45 | return filenames 46 | 47 | 48 | def copyFile(src, dst): 49 | parent = Path(dst).parent 50 | mkPath(parent) 51 | try: 52 | shutil.copytree(str(src), str(dst)) 53 | except: 54 | shutil.copy(str(src), str(dst)) 55 | 56 | 57 | def movFile(src, dst): 58 | parent = Path(dst).parent 59 | mkPath(parent) 60 | shutil.move(str(src), str(dst)) 61 | 62 | 63 | def getSubDirs(root): 64 | p = Path(root) 65 | dirs = [x for x in p.iterdir() if x.is_dir()] 66 | return sorted(dirs) 67 | 68 | 69 | class fileBuffer(Queue): 70 | def __init__(self, capacity): 71 | super(fileBuffer, self).__init__() 72 | assert capacity > 0 73 | self.capacity = capacity 74 | 75 | def __call__(self, x): 76 | while self.qsize() >= self.capacity: 77 | delPath(self.get()) 78 | self.put(x) 79 | 80 | -------------------------------------------------------------------------------- /stage2/runBash/runEvi.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | jobName = 'demo_train_for_lowfps' 4 | part = 'Pixel' 5 | 6 | freeNodes = ['SH-IDC1-10-5-30-135', 'SH-IDC1-10-5-30-138'] # fast 7 | 8 | ntaskPerNode = 8 # number of GPUs per nodes 9 | 10 | envDistributed = 1 11 | 12 | gpuDict = "\"{\'SH-IDC1-10-5-30-135\': \'0,1,2,3,4,5,6,7\', \'SH-IDC1-10-5-30-138\': \'0,1,2,3,4,5,6,7\'}\"" 13 | 14 | nodeNum = len(freeNodes) 15 | nTasks = ntaskPerNode * nodeNum if envDistributed else 1 16 | nodeList = ','.join(freeNodes) 17 | initNode = freeNodes[0] 18 | reuseGPU = 0 19 | scrip = 'train' 20 | # scrip = 'train_EAttn' 21 | config = 'configEVI' 22 | 23 | 24 | def runDist(): 25 | pyCode = [] 26 | pyCode.append('python') 27 | pyCode.append('-m') 28 | pyCode.append(scrip) 29 | pyCode.append('--initNode {}'.format(initNode)) 30 | pyCode.append('--config {}'.format(config)) 31 | pyCode.append('--gpuList {}'.format(gpuDict)) 32 | pyCode.append('--reuseGPU {}'.format(reuseGPU)) 33 | pyCode.append('--envDistributed {}'.format(envDistributed)) 34 | pyCode.append('--expName {}'.format(jobName)) 35 | pyCode = ' '.join(pyCode) 36 | 37 | srunCode = [] 38 | srunCode.append('srun') 39 | srunCode.append('--gres=gpu:{}'.format(ntaskPerNode)) 40 | srunCode.append('--job-name={}'.format(jobName)) 41 | srunCode.append('--partition={}'.format(part)) 42 | srunCode.append('--nodelist={}'.format(nodeList)) if freeNodes is not None else print('Get node by slurm') 43 | srunCode.append('--ntasks={}'.format(nTasks)) 44 | srunCode.append('--nodes={}'.format(nodeNum)) 45 | srunCode.append('--ntasks-per-node={}'.format(ntaskPerNode)) if envDistributed else print( 46 | 'ntasks-per-node is 1') 47 | srunCode.append('--cpus-per-task=4') 48 | srunCode.append('--kill-on-bad-exit=1') 49 | srunCode.append('--mpi=pmi2') 50 | # srunCode.append(' --pty bash') 51 | srunCode.append(pyCode) 52 | 53 | srunCode = ' '.join(srunCode) 54 | 55 | os.system(srunCode) 56 | # else: 57 | # os.system(pyCode) 58 | 59 | 60 | if __name__ == '__main__': 61 | runDist() 62 | -------------------------------------------------------------------------------- /stage2/models/Generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from models.baseModule import BaseNet 4 | from models.EventUnet import EventUnet 5 | from models.FrameUnet import FrameUnet 6 | from models.FuseNet import FuseNet 7 | import math 8 | 9 | 10 | class Generator(BaseNet): 11 | def __init__(self, cfg): 12 | super(Generator, self).__init__(cfg.netInitType, cfg.netInitGain) 13 | self.cfg = cfg 14 | self.netScale = 16 15 | 16 | self.eventUnet = EventUnet(cfg) 17 | self.frameUnet = FrameUnet(cfg) 18 | self.fuseNet = FuseNet() 19 | 20 | if cfg.step in [1, 2, 3]: 21 | self.initPreweight(cfg.pathWeight) 22 | 23 | def getWeight(self, pathPreWeight: str = None): 24 | checkpoints = torch.load(pathPreWeight, map_location=torch.device('cpu')) 25 | try: 26 | weightDict = checkpoints['Generator'] 27 | except Exception as e: 28 | weightDict = checkpoints['model_state_dict'] 29 | return weightDict 30 | 31 | def adap2Net(self, tensor: torch.Tensor): 32 | Height, Width = tensor.size(2), tensor.size(3) 33 | 34 | Height_ = int(math.floor(math.ceil(Height / self.netScale) * self.netScale)) 35 | Width_ = int(math.floor(math.ceil(Width / self.netScale) * self.netScale)) 36 | 37 | if any([Height_ != Height, Width_ != Width]): 38 | tensor = F.pad(tensor, [0, Width_ - Width, 0, Height_ - Height]) 39 | 40 | return tensor 41 | 42 | def forward(self, I0t, I1t, Et): 43 | N, C, H, W = I0t.shape 44 | 45 | I0t = self.adap2Net(I0t) 46 | I1t = self.adap2Net(I1t) 47 | 48 | Et = self.adap2Net(Et) 49 | 50 | z_e = self.eventUnet(Et) 51 | z_f = self.frameUnet(torch.cat([I0t, I1t], dim=1)) 52 | 53 | fusedOut, ST4x, ST2x, ST1x = self.fuseNet(z_e, z_f) 54 | 55 | ST1x = ST1x[:, :, 0:H, 0:W] 56 | ST2x = ST2x[:, :, 0:H // 2, 0:W // 2] 57 | ST4x = ST4x[:, :, 0:H // 4, 0:W // 4] 58 | 59 | output = fusedOut[:, :, 0:H, 0:W] 60 | 61 | return output, ST4x, ST2x, ST1x 62 | -------------------------------------------------------------------------------- /stage1/lib/fileTool.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from pathlib import Path 3 | from queue import Queue 4 | 5 | 6 | def delPath(root): 7 | """ 8 | remove dir trees 9 | :param path: root dir 10 | :return: True or False 11 | """ 12 | root = Path(root) 13 | if root.is_file(): 14 | try: 15 | root.unlink() 16 | except Exception as e: 17 | print(e) 18 | elif root.is_dir(): 19 | for item in root.iterdir(): 20 | delPath(item) 21 | try: 22 | root.rmdir() 23 | # print('Files in {} is removed'.format(root)) 24 | except Exception as e: 25 | print(e) 26 | 27 | 28 | def mkPath(path): 29 | p = Path(path) 30 | try: 31 | p.mkdir(parents=True, exist_ok=False) 32 | return True 33 | except Exception as e: 34 | return False 35 | 36 | 37 | def getAllFiles(root, ext=None): 38 | p = Path(root) 39 | if ext is not None: 40 | pathnames = p.glob("**/*.{}".format(ext)) 41 | 42 | else: 43 | pathnames = p.glob("**/*") 44 | filenames = sorted([x.as_posix() for x in pathnames]) 45 | return filenames 46 | 47 | 48 | def copyFile(src, dst): 49 | parent = Path(dst).parent 50 | mkPath(parent) 51 | try: 52 | shutil.copytree(str(src), str(dst)) 53 | except: 54 | shutil.copy(str(src), str(dst)) 55 | 56 | 57 | def movFile(src, dst): 58 | parent = Path(dst).parent 59 | mkPath(parent) 60 | shutil.move(str(src), str(dst)) 61 | 62 | 63 | def getSubDirs(root): 64 | p = Path(root) 65 | dirs = [x for x in p.iterdir() if x.is_dir()] 66 | return sorted(dirs) 67 | 68 | 69 | class fileBuffer(Queue): 70 | def __init__(self, capacity): 71 | super(fileBuffer, self).__init__() 72 | assert capacity > 0 73 | self.capacity = capacity 74 | 75 | def __call__(self, x): 76 | while self.qsize() >= self.capacity: 77 | delPath(self.get()) 78 | self.put(x) 79 | 80 | 81 | if __name__ == '__main__': 82 | path = '/home/sensetime/project/lib/a' 83 | # mkPath(path) 84 | filenames = delPath(path) 85 | pass 86 | -------------------------------------------------------------------------------- /stage2/lib/fileTool.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from pathlib import Path 3 | from queue import Queue 4 | 5 | 6 | def delPath(root): 7 | """ 8 | remove dir trees 9 | :param path: root dir 10 | :return: True or False 11 | """ 12 | root = Path(root) 13 | if root.is_file(): 14 | try: 15 | root.unlink() 16 | except Exception as e: 17 | print(e) 18 | elif root.is_dir(): 19 | for item in root.iterdir(): 20 | delPath(item) 21 | try: 22 | root.rmdir() 23 | # print('Files in {} is removed'.format(root)) 24 | except Exception as e: 25 | print(e) 26 | 27 | 28 | def mkPath(path): 29 | p = Path(path) 30 | try: 31 | p.mkdir(parents=True, exist_ok=False) 32 | return True 33 | except Exception as e: 34 | return False 35 | 36 | 37 | def getAllFiles(root, ext=None): 38 | p = Path(root) 39 | if ext is not None: 40 | pathnames = p.glob("**/*.{}".format(ext)) 41 | 42 | else: 43 | pathnames = p.glob("**/*") 44 | filenames = sorted([x.as_posix() for x in pathnames]) 45 | return filenames 46 | 47 | 48 | def copyFile(src, dst): 49 | parent = Path(dst).parent 50 | mkPath(parent) 51 | try: 52 | shutil.copytree(str(src), str(dst)) 53 | except: 54 | shutil.copy(str(src), str(dst)) 55 | 56 | 57 | def movFile(src, dst): 58 | parent = Path(dst).parent 59 | mkPath(parent) 60 | shutil.move(str(src), str(dst)) 61 | 62 | 63 | def getSubDirs(root): 64 | p = Path(root) 65 | dirs = [x for x in p.iterdir() if x.is_dir()] 66 | return sorted(dirs) 67 | 68 | 69 | class fileBuffer(Queue): 70 | def __init__(self, capacity): 71 | super(fileBuffer, self).__init__() 72 | assert capacity > 0 73 | self.capacity = capacity 74 | 75 | def __call__(self, x): 76 | while self.qsize() >= self.capacity: 77 | delPath(self.get()) 78 | self.put(x) 79 | 80 | 81 | if __name__ == '__main__': 82 | path = '/home/sensetime/project/lib/a' 83 | # mkPath(path) 84 | filenames = delPath(path) 85 | pass 86 | -------------------------------------------------------------------------------- /stage1/runBash/runEvi.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | jobName = 'Demo_Train_for_lowfps' 4 | part = 'Pixel' # node part 5 | 6 | # name of computational nodes which are available 7 | freeNodes = ['SH-IDC1-10-5-30-135', 'SH-IDC1-10-5-30-138'] 8 | 9 | 10 | ntaskPerNode = 8 # number of GPUs per nodes 11 | reuseGPU = 0 12 | envDistributed = 1 13 | # gpu id on computational node 14 | gpuDict = "\"{\'SH-IDC1-10-5-30-135\': \'0,1,2,3,4,5,6,7\', \'SH-IDC1-10-5-30-138\': \'0,1,2,3,4,5,6,7\'}\"" 15 | 16 | nodeNum = len(freeNodes) 17 | nTasks = ntaskPerNode * nodeNum if envDistributed else 1 18 | nodeList = ','.join(freeNodes) 19 | initNode = freeNodes[0] 20 | 21 | scrip = 'train' 22 | config = 'configEVI' # config name (configEVI.py here) 23 | 24 | 25 | def runDist(): 26 | pyCode = [] 27 | pyCode.append('python') 28 | pyCode.append('-m') 29 | pyCode.append(scrip) 30 | pyCode.append('--initNode {}'.format(initNode)) 31 | pyCode.append('--config {}'.format(config)) 32 | pyCode.append('--gpuList {}'.format(gpuDict)) 33 | pyCode.append('--reuseGPU {}'.format(reuseGPU)) 34 | pyCode.append('--envDistributed {}'.format(envDistributed)) 35 | pyCode.append('--expName {}'.format(jobName)) 36 | pyCode = ' '.join(pyCode) 37 | 38 | srunCode = [] 39 | srunCode.append('srun') 40 | srunCode.append('--gres=gpu:{}'.format(ntaskPerNode)) 41 | srunCode.append('--job-name={}'.format(jobName)) 42 | srunCode.append('--partition={}'.format(part)) 43 | srunCode.append('--nodelist={}'.format(nodeList)) if freeNodes is not None else print('Get node by slurm') 44 | srunCode.append('--ntasks={}'.format(nTasks)) 45 | srunCode.append('--nodes={}'.format(nodeNum)) 46 | srunCode.append('--ntasks-per-node={}'.format(ntaskPerNode)) if envDistributed else print( 47 | 'ntasks-per-node is 1') 48 | srunCode.append('--cpus-per-task=4') 49 | srunCode.append('--kill-on-bad-exit=1') 50 | srunCode.append('--mpi=pmi2') 51 | # srunCode.append(' --pty bash') 52 | srunCode.append(pyCode) 53 | 54 | srunCode = ' '.join(srunCode) 55 | 56 | os.system(srunCode) 57 | # else: 58 | # os.system(pyCode) 59 | 60 | 61 | if __name__ == '__main__': 62 | runDist() 63 | -------------------------------------------------------------------------------- /DVSTool/splitTrainTest.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('../') 4 | from lib import fileTool as FT 5 | from pathlib import Path 6 | from tqdm import tqdm 7 | from functools import partial 8 | import multiprocessing 9 | import os 10 | from multiprocessing import Pool, RLock, freeze_support 11 | 12 | 13 | def sortFunc(name: str): 14 | name = Path(name).stem 15 | return int(name) 16 | 17 | 18 | def splitFile(subDir, srcDir, dstDir): 19 | curProc = multiprocessing.current_process() 20 | allFiles = FT.getAllFiles(subDir) 21 | 22 | allFiles.sort(key=sortFunc) 23 | num = len(allFiles) 24 | testSubset = allFiles[0:num // 3] 25 | pbar = tqdm(total=len(testSubset), position=int(curProc._identity[0])) 26 | for test in testSubset: 27 | targetName = test.replace(srcDir, dstDir) 28 | FT.mkPath(str(Path(targetName).parent)) 29 | FT.movFile(test, targetName) 30 | pbar.update(1) 31 | pbar.clear() 32 | pbar.close() 33 | 34 | 35 | def batchZip(srcDir, dstDir, poolSize=1): 36 | # srcDir = '/mnt/lustre/yuzhiyang/dataset/GoPro_public/event/simEvents/' 37 | 38 | # zipFrameEvent(srcDir, dstDir, imgDir, vis=True) 39 | allSubDirs = FT.getSubDirs(srcDir) 40 | kernelFunc = partial(splitFile, srcDir=srcDir, dstDir=dstDir) 41 | 42 | freeze_support() 43 | tqdm.set_lock(RLock()) 44 | p = Pool(processes=poolSize, initializer=tqdm.set_lock, initargs=(tqdm.get_lock(),)) 45 | p.map(func=kernelFunc, iterable=allSubDirs) 46 | p.close() 47 | p.join() 48 | # for subDir in allSubDirs: 49 | # zipFrameEvent(subDir, dstDir, vis=True) 50 | 51 | 52 | def mainServer(): 53 | # dirs of simulated events 54 | srcDir = '/mnt/lustre/yuzhiyang/dataset/slomoDVS2/event/train' 55 | 56 | # dirs of output samples 57 | dstDir = '/mnt/lustre/yuzhiyang/dataset/slomoDVS2/event/trainSave' 58 | 59 | poolSize = 20 60 | 61 | batchZip(srcDir=srcDir, 62 | dstDir=dstDir, 63 | poolSize=poolSize) 64 | 65 | 66 | def mainLocal(): 67 | # dirs of simulated events 68 | srcDir = '/home/sensetime/data/event/DVS/slomoDVS/event/train' 69 | 70 | # dirs of output samples 71 | dstDir = '/home/sensetime/data/event/DVS/slomoDVS/event/test' 72 | 73 | poolSize = 1 74 | 75 | batchZip(srcDir=srcDir, 76 | dstDir=dstDir, 77 | poolSize=poolSize) 78 | 79 | 80 | if __name__ == '__main__': 81 | mainServer() 82 | # mainLocal() 83 | -------------------------------------------------------------------------------- /stage1/lib/videoTool.py: -------------------------------------------------------------------------------- 1 | import os 2 | import lib.fileTool as FT 3 | 4 | ffmpegPath = '/usr/bin/ffmpeg' 5 | videoName = '/home/sensetime/data/VideoInterpolation/highfps/gopro_yzy/bbb.MP4' 6 | 7 | 8 | def video2Frame(vPath: str, fdir: str, H: int = None, W: int = None): 9 | FT.mkPath(fdir) 10 | if H is None or W is None: 11 | os.system('{} -y -i {} -vsync 0 -qscale:v 2 {}/%04d.png'.format(ffmpegPath, vPath, fdir)) 12 | else: 13 | os.system('{} -y -i {} -vf scale={}:{} -vsync 0 -qscale:v 2 {}/%04d.jpg'.format(ffmpegPath, vPath, W, H, fdir)) 14 | 15 | 16 | def frame2Video(fdir: str, vPath: str, fps: int, H: int = None, W: int = None, ): 17 | if H is None or W is None: 18 | # os.system('{} -y -r {} -f image2 -i {}/%*.png -vcodec libx264 -crf 18 -pix_fmt yuv420p {}' 19 | # .format(ffmpegPath, fps, fdir, vPath)) 20 | 21 | os.system('{} -y -r {} -f image2 -i {}/%6d.png -vcodec libx264 -crf 18 -pix_fmt yuv420p {}' 22 | .format(ffmpegPath, fps, fdir, vPath)) 23 | else: 24 | os.system('{} -y -r {} -f image2 -s {}x{} -i {}/%*.png -vcodec libx264 -crf 25 -pix_fmt yuv420p {}' 25 | .format(ffmpegPath, fps, W, H, fdir, vPath)) 26 | 27 | 28 | def slomo(vPath: str, dstPath: str, fps): 29 | os.system( 30 | '{} -y -r {} -i {} -strict -2 -vcodec libx264 -c:a aac -crf 18 {}'.format(ffmpegPath, fps, vPath, dstPath)) 31 | 32 | 33 | def downFPS(vPath: str, dstPath: str, fps): 34 | os.system( 35 | '{} -i {} -strict -2 -r {} {}'.format(ffmpegPath, vPath, fps, dstPath)) 36 | 37 | 38 | def downSample(vPath: str, dstPath: str, H, W): 39 | os.system( 40 | '{} -i {} -strict -2 -s {}x{} {}'.format(ffmpegPath, vPath, H, W, dstPath)) 41 | 42 | 43 | if __name__ == '__main__': 44 | framePath = '/home/sensetime/data/event/outputTest/pencil2/' 45 | # framePath = '/home/sensetime/data/VideoInterpolation/highfps/gopro_yzy/output' 46 | video = '/home/sensetime/data/event/outputvideo/pencil2.mp4' 47 | # video2Frame(video, framePath) 48 | 49 | # video = '/home/sensetime/data/VideoInterpolation/highfps/goPro_240fps/train/GOPR0372_07_00/out.mp4' 50 | # framePath = '/home/sensetime/data/VideoInterpolation/highfps/goPro_240fps/train/GOPR0372_07_00' 51 | frame2Video(framePath, video, 60) 52 | 53 | # vPath = '/media/sensetime/Elements/0721 /0716_video/1.avi' 54 | # dstPath = '/media/sensetime/Elements/0721/0716_video/1_.mp4' 55 | # downFPS(vPath, dstPath, 8) 56 | -------------------------------------------------------------------------------- /DVSTool/lib/fitTool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.nn import functional as F 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | 7 | 8 | def backWarp(img: torch.Tensor, flow: torch.Tensor): 9 | device = img.device 10 | N, C, H, W = img.size() 11 | 12 | u = flow[:, 0, :, :] 13 | v = flow[:, 1, :, :] 14 | 15 | gridX, gridY = np.meshgrid(np.arange(W), np.arange(H)) 16 | gridX = torch.tensor(gridX, requires_grad=False).to(device) 17 | gridY = torch.tensor(gridY, requires_grad=False).to(device) 18 | 19 | x = gridX.unsqueeze(0).expand_as(u).float() + u 20 | y = gridY.unsqueeze(0).expand_as(v).float() + v 21 | 22 | # range -1 to 1 23 | x = 2 * x / (W - 1.0) - 1.0 24 | y = 2 * y / (H - 1.0) - 1.0 25 | # stacking X and Y 26 | grid = torch.stack((x, y), dim=3) 27 | # Sample pixels using bilinear interpolation. 28 | imgOut = F.grid_sample(img, grid, mode='bilinear', padding_mode='zeros') 29 | 30 | # mask = torch.ones_like(img, requires_grad=False) 31 | # mask = F.grid_sample(mask, grid) 32 | # 33 | # mask[mask < 0.9999] = 0 34 | # mask[mask > 0] = 1 35 | 36 | # return imgOut * (mask.detach()) 37 | return imgOut 38 | 39 | 40 | class IdCoRe(object): 41 | def __init__(self, intime, device, target='idx'): 42 | self.device = device 43 | if isinstance(intime, int): 44 | intime = torch.tensor(intime).to(self.device) 45 | 46 | if target == 'idx': 47 | self.coord = intime + 1 # if the index in frameT is 0 48 | elif target == 'coord': 49 | self.coord = intime # then the related time in dct coord is 1 50 | elif target == 'time': # and the real time is 0.125s 51 | self.coord = intime * 8.0 52 | 53 | @property 54 | def idx(self): 55 | return (self.coord - 1).int() 56 | 57 | @idx.setter 58 | def idx(self, x): 59 | x = torch.tensor(x) 60 | self.coord = x + 1 61 | 62 | @property 63 | def time(self): 64 | return self.coord.float() / 8.0 65 | 66 | @time.setter 67 | def time(self, x): 68 | x = torch.tensor(x).to(self.device) 69 | self.coord = x * 8.0 70 | 71 | 72 | def getAccFlow(aF, bF, aB, bB, t, device): 73 | F0t = aF * (t ** 2) + bF * t 74 | F1t = aB * ((1 - t) ** 2) + bB * (1 - t) 75 | 76 | return F0t.to(device), F1t.to(device) 77 | 78 | 79 | def getAccParam(F0_1, F01): 80 | a = (F01 + F0_1) / 2.0 81 | b = (F01 - F0_1) / 2.0 82 | 83 | return a, b 84 | -------------------------------------------------------------------------------- /stage1/lib/fitTool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.nn import functional as F 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | 7 | 8 | def backWarp(img: torch.Tensor, flow: torch.Tensor): 9 | device = img.device 10 | N, C, H, W = img.size() 11 | 12 | u = flow[:, 0, :, :] 13 | v = flow[:, 1, :, :] 14 | 15 | gridX, gridY = np.meshgrid(np.arange(W), np.arange(H)) 16 | gridX = torch.tensor(gridX, requires_grad=False).to(device) 17 | gridY = torch.tensor(gridY, requires_grad=False).to(device) 18 | 19 | x = gridX.unsqueeze(0).expand_as(u).float() + u 20 | y = gridY.unsqueeze(0).expand_as(v).float() + v 21 | 22 | # range -1 to 1 23 | x = 2 * x / (W - 1.0) - 1.0 24 | y = 2 * y / (H - 1.0) - 1.0 25 | # stacking X and Y 26 | grid = torch.stack((x, y), dim=3) 27 | # Sample pixels using bilinear interpolation. 28 | imgOut = F.grid_sample(img, grid, mode='bilinear', padding_mode='zeros') 29 | 30 | # mask = torch.ones_like(img, requires_grad=False) 31 | # mask = F.grid_sample(mask, grid) 32 | # 33 | # mask[mask < 0.9999] = 0 34 | # mask[mask > 0] = 1 35 | 36 | # return imgOut * (mask.detach()) 37 | return imgOut 38 | 39 | 40 | class IdCoRe(object): 41 | def __init__(self, intime, device, target='idx'): 42 | self.device = device 43 | if isinstance(intime, int): 44 | intime = torch.tensor(intime).to(self.device) 45 | 46 | if target == 'idx': 47 | self.coord = intime + 1 # if the index in frameT is 0 48 | elif target == 'coord': 49 | self.coord = intime # then the related time in dct coord is 1 50 | elif target == 'time': # and the real time is 0.125s 51 | self.coord = intime * 8.0 52 | 53 | @property 54 | def idx(self): 55 | return (self.coord - 1).int() 56 | 57 | @idx.setter 58 | def idx(self, x): 59 | x = torch.tensor(x) 60 | self.coord = x + 1 61 | 62 | @property 63 | def time(self): 64 | return self.coord.float() / 8.0 65 | 66 | @time.setter 67 | def time(self, x): 68 | x = torch.tensor(x).to(self.device) 69 | self.coord = x * 8.0 70 | 71 | 72 | def getAccFlow(a0, b0, a1, b1, t, device): 73 | F0t = a0 * (t ** 2) + b0 * t 74 | F1t = a1 * ((1 - t) ** 2) + b1 * (1 - t) 75 | 76 | return F0t.to(device), F1t.to(device) 77 | 78 | 79 | def getAccParam(F0_1, F01): 80 | a = (F01 + F0_1) / 2.0 81 | b = (F01 - F0_1) / 2.0 82 | 83 | return a, b 84 | -------------------------------------------------------------------------------- /stage2/lib/fitTool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.nn import functional as F 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | 7 | 8 | def backWarp(img: torch.Tensor, flow: torch.Tensor): 9 | device = img.device 10 | N, C, H, W = img.size() 11 | 12 | u = flow[:, 0, :, :] 13 | v = flow[:, 1, :, :] 14 | 15 | gridX, gridY = np.meshgrid(np.arange(W), np.arange(H)) 16 | gridX = torch.tensor(gridX, requires_grad=False).to(device) 17 | gridY = torch.tensor(gridY, requires_grad=False).to(device) 18 | 19 | x = gridX.unsqueeze(0).expand_as(u).float() + u 20 | y = gridY.unsqueeze(0).expand_as(v).float() + v 21 | 22 | # range -1 to 1 23 | x = 2 * x / (W - 1.0) - 1.0 24 | y = 2 * y / (H - 1.0) - 1.0 25 | # stacking X and Y 26 | grid = torch.stack((x, y), dim=3) 27 | # Sample pixels using bilinear interpolation. 28 | imgOut = F.grid_sample(img, grid, mode='bilinear', padding_mode='zeros') 29 | 30 | # mask = torch.ones_like(img, requires_grad=False) 31 | # mask = F.grid_sample(mask, grid) 32 | # 33 | # mask[mask < 0.9999] = 0 34 | # mask[mask > 0] = 1 35 | 36 | # return imgOut * (mask.detach()) 37 | return imgOut 38 | 39 | 40 | class IdCoRe(object): 41 | def __init__(self, intime, device, target='idx'): 42 | self.device = device 43 | if isinstance(intime, int): 44 | intime = torch.tensor(intime).to(self.device) 45 | 46 | if target == 'idx': 47 | self.coord = intime + 1 # if the index in frameT is 0 48 | elif target == 'coord': 49 | self.coord = intime # then the related time in dct coord is 1 50 | elif target == 'time': # and the real time is 0.125s 51 | self.coord = intime * 8.0 52 | 53 | @property 54 | def idx(self): 55 | return (self.coord - 1).int() 56 | 57 | @idx.setter 58 | def idx(self, x): 59 | x = torch.tensor(x) 60 | self.coord = x + 1 61 | 62 | @property 63 | def time(self): 64 | return self.coord.float() / 8.0 65 | 66 | @time.setter 67 | def time(self, x): 68 | x = torch.tensor(x).to(self.device) 69 | self.coord = x * 8.0 70 | 71 | 72 | def getAccFlow(a0, b0, a1, b1, t, device): 73 | F0t = a0 * (t ** 2) + b0 * t 74 | F1t = a1 * ((1 - t) ** 2) + b1 * (1 - t) 75 | 76 | return F0t.to(device), F1t.to(device) 77 | 78 | 79 | def getAccParam(F0_1, F01): 80 | a = (F01 + F0_1) / 2.0 81 | b = (F01 - F0_1) / 2.0 82 | 83 | return a, b 84 | -------------------------------------------------------------------------------- /DVSTool/resplit.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('../') 4 | from lib import fileTool as FT 5 | from pathlib import Path 6 | from tqdm import tqdm 7 | from functools import partial 8 | import multiprocessing 9 | import os 10 | from multiprocessing import Pool, RLock, freeze_support 11 | 12 | 13 | def sortFunc(name: str): 14 | name = Path(name).stem 15 | return int(name) 16 | 17 | 18 | def splitFile(testFile, dstDir): 19 | curProc = multiprocessing.current_process() 20 | # allFiles = FT.getAllFiles(subDir) 21 | 22 | # allFiles.sort(key=sortFunc) 23 | # num = len(allFiles) 24 | # testSubset = allFiles[0:num // 3] 25 | # pbar = tqdm(total=len(allFiles), position=int(curProc._identity[0])) 26 | # for test in allFiles: 27 | fileName = Path(testFile).stem 28 | dirName = str(Path(dstDir) / Path(fileName)) 29 | 30 | targetName = dirName.replace('/train', '/test') 31 | FT.mkPath(str(Path(targetName).parent)) 32 | FT.movFile(dirName, targetName) 33 | # pbar.update(1) 34 | # pbar.clear() 35 | # pbar.close() 36 | 37 | 38 | def batchRun(srcDir, dstDir, poolSize=1): 39 | # srcDir = '/mnt/lustre/yuzhiyang/dataset/GoPro_public/event/simEvents/' 40 | 41 | # zipFrameEvent(srcDir, dstDir, imgDir, vis=True) 42 | allAedats = FT.getAllFiles(srcDir, 'aedat4') 43 | kernelFunc = partial(splitFile, dstDir=dstDir) 44 | 45 | freeze_support() 46 | tqdm.set_lock(RLock()) 47 | p = Pool(processes=poolSize, initializer=tqdm.set_lock, initargs=(tqdm.get_lock(),)) 48 | p.map(func=kernelFunc, iterable=allAedats) 49 | p.close() 50 | p.join() 51 | # for subDir in allSubDirs: 52 | # zipFrameEvent(subDir, dstDir, vis=True) 53 | 54 | 55 | def mainServer(): 56 | # dirs of simulated events 57 | srcDir = '/mnt/lustre/yuzhiyang/dataset/slomoDVS/aedat4/test' 58 | 59 | # dirs of output samples 60 | dstDir = '/mnt/lustre/yuzhiyang/dataset/slomoDVS3/event/train' 61 | 62 | poolSize = 20 63 | 64 | batchRun(srcDir=srcDir, 65 | dstDir=dstDir, 66 | poolSize=poolSize) 67 | 68 | 69 | def mainLocal(): 70 | # dirs of simulated events 71 | srcDir = '/home/sensetime/data/event/DVS/slomoDVS/event/train' 72 | 73 | # dirs of output samples 74 | dstDir = '/home/sensetime/data/event/DVS/slomoDVS/event/test' 75 | 76 | poolSize = 1 77 | 78 | batchRun(srcDir=srcDir, 79 | dstDir=dstDir, 80 | poolSize=poolSize) 81 | 82 | 83 | if __name__ == '__main__': 84 | mainServer() 85 | # mainLocal() 86 | -------------------------------------------------------------------------------- /stage2/lib/videoTool.py: -------------------------------------------------------------------------------- 1 | import os 2 | import lib.fileTool as FT 3 | 4 | ffmpegPath = '/usr/bin/ffmpeg' 5 | videoName = '/home/sensetime/data/VideoInterpolation/highfps/gopro_yzy/bbb.MP4' 6 | 7 | 8 | def video2Frame(vPath: str, fdir: str, H: int = None, W: int = None): 9 | FT.mkPath(fdir) 10 | if H is None or W is None: 11 | os.system('{} -y -i {} -vsync 0 -qscale:v 2 {}/%04d.png'.format(ffmpegPath, vPath, fdir)) 12 | else: 13 | os.system('{} -y -i {} -vf scale={}:{} -vsync 0 -qscale:v 2 {}/%04d.jpg'.format(ffmpegPath, vPath, W, H, fdir)) 14 | 15 | 16 | def frame2Video(fdir: str, vPath: str, fps: int, H: int = None, W: int = None, ): 17 | if H is None or W is None: 18 | os.system('{} -y -r {} -f image2 -i {}/%*.png -vcodec libx264 -crf 18 -pix_fmt yuv420p {}' 19 | .format(ffmpegPath, fps, fdir, vPath)) 20 | 21 | # os.system('{} -y -r {} -f image2 -i {}/%6d.png -vcodec libx264 -crf 18 -pix_fmt yuv420p {}' 22 | # .format(ffmpegPath, fps, fdir, vPath)) 23 | else: 24 | os.system('{} -y -r {} -f image2 -s {}x{} -i {}/%*.png -vcodec libx264 -crf 25 -pix_fmt yuv420p {}' 25 | .format(ffmpegPath, fps, W, H, fdir, vPath)) 26 | 27 | 28 | def slomo(vPath: str, dstPath: str, fps): 29 | os.system( 30 | '{} -y -r {} -i {} -strict -2 -vcodec libx264 -c:a aac -crf 18 {}'.format(ffmpegPath, fps, vPath, dstPath)) 31 | 32 | 33 | def downFPS(vPath: str, dstPath: str, fps): 34 | os.system( 35 | '{} -i {} -strict -2 -r {} {}'.format(ffmpegPath, vPath, fps, dstPath)) 36 | 37 | 38 | def downSample(vPath: str, dstPath: str, H, W): 39 | os.system( 40 | '{} -i {} -strict -2 -s {}x{} {}'.format(ffmpegPath, vPath, H, W, dstPath)) 41 | 42 | 43 | if __name__ == '__main__': 44 | framePath = '/home/sensetime/data/ICCV2021/OurResults/slomoDVS34_16/Ours_S2/slomoDVS-2021_02_24_11_48_40' 45 | # framePath = '/home/sensetime/data/VideoInterpolation/highfps/gopro_yzy/output' 46 | video = '/home/sensetime/data/ICCV2021/OurResults/slomoDVS34_16/Ours_S2/slomoDVS-2021_02_24_11_48_40/video.mp4' 47 | # video2Frame(video, framePath) 48 | 49 | # video = '/home/sensetime/data/VideoInterpolation/highfps/goPro_240fps/train/GOPR0372_07_00/out.mp4' 50 | # framePath = '/home/sensetime/data/VideoInterpolation/highfps/goPro_240fps/train/GOPR0372_07_00' 51 | frame2Video(framePath, video, 24) 52 | 53 | # vPath = '/media/sensetime/ 54 | # Elements/0721 /0716_video/1.avi' 55 | # dstPath = '/media/sensetime/Elements/0721/0716_video/1_.mp4' 56 | # downFPS(vPath, dstPath, 8) 57 | -------------------------------------------------------------------------------- /DVSTool/lib/videoTool.py: -------------------------------------------------------------------------------- 1 | import os 2 | import lib.fileTool as FT 3 | from pathlib import Path 4 | import cv2 5 | 6 | ffmpegPath = '/usr/bin/ffmpeg' 7 | import numpy as np 8 | 9 | 10 | def video2Frame(vPath: str, fdir: str, H: int = None, W: int = None): 11 | FT.mkPath(fdir) 12 | if H is None or W is None: 13 | os.system('{} -y -i {} -vsync 0 -qscale:v 2 {}/%07d.png'.format(ffmpegPath, vPath, fdir)) 14 | else: 15 | os.system('{} -y -i {} -vf scale={}:{} -vsync 0 -qscale:v 2 {}/%07d.jpg'.format(ffmpegPath, vPath, W, H, fdir)) 16 | 17 | 18 | def frame2Video(fdir: str, vPath: str, fps: int, H: int = None, W: int = None, ): 19 | if H is None or W is None: 20 | # os.system('{} -y -r {} -f image2 -i {}/%*.png -vcodec libx264 -crf 18 -pix_fmt yuv420p {}' 21 | # .format(ffmpegPath, fps, fdir, vPath)) 22 | 23 | os.system('{} -y -r {} -f image2 -i {}/%6d.png -vcodec libx264 -crf 18 -pix_fmt yuv420p {}' 24 | .format(ffmpegPath, fps, fdir, vPath)) 25 | else: 26 | os.system('{} -y -r {} -f image2 -s {}x{} -i {}/%*.png -vcodec libx264 -crf 25 -pix_fmt yuv420p {}' 27 | .format(ffmpegPath, fps, W, H, fdir, vPath)) 28 | 29 | 30 | def slomo(vPath: str, dstPath: str, fps): 31 | os.system( 32 | '{} -y -r {} -i {} -strict -2 -vcodec libx264 -c:a aac -crf 18 {}'.format(ffmpegPath, fps, vPath, dstPath)) 33 | 34 | 35 | def downFPS(vPath: str, dstPath: str, fps): 36 | os.system( 37 | '{} -i {} -strict -2 -r {} {}'.format(ffmpegPath, vPath, fps, dstPath)) 38 | 39 | 40 | def downSample(vPath: str, dstPath: str, H, W): 41 | os.system( 42 | '{} -i {} -strict -2 -s {}x{} {}'.format(ffmpegPath, vPath, H, W, dstPath)) 43 | 44 | 45 | def batchVideo2Frames(videosDir, outDir): 46 | allVideos = FT.getAllFiles(videosDir) 47 | for video in allVideos: 48 | targetPath = str(Path(outDir) / Path(video).stem) 49 | 50 | video2Frame(video, targetPath) 51 | 52 | 53 | def roteResize(videoPath, outPath): 54 | allsubDirs = FT.getSubDirs(videoPath) 55 | for subDir in allsubDirs: 56 | allFrames = FT.getAllFiles(subDir) 57 | for idx, frame in enumerate(allFrames): 58 | targetPath = frame.replace(videoPath, outPath) 59 | FT.mkPath(Path(targetPath).parent) 60 | if idx <= 5: 61 | # FT.delPath(frame) 62 | continue 63 | img: np.ndarray = cv2.imread(frame) 64 | H, W, C = img.shape 65 | if H > W: 66 | img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE) 67 | img = cv2.resize(img, dsize=(640, 360)) 68 | cv2.imwrite(targetPath, img) 69 | pass 70 | 71 | 72 | if __name__ == '__main__': 73 | videoDir = '/home/sensetime/data/VideoInterpolation/highfps/adobe_240fps/DeepVideoDeblurring_Dataset_Original_High_FPS_Videos' 74 | outDir = '/home/sensetime/data/VideoInterpolation/highfps/adobe_240fps/adobe240_frames' 75 | outPath = '/home/sensetime/data/VideoInterpolation/highfps/adobe_240fps/adobe240_frames_small' 76 | # batchVideo2Frames(videoDir, outDir) 77 | roteResize(outDir, outPath) 78 | # framePath = '/home/sensetime/data/event/outputTest/pencil2/' 79 | # framePath = '/home/sensetime/data/VideoInterpolation/highfps/gopro_yzy/output' 80 | # video = '/home/sensetime/data/event/outputvideo/pencil2.mp4' 81 | # video2Frame(video, framePath) 82 | 83 | # video = '/home/sensetime/data/VideoInterpolation/highfps/goPro_240fps/train/GOPR0372_07_00/out.mp4' 84 | # framePath = '/home/sensetime/data/VideoInterpolation/highfps/goPro_240fps/train/GOPR0372_07_00' 85 | # frame2Video(framePath, video, 60) 86 | 87 | # vPath = '/media/sensetime/Elements/0721 /0716_video/1.avi' 88 | # dstPath = '/media/sensetime/Elements/0721/0716_video/1_.mp4' 89 | # downFPS(vPath, dstPath, 8) 90 | -------------------------------------------------------------------------------- /stage2/models/subPixelAttn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from lib.warp import backWarp as bWarp 6 | # from lib.softSplit import ModuleSoftsplat 7 | from models.baseModule import Correlation 8 | 9 | 10 | class MultiScaleAttn(nn.Module): 11 | def __init__(self): 12 | super(MultiScaleAttn, self).__init__() 13 | # self.p = [8, 4, 2] 14 | self.p = [8, 4, 2] 15 | self.norm = F.normalize 16 | self.ConvIn = nn.Conv2d(in_channels=1152, out_channels=32, kernel_size=1, bias=False) 17 | self.corr4x = Correlation(pad_size=self.p[0], kernel_size=1, max_displacement=self.p[0], stride1=1, stride2=1) 18 | 19 | def maxId2Offset(self, maxIdx: torch.Tensor, pSize): 20 | U = maxIdx % (pSize * 2 + 1) - pSize 21 | V = maxIdx.int() / int((pSize * 2 + 1)) - pSize 22 | return torch.stack([U.float(), V.float()], dim=1) 23 | 24 | def getSubOffset(self, cosDis, maxIdx, subAttnConv): 25 | nx, sxsx, hk, wk = cosDis.shape 26 | sx = int(np.sqrt(sxsx)) 27 | cout = subAttnConv.shape[0] 28 | 29 | maxIdx = maxIdx.reshape(nx, 1, hk, wk).expand(nx, cout, hk, wk).reshape(nx * cout, 1, hk, wk) 30 | 31 | l2Dis: torch.Tensor = 2.0 - 2.0 * cosDis # N, hkwk, hqwq 32 | l2Dis = l2Dis.permute(0, 2, 3, 1).reshape(nx, 1, hk * wk, sx, sx) # N, 1, hkwk, 17, 17 33 | 34 | l2Dis = F.pad(l2Dis, [1, 1, 1, 1], value=2) 35 | ABC = F.conv3d(input=l2Dis, weight=subAttnConv, bias=None, stride=(1, 1, 1)) # N, 5, hkwk, 17, 17 36 | 37 | ABC = ABC.reshape(nx * cout, hk, wk, sxsx).contiguous() # nx * 5, hk, wk, sxsx 38 | ABC = ABC.permute(0, 3, 1, 2).contiguous() # nx * 5, sxsx, hk, wk 39 | 40 | ABCHard = torch.gather(ABC, dim=1, index=maxIdx) 41 | # nx * 5, 1, hk, wk 42 | ABCHard = ABCHard.reshape(nx, cout, hk, wk) # nx, 5, hqwq 43 | # 44 | subOffU = - ABCHard[:, 2, ...] / (ABCHard[:, 0, ...].clamp(min=1e-6)) # nx, hqwq 45 | subOffU = subOffU.clamp(max=1, min=-1).reshape(nx, 1, hk, wk) # nx, 1, hq, wq 46 | # 47 | subOffV = - ABCHard[:, 3, ...] / (ABCHard[:, 1, ...].clamp(min=1e-6)) # nx, 1, hqwq 48 | subOffV = subOffV.clamp(max=1, min=-1).reshape(nx, 1, hk, wk) # nx, 1, hq, wq 49 | return torch.cat([subOffU, subOffV], dim=1) 50 | 51 | def forward(self, Kt1x, Kt2x, KVt4x, V01x, V02x, KV04x, subAttnMatC): 52 | N, C, H, W = KVt4x.shape 53 | 54 | KVt4x_unfold = self.ConvIn(F.unfold(KVt4x, kernel_size=(3, 3), padding=1).reshape(N, -1, H, W)) 55 | KV04x_unfold = self.ConvIn(F.unfold(KV04x, kernel_size=(3, 3), padding=1).reshape(N, -1, H, W)) 56 | 57 | cosDis4x = self.corr4x(self.norm(KVt4x_unfold, dim=1), 58 | self.norm(KV04x_unfold, dim=1)) * KV04x_unfold.shape[1] # N, (17*17), Hk, Wk 59 | 60 | maxValue4x, maxIdx4x = torch.max(cosDis4x, dim=1) # [N, Hq, Wq] 61 | maxValue4x = maxValue4x.view(N, 1, H, W) 62 | 63 | hardOffset4x = self.maxId2Offset(maxIdx4x, self.p[0]) 64 | subOffset4x = self.getSubOffset(cosDis4x, maxIdx4x, subAttnMatC.detach()) 65 | flowOff4x = hardOffset4x + subOffset4x 66 | 67 | KV04x_unfold = F.unfold(KV04x, kernel_size=(3, 3), padding=1).reshape(N, -1, H, W) 68 | V02x_unfold = F.unfold(V02x, kernel_size=(6, 6), padding=2, stride=2).reshape(N, -1, H, W) 69 | V01x_unfold = F.unfold(V01x, kernel_size=(12, 12), padding=4, stride=4).reshape(N, -1, H, W) 70 | 71 | T4x_unfold = bWarp(KV04x_unfold, flowOff4x).reshape(N, -1, H * W) 72 | T2x_unfold = bWarp(V02x_unfold, flowOff4x).reshape(N, -1, H * W) 73 | T1x_unfold = bWarp(V01x_unfold, flowOff4x).reshape(N, -1, H * W) 74 | 75 | T4x = F.fold(T4x_unfold, output_size=(H, W), kernel_size=(3, 3), padding=1, stride=1) / (3. * 3.) 76 | T2x = F.fold(T2x_unfold, output_size=(H * 2, W * 2), kernel_size=(6, 6), padding=2, stride=2) / (3. * 3.) 77 | T1x = F.fold(T1x_unfold, output_size=(H * 4, W * 4), kernel_size=(12, 12), padding=4, stride=4) / (3. * 3.) 78 | 79 | S = maxValue4x.view(N, 1, H, W) 80 | 81 | return S, T4x, T2x, T1x 82 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Training Weakly Supervised Video Frame Interpolation with Events 2 | (accepted by ICCV2021) 3 | 4 | [[Paper](https://openaccess.thecvf.com/content/ICCV2021/html/Yu_Training_Weakly_Supervised_Video_Frame_Interpolation_With_Events_ICCV_2021_paper.html)] 5 | [[Video](https://www.youtube.com/watch?v=ktG5U3WKGes&t=2s)] 6 | 7 | ### 1.Abstract 8 | This version of code is used for training on real low-fps data of dvs, which is collected by [DAVIS240C](https://inivation.com/wp-content/uploads/2019/08/DAVIS240.pdf). This code can be trained by the visible low-fps frames(12fps) with corresponding events saved in the aedat4 files and interpolate the inbetweens at any time. An aedat4 file is provided in dataset/aedat4, which can be used as a demo to run the whole process. 9 | 10 | ### 2.Environments 11 | 1) cuda 9.0 12 | 13 | 2) python 3.7 14 | 15 | 3) pytorch 1.1 16 | 17 | 4) numpy 1.17.2 18 | 19 | 5) tqdm 20 | 21 | 6) gcc 5.2.0 22 | 23 | 7) cmake 3.16.0 24 | 25 | 8) opencv_contrib_python 26 | 27 | 9) compiling correlation module 28 | (The PWCNet and the correlation module are modified from [DAIN](https://github.com/baowenbo/DAIN/tree/master/PWCNet)) 29 | 30 | a) cd stage1/lib/pwcNet/correlation_pytorch1_1 31 | 32 | b) python setup.py install 33 | 34 | 35 | 10) Install apex: https://github.com/NVIDIA/apex 36 | 37 | 11) For processing DVS file: 38 | 39 | a) More detail information about aedat4 file and DAVIS240C can be found in [here](https://inivation.gitlab.io/dv/dv-docs/docs/getting-started/) 40 | 41 | b) tools for processing aedat4 file: [dv-python](https://gitlab.com/inivation/dv/dv-python) 42 | 43 | 12) For distributed training with multi-gpus on cluster: slurm 15.08.11 44 | 45 | ### 3.Preparing training data 46 | You can prepare your own event data according to the demo in DVSTool 47 | 48 | 1) Place aedat4 file in ./dataset/aedat4 49 | 2) cd DVSTool 50 | 3) python mainDVSProcess_01.py 51 | It will extract the events and frame saved in .aedat4 into pkl which will be saved in dataset/fastDVS_process 52 | 4) python mainGetDVSTrain_02.py 53 | It will gather the train samples and save in dataset/fastDVS_dataset/train. (A train sample includes I0, I1, I2, I01, I21 and E1) 54 | 5) python mainGetDVSTest_03.py 55 | It will gather the test samples and save in dataset/fastDVS_dataset/test (A test sample includes I_-1, I0, I1, I2, E1/3, E2/3) 56 | ### 4.Training stage1 57 | cd stage1 58 | #### 1) Training with single gpu: 59 | a) Modify the config in configs/configEVI.py accordingly 60 | 61 | b) python train.py 62 | 63 | #### 2) Training with muli-gpus(16) on cluster managed by slurm: 64 | a) Modify config in configs/configEVI.py accordingly 65 | 66 | b) Modify runEvi.py in runBash accordingly 67 | 68 | c) python runBash/runEvi.py 69 | 70 | ### 5.Training stage2 71 | cd stage2 72 | 73 | Place the experiment dir trained by stage1 in ./output 74 | 75 | #### 1) Training with single gpu: 76 | a) Modify the config in configs/configEVI.py accordingly, especially the path in lines 28, 29 77 | 78 | b) python train.py 79 | 80 | #### 2) Training with muli-gpus(16) on cluster managed by slurm: 81 | a) Modify config in configs/configEVI.py accordingly, especially the path in lines 28, 29 82 | 83 | b) Modify runEvi.py in runBash accordingly 84 | 85 | c) python runBash/runEvi.py 86 | 87 | ### 6. Citation 88 | ``` 89 | @InProceedings{Yu_2021_ICCV, 90 | author = {Yu, Zhiyang and Zhang, Yu and Liu, Deyuan and Zou, Dongqing and Chen, Xijun and Liu, Yebin and Ren, Jimmy S.}, 91 | title = {Training Weakly Supervised Video Frame Interpolation With Events}, 92 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 93 | month = {October}, 94 | year = {2021}, 95 | pages = {14589-14598} 96 | } 97 | ``` 98 | ### 7. Reference code base 99 | [[styleGAN](https://github.com/tomguluson92/StyleGAN_PyTorch)], [[TTSR](https://github.com/researchmm/TTSR)], [[DAIN](https://github.com/baowenbo/DAIN/tree/master/PWCNet)], [[superSlomo](https://github.com/avinashpaliwal/Super-SloMo)], [[QVI](https://sites.google.com/view/xiangyuxu/qvi_nips19)], [[faceShiter](https://github.com/mindslab-ai/faceshifter)] 100 | -------------------------------------------------------------------------------- /stage1/lib/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from skimage.measure import compare_psnr, compare_ssim 4 | import torch.nn.functional as F 5 | from math import exp 6 | 7 | 8 | class AverageMeter(object): 9 | """Computes and stores the average and current value""" 10 | 11 | def __init__(self): 12 | self.reset() 13 | 14 | def reset(self): 15 | self.val = 0 16 | self.avg = 0 17 | self.sum = 0 18 | self.count = 0 19 | 20 | def update(self, val, n=1): 21 | self.val = val 22 | self.sum += val * n 23 | self.count += n 24 | self.avg = self.sum / self.count 25 | 26 | 27 | def calPSNR(imgTrue: torch.Tensor, imgTest: torch.Tensor, cfg): 28 | # inv = inNormalize(cfg.trainMean, cfg.trainStd) 29 | # 30 | # imgTrue = inv(imgTrue) 31 | # imgTest = inv(imgTest) 32 | 33 | imgTrue = imgTrue[0, ...].detach().cpu().numpy().transpose(1, 2, 0) 34 | imgTrue = np.ascontiguousarray(imgTrue) 35 | 36 | imgTest = imgTest[0, ...].detach().cpu().numpy().transpose(1, 2, 0) 37 | imgTest = np.ascontiguousarray(imgTest) 38 | 39 | return compare_psnr(imgTrue, imgTest, data_range=1) 40 | # return compare_psnr(imgTrue, imgTest) 41 | 42 | 43 | def calPSNRBatch(imgTrue: torch.Tensor, imgTest: torch.Tensor, cfg): 44 | bSize = imgTrue.size(0) 45 | mse = ((imgTrue - imgTest) ** 2).mean(dim=[1, 2, 3]) 46 | psnrBatch = 10 * torch.log10(mse.reciprocal()) 47 | return psnrBatch, bSize 48 | 49 | 50 | def calSSIM(imgTrue: torch.Tensor, imgTest: torch.Tensor, cfg): 51 | # inv = inNormalize(cfg.trainMean, cfg.trainStd) 52 | # 53 | # imgTrue = inv(imgTrue) 54 | # imgTest = inv(imgTest) 55 | 56 | imgTrue = imgTrue[0, ...].detach().cpu().numpy().transpose(1, 2, 0) 57 | imgTrue = np.ascontiguousarray(imgTrue) 58 | imgTest = imgTest[0, ...].detach().cpu().numpy().transpose(1, 2, 0) 59 | imgTest = np.ascontiguousarray(imgTest) 60 | 61 | return compare_ssim(imgTrue, imgTest, data_range=1, multichannel=True, gaussian_weights=True) 62 | # return compare_ssim(imgTrue, imgTest, multichannel=True) 63 | 64 | 65 | def gaussian(window_size, sigma): 66 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 67 | return gauss / gauss.sum() 68 | 69 | 70 | def create_window(window_size, channel): 71 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 72 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 73 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() 74 | return window 75 | 76 | 77 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 78 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 79 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 80 | 81 | mu1_sq = mu1.pow(2) 82 | mu2_sq = mu2.pow(2) 83 | mu1_mu2 = mu1 * mu2 84 | 85 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 86 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 87 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 88 | 89 | C1 = 0.01 ** 2 90 | C2 = 0.03 ** 2 91 | 92 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 93 | 94 | if size_average: 95 | return ssim_map.mean() 96 | else: 97 | return ssim_map.mean(1).mean(1).mean(1) 98 | 99 | 100 | class SSIM(torch.nn.Module): 101 | def __init__(self, window_size=11, size_average=False): 102 | super(SSIM, self).__init__() 103 | self.window_size = window_size 104 | self.size_average = size_average 105 | self.channel = 1 106 | self.window = create_window(window_size, self.channel) 107 | 108 | def forward(self, img1, img2): 109 | (_, channel, _, _) = img1.size() 110 | 111 | if channel == self.channel and self.window.data.type() == img1.data.type(): 112 | window = self.window 113 | else: 114 | window = create_window(self.window_size, channel) 115 | 116 | if img1.is_cuda: 117 | window = window.cuda(img1.get_device()) 118 | window = window.type_as(img1) 119 | 120 | self.window = window 121 | self.channel = channel 122 | 123 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 124 | -------------------------------------------------------------------------------- /stage2/lib/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from skimage.measure import compare_psnr, compare_ssim 4 | import torch.nn.functional as F 5 | from math import exp 6 | 7 | 8 | class AverageMeter(object): 9 | """Computes and stores the average and current value""" 10 | 11 | def __init__(self): 12 | self.reset() 13 | 14 | def reset(self): 15 | self.val = 0 16 | self.avg = 0 17 | self.sum = 0 18 | self.count = 0 19 | 20 | def update(self, val, n=1): 21 | self.val = val 22 | self.sum += val * n 23 | self.count += n 24 | self.avg = self.sum / self.count 25 | 26 | 27 | def calPSNR(imgTrue: torch.Tensor, imgTest: torch.Tensor, cfg): 28 | # inv = inNormalize(cfg.trainMean, cfg.trainStd) 29 | # 30 | # imgTrue = inv(imgTrue) 31 | # imgTest = inv(imgTest) 32 | 33 | imgTrue = imgTrue[0, ...].detach().cpu().numpy().transpose(1, 2, 0) 34 | imgTrue = np.ascontiguousarray(imgTrue) 35 | 36 | imgTest = imgTest[0, ...].detach().cpu().numpy().transpose(1, 2, 0) 37 | imgTest = np.ascontiguousarray(imgTest) 38 | 39 | return compare_psnr(imgTrue, imgTest, data_range=1) 40 | # return compare_psnr(imgTrue, imgTest) 41 | 42 | 43 | def calPSNRBatch(imgTrue: torch.Tensor, imgTest: torch.Tensor, cfg): 44 | bSize = imgTrue.size(0) 45 | mse = ((imgTrue - imgTest) ** 2).mean(dim=[1, 2, 3]) 46 | psnrBatch = 10 * torch.log10(mse.reciprocal()) 47 | return psnrBatch, bSize 48 | 49 | 50 | def calSSIM(imgTrue: torch.Tensor, imgTest: torch.Tensor, cfg): 51 | # inv = inNormalize(cfg.trainMean, cfg.trainStd) 52 | # 53 | # imgTrue = inv(imgTrue) 54 | # imgTest = inv(imgTest) 55 | 56 | imgTrue = imgTrue[0, ...].detach().cpu().numpy().transpose(1, 2, 0) 57 | imgTrue = np.ascontiguousarray(imgTrue) 58 | imgTest = imgTest[0, ...].detach().cpu().numpy().transpose(1, 2, 0) 59 | imgTest = np.ascontiguousarray(imgTest) 60 | 61 | return compare_ssim(imgTrue, imgTest, data_range=1, multichannel=True, gaussian_weights=True) 62 | # return compare_ssim(imgTrue, imgTest, multichannel=True) 63 | 64 | 65 | def gaussian(window_size, sigma): 66 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 67 | return gauss / gauss.sum() 68 | 69 | 70 | def create_window(window_size, channel): 71 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 72 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 73 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() 74 | return window 75 | 76 | 77 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 78 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 79 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 80 | 81 | mu1_sq = mu1.pow(2) 82 | mu2_sq = mu2.pow(2) 83 | mu1_mu2 = mu1 * mu2 84 | 85 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 86 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 87 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 88 | 89 | C1 = 0.01 ** 2 90 | C2 = 0.03 ** 2 91 | 92 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 93 | 94 | if size_average: 95 | return ssim_map.mean() 96 | else: 97 | return ssim_map.mean(1).mean(1).mean(1) 98 | 99 | 100 | class SSIM(torch.nn.Module): 101 | def __init__(self, window_size=11, size_average=False): 102 | super(SSIM, self).__init__() 103 | self.window_size = window_size 104 | self.size_average = size_average 105 | self.channel = 1 106 | self.window = create_window(window_size, self.channel) 107 | 108 | def forward(self, img1, img2): 109 | (_, channel, _, _) = img1.size() 110 | 111 | if channel == self.channel and self.window.data.type() == img1.data.type(): 112 | window = self.window 113 | else: 114 | window = create_window(self.window_size, channel) 115 | 116 | if img1.is_cuda: 117 | window = window.cuda(img1.get_device()) 118 | window = window.type_as(img1) 119 | 120 | self.window = window 121 | self.channel = channel 122 | 123 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 124 | -------------------------------------------------------------------------------- /DVSTool/lib/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from skimage.measure import compare_psnr, compare_ssim 4 | import torch.nn.functional as F 5 | from math import exp 6 | 7 | 8 | class AverageMeter(object): 9 | """Computes and stores the average and current value""" 10 | 11 | def __init__(self): 12 | self.reset() 13 | 14 | def reset(self): 15 | self.val = 0 16 | self.avg = 0 17 | self.sum = 0 18 | self.count = 0 19 | 20 | def update(self, val, n=1): 21 | self.val = val 22 | self.sum += val * n 23 | self.count += n 24 | self.avg = self.sum / self.count 25 | 26 | 27 | def calPSNR(imgTrue: torch.Tensor, imgTest: torch.Tensor, cfg): 28 | # inv = inNormalize(cfg.trainMean, cfg.trainStd) 29 | # 30 | # imgTrue = inv(imgTrue) 31 | # imgTest = inv(imgTest) 32 | 33 | imgTrue = imgTrue[0, ...].detach().cpu().numpy().transpose(1, 2, 0) 34 | imgTrue = np.ascontiguousarray(imgTrue) 35 | 36 | imgTest = imgTest[0, ...].detach().cpu().numpy().transpose(1, 2, 0) 37 | imgTest = np.ascontiguousarray(imgTest) 38 | 39 | return compare_psnr(imgTrue, imgTest, data_range=1) 40 | # return compare_psnr(imgTrue, imgTest) 41 | 42 | 43 | def calPSNRBatch(imgTrue: torch.Tensor, imgTest: torch.Tensor, cfg): 44 | bSize = imgTrue.size(0) 45 | mse = ((imgTrue - imgTest) ** 2).mean(dim=[1, 2, 3]) 46 | psnrBatch = 10 * torch.log10(mse.reciprocal()) 47 | return psnrBatch, bSize 48 | 49 | 50 | def calSSIM(imgTrue: torch.Tensor, imgTest: torch.Tensor, cfg): 51 | # inv = inNormalize(cfg.trainMean, cfg.trainStd) 52 | # 53 | # imgTrue = inv(imgTrue) 54 | # imgTest = inv(imgTest) 55 | 56 | imgTrue = imgTrue[0, ...].detach().cpu().numpy().transpose(1, 2, 0) 57 | imgTrue = np.ascontiguousarray(imgTrue) 58 | imgTest = imgTest[0, ...].detach().cpu().numpy().transpose(1, 2, 0) 59 | imgTest = np.ascontiguousarray(imgTest) 60 | 61 | return compare_ssim(imgTrue, imgTest, data_range=1, multichannel=True, gaussian_weights=True) 62 | # return compare_ssim(imgTrue, imgTest, multichannel=True) 63 | 64 | 65 | def gaussian(window_size, sigma): 66 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 67 | return gauss / gauss.sum() 68 | 69 | 70 | def create_window(window_size, channel): 71 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 72 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 73 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() 74 | return window 75 | 76 | 77 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 78 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 79 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 80 | 81 | mu1_sq = mu1.pow(2) 82 | mu2_sq = mu2.pow(2) 83 | mu1_mu2 = mu1 * mu2 84 | 85 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 86 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 87 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 88 | 89 | C1 = 0.01 ** 2 90 | C2 = 0.03 ** 2 91 | 92 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 93 | 94 | if size_average: 95 | return ssim_map.mean() 96 | else: 97 | return ssim_map.mean(1).mean(1).mean(1) 98 | 99 | 100 | class SSIM(torch.nn.Module): 101 | def __init__(self, window_size=11, size_average=False): 102 | super(SSIM, self).__init__() 103 | self.window_size = window_size 104 | self.size_average = size_average 105 | self.channel = 1 106 | self.window = create_window(window_size, self.channel) 107 | 108 | def forward(self, img1, img2): 109 | (_, channel, _, _) = img1.size() 110 | 111 | if channel == self.channel and self.window.data.type() == img1.data.type(): 112 | window = self.window 113 | else: 114 | window = create_window(self.window_size, channel) 115 | 116 | if img1.is_cuda: 117 | window = window.cuda(img1.get_device()) 118 | window = window.type_as(img1) 119 | 120 | self.window = window 121 | self.channel = channel 122 | 123 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 124 | -------------------------------------------------------------------------------- /stage1/dataloader/dataloaderBase.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Sampler 2 | import torch.distributed as dist 3 | import math 4 | 5 | 6 | class DistributedSamplerVali(Sampler): 7 | """Sampler that restricts data loading to a subset of the dataset. 8 | 9 | It is especially useful in conjunction with 10 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 11 | process can pass a DistributedSampler instance as a DataLoader sampler, 12 | and load a subset of the original dataset that is exclusive to it. 13 | 14 | .. note:: 15 | Dataset is assumed to be of constant size. 16 | 17 | Arguments: 18 | dataset: Dataset used for sampling. 19 | num_replicas (optional): Number of processes participating in 20 | distributed training. 21 | rank (optional): Rank of the current process within num_replicas. 22 | """ 23 | 24 | def __init__(self, dataset, num_replicas=None, rank=None): 25 | # super(DistributedSamplerVali, self).__init__() 26 | if num_replicas is None: 27 | if not dist.is_available(): 28 | raise RuntimeError("Requires distributed package to be available") 29 | num_replicas = dist.get_world_size() 30 | if rank is None: 31 | if not dist.is_available(): 32 | raise RuntimeError("Requires distributed package to be available") 33 | rank = dist.get_rank() 34 | self.dataset = dataset 35 | self.num_replicas = num_replicas 36 | self.rank = rank 37 | self.epoch = 0 38 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 39 | self.total_size = self.num_samples * self.num_replicas 40 | 41 | def __iter__(self): 42 | # deterministically shuffle based on epoch 43 | 44 | indices = list(range(len(self.dataset))) 45 | 46 | # add extra samples to make it evenly divisible 47 | indices += indices[-(self.total_size - len(indices))::] 48 | assert len(indices) == self.total_size 49 | 50 | # subsample 51 | indices = indices[self.rank:self.total_size:self.num_replicas] 52 | assert len(indices) == self.num_samples 53 | 54 | return iter(indices) 55 | 56 | def __len__(self): 57 | return self.num_samples 58 | 59 | def set_epoch(self, epoch): 60 | self.epoch = epoch 61 | 62 | 63 | class Sample(object): 64 | __slots__ = {'I0', 'I1', 'It', 'Et', 'I0t', 'I1t'} 65 | 66 | def __init__(self, I0, I1, It, Et, I0t, I1t): 67 | super(Sample, self).__init__() 68 | 69 | self.I0 = I0 70 | self.I1 = I1 71 | self.It = It 72 | self.Et = Et 73 | 74 | self.I0t = I0t 75 | self.I1t = I1t 76 | 77 | 78 | class Samples(object): 79 | __slots__ = {'I1', 'I2', 'I3', 'I4', 'I5', 80 | 'E2', 'E3', 'E4', 81 | 'I12', 'I13', 'I14', 'I23', 'I24', 'I32', 'I34', 'I43', 'I42', 'I54', 'I53', 'I52'} 82 | 83 | def __init__(self, I1, I2, I3, I4, I5, 84 | E2, E3, E4, 85 | I12, I13, I14, I23, I24, I32, I34, I43, I42, I54, I53, I52): 86 | super(Samples, self).__init__() 87 | self.I1, self.I2, self.I3, self.I4, self.I5 = I1, I2, I3, I4, I5 88 | 89 | self.E2, self.E3, self.E4 = E2, E3, E4 90 | 91 | self.I12, self.I13, self.I14 = I12, I13, I14 92 | 93 | self.I23, self.I24 = I23, I24 94 | self.I32, self.I34 = I32, I34 95 | self.I43, self.I42 = I43, I42 96 | self.I54, self.I53, self.I52 = I54, I53, I52 97 | 98 | def getSample(self, idx): 99 | cases = {'0': Sample(self.I1, self.I3, self.I2, self.E2, self.I12, self.I32), 100 | 101 | '1': Sample(self.I1, self.I4, self.I2, self.E2, self.I12, self.I42), 102 | 103 | '2': Sample(self.I1, self.I5, self.I2, self.E2, self.I12, self.I52), 104 | 105 | '3': Sample(self.I1, self.I4, self.I3, self.E3, self.I13, self.I43), 106 | 107 | '4': Sample(self.I1, self.I5, self.I3, self.E3, self.I13, self.I53), 108 | 109 | '5': Sample(self.I1, self.I5, self.I4, self.E4, self.I14, self.I54), 110 | 111 | '6': Sample(self.I2, self.I4, self.I3, self.E3, self.I23, self.I43), 112 | 113 | '7': Sample(self.I2, self.I5, self.I3, self.E3, self.I23, self.I53), 114 | 115 | '8': Sample(self.I2, self.I5, self.I4, self.E4, self.I24, self.I54), 116 | 117 | '9': Sample(self.I3, self.I5, self.I4, self.E4, self.I34, self.I54) 118 | } 119 | return cases[str(idx)] 120 | -------------------------------------------------------------------------------- /stage2/dataloader/dataloaderBase.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Sampler 2 | import torch.distributed as dist 3 | import math 4 | 5 | 6 | class DistributedSamplerVali(Sampler): 7 | """Sampler that restricts data loading to a subset of the dataset. 8 | 9 | It is especially useful in conjunction with 10 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 11 | process can pass a DistributedSampler instance as a DataLoader sampler, 12 | and load a subset of the original dataset that is exclusive to it. 13 | 14 | .. note:: 15 | Dataset is assumed to be of constant size. 16 | 17 | Arguments: 18 | dataset: Dataset used for sampling. 19 | num_replicas (optional): Number of processes participating in 20 | distributed training. 21 | rank (optional): Rank of the current process within num_replicas. 22 | """ 23 | 24 | def __init__(self, dataset, num_replicas=None, rank=None): 25 | # super(DistributedSamplerVali, self).__init__() 26 | if num_replicas is None: 27 | if not dist.is_available(): 28 | raise RuntimeError("Requires distributed package to be available") 29 | num_replicas = dist.get_world_size() 30 | if rank is None: 31 | if not dist.is_available(): 32 | raise RuntimeError("Requires distributed package to be available") 33 | rank = dist.get_rank() 34 | self.dataset = dataset 35 | self.num_replicas = num_replicas 36 | self.rank = rank 37 | self.epoch = 0 38 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 39 | self.total_size = self.num_samples * self.num_replicas 40 | 41 | def __iter__(self): 42 | # deterministically shuffle based on epoch 43 | 44 | indices = list(range(len(self.dataset))) 45 | 46 | # add extra samples to make it evenly divisible 47 | indices += indices[:(self.total_size - len(indices))] 48 | assert len(indices) == self.total_size 49 | 50 | # subsample 51 | indices = indices[self.rank:self.total_size:self.num_replicas] 52 | assert len(indices) == self.num_samples 53 | 54 | return iter(indices) 55 | 56 | def __len__(self): 57 | return self.num_samples 58 | 59 | def set_epoch(self, epoch): 60 | self.epoch = epoch 61 | 62 | 63 | class Sample(object): 64 | __slots__ = {'I0', 'I1', 'It', 'Et', 'I0t', 'I1t'} 65 | 66 | def __init__(self, I0, I1, It, Et, I0t, I1t): 67 | super(Sample, self).__init__() 68 | 69 | self.I0 = I0 70 | self.I1 = I1 71 | self.It = It 72 | self.Et = Et 73 | 74 | self.I0t = I0t 75 | self.I1t = I1t 76 | 77 | 78 | class Samples(object): 79 | __slots__ = {'I1', 'I2', 'I3', 'I4', 'I5', 80 | 'E2', 'E3', 'E4', 81 | 'I12', 'I13', 'I14', 'I23', 'I24', 'I32', 'I34', 'I43', 'I42', 'I54', 'I53', 'I52'} 82 | 83 | def __init__(self, I1, I2, I3, I4, I5, 84 | E2, E3, E4, 85 | I12, I13, I14, I23, I24, I32, I34, I43, I42, I54, I53, I52): 86 | super(Samples, self).__init__() 87 | self.I1, self.I2, self.I3, self.I4, self.I5 = I1, I2, I3, I4, I5 88 | 89 | self.E2, self.E3, self.E4 = E2, E3, E4 90 | 91 | self.I12, self.I13, self.I14 = I12, I13, I14 92 | 93 | self.I23, self.I24 = I23, I24 94 | self.I32, self.I34 = I32, I34 95 | self.I43, self.I42 = I43, I42 96 | self.I54, self.I53, self.I52 = I54, I53, I52 97 | 98 | def getSample(self, idx): 99 | cases = {'0': Sample(self.I1, self.I3, self.I2, self.E2, self.I12, self.I32), 100 | 101 | '1': Sample(self.I1, self.I4, self.I2, self.E2, self.I12, self.I42), 102 | 103 | '2': Sample(self.I1, self.I5, self.I2, self.E2, self.I12, self.I52), 104 | 105 | '3': Sample(self.I1, self.I4, self.I3, self.E3, self.I13, self.I43), 106 | 107 | '4': Sample(self.I1, self.I5, self.I3, self.E3, self.I13, self.I53), 108 | 109 | '5': Sample(self.I1, self.I5, self.I4, self.E4, self.I14, self.I54), 110 | 111 | '6': Sample(self.I2, self.I4, self.I3, self.E3, self.I23, self.I43), 112 | 113 | '7': Sample(self.I2, self.I5, self.I3, self.E3, self.I23, self.I53), 114 | 115 | '8': Sample(self.I2, self.I5, self.I4, self.E4, self.I24, self.I54), 116 | 117 | '9': Sample(self.I3, self.I5, self.I4, self.E4, self.I34, self.I54) 118 | } 119 | return cases[str(idx)] 120 | 121 | -------------------------------------------------------------------------------- /DVSTool/mainDVSProcess_01.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('../') 4 | import cv2 5 | import lib.fileTool as FT 6 | import numpy as np 7 | import pickle 8 | from pathlib import Path 9 | from DVSBase import DVSReader, ESIMReader 10 | from tqdm import tqdm 11 | from multiprocessing import Pool, RLock, freeze_support 12 | from functools import partial 13 | import multiprocessing 14 | import os 15 | 16 | 17 | def simulate(dvsFile: str, outPath): 18 | curProc = multiprocessing.current_process() 19 | dvs = DVSReader(dvsFile) 20 | outDir = Path(outPath) / Path(dvsFile).stem 21 | FT.mkPath(outDir) 22 | pbar = tqdm(total=len(dvs.tImg), position=int(curProc._identity[0])) 23 | for idx in range(len(dvs.tImg) - 1): 24 | img = dvs.Img[idx] 25 | tImgStart = dvs.tImg[idx] 26 | tImgStop = dvs.tImg[idx + 1] 27 | sliceIdx = (np.array(dvs.tE) >= tImgStart) & (np.array(dvs.tE) < tImgStop) 28 | 29 | tE = np.array(dvs.tE)[sliceIdx].tolist() 30 | xE = np.array(dvs.xE)[sliceIdx].tolist() 31 | yE = np.array(dvs.yE)[sliceIdx].tolist() 32 | pE = np.array(dvs.pE)[sliceIdx].tolist() 33 | 34 | recordEvent = {'tE': tE, 'xE': xE, 'yE': yE, 'pE': pE, 35 | 'tImgStart': tImgStart, 'tImgStop': tImgStop, 36 | 'pathImgStart': img} 37 | targetPath = Path(outDir) / Path('{:07d}.pkl'.format(idx)) 38 | 39 | with open(targetPath, 'wb') as fs: 40 | pickle.dump(recordEvent, fs) 41 | 42 | pbar.set_description('{}'.format(Path(dvsFile).stem)) 43 | pbar.update(1) 44 | pbar.clear() 45 | pbar.close() 46 | 47 | 48 | def batchESIM(dirPath: str, outPath, poolSize=4): 49 | allDvsFiles = FT.getAllFiles(dirPath, 'aedat4') 50 | freeze_support() 51 | tqdm.set_lock(RLock()) 52 | p = Pool(processes=poolSize, initializer=tqdm.set_lock, initargs=(tqdm.get_lock(),)) 53 | 54 | kernelFunc = partial(simulate, outPath=outPath) 55 | p.map(func=kernelFunc, iterable=allDvsFiles) 56 | p.close() 57 | p.join() 58 | 59 | 60 | def mainDVSServer(): 61 | # dataPath = '/mnt/lustre/yuzhiyang/dataset/slomeDVS/aedat4/train' 62 | dataPath = '/mnt/lustre/yuzhiyang/dataset/fastDVS/aedat4/oneTest/' 63 | # outPath = '/mnt/lustre/yuzhiyang/dataset/slomeDVS/event/train' 64 | outPath = '/mnt/lustre/yuzhiyang/dataset/fastDVS/event/oneTest/' 65 | batchESIM(dirPath=dataPath, outPath=outPath, poolSize=2) 66 | 67 | 68 | def mainDVSLocal(): 69 | dataPath = '../dataset/aedat4/train' 70 | outPath = '../dataset/fastDVS_process/train' 71 | batchESIM(dirPath=dataPath, outPath=outPath, poolSize=1) 72 | 73 | 74 | def check(): 75 | channel = 1 76 | cv2.namedWindow('1', 0) 77 | eventDirs = '/home/sensetime/data/event/DVS/slomoDVS/event' 78 | allsubEventDirs = FT.getSubDirs(eventDirs) 79 | for eventDir in allsubEventDirs: 80 | allEvents = FT.getAllFiles(eventDir, 'pkl') 81 | for fIdx, evePath in enumerate(allEvents): 82 | esim = ESIMReader(evePath) 83 | # relateEvents = np.zeros([2, esim['height'], esim['width']], np.int8) 84 | tEvents = np.linspace(start=esim.tImgStart[0], stop=esim.tImgStop[0], num=channel + 1, 85 | endpoint=True).tolist() 86 | for eIdx in range(channel): 87 | eStart = tEvents[eIdx] 88 | eStop = tEvents[eIdx + 1] 89 | relateEvents = esim.aggregEvent(eStart, eStop) 90 | 91 | eventImg = relateEvents.astype(np.float32) 92 | eventImg = ((eventImg - eventImg.min()) / (eventImg.max() - eventImg.min() + 1e-5) * 255.0).astype( 93 | np.uint8) 94 | eventImg = cv2.cvtColor(eventImg, cv2.COLOR_GRAY2BGR) 95 | 96 | img = cv2.cvtColor(esim.pathImgStart[0].copy(), cv2.COLOR_GRAY2BGR) 97 | img[:, :, 0][relateEvents != 0] = 0 98 | 99 | img[:, :, 2][relateEvents > 0] = 255 100 | img[:, :, 1][relateEvents < 0] = 255 101 | 102 | cv2.putText(img, '{}_{}'.format(fIdx, eIdx), (20, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.8, 103 | (0, 0, 255), 104 | 5, 105 | cv2.LINE_AA) 106 | 107 | cv2.imshow('1', np.concatenate([img.astype(np.uint8), eventImg], axis=1)) 108 | cv2.waitKey(100) 109 | 110 | 111 | if __name__ == '__main__': 112 | # srun -p Pixel --nodelist= --cpus-per-task=22 --job-name=Train python mainESIM.py 113 | # mainDVSServer() 114 | mainDVSLocal() 115 | # check() 116 | # mainSimGoproTest() 117 | -------------------------------------------------------------------------------- /DVSTool/lib/forwardWarpTorch/fWarpGaussion.py: -------------------------------------------------------------------------------- 1 | # class WarpLayer warps image x based on optical flow flo. 2 | import numpy 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | 8 | 9 | class ForwardWarp(nn.Module): 10 | """docstring for WarpLayer""" 11 | 12 | def __init__(self, ): 13 | super(ForwardWarp, self).__init__() 14 | 15 | def forward(self, img, flo): 16 | """ 17 | -img: image (N, C, H, W) 18 | -flo: optical flow (N, 2, H, W) 19 | elements of flo is in [0, H] and [0, W] for dx, dy 20 | 21 | """ 22 | 23 | # (x1, y1) (x1, y2) 24 | # +---------------+ 25 | # | | 26 | # | o(x, y) | 27 | # | | 28 | # | | 29 | # | | 30 | # | | 31 | # +---------------+ 32 | # (x2, y1) (x2, y2) 33 | 34 | N, C, _, _ = img.size() 35 | 36 | # translate start-point optical flow to end-point optical flow 37 | y = flo[:, 0:1:, :] 38 | x = flo[:, 1:2, :, :] 39 | 40 | x = x.repeat(1, C, 1, 1) 41 | y = y.repeat(1, C, 1, 1) 42 | 43 | # Four point of square (x1, y1), (x1, y2), (x2, y1), (y2, y2) 44 | x1 = torch.floor(x) 45 | x2 = x1 + 1 46 | y1 = torch.floor(y) 47 | y2 = y1 + 1 48 | 49 | # firstly, get gaussian weights 50 | w11, w12, w21, w22 = self.get_gaussian_weights(x, y, x1, x2, y1, y2) 51 | 52 | # secondly, sample each weighted corner 53 | img11, o11 = self.sample_one(img, x1, y1, w11) 54 | img12, o12 = self.sample_one(img, x1, y2, w12) 55 | img21, o21 = self.sample_one(img, x2, y1, w21) 56 | img22, o22 = self.sample_one(img, x2, y2, w22) 57 | 58 | imgw = img11 + img12 + img21 + img22 59 | o = o11 + o12 + o21 + o22 60 | 61 | # nonZero = o != 0 62 | imgw[o > 0] = imgw[o > 0] / o[o > 0] 63 | 64 | return imgw, o 65 | 66 | def get_gaussian_weights(self, x, y, x1, x2, y1, y2): 67 | w11 = torch.exp(-((x - x1) ** 2 + (y - y1) ** 2)) 68 | w12 = torch.exp(-((x - x1) ** 2 + (y - y2) ** 2)) 69 | w21 = torch.exp(-((x - x2) ** 2 + (y - y1) ** 2)) 70 | w22 = torch.exp(-((x - x2) ** 2 + (y - y2) ** 2)) 71 | 72 | return w11, w12, w21, w22 73 | 74 | def sample_one(self, img, shiftx, shifty, weight): 75 | """ 76 | Input: 77 | -img (N, C, H, W) 78 | -shiftx, shifty (N, c, H, W) 79 | """ 80 | 81 | N, C, H, W = img.size() 82 | 83 | # flatten all (all restored as Tensors) 84 | flat_shiftx = shiftx.view(-1) 85 | flat_shifty = shifty.view(-1) 86 | flat_basex = torch.arange(0, H, requires_grad=False).view(-1, 1)[None, None].cuda().long().repeat(N, C, 1, 87 | W).view(-1) 88 | flat_basey = torch.arange(0, W, requires_grad=False).view(1, -1)[None, None].cuda().long().repeat(N, C, H, 89 | 1).view(-1) 90 | flat_weight = weight.view(-1) 91 | flat_img = img.view(-1) 92 | 93 | # The corresponding positions in I1 94 | idxn = torch.arange(0, N, requires_grad=False).view(N, 1, 1, 1).long().cuda().repeat(1, C, H, W).view(-1) 95 | idxc = torch.arange(0, C, requires_grad=False).view(1, C, 1, 1).long().cuda().repeat(N, 1, H, W).view(-1) 96 | # ttype = flat_basex.type() 97 | idxx = flat_shiftx.long() + flat_basex 98 | idxy = flat_shifty.long() + flat_basey 99 | 100 | # recording the inside part the shifted 101 | mask = idxx.ge(0) & idxx.lt(H) & idxy.ge(0) & idxy.lt(W) 102 | 103 | # Mask off points out of boundaries 104 | ids = (idxn * C * H * W + idxc * H * W + idxx * W + idxy) 105 | ids_mask = torch.masked_select(ids, mask).clone().cuda() 106 | 107 | # (zero part - gt) -> difference 108 | # difference back propagate -> No influence! Whether we do need mask? mask? 109 | # put (add) them together 110 | # Note here! accmulate fla must be true for proper bp 111 | img_warp = torch.zeros([N * C * H * W, ]).cuda() 112 | img_warp.put_(ids_mask, torch.masked_select(flat_img * flat_weight, mask), accumulate=True) 113 | 114 | one_warp = torch.zeros([N * C * H * W, ]).cuda() 115 | one_warp.put_(ids_mask, torch.masked_select(flat_weight, mask), accumulate=True) 116 | 117 | return img_warp.view(N, C, H, W), one_warp.view(N, C, H, W) 118 | -------------------------------------------------------------------------------- /DVSTool/lib/forwardWarpTorch/forwardWarp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class forwadWarp(nn.Module): 6 | def __init__(self, bilinear=True): 7 | super(forwadWarp, self).__init__() 8 | 9 | self.bilinear = bilinear 10 | 11 | def forward(self, srcTensor: torch.Tensor, flow: torch.Tensor, weight: torch.Tensor = None): 12 | if weight is None: 13 | weight = torch.ones_like(srcTensor) 14 | 15 | srcTensor = srcTensor * weight 16 | 17 | self.device = srcTensor.device 18 | N, C, H, W = srcTensor.shape 19 | 20 | xx = torch.arange(0, W, requires_grad=False, device=self.device).view(1, 1, 1, -1).repeat(N, C, H, 1).float() \ 21 | + flow[:, 0:1, :, :] 22 | yy = torch.arange(0, H, requires_grad=False, device=self.device).view(1, 1, -1, 1).repeat(N, C, 1, W).float() \ 23 | + flow[:, 1:2, :, :] 24 | 25 | xxFloor = xx.floor().float().detach() 26 | xxCeil = xxFloor + 1.0 27 | 28 | yyFloor = yy.floor().float().detach() 29 | yyCeil = yyFloor + 1.0 30 | 31 | if self.bilinear: 32 | ltWeight, rtWeight, lbWeight, rbWeight = self.getBilinearWeight(xx, yy, xxFloor, yyFloor) 33 | else: 34 | ltWeight = torch.ones_like(srcTensor).detach() 35 | rtWeight = torch.ones_like(srcTensor).detach() 36 | lbWeight = torch.ones_like(srcTensor).detach() 37 | rbWeight = torch.ones_like(srcTensor).detach() 38 | 39 | ltImg = srcTensor * ltWeight 40 | rtImg = srcTensor * rtWeight 41 | lbImg = srcTensor * lbWeight 42 | rbImg = srcTensor * rbWeight 43 | 44 | ltNorm = weight * ltWeight 45 | rtNorm = weight * rtWeight 46 | lbNorm = weight * lbWeight 47 | rbNorm = weight * rbWeight 48 | 49 | ltTarget, ltScaler = self.splatting(xxFloor, yyFloor, ltImg, ltNorm) 50 | rtTarget, rtScaler = self.splatting(xxCeil, yyFloor, rtImg, rtNorm) 51 | lbTarget, lbScaler = self.splatting(xxFloor, yyCeil, lbImg, lbNorm) 52 | rbTarget, rbScaler = self.splatting(xxCeil, yyCeil, rbImg, rbNorm) 53 | 54 | scale = ltScaler + rtScaler + lbScaler + rbScaler 55 | remapTensor = torch.zeros_like(srcTensor) 56 | 57 | nonZero = scale != 0 58 | remapTensor[nonZero] = (ltTarget[nonZero] + rtTarget[nonZero] + lbTarget[nonZero] + rbTarget[nonZero]) / scale[ 59 | nonZero] 60 | # remapTensor = ltTarget + rtTarget + lbTarget + rbTarget 61 | 62 | # eps = 1e-8 63 | # remapTensor = (ltTarget + rtTarget + lbTarget + rbTarget + eps) / (scale + eps) 64 | 65 | # return remapTensor, scale 66 | return remapTensor 67 | 68 | def getBilinearWeight(self, xx, yy, xxFloor, yyFloor): 69 | alpha = xx - xxFloor 70 | beta = yy - yyFloor 71 | ltWeight = (1 - alpha) * (1 - beta) 72 | rtWeight = alpha * (1 - beta) 73 | lbWeight = (1 - alpha) * beta 74 | rbWeight = alpha * beta 75 | return ltWeight, rtWeight, lbWeight, rbWeight 76 | 77 | def splatting(self, xx, yy, img, Norm): 78 | N, C, H, W = xx.shape 79 | 80 | nn = torch.arange(0, N, requires_grad=False, device=self.device).view(N, 1, 1, 1).long(). \ 81 | repeat(1, C, H, W) # NCHW 82 | cc = torch.arange(0, C, requires_grad=False, device=self.device).view(1, C, 1, 1).long(). \ 83 | repeat(N, 1, H, W) # NCHW 84 | 85 | # grid = xx + yy * W 86 | stride = xx.stride() 87 | 88 | # grid = nn * C * H * W + cc * H * W + yy * W + xx 89 | grid = nn * stride[0] + cc * stride[1] + yy.long() * stride[2] + xx.long() 90 | 91 | mask = (xx.ge(0) & xx.lt(W)) & (yy.ge(0) & yy.lt(H)) 92 | 93 | gridSelect = grid.masked_select(mask).long() 94 | 95 | targetImg = torch.zeros_like(img).float() 96 | 97 | scaler = torch.zeros_like(Norm).float() 98 | 99 | targetImg.put_(gridSelect, img.masked_select(mask), accumulate=True) 100 | scaler.put_(gridSelect, Norm.masked_select(mask), accumulate=True) 101 | 102 | return targetImg, scaler 103 | 104 | # if __name__ == '__main__': 105 | # img = cv2.imread('./135_left.jpg') 106 | # flow = utils.readFlowFile('135.flo') 107 | # devices = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 108 | # 109 | # Func = ForwadWarpLayer() 110 | # H, W, C = img.shape 111 | # 112 | # imgTensor = torch.from_numpy(img).float().to(devices).unsqueeze(0).contiguous() 113 | # flowTensor = torch.from_numpy(flow).float().to(devices).unsqueeze(0).contiguous() 114 | # 115 | # # imgTensor = imgTensor.permute([0, 3, 1, 2]) 116 | # # flowTensor = flowTensor.permute([0, 3, 1, 2]) 117 | # 118 | # remapTensor = Func(imgTensor, flowTensor) 119 | # 120 | # remapImage = (remapTensor[0, ...].to('cpu').numpy()).astype(np.uint8) 121 | # cv2.namedWindow('1', 0) 122 | # cv2.imshow('1', remapImage) 123 | # cv2.waitKey(0) 124 | -------------------------------------------------------------------------------- /DVSTool/DVSBase.py: -------------------------------------------------------------------------------- 1 | from dv import AedatFile 2 | import numpy as np 3 | import torch 4 | import cv2 5 | import pickle 6 | 7 | 8 | class DVSReader(object): 9 | def __init__(self, fileName): 10 | super(DVSReader, self).__init__() 11 | self.C = 3 12 | 13 | with AedatFile(fileName) as f: 14 | self.height, self.width = f['events'].size 15 | self.Events = np.hstack([packet for packet in f['events'].numpy()]) 16 | self.tE = self.Events['timestamp'] 17 | self.xE = self.Events['x'] 18 | self.yE = self.Events['y'] 19 | self.pE = 2 * self.Events['polarity'] - 1 20 | 21 | tImg = [] 22 | Img = [] 23 | # Img H*W*1 24 | for packet in f['frames']: 25 | tImg.append(packet.timestamp) 26 | Img.append(packet.image) 27 | 28 | # self.tImg = np.hstack(tImg) 29 | self.tImg = tImg 30 | # self.Img = np.expand_dims(np.dstack(Img).transpose([2, 0, 1]), axis=1) 31 | self.Img = Img 32 | 33 | def aggregEvent(self, tStart=0, tStop=1e20, P=1): 34 | reverse = False 35 | if tStart >= tStop: 36 | reverse = True 37 | tStart, tStop = tStop, tStart 38 | 39 | if P is not None: 40 | sliceIdx = (self.tE >= tStart) & (self.tE < tStop) & (self.pE == P) 41 | else: 42 | sliceIdx = (self.tE >= tStart) & (self.tE < tStop) 43 | 44 | target = torch.zeros((self.height, self.width)).half().to(self.pE.device) 45 | 46 | if not (1 in sliceIdx): 47 | return target.cpu().char().numpy() 48 | 49 | # tSlice = self.tE[sliceIdx] 50 | xSlice = self.xE[sliceIdx] 51 | ySlice = self.yE[sliceIdx] 52 | pSlice = self.pE[sliceIdx] 53 | 54 | index = ySlice * self.width + xSlice 55 | 56 | target.put_(index=index, source=pSlice, accumulate=True) 57 | target = target.clamp(-10, 10) 58 | 59 | # print(target.max().cpu().item(), target.min().cpu().item()) 60 | if reverse: 61 | target = -target 62 | return target.cpu().char().numpy() 63 | 64 | 65 | class ESIMReader(object): 66 | def __init__(self, fileNames=None): 67 | super(ESIMReader, self).__init__() 68 | tE = [] 69 | xE = [] 70 | yE = [] 71 | pE = [] 72 | self.pathImgStart = [] 73 | self.pathImgStop = [] 74 | self.tImgStart = [] 75 | self.tImgStop = [] 76 | if not isinstance(fileNames, list): 77 | fileNames = [fileNames] 78 | for fileName in fileNames: 79 | fs = open(fileName, 'rb') 80 | record: dict = pickle.load(fs) 81 | tE.extend(record['tE']) 82 | xE.extend(record['xE']) 83 | yE.extend(record['yE']) 84 | pE.extend(record['pE']) 85 | pathImgStart = record['pathImgStart'] 86 | self.pathImgStart.append(pathImgStart) 87 | self.tImgStart.append(record['tImgStart']) 88 | self.tImgStop.append(record['tImgStop']) 89 | 90 | fs.close() 91 | 92 | # self.numStep = 10 93 | # self.tE = torch.from_numpy(np.array(record['tE'])).cuda() 94 | self.tE = torch.from_numpy(np.array(tE)).cuda() 95 | self.xE = torch.from_numpy(np.array(xE)).long().cuda() 96 | self.yE = torch.from_numpy(np.array(yE)).long().cuda() 97 | self.pE = torch.from_numpy(2 * np.array(pE) - 1).half().cuda() 98 | 99 | self.height = record.get('height', 180) 100 | self.width = record.get('width', 240) 101 | 102 | def aggregEvent(self, tStart=0, tStop=1e20, P=None): 103 | reverse = False 104 | if tStart >= tStop: 105 | reverse = True 106 | tStart, tStop = tStop, tStart 107 | 108 | if P is not None: 109 | sliceIdx = (self.tE >= tStart) & (self.tE < tStop) & (self.pE == P) 110 | else: 111 | sliceIdx = (self.tE >= tStart) & (self.tE < tStop) 112 | 113 | target = torch.zeros((self.height, self.width)).half().to(self.pE.device) 114 | 115 | if not (1 in sliceIdx): 116 | return target.cpu().char().numpy() 117 | 118 | # tSlice = self.tE[sliceIdx] 119 | xSlice = self.xE[sliceIdx] 120 | ySlice = self.yE[sliceIdx] 121 | pSlice = self.pE[sliceIdx] 122 | 123 | index = ySlice * self.width + xSlice 124 | 125 | target.put_(index=index, source=pSlice, accumulate=True) 126 | target = target.clamp(-10, 10) 127 | 128 | # print(target.max().cpu().item(), target.min().cpu().item()) 129 | if reverse: 130 | target = -target 131 | return target.cpu().char().numpy() 132 | 133 | 134 | if __name__ == '__main__': 135 | fileName = '/home/sensetime/research/research_vi/EventTool/dataset/1.aedat4' 136 | event = DVSReader(fileName) 137 | cv2.namedWindow('1', 0) 138 | subimg = event.Img[0::10] 139 | for img in subimg: 140 | cv2.imshow('1', img) 141 | cv2.waitKey(0) 142 | -------------------------------------------------------------------------------- /stage1/lib/visualTool.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 2 | 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Tom Runia 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to conditions. 14 | # 15 | # Author: Tom Runia 16 | # Date Created: 2018-08-03 17 | 18 | import numpy as np 19 | import cv2 20 | import torch 21 | 22 | 23 | def make_colorwheel(): 24 | """ 25 | Generates a color wheel for optical flow visualization as presented in: 26 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 27 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 28 | 29 | Code follows the original C++ source code of Daniel Scharstein. 30 | Code follows the the Matlab source code of Deqing Sun. 31 | 32 | Returns: 33 | np.ndarray: Color wheel 34 | """ 35 | 36 | RY = 15 37 | YG = 6 38 | GC = 4 39 | CB = 11 40 | BM = 13 41 | MR = 6 42 | 43 | ncols = RY + YG + GC + CB + BM + MR 44 | colorwheel = np.zeros((ncols, 3)) 45 | col = 0 46 | 47 | # RY 48 | colorwheel[0:RY, 0] = 255 49 | colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY) 50 | col = col + RY 51 | # YG 52 | colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG) 53 | colorwheel[col:col + YG, 1] = 255 54 | col = col + YG 55 | # GC 56 | colorwheel[col:col + GC, 1] = 255 57 | colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC) 58 | col = col + GC 59 | # CB 60 | colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB) 61 | colorwheel[col:col + CB, 2] = 255 62 | col = col + CB 63 | # BM 64 | colorwheel[col:col + BM, 2] = 255 65 | colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM) 66 | col = col + BM 67 | # MR 68 | colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR) 69 | colorwheel[col:col + MR, 0] = 255 70 | return colorwheel 71 | 72 | 73 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 74 | """ 75 | Applies the flow color wheel to (possibly clipped) flow components u and v. 76 | 77 | According to the C++ source code of Daniel Scharstein 78 | According to the Matlab source code of Deqing Sun 79 | 80 | Args: 81 | u (np.ndarray): Input horizontal flow of shape [H,W] 82 | v (np.ndarray): Input vertical flow of shape [H,W] 83 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 84 | 85 | Returns: 86 | np.ndarray: Flow visualization image of shape [H,W,3] 87 | """ 88 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 89 | colorwheel = make_colorwheel() # shape [55x3] 90 | ncols = colorwheel.shape[0] 91 | rad = np.sqrt(np.square(u) + np.square(v)) 92 | a = np.arctan2(-v, -u) / np.pi 93 | fk = (a + 1) / 2 * (ncols - 1) 94 | k0 = np.floor(fk).astype(np.int32) 95 | k1 = k0 + 1 96 | k1[k1 == ncols] = 0 97 | f = fk - k0 98 | for i in range(colorwheel.shape[1]): 99 | tmp = colorwheel[:, i] 100 | col0 = tmp[k0] / 255.0 101 | col1 = tmp[k1] / 255.0 102 | col = (1 - f) * col0 + f * col1 103 | idx = (rad <= 1) 104 | col[idx] = 1 - rad[idx] * (1 - col[idx]) 105 | col[~idx] = col[~idx] * 0.75 # out of range 106 | # Note the 2-i => BGR instead of RGB 107 | ch_idx = 2 - i if convert_to_bgr else i 108 | flow_image[:, :, ch_idx] = np.floor(255 * col) 109 | return flow_image 110 | 111 | 112 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 113 | """ 114 | Expects a two dimensional flow image of shape. 115 | 116 | Args: 117 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 118 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 119 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 120 | 121 | Returns: 122 | np.ndarray: Flow visualization image of shape [H,W,3] 123 | """ 124 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 125 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 126 | if clip_flow is not None: 127 | flow_uv = np.clip(flow_uv, 0, clip_flow) 128 | u = flow_uv[:, :, 0] 129 | v = flow_uv[:, :, 1] 130 | rad = np.sqrt(np.square(u) + np.square(v)) 131 | rad_max = np.max(rad) 132 | epsilon = 1e-5 133 | u = u / (rad_max + epsilon) 134 | v = v / (rad_max + epsilon) 135 | return flow_uv_to_colors(u, v, convert_to_bgr) 136 | 137 | 138 | def viz(img: torch.Tensor = None, name='1', wait=0): 139 | img = img.cpu().detach().float()[0].permute([1,2,0]) 140 | img = (255*(img - img.min()) / (img.max() - img.min())).byte().numpy() 141 | cv2.namedWindow(name, 0) 142 | cv2.imshow(name, img) 143 | cv2.waitKey(wait) 144 | -------------------------------------------------------------------------------- /DVSTool/lib/checkTool.py: -------------------------------------------------------------------------------- 1 | import lib.imgTool_ as trans 2 | import cv2 3 | import imageio 4 | from pathlib import Path 5 | from lib.fileTool import mkPath 6 | import numpy as np 7 | import torch 8 | 9 | 10 | def visualContinuousFrames(imgrendListOut, gtframeTList, path): 11 | mkPath(path) 12 | totalList = [] 13 | for t, (imgrendOut, gtframeT) in enumerate(zip(imgrendListOut, gtframeTList)): 14 | imgrendOut = trans.ToCVImage(imgrendOut) 15 | 16 | gtframeT = trans.ToCVImage(gtframeT) 17 | 18 | imgrendOut = cv2.putText(imgrendOut, "rdOut{}".format(t + 1), (10, 20), 19 | cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 1, cv2.LINE_AA) 20 | 21 | gtframeT = cv2.putText(gtframeT, "gt{}".format(t + 1), (10, 20), 22 | cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 1, cv2.LINE_AA) 23 | 24 | imgcat = trans.makeGrid([gtframeT, imgrendOut], shape=[1, 2]) 25 | totalList.append(imgcat) 26 | 27 | imageio.mimsave(str(Path(path) / 'Input.gif'), totalList, 'GIF', duration=0.5) 28 | 29 | 30 | def checkGrad(net): 31 | for parem in list(net.named_parameters()): 32 | if parem[1].grad is not None: 33 | print(parem[0] + ' \t shape={}, \t mean={}, \t std={}\n'.format(parem[1].shape, 34 | parem[1].grad.abs().mean().cpu().item(), 35 | parem[1].grad.abs().std().cpu().item())) 36 | 37 | 38 | def checkInput(I0t, I1t, It, Et): 39 | cv2.namedWindow('1', 0) 40 | for bidx in range(I0t.size(0)): 41 | It_ = ((It[bidx, 0] + 1) * 127.5).cpu().byte().numpy() 42 | I0t_ = ((I0t[bidx, 0] + 1) * 127.5).cpu().byte().numpy() 43 | I1t_ = ((I1t[bidx, 0] + 1) * 127.5).cpu().byte().numpy() 44 | # Et_ = ((Et[bidx, 0] + 1) * 127.5).cpu().byte().numpy() 45 | Et_ = Et.cpu().float().numpy() 46 | cv2.imshow('1', It_) 47 | cv2.waitKey(0) 48 | 49 | cv2.imshow('1', I0t_) 50 | cv2.waitKey(0) 51 | 52 | cv2.imshow('1', I1t_) 53 | cv2.waitKey(0) 54 | 55 | Pos = Et_[bidx, 0::2, ...] 56 | Neg = Et_[bidx, 1::2, ...] 57 | 58 | for eIdx in range(4): 59 | imgCV = np.zeros((180, 240, 3), dtype=np.uint8) 60 | imgCV[..., 0] = It_ 61 | imgCV[..., 1] = It_ 62 | imgCV[..., 2] = It_ 63 | 64 | pPos = Pos[eIdx, ...] 65 | pNeg = Neg[eIdx, ...] 66 | eventGray = pPos + pNeg 67 | eventGray = ((eventGray - eventGray.min()) / (eventGray.max() - eventGray.min())) * 255 68 | eventGray = cv2.cvtColor(eventGray, cv2.COLOR_BGR2RGB).astype(np.uint8) 69 | 70 | imgCV[..., 2][pPos > 0] = 255 71 | imgCV[..., 1][pNeg < 0] = 255 72 | 73 | cv2.imshow('1', np.concatenate([imgCV, eventGray], axis=1)) 74 | cv2.waitKey(0) 75 | cv2.destroyAllWindows() 76 | 77 | 78 | def checkEvents(imgs: torch.Tensor, events: torch.Tensor): 79 | # N, F, C, H, W = img.shape 80 | N, C, H, W = events.shape 81 | 82 | Pos = events[:, 0::2, ...] 83 | # Neg = -events[:, 1::2, :, ...] 84 | Neg = -events[:, 1::2, ...] 85 | 86 | cv2.namedWindow('1', 0) 87 | 88 | for n in range(N): 89 | for eIdx in range(C // 2): 90 | pPos = Pos[n, eIdx, ...].detach().cpu().numpy() 91 | pNeg = Neg[n, eIdx, ...].detach().cpu().numpy() 92 | 93 | img = np.zeros([H, W, 3]) 94 | imgSample = imgs[n, 0, ...].detach().cpu().numpy() 95 | imgSample = ((imgSample - imgSample.min()) / (imgSample.max() - imgSample.min()) * 255).astype(np.uint8) 96 | 97 | img[:, :, -1] = imgSample 98 | img[:, :, 1] = imgSample 99 | img[:, :, 0] = imgSample 100 | 101 | img[:, :, -1][pPos > 0] = 255 102 | img[:, :, 1][pNeg > 0] = 255 103 | 104 | # cv2.putText(img, '{}_{}'.format(fIdx, eIdx), (20, 20), 105 | # cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 1, cv2.LINE_AA) 106 | cv2.putText(img, '{}_{}'.format(n, eIdx), (20, 20), 107 | cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 1, cv2.LINE_AA) 108 | cv2.imshow('1', img.astype(np.uint8)) 109 | cv2.waitKey(100) 110 | 111 | pPos = Pos[n, ...].sum(dim=0).detach().cpu().numpy() 112 | pNeg = Neg[n, ...].sum(dim=0).detach().cpu().numpy() 113 | 114 | img = np.zeros([H, W, 3]) 115 | imgSample = imgs[n, 0, ...].detach().cpu().numpy() 116 | imgSample = ((imgSample - imgSample.min()) / (imgSample.max() - imgSample.min()) * 255).astype(np.uint8) 117 | 118 | img[:, :, -1] = imgSample 119 | img[:, :, 1] = imgSample 120 | img[:, :, 0] = imgSample 121 | 122 | img[:, :, -1][pPos > 0] = 255 123 | img[:, :, 1][pNeg > 0] = 255 124 | 125 | # cv2.putText(img, '{}_{}'.format(fIdx, eIdx), (20, 20), 126 | # cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 1, cv2.LINE_AA) 127 | cv2.putText(img, 'final_{}'.format(n), (20, 20), 128 | cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 1, cv2.LINE_AA) 129 | cv2.imshow('1', img.astype(np.uint8)) 130 | cv2.waitKey(100) 131 | -------------------------------------------------------------------------------- /DVSTool/lib/eventJiangzhe.py: -------------------------------------------------------------------------------- 1 | # only use for python2 2 | 3 | from __future__ import division 4 | from __future__ import print_function 5 | from glob import glob 6 | import numpy as np 7 | import os.path as op 8 | import struct 9 | import copy 10 | import cv2 11 | import os 12 | from tqdm import tqdm 13 | import aedat 14 | 15 | EVT_DVS = 0 # DVS event type 16 | EVT_APS = 1 # APS event 17 | 18 | 19 | def mkdir(path): 20 | if not os.path.exists(path): 21 | os.makedirs(path) 22 | 23 | 24 | def loadaerdat(datafile='.aedat', length=0): 25 | aeLen = 8 26 | readMode = '>II' 27 | td = 0.000001 28 | 29 | sizeX = 240 30 | sizeY = 180 31 | x0 = 1 32 | y0 = 1 33 | xmask = 0x003ff000 34 | xshift = 12 35 | ymask = 0x7fc00000 36 | yshift = 22 37 | pmask = 0x800 38 | pshift = 11 39 | eventtypeshift = 31 40 | adcmask = 0x3ff 41 | 42 | frame = np.zeros([6, sizeX, sizeY], dtype=np.int32) 43 | 44 | aerdatafh = open(datafile, 'rb') 45 | k = 0 # line number 46 | p = 0 # pointer, position on bytes 47 | statinfo = os.stat(datafile) 48 | if length == 0: 49 | length = statinfo.st_size 50 | print('file size', length) 51 | 52 | # header 53 | lt = aerdatafh.readline() 54 | while lt and lt[0] == '#': 55 | p += len(lt) 56 | k += 1 57 | lt = aerdatafh.readline() 58 | continue 59 | 60 | # variables to parse 61 | timestamps = [] 62 | xaddr = [] 63 | yaddr = [] 64 | pol = [] 65 | frames = [] 66 | # read data-part of file 67 | aerdatafh.seek(p) 68 | s = aerdatafh.read(aeLen) 69 | p += aeLen 70 | pbar = tqdm(total=length) 71 | while p < length: 72 | addr, ts = struct.unpack(readMode, s) 73 | eventtype = (addr >> eventtypeshift) 74 | 75 | # parse event's data 76 | if eventtype == EVT_DVS: # this is a DVS event 77 | x_addr = (addr & xmask) >> xshift 78 | y_addr = (addr & ymask) >> yshift 79 | a_pol = (addr & pmask) >> pshift 80 | 81 | timestamps.append(ts) 82 | xaddr.append(sizeX - x_addr) 83 | yaddr.append(sizeY - y_addr) 84 | pol.append(a_pol) 85 | 86 | if eventtype == EVT_APS: # this is an APS packet 87 | x1 = sizeX 88 | y1 = sizeY 89 | 90 | x_addr = (addr & xmask) >> xshift 91 | y_addr = (addr & ymask) >> yshift 92 | adc_data = addr & adcmask 93 | read_reset = (addr >> 10) & 3 94 | 95 | if x_addr >= x0 and x_addr < x1 and y_addr >= y0 and y_addr < y1: 96 | if (read_reset == 0): # is reset read 97 | frame[0, x_addr, y_addr] = adc_data 98 | frame[4, x_addr, y_addr] = ts # resetTsBuffer; 99 | if (read_reset == 1): # is read signal 100 | # print "read", read_reset 101 | frame[1, x_addr, y_addr] = adc_data 102 | frame[3, x_addr, y_addr] = ts # readTsBuffer; 103 | 104 | if (read_reset == 0) and x_addr == 0 and y_addr == 0: 105 | frame[2, :, :] = frame[0, :, :] - frame[1, :, :] 106 | frame[5, :, :] = frame[3, :, :] - frame[4, :, :] 107 | frames.append(frame) 108 | frame = np.zeros([6, sizeX, sizeY], dtype=np.int32) 109 | 110 | aerdatafh.seek(p) 111 | s = aerdatafh.read(aeLen) 112 | p += aeLen 113 | pbar.update(aeLen) 114 | pbar.close() 115 | 116 | try: 117 | print('read %i (~ %.2fM) AE events, duration= %.2fs' % ( 118 | len(timestamps), len(timestamps) / float(10 ** 6), (timestamps[-1] - timestamps[0]) * td)) 119 | n = 5 120 | print('showing first %i:' % (n)) 121 | print('timestamps: %s \nX-addr: %s\nY-addr: %s\npolarity: %s' % ( 122 | timestamps[0:n], xaddr[0:n], yaddr[0:n], pol[0:n])) 123 | except: 124 | print('failed to print statistics') 125 | 126 | return timestamps, xaddr, yaddr, pol, frames 127 | 128 | 129 | def save_data(timestamps, xaddr, yaddr, pol, frames, save_frame=True, save_event=True, save_path=''): 130 | if save_frame: 131 | start = timestamps[0] 132 | f = open(op.join(save_path, 'images.txt'), 'w') 133 | mkdir(os.path.join(save_path, 'images')) 134 | 135 | for idx, frame in enumerate(frames[1:]): 136 | img_path = op.join(save_path, 'images/%06d.png' % (idx + 1)) 137 | img = frame[2] 138 | img[img < 0] = 0 139 | img = (np.rot90(img, 1) / np.power(2, 10) * 255.0).astype('uint8') 140 | time = (np.max(frame[3]) - start) * 1e-6 141 | 142 | # put the writing as the end 143 | cv2.imwrite(img_path, img) 144 | f.write('%.6f' % time + ' ' + 'images/%06d.png' % (idx + 1) + '\n') 145 | f.close() 146 | 147 | if save_event: 148 | f = open(op.join(save_path, 'events.txt'), 'w') 149 | start = timestamps[0] 150 | timestamps = ['%.6f' % ((x - start) * 1e-6) for x in timestamps] 151 | xaddr = np.array(xaddr) - 1 152 | yaddr = np.array(yaddr) - 1 153 | for line in zip(timestamps, xaddr.tolist(), yaddr.tolist(), pol): 154 | f.write(' '.join([str(item) for item in line]) + '\n') 155 | f.close() 156 | 157 | 158 | if __name__ == '__main__': 159 | # filename = 'DAVIS240C-2020-08-18T11-46-15 0800-00000000-0.aedat' 160 | # save_data(*loadaerdat(filename)) 161 | import aer 162 | 163 | # read all at once 164 | events = aer.AEData("DAVIS240C-2020-08-18T11-46-15 0800-00000000-0.aedat") 165 | pass 166 | -------------------------------------------------------------------------------- /stage1/dataloader/eventReader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | from torch.utils.data.distributed import DistributedSampler 4 | import torch.nn.functional as F 5 | import random 6 | from lib import fileTool as filLib 7 | import numpy as np 8 | from pathlib import Path 9 | import pickle 10 | from dataloader.dataloaderBase import DistributedSamplerVali 11 | import cv2 12 | 13 | # train------------------------------------------------------------------- 14 | class eventReaderTrain(data.Dataset): 15 | def __init__(self, cfg=None): 16 | super(eventReaderTrain, self).__init__() 17 | self.cfg = cfg 18 | self.eventPath = cfg.pathTrainEvent 19 | self.numPerSamples = 10 20 | 21 | self.eventGroups = filLib.getAllFiles(self.eventPath, 'pkl') 22 | self.len = len(self.eventGroups) 23 | self.size = cfg.trainSize 24 | 25 | def __len__(self): 26 | return self.len 27 | 28 | def __getitem__(self, index): 29 | with open(self.eventGroups[index], 'rb') as f: 30 | record = pickle.load(f) 31 | 32 | I0 = record['I0'] 33 | It = record['It'] 34 | I1 = record['I1'] 35 | 36 | F0t = record['F0t'] 37 | F1t = record['F1t'] 38 | 39 | Et = record['Et'] 40 | 41 | C, H, W = It.shape 42 | 43 | # i = random.randint(0, H - self.size[0]) 44 | i = 0 45 | # j = random.randint(0, W - self.size[1]) 46 | j = 0 47 | 48 | I0 = I0[:, i: i + self.size[0], j:j + self.size[1]] 49 | It = It[:, i: i + self.size[0], j:j + self.size[1]] 50 | I1 = I1[:, i: i + self.size[0], j:j + self.size[1]] 51 | 52 | F0t = F0t[:, i: i + self.size[0], j:j + self.size[1]] 53 | F1t = F1t[:, i: i + self.size[0], j:j + self.size[1]] 54 | 55 | Et = Et[:, i: i + self.size[0], j:j + self.size[1]] 56 | 57 | I0 = torch.from_numpy(I0.copy()).float() / 127.5 - 1 58 | I0 = I0.repeat((3, 1, 1)) 59 | 60 | It = torch.from_numpy(It.copy()).float() / 127.5 - 1 61 | It = It.repeat((3, 1, 1)) 62 | 63 | I1 = torch.from_numpy(I1.copy()).float() / 127.5 - 1 64 | I1 = I1.repeat((3, 1, 1)) 65 | 66 | F0t = torch.from_numpy(F0t.copy()).float() 67 | F1t = torch.from_numpy(F1t.copy()).float() 68 | 69 | Et = torch.from_numpy(Et.copy()).float().clamp(-10, 10) 70 | 71 | return I0, It, I1, F0t, F1t, Et 72 | 73 | def getFramGroups(self): 74 | framDirs = filLib.getAllFiles(self.eventPath, 'pkl') 75 | framDirs.sort(key=lambda x: int(Path(x).stem)) 76 | return framDirs 77 | 78 | 79 | def createEventVITrain(cfg=None): 80 | trainDataset = eventReaderTrain(cfg) 81 | 82 | if cfg.envDistributed: 83 | trainSampler = DistributedSampler(trainDataset, num_replicas=cfg.envWorldSize, rank=cfg.envRank) 84 | 85 | trainLoader = data.DataLoader(dataset=trainDataset, 86 | batch_size=cfg.trainBatchPerGPU, 87 | shuffle=False, 88 | num_workers=cfg.envWorkers, 89 | pin_memory=False, # False if memory is not enough 90 | drop_last=False, 91 | sampler=trainSampler 92 | ) 93 | return trainSampler, trainLoader 94 | else: 95 | trainLoader = data.DataLoader(dataset=trainDataset, 96 | batch_size=cfg.trainBatch, 97 | shuffle=True, 98 | num_workers=cfg.envWorkers, 99 | pin_memory=False, # False if memory is not enough 100 | drop_last=False 101 | ) 102 | return None, trainLoader 103 | 104 | 105 | 106 | def checkTrain(): 107 | import configs.configEVI as config 108 | import cv2 109 | cfg = config.Config({'SingleNode': '0'}, False) 110 | sample, Testloader = createEventVITrain(cfg) 111 | cv2.namedWindow('1', 0) 112 | for trainIndex, (I0t, I1t, It, Et) in enumerate(Testloader, 0): 113 | I0t = ((I0t[0] + 1) * 127.5).cpu().numpy().transpose([1, 2, 0]).astype(np.uint8) 114 | I1t = ((I1t[0] + 1) * 127.5).cpu().numpy().transpose([1, 2, 0]).astype(np.uint8) 115 | It = ((It[0] + 1) * 127.5).cpu().numpy().transpose([1, 2, 0]).astype(np.uint8) 116 | 117 | print('check I0t, I1t') 118 | cv2.imshow('1', np.concatenate([I1t, I0t], axis=1)) 119 | cv2.waitKey(0) 120 | 121 | cv2.imshow('1', np.concatenate([It, It], axis=1)) 122 | cv2.waitKey(0) 123 | 124 | Et = Et[0].cpu().numpy().astype(np.float32) 125 | 126 | for eIdx, p in enumerate(Et): 127 | eventImg = p 128 | eventImg = ((eventImg - eventImg.min()) / (eventImg.max() - eventImg.min()) * 255.0).astype( 129 | np.uint8) 130 | eventImg = cv2.cvtColor(eventImg, cv2.COLOR_GRAY2BGR) 131 | 132 | img = It.copy() 133 | 134 | img[:, :, 0][p != 0] = 0 135 | 136 | img[:, :, 2][p > 0] = 255 137 | img[:, :, 1][p < 0] = 255 138 | 139 | cv2.putText(img, '{}'.format(eIdx), (20, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.8, 140 | (0, 0, 255), 141 | 5, 142 | cv2.LINE_AA) 143 | 144 | cv2.imshow('1', np.concatenate([img.astype(np.uint8), eventImg], axis=1)) 145 | cv2.waitKey(0) 146 | 147 | 148 | if __name__ == '__main__': 149 | checkTrain() 150 | -------------------------------------------------------------------------------- /stage1/lib/warp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | 7 | def backWarp(img: torch.Tensor, flow: torch.Tensor): 8 | device = img.device 9 | N, C, H, W = img.size() 10 | 11 | u = flow[:, 0, :, :] 12 | v = flow[:, 1, :, :] 13 | 14 | gridX, gridY = np.meshgrid(np.arange(W), np.arange(H)) 15 | gridX = torch.from_numpy(gridX).detach().to(device) 16 | gridY = torch.from_numpy(gridY).detach().to(device) 17 | 18 | x = gridX.unsqueeze(0).expand_as(u).float() + u 19 | y = gridY.unsqueeze(0).expand_as(v).float() + v 20 | 21 | # range -1 to 1 22 | x = 2 * x / (W - 1.0) - 1.0 23 | y = 2 * y / (H - 1.0) - 1.0 24 | # stacking X and Y 25 | grid = torch.stack((x, y), dim=3) 26 | # Sample pixels using bilinear interpolation. 27 | imgOut = F.grid_sample(img, grid, mode='bilinear', padding_mode='zeros') 28 | 29 | # mask = torch.ones_like(img, requires_grad=False) 30 | # mask = F.grid_sample(mask, grid) 31 | # 32 | # mask[mask < 0.9999] = 0 33 | # mask[mask > 0] = 1 34 | 35 | # return imgOut * (mask.detach()) 36 | return imgOut 37 | 38 | 39 | class forwadWarp(nn.Module): 40 | def __init__(self, bilinear=True): 41 | super(forwadWarp, self).__init__() 42 | 43 | self.bilinear = bilinear 44 | 45 | def forward(self, srcTensor: torch.Tensor, flow: torch.Tensor, weight: torch.Tensor = None): 46 | if weight is None: 47 | weight = torch.ones_like(srcTensor) 48 | 49 | srcTensor = srcTensor * weight 50 | 51 | self.device = srcTensor.device 52 | N, C, H, W = srcTensor.shape 53 | 54 | xx = torch.arange(0, W, requires_grad=False, device=self.device).view(1, 1, 1, -1).repeat(N, C, H, 1).float() \ 55 | + flow[:, 0:1, :, :] 56 | yy = torch.arange(0, H, requires_grad=False, device=self.device).view(1, 1, -1, 1).repeat(N, C, 1, W).float() \ 57 | + flow[:, 1:2, :, :] 58 | 59 | xxFloor = xx.floor().float().detach() 60 | xxCeil = xxFloor + 1.0 61 | 62 | yyFloor = yy.floor().float().detach() 63 | yyCeil = yyFloor + 1.0 64 | 65 | if self.bilinear: 66 | ltWeight, rtWeight, lbWeight, rbWeight = self.getBilinearWeight(xx, yy, xxFloor, yyFloor) 67 | else: 68 | ltWeight = torch.ones_like(srcTensor).detach() 69 | rtWeight = torch.ones_like(srcTensor).detach() 70 | lbWeight = torch.ones_like(srcTensor).detach() 71 | rbWeight = torch.ones_like(srcTensor).detach() 72 | 73 | ltImg = srcTensor * ltWeight 74 | rtImg = srcTensor * rtWeight 75 | lbImg = srcTensor * lbWeight 76 | rbImg = srcTensor * rbWeight 77 | 78 | ltNorm = weight * ltWeight 79 | rtNorm = weight * rtWeight 80 | lbNorm = weight * lbWeight 81 | rbNorm = weight * rbWeight 82 | 83 | ltTarget, ltScaler = self.splatting(xxFloor, yyFloor, ltImg, ltNorm) 84 | rtTarget, rtScaler = self.splatting(xxCeil, yyFloor, rtImg, rtNorm) 85 | lbTarget, lbScaler = self.splatting(xxFloor, yyCeil, lbImg, lbNorm) 86 | rbTarget, rbScaler = self.splatting(xxCeil, yyCeil, rbImg, rbNorm) 87 | 88 | scale = ltScaler + rtScaler + lbScaler + rbScaler 89 | remapTensor = torch.zeros_like(srcTensor) 90 | 91 | nonZero = scale != 0 92 | remapTensor[nonZero] = (ltTarget[nonZero] + rtTarget[nonZero] + lbTarget[nonZero] + rbTarget[nonZero]) / scale[ 93 | nonZero] 94 | # remapTensor = ltTarget + rtTarget + lbTarget + rbTarget 95 | 96 | # eps = 1e-8 97 | # remapTensor = (ltTarget + rtTarget + lbTarget + rbTarget + eps) / (scale + eps) 98 | 99 | # return remapTensor, scale 100 | return remapTensor 101 | 102 | def getBilinearWeight(self, xx, yy, xxFloor, yyFloor): 103 | alpha = xx - xxFloor 104 | beta = yy - yyFloor 105 | ltWeight = (1 - alpha) * (1 - beta) 106 | rtWeight = alpha * (1 - beta) 107 | lbWeight = (1 - alpha) * beta 108 | rbWeight = alpha * beta 109 | return ltWeight, rtWeight, lbWeight, rbWeight 110 | 111 | def splatting(self, xx, yy, img, Norm): 112 | N, C, H, W = xx.shape 113 | 114 | nn = torch.arange(0, N, requires_grad=False, device=self.device).view(N, 1, 1, 1).long(). \ 115 | repeat(1, C, H, W) # NCHW 116 | cc = torch.arange(0, C, requires_grad=False, device=self.device).view(1, C, 1, 1).long(). \ 117 | repeat(N, 1, H, W) # NCHW 118 | 119 | # grid = xx + yy * W 120 | stride = xx.stride() 121 | 122 | # grid = nn * C * H * W + cc * H * W + yy * W + xx 123 | grid = nn * stride[0] + cc * stride[1] + yy.long() * stride[2] + xx.long() 124 | 125 | mask = (xx.ge(0) & xx.lt(W)) & (yy.ge(0) & yy.lt(H)) 126 | 127 | gridSelect = grid.masked_select(mask).long() 128 | 129 | targetImg = torch.zeros_like(img).float() 130 | 131 | scaler = torch.zeros_like(Norm).float() 132 | 133 | targetImg.put_(gridSelect, img.masked_select(mask), accumulate=True) 134 | scaler.put_(gridSelect, Norm.masked_select(mask), accumulate=True) 135 | 136 | return targetImg, scaler 137 | 138 | # if __name__ == '__main__': 139 | # img = cv2.imread('./135_left.jpg') 140 | # flow = utils.readFlowFile('135.flo') 141 | # devices = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 142 | # 143 | # Func = ForwadWarpLayer() 144 | # H, W, C = img.shape 145 | # 146 | # imgTensor = torch.from_numpy(img).float().to(devices).unsqueeze(0).contiguous() 147 | # flowTensor = torch.from_numpy(flow).float().to(devices).unsqueeze(0).contiguous() 148 | # 149 | # # imgTensor = imgTensor.permute([0, 3, 1, 2]) 150 | # # flowTensor = flowTensor.permute([0, 3, 1, 2]) 151 | # 152 | # remapTensor = Func(imgTensor, flowTensor) 153 | # 154 | # remapImage = (remapTensor[0, ...].to('cpu').numpy()).astype(np.uint8) 155 | # cv2.namedWindow('1', 0) 156 | # cv2.imshow('1', remapImage) 157 | # cv2.waitKey(0) 158 | -------------------------------------------------------------------------------- /stage2/lib/warp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | 7 | def backWarp(img: torch.Tensor, flow: torch.Tensor): 8 | device = img.device 9 | N, C, H, W = img.size() 10 | 11 | u = flow[:, 0, :, :] 12 | v = flow[:, 1, :, :] 13 | 14 | gridY, gridX = torch.meshgrid([torch.arange(start=0, end=H, device=device, requires_grad=False), 15 | torch.arange(start=0, end=W, device=device, requires_grad=False)]) 16 | 17 | x = gridX.unsqueeze(0).expand_as(u).float().detach() + u 18 | y = gridY.unsqueeze(0).expand_as(v).float().detach() + v 19 | 20 | # range -1 to 1 21 | x = 2 * x / (W - 1.0) - 1.0 22 | y = 2 * y / (H - 1.0) - 1.0 23 | # stacking X and Y 24 | grid = torch.stack((x, y), dim=3) 25 | # Sample pixels using bilinear interpolation. 26 | imgOut = F.grid_sample(img, grid, mode='bilinear', padding_mode='zeros') 27 | 28 | # mask = torch.ones_like(img, requires_grad=False) 29 | # mask = F.grid_sample(mask, grid) 30 | # 31 | # mask[mask < 0.9999] = 0 32 | # mask[mask > 0] = 1 33 | 34 | # return imgOut * (mask.detach()) 35 | return imgOut 36 | 37 | 38 | class forwadWarp(nn.Module): 39 | def __init__(self, bilinear=True): 40 | super(forwadWarp, self).__init__() 41 | 42 | self.bilinear = bilinear 43 | 44 | def forward(self, srcTensor: torch.Tensor, flow: torch.Tensor, weight: torch.Tensor = None): 45 | if weight is None: 46 | weight = torch.ones_like(srcTensor) 47 | 48 | srcTensor = srcTensor * weight 49 | 50 | self.device = srcTensor.device 51 | N, C, H, W = srcTensor.shape 52 | 53 | xx = torch.arange(0, W, requires_grad=False, device=self.device).view(1, 1, 1, -1).repeat(N, C, H, 1).float() \ 54 | + flow[:, 0:1, :, :] 55 | yy = torch.arange(0, H, requires_grad=False, device=self.device).view(1, 1, -1, 1).repeat(N, C, 1, W).float() \ 56 | + flow[:, 1:2, :, :] 57 | 58 | xxFloor = xx.floor().float().detach() 59 | xxCeil = xxFloor + 1.0 60 | 61 | yyFloor = yy.floor().float().detach() 62 | yyCeil = yyFloor + 1.0 63 | 64 | if self.bilinear: 65 | ltWeight, rtWeight, lbWeight, rbWeight = self.getBilinearWeight(xx, yy, xxFloor, yyFloor) 66 | else: 67 | ltWeight = torch.ones_like(srcTensor).detach() 68 | rtWeight = torch.ones_like(srcTensor).detach() 69 | lbWeight = torch.ones_like(srcTensor).detach() 70 | rbWeight = torch.ones_like(srcTensor).detach() 71 | 72 | ltImg = srcTensor * ltWeight 73 | rtImg = srcTensor * rtWeight 74 | lbImg = srcTensor * lbWeight 75 | rbImg = srcTensor * rbWeight 76 | 77 | ltNorm = weight * ltWeight 78 | rtNorm = weight * rtWeight 79 | lbNorm = weight * lbWeight 80 | rbNorm = weight * rbWeight 81 | 82 | ltTarget, ltScaler = self.splatting(xxFloor, yyFloor, ltImg, ltNorm) 83 | rtTarget, rtScaler = self.splatting(xxCeil, yyFloor, rtImg, rtNorm) 84 | lbTarget, lbScaler = self.splatting(xxFloor, yyCeil, lbImg, lbNorm) 85 | rbTarget, rbScaler = self.splatting(xxCeil, yyCeil, rbImg, rbNorm) 86 | 87 | scale = ltScaler + rtScaler + lbScaler + rbScaler 88 | remapTensor = torch.zeros_like(srcTensor) 89 | 90 | nonZero = scale != 0 91 | remapTensor[nonZero] = (ltTarget[nonZero] + rtTarget[nonZero] + lbTarget[nonZero] + rbTarget[nonZero]) / scale[ 92 | nonZero] 93 | # remapTensor = ltTarget + rtTarget + lbTarget + rbTarget 94 | 95 | # eps = 1e-8 96 | # remapTensor = (ltTarget + rtTarget + lbTarget + rbTarget + eps) / (scale + eps) 97 | 98 | # return remapTensor, scale 99 | return remapTensor 100 | 101 | def getBilinearWeight(self, xx, yy, xxFloor, yyFloor): 102 | alpha = xx - xxFloor 103 | beta = yy - yyFloor 104 | ltWeight = (1 - alpha) * (1 - beta) 105 | rtWeight = alpha * (1 - beta) 106 | lbWeight = (1 - alpha) * beta 107 | rbWeight = alpha * beta 108 | return ltWeight, rtWeight, lbWeight, rbWeight 109 | 110 | def splatting(self, xx, yy, img, Norm): 111 | N, C, H, W = xx.shape 112 | 113 | nn = torch.arange(0, N, requires_grad=False, device=self.device).view(N, 1, 1, 1).long(). \ 114 | repeat(1, C, H, W) # NCHW 115 | cc = torch.arange(0, C, requires_grad=False, device=self.device).view(1, C, 1, 1).long(). \ 116 | repeat(N, 1, H, W) # NCHW 117 | 118 | # grid = xx + yy * W 119 | stride = xx.stride() 120 | 121 | # grid = nn * C * H * W + cc * H * W + yy * W + xx 122 | grid = nn * stride[0] + cc * stride[1] + yy.long() * stride[2] + xx.long() 123 | 124 | mask = (xx.ge(0) & xx.lt(W)) & (yy.ge(0) & yy.lt(H)) 125 | 126 | gridSelect = grid.masked_select(mask).long() 127 | 128 | targetImg = torch.zeros_like(img).float() 129 | 130 | scaler = torch.zeros_like(Norm).float() 131 | 132 | targetImg.put_(gridSelect, img.masked_select(mask), accumulate=True) 133 | scaler.put_(gridSelect, Norm.masked_select(mask), accumulate=True) 134 | 135 | return targetImg, scaler 136 | 137 | # if __name__ == '__main__': 138 | # img = cv2.imread('./135_left.jpg') 139 | # flow = utils.readFlowFile('135.flo') 140 | # devices = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 141 | # 142 | # Func = ForwadWarpLayer() 143 | # H, W, C = img.shape 144 | # 145 | # imgTensor = torch.from_numpy(img).float().to(devices).unsqueeze(0).contiguous() 146 | # flowTensor = torch.from_numpy(flow).float().to(devices).unsqueeze(0).contiguous() 147 | # 148 | # # imgTensor = imgTensor.permute([0, 3, 1, 2]) 149 | # # flowTensor = flowTensor.permute([0, 3, 1, 2]) 150 | # 151 | # remapTensor = Func(imgTensor, flowTensor) 152 | # 153 | # remapImage = (remapTensor[0, ...].to('cpu').numpy()).astype(np.uint8) 154 | # cv2.namedWindow('1', 0) 155 | # cv2.imshow('1', remapImage) 156 | # cv2.waitKey(0) 157 | -------------------------------------------------------------------------------- /DVSTool/lib/visualTool.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 2 | 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Tom Runia 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to conditions. 14 | # 15 | # Author: Tom Runia 16 | # Date Created: 2018-08-03 17 | 18 | import numpy as np 19 | import cv2 20 | import torch 21 | 22 | 23 | def make_colorwheel(): 24 | """ 25 | Generates a color wheel for optical flow visualization as presented in: 26 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 27 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 28 | 29 | Code follows the original C++ source code of Daniel Scharstein. 30 | Code follows the the Matlab source code of Deqing Sun. 31 | 32 | Returns: 33 | np.ndarray: Color wheel 34 | """ 35 | 36 | RY = 15 37 | YG = 6 38 | GC = 4 39 | CB = 11 40 | BM = 13 41 | MR = 6 42 | 43 | ncols = RY + YG + GC + CB + BM + MR 44 | colorwheel = np.zeros((ncols, 3)) 45 | col = 0 46 | 47 | # RY 48 | colorwheel[0:RY, 0] = 255 49 | colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY) 50 | col = col + RY 51 | # YG 52 | colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG) 53 | colorwheel[col:col + YG, 1] = 255 54 | col = col + YG 55 | # GC 56 | colorwheel[col:col + GC, 1] = 255 57 | colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC) 58 | col = col + GC 59 | # CB 60 | colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB) 61 | colorwheel[col:col + CB, 2] = 255 62 | col = col + CB 63 | # BM 64 | colorwheel[col:col + BM, 2] = 255 65 | colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM) 66 | col = col + BM 67 | # MR 68 | colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR) 69 | colorwheel[col:col + MR, 0] = 255 70 | return colorwheel 71 | 72 | 73 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 74 | """ 75 | Applies the flow color wheel to (possibly clipped) flow components u and v. 76 | 77 | According to the C++ source code of Daniel Scharstein 78 | According to the Matlab source code of Deqing Sun 79 | 80 | Args: 81 | u (np.ndarray): Input horizontal flow of shape [H,W] 82 | v (np.ndarray): Input vertical flow of shape [H,W] 83 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 84 | 85 | Returns: 86 | np.ndarray: Flow visualization image of shape [H,W,3] 87 | """ 88 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 89 | colorwheel = make_colorwheel() # shape [55x3] 90 | ncols = colorwheel.shape[0] 91 | rad = np.sqrt(np.square(u) + np.square(v)) 92 | a = np.arctan2(-v, -u) / np.pi 93 | fk = (a + 1) / 2 * (ncols - 1) 94 | k0 = np.floor(fk).astype(np.int32) 95 | k1 = k0 + 1 96 | k1[k1 == ncols] = 0 97 | f = fk - k0 98 | for i in range(colorwheel.shape[1]): 99 | tmp = colorwheel[:, i] 100 | col0 = tmp[k0] / 255.0 101 | col1 = tmp[k1] / 255.0 102 | col = (1 - f) * col0 + f * col1 103 | idx = (rad <= 1) 104 | col[idx] = 1 - rad[idx] * (1 - col[idx]) 105 | col[~idx] = col[~idx] * 0.75 # out of range 106 | # Note the 2-i => BGR instead of RGB 107 | ch_idx = 2 - i if convert_to_bgr else i 108 | flow_image[:, :, ch_idx] = np.floor(255 * col) 109 | return flow_image 110 | 111 | 112 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 113 | """ 114 | Expects a two dimensional flow image of shape. 115 | 116 | Args: 117 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 118 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 119 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 120 | 121 | Returns: 122 | np.ndarray: Flow visualization image of shape [H,W,3] 123 | """ 124 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 125 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 126 | if clip_flow is not None: 127 | flow_uv = np.clip(flow_uv, 0, clip_flow) 128 | u = flow_uv[:, :, 0] 129 | v = flow_uv[:, :, 1] 130 | rad = np.sqrt(np.square(u) + np.square(v)) 131 | rad_max = np.max(rad) 132 | epsilon = 1e-5 133 | u = u / (rad_max + epsilon) 134 | v = v / (rad_max + epsilon) 135 | return flow_uv_to_colors(u, v, convert_to_bgr) 136 | 137 | 138 | def viz(img: torch.Tensor = None, flo: torch.Tensor = None, num='1', savePath=None): 139 | """ 140 | 141 | :param img: tensor(N,3,H,W) or None 142 | :param flo: tensor(N,3,H,W) or None: 143 | :return: None 144 | """ 145 | if img is not None: 146 | N, _, _, _ = img.shape 147 | elif flo is not None: 148 | N, _, _, _ = flo.shape 149 | else: 150 | N = 1 151 | for n in range(N): 152 | if img is not None: 153 | img = (img.float() - img.min()) / (img.max() - img.min()) 154 | # img[img > 0] = 1.0 155 | if img.size(1) == 1: 156 | img = img.repeat([1, 3, 1, 1]) 157 | img = img[n].permute(1, 2, 0).cpu().detach().numpy() 158 | img_flo = img 159 | if flo is not None: 160 | flo = flo[n].permute(1, 2, 0).cpu().detach().numpy() 161 | # map flow to rgb image 162 | flo = flow_to_image(flo) 163 | img_flo = flo 164 | 165 | if all([img is not None, flo is not None]): 166 | img_flo = np.concatenate([img, flo], axis=0) 167 | 168 | cv2.namedWindow(num, 0) 169 | cv2.imshow(num, img_flo) 170 | cv2.waitKey(0) 171 | if savePath is not None: 172 | cv2.imwrite(savePath, img_flo) 173 | -------------------------------------------------------------------------------- /stage2/lib/visualTool.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | def make_colorwheel(): 8 | """ 9 | Generates a color wheel for optical flow visualization as presented in: 10 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 11 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 12 | 13 | Code follows the original C++ source code of Daniel Scharstein. 14 | Code follows the the Matlab source code of Deqing Sun. 15 | 16 | Returns: 17 | np.ndarray: Color wheel 18 | """ 19 | 20 | RY = 15 21 | YG = 6 22 | GC = 4 23 | CB = 11 24 | BM = 13 25 | MR = 6 26 | 27 | ncols = RY + YG + GC + CB + BM + MR 28 | colorwheel = np.zeros((ncols, 3)) 29 | col = 0 30 | 31 | # RY 32 | colorwheel[0:RY, 0] = 255 33 | colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY) 34 | col = col + RY 35 | # YG 36 | colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG) 37 | colorwheel[col:col + YG, 1] = 255 38 | col = col + YG 39 | # GC 40 | colorwheel[col:col + GC, 1] = 255 41 | colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC) 42 | col = col + GC 43 | # CB 44 | colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB) 45 | colorwheel[col:col + CB, 2] = 255 46 | col = col + CB 47 | # BM 48 | colorwheel[col:col + BM, 2] = 255 49 | colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM) 50 | col = col + BM 51 | # MR 52 | colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR) 53 | colorwheel[col:col + MR, 0] = 255 54 | return colorwheel 55 | 56 | 57 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 58 | """ 59 | Applies the flow color wheel to (possibly clipped) flow components u and v. 60 | 61 | According to the C++ source code of Daniel Scharstein 62 | According to the Matlab source code of Deqing Sun 63 | 64 | Args: 65 | u (np.ndarray): Input horizontal flow of shape [H,W] 66 | v (np.ndarray): Input vertical flow of shape [H,W] 67 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 68 | 69 | Returns: 70 | np.ndarray: Flow visualization image of shape [H,W,3] 71 | """ 72 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 73 | colorwheel = make_colorwheel() # shape [55x3] 74 | ncols = colorwheel.shape[0] 75 | rad = np.sqrt(np.square(u) + np.square(v)) 76 | a = np.arctan2(-v, -u) / np.pi 77 | fk = (a + 1) / 2 * (ncols - 1) 78 | k0 = np.floor(fk).astype(np.int32) 79 | k1 = k0 + 1 80 | k1[k1 == ncols] = 0 81 | f = fk - k0 82 | for i in range(colorwheel.shape[1]): 83 | tmp = colorwheel[:, i] 84 | col0 = tmp[k0] / 255.0 85 | col1 = tmp[k1] / 255.0 86 | col = (1 - f) * col0 + f * col1 87 | idx = (rad <= 1) 88 | col[idx] = 1 - rad[idx] * (1 - col[idx]) 89 | col[~idx] = col[~idx] * 0.75 # out of range 90 | # Note the 2-i => BGR instead of RGB 91 | ch_idx = 2 - i if convert_to_bgr else i 92 | flow_image[:, :, ch_idx] = np.floor(255 * col) 93 | return flow_image 94 | 95 | 96 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 97 | """ 98 | Expects a two dimensional flow image of shape. 99 | 100 | Args: 101 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 102 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 103 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 104 | 105 | Returns: 106 | np.ndarray: Flow visualization image of shape [H,W,3] 107 | """ 108 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 109 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 110 | if clip_flow is not None: 111 | flow_uv = np.clip(flow_uv, 0, clip_flow) 112 | u = flow_uv[:, :, 0] 113 | v = flow_uv[:, :, 1] 114 | rad = np.sqrt(np.square(u) + np.square(v)) 115 | rad_max = np.max(rad) 116 | epsilon = 1e-5 117 | u = u / (rad_max + epsilon) 118 | v = v / (rad_max + epsilon) 119 | return flow_uv_to_colors(u, v, convert_to_bgr) 120 | 121 | 122 | def viz(img: torch.Tensor = None, name='1', wait=0): 123 | img = img.cpu().detach().float()[0].permute([1,2,0]) 124 | img = (255*(img - img.min()) / (img.max() - img.min())).byte().numpy() 125 | cv2.namedWindow(name, 0) 126 | cv2.imshow(name, img) 127 | cv2.waitKey(wait) 128 | 129 | 130 | def flow2ImgBatch(flos: torch.Tensor): 131 | N, C, H, W = flos.shape 132 | output = np.zeros([N, H, W, 3]) 133 | for n in range(N): 134 | flo = flos[n, ...].permute(1, 2, 0).cpu().numpy() 135 | output[n, ...] = flow_to_image(flo) 136 | return output 137 | 138 | 139 | def mask2ImgBatch(mask: torch.Tensor): 140 | N, C, H, W = mask.shape 141 | output = np.zeros([N, H, W, 3]) 142 | mask = (mask - mask.min()) / (mask.max() - mask.min()) 143 | mask = (mask * 255).cpu().byte().numpy() 144 | for n in range(N): 145 | output[n, ...] = cv2.applyColorMap(mask[n, 0], cv2.COLORMAP_JET) 146 | return output.astype(np.uint8) 147 | 148 | 149 | def bis(input, dim, index): 150 | N, C, H, W = input.shape 151 | input = F.unfold(input, kernel_size=(24, 24), padding=8, stride=8) 152 | views = [input.size(0)] + [1 if i != dim else -1 for i in range(1, len(input.size()))] 153 | expanse = list(input.size()) 154 | expanse[0] = -1 155 | expanse[dim] = -1 156 | index = index.view(views).expand(expanse) 157 | input = torch.gather(input, dim, index) 158 | 159 | input = F.fold(input, output_size=(H, W), kernel_size=(24, 24), padding=8, stride=8) 160 | 161 | return input 162 | 163 | 164 | def softAttn(input, R4x): 165 | N, C, H, W = input.shape 166 | n = 1 167 | # R4x = torch.softmax(R4x / 0.01, dim=1) 168 | 169 | input_ubfold = F.unfold(input, kernel_size=(24, 24), padding=8, stride=8) 170 | input_ubfold = input_ubfold.view([1, 1728, 54, 96]) 171 | R4x = R4x.view(1, 54, 96, 54, 96) 172 | output = torch.zeros([1, 1728, 54, 96]) 173 | 174 | for x in range(n, output.shape[2] - n - 1): 175 | for y in range(n, output.shape[3] - n - 1): 176 | subWeight = R4x[:, x - n:x + n + 1, y - n:y + n + 1, x, y].unsqueeze(1) 177 | # subWeight = F.softmax(subWeight.contiguous().view([1, 1, -1])/0.00001, dim=-1).view([1, 1, 2*n+1, 2*n+1]) 178 | subWeight = F.softmax(subWeight.contiguous().view([1, 1, -1])/0.15, dim=-1).view( 179 | [1, 1, 2 * n + 1, 2 * n + 1]) 180 | subPatch = input_ubfold[:, :, x - n:x + n + 1, y - n:y + n + 1] 181 | patch = subWeight * subPatch 182 | 183 | weight = subWeight.sum() 184 | 185 | patch = patch.sum(2).sum(2) / weight 186 | output[:, :, x, y] = patch 187 | 188 | output = F.fold(output.view([1, 1728, -1]), output_size=(H, W), kernel_size=(24, 24), padding=8, stride=8) 189 | outPut = output.view([N, C, H, W]) 190 | viz(outPut) 191 | return outPut -------------------------------------------------------------------------------- /stage1/lib/checkTool.py: -------------------------------------------------------------------------------- 1 | import lib.imgTool_ as trans 2 | import cv2 3 | import imageio 4 | from pathlib import Path 5 | from lib.fileTool import mkPath 6 | import numpy as np 7 | import torch 8 | 9 | 10 | def visualContinuousFrames(imgrendListOut, gtframeTList, path): 11 | mkPath(path) 12 | totalList = [] 13 | for t, (imgrendOut, gtframeT) in enumerate(zip(imgrendListOut, gtframeTList)): 14 | imgrendOut = trans.ToCVImage(imgrendOut) 15 | 16 | gtframeT = trans.ToCVImage(gtframeT) 17 | 18 | imgrendOut = cv2.putText(imgrendOut, "rdOut{}".format(t + 1), (10, 20), 19 | cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 1, cv2.LINE_AA) 20 | 21 | gtframeT = cv2.putText(gtframeT, "gt{}".format(t + 1), (10, 20), 22 | cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 1, cv2.LINE_AA) 23 | 24 | imgcat = trans.makeGrid([gtframeT, imgrendOut], shape=[1, 2]) 25 | totalList.append(imgcat) 26 | 27 | imageio.mimsave(str(Path(path) / 'Input.gif'), totalList, 'GIF', duration=0.5) 28 | 29 | 30 | def checkGrad(net): 31 | for parem in list(net.named_parameters()): 32 | if parem[1].grad is not None: 33 | print(parem[0] + ' \t shape={}, \t mean={}, \t std={}\n'.format(parem[1].shape, 34 | parem[1].grad.abs().mean().cpu().item(), 35 | parem[1].grad.abs().std().cpu().item())) 36 | 37 | 38 | def checkImgList(imgList): 39 | cv2.namedWindow('1', 0) 40 | N, C, H, W = imgList[0].shape 41 | for n in N: 42 | for imgn in imgList: 43 | img: torch.Tensor = (imgn[n] - imgn[n].min()) / (imgn[n].max() - imgn[n].min()) 44 | img = (img.cpu().numpy() * 255).astype(np.uint8) 45 | cv2.imshow('1', img) 46 | cv2.waitKey(0) 47 | 48 | 49 | def checkIE(I0t: torch.Tensor, I1t: torch.Tensor, It: torch.Tensor, Et: torch.Tensor): 50 | # N, F, C, H, W = img.shape 51 | N, C, H, W = Et.shape 52 | 53 | cv2.namedWindow('1', 0) 54 | 55 | for n in range(N): 56 | img0t = ((I0t[n].float() - I0t[n].min()) / (I0t[n].max() - I0t[n].min()) * 255).cpu().numpy().transpose( 57 | [1, 2, 0]).astype(np.uint8) 58 | 59 | img1t = ((I1t[n].float() - I1t[n].min()) / (I1t[n].max() - I1t[n].min()) * 255).cpu().numpy().transpose( 60 | [1, 2, 0]).astype(np.uint8) 61 | I = ((It[n].float() - It[n].min()) / (It[n].max() - It[n].min()) * 255).cpu().numpy().transpose( 62 | [1, 2, 0]).astype(np.uint8) 63 | cv2.imshow('1', np.concatenate([img0t, img1t], axis=1)) 64 | cv2.waitKey(100) 65 | cv2.imshow('1', np.concatenate([I, I], axis=1)) 66 | cv2.waitKey(100) 67 | 68 | E = Et[n].cpu().numpy().astype(np.float32) 69 | 70 | for eIdx, p in enumerate(E): 71 | eventImg = p 72 | eventImg = ((eventImg - eventImg.min()) / (eventImg.max() - eventImg.min()) * 255.0).astype( 73 | np.uint8) 74 | eventImg = cv2.cvtColor(eventImg, cv2.COLOR_GRAY2BGR) 75 | 76 | img = I.copy() 77 | 78 | img[:, :, 0][p != 0] = 0 79 | 80 | img[:, :, 2][p > 0] = 255 81 | img[:, :, 1][p < 0] = 255 82 | 83 | cv2.putText(img, '{}'.format(eIdx), (20, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.8, 84 | (0, 0, 255), 85 | 5, 86 | cv2.LINE_AA) 87 | 88 | cv2.imshow('1', np.concatenate([img.astype(np.uint8), eventImg], axis=1)) 89 | cv2.waitKey(100) 90 | 91 | 92 | def checkValInput(IV, IT, ET): 93 | N, C, H, W = IV[0].shape 94 | for n in range(N): 95 | for Iv in IV: 96 | I = ((Iv[n].float() - Iv[n].min()) / (Iv[n].max() - Iv[n].min()) * 255).cpu().numpy().transpose( 97 | [1, 2, 0]).astype(np.uint8) 98 | cv2.imshow('1', I) 99 | cv2.waitKey(0) 100 | for It in IT: 101 | I = ((It[n].float() - It[n].min()) / (It[n].max() - It[n].min()) * 255).cpu().numpy().transpose( 102 | [1, 2, 0]).astype(np.uint8) 103 | cv2.imshow('1', I) 104 | cv2.waitKey(0) 105 | 106 | E = [i[n].cpu().numpy().astype(np.float32) for i in ET] 107 | for Idx, Img in enumerate(IT): 108 | Img = (Img[n] * 127.5 + 127.5).cpu().numpy().astype(np.uint8).transpose([1, 2, 0]) 109 | Et = E[Idx] 110 | Ce, He, We = Et.shape 111 | for eIdx in range(Ce): 112 | p = Et[eIdx, ...] 113 | 114 | pPos = np.zeros_like(p) 115 | pPos[p > 0] = 255 116 | 117 | pNeg = np.zeros_like(p) 118 | pNeg[p < 0] = 255 119 | 120 | eventImg = np.clip( 121 | np.abs(np.stack([np.zeros_like(Img)[:, :, 0], pPos, pNeg], axis=2)) 122 | .astype(np.float32) * 255, 0, 255).astype(np.uint8) 123 | 124 | img = cv2.cvtColor(Img.copy(), cv2.COLOR_GRAY2BGR) 125 | img[..., 0][p != 0] = 0 126 | img[..., 0][p != 0] = 0 127 | img[eventImg != 0] = eventImg[eventImg != 0] 128 | 129 | cv2.putText(img, '{}_{}'.format(1, eIdx), (20, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.8, 130 | (0, 0, 255), 131 | 5, 132 | cv2.LINE_AA) 133 | 134 | cv2.imshow('1', np.concatenate([img.astype(np.uint8), eventImg], axis=1)) 135 | cv2.waitKey(500) 136 | 137 | 138 | def checkTrainInput(I0ts, I1ts, Its, Ets): 139 | cv2.namedWindow('1', 0) 140 | N, C, H, W = I0ts.shape 141 | for n in range(N): 142 | I0t = ((I0ts[n] + 1) * 127.5).cpu().numpy().transpose([1, 2, 0]).astype(np.uint8) 143 | I1t = ((I1ts[n] + 1) * 127.5).cpu().numpy().transpose([1, 2, 0]).astype(np.uint8) 144 | It = ((Its[n] + 1) * 127.5).cpu().numpy().transpose([1, 2, 0]).astype(np.uint8) 145 | 146 | print('check I0t, I1t') 147 | cv2.imshow('1', np.concatenate([I1t, I0t], axis=1)) 148 | cv2.waitKey(0) 149 | 150 | cv2.imshow('1', np.concatenate([It, It], axis=1)) 151 | cv2.waitKey(0) 152 | 153 | Et = Ets[n].cpu().numpy().astype(np.float32) 154 | 155 | for eIdx in range(8): 156 | Img = It.copy() 157 | pPos = Et[eIdx, ...] > 0 158 | pNeg = Et[eIdx, ...] < 0 159 | 160 | eventImg = np.clip( 161 | np.abs(np.stack([np.zeros_like(Img)[:, :, 0], pPos, pNeg], axis=2)) 162 | .astype(np.float32) * 255, 0, 255).astype(np.uint8) 163 | 164 | img = cv2.cvtColor(Img.copy(), cv2.COLOR_GRAY2BGR) 165 | img[..., 0][pPos != 0] = 0 166 | img[..., 0][pNeg != 0] = 0 167 | img[eventImg != 0] = eventImg[eventImg != 0] 168 | 169 | cv2.putText(img, '{}_{}'.format(1, eIdx), (20, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.8, 170 | (0, 0, 255), 171 | 5, 172 | cv2.LINE_AA) 173 | 174 | cv2.imshow('1', np.concatenate([img.astype(np.uint8), eventImg], axis=1)) 175 | cv2.waitKey(500) 176 | -------------------------------------------------------------------------------- /DVSTool/lib/pwcNet/correlation_pytorch1_1/correlation_cuda.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include //works for 1.0.0 6 | #include "correlation_cuda_kernel.cuh" 7 | 8 | int correlation_forward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& output, 9 | int pad_size, 10 | int kernel_size, 11 | int max_displacement, 12 | int stride1, 13 | int stride2, 14 | int corr_type_multiply) 15 | { 16 | 17 | int batchSize = input1.size(0); 18 | 19 | int nInputChannels = input1.size(1); 20 | int inputHeight = input1.size(2); 21 | int inputWidth = input1.size(3); 22 | 23 | int kernel_radius = (kernel_size - 1) / 2; 24 | int border_radius = kernel_radius + max_displacement; 25 | 26 | int paddedInputHeight = inputHeight + 2 * pad_size; 27 | int paddedInputWidth = inputWidth + 2 * pad_size; 28 | 29 | int nOutputChannels = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1); 30 | 31 | int outputHeight = ceil(static_cast(paddedInputHeight - 2 * border_radius) / static_cast(stride1)); 32 | int outputwidth = ceil(static_cast(paddedInputWidth - 2 * border_radius) / static_cast(stride1)); 33 | 34 | rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 35 | rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 36 | output.resize_({batchSize, nOutputChannels, outputHeight, outputwidth}); 37 | 38 | rInput1.fill_(0); 39 | rInput2.fill_(0); 40 | output.fill_(0); 41 | 42 | int success = correlation_forward_cuda_kernel( 43 | output, 44 | output.size(0), 45 | output.size(1), 46 | output.size(2), 47 | output.size(3), 48 | output.stride(0), 49 | output.stride(1), 50 | output.stride(2), 51 | output.stride(3), 52 | input1, 53 | input1.size(1), 54 | input1.size(2), 55 | input1.size(3), 56 | input1.stride(0), 57 | input1.stride(1), 58 | input1.stride(2), 59 | input1.stride(3), 60 | input2, 61 | input2.size(1), 62 | input2.stride(0), 63 | input2.stride(1), 64 | input2.stride(2), 65 | input2.stride(3), 66 | rInput1, 67 | rInput2, 68 | pad_size, 69 | kernel_size, 70 | max_displacement, 71 | stride1, 72 | stride2, 73 | corr_type_multiply, 74 | // at::globalContext().getCurrentCUDAStream() //works for 0.4.1 75 | at::cuda::getCurrentCUDAStream() //works for 1.0.0 76 | ); 77 | 78 | //check for errors 79 | if (!success) { 80 | AT_ERROR("CUDA call failed"); 81 | } 82 | 83 | return 1; 84 | 85 | } 86 | 87 | int correlation_backward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& gradOutput, 88 | at::Tensor& gradInput1, at::Tensor& gradInput2, 89 | int pad_size, 90 | int kernel_size, 91 | int max_displacement, 92 | int stride1, 93 | int stride2, 94 | int corr_type_multiply) 95 | { 96 | 97 | int batchSize = input1.size(0); 98 | int nInputChannels = input1.size(1); 99 | int paddedInputHeight = input1.size(2)+ 2 * pad_size; 100 | int paddedInputWidth = input1.size(3)+ 2 * pad_size; 101 | 102 | int height = input1.size(2); 103 | int width = input1.size(3); 104 | 105 | rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 106 | rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 107 | gradInput1.resize_({batchSize, nInputChannels, height, width}); 108 | gradInput2.resize_({batchSize, nInputChannels, height, width}); 109 | 110 | rInput1.fill_(0); 111 | rInput2.fill_(0); 112 | gradInput1.fill_(0); 113 | gradInput2.fill_(0); 114 | 115 | int success = correlation_backward_cuda_kernel(gradOutput, 116 | gradOutput.size(0), 117 | gradOutput.size(1), 118 | gradOutput.size(2), 119 | gradOutput.size(3), 120 | gradOutput.stride(0), 121 | gradOutput.stride(1), 122 | gradOutput.stride(2), 123 | gradOutput.stride(3), 124 | input1, 125 | input1.size(1), 126 | input1.size(2), 127 | input1.size(3), 128 | input1.stride(0), 129 | input1.stride(1), 130 | input1.stride(2), 131 | input1.stride(3), 132 | input2, 133 | input2.stride(0), 134 | input2.stride(1), 135 | input2.stride(2), 136 | input2.stride(3), 137 | gradInput1, 138 | gradInput1.stride(0), 139 | gradInput1.stride(1), 140 | gradInput1.stride(2), 141 | gradInput1.stride(3), 142 | gradInput2, 143 | gradInput2.size(1), 144 | gradInput2.stride(0), 145 | gradInput2.stride(1), 146 | gradInput2.stride(2), 147 | gradInput2.stride(3), 148 | rInput1, 149 | rInput2, 150 | pad_size, 151 | kernel_size, 152 | max_displacement, 153 | stride1, 154 | stride2, 155 | corr_type_multiply, 156 | // at::globalContext().getCurrentCUDAStream() //works for 0.4.1 157 | at::cuda::getCurrentCUDAStream() //works for 1.0.0 158 | ); 159 | 160 | if (!success) { 161 | AT_ERROR("CUDA call failed"); 162 | } 163 | 164 | return 1; 165 | } 166 | 167 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 168 | m.def("forward", &correlation_forward_cuda, "Correlation forward (CUDA)"); 169 | m.def("backward", &correlation_backward_cuda, "Correlation backward (CUDA)"); 170 | } 171 | 172 | -------------------------------------------------------------------------------- /stage1/lib/pwcNet/correlation_pytorch1_1/correlation_cuda.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include //works for 1.0.0 6 | #include "correlation_cuda_kernel.cuh" 7 | 8 | int correlation_forward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& output, 9 | int pad_size, 10 | int kernel_size, 11 | int max_displacement, 12 | int stride1, 13 | int stride2, 14 | int corr_type_multiply) 15 | { 16 | 17 | int batchSize = input1.size(0); 18 | 19 | int nInputChannels = input1.size(1); 20 | int inputHeight = input1.size(2); 21 | int inputWidth = input1.size(3); 22 | 23 | int kernel_radius = (kernel_size - 1) / 2; 24 | int border_radius = kernel_radius + max_displacement; 25 | 26 | int paddedInputHeight = inputHeight + 2 * pad_size; 27 | int paddedInputWidth = inputWidth + 2 * pad_size; 28 | 29 | int nOutputChannels = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1); 30 | 31 | int outputHeight = ceil(static_cast(paddedInputHeight - 2 * border_radius) / static_cast(stride1)); 32 | int outputwidth = ceil(static_cast(paddedInputWidth - 2 * border_radius) / static_cast(stride1)); 33 | 34 | rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 35 | rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 36 | output.resize_({batchSize, nOutputChannels, outputHeight, outputwidth}); 37 | 38 | rInput1.fill_(0); 39 | rInput2.fill_(0); 40 | output.fill_(0); 41 | 42 | int success = correlation_forward_cuda_kernel( 43 | output, 44 | output.size(0), 45 | output.size(1), 46 | output.size(2), 47 | output.size(3), 48 | output.stride(0), 49 | output.stride(1), 50 | output.stride(2), 51 | output.stride(3), 52 | input1, 53 | input1.size(1), 54 | input1.size(2), 55 | input1.size(3), 56 | input1.stride(0), 57 | input1.stride(1), 58 | input1.stride(2), 59 | input1.stride(3), 60 | input2, 61 | input2.size(1), 62 | input2.stride(0), 63 | input2.stride(1), 64 | input2.stride(2), 65 | input2.stride(3), 66 | rInput1, 67 | rInput2, 68 | pad_size, 69 | kernel_size, 70 | max_displacement, 71 | stride1, 72 | stride2, 73 | corr_type_multiply, 74 | // at::globalContext().getCurrentCUDAStream() //works for 0.4.1 75 | at::cuda::getCurrentCUDAStream() //works for 1.0.0 76 | ); 77 | 78 | //check for errors 79 | if (!success) { 80 | AT_ERROR("CUDA call failed"); 81 | } 82 | 83 | return 1; 84 | 85 | } 86 | 87 | int correlation_backward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& gradOutput, 88 | at::Tensor& gradInput1, at::Tensor& gradInput2, 89 | int pad_size, 90 | int kernel_size, 91 | int max_displacement, 92 | int stride1, 93 | int stride2, 94 | int corr_type_multiply) 95 | { 96 | 97 | int batchSize = input1.size(0); 98 | int nInputChannels = input1.size(1); 99 | int paddedInputHeight = input1.size(2)+ 2 * pad_size; 100 | int paddedInputWidth = input1.size(3)+ 2 * pad_size; 101 | 102 | int height = input1.size(2); 103 | int width = input1.size(3); 104 | 105 | rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 106 | rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 107 | gradInput1.resize_({batchSize, nInputChannels, height, width}); 108 | gradInput2.resize_({batchSize, nInputChannels, height, width}); 109 | 110 | rInput1.fill_(0); 111 | rInput2.fill_(0); 112 | gradInput1.fill_(0); 113 | gradInput2.fill_(0); 114 | 115 | int success = correlation_backward_cuda_kernel(gradOutput, 116 | gradOutput.size(0), 117 | gradOutput.size(1), 118 | gradOutput.size(2), 119 | gradOutput.size(3), 120 | gradOutput.stride(0), 121 | gradOutput.stride(1), 122 | gradOutput.stride(2), 123 | gradOutput.stride(3), 124 | input1, 125 | input1.size(1), 126 | input1.size(2), 127 | input1.size(3), 128 | input1.stride(0), 129 | input1.stride(1), 130 | input1.stride(2), 131 | input1.stride(3), 132 | input2, 133 | input2.stride(0), 134 | input2.stride(1), 135 | input2.stride(2), 136 | input2.stride(3), 137 | gradInput1, 138 | gradInput1.stride(0), 139 | gradInput1.stride(1), 140 | gradInput1.stride(2), 141 | gradInput1.stride(3), 142 | gradInput2, 143 | gradInput2.size(1), 144 | gradInput2.stride(0), 145 | gradInput2.stride(1), 146 | gradInput2.stride(2), 147 | gradInput2.stride(3), 148 | rInput1, 149 | rInput2, 150 | pad_size, 151 | kernel_size, 152 | max_displacement, 153 | stride1, 154 | stride2, 155 | corr_type_multiply, 156 | // at::globalContext().getCurrentCUDAStream() //works for 0.4.1 157 | at::cuda::getCurrentCUDAStream() //works for 1.0.0 158 | ); 159 | 160 | if (!success) { 161 | AT_ERROR("CUDA call failed"); 162 | } 163 | 164 | return 1; 165 | } 166 | 167 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 168 | m.def("forward", &correlation_forward_cuda, "Correlation forward (CUDA)"); 169 | m.def("backward", &correlation_backward_cuda, "Correlation backward (CUDA)"); 170 | } 171 | 172 | -------------------------------------------------------------------------------- /stage2/lib/pwcNet/correlation_pytorch1_1/correlation_cuda.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include //works for 1.0.0 6 | #include "correlation_cuda_kernel.cuh" 7 | 8 | int correlation_forward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& output, 9 | int pad_size, 10 | int kernel_size, 11 | int max_displacement, 12 | int stride1, 13 | int stride2, 14 | int corr_type_multiply) 15 | { 16 | 17 | int batchSize = input1.size(0); 18 | 19 | int nInputChannels = input1.size(1); 20 | int inputHeight = input1.size(2); 21 | int inputWidth = input1.size(3); 22 | 23 | int kernel_radius = (kernel_size - 1) / 2; 24 | int border_radius = kernel_radius + max_displacement; 25 | 26 | int paddedInputHeight = inputHeight + 2 * pad_size; 27 | int paddedInputWidth = inputWidth + 2 * pad_size; 28 | 29 | int nOutputChannels = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1); 30 | 31 | int outputHeight = ceil(static_cast(paddedInputHeight - 2 * border_radius) / static_cast(stride1)); 32 | int outputwidth = ceil(static_cast(paddedInputWidth - 2 * border_radius) / static_cast(stride1)); 33 | 34 | rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 35 | rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 36 | output.resize_({batchSize, nOutputChannels, outputHeight, outputwidth}); 37 | 38 | rInput1.fill_(0); 39 | rInput2.fill_(0); 40 | output.fill_(0); 41 | 42 | int success = correlation_forward_cuda_kernel( 43 | output, 44 | output.size(0), 45 | output.size(1), 46 | output.size(2), 47 | output.size(3), 48 | output.stride(0), 49 | output.stride(1), 50 | output.stride(2), 51 | output.stride(3), 52 | input1, 53 | input1.size(1), 54 | input1.size(2), 55 | input1.size(3), 56 | input1.stride(0), 57 | input1.stride(1), 58 | input1.stride(2), 59 | input1.stride(3), 60 | input2, 61 | input2.size(1), 62 | input2.stride(0), 63 | input2.stride(1), 64 | input2.stride(2), 65 | input2.stride(3), 66 | rInput1, 67 | rInput2, 68 | pad_size, 69 | kernel_size, 70 | max_displacement, 71 | stride1, 72 | stride2, 73 | corr_type_multiply, 74 | // at::globalContext().getCurrentCUDAStream() //works for 0.4.1 75 | at::cuda::getCurrentCUDAStream() //works for 1.0.0 76 | ); 77 | 78 | //check for errors 79 | if (!success) { 80 | AT_ERROR("CUDA call failed"); 81 | } 82 | 83 | return 1; 84 | 85 | } 86 | 87 | int correlation_backward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& gradOutput, 88 | at::Tensor& gradInput1, at::Tensor& gradInput2, 89 | int pad_size, 90 | int kernel_size, 91 | int max_displacement, 92 | int stride1, 93 | int stride2, 94 | int corr_type_multiply) 95 | { 96 | 97 | int batchSize = input1.size(0); 98 | int nInputChannels = input1.size(1); 99 | int paddedInputHeight = input1.size(2)+ 2 * pad_size; 100 | int paddedInputWidth = input1.size(3)+ 2 * pad_size; 101 | 102 | int height = input1.size(2); 103 | int width = input1.size(3); 104 | 105 | rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 106 | rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 107 | gradInput1.resize_({batchSize, nInputChannels, height, width}); 108 | gradInput2.resize_({batchSize, nInputChannels, height, width}); 109 | 110 | rInput1.fill_(0); 111 | rInput2.fill_(0); 112 | gradInput1.fill_(0); 113 | gradInput2.fill_(0); 114 | 115 | int success = correlation_backward_cuda_kernel(gradOutput, 116 | gradOutput.size(0), 117 | gradOutput.size(1), 118 | gradOutput.size(2), 119 | gradOutput.size(3), 120 | gradOutput.stride(0), 121 | gradOutput.stride(1), 122 | gradOutput.stride(2), 123 | gradOutput.stride(3), 124 | input1, 125 | input1.size(1), 126 | input1.size(2), 127 | input1.size(3), 128 | input1.stride(0), 129 | input1.stride(1), 130 | input1.stride(2), 131 | input1.stride(3), 132 | input2, 133 | input2.stride(0), 134 | input2.stride(1), 135 | input2.stride(2), 136 | input2.stride(3), 137 | gradInput1, 138 | gradInput1.stride(0), 139 | gradInput1.stride(1), 140 | gradInput1.stride(2), 141 | gradInput1.stride(3), 142 | gradInput2, 143 | gradInput2.size(1), 144 | gradInput2.stride(0), 145 | gradInput2.stride(1), 146 | gradInput2.stride(2), 147 | gradInput2.stride(3), 148 | rInput1, 149 | rInput2, 150 | pad_size, 151 | kernel_size, 152 | max_displacement, 153 | stride1, 154 | stride2, 155 | corr_type_multiply, 156 | // at::globalContext().getCurrentCUDAStream() //works for 0.4.1 157 | at::cuda::getCurrentCUDAStream() //works for 1.0.0 158 | ); 159 | 160 | if (!success) { 161 | AT_ERROR("CUDA call failed"); 162 | } 163 | 164 | return 1; 165 | } 166 | 167 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 168 | m.def("forward", &correlation_forward_cuda, "Correlation forward (CUDA)"); 169 | m.def("backward", &correlation_backward_cuda, "Correlation backward (CUDA)"); 170 | } 171 | 172 | -------------------------------------------------------------------------------- /stage2/configs/configTest.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import torch 4 | from torch.backends import cudnn 5 | from datetime import datetime 6 | from lib import fileTool as FT 7 | from pathlib import Path 8 | from tensorboardX import SummaryWriter 9 | 10 | 11 | class Config(object): 12 | def __init__(self, gpuList=None, envDistributed=False, expName='gopro_gn_Test'): 13 | super(Config, self).__init__() 14 | 15 | # always change--------------------------------------------------------------------------------------- 16 | self.a_name = expName 17 | 18 | self.step = 2 19 | self.snapShot = 20 20 | self.pathOut = './output/' 21 | self.pathExp = 'New_gn__202101161228' 22 | self.pathWeight = './output/New_gn__202101161228/state/bestModel_epoch400.pth' 23 | 24 | self.trainSize = (432, 768) 25 | # self.trainSize = (255, 255) 26 | self.testScale = 0.5 27 | self.valNumInter = 9 28 | 29 | # self.lrInit = 5e-4 30 | self.lrInit = 2e-5 31 | self.trainEpoch = 5000 32 | self.trainBatchPerGPU = 1 33 | self.lrGamma = 0.999 # decay scale for exp scheduler 34 | 35 | self.trainMean = 0 36 | self.trainStd = 1 37 | # network param------------------------------------------------------------------------------------------- 38 | self.netActivate = 'leakyrelu' # prelu, relu, leakyrelu, swish 39 | if '_bn_' in self.a_name: 40 | self.netNorm = 'bn' # instance, lrn, bn, identity, group 41 | if '_gn_' in self.a_name: 42 | self.netNorm = 'group' # instance, lrn, bn, identity, group 43 | 44 | # self.pathTestEvent = str(Path(__file__).parents[1] / Path('dataset/event/testNew')) 45 | self.pathTestEvent = '/home/sensetime/data/VideoInterpolation/highfps/goPro/240fps/GoPro_public/event/testNew/' 46 | self.pathInference = '/home/sensetime/data/VideoInterpolation/highfps/goPro/240fps/GoPro_public/event/Inference/' 47 | self.envUseApex = False 48 | self.envApexLevel = 'O0' 49 | 50 | # hardware environment----------------------------------------------------------------------------------- 51 | self.envDistributed = envDistributed 52 | self.envnodeName = self.getEnviron('SLURMD_NODENAME', 'SingleNode') 53 | self.envWorldSize = self.getEnviron('SLURM_NTASKS', 1) 54 | self.envNodeID = self.getEnviron('SLURM_NODEID', 0) 55 | self.envRank = self.getEnviron('SLURM_PROCID', 0) 56 | self.envLocalRank = self.getEnviron('SLURM_LOCALID', 0) 57 | 58 | if gpuList is not None and self.envDistributed: 59 | self.gpulist = gpuList[self.envnodeName] 60 | os.environ['CUDA_VISIBLE_DEVICES'] = self.gpulist 61 | 62 | self.envNumGPUs = torch.cuda.device_count() 63 | if self.envNumGPUs > 0: 64 | assert (torch.cuda.is_available()) and cudnn.enabled 65 | 66 | self.envParallel = True if (self.envNumGPUs > 2 and not self.envDistributed) else False 67 | self.testBatchPerGPU = 1 68 | self.netInitType = 'xavier' # normal, xavier, orthogonal, kaiming,default 69 | 70 | self.netInitGain = 0.2 71 | if self.envDistributed: 72 | self.testBatch = self.testBatchPerGPU * self.envWorldSize 73 | elif self.envParallel: 74 | self.testBatch = self.testBatchPerGPU * self.envNumGPUs 75 | else: # CPU 76 | self.testBatch = self.testBatchPerGPU 77 | 78 | # self.netCheck = True 79 | self.netCheck = False 80 | 81 | # ---------------------------------------------------------------------------------------------------- 82 | 83 | self.setRandSeed = 2020 84 | 85 | # path and init------------------------------------------------------------------------------------------ 86 | self.envWorkers = 4 if self.envDistributed else 0 87 | self.pathExp, self.pathEvents, self.pathState, self.pathGif = self.expInit() 88 | 89 | # train config------------------------------------------------------------------------------------------ 90 | if self.envRank == 0: 91 | self.trainLogger = self.logInit() 92 | 93 | if self.step in [1, 2]: 94 | checkpoints = torch.load(self.pathWeight, map_location=torch.device('cpu')) 95 | try: 96 | totalIter = checkpoints['totalIter'] 97 | except Exception as e: 98 | totalIter = 0 99 | self.trainWriter = SummaryWriter(self.pathEvents, purge_step=totalIter) 100 | else: 101 | self.trainWriter = SummaryWriter(self.pathEvents) 102 | 103 | self.trainMaxSave = 2 104 | 105 | # finally ------------------------------------------------------------------------------------------------ 106 | if 0 == self.envRank: 107 | self.record() 108 | 109 | def set_gpu_ids(self): 110 | str_ids = self.gpu_ids 111 | self.gpu_ids = [] 112 | for str_id in str_ids: 113 | id = int(str_id) 114 | if id >= 0: 115 | self.gpu_ids.append(id) 116 | if len(self.gpu_ids) > 0: 117 | torch.cuda.set_device(self.gpu_ids[0]) 118 | 119 | def expInit(self): 120 | if self.step not in [1, 2]: 121 | # if self.envRank == 0: 122 | now = datetime.now().strftime("%Y%m%d%H%M") 123 | pathExp = str(Path(self.pathOut) / '{}_{}'.format(self.a_name, now)) 124 | pathEvents = str(Path(pathExp) / 'events') 125 | pathState = str(Path(pathExp) / 'state') 126 | pathGif = str(Path(pathExp) / 'gif') 127 | 128 | FT.mkPath(pathExp) 129 | FT.mkPath(pathEvents) 130 | FT.mkPath(pathState) 131 | FT.mkPath(pathGif) 132 | else: 133 | pathExp = str(Path(self.pathOut) / self.pathExp) 134 | assert Path(pathExp).is_dir(), pathExp 135 | pathEvents = str(Path(pathExp) / 'events') 136 | pathState = str(Path(pathExp) / 'state') 137 | pathGif = str(Path(pathExp) / 'gif') 138 | return pathExp, pathEvents, pathState, pathGif 139 | 140 | def logInit(self): 141 | logger = logging.getLogger(__name__) 142 | logfile = str(Path(self.pathExp) / 'log.txt') 143 | fh = logging.FileHandler(logfile, mode='a') 144 | formatter = logging.Formatter("%(asctime)s: %(message)s") 145 | fh.setFormatter(formatter) 146 | logger.addHandler(fh) 147 | return logger 148 | 149 | def getEnviron(self, key: str, default): 150 | 151 | out = os.environ[key] if all([key in os.environ, self.envDistributed]) else default 152 | if isinstance(default, int): 153 | out = int(out) 154 | elif isinstance(default, str): 155 | out = str(out) 156 | return out 157 | 158 | def record(self): 159 | logger_ = logging.getLogger(__name__ + 'sub') 160 | logging.basicConfig(level=logging.INFO) 161 | path = os.path.join(self.pathExp, 'config.txt') 162 | fh = logging.FileHandler(path, mode='w') 163 | formatter = logging.Formatter("%(message)s") 164 | fh.setFormatter(formatter) 165 | logger_.addHandler(fh) 166 | 167 | args = vars(self) 168 | logger_.info('-------------Config-------------------') 169 | for k, v in sorted(args.items()): 170 | logger_.info('{} = {}'.format(k, v)) 171 | logger_.info('--------------End---------------------') 172 | -------------------------------------------------------------------------------- /stage1/configs/configEVI.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import torch 4 | from torch.backends import cudnn 5 | from datetime import datetime 6 | from lib import fileTool as FT 7 | from pathlib import Path 8 | from tensorboardX import SummaryWriter 9 | 10 | 11 | class Config(object): 12 | def __init__(self, gpuList=None, envDistributed=False, expName='gopro_gn_'): 13 | super(Config, self).__init__() 14 | 15 | # always change--------------------------------------------------------------------------------------- 16 | self.a_name = expName 17 | 18 | # step=0:train EVI from scratch 19 | # lr=5e-4 20 | 21 | # step=1:continue to train EVI if killed 22 | # lr=5e-5 23 | 24 | self.step = 0 25 | self.snapShot = 10 26 | 27 | self.pathOut = './output/' # the output path to save checkpoints 28 | self.pathExp = 'Demo_train_on_lowfps_202111201639' # used for step=1: dir name of saved checkpoints 29 | # used for step=1: filename of saved weight for resume 30 | self.pathWeight = './output/Demo_train_on_lowfps_202111201639/state/bestEVI_epoch100.pth' 31 | 32 | self.trainSize = (176, 240) # crop size of train data 33 | 34 | self.valScale = 1 # deprecated 35 | self.valNumInter = 3 # number of frames to be interpolated 36 | 37 | self.netCheck = False # check grad, visual middle results 38 | 39 | self.lrInit = 5e-4 40 | # self.lrInit = 5e-5 41 | self.trainEpoch = 200 42 | # self.trainBatchPerGPU = 8 43 | self.trainBatchPerGPU = 2 # batch per gpu, total_batch=batch_per_gpu * num_of_gpus 44 | self.lrGamma = 0.999 # decay scale for exp scheduler 45 | 46 | # network param------------------------------------------------------------------------------------------- 47 | self.netActivate = 'leakyrelu' # prelu, relu, leakyrelu, swish 48 | self.netNorm = 'group' # instance, lrn, bn, identity, group 49 | 50 | self.netInitType = 'xavier' # normal, xavier, orthogonal, kaiming,default 51 | 52 | self.netInitGain = 0.2 53 | self.pathTrainEvent = str(Path(__file__).parents[2] / 'dataset/fastDVS_dataset/train') 54 | 55 | # optimizer----------------------------------------------------------------------------------------------- 56 | 57 | self.optPolicy = 'Adam' # adam, sgd 58 | self.optBetas = [0.9, 0.999] 59 | self.optDecay = 0 60 | self.optMomentum = 0.995 61 | 62 | self.lrPolicy = 'exp' # step, multistep, cosine, plateau, exp 63 | self.lrdecayIter = 100 64 | self.lrMilestones = [100, 150] 65 | 66 | self.trainMean = 0 67 | self.trainStd = 1 68 | 69 | self.envUseApex = False 70 | self.envApexLevel = 'O0' 71 | 72 | # hardware environment----------------------------------------------------------------------------------- 73 | self.envDistributed = envDistributed 74 | self.envnodeName = self.getEnviron('SLURMD_NODENAME', 'SingleNode') 75 | self.envWorldSize = self.getEnviron('SLURM_NTASKS', 1) 76 | self.envNodeID = self.getEnviron('SLURM_NODEID', 0) 77 | self.envRank = self.getEnviron('SLURM_PROCID', 0) 78 | self.envLocalRank = self.getEnviron('SLURM_LOCALID', 0) 79 | 80 | self.envNumGPUs = torch.cuda.device_count() 81 | if self.envNumGPUs > 0: 82 | assert (torch.cuda.is_available()) and cudnn.enabled 83 | 84 | self.envParallel = True if (self.envNumGPUs > 2 and not self.envDistributed) else False 85 | 86 | # self.testBatchPerGPU = self.trainBatchPerGPU 87 | self.testBatchPerGPU = 2 * self.trainBatchPerGPU 88 | 89 | if self.envDistributed: 90 | self.trainBatch = self.trainBatchPerGPU * self.envWorldSize 91 | self.testBatch = self.testBatchPerGPU * self.envWorldSize 92 | elif self.envParallel: 93 | self.trainBatch = self.trainBatchPerGPU * self.envNumGPUs 94 | self.testBatch = self.testBatchPerGPU * self.envNumGPUs 95 | # self.testBatch = self.testBatchPerGPU 96 | else: # CPU 97 | self.trainBatch = self.trainBatchPerGPU 98 | self.testBatch = self.testBatchPerGPU 99 | self.trainSize = (176, 240) 100 | self.valScale = 1 101 | 102 | # self.netCheck = True 103 | self.netCheck = False 104 | self.trainVisual = False 105 | 106 | # ---------------------------------------------------------------------------------------------------- 107 | 108 | self.setRandSeed = 2021 109 | 110 | # path and init------------------------------------------------------------------------------------------ 111 | self.envWorkers = 4 if self.envDistributed else 0 112 | self.pathExp, self.pathEvents, self.pathState, self.pathGif = self.expInit() 113 | 114 | # train config------------------------------------------------------------------------------------------ 115 | if self.envRank == 0: 116 | self.trainLogger = self.logInit() 117 | 118 | if self.step in [1, 2]: 119 | checkpoints = torch.load(self.pathWeight, map_location=torch.device('cpu')) 120 | try: 121 | totalIter = checkpoints['totalIter'] 122 | except Exception as e: 123 | totalIter = 0 124 | self.trainWriter = SummaryWriter(self.pathEvents, purge_step=totalIter) 125 | else: 126 | self.trainWriter = SummaryWriter(self.pathEvents) 127 | 128 | self.trainMaxSave = 4 # max checkpoints to save 129 | 130 | # finally ------------------------------------------------------------------------------------------------ 131 | if 0 == self.envRank: 132 | self.record() 133 | 134 | def set_gpu_ids(self): 135 | str_ids = self.gpu_ids 136 | self.gpu_ids = [] 137 | for str_id in str_ids: 138 | id = int(str_id) 139 | if id >= 0: 140 | self.gpu_ids.append(id) 141 | if len(self.gpu_ids) > 0: 142 | torch.cuda.set_device(self.gpu_ids[0]) 143 | 144 | def expInit(self): 145 | if 0 == self.step: 146 | # if self.envRank == 0: 147 | now = datetime.now().strftime("%Y%m%d%H%M") 148 | pathExp = str(Path(self.pathOut) / '{}_{}'.format(self.a_name, now)) 149 | pathEvents = str(Path(pathExp) / 'events') 150 | pathState = str(Path(pathExp) / 'state') 151 | pathGif = str(Path(pathExp) / 'gif') 152 | 153 | FT.mkPath(pathExp) 154 | FT.mkPath(pathEvents) 155 | FT.mkPath(pathState) 156 | FT.mkPath(pathGif) 157 | else: 158 | pathExp = str(Path(self.pathOut) / self.pathExp) 159 | assert Path(pathExp).is_dir(), pathExp 160 | pathEvents = str(Path(pathExp) / 'events') 161 | pathState = str(Path(pathExp) / 'state') 162 | pathGif = str(Path(pathExp) / 'gif') 163 | return pathExp, pathEvents, pathState, pathGif 164 | 165 | def logInit(self): 166 | logger = logging.getLogger(__name__) 167 | logfile = str(Path(self.pathExp) / 'log.txt') 168 | fh = logging.FileHandler(logfile, mode='a') 169 | formatter = logging.Formatter("%(asctime)s: %(message)s") 170 | fh.setFormatter(formatter) 171 | logger.addHandler(fh) 172 | return logger 173 | 174 | def getEnviron(self, key: str, default): 175 | 176 | out = os.environ[key] if all([key in os.environ, self.envDistributed]) else default 177 | if isinstance(default, int): 178 | out = int(out) 179 | elif isinstance(default, str): 180 | out = str(out) 181 | return out 182 | 183 | def record(self): 184 | logger_ = logging.getLogger(__name__ + 'sub') 185 | logging.basicConfig(level=logging.INFO) 186 | path = os.path.join(self.pathExp, 'config.txt') 187 | fh = logging.FileHandler(path, mode='w') 188 | formatter = logging.Formatter("%(message)s") 189 | fh.setFormatter(formatter) 190 | logger_.addHandler(fh) 191 | 192 | args = vars(self) 193 | logger_.info('-------------Config-------------------') 194 | for k, v in sorted(args.items()): 195 | logger_.info('{} = {}'.format(k, v)) 196 | logger_.info('--------------End---------------------') 197 | -------------------------------------------------------------------------------- /stage2/configs/configEVI.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import torch 4 | from torch.backends import cudnn 5 | from datetime import datetime 6 | from lib import fileTool as FT 7 | from pathlib import Path 8 | from tensorboardX import SummaryWriter 9 | 10 | 11 | class Config(object): 12 | def __init__(self, gpuList=None, envDistributed=False, expName='gopro_gn_'): 13 | super(Config, self).__init__() 14 | 15 | # always change--------------------------------------------------------------------------------------- 16 | self.a_name = expName 17 | 18 | # 1:train stage2 19 | # lr=5e-5 20 | 21 | # 2:continue train stage if killed 22 | # lr=5e-5 23 | 24 | self.step = 1 25 | self.snapShot = 10 26 | 27 | self.pathOut = './output/' 28 | self.pathExp = 'Demo_train_on_lowfps_S1' # dir name for experiment of stage 1 29 | self.pathWeight = f'./output/{self.pathExp}/state/bestEVI_epoch100.pth' # path for weight trained in stage 1 30 | 31 | self.dump = False # save image results or not 32 | self.outPathS2 = './output/Demo_train_on_lowfps_S1/Real_S2/' # path to save image results 33 | 34 | self.trainSize = (176, 240) 35 | 36 | self.valScale = 1 37 | self.valNumInter = 3 38 | 39 | self.netCheck = False 40 | 41 | self.lrInit = 5e-4 42 | self.trainEpoch = 5000 43 | 44 | self.trainBatchPerGPU = 8 45 | self.testBatchPerGPU = 2 * self.trainBatchPerGPU 46 | if self.dump: 47 | self.testBatchPerGPU = 1 48 | 49 | self.lrGamma = 0.999 # decay scale for exp scheduler 50 | 51 | # network param------------------------------------------------------------------------------------------- 52 | self.netActivate = 'leakyrelu' # prelu, relu, leakyrelu, swish 53 | 54 | self.netNorm = 'group' # instance, lrn, bn, identity, group 55 | 56 | self.netInitType = 'xavier' # normal, xavier, orthogonal, kaiming,default 57 | 58 | self.netInitGain = 0.2 59 | self.pathTrainEvent = str(Path(__file__).parents[2] / 'dataset/fastDVS_dataset/train') 60 | self.pathValEvent = str(Path(__file__).parents[2] / 'dataset/fastDVS_dataset/test') 61 | 62 | # optimizer----------------------------------------------------------------------------------------------- 63 | 64 | self.optPolicy = 'Adam' # adam, sgd 65 | self.optBetas = [0.9, 0.999] 66 | self.optDecay = 0 67 | self.optMomentum = 0.995 68 | 69 | self.lrPolicy = 'exp' # step, multistep, cosine, plateau, exp 70 | self.lrdecayIter = 100 71 | self.lrMilestones = [100, 150] 72 | 73 | self.trainMean = 0 74 | self.trainStd = 1 75 | 76 | self.envUseApex = False 77 | self.envApexLevel = 'O0' 78 | 79 | # hardware environment----------------------------------------------------------------------------------- 80 | self.envDistributed = envDistributed 81 | self.envnodeName = self.getEnviron('SLURMD_NODENAME', 'SingleNode') 82 | self.envWorldSize = self.getEnviron('SLURM_NTASKS', 1) 83 | self.envNodeID = self.getEnviron('SLURM_NODEID', 0) 84 | self.envRank = self.getEnviron('SLURM_PROCID', 0) 85 | self.envLocalRank = self.getEnviron('SLURM_LOCALID', 0) 86 | 87 | if gpuList is not None and self.envDistributed: 88 | self.gpulist = gpuList[self.envnodeName] 89 | os.environ['CUDA_VISIBLE_DEVICES'] = self.gpulist 90 | 91 | self.envNumGPUs = torch.cuda.device_count() 92 | if self.envNumGPUs > 0: 93 | assert (torch.cuda.is_available()) and cudnn.enabled 94 | 95 | self.envParallel = True if (self.envNumGPUs > 2 and not self.envDistributed) else False 96 | 97 | # self.testBatchPerGPU = self.trainBatchPerGPU 98 | 99 | if self.envDistributed: 100 | self.trainBatch = self.trainBatchPerGPU * self.envWorldSize 101 | self.testBatch = self.testBatchPerGPU * self.envWorldSize 102 | elif self.envParallel: 103 | self.trainBatch = self.trainBatchPerGPU * self.envNumGPUs 104 | self.testBatch = self.testBatchPerGPU * self.envNumGPUs 105 | # self.testBatch = self.testBatchPerGPU 106 | else: # CPU 107 | self.trainBatch = self.trainBatchPerGPU 108 | self.testBatch = self.testBatchPerGPU 109 | # self.valScale = 0.8 110 | self.valScale = 1 111 | self.trainSize = (64, 64) 112 | 113 | self.trainVisual = False 114 | 115 | # ---------------------------------------------------------------------------------------------------- 116 | 117 | self.setRandSeed = 2021 118 | 119 | # path and init------------------------------------------------------------------------------------------ 120 | self.envWorkers = 4 if self.envDistributed else 0 121 | self.pathExp, self.pathEvents, self.pathState, self.pathGif = self.expInit() 122 | 123 | # train config------------------------------------------------------------------------------------------ 124 | if self.envRank == 0: 125 | self.trainLogger = self.logInit() 126 | 127 | if self.step in [1, 2]: 128 | checkpoints = torch.load(self.pathWeight, map_location=torch.device('cpu')) 129 | try: 130 | totalIter = checkpoints['totalIter'] 131 | except Exception as e: 132 | totalIter = 0 133 | self.trainWriter = SummaryWriter(self.pathEvents, purge_step=totalIter) 134 | else: 135 | self.trainWriter = SummaryWriter(self.pathEvents) 136 | 137 | self.trainMaxSave = 10 138 | 139 | # finally ------------------------------------------------------------------------------------------------ 140 | if 0 == self.envRank: 141 | self.record() 142 | 143 | def set_gpu_ids(self): 144 | str_ids = self.gpu_ids 145 | self.gpu_ids = [] 146 | for str_id in str_ids: 147 | id = int(str_id) 148 | if id >= 0: 149 | self.gpu_ids.append(id) 150 | if len(self.gpu_ids) > 0: 151 | torch.cuda.set_device(self.gpu_ids[0]) 152 | 153 | def expInit(self): 154 | if 0 == self.step: 155 | # if self.envRank == 0: 156 | now = datetime.now().strftime("%Y%m%d%H%M") 157 | pathExp = str(Path(self.pathOut) / '{}_{}'.format(self.a_name, now)) 158 | pathEvents = str(Path(pathExp) / 'events') 159 | pathState = str(Path(pathExp) / 'state') 160 | pathGif = str(Path(pathExp) / 'gif') 161 | 162 | FT.mkPath(pathExp) 163 | FT.mkPath(pathEvents) 164 | FT.mkPath(pathState) 165 | FT.mkPath(pathGif) 166 | else: 167 | pathExp = str(Path(self.pathOut) / self.pathExp) 168 | assert Path(pathExp).is_dir(), pathExp 169 | pathEvents = str(Path(pathExp) / 'events') 170 | pathState = str(Path(pathExp) / 'state') 171 | pathGif = str(Path(pathExp) / 'gif') 172 | return pathExp, pathEvents, pathState, pathGif 173 | 174 | def logInit(self): 175 | logger = logging.getLogger(__name__) 176 | logfile = str(Path(self.pathExp) / 'log.txt') 177 | fh = logging.FileHandler(logfile, mode='a') 178 | formatter = logging.Formatter("%(asctime)s: %(message)s") 179 | fh.setFormatter(formatter) 180 | logger.addHandler(fh) 181 | return logger 182 | 183 | def getEnviron(self, key: str, default): 184 | 185 | out = os.environ[key] if all([key in os.environ, self.envDistributed]) else default 186 | if isinstance(default, int): 187 | out = int(out) 188 | elif isinstance(default, str): 189 | out = str(out) 190 | return out 191 | 192 | def record(self): 193 | logger_ = logging.getLogger(__name__ + 'sub') 194 | logging.basicConfig(level=logging.INFO) 195 | path = os.path.join(self.pathExp, 'config.txt') 196 | fh = logging.FileHandler(path, mode='w') 197 | formatter = logging.Formatter("%(message)s") 198 | fh.setFormatter(formatter) 199 | logger_.addHandler(fh) 200 | 201 | args = vars(self) 202 | logger_.info('-------------Config-------------------') 203 | for k, v in sorted(args.items()): 204 | logger_.info('{} = {}'.format(k, v)) 205 | logger_.info('--------------End---------------------') 206 | --------------------------------------------------------------------------------