├── model ├── __init__.py ├── base │ ├── __init__.py │ └── base_model.py ├── hyper │ ├── bessel.npy │ ├── __init__.py │ ├── fourier_bessel.py │ └── hyper_dynamic.py ├── eraft │ ├── utils.py │ ├── corr.py │ ├── update.py │ ├── eraft.py │ └── image_utils.py ├── eitr │ ├── eitr.py │ ├── position_encoding.py │ ├── transformer_encoder.py │ ├── transformer_decoder.py │ └── transformer.py ├── default_config.py ├── nernet_model.py └── loss.py ├── config ├── webvid_root.txt ├── evbird_lhy.txt ├── mvsec_test_flow.txt ├── mvsec_test.txt ├── ijrr_test.txt ├── evaid_test.txt ├── hqf_test.txt ├── test_evbird.yaml ├── test_etnet_original.yaml ├── test_hypere2vid_original.yaml ├── test_e2vid++_original.yaml ├── test_eraft_original.yaml ├── test_evflow_original.yaml ├── test_nernet_original.yaml ├── train_ablation_e2vid_esim.yaml ├── train_v2v_etnet_10k.yaml ├── train_v2v_evflow_10k.yaml ├── train_v2v_eraft_10k.yaml ├── train_v2v_e2vid_10k.yaml ├── train_ablation_e2vid_filtered.yaml ├── train_ablation_e2vid_10k_fixed.yaml ├── train_ablation_e2vid_hdr.yaml ├── train_v2v_hyper_10k.yaml └── webvid100_unfiltered.txt ├── ckpt_paths ├── v2v_etnet_10k.txt ├── v2v_hyper_10k.txt ├── v2v_e2vid_10k.txt ├── v2v_eraft_10k.txt ├── eraft_original.txt ├── etnet_original.txt ├── nernet_original.txt ├── v2v_evflow_10k.txt ├── evflow_original.txt ├── hypere2vid_original.txt └── e2vid++_original.txt ├── utils ├── __init__.py ├── henri_compatible.py ├── default_config.py ├── extract_images_MMP.py ├── timers.py ├── data.py ├── color_utils.py ├── myutil.py ├── parse_config.py └── training_utils.py ├── .gitignore ├── PerceptualSimilarity └── models │ ├── weights │ ├── v0.0 │ │ ├── alex.pth │ │ ├── vgg.pth │ │ └── squeeze.pth │ └── v0.1 │ │ ├── alex.pth │ │ ├── vgg.pth │ │ └── squeeze.pth │ ├── base_model.py │ └── __init__.py ├── scripts ├── subsample_unfiltered.py ├── convert_checkpoint_from_original.py ├── clean_checkpoints.py ├── flow_result_to_col.py ├── generate_random_thresholds.py ├── esim_to_voxel.py ├── testset_evcnt_maps.py ├── result_to_col.py ├── aedat4_to_h5.py ├── make_ref_videos.py ├── select_best_checkpoint.py ├── save_gt_images.py ├── hs_ergb_to_h5.py ├── evaid_to_h5.py ├── ijrr_to_h5.py └── qwen_vl_annotate.py ├── requirements.txt ├── data ├── data_interface.py ├── v2v_core_esim.py └── esim_dataset.py ├── clear_experiment.sh └── test_flow.py /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/base/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /config/webvid_root.txt: -------------------------------------------------------------------------------- 1 | /mnt/ssd/webvid/ -------------------------------------------------------------------------------- /ckpt_paths/v2v_etnet_10k.txt: -------------------------------------------------------------------------------- 1 | checkpoints/v2v_etnet_10k/epoch_0096.pth -------------------------------------------------------------------------------- /ckpt_paths/v2v_hyper_10k.txt: -------------------------------------------------------------------------------- 1 | checkpoints/v2v_hyper_10k/epoch_0078.pth -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import * 2 | from .event_utils import * -------------------------------------------------------------------------------- /ckpt_paths/v2v_e2vid_10k.txt: -------------------------------------------------------------------------------- 1 | checkpoints/v2v_e2vid_10k/epoch_0077.pth 2 | -------------------------------------------------------------------------------- /ckpt_paths/v2v_eraft_10k.txt: -------------------------------------------------------------------------------- 1 | checkpoints/v2v_eraft_10k/epoch_0049.pth 2 | -------------------------------------------------------------------------------- /ckpt_paths/eraft_original.txt: -------------------------------------------------------------------------------- 1 | checkpoints/eraft_original/eraft_mvsec_20.pth -------------------------------------------------------------------------------- /ckpt_paths/etnet_original.txt: -------------------------------------------------------------------------------- 1 | checkpoints/etnet_original/etnet_state_dict.pth -------------------------------------------------------------------------------- /ckpt_paths/nernet_original.txt: -------------------------------------------------------------------------------- 1 | checkpoints/nernet_original/ner_state_dict.pth -------------------------------------------------------------------------------- /ckpt_paths/v2v_evflow_10k.txt: -------------------------------------------------------------------------------- 1 | checkpoints/v2v_evflow_10k/epoch_0049.pth 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints 2 | results 3 | tensorboard_logs 4 | __pycache__ 5 | **.pyc -------------------------------------------------------------------------------- /ckpt_paths/evflow_original.txt: -------------------------------------------------------------------------------- 1 | checkpoints/evflow_original/flow_model_state_dict.pth -------------------------------------------------------------------------------- /ckpt_paths/hypere2vid_original.txt: -------------------------------------------------------------------------------- 1 | checkpoints/hypere2vid_original/hyper_state_dict.pth -------------------------------------------------------------------------------- /model/hyper/bessel.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HYLZ-2019/V2V/HEAD/model/hyper/bessel.npy -------------------------------------------------------------------------------- /ckpt_paths/e2vid++_original.txt: -------------------------------------------------------------------------------- 1 | checkpoints/e2vid++_original/reconstruction_model_state_dict.pth -------------------------------------------------------------------------------- /config/evbird_lhy.txt: -------------------------------------------------------------------------------- 1 | /mnt/ssd/evbird/maque.h5 2 | /mnt/ssd/evbird/xique_fly.h5 3 | /mnt/ssd/evbird/heitiane_1.h5 -------------------------------------------------------------------------------- /model/hyper/__init__.py: -------------------------------------------------------------------------------- 1 | from .hyper_dynamic import ConvolutionalContextFusion, DynamicAtomGeneration, DynamicConv 2 | -------------------------------------------------------------------------------- /PerceptualSimilarity/models/weights/v0.0/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HYLZ-2019/V2V/HEAD/PerceptualSimilarity/models/weights/v0.0/alex.pth -------------------------------------------------------------------------------- /PerceptualSimilarity/models/weights/v0.0/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HYLZ-2019/V2V/HEAD/PerceptualSimilarity/models/weights/v0.0/vgg.pth -------------------------------------------------------------------------------- /PerceptualSimilarity/models/weights/v0.1/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HYLZ-2019/V2V/HEAD/PerceptualSimilarity/models/weights/v0.1/alex.pth -------------------------------------------------------------------------------- /PerceptualSimilarity/models/weights/v0.1/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HYLZ-2019/V2V/HEAD/PerceptualSimilarity/models/weights/v0.1/vgg.pth -------------------------------------------------------------------------------- /PerceptualSimilarity/models/weights/v0.0/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HYLZ-2019/V2V/HEAD/PerceptualSimilarity/models/weights/v0.0/squeeze.pth -------------------------------------------------------------------------------- /PerceptualSimilarity/models/weights/v0.1/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HYLZ-2019/V2V/HEAD/PerceptualSimilarity/models/weights/v0.1/squeeze.pth -------------------------------------------------------------------------------- /config/mvsec_test_flow.txt: -------------------------------------------------------------------------------- 1 | /mnt/ssd/MVSEC_wflow/indoor_flying1.h5 2 | /mnt/ssd/MVSEC_wflow/indoor_flying2.h5 3 | /mnt/ssd/MVSEC_wflow/indoor_flying3.h5 4 | /mnt/ssd/MVSEC_wflow/outdoor_day1.h5 5 | /mnt/ssd/MVSEC_wflow/outdoor_day2.h5 -------------------------------------------------------------------------------- /config/mvsec_test.txt: -------------------------------------------------------------------------------- 1 | /mnt/ssd/MVSEC_cut/indoor_flying1.h5 2 | /mnt/ssd/MVSEC_cut/indoor_flying2.h5 3 | /mnt/ssd/MVSEC_cut/indoor_flying3.h5 4 | /mnt/ssd/MVSEC_cut/indoor_flying4.h5 5 | /mnt/ssd/MVSEC_cut/outdoor_day1.h5 6 | /mnt/ssd/MVSEC_cut/outdoor_day2.h5 -------------------------------------------------------------------------------- /config/ijrr_test.txt: -------------------------------------------------------------------------------- 1 | /mnt/ssd/IJRR_cut/boxes_6dof.h5 2 | /mnt/ssd/IJRR_cut/calibration.h5 3 | /mnt/ssd/IJRR_cut/dynamic_6dof.h5 4 | /mnt/ssd/IJRR_cut/office_zigzag.h5 5 | /mnt/ssd/IJRR_cut/poster_6dof.h5 6 | /mnt/ssd/IJRR_cut/shapes_6dof.h5 7 | /mnt/ssd/IJRR_cut/slider_depth.h5 8 | -------------------------------------------------------------------------------- /scripts/subsample_unfiltered.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | with open("config/webvid10000_unfiltered.txt", "r") as f: 4 | lines = f.readlines() 5 | 6 | sub = random.sample(lines, 1000) 7 | 8 | with open("config/webvid1000_unfiltered.txt", "w") as f: 9 | f.write("".join(sub)) 10 | 11 | sub = random.sample(sub, 100) 12 | 13 | with open("config/webvid100_unfiltered.txt", "w") as f: 14 | f.write("".join(sub)) -------------------------------------------------------------------------------- /config/evaid_test.txt: -------------------------------------------------------------------------------- 1 | /mnt/ssd/EventAid-R-h5/ball.h5 2 | /mnt/ssd/EventAid-R-h5/bear.h5 3 | /mnt/ssd/EventAid-R-h5/box.h5 4 | /mnt/ssd/EventAid-R-h5/building.h5 5 | /mnt/ssd/EventAid-R-h5/outdoor.h5 6 | /mnt/ssd/EventAid-R-h5/playball.h5 7 | /mnt/ssd/EventAid-R-h5/room1.h5 8 | /mnt/ssd/EventAid-R-h5/sculpture.h5 9 | /mnt/ssd/EventAid-R-h5/toy.h5 10 | /mnt/ssd/EventAid-R-h5/traffic.h5 11 | /mnt/ssd/EventAid-R-h5/wall.h5 -------------------------------------------------------------------------------- /config/hqf_test.txt: -------------------------------------------------------------------------------- 1 | /mnt/ssd/HQF_h5/bike_bay_hdr.h5 2 | /mnt/ssd/HQF_h5/boxes.h5 3 | /mnt/ssd/HQF_h5/desk.h5 4 | /mnt/ssd/HQF_h5/desk_fast.h5 5 | /mnt/ssd/HQF_h5/desk_hand_only.h5 6 | /mnt/ssd/HQF_h5/desk_slow.h5 7 | /mnt/ssd/HQF_h5/engineering_posters.h5 8 | /mnt/ssd/HQF_h5/high_texture_plants.h5 9 | /mnt/ssd/HQF_h5/poster_pillar_1.h5 10 | /mnt/ssd/HQF_h5/poster_pillar_2.h5 11 | /mnt/ssd/HQF_h5/reflective_materials.h5 12 | /mnt/ssd/HQF_h5/slow_and_fast_desk.h5 13 | /mnt/ssd/HQF_h5/slow_hand.h5 14 | /mnt/ssd/HQF_h5/still_life.h5 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name python=3.10.14 3 | # $ pip install -r requirements.txt 4 | dv-processing==1.7.9 # only needed for aedat4 reading 5 | einops 6 | event-voxel-builder 7 | h5py 8 | hdf5plugin 9 | matplotlib 10 | numpy 11 | opencv-python 12 | torch==2.4.0 13 | torchaudio 14 | torchvision 15 | torchmetrics 16 | tqdm 17 | pyyaml 18 | pandas 19 | scikit-image 20 | IPython 21 | tensorboard 22 | ffmpeg # Option for video reading, delete related code if you only want to use OpenCV for video reading 23 | moviepy==1.0.3 # Version matters -------------------------------------------------------------------------------- /model/base/base_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | from abc import abstractmethod 4 | 5 | 6 | class BaseModel(nn.Module): 7 | """ 8 | Base class for all models 9 | """ 10 | @abstractmethod 11 | def forward(self, *inputs): 12 | """ 13 | Forward pass logic 14 | 15 | :return: Model output 16 | """ 17 | raise NotImplementedError 18 | 19 | def __str__(self): 20 | """ 21 | Model prints with number of trainable parameters 22 | """ 23 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 24 | params = sum([np.prod(p.size()) for p in model_parameters]) 25 | return super().__str__() + '\nTrainable parameters: {}'.format(params) 26 | -------------------------------------------------------------------------------- /scripts/convert_checkpoint_from_original.py: -------------------------------------------------------------------------------- 1 | # Run this script in the root directory of another codebase, such as https://github.com/TimoStoff/event_cnn_minimal, since the torch.load will try to import model classes from the original code structure. Then move the extracted state_dict-only checkpoint wherever you need it. 2 | 3 | import torch 4 | 5 | ckpt = torch.load("pretrained/flow_model.pth") 6 | 7 | new_dict = { 8 | "state_dict": ckpt["state_dict"], 9 | } 10 | 11 | torch.save(new_dict, "pretrained/flow_model_state_dict.pth") 12 | 13 | ''' 14 | # Code for ERAFT: 15 | 16 | import torch 17 | 18 | in_pth = "mvsec_20.tar" 19 | checkpoint = torch.load(in_pth, map_location="cpu") 20 | out_pth = "checkpoints/eraft_original/eraft_mvsec_20.pth" 21 | 22 | new_dict = { 23 | "state_dict": checkpoint["model"], 24 | } 25 | torch.save(new_dict, out_pth) 26 | 27 | ''' -------------------------------------------------------------------------------- /scripts/clean_checkpoints.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import sys 4 | 5 | exp_list = sorted(glob.glob("checkpoints/*")) 6 | exp_list = [os.path.basename(exp) for exp in exp_list] 7 | exp_list = [exp for exp in exp_list if exp[2] != "5"] 8 | 9 | print(exp_list) 10 | 11 | for exp_name in exp_list: 12 | ckpt_path_file = f"ckpt_paths/{exp_name}.txt" 13 | try: 14 | with open(ckpt_path_file, "r") as f: 15 | lines = f.readlines() 16 | lines = [line.strip() for line in lines if line.strip() != ""] 17 | last_epoch = lines[-1].split("/")[-1] 18 | print(exp_name, last_epoch) 19 | 20 | all_ckpts = sorted(glob.glob(f"checkpoints/{exp_name}/*.pth")) 21 | for ac in all_ckpts: 22 | if os.path.basename(ac) != last_epoch: 23 | print(f"Removing {ac}") 24 | os.remove(ac) 25 | except Exception as e: 26 | print(e) -------------------------------------------------------------------------------- /utils/henri_compatible.py: -------------------------------------------------------------------------------- 1 | from utils.default_config import default_config 2 | import copy 3 | from parse_config import ConfigParser 4 | 5 | 6 | def make_henri_compatible(checkpoint, final_activation=''): 7 | """ 8 | Checkpoints have ConfigParser type configs, whereas Henri checkpoints have 9 | dictionary type configs or "arch, model" dicts. 10 | We will generate and add a ConfigParser to the checkpoint and return it. 11 | """ 12 | assert ('config' in checkpoint or ('arch' in checkpoint and 'model' in checkpoint)) 13 | check_config = checkpoint['config'] if 'config' in checkpoint else checkpoint 14 | new_config = copy.deepcopy(default_config) 15 | new_config['arch']['type'] = check_config['arch'] 16 | new_config['arch']['args']['unet_kwargs'] = check_config['model'] 17 | if final_activation: 18 | new_config['arch']['args']['unet_kwargs']['final_activation'] = final_activation 19 | config = ConfigParser(new_config) 20 | checkpoint['config'] = config 21 | print(new_config) 22 | return checkpoint 23 | -------------------------------------------------------------------------------- /data/data_interface.py: -------------------------------------------------------------------------------- 1 | from utils.util import get_obj_from_str 2 | from torch.utils.data import ConcatDataset 3 | import pandas as pd 4 | from tqdm import tqdm 5 | 6 | def make_concat_dataset(configs): 7 | data_file = configs["data_file"] 8 | class_name = configs["class_name"] 9 | dataset_type = get_obj_from_str(class_name) 10 | data_paths = pd.read_csv(data_file, header=None).values.flatten().tolist() 11 | 12 | begin_seq = configs.get("begin_seq", 0) 13 | end_seq = configs.get("end_seq", len(data_paths)) 14 | data_paths = data_paths[begin_seq:end_seq] 15 | 16 | dataset_list = [] 17 | print('Concatenating {} datasets'.format(dataset_type)) 18 | for data_path in tqdm(data_paths): 19 | dataset_list.append(dataset_type(data_path, configs)) 20 | print("Total samples: ", sum([len(d) for d in dataset_list])) 21 | return ConcatDataset(dataset_list) 22 | 23 | def make_concat_multi_dataset(configs): 24 | datasets = [] 25 | for config in configs: 26 | datasets.append(make_concat_dataset(config)) 27 | return ConcatDataset(datasets) -------------------------------------------------------------------------------- /config/test_evbird.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: v2v_e2vid_10k 2 | test_output_dir: results/v2v_e2vid_10k 3 | 4 | module: 5 | loss: 6 | lpips_weight: 1.0 7 | lpips_type: vgg 8 | l2_weight: 0 9 | l1_weight: 1.0 10 | ssim_weight: 0 11 | temporal_consistency_weight: 1.0 12 | optical_flow_source: raft_small 13 | temporal_consistency_L0: 20 14 | 15 | normalize_voxels: false 16 | model: 17 | target: model.model.E2VIDRecurrent 18 | params: 19 | unet_kwargs: 20 | num_bins: 5 21 | skip_type: sum 22 | recurrent_block_type: convlstm 23 | num_encoders: 3 24 | base_num_channels: 32 25 | num_residual_blocks: 2 26 | use_upsample_conv: true 27 | final_activation: "" 28 | norm: none 29 | 30 | test_stage: 31 | test_batch_size: 1 32 | test_num_workers: 4 33 | test: 34 | - data_file: config/evbird_lhy.txt 35 | class_name: data.testh5.FPS_H5Dataset 36 | dataset_name: evbird 37 | FPS: 100 38 | H: 260 39 | W: 346 40 | num_bins: 5 41 | sequence_length: 80 42 | interpolate_bins: false -------------------------------------------------------------------------------- /model/eraft/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy import interpolate 5 | 6 | 7 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 8 | """ Wrapper for grid_sample, uses pixel coordinates """ 9 | H, W = img.shape[-2:] 10 | xgrid, ygrid = coords.split([1,1], dim=-1) 11 | xgrid = 2*xgrid/(W-1) - 1 12 | ygrid = 2*ygrid/(H-1) - 1 13 | 14 | grid = torch.cat([xgrid, ygrid], dim=-1) 15 | img = F.grid_sample(img, grid, align_corners=True) 16 | 17 | if mask: 18 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 19 | return img, mask.float() 20 | 21 | return img 22 | 23 | 24 | def coords_grid(batch, ht, wd): 25 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) 26 | coords = torch.stack(coords[::-1], dim=0).float() 27 | return coords[None].repeat(batch, 1, 1, 1) 28 | 29 | 30 | def upflow8(flow, mode='bilinear'): 31 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 32 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 33 | -------------------------------------------------------------------------------- /model/eitr/eitr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | from model.model_util import CropSize 6 | from .u_trans import mls_tpa 7 | 8 | 9 | class EITR(mls_tpa): 10 | def __init__(self, eitr_kwargs): 11 | super().__init__(eitr_kwargs['num_bins'], eitr_kwargs['norm']) 12 | 13 | def forward(self, event_tensor): 14 | """ 15 | :param event_tensor: N x num_bins x H x W 16 | :return: output dict with image taking values in [0,1], and 17 | displacement within event_tensor. 18 | N x 1 x H x W 19 | """ 20 | n, c, H, W = event_tensor.size() 21 | 22 | # pad size 23 | factor = {'h':8, 'w':8} 24 | pad_crop = CropSize(W, H, factor) 25 | if (H % factor['h'] != 0) or (W % factor['w'] != 0): 26 | event_tensor = pad_crop.pad(event_tensor) 27 | 28 | out = self.func(event_tensor) 29 | #print("Min and max of out: ", out.min(), out.max()) 30 | 31 | # crop size 32 | if (H % factor['h'] != 0) or (W % factor['w'] != 0): 33 | out = pad_crop.crop(out) 34 | 35 | return {'image': out} 36 | 37 | -------------------------------------------------------------------------------- /model/eitr/position_encoding.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | import numpy as np 5 | 6 | class PositionalEncodingSine(nn.Module): 7 | def __init__(self, d_hid, n_position=20000): 8 | super().__init__() 9 | pos_table = self._get_sinusoid_encoding_table(n_position, d_hid) 10 | self.register_buffer("pos_table", pos_table) 11 | 12 | def _get_sinusoid_encoding_table(self, n_position, d_hid): 13 | def get_position_angle_vec(position): 14 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] 15 | 16 | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) 17 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) 18 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) 19 | return torch.FloatTensor(sinusoid_table).unsqueeze(0) 20 | 21 | def forward(self, x): 22 | return self.pos_table[:, :x.size(1)] 23 | 24 | def build_position_encoding(pos_type, d_model): 25 | if pos_type == 'sine': 26 | position_embedding = PositionalEncodingSine(d_model) 27 | else: 28 | raise ValueError(f"not support {pos_type}") 29 | return position_embedding 30 | -------------------------------------------------------------------------------- /clear_experiment.sh: -------------------------------------------------------------------------------- 1 | # Usage: ./clear_experiment.sh 2 | # rm -r tensorboard_logs/{experiment_name} 3 | # rm ckpt_paths/{experiment_name}.txt 4 | # rm -r checkpoints/{experiment_name} 5 | 6 | #!/bin/bash 7 | 8 | # Check if experiment name is provided 9 | if [ -z "$1" ]; then 10 | echo "Usage: $0 " 11 | exit 1 12 | fi 13 | 14 | experiment_name=$1 15 | 16 | # Define paths 17 | tensorboard_dir="tensorboard_logs/${experiment_name}" 18 | ckpt_file="ckpt_paths/${experiment_name}.txt" 19 | 20 | # Remove tensorboard directory if it exists 21 | if [ -d "$tensorboard_dir" ]; then 22 | echo "Removing directory: $tensorboard_dir" 23 | rm -r "$tensorboard_dir" 24 | else 25 | echo "Directory not found: $tensorboard_dir" 26 | fi 27 | 28 | # Remove checkpoint file if it exists 29 | if [ -f "$ckpt_file" ]; then 30 | echo "Removing file: $ckpt_file" 31 | rm "$ckpt_file" 32 | else 33 | echo "File not found: $ckpt_file" 34 | fi 35 | 36 | # Remove checkpoints directory if it exists 37 | checkpoints_dir="checkpoints/${experiment_name}" 38 | if [ -d "$checkpoints_dir" ]; then 39 | echo "Removing directory: $checkpoints_dir" 40 | rm -r "$checkpoints_dir" 41 | else 42 | echo "Directory not found: $checkpoints_dir" 43 | fi 44 | 45 | echo "Cleanup for experiment '$experiment_name' complete." -------------------------------------------------------------------------------- /config/test_etnet_original.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: etnet_original 2 | test_output_dir: results/etnet_original 3 | 4 | module: 5 | loss: 6 | lpips_weight: 1.0 7 | lpips_type: alex 8 | l2_weight: 0 9 | l1_weight: 0 10 | ssim_weight: 0 11 | temporal_consistency_weight: 1.0 12 | 13 | normalize_voxels: false 14 | model: 15 | target: model.eitr.eitr.EITR 16 | params: 17 | eitr_kwargs: 18 | num_bins: 5 19 | norm: none 20 | 21 | test_stage: 22 | test_num_workers: 4 23 | test: 24 | - data_file: config/evaid_test.txt 25 | class_name: data.testh5.TestH5Dataset 26 | dataset_name: evaid 27 | num_bins: 5 28 | sequence_length: 80 29 | interpolate_bins: true 30 | - data_file: config/ijrr_test.txt 31 | class_name: data.testh5.TestH5Dataset 32 | dataset_name: ijrr 33 | num_bins: 5 34 | sequence_length: 80 35 | interpolate_bins: true 36 | - data_file: config/hqf_test.txt 37 | class_name: data.testh5.TestH5Dataset 38 | dataset_name: hqf 39 | num_bins: 5 40 | sequence_length: 80 41 | interpolate_bins: true 42 | - data_file: config/mvsec_test.txt 43 | class_name: data.testh5.TestH5Dataset 44 | dataset_name: mvsec 45 | num_bins: 5 46 | sequence_length: 80 47 | interpolate_bins: true -------------------------------------------------------------------------------- /scripts/flow_result_to_col.py: -------------------------------------------------------------------------------- 1 | # Tool script that converts a row from a test result to a column (easy to copy & paste into a feishu doc / excel sheet). 2 | 3 | import os 4 | import sys 5 | 6 | sequences = { 7 | "MVSEC": [ 8 | "indoor_flying1", 9 | "indoor_flying2", 10 | "indoor_flying3", 11 | "outdoor_day1", 12 | "outdoor_day2" 13 | ] 14 | } 15 | 16 | all_metric_names = [] 17 | for dataset in ["MVSEC"]: 18 | for seqname in sequences[dataset]: 19 | for metric in ["dense_EPE", "dense_3PE", "sparse_EPE", "sparse_3PE"]: 20 | all_metric_names.append(f"{dataset}/{seqname}/{metric}") 21 | 22 | with open("debug/col_heads.txt", "w") as f: 23 | for key in all_metric_names: 24 | f.write(f"{key}\n") 25 | 26 | # line_n is the line number in vscode 27 | def extract_line(file_name, line_n): 28 | with open(file_name, "r", encoding="UTF-8") as f: 29 | lines = f.readlines() 30 | head = lines[0].split(",") 31 | 32 | data = {} 33 | dataline = lines[line_n-1].split(",") 34 | for i in range(len(dataline)): 35 | data[head[i]] = dataline[i] 36 | 37 | of_path = os.path.join(os.path.dirname(file_name), f"col_from_line_{line_n:03d}_{dataline[0]}.txt") 38 | 39 | with open(of_path, "w", encoding="UTF-8") as of: 40 | for key in all_metric_names: 41 | val = float(data[key]) 42 | of.write(f"{val:.03f}\n") 43 | 44 | if __name__ == "__main__": 45 | file_name = sys.argv[1] 46 | line_n = int(sys.argv[2]) 47 | extract_line(file_name, line_n) -------------------------------------------------------------------------------- /scripts/generate_random_thresholds.py: -------------------------------------------------------------------------------- 1 | # Used to generate random thresholds for fixed-threshold ablation experiments. 2 | 3 | import numpy as np 4 | 5 | def ran_thres(threshold_range=[0.05, 2], max_thres_pos_neg_gap=1.5): 6 | thres_1 = np.random.uniform(*threshold_range) 7 | pos_neg_gap = np.random.uniform(1, max_thres_pos_neg_gap) 8 | thres_2 = thres_1 * pos_neg_gap 9 | if np.random.rand() > 0.5: 10 | pos_thres = thres_1 11 | neg_thres = thres_2 12 | else: 13 | pos_thres = thres_2 14 | neg_thres = thres_1 15 | return pos_thres, neg_thres 16 | 17 | def process_file(input_file): 18 | # Original content of file: 19 | # 000401_000450/5876366.mp4 225 20 | # 000351_000400/1050217609.mp4 885 21 | 22 | # Target content of file: 23 | # 000401_000450/5876366.mp4 225 0.07 0.075 24 | # 000351_000400/1050217609.mp4 885 0.05 0.06 25 | with open(input_file, 'r') as f: 26 | lines = f.readlines() 27 | processed_lines = [] 28 | 29 | for line in lines: 30 | line = line.strip() 31 | if not line: 32 | continue 33 | parts = line.split() 34 | video_path = parts[0] 35 | frame_num = parts[1] 36 | pos_thres, neg_thres = ran_thres() 37 | processed_line = f"{video_path} {frame_num} {pos_thres:.3f} {neg_thres:.3f}" 38 | processed_lines.append(processed_line) 39 | 40 | # Write the processed lines to a new file 41 | with open(input_file, 'w') as f: 42 | for processed_line in processed_lines: 43 | f.write(processed_line + '\n') 44 | 45 | if __name__ == "__main__": 46 | process_file("config/webvid10000_unfiltered.txt") -------------------------------------------------------------------------------- /model/default_config.py: -------------------------------------------------------------------------------- 1 | default_config = { 2 | 'name': 'inference', 3 | 'n_gpu': 1, 4 | 'arch': { 5 | 'args': {} 6 | }, 7 | 'valid_data_loader': { 8 | 'type': 'HDF5DataLoader', 9 | 'args': { 10 | 'batch_size': 1, 11 | 'shuffle': False, 12 | 'num_workers': 1, 13 | 'pin_memory': True, 14 | 'sequence_kwargs': { 15 | 'dataset_type': 'HDF5Dataset', 16 | 'normalize_image': True, 17 | 'dataset_kwargs': { 18 | 'transforms': { 19 | 'CenterCrop': { 20 | 'size': 160 21 | } 22 | } 23 | } 24 | } 25 | } 26 | }, 27 | 'optimizer': { 28 | 'type': 'Adam', 29 | 'args': { 30 | 'lr': 0.0001, 31 | 'weight_decay': 0, 32 | 'amsgrad': True 33 | } 34 | }, 35 | 'loss_ftns': { 36 | 'perceptual_loss': { 37 | 'weight': 1.0 38 | }, 39 | 'temporal_consistency_loss': { 40 | 'weight': 1.0 41 | } 42 | }, 43 | 'lr_scheduler': { 44 | 'type': 'StepLR', 45 | 'args': { 46 | 'step_size': 50, 47 | 'gamma': 1.0 48 | } 49 | }, 50 | 'trainer': { 51 | 'epochs': 1, 52 | 'save_dir': '/tmp/inference', 53 | 'save_period': 1, 54 | 'verbosity': 2, 55 | 'monitor': 'min val_loss', 56 | 'num_previews': 4, 57 | 'val_num_previews': 8, 58 | 'tensorboard': True 59 | } 60 | } 61 | 62 | -------------------------------------------------------------------------------- /utils/default_config.py: -------------------------------------------------------------------------------- 1 | default_config = { 2 | 'name': 'inference', 3 | 'n_gpu': 1, 4 | 'arch': { 5 | 'args': {} 6 | }, 7 | 'valid_data_loader': { 8 | 'type': 'HDF5DataLoader', 9 | 'args': { 10 | 'batch_size': 1, 11 | 'shuffle': False, 12 | 'num_workers': 1, 13 | 'pin_memory': True, 14 | 'sequence_kwargs': { 15 | 'dataset_type': 'HDF5Dataset', 16 | 'normalize_image': True, 17 | 'dataset_kwargs': { 18 | 'transforms': { 19 | 'CenterCrop': { 20 | 'size': 160 21 | } 22 | } 23 | } 24 | } 25 | } 26 | }, 27 | 'optimizer': { 28 | 'type': 'Adam', 29 | 'args': { 30 | 'lr': 0.0001, 31 | 'weight_decay': 0, 32 | 'amsgrad': True 33 | } 34 | }, 35 | 'loss_ftns': { 36 | 'perceptual_loss': { 37 | 'weight': 1.0 38 | }, 39 | 'temporal_consistency_loss': { 40 | 'weight': 1.0 41 | } 42 | }, 43 | 'lr_scheduler': { 44 | 'type': 'StepLR', 45 | 'args': { 46 | 'step_size': 50, 47 | 'gamma': 1.0 48 | } 49 | }, 50 | 'trainer': { 51 | 'epochs': 1, 52 | 'save_dir': '/tmp/inference', 53 | 'save_period': 1, 54 | 'verbosity': 2, 55 | 'monitor': 'min val_loss', 56 | 'num_previews': 4, 57 | 'val_num_previews': 8, 58 | 'tensorboard': True 59 | } 60 | } 61 | 62 | -------------------------------------------------------------------------------- /utils/extract_images_MMP.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import os 4 | import numpy as np 5 | from util import setup_output_folder, append_timestamp 6 | from os.path import join 7 | from tqdm import tqdm 8 | 9 | 10 | def load_data(data_path, timestamp_fname='timestamps.npy', image_fname='images.npy'): 11 | 12 | assert os.path.isdir(data_path), '%s is not a valid data_pathectory' % data_path 13 | 14 | data = {} 15 | for subroot, _, fnames in sorted(os.walk(data_path)): 16 | for fname in sorted(fnames): 17 | path = os.path.join(subroot, fname) 18 | if fname.endswith('.npy'): 19 | if fname.endswith(timestamp_fname): 20 | frame_stamps = np.load(path) 21 | data['frame_stamps'] = frame_stamps 22 | elif fname.endswith(image_fname): 23 | data['images'] = np.load(path, mmap_mode='r') # N x H x W x C 24 | return data 25 | 26 | 27 | def save_images(data, output_folder, ts_path): 28 | for i, (image, ts) in enumerate(zip(tqdm(data['images']), data['frame_stamps'])): 29 | fname = 'frame_{:010d}.png'.format(i) 30 | cv2.imwrite(join(output_folder, fname), image) 31 | append_timestamp(ts_path, fname, ts) 32 | 33 | 34 | def main(args): 35 | data = load_data(args.data_path) 36 | ts_path = setup_output_folder(args.output_folder) 37 | save_images(data, args.output_folder, ts_path) 38 | 39 | 40 | if __name__ == '__main__': 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument('data_path', type=str) 43 | parser.add_argument('output_folder', type=str) 44 | args = parser.parse_args() 45 | main(args) 46 | -------------------------------------------------------------------------------- /config/test_hypere2vid_original.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: hypere2vid_original 2 | test_output_dir: results/hypere2vid_original 3 | 4 | module: 5 | loss: 6 | lpips_weight: 1.0 7 | lpips_type: alex 8 | l2_weight: 0 9 | l1_weight: 0 10 | ssim_weight: 0 11 | temporal_consistency_weight: 1.0 12 | 13 | normalize_voxels: false 14 | model: 15 | target: model.hyper_model.HyperE2VID 16 | params: 17 | unet_kwargs: 18 | num_bins: 5 19 | skip_type: sum 20 | recurrent_block_type: convlstm 21 | kernel_size: 5 22 | channel_multiplier: 2 23 | num_encoders: 3 24 | base_num_channels: 32 25 | num_residual_blocks: 2 26 | use_upsample_conv: true 27 | norm: none 28 | num_output_channels: 1 29 | use_dynamic_decoder: true # Key difference of HyperE2VID 30 | 31 | test_stage: 32 | test_batch_size: 1 33 | test_num_workers: 4 34 | test: 35 | - data_file: config/evaid_test.txt 36 | class_name: data.testh5.TestH5Dataset 37 | dataset_name: evaid 38 | num_bins: 5 39 | sequence_length: 80 40 | interpolate_bins: true 41 | - data_file: config/ijrr_test.txt 42 | class_name: data.testh5.TestH5Dataset 43 | dataset_name: ijrr 44 | num_bins: 5 45 | sequence_length: 80 46 | interpolate_bins: true 47 | - data_file: config/hqf_test.txt 48 | class_name: data.testh5.TestH5Dataset 49 | dataset_name: hqf 50 | num_bins: 5 51 | sequence_length: 80 52 | interpolate_bins: true 53 | - data_file: config/mvsec_test.txt 54 | class_name: data.testh5.TestH5Dataset 55 | dataset_name: mvsec 56 | num_bins: 5 57 | sequence_length: 80 58 | interpolate_bins: true -------------------------------------------------------------------------------- /config/test_e2vid++_original.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: e2vid++_original 2 | test_output_dir: results/e2vid++_original 3 | 4 | module: 5 | lr_scheduler: 6 | target: torch.optim.lr_scheduler.ReduceLROnPlateau 7 | params: 8 | mode: 'min' 9 | factor: 0.5 # 当触发条件满足时,学习率将乘以此因子 10 | patience: 10 # 在触发学习率下降之前,允许的连续无改进的epoch数 11 | 12 | loss: 13 | lpips_weight: 1.0 14 | lpips_type: alex 15 | l2_weight: 0 16 | l1_weight: 0 17 | ssim_weight: 0 18 | temporal_consistency_weight: 1.0 19 | 20 | normalize_voxels: false 21 | model: 22 | target: model.model.FlowNet 23 | params: 24 | unet_kwargs: 25 | num_bins: 5 26 | skip_type: sum 27 | recurrent_block_type: convlstm 28 | num_encoders: 3 29 | base_num_channels: 32 30 | num_residual_blocks: 2 31 | use_upsample_conv: true 32 | norm: none 33 | num_output_channels: 3 34 | 35 | test_stage: 36 | test_num_workers: 4 37 | test: 38 | - data_file: config/evaid_test.txt 39 | class_name: data.testh5.TestH5Dataset 40 | dataset_name: evaid 41 | num_bins: 5 42 | sequence_length: 80 43 | interpolate_bins: true 44 | - data_file: config/ijrr_test.txt 45 | class_name: data.testh5.TestH5Dataset 46 | dataset_name: ijrr 47 | num_bins: 5 48 | sequence_length: 80 49 | interpolate_bins: true 50 | - data_file: config/hqf_test.txt 51 | class_name: data.testh5.TestH5Dataset 52 | dataset_name: hqf 53 | num_bins: 5 54 | sequence_length: 80 55 | interpolate_bins: true 56 | - data_file: config/mvsec_test.txt 57 | class_name: data.testh5.TestH5Dataset 58 | dataset_name: mvsec 59 | num_bins: 5 60 | sequence_length: 80 61 | interpolate_bins: true -------------------------------------------------------------------------------- /utils/timers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import numpy as np 4 | import atexit 5 | from collections import defaultdict 6 | 7 | cuda_timers = defaultdict(list) 8 | timers = defaultdict(list) 9 | 10 | 11 | class CudaTimer: 12 | def __init__(self, timer_name=''): 13 | self.timer_name = timer_name 14 | 15 | self.start = torch.cuda.Event(enable_timing=True) 16 | self.end = torch.cuda.Event(enable_timing=True) 17 | 18 | def __enter__(self): 19 | self.start.record() 20 | return self 21 | 22 | def __exit__(self, *args): 23 | self.end.record() 24 | torch.cuda.synchronize() 25 | cuda_timers[self.timer_name].append(self.start.elapsed_time(self.end)) 26 | 27 | 28 | class Timer: 29 | def __init__(self, timer_name=''): 30 | self.timer_name = timer_name 31 | 32 | def __enter__(self): 33 | self.start = time.time() 34 | return self 35 | 36 | def __exit__(self, *args): 37 | self.end = time.time() 38 | self.interval = self.end - self.start # measured in seconds 39 | self.interval *= 1000.0 # convert to milliseconds 40 | timers[self.timer_name].append(self.interval) 41 | 42 | 43 | def print_timing_info(): 44 | print('== Timing statistics ==') 45 | for timer_name, timing_values in [*cuda_timers.items(), *timers.items()]: 46 | timing_value = np.mean(np.array(timing_values)) 47 | if timing_value < 1000.0: 48 | print('{}: {:.2f} ms ({} samples)'.format(timer_name, timing_value, len(timing_values))) 49 | else: 50 | print('{}: {:.2f} s ({} samples)'.format(timer_name, timing_value / 1000.0, len(timing_values))) 51 | 52 | 53 | # this will print all the timer values upon termination of any program that imported this file 54 | atexit.register(print_timing_info) 55 | -------------------------------------------------------------------------------- /PerceptualSimilarity/models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.autograd import Variable 4 | from pdb import set_trace as st 5 | from IPython import embed 6 | 7 | class BaseModel(): 8 | def __init__(self): 9 | pass; 10 | 11 | def name(self): 12 | return 'BaseModel' 13 | 14 | def initialize(self, use_gpu=True, gpu_ids=[0]): 15 | self.use_gpu = use_gpu 16 | self.gpu_ids = gpu_ids 17 | 18 | def forward(self): 19 | pass 20 | 21 | def get_image_paths(self): 22 | pass 23 | 24 | def optimize_parameters(self): 25 | pass 26 | 27 | def get_current_visuals(self): 28 | return self.input 29 | 30 | def get_current_errors(self): 31 | return {} 32 | 33 | def save(self, label): 34 | pass 35 | 36 | # helper saving function that can be used by subclasses 37 | def save_network(self, network, path, network_label, epoch_label): 38 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 39 | save_path = os.path.join(path, save_filename) 40 | torch.save(network.state_dict(), save_path) 41 | 42 | # helper loading function that can be used by subclasses 43 | def load_network(self, network, network_label, epoch_label): 44 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 45 | save_path = os.path.join(self.save_dir, save_filename) 46 | print('Loading network from %s'%save_path) 47 | network.load_state_dict(torch.load(save_path)) 48 | 49 | def update_learning_rate(): 50 | pass 51 | 52 | def get_image_paths(self): 53 | return self.image_paths 54 | 55 | def save_done(self, flag=False): 56 | np.save(os.path.join(self.save_dir, 'done_flag'),flag) 57 | np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i') 58 | 59 | -------------------------------------------------------------------------------- /config/test_eraft_original.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: eraft_original 2 | test_output_dir: results/eraft_original 3 | save_npy: false 4 | save_png: true 5 | use_compile: false 6 | task: flow 7 | 8 | module: 9 | forward_type: eraft 10 | 11 | loss: 12 | l1_weight: 1.0 13 | optical_flow_source: raft_large 14 | raft_num_flow_updates: 12 15 | 16 | normalize_voxels: false 17 | model: 18 | target: model.eraft.eraft.ERAFT 19 | params: 20 | config: 21 | subtype: warm_start 22 | n_first_channels: 15 23 | 24 | test_stage: 25 | test_batch_size: 1 26 | test_num_workers: 16 27 | test: 28 | - data_file: config/evaid_test.txt 29 | class_name: data.testh5.TestH5Dataset 30 | dataset_name: evaid 31 | num_bins: 15 32 | sequence_length: 10 33 | interpolate_bins: true 34 | output_additional_frame: true 35 | output_additional_evs: true 36 | image_range: 1 37 | max_samples: 2 38 | - data_file: config/ijrr_test.txt 39 | class_name: data.testh5.TestH5Dataset 40 | dataset_name: ijrr 41 | num_bins: 15 42 | sequence_length: 80 43 | interpolate_bins: true 44 | output_additional_frame: true 45 | output_additional_evs: true 46 | image_range: 1 47 | max_samples: 2 48 | - data_file: config/hqf_test.txt 49 | class_name: data.testh5.TestH5Dataset 50 | dataset_name: hqf 51 | num_bins: 15 52 | sequence_length: 80 53 | interpolate_bins: true 54 | output_additional_frame: true 55 | output_additional_evs: true 56 | image_range: 1 57 | max_samples: 2 58 | - data_file: config/mvsec_test_flow.txt 59 | class_name: data.testh5.TestH5FlowDataset 60 | dataset_name: mvsec 61 | num_bins: 15 62 | sequence_length: 80 63 | interpolate_bins: true 64 | output_additional_frame: true 65 | output_additional_evs: true 66 | image_range: 1 -------------------------------------------------------------------------------- /config/test_evflow_original.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: evflow_original 2 | test_output_dir: results/evflow_original 3 | use_compile: false 4 | task: flow 5 | 6 | module: 7 | loss: 8 | l1_weight: 1.0 9 | optical_flow_source: raft_large 10 | raft_num_flow_updates: 12 11 | 12 | normalize_voxels: false 13 | model: 14 | target: model.model.EVFlowNet 15 | params: 16 | unet_kwargs: 17 | num_bins: 5 18 | base_num_channels: 32 19 | num_encoders: 4 20 | num_residual_blocks: 2 21 | num_output_channels: 2 22 | skip_type: concat 23 | norm: null 24 | use_upsample_conv: true 25 | kernel_size: 3 26 | channel_multiplier: 2 27 | 28 | test_stage: 29 | test_batch_size: 1 30 | test_num_workers: 16 31 | test: 32 | - data_file: config/evaid_test.txt 33 | class_name: data.testh5.TestH5Dataset 34 | dataset_name: evaid 35 | num_bins: 5 36 | sequence_length: 10 37 | interpolate_bins: true 38 | output_additional_frame: true 39 | image_range: 1 40 | max_samples: 2 41 | - data_file: config/ijrr_test.txt 42 | class_name: data.testh5.TestH5Dataset 43 | dataset_name: ijrr 44 | num_bins: 5 45 | sequence_length: 80 46 | interpolate_bins: true 47 | output_additional_frame: true 48 | image_range: 1 49 | max_samples: 2 50 | - data_file: config/hqf_test.txt 51 | class_name: data.testh5.TestH5Dataset 52 | dataset_name: hqf 53 | num_bins: 5 54 | sequence_length: 80 55 | interpolate_bins: true 56 | output_additional_frame: true 57 | image_range: 1 58 | max_samples: 2 59 | - data_file: config/mvsec_test_flow.txt 60 | class_name: data.testh5.TestH5FlowDataset 61 | dataset_name: mvsec 62 | num_bins: 5 63 | sequence_length: 80 64 | interpolate_bins: true 65 | output_additional_frame: true 66 | image_range: 1 -------------------------------------------------------------------------------- /config/test_nernet_original.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: nernet_original 2 | test_output_dir: results/nernet_original 3 | 4 | module: 5 | is_nernet: true 6 | loss: 7 | lpips_weight: 1.0 8 | lpips_type: alex 9 | l2_weight: 0 10 | l1_weight: 0 11 | ssim_weight: 0 12 | temporal_consistency_weight: 1.0 13 | 14 | normalize_voxels: false 15 | model: 16 | target: model.nernet_model.RepresentationRecurrent 17 | params: 18 | unet_kwargs: 19 | num_bins: 5 20 | skip_type: "sum" 21 | recurrent_network: "NIAM_STcell_GCB" 22 | recurrent_block_type: "" 23 | num_encoders: 3 24 | base_num_channels: 32 25 | num_residual_blocks: 2 26 | use_upsample_conv: true 27 | norm: "" 28 | crop_size: 224 29 | mlp_layers: 30 | - 1 31 | - 50 32 | - 50 33 | - 50 34 | - 1 35 | use_cnn_representation: true 36 | normalize: false 37 | combine_voxel: false 38 | RepCNN_kernel_size: 3 39 | RepCNN_padding: 1 40 | RepCNN_channel: 64 41 | RepCNN_num_layers: 1 42 | num_output_channels: 1 43 | 44 | test_stage: 45 | test_batch_size: 1 46 | test_num_workers: 0 47 | test: 48 | - data_file: config/evaid_test.txt 49 | class_name: data.testh5.TestH5EventDataset 50 | dataset_name: evaid 51 | num_bins: 5 52 | sequence_length: 80 53 | interpolate_bins: true 54 | - data_file: config/ijrr_test.txt 55 | class_name: data.testh5.TestH5EventDataset 56 | dataset_name: ijrr 57 | num_bins: 5 58 | sequence_length: 80 59 | interpolate_bins: true 60 | - data_file: config/hqf_test.txt 61 | class_name: data.testh5.TestH5EventDataset 62 | dataset_name: hqf 63 | num_bins: 5 64 | sequence_length: 80 65 | interpolate_bins: true 66 | - data_file: config/mvsec_test.txt 67 | class_name: data.testh5.TestH5EventDataset 68 | dataset_name: mvsec 69 | num_bins: 5 70 | sequence_length: 80 71 | interpolate_bins: true -------------------------------------------------------------------------------- /scripts/esim_to_voxel.py: -------------------------------------------------------------------------------- 1 | # Code used to convert h5 ESIM events to cached h5 voxels. 2 | 3 | import os 4 | import sys 5 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 6 | 7 | from data.dataset import DynamicH5Dataset 8 | import glob 9 | import h5py 10 | import numpy as np 11 | import tqdm 12 | 13 | original_paths = sorted(glob.glob("/mnt/ssd/esim_h5/*.h5")) 14 | os.makedirs("/mnt/ssd/esim_voxel_nobi", exist_ok=True) 15 | os.makedirs("/mnt/ssd/esim_voxel_cache", exist_ok=True) 16 | 17 | def convert(temporal_bilinear, output_file_name): 18 | for p in tqdm.tqdm(original_paths): 19 | out_path = p.replace("esim_h5", output_file_name) 20 | print("Saving to ", out_path) 21 | dataset = DynamicH5Dataset(data_path=p, temporal_bilinear=temporal_bilinear) 22 | 23 | all_frames = [] 24 | all_flow = [] 25 | all_events = [] 26 | all_timestamps = [] 27 | all_dt = [] 28 | 29 | for i in range(len(dataset)): 30 | item = dataset[i] 31 | all_frames.append(item["frame"].numpy()) 32 | all_flow.append(item["flow"].numpy()) 33 | all_events.append(item["events"].numpy()) 34 | all_timestamps.append(item["timestamp"]) 35 | all_dt.append(item["dt"]) 36 | 37 | with h5py.File(out_path, "w") as f: 38 | 39 | all_frames = np.stack(all_frames) 40 | all_flow = np.stack(all_flow) 41 | all_events = np.stack(all_events) 42 | all_timestamps = np.stack(all_timestamps) 43 | all_dt = np.stack(all_dt) 44 | 45 | f.attrs["sensor_resolution"] = dataset.sensor_resolution 46 | f.attrs["source"] = "esim" 47 | f.create_dataset(f"frames", data=all_frames, dtype=np.float32) 48 | f.create_dataset(f"flow", data=all_flow, dtype=np.float32) 49 | f.create_dataset(f"events", data=all_events, dtype=np.float32) 50 | f.create_dataset(f"timestamps", data=all_timestamps, dtype=np.float32) 51 | f.create_dataset(f"dt", data=all_dt, dtype=np.float32) 52 | print(all_frames.shape) 53 | print(all_flow.shape) 54 | print(all_events.shape) 55 | print(all_timestamps.shape) 56 | print(all_dt.shape) 57 | 58 | convert(True, "esim_voxel_cache") # Already done with previous code, not tested with this code yet 59 | convert(False, "esim_voxel_nobi") -------------------------------------------------------------------------------- /model/eraft/corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from .utils import bilinear_sampler, coords_grid 4 | 5 | try: 6 | import alt_cuda_corr 7 | except: 8 | # alt_cuda_corr is not compiled 9 | pass 10 | 11 | 12 | class CorrBlock: 13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 14 | self.num_levels = num_levels 15 | self.radius = radius 16 | self.corr_pyramid = [] 17 | 18 | # all pairs correlation 19 | corr = CorrBlock.corr(fmap1, fmap2) 20 | 21 | batch, h1, w1, dim, h2, w2 = corr.shape 22 | corr = corr.reshape(batch*h1*w1, dim, h2, w2) 23 | 24 | self.corr_pyramid.append(corr) 25 | for i in range(self.num_levels-1): 26 | corr = F.avg_pool2d(corr, 2, stride=2) 27 | self.corr_pyramid.append(corr) 28 | 29 | def __call__(self, coords): 30 | r = self.radius 31 | coords = coords.permute(0, 2, 3, 1) 32 | batch, h1, w1, _ = coords.shape 33 | 34 | out_pyramid = [] 35 | for i in range(self.num_levels): 36 | corr = self.corr_pyramid[i] 37 | dx = torch.linspace(-r, r, 2*r+1) 38 | dy = torch.linspace(-r, r, 2*r+1) 39 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) 40 | 41 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i 42 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) 43 | coords_lvl = centroid_lvl + delta_lvl 44 | 45 | corr = bilinear_sampler(corr, coords_lvl) 46 | corr = corr.view(batch, h1, w1, -1) 47 | out_pyramid.append(corr) 48 | 49 | out = torch.cat(out_pyramid, dim=-1) 50 | return out.permute(0, 3, 1, 2).contiguous().float() 51 | 52 | @staticmethod 53 | def corr(fmap1, fmap2): 54 | batch, dim, ht, wd = fmap1.shape 55 | fmap1 = fmap1.view(batch, dim, ht*wd) 56 | fmap2 = fmap2.view(batch, dim, ht*wd) 57 | 58 | corr = torch.matmul(fmap1.transpose(1,2), fmap2) 59 | corr = corr.view(batch, ht, wd, 1, ht, wd) 60 | return corr / torch.sqrt(torch.tensor(dim).float()) 61 | -------------------------------------------------------------------------------- /scripts/testset_evcnt_maps.py: -------------------------------------------------------------------------------- 1 | # Produce visualizations for Figure 10 of the V2V paper. 2 | 3 | import tqdm 4 | import h5py 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import os 8 | from mpl_toolkits.axes_grid1 import make_axes_locatable 9 | from matplotlib.colors import LogNorm 10 | 11 | def read_txt(filename): 12 | files = [] 13 | with open(filename, "r") as f: 14 | for line in f: 15 | files.append(line.strip()) 16 | return files 17 | 18 | def make_plot(h5file, out_path): 19 | with h5py.File(h5file, "r") as f: 20 | img_keys = sorted(f["images"].keys()) 21 | H, W = f["images"][img_keys[0]].shape 22 | evcnt = np.zeros((H, W)) 23 | xs = f["events/xs"][:] 24 | ys = f["events/ys"][:] 25 | np.add.at(evcnt, (ys, xs), 1) 26 | # Use log scale for visualization but show actual values 27 | evcnt_log = np.log1p(evcnt) # log(1+x) handles zeros gracefully 28 | im = plt.imshow(evcnt_log, cmap="jet", vmin=0, vmax=np.max(evcnt_log)) 29 | ax = plt.gca() 30 | divider = make_axes_locatable(ax) 31 | cax = divider.append_axes("right", size="5%", pad=0.05) 32 | cbar = plt.colorbar(im, cax=cax) 33 | 34 | # Create custom ticks that show original values 35 | max_val = evcnt.max() 36 | if max_val > 10000: 37 | ticks = [1, 10, 100, 1000, 10000] 38 | if max_val > 1000: 39 | ticks = [1, 10, 100, 1000] 40 | elif max_val > 100: 41 | ticks = [1, 10, 100] 42 | else: 43 | ticks = [1, 10] 44 | ticks = [t for t in ticks if t <= max_val] 45 | log_ticks = [np.log1p(t) for t in ticks] # Convert to log scale 46 | cbar.set_ticks(log_ticks) 47 | cbar.set_ticklabels([str(t) for t in ticks]) # Show original values 48 | cbar.ax.tick_params(labelsize=8) 49 | ax.axis("off") 50 | # Add a dummy label in white font with number 10000, so all images have the same width 51 | ax.text(0.5, 0.5, "100000000", color="white", fontsize=8, ha="left", va="top", transform=cax.transAxes) 52 | 53 | plt.savefig(out_path, bbox_inches="tight", pad_inches=0, dpi=300) 54 | plt.close() 55 | 56 | for classname, filepath in [ 57 | ("MVSEC", "config/mvsec_test.txt"), 58 | ("IJRR", "config/ijrr_test.txt"), 59 | ("HQF", "config/hqf_test.txt"), 60 | ("EVAID", "config/evaid_test.txt"), 61 | ]: 62 | files = read_txt(filepath) 63 | for file in tqdm.tqdm(files): 64 | seqname = file.split("/")[-1].split(".")[0] 65 | out_path = f"videos/event_count_maps/{classname}/{seqname}.png" 66 | os.makedirs(os.path.dirname(out_path), exist_ok=True) 67 | make_plot(file, out_path) 68 | -------------------------------------------------------------------------------- /data/v2v_core_esim.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def reverse_gamma_correction(imgs, gamma=2.2): 4 | return (imgs / 255) ** gamma * 255 5 | 6 | class EventEmulator(object): 7 | 8 | def __init__( 9 | self, 10 | pos_thres: float = 0.2, 11 | neg_thres: float = 0.2, 12 | base_noise_std: float = 0.1, 13 | hot_pixel_fraction: float = 0.001, 14 | hot_pixel_std: float = 0.1, 15 | put_noise_external: bool = False, 16 | seed: int = None, 17 | ): 18 | self.pos_threshold = pos_thres 19 | self.neg_threshold = neg_thres 20 | self.base_noise_std = base_noise_std 21 | self.hot_pixel_fraction = hot_pixel_fraction 22 | self.hot_pixel_std = hot_pixel_std 23 | self.put_noise_external = put_noise_external 24 | self.seed = seed 25 | 26 | def video_to_voxel(self, video): 27 | N, H, W = video.shape 28 | # Initialize the potential uniform random between -neg_thres and pos_thres 29 | self.potential = np.random.rand(H, W) * (self.pos_threshold + self.neg_threshold) - self.neg_threshold 30 | 31 | all_voxels = [] 32 | # Reverse gamma correction will make video more linear. 33 | video = reverse_gamma_correction(video) 34 | log_imgs = np.log(0.001 + video/255.0) 35 | 36 | # The hot noise persists for the entire video 37 | hot_pixel_mask = np.random.rand(H, W) < self.hot_pixel_fraction 38 | hot_noise = self.hot_pixel_std * np.random.randn(H, W) 39 | hot_noise = np.where(hot_pixel_mask, hot_noise, 0) 40 | 41 | for i in range(N-1): 42 | diff = log_imgs[i+1] - log_imgs[i] 43 | self.potential += diff 44 | base_noise = self.base_noise_std * np.random.randn(H, W) 45 | 46 | if not self.put_noise_external: 47 | # The noise influences the potential. 48 | self.potential += base_noise 49 | self.potential += hot_noise 50 | 51 | pos_events = np.floor_divide(self.potential, self.pos_threshold) 52 | pos_events = np.where(self.potential >= self.pos_threshold, pos_events, 0) 53 | 54 | neg_events = np.floor_divide(-self.potential, self.neg_threshold) 55 | neg_events = np.where(self.potential <= -self.neg_threshold, neg_events, 0) 56 | 57 | self.potential -= pos_events * self.pos_threshold 58 | self.potential += neg_events * self.neg_threshold 59 | 60 | voxel = pos_events - neg_events 61 | 62 | if self.put_noise_external: 63 | # Directly add the noise (a float) to the voxel output 64 | voxel = voxel + base_noise 65 | voxel = voxel + hot_noise 66 | 67 | all_voxels.append(voxel) 68 | 69 | return np.array(all_voxels) 70 | -------------------------------------------------------------------------------- /scripts/result_to_col.py: -------------------------------------------------------------------------------- 1 | # Tool script that converts a row from a test result to a column (easy to copy & paste into a feishu doc / excel sheet). 2 | 3 | import os 4 | import sys 5 | 6 | sequences = { 7 | "HQF": [ 8 | "bike_bay_hdr", 9 | "boxes", 10 | "desk", 11 | "desk_fast", 12 | "desk_hand_only", 13 | "desk_slow", 14 | "engineering_posters", 15 | "high_texture_plants", 16 | "poster_pillar_1", 17 | "poster_pillar_2", 18 | "reflective_materials", 19 | "slow_and_fast_desk", 20 | "slow_hand", 21 | "still_life" 22 | ], 23 | "EVAID": [ 24 | "ball", 25 | "bear", 26 | "box", 27 | "building", 28 | "outdoor", 29 | "playball", 30 | "room1", 31 | "sculpture", 32 | "toy", 33 | "traffic", 34 | "wall" 35 | ], 36 | "IJRR": [ 37 | "boxes_6dof", 38 | "calibration", 39 | "dynamic_6dof", 40 | "office_zigzag", 41 | "poster_6dof", 42 | "shapes_6dof", 43 | "slider_depth" 44 | ], 45 | "MVSEC": [ 46 | "indoor_flying1", 47 | "indoor_flying2", 48 | "indoor_flying3", 49 | "indoor_flying4", 50 | "outdoor_day1", 51 | "outdoor_day2" 52 | ] 53 | } 54 | 55 | all_metric_names = [] 56 | for dataset in ["HQF", "EVAID"]: 57 | for seqname in sequences[dataset]: 58 | for metric in ["MSE", "SSIM", "LPIPS"]: 59 | all_metric_names.append(f"{dataset}/{seqname}/{metric}") 60 | 61 | with open("debug/col_heads.txt", "w") as f: 62 | for key in all_metric_names: 63 | f.write(f"{key}\n") 64 | 65 | # line_n is the line number in vscode 66 | def extract_line(file_name, line_n): 67 | with open(file_name, "r", encoding="UTF-8") as f: 68 | lines = f.readlines() 69 | head = lines[0].split(",") 70 | 71 | data = {} 72 | dataline = lines[line_n-1].split(",") 73 | for i in range(len(dataline)): 74 | data[head[i]] = dataline[i] 75 | 76 | of_path = os.path.join(os.path.dirname(file_name), f"col_from_line_{line_n:03d}_{dataline[0]}.txt") 77 | 78 | with open(of_path, "w", encoding="UTF-8") as of: 79 | for key in all_metric_names: 80 | val = float(data[key]) 81 | if "SSIM" in key: 82 | val = -val 83 | of.write(f"{val:.03f}\n") 84 | 85 | # In avg_metrics_from_line_{line_n}.txt, each line is average over all sequence, e.g. HQF/MSE. 86 | metric_pth = os.path.join(os.path.dirname(file_name), f"avg_metrics_from_line_{line_n:03d}_{dataline[0]}.txt") 87 | with open(metric_pth, "w") as of: 88 | for dataset in ["HQF", "EVAID", "IJRR", "MVSEC"]: 89 | for metric in ["MSE", "SSIM", "LPIPS"]: 90 | vals = [v for k, v in data.items() if dataset in k and metric in k] 91 | avg = sum([float(v) for v in vals]) / len(vals) 92 | of.write(f"{avg:.03f}\n") 93 | 94 | if __name__ == "__main__": 95 | file_name = sys.argv[1] 96 | line_n = int(sys.argv[2]) 97 | extract_line(file_name, line_n) -------------------------------------------------------------------------------- /scripts/aedat4_to_h5.py: -------------------------------------------------------------------------------- 1 | import os 2 | import h5py 3 | import dv_processing as dv 4 | import numpy as np 5 | 6 | # Read from a aedat4 file. 7 | # Only keep data from seconds [begin] to [end]. 8 | # Save to TestH5Dataset format. 9 | 10 | def convert(aedat4_file, h5_file, begin, end): 11 | reader = dv.io.MonoCameraRecording(aedat4_file) 12 | # Run the loop while camera is still connected 13 | base_time = None 14 | all_x = [] 15 | all_y = [] 16 | all_t = [] 17 | all_p = [] 18 | while reader.isRunning(): 19 | # Read batch of events 20 | events = reader.getNextEventBatch() 21 | if events is not None: 22 | # Print received packet time range 23 | evs = events.numpy() 24 | # print(evs.dtype) 25 | # {'names': ['timestamp', 'x', 'y', 'polarity'], 'formats': [' end: 31 | continue 32 | 33 | all_x.append(evs['x']) 34 | all_y.append(evs['y']) 35 | all_t.append(evs['timestamp']) 36 | all_p.append(evs['polarity']) 37 | 38 | reader = dv.io.MonoCameraRecording(aedat4_file) 39 | # Read the images 40 | all_imgs = [] 41 | img_timestamps = [] 42 | while reader.isRunning(): 43 | # Read a frame from the camera 44 | frame = reader.getNextFrame() 45 | 46 | if frame is not None: 47 | timestamp = (frame.timestamp - base_time) / 1e6 48 | if timestamp < begin or timestamp > end: 49 | continue 50 | all_imgs.append(frame.image) 51 | img_timestamps.append(frame.timestamp) 52 | 53 | # Save to h5 file 54 | all_x = np.concatenate(all_x) 55 | all_y = np.concatenate(all_y) 56 | base_t = all_t[0][0] 57 | all_t = (np.concatenate(all_t) - base_t).astype(np.float64) / 1e6 58 | print(all_t[0]) 59 | print(all_t[-1]) 60 | 61 | all_p = np.concatenate(all_p) 62 | 63 | img_event_idxs = np.searchsorted(all_t, img_timestamps) 64 | 65 | with h5py.File(h5_file, 'w') as f: 66 | f.create_dataset("events/ts", data=all_t, dtype=np.float32) 67 | f.create_dataset("events/xs", data=all_x, dtype=np.int16) 68 | f.create_dataset("events/ys", data=all_y, dtype=np.int16) 69 | f.create_dataset("events/ps", data=all_p, dtype=np.bool_) 70 | 71 | for i, img in enumerate(all_imgs): 72 | f.create_dataset(f"images/{i:06d}", data=img) 73 | # Set attribute f["images"][f"{i:06d}"].attrs["event_idx"] = img_event_idxs[i] 74 | f["images"][f"{i:06d}"].attrs.create("event_idx", img_event_idxs[i]) 75 | 76 | 77 | src_dir = "/mnt/nas-cp/hylou/Datasets/EvBirding/20250331/raw" 78 | dst_dir = "/mnt/nas-cp/hylou/Datasets/EvBirding/20250331/h5" 79 | convert(f"{src_dir}/maque.aedat4", f"{dst_dir}/maque.h5", 0, 10) -------------------------------------------------------------------------------- /scripts/make_ref_videos.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import glob 4 | import os 5 | import ffmpeg 6 | import sys 7 | 8 | def make_video(src_dirs, dataset, seqname, dst_dir): 9 | # Read all images in {src_dir}/{dataset}/{seqname}/*.png 10 | src_img_paths = [ 11 | sorted(glob.glob(f"{src_dir}/{dataset}/{seqname}/*.png")) 12 | for src_dir in src_dirs 13 | ] 14 | assert all(len(src_img_paths[0]) == len(src_img_paths[i]) for i in range(1, len(src_dirs))) 15 | T = len(src_img_paths[0]) 16 | 17 | # Get image dimensions from the first image 18 | img0 = cv2.imread(src_img_paths[0][0]) 19 | H, W, _ = img0.shape 20 | pad_W = W % 4 21 | pad_H = H % 4 22 | H += pad_H 23 | W += pad_W 24 | 25 | os.makedirs(f"{dst_dir}/{dataset}", exist_ok=True) 26 | output_path = f"{dst_dir}/{dataset}/{seqname}.mp4" 27 | 28 | # Start ffmpeg process with highest quality settings 29 | process = ( 30 | ffmpeg 31 | .input('pipe:', format='rawvideo', pix_fmt='rgb24', s=f'{W*len(src_dirs)}x{H}', r=24) 32 | .output(output_path, 33 | vcodec='libx264', # Specify H.264 codec 34 | pix_fmt='yuv420p', # Keep yuv420p for compatibility 35 | r=24, 36 | movflags='faststart', # Allow streaming 37 | tune='film', # Optimize for high quality video 38 | **{'b:v': '0'} # Let CRF control quality 39 | ) 40 | .overwrite_output() 41 | .run_async(pipe_stdin=True) 42 | ) 43 | 44 | try: 45 | # Process and write frames one by one 46 | for t in range(T): 47 | imgs = [cv2.imread(src_img_paths[i][t]) for i in range(len(src_dirs))] 48 | imgs = [ 49 | cv2.copyMakeBorder(img, 0, pad_H, 0, pad_W, cv2.BORDER_CONSTANT, value=[0, 0, 0]) 50 | for img in imgs 51 | ] 52 | img = np.concatenate(imgs, axis=1) 53 | # Convert from BGR to RGB for ffmpeg 54 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 55 | process.stdin.write(img.tobytes()) 56 | finally: 57 | # Ensure pipe is closed and process ends 58 | process.stdin.close() 59 | process.wait() 60 | 61 | def make_videos(experiment_name): 62 | ref_dir = "results/e2vid++_original" 63 | my_dir = f"results/{experiment_name}" 64 | dst_dir = f"videos/{experiment_name}" 65 | gt_dir = "results/gt_images" 66 | 67 | datasets = ["EVAID", "IJRR", "HQF", "MVSEC"] 68 | for dataset in datasets: 69 | seqnames = os.listdir(f"{ref_dir}/{dataset}") 70 | for seqname in seqnames: 71 | make_video([my_dir, ref_dir, gt_dir], dataset, seqname, dst_dir) 72 | 73 | def make_videos_comb(out_name, experiment_name_list): 74 | 75 | all_dirs = [f"results/{experiment_name}" for experiment_name in experiment_name_list] 76 | 77 | datasets = ["EVAID", "IJRR", "HQF", "MVSEC"] 78 | for dataset in datasets: 79 | seqnames = os.listdir(f"{all_dirs[0]}/{dataset}") 80 | for seqname in seqnames: 81 | make_video(all_dirs, dataset, seqname, f"videos/{out_name}") 82 | 83 | if __name__ == "__main__": 84 | # Take the experiment name as argument 85 | if len(sys.argv) > 1: 86 | experiment_name = sys.argv[1] 87 | make_videos(experiment_name) 88 | else: 89 | print("Please provide the experiment name as argument.") -------------------------------------------------------------------------------- /utils/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | from tqdm import tqdm 4 | from torch.utils.data import ConcatDataset 5 | 6 | 7 | data_sources = ('esim', 'ijrr', 'mvsec', 'eccd', 'hqf', 'unknown', 'reds', 'sportsslomo', 'adobe', 'youcook', 'vimeo', 'webvid', 'evbird', 'evaid', 'hs-ergb', 'openvid') 8 | # Usage: name = data_sources[1], idx = data_sources.index('ijrr') 9 | 10 | 11 | def concatenate_subfolders(data_file, dataset, dataset_kwargs): 12 | """ 13 | Create an instance of ConcatDataset by aggregating all the datasets in a given folder 14 | """ 15 | if os.path.isdir(data_file): 16 | subfolders = [os.path.join(data_file, s) for s in os.listdir(data_file)] 17 | elif os.path.isfile(data_file): 18 | subfolders = pd.read_csv(data_file, header=None).values.flatten().tolist() 19 | else: 20 | raise Exception('{} must be data_file.txt or base/folder'.format(data_file)) 21 | print('Found {} samples in {}'.format(len(subfolders), data_file)) 22 | datasets = [] 23 | for subfolder in subfolders: 24 | dataset_kwargs['item_kwargs'].update({'base_folder': subfolder}) 25 | datasets.append(dataset(**dataset_kwargs)) 26 | return ConcatDataset(datasets) 27 | 28 | 29 | def concatenate_datasets(data_file, dataset_type, dataset_kwargs={}): 30 | """ 31 | Generates a dataset for each data_path specified in data_file and concatenates the datasets. 32 | :param data_file: A file containing a list of paths to CTI h5 files. 33 | Each file is expected to have a sequence of frame_{:09d} 34 | :param dataset_type: Pointer to dataset class 35 | :param sequence_length: Desired length of each sequence 36 | :return ConcatDataset: concatenated dataset of all data_paths in data_file 37 | """ 38 | data_paths = pd.read_csv(data_file, header=None).values.flatten().tolist() 39 | dataset_list = [] 40 | print('Concatenating {} datasets'.format(dataset_type)) 41 | for data_path in tqdm(data_paths): 42 | dataset_list.append(dataset_type(data_path, **dataset_kwargs)) 43 | print("Total samples: ", sum([len(d) for d in dataset_list])) 44 | return ConcatDataset(dataset_list) 45 | 46 | def concatenate_memmap_datasets(data_file, dataset_type, dataset_kwargs): 47 | """ 48 | Generates a dataset for each memmap_path specified in data_file and concatenates the datasets. 49 | :param data_file: A file containing a list of paths to memmap root dirs. 50 | :param dataset_type: Pointer to dataset class 51 | :param dataset_kwargs: Dataset keyword arguments 52 | :return ConcatDataset: concatenated dataset of all memmap_paths in data_file 53 | """ 54 | if dataset_kwargs is None: 55 | dataset_kwargs = {} 56 | 57 | memmap_paths = pd.read_csv(data_file, header=None).values.flatten().tolist() 58 | dataset_list = [] 59 | print('Concatenating {} datasets'.format(dataset_type)) 60 | for memmap_path in tqdm(memmap_paths): 61 | dataset_kwargs['dataset_kwargs'].update({'root': memmap_path}) 62 | dataset_list.append(dataset_type(**dataset_kwargs)) 63 | return ConcatDataset(dataset_list) 64 | -------------------------------------------------------------------------------- /model/eitr/transformer_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import copy 5 | 6 | 7 | class transformer_encoder(nn.Module): 8 | def __init__(self, d_model=256, nhead=8, num_encoder_layers=6, activation='relu', 9 | dim_feedforward=2048, dropout=0.1): 10 | super().__init__() 11 | self.d_model = d_model 12 | self.nhead = nhead 13 | 14 | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, 15 | dropout, activation) 16 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers) 17 | 18 | self._reset_parameters() 19 | 20 | def _reset_parameters(self): 21 | for p in self.parameters(): 22 | if p.dim() > 1: 23 | nn.init.xavier_uniform_(p) 24 | 25 | def forward(self, src, pos): 26 | output = self.encoder(src, pos) 27 | 28 | return output 29 | 30 | 31 | class TransformerEncoder(nn.Module): 32 | def __init__(self, encoder_layer, num_layers): 33 | super().__init__() 34 | self.layers = _get_clones(encoder_layer, num_layers) 35 | 36 | def with_embed(self, tensor, pos): 37 | return tensor if pos is None else tensor + pos 38 | 39 | def forward(self, src, pos): 40 | output = self.with_embed(src, pos) 41 | 42 | for layer in self.layers: 43 | output = layer(output) 44 | 45 | return output 46 | 47 | 48 | class TransformerEncoderLayer(nn.Module): 49 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 50 | activation="relu"): 51 | super().__init__() 52 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 53 | self.attn_dropout = nn.Dropout(dropout) 54 | self.norm1 = nn.LayerNorm(d_model) 55 | self.linear1 = nn.Linear(d_model, dim_feedforward) 56 | self.activation = _get_activation_fn(activation) 57 | self.ffn_dropout1 = nn.Dropout(dropout) 58 | self.linear2 = nn.Linear(dim_feedforward, d_model) 59 | self.ffn_dropout2 = nn.Dropout(dropout) 60 | self.norm2 = nn.LayerNorm(d_model) 61 | 62 | def with_embed(self, tensor, pos): 63 | return tensor if pos is None else tensor + pos 64 | 65 | def forward(self, src): 66 | # self attention 67 | q = k = v = self.norm1(src) 68 | src1 = self.self_attn(q, k, v)[0] 69 | src2 = src + self.attn_dropout(src1) 70 | 71 | # FFN 72 | src3 = self.norm2(src2) 73 | src4 = self.linear2(self.ffn_dropout1(self.activation(self.linear1(src3)))) 74 | src5 = src2 + self.ffn_dropout2(src4) 75 | 76 | return src5 77 | 78 | 79 | def _get_clones(module, N): 80 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 81 | 82 | 83 | def _get_activation_fn(activation): 84 | """Return an activation function given a string""" 85 | if activation == "relu": 86 | return F.relu 87 | if activation == "gelu": 88 | return F.gelu 89 | if activation == "glu": 90 | return F.glu 91 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 92 | 93 | 94 | def build_transformer(args): 95 | return transformer(**args) 96 | -------------------------------------------------------------------------------- /model/hyper/fourier_bessel.py: -------------------------------------------------------------------------------- 1 | ''' 2 | The code is directly translated from the matlab code 3 | https://github.com/xycheng/DCFNet/blob/master/calculate_FB_bases.m 4 | ''' 5 | import os 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | from scipy import special 11 | 12 | path_to_bessel = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'bessel.npy') 13 | 14 | 15 | def bases_list(ks, num_bases): 16 | len_list = ks // 2 17 | b_list = [] 18 | for i in range(len_list): 19 | kernel_size = (i + 1) * 2 + 1 20 | normed_bases, _, _ = calculate_FB_bases(i + 1) 21 | normed_bases = normed_bases.transpose().reshape(-1, kernel_size, kernel_size).astype(np.float32)[:num_bases, 22 | ...] 23 | 24 | pad = len_list - (i + 1) 25 | bases = torch.Tensor(normed_bases) 26 | bases = F.pad(bases, (pad, pad, pad, pad, 0, 0)).view(num_bases, ks * ks) 27 | b_list.append(bases) 28 | return torch.cat(b_list, 0) 29 | 30 | 31 | def cart2pol(x, y): 32 | rho = np.sqrt(x ** 2 + y ** 2) 33 | phi = np.arctan2(y, x) 34 | return (phi, rho) 35 | 36 | 37 | def calculate_FB_bases(L1): 38 | maxK = (2 * L1 + 1) ** 2 - 1 39 | 40 | L = L1 + 1 41 | R = L1 + 0.5 42 | 43 | truncate_freq_factor = 1.5 44 | 45 | if L1 < 2: 46 | truncate_freq_factor = 2 47 | 48 | xx, yy = np.meshgrid(range(-L, L + 1), range(-L, L + 1)) 49 | 50 | xx = xx / R 51 | yy = yy / R 52 | 53 | ugrid = np.concatenate([yy.reshape(-1, 1), xx.reshape(-1, 1)], 1) 54 | tgrid, rgrid = cart2pol(ugrid[:, 0], ugrid[:, 1]) 55 | 56 | num_grid_points = ugrid.shape[0] 57 | 58 | kmax = 15 59 | 60 | bessel = np.load(path_to_bessel) 61 | 62 | B = bessel[(bessel[:, 0] <= kmax) & (bessel[:, 3] <= np.pi * R * truncate_freq_factor)] 63 | 64 | idxB = np.argsort(B[:, 2]) 65 | 66 | mu_ns = B[idxB, 2] ** 2 67 | 68 | ang_freqs = B[idxB, 0] 69 | rad_freqs = B[idxB, 1] 70 | R_ns = B[idxB, 2] 71 | 72 | num_kq_all = len(ang_freqs) 73 | max_ang_freqs = max(ang_freqs) 74 | 75 | Phi_ns = np.zeros((num_grid_points, num_kq_all), np.float32) 76 | 77 | Psi = [] 78 | kq_Psi = [] 79 | num_bases = 0 80 | 81 | for i in range(B.shape[0]): 82 | ki = ang_freqs[i] 83 | qi = rad_freqs[i] 84 | rkqi = R_ns[i] 85 | 86 | r0grid = rgrid * R_ns[i] 87 | 88 | F = special.jv(ki, r0grid) 89 | 90 | Phi = 1. / np.abs(special.jv(ki + 1, R_ns[i])) * F 91 | 92 | Phi[rgrid >= 1] = 0 93 | 94 | Phi_ns[:, i] = Phi 95 | 96 | if ki == 0: 97 | Psi.append(Phi) 98 | kq_Psi.append([ki, qi, rkqi]) 99 | num_bases = num_bases + 1 100 | 101 | else: 102 | Psi.append(Phi * np.cos(ki * tgrid) * np.sqrt(2)) 103 | Psi.append(Phi * np.sin(ki * tgrid) * np.sqrt(2)) 104 | kq_Psi.append([ki, qi, rkqi]) 105 | kq_Psi.append([ki, qi, rkqi]) 106 | num_bases = num_bases + 2 107 | 108 | Psi = np.array(Psi) 109 | kq_Psi = np.array(kq_Psi) 110 | 111 | num_bases = Psi.shape[1] 112 | 113 | if num_bases > maxK: 114 | Psi = Psi[:maxK] 115 | kq_Psi = kq_Psi[:maxK] 116 | num_bases = Psi.shape[0] 117 | p = Psi.reshape(num_bases, 2 * L + 1, 2 * L + 1).transpose(1, 2, 0) 118 | psi = p[1:-1, 1:-1, :] 119 | # print(psi.shape) 120 | psi = psi.reshape((2 * L1 + 1) ** 2, num_bases) 121 | 122 | c = np.sqrt(np.sum(psi ** 2, 0).mean()) 123 | 124 | psi = psi / c 125 | 126 | return psi, c, kq_Psi 127 | -------------------------------------------------------------------------------- /scripts/select_best_checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | from tensorboard.backend.event_processing import event_accumulator 5 | import matplotlib.pyplot as plt 6 | import glob 7 | import time 8 | 9 | size_guidance = { 10 | event_accumulator.COMPRESSED_HISTOGRAMS: 0, # 跳过压缩直方图 11 | event_accumulator.IMAGES: 0, # 跳过图像数据 12 | event_accumulator.AUDIO: 0, # 跳过音频数据 13 | event_accumulator.SCALARS: 10000, # 仅保留标量数据 14 | event_accumulator.HISTOGRAMS: 0, # 跳过普通直方图 15 | event_accumulator.TENSORS: 0, # 跳过张量数据 16 | event_accumulator.GRAPH: 0, # 跳过计算图数据 17 | event_accumulator.META_GRAPH: 0 # 跳过元图数据 18 | } 19 | 20 | def process(experiment_name, epochs_per_val): 21 | log_list = sorted(glob.glob(f"tensorboard_logs/{experiment_name}/events.out.tfevents*")) 22 | 23 | tags_to_track = ["val/perceptual_loss/evaid", "val/perceptual_loss/hqf", "val/perceptual_loss/ijrr", "val/perceptual_loss/mvsec"] 24 | 25 | all_events = {tag: [] for tag in tags_to_track} 26 | 27 | for in_pth in log_list[:]: 28 | try: 29 | ea = event_accumulator.EventAccumulator(in_pth, size_guidance=size_guidance) 30 | ea.Reload() 31 | for tag in tags_to_track: 32 | events = ea.scalars.Items(tag) 33 | all_events[tag].extend(events) 34 | except: 35 | pass 36 | 37 | avg_metric_per_epoch = {} 38 | 39 | for tag in tags_to_track: 40 | # Each metric has different number of steps due to mixed dataset 41 | steps = [event.step for event in all_events[tag]] 42 | print(steps) 43 | steps = np.array(steps) 44 | epochs = np.zeros_like(steps) 45 | ep = 0 46 | for i in range(0, len(steps)): 47 | # if steps[i] - steps[i-1] > 500: 48 | # ep += 1 49 | # epochs[i] = ep 50 | epochs[i] = steps[i] // (381*epochs_per_val) 51 | max_ep = max(epochs) + 1 52 | metric = np.array([event.value for event in all_events[tag]]) 53 | avg_metric_per_epoch[tag] = np.zeros((max_ep)) 54 | 55 | for i in range(max_ep): 56 | sub_metrics = metric[epochs == i] 57 | if len(sub_metrics) > 0: 58 | avg_metric_per_epoch[tag][i] = np.mean(sub_metrics) 59 | else: 60 | avg_metric_per_epoch[tag][i] = 1e6 61 | 62 | num_epochs = len(avg_metric_per_epoch[tags_to_track[0]]) 63 | for tag in tags_to_track: 64 | assert len(avg_metric_per_epoch[tag]) == num_epochs, f"Length mismatch for {tag}: {len(avg_metric_per_epoch[tag])} vs {num_epochs}" 65 | 66 | loss_output_file = os.path.join("tensorboard_logs", experiment_name, "calc_val_loss_per_checkpoint.txt") 67 | all_total_loss = [] 68 | with open(loss_output_file, "w") as f: 69 | for i in range(num_epochs): 70 | # IMPORTANT: Only use the two 71 | total_loss = avg_metric_per_epoch["val/perceptual_loss/evaid"][i] + avg_metric_per_epoch["val/perceptual_loss/hqf"][i] 72 | #total_loss = avg_metric_per_epoch["val/perceptual_loss/hqf"][i] + avg_metric_per_epoch["val/perceptual_loss/ijrr"][i] + avg_metric_per_epoch["val/perceptual_loss/mvsec"][i] 73 | all_total_loss.append(total_loss) 74 | f.write(f"{total_loss:.03f}\n") 75 | 76 | best_idx = np.argmin(all_total_loss) 77 | print("best_idx: ", best_idx) 78 | all_checkpoints = sorted(glob.glob(f"checkpoints/{experiment_name}/*.pth")) 79 | print("Selected checkpoint: ", all_checkpoints[best_idx]) 80 | 81 | if __name__ == "__main__": 82 | # Take the experiment name as argument 83 | if len(sys.argv) > 2: 84 | epochs_per_val = int(sys.argv[2]) 85 | else: 86 | epochs_per_val = 1 87 | 88 | if len(sys.argv) > 1: 89 | experiment_name = sys.argv[1] 90 | start = time.time() 91 | process(experiment_name, epochs_per_val) 92 | end = time.time() 93 | print("Used seconds: ", end-start) 94 | else: 95 | print("Please provide the experiment name as argument.") -------------------------------------------------------------------------------- /model/eitr/transformer_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import copy 5 | 6 | 7 | class transformer_decoder(nn.Module): 8 | def __init__(self, d_model=256, nhead=8, num_decoder_layers=6, dim_feedforward=2048, activation='relu', dropout=0.1): 9 | super().__init__() 10 | self.d_model = d_model 11 | self.nhead = nhead 12 | 13 | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, 14 | dropout, activation) 15 | self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers) 16 | 17 | self._reset_parameters() 18 | 19 | def _reset_parameters(self): 20 | for p in self.parameters(): 21 | if p.dim() > 1: 22 | nn.init.xavier_uniform_(p) 23 | 24 | def forward(self, tgt, memory): 25 | output = self.decoder(tgt, memory) 26 | 27 | return output 28 | 29 | 30 | class TransformerDecoder(nn.Module): 31 | def __init__(self, encoder_layer, num_layers): 32 | super().__init__() 33 | self.layers = _get_clones(encoder_layer, num_layers) 34 | 35 | def forward(self, tgt, memory): 36 | output = tgt 37 | 38 | for layer in self.layers: 39 | output = layer(output, memory) 40 | 41 | return output 42 | 43 | 44 | class TransformerDecoderLayer(nn.Module): 45 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 46 | activation="relu"): 47 | super().__init__() 48 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 49 | self.sattn_dropout = nn.Dropout(dropout) 50 | self.norm1 = nn.LayerNorm(d_model) 51 | self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 52 | self.cattn_dropout = nn.Dropout(dropout) 53 | self.norm21 = nn.LayerNorm(d_model) 54 | self.norm22 = nn.LayerNorm(d_model) 55 | self.linear1 = nn.Linear(d_model, dim_feedforward) 56 | self.activation = _get_activation_fn(activation) 57 | self.ffn_dropout1 = nn.Dropout(dropout) 58 | self.linear2 = nn.Linear(dim_feedforward, d_model) 59 | self.ffn_dropout2 = nn.Dropout(dropout) 60 | self.norm3 = nn.LayerNorm(d_model) 61 | 62 | def with_embed(self, tensor, pos): 63 | return tensor if pos is None else tensor + pos 64 | 65 | def forward(self, tgt, memory): 66 | # self attention 67 | q = k = v = self.norm1(tgt) 68 | tgt1 = self.self_attn(q, k, v)[0] 69 | tgt2 = tgt + self.sattn_dropout(tgt1) 70 | 71 | # cross attention 72 | q = self.norm21(tgt2) 73 | k = v = self.norm22(memory) 74 | tgt3 = self.cross_attn(q, k, v)[0] 75 | tgt4 = tgt2 + self.cattn_dropout(tgt3) 76 | 77 | # FFN 78 | tgt5 = self.norm3(tgt4) 79 | tgt6 = self.linear2(self.ffn_dropout1(self.activation(self.linear1(tgt5)))) 80 | tgt7 = tgt4 + self.ffn_dropout2(tgt6) 81 | 82 | return tgt7 83 | 84 | 85 | def _get_clones(module, N): 86 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 87 | 88 | 89 | def _get_activation_fn(activation): 90 | """Return an activation function given a string""" 91 | if activation == "relu": 92 | return F.relu 93 | if activation == "gelu": 94 | return F.gelu 95 | if activation == "glu": 96 | return F.glu 97 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 98 | 99 | 100 | def build_transformer(args): 101 | return transformer(**args) 102 | -------------------------------------------------------------------------------- /config/train_ablation_e2vid_esim.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: ablation_e2vid_esim 2 | check_val_every_n_epoch: 6 3 | test_output_dir: results/ablation_e2vid_esim 4 | use_compile: false 5 | 6 | module: 7 | loss: 8 | lpips_weight: 1.0 9 | lpips_type: alex 10 | l2_weight: 0 11 | l1_weight: 0 12 | ssim_weight: 0 13 | temporal_consistency_weight: 1.0 14 | optical_flow_source: gt 15 | 16 | normalize_voxels: false 17 | model: 18 | target: model.model.E2VIDRecurrent 19 | params: 20 | unet_kwargs: 21 | num_bins: 5 22 | skip_type: sum 23 | recurrent_block_type: convlstm 24 | num_encoders: 3 25 | base_num_channels: 32 26 | num_residual_blocks: 2 27 | use_upsample_conv: true 28 | final_activation: "" 29 | norm: none 30 | 31 | train_stages: 32 | - stage_name: stage1 33 | max_epochs: 500 34 | 35 | optimizer: 36 | target: torch.optim.Adam 37 | params: 38 | lr: 0.0001 39 | weight_decay: 0 40 | amsgrad: true 41 | 42 | lr_scheduler: 43 | target: torch.optim.lr_scheduler.StepLR 44 | params: 45 | step_size: 50 46 | gamma: 1.0 # This is actually constant LR 47 | 48 | dataset: 49 | train_batch_size: 12 50 | num_workers: 9 51 | val_batch_size: 1 # Because test data has different sizes 52 | persistent_workers: true 53 | pin_memory: true 54 | 55 | train: 56 | - data_file: config/esim_h5.txt 57 | class_name: data.esim_dataset.ESIMH5Dataset 58 | data_source_name: esim 59 | sequence_length: 40 60 | proba_pause_when_running: 0.05 61 | proba_pause_when_paused: 0.9 62 | noise_std: 0.1 63 | noise_fraction: 1.0 64 | hot_pixel_std: 0.1 65 | max_hot_pixel_fraction: 0.001 66 | random_crop_size: 128 67 | random_flip: true 68 | 69 | val: 70 | - data_file: config/evaid_test.txt 71 | class_name: data.testh5.TestH5Dataset 72 | dataset_name: evaid 73 | num_bins: 5 74 | sequence_length: 80 75 | interpolate_bins: true 76 | max_samples: 1 # Limit val time, 720p runs really slow 77 | image_range: 1 78 | - data_file: config/ijrr_test.txt 79 | class_name: data.testh5.TestH5Dataset 80 | dataset_name: ijrr 81 | num_bins: 5 82 | sequence_length: 80 83 | interpolate_bins: true 84 | image_range: 1 85 | - data_file: config/hqf_test.txt 86 | class_name: data.testh5.TestH5Dataset 87 | dataset_name: hqf 88 | num_bins: 5 89 | sequence_length: 80 90 | interpolate_bins: true 91 | image_range: 1 92 | - data_file: config/mvsec_test.txt 93 | class_name: data.testh5.TestH5Dataset 94 | dataset_name: mvsec 95 | num_bins: 5 96 | sequence_length: 80 97 | interpolate_bins: true 98 | image_range: 1 99 | 100 | test_stage: 101 | need_do_gamma: False 102 | test_batch_size: 1 103 | test_num_workers: 16 104 | test: 105 | - data_file: config/evaid_test.txt 106 | class_name: data.testh5.TestH5Dataset 107 | dataset_name: evaid 108 | num_bins: 5 109 | sequence_length: 80 110 | interpolate_bins: true 111 | - data_file: config/ijrr_test.txt 112 | class_name: data.testh5.TestH5Dataset 113 | dataset_name: ijrr 114 | num_bins: 5 115 | sequence_length: 80 116 | interpolate_bins: true 117 | - data_file: config/hqf_test.txt 118 | class_name: data.testh5.TestH5Dataset 119 | dataset_name: hqf 120 | num_bins: 5 121 | sequence_length: 80 122 | interpolate_bins: true 123 | - data_file: config/mvsec_test.txt 124 | class_name: data.testh5.TestH5Dataset 125 | dataset_name: mvsec 126 | num_bins: 5 127 | sequence_length: 80 128 | interpolate_bins: true -------------------------------------------------------------------------------- /scripts/save_gt_images.py: -------------------------------------------------------------------------------- 1 | # This tool script saves the ground truth images & event visualizations of the test datasets for aligned comparison (such as for making videos with scripts/make_ref_videos.py). 2 | 3 | import os 4 | import sys 5 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 6 | import yaml 7 | import torch 8 | import tqdm 9 | import cv2 10 | import numpy as np 11 | from collections import defaultdict 12 | from utils.data import data_sources 13 | 14 | from data.data_interface import make_concat_multi_dataset 15 | from torch.utils.data import DataLoader 16 | 17 | def map_color(val, clip=10): 18 | BLUE = np.expand_dims(np.expand_dims(np.array([255, 0, 0]), 0), 0) 19 | RED = np.expand_dims(np.expand_dims(np.array([0, 0, 255]), 0), 0) 20 | WHITE = np.expand_dims(np.expand_dims(np.array([255, 255, 255]), 0), 0) 21 | val = np.clip(val, -clip, clip) 22 | val = np.expand_dims(val, -1) 23 | red_side = (1 - val / clip) * WHITE + (val / clip) * RED 24 | blue_side = (1 + val / clip) * WHITE + (-val / clip) * BLUE 25 | return np.where(val > 0, red_side, blue_side).astype(np.uint8) 26 | 27 | 28 | def create_test_dataloader(stage_cfg): 29 | dataset = make_concat_multi_dataset(stage_cfg["test"]) 30 | dataloader = DataLoader(dataset, 31 | batch_size=1, 32 | num_workers=stage_cfg["test_num_workers"], 33 | shuffle=False) 34 | return dataloader 35 | 36 | def save_gt(dataloader, output_dir): 37 | 38 | previous_test_sequence = None 39 | 40 | with torch.no_grad(): 41 | 42 | for batch_idx, batch in enumerate(tqdm.tqdm(dataloader)): 43 | sequence_name = batch["sequence_name"][0][0] 44 | 45 | if previous_test_sequence is None or previous_test_sequence != sequence_name: 46 | output_img_idx = 0 47 | if output_dir is not None: 48 | data_source_idx = batch["data_source_idx"][0] 49 | data_source = data_sources[data_source_idx].upper() 50 | seq_output_dir = os.path.join(output_dir, data_source, sequence_name) 51 | #print("seq_output_dir:", seq_output_dir) 52 | os.makedirs(seq_output_dir, exist_ok=True) 53 | 54 | pred = batch["frame"] 55 | 56 | if output_dir is not None: 57 | one, T, one, H, W = pred.shape 58 | for t in range(T): 59 | img = pred[0, t, 0].cpu().numpy() 60 | img = np.clip(img, 0, 255).astype(np.uint8) 61 | cv2.imwrite(os.path.join(seq_output_dir, f"{output_img_idx:06d}.png"), img) 62 | output_img_idx += 1 63 | 64 | previous_test_sequence = sequence_name 65 | 66 | def save_evs(dataloader, output_dir): 67 | 68 | previous_test_sequence = None 69 | 70 | with torch.no_grad(): 71 | 72 | for batch_idx, batch in enumerate(tqdm.tqdm(dataloader)): 73 | sequence_name = batch["sequence_name"][0][0] 74 | 75 | if previous_test_sequence is None or previous_test_sequence != sequence_name: 76 | output_img_idx = 0 77 | if output_dir is not None: 78 | data_source_idx = batch["data_source_idx"][0] 79 | data_source = data_sources[data_source_idx].upper() 80 | seq_output_dir = os.path.join(output_dir, data_source, sequence_name) 81 | #print("seq_output_dir:", seq_output_dir) 82 | os.makedirs(seq_output_dir, exist_ok=True) 83 | 84 | voxel = batch["events"] 85 | one, T, B, H, W = voxel.shape 86 | for t in range(T): 87 | vis = map_color(voxel[0, t, :, :, :].sum(axis=0).cpu().numpy(), clip=5) 88 | 89 | cv2.imwrite(os.path.join(seq_output_dir, f"{output_img_idx:06d}.png"), vis) 90 | output_img_idx += 1 91 | 92 | previous_test_sequence = sequence_name 93 | 94 | 95 | def main(output_dir, output_evs_dir): 96 | # Add two arguments. 97 | # Argument 1: config_path 98 | # Argument 2 (optional): test_all_pths (default=False) 99 | if len(sys.argv) > 1: 100 | config_path = sys.argv[1] 101 | else: 102 | config_path = "configs/template.yaml" 103 | 104 | with open(config_path) as f: 105 | config = yaml.load(f, Loader=yaml.Loader) 106 | 107 | test_dataloader = create_test_dataloader(config["test_stage"]) 108 | save_evs(test_dataloader, output_evs_dir) 109 | save_gt(test_dataloader, output_dir) 110 | 111 | if __name__ == "__main__": 112 | main("results/gt_images", "results/gt_evs") -------------------------------------------------------------------------------- /utils/color_utils.py: -------------------------------------------------------------------------------- 1 | from .timers import Timer, CudaTimer 2 | import numpy as np 3 | import cv2 4 | 5 | 6 | def shift_image(X, dx, dy): 7 | X = np.roll(X, dy, axis=0) 8 | X = np.roll(X, dx, axis=1) 9 | if dy > 0: 10 | X[:dy, :] = np.expand_dims(X[dy, :], axis=0) 11 | elif dy < 0: 12 | X[dy:, :] = np.expand_dims(X[dy, :], axis=0) 13 | if dx > 0: 14 | X[:, :dx] = np.expand_dims(X[:, dx], axis=1) 15 | elif dx < 0: 16 | X[:, dx:] = np.expand_dims(X[:, dx], axis=1) 17 | return X 18 | 19 | 20 | def upsample_color_image(grayscale_highres, color_lowres_bgr, colorspace='LAB'): 21 | """ 22 | Generate a high res color image from a high res grayscale image, and a low res color image, 23 | using the trick described in: 24 | http://www.planetary.org/blogs/emily-lakdawalla/2013/04231204-image-processing-colorizing-images.html 25 | """ 26 | assert(len(grayscale_highres.shape) == 2) 27 | assert(len(color_lowres_bgr.shape) == 3 and color_lowres_bgr.shape[2] == 3) 28 | 29 | if colorspace == 'LAB': 30 | # convert color image to LAB space 31 | lab = cv2.cvtColor(src=color_lowres_bgr, code=cv2.COLOR_BGR2LAB) 32 | # replace lightness channel with the highres image 33 | lab[:, :, 0] = grayscale_highres 34 | # convert back to BGR 35 | color_highres_bgr = cv2.cvtColor(src=lab, code=cv2.COLOR_LAB2BGR) 36 | elif colorspace == 'HSV': 37 | # convert color image to HSV space 38 | hsv = cv2.cvtColor(src=color_lowres_bgr, code=cv2.COLOR_BGR2HSV) 39 | # replace value channel with the highres image 40 | hsv[:, :, 2] = grayscale_highres 41 | # convert back to BGR 42 | color_highres_bgr = cv2.cvtColor(src=hsv, code=cv2.COLOR_HSV2BGR) 43 | elif colorspace == 'HLS': 44 | # convert color image to HLS space 45 | hls = cv2.cvtColor(src=color_lowres_bgr, code=cv2.COLOR_BGR2HLS) 46 | # replace lightness channel with the highres image 47 | hls[:, :, 1] = grayscale_highres 48 | # convert back to BGR 49 | color_highres_bgr = cv2.cvtColor(src=hls, code=cv2.COLOR_HLS2BGR) 50 | 51 | return color_highres_bgr 52 | 53 | 54 | def merge_channels_into_color_image(channels): 55 | """ 56 | Combine a full resolution grayscale reconstruction and four color channels at half resolution 57 | into a color image at full resolution. 58 | 59 | :param channels: dictionary containing the four color reconstructions (at quarter resolution), 60 | and the full resolution grayscale reconstruction. 61 | :return a color image at full resolution 62 | """ 63 | 64 | with Timer('Merge color channels'): 65 | 66 | assert('R' in channels) 67 | assert('G' in channels) 68 | assert('W' in channels) 69 | assert('B' in channels) 70 | assert('grayscale' in channels) 71 | 72 | # upsample each channel independently 73 | for channel in ['R', 'G', 'W', 'B']: 74 | channels[channel] = cv2.resize(channels[channel], dsize=None, fx=2, fy=2, interpolation=cv2.INTER_LINEAR) 75 | 76 | # Shift the channels so that they all have the same origin 77 | channels['B'] = shift_image(channels['B'], dx=1, dy=1) 78 | channels['G'] = shift_image(channels['G'], dx=1, dy=0) 79 | channels['W'] = shift_image(channels['W'], dx=0, dy=1) 80 | 81 | # reconstruct the color image at half the resolution using the reconstructed channels RGBW 82 | reconstruction_bgr = np.dstack([channels['B'], 83 | cv2.addWeighted(src1=channels['G'], alpha=0.5, 84 | src2=channels['W'], beta=0.5, 85 | gamma=0.0, dtype=cv2.CV_8U), 86 | channels['R']]) 87 | 88 | reconstruction_grayscale = channels['grayscale'] 89 | 90 | # combine the full res grayscale resolution with the low res to get a full res color image 91 | upsampled_img = upsample_color_image(reconstruction_grayscale, reconstruction_bgr) 92 | return upsampled_img 93 | -------------------------------------------------------------------------------- /config/train_v2v_etnet_10k.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: v2v_etnet_10k 2 | check_val_every_n_epoch: 1 3 | test_output_dir: results/v2v_etnet_10k 4 | use_compile: false 5 | 6 | module: 7 | loss: 8 | lpips_weight: 1.0 9 | lpips_type: vgg 10 | l2_weight: 0 11 | l1_weight: 1.0 12 | ssim_weight: 0 13 | temporal_consistency_weight: 1.0 14 | optical_flow_source: raft_small 15 | temporal_consistency_L0: 20 16 | 17 | normalize_voxels: false 18 | model: 19 | target: model.eitr.eitr.EITR 20 | params: 21 | eitr_kwargs: 22 | num_bins: 5 23 | norm: none 24 | 25 | train_stages: 26 | - stage_name: stage1 27 | max_epochs: 100 28 | 29 | optimizer: 30 | target: torch.optim.AdamW 31 | params: 32 | lr: 0.0002 33 | weight_decay: 0.01 34 | amsgrad: true 35 | 36 | lr_scheduler: 37 | target: torch.optim.lr_scheduler.ExponentialLR 38 | params: 39 | gamma: 0.94 # 0.99**700 ~ 1e-3 0.94**110 40 | 41 | dataset: 42 | train_batch_size: 6 43 | num_workers: 9 44 | val_batch_size: 1 # Because test data has different sizes 45 | persistent_workers: true 46 | pin_memory: true 47 | 48 | train: 49 | - data_file: config/webvid_root.txt 50 | class_name: data.v2v_datasets.WebvidDatasetV2 51 | video_list_file: config/webvid10000_filtered.txt # Inconsistent with other models due to training accident 52 | data_source_name: webvid 53 | video_reader: opencv 54 | sequence_length: 40 55 | pause_granularity: 5 56 | proba_pause_when_running: 0.0102 57 | proba_pause_when_paused: 0.9791 58 | crop_size: 128 59 | random_flip: true 60 | num_bins: 5 61 | min_resize_scale: 1 62 | max_resize_scale: 1 63 | frames_per_bin: 1 64 | threshold_range: [0.05, 2] 65 | max_thres_pos_neg_gap: 1.5 66 | base_noise_std_range: [0, 0.1] 67 | hot_pixel_std_range: [0, 10] 68 | max_samples_per_shot: 10 69 | 70 | val: 71 | - data_file: config/evaid_test.txt 72 | class_name: data.testh5.TestH5Dataset 73 | dataset_name: evaid 74 | num_bins: 5 75 | sequence_length: 80 76 | interpolate_bins: false 77 | max_samples: 1 # Limit val time, 720p runs really slow 78 | image_range: 1 79 | - data_file: config/ijrr_test.txt 80 | class_name: data.testh5.TestH5Dataset 81 | dataset_name: ijrr 82 | num_bins: 5 83 | sequence_length: 80 84 | interpolate_bins: false 85 | image_range: 1 86 | - data_file: config/hqf_test.txt 87 | class_name: data.testh5.TestH5Dataset 88 | dataset_name: hqf 89 | num_bins: 5 90 | sequence_length: 80 91 | interpolate_bins: false 92 | image_range: 1 93 | - data_file: config/mvsec_test.txt 94 | class_name: data.testh5.TestH5Dataset 95 | dataset_name: mvsec 96 | num_bins: 5 97 | sequence_length: 80 98 | interpolate_bins: false 99 | image_range: 1 100 | 101 | test_stage: 102 | test_batch_size: 1 103 | test_num_workers: 4 104 | test: 105 | # - data_file: config/evaid_test.txt 106 | # class_name: data.testh5.TestH5Dataset 107 | # dataset_name: evaid 108 | # num_bins: 5 109 | # sequence_length: 80 # Decrease test sequence length to save GPU memory, it won't have any behaviour impact 110 | # interpolate_bins: false 111 | - data_file: config/ijrr_test.txt 112 | class_name: data.testh5.TestH5Dataset 113 | dataset_name: ijrr 114 | num_bins: 5 115 | sequence_length: 80 116 | interpolate_bins: false 117 | - data_file: config/hqf_test.txt 118 | class_name: data.testh5.TestH5Dataset 119 | dataset_name: hqf 120 | num_bins: 5 121 | sequence_length: 80 122 | interpolate_bins: false 123 | - data_file: config/mvsec_test.txt 124 | class_name: data.testh5.TestH5Dataset 125 | dataset_name: mvsec 126 | num_bins: 5 127 | sequence_length: 80 128 | interpolate_bins: false -------------------------------------------------------------------------------- /config/train_v2v_evflow_10k.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: v2v_evflow_10k 2 | check_val_every_n_epoch: 1 3 | test_output_dir: results/v2v_evflow_10k 4 | save_npy: false 5 | save_png: true 6 | use_compile: false 7 | task: flow 8 | 9 | module: 10 | loss: 11 | l1_weight: 1.0 12 | optical_flow_source: raft_large 13 | raft_num_flow_updates: 12 14 | 15 | normalize_voxels: false 16 | model: 17 | target: model.model.EVFlowNet 18 | params: 19 | unet_kwargs: 20 | num_bins: 5 21 | base_num_channels: 32 22 | num_encoders: 4 23 | num_residual_blocks: 2 24 | num_output_channels: 2 25 | skip_type: concat 26 | norm: null 27 | use_upsample_conv: true 28 | kernel_size: 3 29 | channel_multiplier: 2 30 | 31 | train_stages: 32 | - stage_name: stage1 33 | max_epochs: 50 34 | 35 | optimizer: 36 | target: torch.optim.Adam 37 | params: 38 | lr: 0.0001 39 | weight_decay: 0 40 | amsgrad: true 41 | 42 | lr_scheduler: 43 | target: torch.optim.lr_scheduler.StepLR 44 | params: 45 | step_size: 50 46 | gamma: 1.0 # This is actually constant LR 47 | 48 | dataset: 49 | train_batch_size: 10 50 | num_workers: 10 51 | val_batch_size: 1 # Because test data has different sizes 52 | persistent_workers: true 53 | pin_memory: true 54 | 55 | train: 56 | - data_file: config/webvid_root.txt 57 | class_name: data.v2v_datasets.WebvidDatasetV2 58 | video_list_file: config/webvid10000_full.txt 59 | data_source_name: webvid 60 | video_reader: opencv 61 | sequence_length: 40 62 | pause_granularity: 5 63 | proba_pause_when_running: 0.0102 64 | proba_pause_when_paused: 0.9791 65 | crop_size: 128 66 | random_flip: true 67 | num_bins: 5 68 | min_resize_scale: 1 69 | max_resize_scale: 1 70 | frames_per_bin: 1 71 | threshold_range: [0.05, 2] 72 | max_thres_pos_neg_gap: 1.5 73 | base_noise_std_range: [0, 0.1] 74 | hot_pixel_std_range: [0, 10] 75 | max_samples_per_shot: 10 76 | output_additional_frame: true 77 | 78 | val: 79 | - data_file: config/hqf_test.txt 80 | class_name: data.testh5.TestH5Dataset 81 | dataset_name: hqf 82 | num_bins: 5 83 | sequence_length: 80 84 | interpolate_bins: false 85 | output_additional_frame: true 86 | image_range: 1 87 | max_samples: 1 88 | - data_file: config/mvsec_test_flow.txt 89 | class_name: data.testh5.TestH5FlowDataset 90 | dataset_name: mvsec 91 | num_bins: 5 92 | sequence_length: 80 93 | interpolate_bins: false 94 | output_additional_frame: true 95 | image_range: 1 96 | 97 | test_stage: 98 | test_batch_size: 1 99 | test_num_workers: 16 100 | test: 101 | # - data_file: config/evaid_test.txt 102 | # class_name: data.testh5.TestH5Dataset 103 | # dataset_name: evaid 104 | # num_bins: 5 105 | # sequence_length: 10 106 | # interpolate_bins: false 107 | # output_additional_frame: true 108 | # image_range: 1 109 | # max_samples: 4 110 | - data_file: config/ijrr_test.txt 111 | class_name: data.testh5.TestH5Dataset 112 | dataset_name: ijrr 113 | num_bins: 5 114 | sequence_length: 80 115 | interpolate_bins: false 116 | output_additional_frame: true 117 | image_range: 1 118 | max_samples: 2 119 | - data_file: config/hqf_test.txt 120 | class_name: data.testh5.TestH5Dataset 121 | dataset_name: hqf 122 | num_bins: 5 123 | sequence_length: 80 124 | interpolate_bins: false 125 | output_additional_frame: true 126 | image_range: 1 127 | max_samples: 2 128 | - data_file: config/mvsec_test_flow.txt 129 | class_name: data.testh5.TestH5FlowDataset 130 | dataset_name: mvsec 131 | num_bins: 5 132 | sequence_length: 80 133 | interpolate_bins: false 134 | output_additional_frame: true 135 | image_range: 1 -------------------------------------------------------------------------------- /config/train_v2v_eraft_10k.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: v2v_eraft_10k 2 | check_val_every_n_epoch: 1 3 | test_output_dir: results/v2v_eraft_10k 4 | save_npy: false 5 | save_png: true 6 | use_compile: false 7 | task: flow 8 | 9 | module: 10 | forward_type: eraft 11 | 12 | loss: 13 | l1_weight: 1.0 14 | optical_flow_source: raft_large 15 | raft_num_flow_updates: 12 16 | 17 | normalize_voxels: false 18 | model: 19 | target: model.eraft.eraft.ERAFT 20 | params: 21 | config: 22 | subtype: warm_start 23 | n_first_channels: 5 24 | 25 | train_stages: 26 | - stage_name: stage1 27 | max_epochs: 50 28 | 29 | optimizer: 30 | target: torch.optim.Adam 31 | params: 32 | lr: 0.0001 33 | weight_decay: 0 34 | amsgrad: true 35 | 36 | lr_scheduler: 37 | target: torch.optim.lr_scheduler.StepLR 38 | params: 39 | step_size: 50 40 | gamma: 1.0 # This is actually constant LR 41 | 42 | dataset: 43 | train_batch_size: 10 44 | num_workers: 10 45 | val_batch_size: 1 # Because test data has different sizes 46 | persistent_workers: true 47 | pin_memory: true 48 | 49 | train: 50 | - data_file: config/webvid_root.txt 51 | class_name: data.v2v_datasets.WebvidDatasetV2 52 | video_list_file: config/webvid10000_unfiltered.txt 53 | data_source_name: webvid 54 | video_reader: opencv 55 | sequence_length: 40 56 | pause_granularity: 5 57 | proba_pause_when_running: 0.0102 58 | proba_pause_when_paused: 0.9791 59 | crop_size: 128 60 | random_flip: true 61 | num_bins: 5 62 | min_resize_scale: 1 63 | max_resize_scale: 1 64 | frames_per_bin: 1 65 | threshold_range: [0.05, 2] 66 | max_thres_pos_neg_gap: 1.5 67 | base_noise_std_range: [0, 0.1] 68 | hot_pixel_std_range: [0, 10] 69 | max_samples_per_shot: 10 70 | output_additional_frame: true 71 | output_additional_evs: true 72 | 73 | val: 74 | - data_file: config/hqf_test.txt 75 | class_name: data.testh5.TestH5Dataset 76 | dataset_name: hqf 77 | num_bins: 5 78 | sequence_length: 80 79 | interpolate_bins: false 80 | output_additional_frame: true 81 | output_additional_evs: true 82 | image_range: 1 83 | max_samples: 1 84 | - data_file: config/mvsec_test_flow.txt 85 | class_name: data.testh5.TestH5FlowDataset 86 | dataset_name: mvsec 87 | num_bins: 5 88 | sequence_length: 80 89 | interpolate_bins: false 90 | output_additional_frame: true 91 | output_additional_evs: true 92 | image_range: 1 93 | 94 | test_stage: 95 | test_batch_size: 1 96 | test_num_workers: 16 97 | test: 98 | # - data_file: config/evaid_test.txt 99 | # class_name: data.testh5.TestH5Dataset 100 | # dataset_name: evaid 101 | # num_bins: 5 102 | # sequence_length: 10 103 | # interpolate_bins: true 104 | # output_additional_frame: true 105 | # output_additional_evs: true 106 | # image_range: 1 107 | # max_samples: 2 108 | - data_file: config/ijrr_test.txt 109 | class_name: data.testh5.TestH5Dataset 110 | dataset_name: ijrr 111 | num_bins: 5 112 | sequence_length: 80 113 | interpolate_bins: true 114 | output_additional_frame: true 115 | output_additional_evs: true 116 | image_range: 1 117 | max_samples: 2 118 | - data_file: config/hqf_test.txt 119 | class_name: data.testh5.TestH5Dataset 120 | dataset_name: hqf 121 | num_bins: 5 122 | sequence_length: 80 123 | interpolate_bins: true 124 | output_additional_frame: true 125 | output_additional_evs: true 126 | image_range: 1 127 | max_samples: 2 128 | - data_file: config/mvsec_test_flow.txt 129 | class_name: data.testh5.TestH5FlowDataset 130 | dataset_name: mvsec 131 | num_bins: 5 132 | sequence_length: 80 133 | interpolate_bins: false 134 | output_additional_frame: true 135 | output_additional_evs: true 136 | image_range: 1 -------------------------------------------------------------------------------- /config/train_v2v_e2vid_10k.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: v2v_e2vid_10k 2 | check_val_every_n_epoch: 1 3 | test_output_dir: results/v2v_e2vid_10k 4 | use_compile: false 5 | 6 | module: 7 | loss: 8 | lpips_weight: 1.0 9 | lpips_type: vgg 10 | l2_weight: 0 11 | l1_weight: 1.0 12 | ssim_weight: 0 13 | temporal_consistency_weight: 1.0 14 | optical_flow_source: raft_small 15 | temporal_consistency_L0: 20 16 | 17 | normalize_voxels: false 18 | model: 19 | target: model.model.E2VIDRecurrent 20 | params: 21 | unet_kwargs: 22 | num_bins: 5 23 | skip_type: sum 24 | recurrent_block_type: convlstm 25 | num_encoders: 3 26 | base_num_channels: 32 27 | num_residual_blocks: 2 28 | use_upsample_conv: true 29 | final_activation: "" 30 | norm: none 31 | 32 | train_stages: 33 | - stage_name: stage1 34 | max_epochs: 80 35 | 36 | optimizer: 37 | target: torch.optim.Adam 38 | params: 39 | lr: 0.0001 40 | weight_decay: 0 41 | amsgrad: true 42 | 43 | lr_scheduler: 44 | target: torch.optim.lr_scheduler.StepLR 45 | params: 46 | step_size: 50 47 | gamma: 1.0 # This is actually constant LR 48 | 49 | dataset: 50 | train_batch_size: 12 51 | num_workers: 9 52 | val_batch_size: 1 # Because test data has different sizes 53 | persistent_workers: true 54 | pin_memory: true 55 | 56 | train: 57 | - data_file: config/webvid_root.txt 58 | class_name: data.v2v_datasets.WebvidDatasetV2 59 | video_list_file: config/webvid10000_unfiltered.txt 60 | data_source_name: webvid 61 | video_reader: opencv 62 | sequence_length: 40 63 | pause_granularity: 5 64 | proba_pause_when_running: 0.0102 65 | proba_pause_when_paused: 0.9791 66 | crop_size: 128 67 | random_flip: true 68 | num_bins: 5 69 | min_resize_scale: 1 70 | max_resize_scale: 1 71 | frames_per_bin: 1 72 | threshold_range: [0.05, 2] 73 | max_thres_pos_neg_gap: 1.5 74 | base_noise_std_range: [0, 0.1] 75 | hot_pixel_std_range: [0, 10] 76 | max_samples_per_shot: 10 77 | 78 | val: 79 | - data_file: config/evaid_test.txt 80 | class_name: data.testh5.TestH5Dataset 81 | dataset_name: evaid 82 | num_bins: 5 83 | sequence_length: 80 84 | interpolate_bins: false 85 | max_samples: 1 # Limit val time, 720p runs really slow 86 | image_range: 1 87 | - data_file: config/ijrr_test.txt 88 | class_name: data.testh5.TestH5Dataset 89 | dataset_name: ijrr 90 | num_bins: 5 91 | sequence_length: 80 92 | interpolate_bins: false 93 | image_range: 1 94 | - data_file: config/hqf_test.txt 95 | class_name: data.testh5.TestH5Dataset 96 | dataset_name: hqf 97 | num_bins: 5 98 | sequence_length: 80 99 | interpolate_bins: false 100 | image_range: 1 101 | - data_file: config/mvsec_test.txt 102 | class_name: data.testh5.TestH5Dataset 103 | dataset_name: mvsec 104 | num_bins: 5 105 | sequence_length: 80 106 | interpolate_bins: false 107 | image_range: 1 108 | 109 | test_stage: 110 | test_batch_size: 1 111 | test_num_workers: 4 112 | test: 113 | - data_file: config/evaid_test.txt 114 | class_name: data.testh5.TestH5Dataset 115 | dataset_name: evaid 116 | num_bins: 5 117 | sequence_length: 80 118 | interpolate_bins: false 119 | - data_file: config/ijrr_test.txt 120 | class_name: data.testh5.TestH5Dataset 121 | dataset_name: ijrr 122 | num_bins: 5 123 | sequence_length: 80 124 | interpolate_bins: false 125 | - data_file: config/hqf_test.txt 126 | class_name: data.testh5.TestH5Dataset 127 | dataset_name: hqf 128 | num_bins: 5 129 | sequence_length: 80 130 | interpolate_bins: false 131 | - data_file: config/mvsec_test.txt 132 | class_name: data.testh5.TestH5Dataset 133 | dataset_name: mvsec 134 | num_bins: 5 135 | sequence_length: 80 136 | interpolate_bins: false -------------------------------------------------------------------------------- /config/train_ablation_e2vid_filtered.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: ablation_e2vid_filtered 2 | check_val_every_n_epoch: 1 3 | test_output_dir: results/ablation_e2vid_filtered 4 | use_compile: false 5 | 6 | module: 7 | loss: 8 | lpips_weight: 1.0 9 | lpips_type: vgg 10 | l2_weight: 0 11 | l1_weight: 1.0 12 | ssim_weight: 0 13 | temporal_consistency_weight: 1.0 14 | optical_flow_source: raft_small 15 | temporal_consistency_L0: 20 16 | 17 | normalize_voxels: false 18 | model: 19 | target: model.model.E2VIDRecurrent 20 | params: 21 | unet_kwargs: 22 | num_bins: 5 23 | skip_type: sum 24 | recurrent_block_type: convlstm 25 | num_encoders: 3 26 | base_num_channels: 32 27 | num_residual_blocks: 2 28 | use_upsample_conv: true 29 | final_activation: "" 30 | norm: none 31 | 32 | train_stages: 33 | - stage_name: stage1 34 | max_epochs: 80 35 | 36 | optimizer: 37 | target: torch.optim.Adam 38 | params: 39 | lr: 0.0001 40 | weight_decay: 0 41 | amsgrad: true 42 | 43 | lr_scheduler: 44 | target: torch.optim.lr_scheduler.StepLR 45 | params: 46 | step_size: 50 47 | gamma: 1.0 # This is actually constant LR 48 | 49 | dataset: 50 | train_batch_size: 3 51 | num_workers: 9 52 | val_batch_size: 1 # Because test data has different sizes 53 | persistent_workers: true 54 | pin_memory: true 55 | 56 | train: 57 | - data_file: config/webvid_root.txt 58 | class_name: data.v2v_datasets.WebvidDatasetV2 59 | video_list_file: config/webvid100_unfiltered.txt # Use different video list 60 | data_source_name: webvid 61 | video_reader: opencv 62 | sequence_length: 40 63 | pause_granularity: 5 64 | proba_pause_when_running: 0.0102 65 | proba_pause_when_paused: 0.9791 66 | crop_size: 128 67 | random_flip: true 68 | num_bins: 5 69 | min_resize_scale: 1 70 | max_resize_scale: 1 71 | frames_per_bin: 1 72 | threshold_range: [0.05, 2] 73 | max_thres_pos_neg_gap: 1.5 74 | base_noise_std_range: [0, 0.1] 75 | hot_pixel_std_range: [0, 10] 76 | max_samples_per_shot: 10 77 | 78 | val: 79 | - data_file: config/evaid_test.txt 80 | class_name: data.testh5.TestH5Dataset 81 | dataset_name: evaid 82 | num_bins: 5 83 | sequence_length: 80 84 | interpolate_bins: false 85 | max_samples: 1 # Limit val time, 720p runs really slow 86 | image_range: 1 87 | - data_file: config/ijrr_test.txt 88 | class_name: data.testh5.TestH5Dataset 89 | dataset_name: ijrr 90 | num_bins: 5 91 | sequence_length: 80 92 | interpolate_bins: false 93 | image_range: 1 94 | - data_file: config/hqf_test.txt 95 | class_name: data.testh5.TestH5Dataset 96 | dataset_name: hqf 97 | num_bins: 5 98 | sequence_length: 80 99 | interpolate_bins: false 100 | image_range: 1 101 | - data_file: config/mvsec_test.txt 102 | class_name: data.testh5.TestH5Dataset 103 | dataset_name: mvsec 104 | num_bins: 5 105 | sequence_length: 80 106 | interpolate_bins: false 107 | image_range: 1 108 | 109 | test_stage: 110 | test_batch_size: 1 111 | test_num_workers: 16 112 | test: 113 | - data_file: config/evaid_test.txt 114 | class_name: data.testh5.TestH5Dataset 115 | dataset_name: evaid 116 | num_bins: 5 117 | sequence_length: 80 118 | interpolate_bins: false 119 | - data_file: config/ijrr_test.txt 120 | class_name: data.testh5.TestH5Dataset 121 | dataset_name: ijrr 122 | num_bins: 5 123 | sequence_length: 80 124 | interpolate_bins: false 125 | - data_file: config/hqf_test.txt 126 | class_name: data.testh5.TestH5Dataset 127 | dataset_name: hqf 128 | num_bins: 5 129 | sequence_length: 80 130 | interpolate_bins: false 131 | - data_file: config/mvsec_test.txt 132 | class_name: data.testh5.TestH5Dataset 133 | dataset_name: mvsec 134 | num_bins: 5 135 | sequence_length: 80 136 | interpolate_bins: false -------------------------------------------------------------------------------- /model/nernet_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import copy 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | # local modules 7 | 8 | from .model_util import CropParameters, recursive_clone 9 | from .base.base_model import BaseModel 10 | 11 | from .nernet.representation_modules import Voxelization 12 | from .nernet.unet import UNetNIAM_STcell_GCB, UNetRecurrent 13 | 14 | def copy_states(states): 15 | """ 16 | LSTM states: [(torch.tensor, torch.tensor), ...] 17 | GRU states: [torch.tensor, ...] 18 | """ 19 | if states[0] is None: 20 | return copy.deepcopy(states) 21 | return recursive_clone(states) 22 | 23 | class RepresentationRecurrent(BaseModel): 24 | """ 25 | Compatible with E2VID_lightweight and Representation network 26 | Recurrent, UNet-like architecture where each encoder is followed by a ConvLSTM or ConvGRU. 27 | """ 28 | def __init__(self, unet_kwargs): 29 | super().__init__() 30 | self.num_bins = unet_kwargs['num_bins'] # legacy 31 | self.num_encoders = unet_kwargs['num_encoders'] # legacy 32 | self.crop_size = unet_kwargs['crop_size'] 33 | self.mlp_layers = unet_kwargs['mlp_layers'] 34 | self.normalize = unet_kwargs['normalize'] 35 | # self.normalize = False 36 | self.height = None 37 | self.width = None 38 | self.representation = None 39 | 40 | self.unet_kwargs = unet_kwargs 41 | 42 | self.network = unet_kwargs['recurrent_network'] 43 | if self.network == 'NIAM_STcell_GCB': 44 | self.unetrecurrent = UNetNIAM_STcell_GCB(unet_kwargs) 45 | else: 46 | self.unetrecurrent = UNetRecurrent(unet_kwargs) 47 | 48 | self.set_resolution(256, 256) # Make placeholder so weights can be loaded in 49 | 50 | def set_resolution(self, H, W): 51 | if self.height is None or self.width is None: 52 | # First time setting resolution 53 | self.height = H 54 | self.width = W 55 | # Reset the resolution. 56 | device = next(self.unetrecurrent.parameters()).device 57 | self.representation = Voxelization(self.unet_kwargs, self.unet_kwargs['use_cnn_representation'], voxel_dimension=(self.num_bins, self.height, self.width), mlp_layers=self.mlp_layers, activation=nn.LeakyReLU(negative_slope=0.1), pretrained=True, normalize=self.normalize, combine_voxel=self.unet_kwargs['combine_voxel']).to(device) 58 | self.crop = CropParameters(self.width, self.height, self.num_encoders) 59 | return 60 | 61 | if H != self.height or W != self.width: 62 | # Resolution has changed. Keep the network parameters of Voxelization. 63 | old_state_dict = self.representation.state_dict() 64 | self.height = H 65 | self.width = W 66 | # Reset the resolution. 67 | device = next(self.unetrecurrent.parameters()).device 68 | self.representation = Voxelization(self.unet_kwargs, self.unet_kwargs['use_cnn_representation'], voxel_dimension=(self.num_bins, self.height, self.width), mlp_layers=self.mlp_layers, activation=nn.LeakyReLU(negative_slope=0.1), pretrained=True, normalize=self.normalize, combine_voxel=self.unet_kwargs['combine_voxel']).to(device) 69 | self.crop = CropParameters(self.width, self.height, self.num_encoders) 70 | # Copy the parameters from the old representation to the new one. 71 | self.representation.load_state_dict(old_state_dict, strict=True) 72 | return 73 | 74 | 75 | 76 | @property 77 | def states(self): 78 | return copy_states(self.unetrecurrent.states) 79 | 80 | @states.setter 81 | def states(self, states): 82 | self.unetrecurrent.states = states 83 | 84 | def reset_states(self): 85 | if 'NIAM' in self.network or 'NAS' in self.network: 86 | self.unetrecurrent.states = None 87 | else: 88 | self.unetrecurrent.states = [None] * self.unetrecurrent.num_encoders 89 | 90 | def forward(self, x): 91 | """ 92 | :param x: events[x, y, t, p] 93 | :return: output dict with image taking values in [0,1], and 94 | displacement within event_tensor. 95 | """ 96 | event_tensor = self.representation.forward(x) 97 | event_tensor = self.crop.pad(event_tensor) 98 | output_dict = self.unetrecurrent.forward(event_tensor) 99 | return output_dict, event_tensor -------------------------------------------------------------------------------- /config/train_ablation_e2vid_10k_fixed.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: ablation_e2vid_10k_fixed 2 | check_val_every_n_epoch: 1 3 | test_output_dir: results/ablation_e2vid_10k_fixed 4 | use_compile: false 5 | 6 | module: 7 | loss: 8 | lpips_weight: 1.0 9 | lpips_type: vgg 10 | l2_weight: 0 11 | l1_weight: 1.0 12 | ssim_weight: 0 13 | temporal_consistency_weight: 1.0 14 | optical_flow_source: raft_small 15 | temporal_consistency_L0: 20 16 | 17 | normalize_voxels: false 18 | model: 19 | target: model.model.E2VIDRecurrent 20 | params: 21 | unet_kwargs: 22 | num_bins: 5 23 | skip_type: sum 24 | recurrent_block_type: convlstm 25 | num_encoders: 3 26 | base_num_channels: 32 27 | num_residual_blocks: 2 28 | use_upsample_conv: true 29 | final_activation: "" 30 | norm: none 31 | 32 | train_stages: 33 | - stage_name: stage1 34 | max_epochs: 80 35 | 36 | optimizer: 37 | target: torch.optim.Adam 38 | params: 39 | lr: 0.0001 40 | weight_decay: 0 41 | amsgrad: true 42 | 43 | lr_scheduler: 44 | target: torch.optim.lr_scheduler.StepLR 45 | params: 46 | step_size: 50 47 | gamma: 1.0 # This is actually constant LR 48 | 49 | dataset: 50 | train_batch_size: 12 51 | num_workers: 9 52 | val_batch_size: 1 # Because test data has different sizes 53 | persistent_workers: true 54 | pin_memory: true 55 | 56 | train: 57 | - data_file: config/webvid_root.txt 58 | class_name: data.v2v_datasets.WebvidDatasetV2 59 | video_list_file: config/webvid10000_unfiltered.txt 60 | data_source_name: webvid 61 | video_reader: opencv 62 | sequence_length: 40 63 | pause_granularity: 5 64 | proba_pause_when_running: 0.0102 65 | proba_pause_when_paused: 0.9791 66 | crop_size: 128 67 | random_flip: true 68 | num_bins: 5 69 | min_resize_scale: 1 70 | max_resize_scale: 1 71 | frames_per_bin: 1 72 | # threshold_range: [0.05, 2] 73 | # max_thres_pos_neg_gap: 1.5 74 | use_fixed_thresholds: true # Parameter for fixed thresholds ablation 75 | base_noise_std_range: [0, 0.1] 76 | hot_pixel_std_range: [0, 10] 77 | max_samples_per_shot: 10 78 | 79 | val: 80 | - data_file: config/evaid_test.txt 81 | class_name: data.testh5.TestH5Dataset 82 | dataset_name: evaid 83 | num_bins: 5 84 | sequence_length: 80 85 | interpolate_bins: false 86 | max_samples: 1 # Limit val time, 720p runs really slow 87 | image_range: 1 88 | - data_file: config/ijrr_test.txt 89 | class_name: data.testh5.TestH5Dataset 90 | dataset_name: ijrr 91 | num_bins: 5 92 | sequence_length: 80 93 | interpolate_bins: false 94 | image_range: 1 95 | - data_file: config/hqf_test.txt 96 | class_name: data.testh5.TestH5Dataset 97 | dataset_name: hqf 98 | num_bins: 5 99 | sequence_length: 80 100 | interpolate_bins: false 101 | image_range: 1 102 | - data_file: config/mvsec_test.txt 103 | class_name: data.testh5.TestH5Dataset 104 | dataset_name: mvsec 105 | num_bins: 5 106 | sequence_length: 80 107 | interpolate_bins: false 108 | image_range: 1 109 | 110 | test_stage: 111 | test_batch_size: 1 112 | test_num_workers: 16 113 | test: 114 | - data_file: config/evaid_test.txt 115 | class_name: data.testh5.TestH5Dataset 116 | dataset_name: evaid 117 | num_bins: 5 118 | sequence_length: 80 119 | interpolate_bins: false 120 | - data_file: config/ijrr_test.txt 121 | class_name: data.testh5.TestH5Dataset 122 | dataset_name: ijrr 123 | num_bins: 5 124 | sequence_length: 80 125 | interpolate_bins: false 126 | - data_file: config/hqf_test.txt 127 | class_name: data.testh5.TestH5Dataset 128 | dataset_name: hqf 129 | num_bins: 5 130 | sequence_length: 80 131 | interpolate_bins: false 132 | - data_file: config/mvsec_test.txt 133 | class_name: data.testh5.TestH5Dataset 134 | dataset_name: mvsec 135 | num_bins: 5 136 | sequence_length: 80 137 | interpolate_bins: false -------------------------------------------------------------------------------- /config/train_ablation_e2vid_hdr.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: ablation_e2vid_hdr 2 | check_val_every_n_epoch: 1 3 | test_output_dir: results/ablation_e2vid_hdr 4 | use_compile: false 5 | 6 | module: 7 | loss: 8 | lpips_weight: 1.0 9 | lpips_type: vgg 10 | l2_weight: 0 11 | l1_weight: 1.0 12 | ssim_weight: 0 13 | temporal_consistency_weight: 1.0 14 | optical_flow_source: raft_small 15 | temporal_consistency_L0: 20 16 | 17 | normalize_voxels: false 18 | model: 19 | target: model.model.E2VIDRecurrent 20 | params: 21 | unet_kwargs: 22 | num_bins: 5 23 | skip_type: sum 24 | recurrent_block_type: convlstm 25 | num_encoders: 3 26 | base_num_channels: 32 27 | num_residual_blocks: 2 28 | use_upsample_conv: true 29 | final_activation: "" 30 | norm: none 31 | 32 | train_stages: 33 | - stage_name: stage1 34 | max_epochs: 80 35 | 36 | optimizer: 37 | target: torch.optim.Adam 38 | params: 39 | lr: 0.0001 40 | weight_decay: 0 41 | amsgrad: true 42 | 43 | lr_scheduler: 44 | target: torch.optim.lr_scheduler.StepLR 45 | params: 46 | step_size: 50 47 | gamma: 1.0 # This is actually constant LR 48 | 49 | dataset: 50 | train_batch_size: 12 51 | num_workers: 9 52 | val_batch_size: 1 # Because test data has different sizes 53 | persistent_workers: true 54 | pin_memory: true 55 | 56 | train: 57 | - data_file: config/webvid_root.txt 58 | class_name: data.v2v_datasets.WebvidDatasetV2 59 | video_list_file: config/webvid10000_unfiltered.txt 60 | data_source_name: webvid 61 | video_reader: opencv 62 | sequence_length: 40 63 | pause_granularity: 5 64 | proba_pause_when_running: 0.0102 65 | proba_pause_when_paused: 0.9791 66 | crop_size: 128 67 | random_flip: true 68 | num_bins: 5 69 | min_resize_scale: 1 70 | max_resize_scale: 1 71 | frames_per_bin: 1 72 | threshold_range: [0.05, 2] 73 | max_thres_pos_neg_gap: 1.5 74 | base_noise_std_range: [0, 0.1] 75 | hot_pixel_std_range: [0, 10] 76 | max_samples_per_shot: 10 77 | video_degrade: hdr # Degrade the videos for ablation studies 78 | degrade_ratio: 0.8 79 | 80 | val: 81 | - data_file: config/evaid_test.txt 82 | class_name: data.testh5.TestH5Dataset 83 | dataset_name: evaid 84 | num_bins: 5 85 | sequence_length: 80 86 | interpolate_bins: false 87 | max_samples: 1 # Limit val time, 720p runs really slow 88 | image_range: 1 89 | - data_file: config/ijrr_test.txt 90 | class_name: data.testh5.TestH5Dataset 91 | dataset_name: ijrr 92 | num_bins: 5 93 | sequence_length: 80 94 | interpolate_bins: false 95 | image_range: 1 96 | - data_file: config/hqf_test.txt 97 | class_name: data.testh5.TestH5Dataset 98 | dataset_name: hqf 99 | num_bins: 5 100 | sequence_length: 80 101 | interpolate_bins: false 102 | image_range: 1 103 | - data_file: config/mvsec_test.txt 104 | class_name: data.testh5.TestH5Dataset 105 | dataset_name: mvsec 106 | num_bins: 5 107 | sequence_length: 80 108 | interpolate_bins: false 109 | image_range: 1 110 | 111 | test_stage: 112 | test_batch_size: 1 113 | test_num_workers: 16 114 | test: 115 | - data_file: config/evaid_test.txt 116 | class_name: data.testh5.TestH5Dataset 117 | dataset_name: evaid 118 | num_bins: 5 119 | sequence_length: 80 120 | interpolate_bins: false 121 | - data_file: config/ijrr_test.txt 122 | class_name: data.testh5.TestH5Dataset 123 | dataset_name: ijrr 124 | num_bins: 5 125 | sequence_length: 80 126 | interpolate_bins: false 127 | - data_file: config/hqf_test.txt 128 | class_name: data.testh5.TestH5Dataset 129 | dataset_name: hqf 130 | num_bins: 5 131 | sequence_length: 80 132 | interpolate_bins: false 133 | - data_file: config/mvsec_test.txt 134 | class_name: data.testh5.TestH5Dataset 135 | dataset_name: mvsec 136 | num_bins: 5 137 | sequence_length: 80 138 | interpolate_bins: false -------------------------------------------------------------------------------- /model/eraft/update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FlowHead(nn.Module): 7 | def __init__(self, input_dim=128, hidden_dim=256): 8 | super(FlowHead, self).__init__() 9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 11 | self.relu = nn.ReLU(inplace=True) 12 | 13 | def forward(self, x): 14 | return self.conv2(self.relu(self.conv1(x))) 15 | 16 | class ConvGRU(nn.Module): 17 | def __init__(self, hidden_dim=128, input_dim=192+128): 18 | super(ConvGRU, self).__init__() 19 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 20 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 21 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 22 | 23 | def forward(self, h, x): 24 | hx = torch.cat([h, x], dim=1) 25 | 26 | z = torch.sigmoid(self.convz(hx)) 27 | r = torch.sigmoid(self.convr(hx)) 28 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) 29 | 30 | h = (1-z) * h + z * q 31 | return h 32 | 33 | class SepConvGRU(nn.Module): 34 | def __init__(self, hidden_dim=128, input_dim=192+128): 35 | super(SepConvGRU, self).__init__() 36 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 37 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 38 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 39 | 40 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 41 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 42 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 43 | 44 | 45 | def forward(self, h, x): 46 | # horizontal 47 | hx = torch.cat([h, x], dim=1) 48 | z = torch.sigmoid(self.convz1(hx)) 49 | r = torch.sigmoid(self.convr1(hx)) 50 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) 51 | h = (1-z) * h + z * q 52 | 53 | # vertical 54 | hx = torch.cat([h, x], dim=1) 55 | z = torch.sigmoid(self.convz2(hx)) 56 | r = torch.sigmoid(self.convr2(hx)) 57 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) 58 | h = (1-z) * h + z * q 59 | 60 | return h 61 | 62 | 63 | class BasicMotionEncoder(nn.Module): 64 | def __init__(self, args): 65 | super(BasicMotionEncoder, self).__init__() 66 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 67 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) 68 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 69 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 70 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 71 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) 72 | 73 | def forward(self, flow, corr): 74 | cor = F.relu(self.convc1(corr)) 75 | cor = F.relu(self.convc2(cor)) 76 | flo = F.relu(self.convf1(flow)) 77 | flo = F.relu(self.convf2(flo)) 78 | 79 | cor_flo = torch.cat([cor, flo], dim=1) 80 | out = F.relu(self.conv(cor_flo)) 81 | return torch.cat([out, flow], dim=1) 82 | 83 | 84 | class BasicUpdateBlock(nn.Module): 85 | def __init__(self, args, hidden_dim=128, input_dim=128): 86 | super(BasicUpdateBlock, self).__init__() 87 | self.args = args 88 | self.encoder = BasicMotionEncoder(args) 89 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) 90 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 91 | 92 | self.mask = nn.Sequential( 93 | nn.Conv2d(128, 256, 3, padding=1), 94 | nn.ReLU(inplace=True), 95 | nn.Conv2d(256, 64*9, 1, padding=0)) 96 | 97 | def forward(self, net, inp, corr, flow, upsample=True): 98 | motion_features = self.encoder(flow, corr) 99 | inp = torch.cat([inp, motion_features], dim=1) 100 | 101 | net = self.gru(net, inp) 102 | delta_flow = self.flow_head(net) 103 | 104 | # scale mask to balance gradients 105 | mask = .25 * self.mask(net) 106 | return net, mask, delta_flow 107 | 108 | 109 | 110 | -------------------------------------------------------------------------------- /config/train_v2v_hyper_10k.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: v2v_hyper_10k 2 | test_output_dir: results/v2v_hyper_10k 3 | check_val_every_n_epoch: 1 4 | use_compile: false 5 | 6 | module: 7 | loss: 8 | lpips_weight: 1.0 9 | lpips_type: vgg 10 | l2_weight: 0 11 | l1_weight: 1.0 12 | ssim_weight: 0 13 | temporal_consistency_weight: 1.0 14 | optical_flow_source: raft_small 15 | temporal_consistency_L0: 20 16 | 17 | normalize_voxels: false 18 | model: 19 | target: model.hyper_model.HyperE2VID 20 | params: 21 | unet_kwargs: 22 | num_bins: 5 23 | skip_type: sum 24 | recurrent_block_type: convlstm 25 | kernel_size: 5 26 | channel_multiplier: 2 27 | num_encoders: 3 28 | base_num_channels: 32 29 | num_residual_blocks: 2 30 | use_upsample_conv: true 31 | norm: none 32 | num_output_channels: 1 33 | use_dynamic_decoder: true # Key difference of HyperE2VID 34 | hyper_epochs: 16 # Number of curriculum epochs 35 | 36 | train_stages: 37 | - stage_name: stage1 38 | max_epochs: 80 39 | 40 | optimizer: 41 | target: torch.optim.Adam 42 | params: 43 | lr: 0.001 44 | weight_decay: 0 45 | amsgrad: true 46 | 47 | lr_scheduler: 48 | target: torch.optim.lr_scheduler.StepLR 49 | params: 50 | step_size: 50 51 | gamma: 1.0 # This is actually constant LR 52 | 53 | dataset: 54 | train_batch_size: 12 55 | num_workers: 9 56 | val_batch_size: 1 # Because test data has different sizes 57 | persistent_workers: true 58 | pin_memory: true 59 | 60 | train: 61 | - data_file: config/webvid_root.txt 62 | class_name: data.v2v_datasets.WebvidDatasetV2 63 | video_list_file: config/webvid10000_unfiltered.txt 64 | data_source_name: webvid 65 | video_reader: opencv 66 | sequence_length: 40 67 | pause_granularity: 5 68 | proba_pause_when_running: 0.0102 69 | proba_pause_when_paused: 0.9791 70 | crop_size: 128 71 | random_flip: true 72 | num_bins: 5 73 | min_resize_scale: 1 74 | max_resize_scale: 1 75 | frames_per_bin: 1 76 | threshold_range: [0.05, 2] 77 | max_thres_pos_neg_gap: 1.5 78 | base_noise_std_range: [0, 0.1] 79 | hot_pixel_std_range: [0, 10] 80 | max_samples_per_shot: 10 81 | 82 | val: 83 | - data_file: config/evaid_test.txt 84 | class_name: data.testh5.TestH5Dataset 85 | dataset_name: evaid 86 | num_bins: 5 87 | sequence_length: 80 88 | interpolate_bins: false 89 | max_samples: 1 # Limit val time, 720p runs really slow 90 | image_range: 1 91 | - data_file: config/ijrr_test.txt 92 | class_name: data.testh5.TestH5Dataset 93 | dataset_name: ijrr 94 | num_bins: 5 95 | sequence_length: 80 96 | interpolate_bins: false 97 | image_range: 1 98 | - data_file: config/hqf_test.txt 99 | class_name: data.testh5.TestH5Dataset 100 | dataset_name: hqf 101 | num_bins: 5 102 | sequence_length: 80 103 | interpolate_bins: false 104 | image_range: 1 105 | - data_file: config/mvsec_test.txt 106 | class_name: data.testh5.TestH5Dataset 107 | dataset_name: mvsec 108 | num_bins: 5 109 | sequence_length: 80 110 | interpolate_bins: false 111 | image_range: 1 112 | 113 | test_stage: 114 | test_batch_size: 1 115 | test_num_workers: 4 116 | test: 117 | - data_file: config/evaid_test.txt 118 | class_name: data.testh5.TestH5Dataset 119 | dataset_name: evaid 120 | num_bins: 5 121 | sequence_length: 80 122 | interpolate_bins: false 123 | - data_file: config/ijrr_test.txt 124 | class_name: data.testh5.TestH5Dataset 125 | dataset_name: ijrr 126 | num_bins: 5 127 | sequence_length: 80 128 | interpolate_bins: false 129 | - data_file: config/hqf_test.txt 130 | class_name: data.testh5.TestH5Dataset 131 | dataset_name: hqf 132 | num_bins: 5 133 | sequence_length: 80 134 | interpolate_bins: false 135 | - data_file: config/mvsec_test.txt 136 | class_name: data.testh5.TestH5Dataset 137 | dataset_name: mvsec 138 | num_bins: 5 139 | sequence_length: 80 140 | interpolate_bins: false -------------------------------------------------------------------------------- /scripts/hs_ergb_to_h5.py: -------------------------------------------------------------------------------- 1 | # Convert HS-ERGB dataset to h5 format. 2 | # We didn't use the dataset for evaluation because (1) the backgrounds were too static, and (2) the GT images were too noisy. 3 | # HS-ERGB download site: https://rpg.ifi.uzh.ch/TimeLens.html 4 | 5 | import os 6 | import h5py 7 | import numpy as np 8 | import glob 9 | import cv2 10 | 11 | def convert(evaid_dir, h5_path): 12 | ''' 13 | Download the dataset from our project page. The dataset structure is as follows 14 | . 15 | ├── close 16 | │ └── test 17 | │ ├── baloon_popping 18 | │ │ ├── events_aligned 19 | │ │ └── images_corrected 20 | │ ├── candle 21 | │ │ ├── events_aligned 22 | │ │ └── images_corrected 23 | │ ... 24 | │ 25 | └── far 26 | └── test 27 | ├── bridge_lake_01 28 | │ ├── events_aligned 29 | │ └── images_corrected 30 | ├── bridge_lake_03 31 | │ ├── events_aligned 32 | │ └── images_corrected 33 | ... 34 | 35 | Each events_aligned folder contains events files with template filename %06d.npz, and images_corrected contains image files with template filename %06d.png. In events_aligned each event file with index n contains events between images with index n-1 and n, i.e. event file 000001.npz contains events between images 000000.png and 000001.png. Each event file contains keys for the x,y,t, and p event component. Note that x and y need to be divided by 32 before use. This is because they actually correspond to remapped events, which have floating point coordinates. 36 | 37 | Moreover, images_corrected also contains timestamp.txt where image timestamps are stored. Note that in some folders there are more image files than event files. However, the image stamps in timestamp.txt should match with the event files and the additional images can be ignored. 38 | ''' 39 | 40 | of = h5py.File(h5_path, 'w') 41 | 42 | # Read timestamps 43 | timestamps_path = os.path.join(evaid_dir, 'images/timestamp.txt') 44 | with open(timestamps_path, 'r') as f: 45 | timestamps = f.readlines() 46 | timestamps = [float(x.strip()) for x in timestamps] # The timestamps are integers e.g. 2810536.0 47 | 48 | all_img_paths = sorted(glob.glob(os.path.join(evaid_dir, "images/*.png"))) 49 | 50 | # Read shape from first image 51 | img0 = cv2.imread(all_img_paths[0]) 52 | H, W, C = img0.shape 53 | print("H, W, C: ", H, W, C) 54 | of.create_dataset('sensor_resolution', data=[H, W]) 55 | 56 | # Read events 57 | all_xs = [] 58 | all_ys = [] 59 | all_ts = [] 60 | all_ps = [] 61 | 62 | ev_paths = sorted(glob.glob(os.path.join(evaid_dir, 'events/*.npz'))) 63 | for evp in ev_paths: 64 | # Use accelerated reading with pandas 65 | ev = np.load(evp) 66 | xs = ev['x'] // 32 # Throw away the floating point parts, leave this shit to later 67 | ys = ev['y'] // 32 68 | ts = ev["timestamp"] 69 | ps = ev["polarity"] 70 | 71 | # Filter out all events with xs >= W and ys >= H 72 | mask = np.logical_and(xs < W, ys < H) 73 | xs = xs[mask] 74 | ys = ys[mask] 75 | ts = ts[mask] 76 | ps = ps[mask] 77 | 78 | if xs.shape[0] > 0: 79 | all_xs.append(xs) 80 | all_ys.append(ys) 81 | all_ps.append(ps) 82 | all_ts.append(ts) 83 | 84 | all_xs = np.concatenate(all_xs) 85 | all_ys = np.concatenate(all_ys) 86 | all_ts = np.concatenate(all_ts) 87 | all_ps = np.concatenate(all_ps) 88 | 89 | event_idxs = np.searchsorted(all_ts, timestamps) 90 | basetime = all_ts[0] 91 | all_ts = (all_ts - basetime).astype(np.float64) / 1e6 92 | timestamps = (np.array(timestamps) - basetime).astype(np.float64) / 1e6 93 | 94 | of.create_dataset('events/ts', data=ts, dtype=np.float32) 95 | of.create_dataset('events/xs', data=xs, dtype=np.int16) 96 | of.create_dataset('events/ys', data=ys, dtype=np.int16) 97 | of.create_dataset('events/ps', data=ps, dtype=np.bool_) 98 | 99 | frame_cnt = min(len(timestamps), len(all_img_paths)) 100 | 101 | for i in range(frame_cnt): 102 | img = cv2.imread(all_img_paths[i], cv2.IMREAD_GRAYSCALE) 103 | of.create_dataset(f'images/{i:06d}', data=img) 104 | of["images"][f'{i:06d}'].attrs['timestamp'] = timestamps[i] 105 | of["images"][f'{i:06d}'].attrs['event_idx'] = event_idxs[i] 106 | of.close() 107 | 108 | def process(seqname): 109 | evaid_dir = f"/mnt/ssd/bs_ergb/1_TEST/{seqname}" 110 | h5_path = f"/mnt/ssd/bs_ergb/test/{seqname}.h5" 111 | os.makedirs(os.path.dirname(h5_path), exist_ok=True) 112 | convert(evaid_dir, h5_path) 113 | 114 | 115 | all_sequences = [ 116 | os.path.basename(x) for x in glob.glob("/mnt/ssd/bs_ergb/1_TEST/*") 117 | ] 118 | print(sorted(all_sequences)) 119 | for seqname in sorted(all_sequences): 120 | print(seqname) 121 | process(seqname) -------------------------------------------------------------------------------- /scripts/evaid_to_h5.py: -------------------------------------------------------------------------------- 1 | # Download EVAID-R from: https://sites.google.com/view/eventaid-benchmark/home 2 | 3 | import os 4 | import h5py 5 | import numpy as np 6 | import glob 7 | import pandas as pd 8 | import cv2 9 | import subprocess 10 | 11 | def convert(evaid_dir, h5_path, begin_second, end_second): 12 | # Evaid format: 13 | # event/*.txt: Events. Each txt is a series of events, each events takes a line: "{timestamp}, {x}, {y}, {polarity}", e.g. "4775805 1131 644 0". 14 | # gt/*.png: Images. Each png is a frame. 15 | # shape.txt: Txt with a single line "{W} {H}". 16 | # timestamps.txt: The i-th line corresponds to the timestamp of the i-th image, e.g. "4775787". 17 | # event/000001.txt are the events between gt/000001_img.png and gt/000002_img.png. There are no events before the first image, so when converting to h5 we will discard the first image. 18 | of = h5py.File(h5_path, 'w') 19 | 20 | all_events = [] 21 | 22 | # Read timestamps 23 | timestamps_path = os.path.join(evaid_dir, 'timestamps.txt') 24 | with open(timestamps_path, 'r') as f: 25 | timestamps = f.readlines() 26 | timestamps = [int(x.strip()) for x in timestamps] 27 | 28 | timestamp_rel = np.array(timestamps) - timestamps[0] 29 | begin_idx = np.searchsorted(timestamp_rel, begin_second * 1e6) 30 | end_idx = np.searchsorted(timestamp_rel, end_second * 1e6) 31 | print("begin_idx", begin_idx) 32 | print("end_idx", end_idx) 33 | timestamps = timestamps[begin_idx:end_idx+1] 34 | 35 | image_paths = sorted(glob.glob(os.path.join(evaid_dir, 'gt/*.png'))) + sorted(glob.glob(os.path.join(evaid_dir, 'gt/*.jpg'))) 36 | image_paths = image_paths[begin_idx:end_idx+1] 37 | 38 | # Read shape 39 | shape_path = os.path.join(evaid_dir, 'shape.txt') 40 | with open(shape_path, 'r') as f: 41 | shape = f.readlines()[0].strip().split(' ') 42 | W = int(shape[0]) 43 | H = int(shape[1]) 44 | of.create_dataset('sensor_resolution', data=[H, W]) 45 | 46 | # Read events 47 | ev_paths = sorted(glob.glob(os.path.join(evaid_dir, 'event/*.txt')))[begin_idx:end_idx+2] 48 | for evp in ev_paths: 49 | # Use accelerated reading with pandas 50 | ev = pd.read_csv(evp, header=None, sep=' ', names=['timestamp', 'x', 'y', 'polarity']) 51 | ev = ev.to_numpy() 52 | if ev.shape[0] > 0: 53 | all_events.append(ev) 54 | all_events = np.concatenate(all_events) 55 | 56 | ts = all_events[:, 0] 57 | xs = all_events[:, 1] 58 | print("xs.shape", xs.shape) 59 | ys = all_events[:, 2] 60 | ps = all_events[:, 3] 61 | 62 | event_idxs = np.searchsorted(ts, timestamps) 63 | basetime = ts[0] 64 | ts = (ts - basetime).astype(np.float64) / 1e6 65 | timestamps = (np.array(timestamps) - basetime).astype(np.float64) / 1e6 66 | 67 | of.create_dataset('events/ts', data=ts, dtype=np.float32) 68 | of.create_dataset('events/xs', data=xs, dtype=np.int16) 69 | of.create_dataset('events/ys', data=ys, dtype=np.int16) 70 | of.create_dataset('events/ps', data=ps, dtype=np.bool_) 71 | 72 | 73 | 74 | for i, imgp in enumerate(image_paths): 75 | if i == 0: # Discard first image 76 | continue 77 | img = cv2.imread(imgp, cv2.IMREAD_GRAYSCALE) 78 | of.create_dataset(f'images/{i:06d}', data=img) 79 | of["images"][f'{i:06d}'].attrs['timestamp'] = timestamps[i] 80 | of["images"][f'{i:06d}'].attrs['event_idx'] = event_idxs[i] 81 | of.close() 82 | 83 | def process(seqname, begin_second, end_second): 84 | # First, for {seqname}, excecute: 85 | # unzip /mnt/ssd/EventAid-R/{seqname}.zip -d /mnt/ssd/EventAid-R/{seqname} 86 | # Check if directory already exists to avoid re-extraction 87 | if not os.path.exists(f"/mnt/ssd/EventAid-R/{seqname}"): 88 | subprocess.run(["unzip", f"/mnt/ssd/EventAid-R/{seqname}.zip", "-d", f"/mnt/ssd/EventAid-R/{seqname}"], check=True) 89 | else: 90 | print(f"Directory for {seqname} already exists, skipping extraction") 91 | 92 | evaid_dir = f"/mnt/ssd/EventAid-R/{seqname}" 93 | h5_path = f"/mnt/ssd/EventAid-R-h5/{seqname}.h5" 94 | os.makedirs(os.path.dirname(h5_path), exist_ok=True) 95 | convert(evaid_dir, h5_path, begin_second, end_second) 96 | 97 | ''' 98 | ball:可以有 99 | bear:可以有 100 | blocks:不行,背景完全没有 101 | box:可以有 102 | building:可以有 103 | outdoor:可以有 104 | playball:往后剪 105 | room1:可以有 106 | room2:和room1重复了 107 | scuplture:可以有 108 | toy:可以有 109 | traffic:可以有 110 | umbrella:不行,背景完全没有 111 | wall:可以有 112 | ''' 113 | 114 | use_seqs = { 115 | "ball": [0, 5], 116 | "bear": [0, 5], 117 | "box": [0, 5], 118 | "building": [0, 5], 119 | "outdoor": [0, 5], 120 | "playball": [25, 30], 121 | "room1": [0, 5], 122 | "sculpture": [0, 5], 123 | "toy": [0, 5], 124 | "traffic": [0, 5], 125 | "wall": [0, 5] 126 | } 127 | 128 | for seqname, (begin_second, end_second) in use_seqs.items(): 129 | process(seqname, begin_second, end_second) -------------------------------------------------------------------------------- /scripts/ijrr_to_h5.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import h5py 4 | import numpy as np 5 | import glob 6 | import cv2 7 | import tqdm 8 | 9 | def unzip(file_path): 10 | # Use the system unzip command to unzip it into the current directory 11 | # If the zip contains a folter named "name", unzip it directly. Else, unzip it into file_dir/name. 12 | file_dir = os.path.dirname(file_path) 13 | name = os.path.basename(file_path).split(".")[0] 14 | out_dir = os.path.join(file_dir, name) 15 | os.system(f"unzip -d {out_dir} {file_path}") 16 | # If this creates a directory file_dir/name/name, move the contents of file_dir/name/name to file_dir/name 17 | if os.path.isdir(os.path.join(out_dir, name)): 18 | os.system(f"mv {os.path.join(out_dir, name)}/* {out_dir}") 19 | os.system(f"rm -r {os.path.join(out_dir, name)}") 20 | 21 | CUT_SECONDS = { 22 | "boxes_6dof": (5, 20), 23 | "calibration": (5, 20), 24 | "dynamic_6dof": (5, 20), 25 | "office_zigzag": (5, 12), 26 | "poster_6dof": (5, 20), 27 | "shapes_6dof": (5, 20), 28 | "slider_depth": (1, 2.5) 29 | } 30 | 31 | IN_DIR = "/mnt/ssd/IJRR" 32 | OUT_DIR = "/mnt/ssd/IJRR_cut" 33 | os.makedirs(OUT_DIR, exist_ok=True) 34 | 35 | for seq_name in CUT_SECONDS.keys(): 36 | zip_path = f"{IN_DIR}/{seq_name}.zip" 37 | unzip(zip_path) 38 | 39 | out_h5path = f"{OUT_DIR}/{seq_name}.h5" 40 | in_root = f"{IN_DIR}/{seq_name}" 41 | 42 | img_timestamp_txt = f"{in_root}/images.txt" 43 | # In the txt is N rows. Each row is timestamp + image filepath, such as : 44 | # 1468941032.255472635 images/frame_00000000.png 45 | timestamps = [] 46 | img_paths = [] 47 | with open(img_timestamp_txt, "r") as f: 48 | for line in f: 49 | timestamp, img_path = line.strip().split(" ") 50 | timestamps.append(float(timestamp)) 51 | img_paths.append(img_path) 52 | # The events are stored in a txt file. Each line is [t x y p], such as: 53 | # 1468941032.229165635 128 154 1 54 | events_txt = f"{in_root}/events.txt" 55 | events = np.loadtxt(events_txt, dtype=np.float64) 56 | 57 | ts = events[:, 0] 58 | 59 | event_begin_idx = np.searchsorted(ts, CUT_SECONDS[seq_name][0] + timestamps[0]) 60 | event_end_idx = np.searchsorted(ts, CUT_SECONDS[seq_name][1] + timestamps[0]) 61 | print("timestamps[0]", timestamps[0]) 62 | print("CUT_SECONDS[seq_name][1]", CUT_SECONDS[seq_name][1]) 63 | print("ts[0]", ts[0]) 64 | print("event_begin_idx", event_begin_idx, "event_end_idx", event_end_idx) 65 | image_begin_idx = np.searchsorted(timestamps, CUT_SECONDS[seq_name][0] + timestamps[0]) 66 | image_end_idx = np.searchsorted(timestamps, CUT_SECONDS[seq_name][1] + timestamps[0]) 67 | 68 | img_ev_idx = [] 69 | for i in range(image_begin_idx, image_end_idx): 70 | img_ev_idx.append(np.searchsorted(ts[event_begin_idx:event_end_idx], timestamps[i])) 71 | 72 | # Extract the event data and images first 73 | event_xs = events[event_begin_idx:event_end_idx, 1].astype(np.uint16) 74 | event_ys = events[event_begin_idx:event_end_idx, 2].astype(np.uint16) 75 | event_ts = events[event_begin_idx:event_end_idx, 0].astype(np.float64) 76 | event_ps = events[event_begin_idx:event_end_idx, 3].astype(np.uint8) 77 | 78 | images = [] 79 | for img_path in img_paths[image_begin_idx:image_end_idx]: 80 | images.append(cv2.imread(f"{in_root}/{img_path}", cv2.IMREAD_GRAYSCALE)) 81 | images = np.stack(images) 82 | N, H, W = images.shape 83 | 84 | # Output in HQF format directly 85 | with h5py.File(out_h5path, "w") as f: 86 | # Save metadata as attributes 87 | f.attrs["sensor_resolution"] = (H, W) 88 | f.attrs["num_events"] = event_ts.shape[0] 89 | f.attrs["num_imgs"] = N 90 | f.attrs["data_source"] = "ijrr" 91 | 92 | # Save event data 93 | f.create_dataset("events/xs", data=event_xs) 94 | f.create_dataset("events/ys", data=event_ys) 95 | f.create_dataset("events/ts", data=event_ts) 96 | f.create_dataset("events/ps", data=event_ps) 97 | 98 | # Save images with proper attributes 99 | img_timestamps = np.array(timestamps[image_begin_idx:image_end_idx]) 100 | for idx in range(N): 101 | image_name = f"images/image{idx:09d}" 102 | f.create_dataset(image_name, data=images[idx]) 103 | f[image_name].attrs["event_idx"] = img_ev_idx[idx] 104 | f[image_name].attrs["timestamp"] = img_timestamps[idx] 105 | 106 | print(f"Processed {seq_name} - saved to {out_h5path}") 107 | 108 | 109 | 110 | -------------------------------------------------------------------------------- /config/webvid100_unfiltered.txt: -------------------------------------------------------------------------------- 1 | 000451_000500/34954225.mp4 299 1.379 1.884 2 | 000001_000050/1066687471.mp4 1433 1.050 1.574 3 | 000351_000400/5569250.mp4 1260 0.830 0.956 4 | 000401_000450/3271574.mp4 795 1.089 0.983 5 | 000051_000100/1066662943.mp4 400 1.883 2.372 6 | 000001_000050/1066695484.mp4 600 1.333 1.803 7 | 000351_000400/1017766681.mp4 300 1.738 1.678 8 | 000351_000400/2409980.mp4 726 0.245 0.287 9 | 000001_000050/1066693069.mp4 463 0.472 0.425 10 | 000051_000100/1066662949.mp4 400 0.932 0.709 11 | 000151_000200/1035762251.mp4 235 1.771 1.640 12 | 000351_000400/34174477.mp4 689 1.928 1.779 13 | 000301_000350/21315283.mp4 238 1.167 0.842 14 | 000201_000250/1048743643.mp4 303 2.540 1.716 15 | 000401_000450/8506630.mp4 300 0.315 0.421 16 | 000351_000400/1017400204.mp4 259 1.357 1.914 17 | 000001_000050/1066683316.mp4 541 0.359 0.501 18 | 000201_000250/1042129498.mp4 344 0.350 0.283 19 | 000351_000400/31411729.mp4 925 0.979 0.804 20 | 000051_000100/1066650397.mp4 389 0.648 0.867 21 | 000351_000400/13149671.mp4 743 1.481 2.118 22 | 000151_000200/31345138.mp4 219 0.078 0.074 23 | 000001_000050/1066694605.mp4 400 0.633 0.852 24 | 000301_000350/20767336.mp4 226 1.932 2.729 25 | 000351_000400/1013741759.mp4 512 1.547 1.707 26 | 000051_000100/1066657489.mp4 2732 1.502 1.717 27 | 000251_000300/1052293579.mp4 654 0.699 0.498 28 | 000201_000250/1022242063.mp4 288 0.125 0.186 29 | 000351_000400/1047441931.mp4 458 0.576 0.824 30 | 000301_000350/19838812.mp4 819 1.460 1.721 31 | 000101_000150/1054929692.mp4 350 1.012 1.106 32 | 000451_000500/1024662227.mp4 932 0.559 0.518 33 | 000301_000350/23156584.mp4 634 1.793 2.062 34 | 000201_000250/16688431.mp4 528 0.761 0.552 35 | 000401_000450/33051685.mp4 264 1.506 1.299 36 | 000401_000450/29984371.mp4 541 0.204 0.186 37 | 000101_000150/1020350419.mp4 755 1.220 1.650 38 | 000451_000500/9892616.mp4 416 0.436 0.514 39 | 000401_000450/1036252877.mp4 392 2.249 1.659 40 | 000101_000150/1024461695.mp4 716 1.619 1.338 41 | 000201_000250/15524014.mp4 753 0.506 0.394 42 | 000401_000450/1028604152.mp4 299 2.096 1.478 43 | 000001_000050/1066679668.mp4 2041 0.211 0.186 44 | 000301_000350/1016812786.mp4 791 0.237 0.246 45 | 000351_000400/1014494702.mp4 552 2.659 1.940 46 | 000401_000450/5239799.mp4 360 2.257 1.672 47 | 000201_000250/32073559.mp4 1212 1.380 1.581 48 | 000051_000100/1066650964.mp4 435 1.553 1.666 49 | 000351_000400/1035629579.mp4 540 0.214 0.309 50 | 000451_000500/8793694.mp4 685 0.313 0.424 51 | 000151_000200/1031472524.mp4 647 0.092 0.101 52 | 000351_000400/29541901.mp4 255 1.231 1.709 53 | 000101_000150/1045899340.mp4 663 0.833 0.656 54 | 000051_000100/1066660393.mp4 754 0.381 0.453 55 | 000301_000350/1007336329.mp4 485 1.480 1.469 56 | 000251_000300/22056640.mp4 685 0.902 0.664 57 | 000201_000250/12276743.mp4 624 1.070 1.574 58 | 000351_000400/10459523.mp4 426 1.314 1.757 59 | 000051_000100/1066666075.mp4 220 1.666 2.422 60 | 000251_000300/19371796.mp4 540 0.337 0.297 61 | 000301_000350/1787213.mp4 448 0.917 1.306 62 | 000301_000350/7969804.mp4 375 0.902 0.719 63 | 000151_000200/23240677.mp4 390 1.332 1.247 64 | 000401_000450/1032774740.mp4 386 1.820 2.597 65 | 000301_000350/1017656467.mp4 396 0.063 0.053 66 | 000301_000350/1026522881.mp4 390 1.604 1.133 67 | 000351_000400/1030039427.mp4 313 0.759 1.110 68 | 000201_000250/1009399472.mp4 733 1.575 1.188 69 | 000101_000150/1054965770.mp4 559 1.633 2.374 70 | 000401_000450/11353394.mp4 373 2.035 1.760 71 | 000301_000350/1035636821.mp4 1129 0.809 0.650 72 | 000051_000100/1066648597.mp4 1501 0.561 0.403 73 | 000051_000100/1066662052.mp4 629 1.189 1.394 74 | 000001_000050/1066697728.mp4 271 1.894 2.013 75 | 000051_000100/1066665967.mp4 600 1.483 1.160 76 | 000251_000300/5956955.mp4 379 1.620 1.320 77 | 000101_000150/34383022.mp4 1456 1.597 1.415 78 | 000301_000350/1025135294.mp4 582 0.095 0.121 79 | 000101_000150/1034265602.mp4 370 1.674 1.728 80 | 000051_000100/1066656688.mp4 646 0.744 0.604 81 | 000401_000450/4725545.mp4 465 0.482 0.394 82 | 000151_000200/16379059.mp4 353 0.322 0.221 83 | 000151_000200/10782767.mp4 367 0.482 0.470 84 | 000251_000300/1032863486.mp4 945 0.906 0.768 85 | 000051_000100/1066655575.mp4 620 1.341 1.280 86 | 000151_000200/20283382.mp4 808 1.096 1.479 87 | 000251_000300/1065084745.mp4 539 0.822 0.638 88 | 000251_000300/16137157.mp4 1807 1.909 1.276 89 | 000451_000500/1038132419.mp4 1155 1.001 1.104 90 | 000251_000300/1053259499.mp4 300 0.304 0.371 91 | 000151_000200/15836395.mp4 354 0.850 1.201 92 | 000301_000350/1039855880.mp4 264 1.395 1.028 93 | 000451_000500/27427918.mp4 380 2.560 1.807 94 | 000051_000100/1066660408.mp4 857 0.383 0.476 95 | 000001_000050/1066706878.mp4 366 0.223 0.166 96 | 000001_000050/1066693258.mp4 313 0.766 0.833 97 | 000151_000200/22994827.mp4 322 1.303 1.174 98 | 000351_000400/1027088822.mp4 600 2.534 1.691 99 | 000101_000150/1031009618.mp4 450 1.400 1.328 100 | 000351_000400/1021789513.mp4 755 1.591 1.505 101 | -------------------------------------------------------------------------------- /model/hyper/hyper_dynamic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from .fourier_bessel import bases_list 5 | 6 | 7 | class ConvolutionalContextFusion(nn.Module): 8 | """ This module takes the event tensor and the previous reconstructions and fuses them together. 9 | First, these tensors are concatenated in the channel dimension. Then, the tensor is downsampled. 10 | Finally, a convolution is applied to the downsampled tensor. 11 | """ 12 | 13 | def __init__(self, in_channels, out_channels, downsample_factor=4, kernel_size=3, padding="same"): 14 | super().__init__() 15 | self.scale = 1.0 / downsample_factor 16 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 17 | kernel_size=kernel_size, padding=padding) 18 | 19 | def forward(self, ev_tensor, prev_recs): 20 | context = torch.cat((ev_tensor, prev_recs), dim=1) 21 | context = nn.functional.interpolate(context, scale_factor=self.scale, mode='bilinear', align_corners=False) 22 | context = self.conv(context) 23 | return context 24 | 25 | 26 | class DynamicAtomGeneration(nn.Module): 27 | """ This module takes the context tensor and generates a set of dynamic atoms for each pixel. 28 | The context tensor is first passed through a convolutional network. The output of this network is a set of 29 | coefficients for the set of multiscale Fourier Bessel basis elements. These coefficients are then used to generate 30 | the dynamic atoms. 31 | """ 32 | def __init__(self, kernel_size=3, num_atoms=6, num_bases=6, in_context_channels=32, hid_channels=64, stride=1): 33 | super().__init__() 34 | self.stride = stride 35 | self.num_atoms = num_atoms 36 | bases = bases_list(kernel_size, num_bases) # This is the list of multiscale Fourier Bessel basis elements 37 | self.register_buffer('bases', torch.Tensor(bases).float()) # Tensor for the multiscale Fourier Bessel bases 38 | self.num_multiscale_bases = len(bases) # This is the total number of multiscale Fourier Bessel basis elements 39 | num_basis_coeff = num_atoms * self.num_multiscale_bases # This is the number of basis coefficients per pixel 40 | 41 | self.bases_net = nn.Sequential( 42 | nn.Conv2d(in_context_channels, hid_channels, kernel_size=3, padding="same", stride=stride), 43 | nn.BatchNorm2d(hid_channels), 44 | nn.Tanh(), 45 | nn.Conv2d(hid_channels, num_basis_coeff, kernel_size=3, padding="same"), 46 | nn.BatchNorm2d(num_basis_coeff), 47 | nn.Tanh() 48 | ) 49 | 50 | def forward(self, context): 51 | N, _, H, W = context.shape 52 | H = H // self.stride 53 | W = W // self.stride 54 | basis_coefficients = self.bases_net(context) 55 | basis_coefficients = basis_coefficients.view(N, self.num_atoms, self.num_multiscale_bases, H, W) 56 | per_pixel_dynamic_atoms = torch.einsum('bmkhw,kl->bmlhw', basis_coefficients, self.bases) 57 | return per_pixel_dynamic_atoms 58 | 59 | 60 | class DynamicConv(nn.Module): 61 | """ 62 | This module takes an input tensor and convolves it with dynamic per-pixel kernels. The dynamic kernels are generated 63 | by multiplying the dynamic atoms with the learned compositional coefficients. 64 | """ 65 | 66 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, num_atoms=6): 67 | super().__init__() 68 | self.in_channels = in_channels 69 | self.out_channels = out_channels 70 | self.kernel_size = kernel_size 71 | self.stride = stride 72 | self.padding = padding 73 | self.num_atoms = num_atoms 74 | self.compositional_coefficients = nn.Parameter(torch.Tensor(out_channels, in_channels * num_atoms, 1, 1)) 75 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 76 | self.reset_parameters() 77 | 78 | def reset_parameters(self): 79 | nn.init.kaiming_normal_(self.compositional_coefficients, mode='fan_out', nonlinearity='relu') 80 | if self.bias is not None: 81 | self.bias.data.zero_() 82 | 83 | def forward(self, input_tensor, per_pixel_dynamic_atoms): 84 | N, C, H, W = input_tensor.shape 85 | H = H // self.stride 86 | W = W // self.stride 87 | x = nn.functional.unfold(input_tensor, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding) 88 | x = x.view(N, self.in_channels, self.kernel_size * self.kernel_size, H, W) 89 | intermediate_features = torch.einsum('bmlhw,bclhw->bcmhw', per_pixel_dynamic_atoms, x) 90 | intermediate_features = intermediate_features.reshape(N, self.in_channels * self.num_atoms, H, W) 91 | out = nn.functional.conv2d(intermediate_features, self.compositional_coefficients, self.bias) 92 | return out 93 | -------------------------------------------------------------------------------- /utils/myutil.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | from math import fabs, ceil, floor 4 | import torch 5 | import os 6 | from torch.nn import ZeroPad2d 7 | from utils.parse_config import ConfigParser 8 | from utils.default_config import default_config 9 | 10 | 11 | def skip_concat(x1, x2): 12 | return torch.cat([x1, x2], dim=1) 13 | 14 | 15 | def skip_sum(x1, x2): 16 | return x1 + x2 17 | 18 | 19 | def mean(l): 20 | return 0 if len(l) == 0 else sum(l) / len(l) 21 | 22 | 23 | def quick_norm(img): 24 | return (img - torch.min(img))/(torch.max(img) - torch.min(img) + 1e-5) 25 | 26 | 27 | def robust_min(img, p=5): 28 | return np.percentile(img.ravel(), p) 29 | 30 | 31 | def robust_max(img, p=95): 32 | return np.percentile(img.ravel(), p) 33 | 34 | 35 | def normalize(img, m=10, M=90): 36 | return np.clip((img - robust_min(img, m)) / (robust_max(img, M) - robust_min(img, m)), 0.0, 1.0) 37 | 38 | 39 | def ffmpeg_glob_cmd(input_folder, output_path=None): 40 | if output_path is None: 41 | output_path = os.path.join(input_folder, 'a_video.mp4') 42 | return ['ffmpeg', '-y', '-pattern_type', 'glob', '-i', 43 | os.path.join(input_folder, '*.png'), '-framerate', '20', 44 | output_path] 45 | 46 | 47 | def optimal_crop_size(max_size, max_subsample_factor, safety_margin=0): 48 | """ Find the optimal crop size for a given max_size and subsample_factor. 49 | The optimal crop size is the smallest integer which is greater or equal than max_size, 50 | while being divisible by 2^max_subsample_factor. 51 | """ 52 | crop_size = int(pow(2, max_subsample_factor) * ceil(max_size / pow(2, max_subsample_factor))) 53 | crop_size += safety_margin * pow(2, max_subsample_factor) 54 | return crop_size 55 | 56 | 57 | class CropParameters: 58 | """ Helper class to compute and store useful parameters for pre-processing and post-processing 59 | of images in and out of E2VID. 60 | Pre-processing: finding the best image size for the network, and padding the input image with zeros 61 | Post-processing: Crop the output image back to the original image size 62 | """ 63 | 64 | def __init__(self, width, height, num_encoders, safety_margin=0): 65 | 66 | self.height = height 67 | self.width = width 68 | self.num_encoders = num_encoders 69 | self.width_crop_size = optimal_crop_size(self.width, num_encoders, safety_margin) 70 | self.height_crop_size = optimal_crop_size(self.height, num_encoders, safety_margin) 71 | 72 | self.padding_top = ceil(0.5 * (self.height_crop_size - self.height)) 73 | self.padding_bottom = floor(0.5 * (self.height_crop_size - self.height)) 74 | self.padding_left = ceil(0.5 * (self.width_crop_size - self.width)) 75 | self.padding_right = floor(0.5 * (self.width_crop_size - self.width)) 76 | self.pad = ZeroPad2d((self.padding_left, self.padding_right, self.padding_top, self.padding_bottom)) 77 | 78 | self.cx = floor(self.width_crop_size / 2) 79 | self.cy = floor(self.height_crop_size / 2) 80 | 81 | self.ix0 = self.cx - floor(self.width / 2) 82 | self.ix1 = self.cx + ceil(self.width / 2) 83 | self.iy0 = self.cy - floor(self.height / 2) 84 | self.iy1 = self.cy + ceil(self.height / 2) 85 | 86 | def crop(self, img): 87 | return img[..., self.iy0:self.iy1, self.ix0:self.ix1] 88 | 89 | 90 | def format_power(size): 91 | power = 1e3 92 | n = 0 93 | power_labels = {0 : '', 1: 'K', 2: 'M', 3: 'G', 4: 'T'} 94 | while size > power: 95 | size /= power 96 | n += 1 97 | return size, power_labels[n] 98 | 99 | def make_henri_compatible(checkpoint): 100 | """ 101 | Checkpoints have ConfigParser type configs, whereas Henri checkpoints have 102 | dictionary type configs or "arch, model" dicts. 103 | We will generate and add a ConfigParser to the checkpoint and return it. 104 | """ 105 | assert ('config' in checkpoint or ('arch' in checkpoint and 'model' in checkpoint)) 106 | check_config = checkpoint['config'] if 'config' in checkpoint else checkpoint 107 | new_config = copy.deepcopy(default_config) 108 | new_config['arch']['type'] = check_config['arch'] 109 | new_config['arch']['args']['unet_kwargs'] = check_config['model'] 110 | config = ConfigParser(new_config) 111 | checkpoint['config'] = config 112 | print(new_config) 113 | return checkpoint 114 | 115 | 116 | def recursive_clone(tensor): 117 | """ 118 | Assumes tensor is a torch.tensor with 'clone()' method, possibly 119 | inside nested iterable. 120 | E.g., tensor = [(pytorch_tensor, pytorch_tensor), ...] 121 | """ 122 | if hasattr(tensor, 'clone'): 123 | return tensor.clone() 124 | try: 125 | return type(tensor)(recursive_clone(t) for t in tensor) 126 | except TypeError: 127 | print('{} is not iterable and has no clone() method.'.format(tensor)) 128 | -------------------------------------------------------------------------------- /scripts/qwen_vl_annotate.py: -------------------------------------------------------------------------------- 1 | # The annotation results can be found along with the V2V checkpoints (checkpoints/webvid_annots.txt). 2 | 3 | import cv2 4 | import torch 5 | from PIL import Image 6 | from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor 7 | from qwen_vl_utils import process_vision_info 8 | import glob 9 | import tqdm 10 | import os 11 | import json 12 | import concurrent.futures 13 | import threading 14 | 15 | questions = [ 16 | ("Is this image a photo without any post-production artefacts (such as subtitles/CG/...)? (True/False) ", "real"), 17 | ("Is there an outdoor scene in this image? (True/False)", "outdoor"), 18 | ("Is there an indoor scene in this image? (True/False)", "indoor"), 19 | ("Is it daytime in this image? (True/False)", "day"), 20 | ("Is it nighttime in this image? (True/False)", "night"), 21 | ("Is there water in this image? (True/False)", "water"), 22 | ("Are there humans in this image? (True/False)", "human"), 23 | ("Is the sky in this image? (True/False)", "sky"), 24 | ("Is this image blank (pure white / black / ... without any significant object)? (True/False)", "blank"), 25 | ("Is there out-of-focus blur in this image? (True/False)", "defocus"), 26 | ("Is there motion blur in this image? (True/False)", "motion"), 27 | ("Is there any object with text (such as a book cover) in this image? (True/False)", "text"), 28 | ("Describe the content of the photo.", "description") 29 | ] 30 | 31 | # 1. 加载本地模型和tokenizer 32 | model_dir = "pretrained/qwen_2_5_vl" 33 | model = Qwen2_5_VLForConditionalGeneration.from_pretrained( 34 | model_dir, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" 35 | ) 36 | processor = AutoProcessor.from_pretrained(model_dir, use_fast=True, padding_side="left") 37 | 38 | 39 | def infer(img_path, text): 40 | messages = [ 41 | { 42 | "role": "user", 43 | "content": [ 44 | { 45 | "type": "image", 46 | "image": img_path, 47 | }, 48 | {"type": "text", "text": text}, 49 | ], 50 | } 51 | ] 52 | 53 | # Preparation for inference 54 | text = processor.apply_chat_template( 55 | messages, tokenize=False, add_generation_prompt=True 56 | ) 57 | image_inputs, video_inputs = process_vision_info(messages) 58 | inputs = processor( 59 | text=[text], 60 | images=image_inputs, 61 | videos=video_inputs, 62 | padding=True, 63 | return_tensors="pt", 64 | ) 65 | inputs = inputs.to("cuda") 66 | 67 | # Inference: Generation of the output 68 | generated_ids = model.generate(**inputs, max_new_tokens=128) 69 | generated_ids_trimmed = [ 70 | out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) 71 | ] 72 | output_text = processor.batch_decode( 73 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False 74 | ) 75 | 76 | return output_text 77 | 78 | def infer_batch(img_path): 79 | messages = [ 80 | [{ 81 | "role": "user", 82 | "content": [ 83 | { 84 | "type": "image", 85 | "image": img_path, 86 | }, 87 | {"type": "text", "text": question}, 88 | ], 89 | } 90 | ] for question, key in questions 91 | ] 92 | 93 | # Preparation for batch inference 94 | texts = [ 95 | processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) 96 | for msg in messages 97 | ] 98 | image_inputs, video_inputs = process_vision_info(messages) 99 | inputs = processor( 100 | text=texts, 101 | images=image_inputs, 102 | videos=video_inputs, 103 | padding=True, 104 | return_tensors="pt", 105 | ) 106 | inputs = inputs.to("cuda") 107 | 108 | # Batch Inference 109 | generated_ids = model.generate(**inputs, max_new_tokens=128) 110 | generated_ids_trimmed = [ 111 | out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) 112 | ] 113 | output_texts = processor.batch_decode( 114 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False 115 | ) 116 | print(output_texts) 117 | 118 | return output_texts 119 | 120 | 121 | def annotate_img(image_path): 122 | img_id = os.path.basename(image_path).split(".")[0] 123 | annot = {"id": img_id} 124 | 125 | all_answers = infer_batch(image_path) 126 | for idx, (question, key) in enumerate(questions): 127 | #response = infer(image_path, question)[0] 128 | response = all_answers[idx] 129 | print(response) 130 | if key == "description": 131 | annot[key] = response 132 | else: 133 | if "true" in response.lower(): 134 | annot[key] = True 135 | else: 136 | annot[key] = False 137 | return annot 138 | 139 | def process_image(image_path, lock, annot_path): 140 | try: 141 | text = annotate_img(image_path) 142 | with lock: 143 | with open(annot_path, "a", encoding="UTF-8") as f: 144 | f.write(str(text) + "\n") 145 | f.flush() # 立即写入磁盘 146 | except Exception as e: 147 | print("Exception", e, "with", image_path, ":", e) 148 | 149 | def main(): 150 | all_img_paths = sorted(glob.glob("../data/webvid_imgs/*.png")) 151 | annot_path = "../data/webvid/annots.txt" 152 | 153 | # # 清空文件 154 | # with open(annot_path, "w") as f: 155 | # pass 156 | 157 | lock = threading.Lock() 158 | 159 | with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: 160 | futures = [executor.submit(process_image, ip, lock, annot_path) for ip in all_img_paths] 161 | for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(all_img_paths)): 162 | future.result() # 获取结果,如果出现异常,会在这里抛出 163 | 164 | if __name__ == "__main__": 165 | main() -------------------------------------------------------------------------------- /utils/parse_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from functools import reduce, partial 4 | from operator import getitem 5 | from datetime import datetime 6 | from .util import read_json, write_json 7 | 8 | 9 | class ConfigParser: 10 | def __init__(self, config, resume=None, modification=None, run_id=None): 11 | """ 12 | class to parse configuration json file. Handles hyperparameters for training, initializations of modules, checkpoint saving 13 | :param config: Dict containing configurations, hyperparameters for training. contents of `config.json` file for example. 14 | :param resume: String, path to the checkpoint being loaded. 15 | :param modification: Dict keychain:value, specifying position values to be replaced from config dict. 16 | :param run_id: Unique Identifier for training processes. Used to save checkpoints. Timestamp is being used as default 17 | """ 18 | # load config file and apply modification 19 | self._config = _update_config(config, modification) 20 | self.resume = resume 21 | 22 | # set save_dir where trained model and log will be saved. 23 | save_dir = Path(self.config['trainer']['save_dir']) 24 | 25 | exper_name = self.config['name'] 26 | if run_id is None: # use timestamp as default run-id 27 | run_id = datetime.now().strftime(r'%m%d_%H%M%S') 28 | self._save_dir = save_dir / 'models' / exper_name / run_id 29 | 30 | # make directory for saving checkpoints 31 | exist_ok = run_id == '' 32 | self.save_dir.mkdir(parents=True, exist_ok=exist_ok) 33 | 34 | # save updated config file to the checkpoint dir 35 | write_json(self.config, self.save_dir / 'config.json') 36 | 37 | @classmethod 38 | def from_args(cls, args, options=''): 39 | """ 40 | Initialize this class from some cli arguments. Used in train, test. 41 | """ 42 | for opt in options: 43 | args.add_argument(*opt.flags, default=None, type=opt.type) 44 | if not isinstance(args, tuple): 45 | args = args.parse_args() 46 | 47 | if args.device is not None: 48 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device 49 | if args.resume is not None: 50 | resume = Path(args.resume) 51 | cfg_fname = resume.parent / 'config.json' 52 | else: 53 | msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example." 54 | assert args.config is not None, msg_no_cfg 55 | resume = None 56 | cfg_fname = Path(args.config) 57 | 58 | config = read_json(cfg_fname) 59 | if args.config and resume: 60 | # update new config for fine-tuning 61 | config.update(read_json(args.config)) 62 | 63 | # parse custom cli options into dictionary 64 | modification = {opt.target: getattr(args, _get_opt_name(opt.flags)) for opt in options} 65 | return cls(config, resume, modification) 66 | 67 | def init_obj(self, name, module, *args, **kwargs): 68 | """ 69 | Finds a function handle with the name given as 'type' in config, and returns the 70 | instance initialized with corresponding arguments given. 71 | 72 | `object = config.init_obj('name', module, a, b=1)` 73 | is equivalent to 74 | `object = module.name(a, b=1)` 75 | """ 76 | module_name = self[name]['type'] 77 | module_args = dict(self[name]['args']) 78 | assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed' 79 | module_args.update(kwargs) 80 | return getattr(module, module_name)(*args, **module_args) 81 | 82 | def init_ftn(self, name, module, *args, **kwargs): 83 | """ 84 | Finds a function handle with the name given as 'type' in config, and returns the 85 | function with given arguments fixed with functools.partial. 86 | 87 | `function = config.init_ftn('name', module, a, b=1)` 88 | is equivalent to 89 | `function = lambda *args, **kwargs: module.name(a, *args, b=1, **kwargs)`. 90 | """ 91 | module_name = self[name]['type'] 92 | module_args = dict(self[name]['args']) 93 | assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed' 94 | module_args.update(kwargs) 95 | return partial(getattr(module, module_name), *args, **module_args) 96 | 97 | def __getitem__(self, name): 98 | """Access items like ordinary dict.""" 99 | return self.config[name] 100 | 101 | # setting read-only attributes 102 | @property 103 | def config(self): 104 | return self._config 105 | 106 | @property 107 | def save_dir(self): 108 | return self._save_dir 109 | 110 | # helper functions to update config dict with custom cli options 111 | def _update_config(config, modification): 112 | if modification is None: 113 | return config 114 | 115 | for k, v in modification.items(): 116 | if v is not None: 117 | _set_by_path(config, k, v) 118 | return config 119 | 120 | 121 | def _get_opt_name(flags): 122 | for flg in flags: 123 | if flg.startswith('--'): 124 | return flg.replace('--', '') 125 | return flags[0].replace('--', '') 126 | 127 | 128 | def _set_by_path(tree, keys, value): 129 | """Set a value in a nested object in tree by sequence of keys.""" 130 | keys = keys.split(';') 131 | _get_by_path(tree, keys[:-1])[keys[-1]] = value 132 | 133 | 134 | def _get_by_path(tree, keys): 135 | """Access a nested object in tree by sequence of keys.""" 136 | return reduce(getitem, keys, tree) 137 | -------------------------------------------------------------------------------- /model/eraft/eraft.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .update import BasicUpdateBlock 7 | from .extractor import BasicEncoder 8 | from .corr import CorrBlock 9 | from .utils import coords_grid, upflow8 10 | from argparse import Namespace 11 | from .image_utils import ImagePadder, forward_interpolate_pytorch 12 | from torch.amp import autocast 13 | 14 | 15 | def get_args(): 16 | # This is an adapter function that converts the arguments given in out config file to the format, which the ERAFT 17 | # expects. 18 | args = Namespace(small=False, 19 | dropout=False, 20 | mixed_precision=False, 21 | clip=1.0) 22 | return args 23 | 24 | 25 | 26 | class ERAFT(nn.Module): 27 | def __init__(self, config, n_first_channels): 28 | # args: 29 | super(ERAFT, self).__init__() 30 | args = get_args() 31 | self.args = args 32 | self.image_padder = ImagePadder(min_size=32) 33 | self.subtype = config['subtype'].lower() 34 | 35 | assert (self.subtype == 'standard' or self.subtype == 'warm_start') 36 | 37 | self.hidden_dim = hdim = 128 38 | self.context_dim = cdim = 128 39 | args.corr_levels = 4 40 | args.corr_radius = 4 41 | 42 | # feature network, context network, and update block 43 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=0, n_first_channels=n_first_channels) 44 | self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=0, n_first_channels=n_first_channels) 45 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) 46 | 47 | self.flow_init = None 48 | 49 | def reset_states(self): 50 | self.flow_init = None 51 | 52 | def freeze_bn(self): 53 | for m in self.modules(): 54 | if isinstance(m, nn.BatchNorm2d): 55 | m.eval() 56 | 57 | def initialize_flow(self, img): 58 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 59 | N, C, H, W = img.shape 60 | coords0 = coords_grid(N, H//8, W//8).to(img.device) 61 | coords1 = coords_grid(N, H//8, W//8).to(img.device) 62 | 63 | # optical flow computed as difference: flow = coords1 - coords0 64 | return coords0, coords1 65 | 66 | def upsample_flow(self, flow, mask): 67 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 68 | N, _, H, W = flow.shape 69 | mask = mask.view(N, 1, 9, 8, 8, H, W) 70 | mask = torch.softmax(mask, dim=2) 71 | 72 | up_flow = F.unfold(8 * flow, [3,3], padding=1) 73 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 74 | 75 | up_flow = torch.sum(mask * up_flow, dim=2) 76 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 77 | return up_flow.reshape(N, 2, 8*H, 8*W) 78 | 79 | 80 | def forward_(self, image1, image2, iters=12): 81 | """ Estimate optical flow between pair of frames """ 82 | # Pad Image (for flawless up&downsampling) 83 | image1 = self.image_padder.pad(image1) 84 | image2 = self.image_padder.pad(image2) 85 | 86 | image1 = image1.contiguous() 87 | image2 = image2.contiguous() 88 | 89 | hdim = self.hidden_dim 90 | cdim = self.context_dim 91 | 92 | # run the feature network 93 | with autocast(enabled=self.args.mixed_precision, device_type="cuda"): 94 | fmap1, fmap2 = self.fnet([image1, image2]) 95 | 96 | fmap1 = fmap1.float() 97 | fmap2 = fmap2.float() 98 | 99 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 100 | 101 | # run the context network 102 | with autocast(enabled=self.args.mixed_precision, device_type="cuda"): 103 | if (self.subtype == 'standard' or self.subtype == 'warm_start'): 104 | cnet = self.cnet(image2) 105 | else: 106 | raise Exception 107 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 108 | net = torch.tanh(net) 109 | inp = torch.relu(inp) 110 | 111 | # Initialize Grids. First channel: x, 2nd channel: y. Image is just used to get the shape 112 | coords0, coords1 = self.initialize_flow(image1) 113 | 114 | if self.flow_init is not None: 115 | coords1 = coords1 + self.flow_init 116 | 117 | flow_predictions = [] 118 | for itr in range(iters): 119 | coords1 = coords1.detach() 120 | corr = corr_fn(coords1) # index correlation volume 121 | 122 | flow = coords1 - coords0 123 | with autocast(enabled=self.args.mixed_precision, device_type="cuda"): 124 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) 125 | 126 | # F(t+1) = F(t) + \Delta(t) 127 | coords1 = coords1 + delta_flow 128 | 129 | # upsample predictions 130 | if up_mask is None: 131 | flow_up = upflow8(coords1 - coords0) 132 | else: 133 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 134 | 135 | flow_predictions.append(self.image_padder.unpad(flow_up)) 136 | 137 | return coords1 - coords0, flow_predictions 138 | 139 | 140 | # Wrap in forward code from E-RAFT/test.py/TestRaftEventsWarm.run_network 141 | def forward(self, image1, image2, iters=12): 142 | flow_low_res, flow_list = self.forward_(image1=image1, image2=image2, iters=iters) 143 | # Keep results for next recurrent call 144 | self.flow_init = forward_interpolate_pytorch(flow_low_res) 145 | flow_est = flow_list[-1] 146 | return flow_est 147 | -------------------------------------------------------------------------------- /model/eitr/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | # from typing import Optional 5 | import copy 6 | 7 | 8 | class transformer(nn.Module): 9 | def __init__(self, d_model=256, nhead=8, num_encoder_layers=6, activation='relu', 10 | num_decoder_layers=6, dim_feedforward=2048, dropout=0.1): 11 | super().__init__() 12 | self.d_model = d_model 13 | self.nhead = nhead 14 | 15 | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, 16 | dropout, activation) 17 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, d_model) 18 | 19 | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, 20 | dropout, activation) 21 | self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, d_model) 22 | 23 | self._reset_parameters() 24 | 25 | def _reset_parameters(self): 26 | for p in self.parameters(): 27 | if p.dim() > 1: 28 | nn.init.xavier_uniform_(p) 29 | 30 | def forward(self, src, pos, task_embed=None): 31 | tgt = memory = self.encoder(src, pos) 32 | output = self.decoder(tgt, memory, task_embed) 33 | 34 | return output 35 | 36 | 37 | class TransformerEncoder(nn.Module): 38 | def __init__(self, encoder_layer, num_layers, d_model): 39 | super().__init__() 40 | self.layers = _get_clones(encoder_layer, num_layers) 41 | 42 | def forward(self, src, pos): 43 | output = src 44 | 45 | for layer in self.layers: 46 | output = layer(output, pos) 47 | 48 | return output 49 | 50 | 51 | class TransformerDecoder(nn.Module): 52 | def __init__(self, encoder_layer, num_layers, d_model): 53 | super().__init__() 54 | self.layers = _get_clones(encoder_layer, num_layers) 55 | 56 | def forward(self, tgt, memory, task_embed): 57 | output = tgt 58 | 59 | for layer in self.layers: 60 | output = layer(output, memory, task_embed) 61 | 62 | return output 63 | 64 | 65 | class TransformerEncoderLayer(nn.Module): 66 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 67 | activation="relu"): 68 | super().__init__() 69 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 70 | self.attn_dropout = nn.Dropout(dropout) 71 | self.norm1 = nn.LayerNorm(d_model) 72 | self.linear1 = nn.Linear(d_model, dim_feedforward) 73 | self.activation = _get_activation_fn(activation) 74 | self.ffn_dropout1 = nn.Dropout(dropout) 75 | self.linear2 = nn.Linear(dim_feedforward, d_model) 76 | self.ffn_dropout2 = nn.Dropout(dropout) 77 | self.norm2 = nn.LayerNorm(d_model) 78 | 79 | def with_embed(self, tensor, pos): 80 | return tensor if pos is None else tensor + pos 81 | 82 | def forward(self, src, pos): 83 | # self attention 84 | q = k = self.with_embed(src, pos) 85 | v = src 86 | src2 = self.self_attn(q, k, v)[0] 87 | src3 = src + self.attn_dropout(src2) 88 | src4 = self.norm1(src3) 89 | 90 | # FFN 91 | src5 = self.linear2(self.ffn_dropout1(self.activation(self.linear1(src4)))) 92 | src6 = src4 + self.ffn_dropout2(src5) 93 | src7 = self.norm2(src6) 94 | 95 | return src7 96 | 97 | 98 | class TransformerDecoderLayer(nn.Module): 99 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 100 | activation="relu"): 101 | super().__init__() 102 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 103 | self.sattn_dropout = nn.Dropout(dropout) 104 | self.norm1 = nn.LayerNorm(d_model) 105 | self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 106 | self.cattn_dropout = nn.Dropout(dropout) 107 | self.norm2 = nn.LayerNorm(d_model) 108 | self.linear1 = nn.Linear(d_model, dim_feedforward) 109 | self.activation = _get_activation_fn(activation) 110 | self.ffn_dropout1 = nn.Dropout(dropout) 111 | self.linear2 = nn.Linear(dim_feedforward, d_model) 112 | self.ffn_dropout2 = nn.Dropout(dropout) 113 | self.norm3 = nn.LayerNorm(d_model) 114 | 115 | def with_embed(self, tensor, pos): 116 | return tensor if pos is None else tensor + pos 117 | 118 | def forward(self, tgt, memory, task_embed): 119 | # self attention 120 | v = tgt 121 | q = k = self.with_embed(tgt, task_embed) 122 | tgt2 = self.self_attn(q, k, v)[0] 123 | tgt3 = tgt + self.sattn_dropout(tgt2) 124 | tgt4 = self.norm1(tgt3) 125 | 126 | # cross attention 127 | q = self.with_embed(tgt4, task_embed) 128 | k = v = memory 129 | tgt5 = self.cross_attn(q, k, v)[0] 130 | tgt6 = tgt4 + self.cattn_dropout(tgt5) 131 | tgt7 = self.norm2(tgt6) 132 | 133 | # FFN 134 | tgt8 = self.linear2(self.ffn_dropout1(self.activation(self.linear1(tgt7)))) 135 | tgt9 = tgt7 + self.ffn_dropout2(tgt8) 136 | tgt10 = self.norm3(tgt9) 137 | 138 | return tgt10 139 | 140 | 141 | def _get_clones(module, N): 142 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 143 | 144 | 145 | def _get_activation_fn(activation): 146 | """Return an activation function given a string""" 147 | if activation == "relu": 148 | return F.relu 149 | if activation == "gelu": 150 | return F.gelu 151 | if activation == "glu": 152 | return F.glu 153 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 154 | 155 | 156 | def build_transformer(args): 157 | return transformer(**args) 158 | -------------------------------------------------------------------------------- /model/eraft/image_utils.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import torch 3 | from torch import nn 4 | from torch.nn.functional import grid_sample 5 | from scipy.spatial import transform 6 | from scipy import interpolate 7 | from matplotlib import pyplot as plt 8 | 9 | 10 | def grid_sample_values(input, height, width): 11 | # ================================ Grid Sample Values ============================= # 12 | # Input: Torch Tensor [3,H*W]m where the 3 Dimensions mean [x,y,z] # 13 | # Height: Image Height # 14 | # Width: Image Width # 15 | # --------------------------------------------------------------------------------- # 16 | # Output: tuple(value_ipl, valid_mask) # 17 | # value_ipl -> [H,W]: Interpolated values # 18 | # valid_mask -> [H,W]: 1: Point is valid, 0: Point is invalid # 19 | # ================================================================================= # 20 | device = input.device 21 | ceil = torch.stack([torch.ceil(input[0,:]), torch.ceil(input[1,:]), input[2,:]]) 22 | floor = torch.stack([torch.floor(input[0,:]), torch.floor(input[1,:]), input[2,:]]) 23 | z = input[2,:].clone() 24 | 25 | values_ipl = torch.zeros(height*width, device=device) 26 | weights_acc = torch.zeros(height*width, device=device) 27 | # Iterate over all ceil/floor points 28 | for x_vals in [floor[0], ceil[0]]: 29 | for y_vals in [floor[1], ceil[1]]: 30 | # Mask Points that are in the image 31 | in_bounds_mask = (x_vals < width) & (x_vals >=0) & (y_vals < height) & (y_vals >= 0) 32 | 33 | # Calculate weights, according to their real distance to the floored/ceiled value 34 | weights = (1 - (input[0]-x_vals).abs()) * (1 - (input[1]-y_vals).abs()) 35 | 36 | # Put them into the right grid 37 | indices = (x_vals + width * y_vals).long() 38 | values_ipl.put_(indices[in_bounds_mask], (z * weights)[in_bounds_mask], accumulate=True) 39 | weights_acc.put_(indices[in_bounds_mask], weights[in_bounds_mask], accumulate=True) 40 | 41 | # Mask of valid pixels -> Everywhere where we have an interpolated value 42 | valid_mask = weights_acc.clone() 43 | valid_mask[valid_mask > 0] = 1 44 | valid_mask= valid_mask.bool().reshape([height,width]) 45 | 46 | # Divide by weights to get interpolated values 47 | values_ipl = values_ipl / (weights_acc + 1e-15) 48 | values_rs = values_ipl.reshape([height,width]) 49 | 50 | return values_rs.unsqueeze(0).clone(), valid_mask.unsqueeze(0).clone() 51 | 52 | def forward_interpolate_pytorch(flow_in): 53 | # Same as the numpy implementation, but differentiable :) 54 | # Flow: [B,2,H,W] 55 | flow = flow_in.clone() 56 | if len(flow.shape) < 4: 57 | flow = flow.unsqueeze(0) 58 | 59 | b, _, h, w = flow.shape 60 | device = flow.device 61 | 62 | dx ,dy = flow[:,0], flow[:,1] 63 | y0, x0 = torch.meshgrid(torch.arange(0, h, 1), torch.arange(0, w, 1)) 64 | x0 = torch.stack([x0]*b).to(device) 65 | y0 = torch.stack([y0]*b).to(device) 66 | 67 | x1 = x0 + dx 68 | y1 = y0 + dy 69 | 70 | x1 = x1.flatten(start_dim=1) 71 | y1 = y1.flatten(start_dim=1) 72 | dx = dx.flatten(start_dim=1) 73 | dy = dy.flatten(start_dim=1) 74 | 75 | # Interpolate Griddata... 76 | # Note that a Nearest Neighbor Interpolation would be better. But there does not exist a pytorch fcn yet. 77 | # See issue: https://github.com/pytorch/pytorch/issues/50339 78 | flow_new = torch.zeros(flow.shape, device=device) 79 | for i in range(b): 80 | flow_new[i,0] = grid_sample_values(torch.stack([x1[i],y1[i],dx[i]]), h, w)[0] 81 | flow_new[i,1] = grid_sample_values(torch.stack([x1[i],y1[i],dy[i]]), h, w)[0] 82 | 83 | return flow_new 84 | 85 | class ImagePadder(object): 86 | # =================================================================== # 87 | # In some networks, the image gets downsized. This is a problem, if # 88 | # the to-be-downsized image has odd dimensions ([15x20]->[7.5x10]). # 89 | # To prevent this, the input image of the network needs to be a # 90 | # multiple of a minimum size (min_size) # 91 | # The ImagePadder makes sure, that the input image is of such a size, # 92 | # and if not, it pads the image accordingly. # 93 | # =================================================================== # 94 | 95 | def __init__(self, min_size=64): 96 | # --------------------------------------------------------------- # 97 | # The min_size additionally ensures, that the smallest image # 98 | # does not get too small # 99 | # --------------------------------------------------------------- # 100 | self.min_size = min_size 101 | self.pad_height = None 102 | self.pad_width = None 103 | 104 | def pad(self, image): 105 | # --------------------------------------------------------------- # 106 | # If necessary, this function pads the image on the left & top # 107 | # --------------------------------------------------------------- # 108 | height, width = image.shape[-2:] 109 | 110 | self.pad_height = (self.min_size - height % self.min_size)%self.min_size 111 | self.pad_width = (self.min_size - width % self.min_size)%self.min_size 112 | 113 | return nn.ZeroPad2d((self.pad_width, 0, self.pad_height, 0))(image) 114 | 115 | def unpad(self, image): 116 | # --------------------------------------------------------------- # 117 | # Removes the padded rows & columns # 118 | # --------------------------------------------------------------- # 119 | return image[..., self.pad_height:, self.pad_width:] 120 | -------------------------------------------------------------------------------- /data/esim_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import h5py 3 | import random 4 | import numpy as np 5 | from utils.data import data_sources 6 | 7 | def add_hot_pixels_to_voxels(voxels, hot_pixel_std=1.0, max_hot_pixel_fraction=0.001, integer_noise=False): 8 | # voxels.shape = (T, C, H, W) 9 | T, C, H, W = voxels.shape 10 | hot_pixel_fraction = random.uniform(0, max_hot_pixel_fraction) 11 | num_hot_pixels = int(hot_pixel_fraction * H * W) 12 | x = np.random.randint(0, W, num_hot_pixels) 13 | y = np.random.randint(0, H, num_hot_pixels) 14 | if integer_noise: 15 | # Model the noise N = y * sign, where y is a poisson distribution with lambda, and sign is 50% prob +1, 50% prob -1. 16 | # Then var(N) will have variance lambda**2 + lambda. 17 | # We hope lambda**2 + lambda == hot_pixel_std**2. 18 | # So lambda = (-1 + sqrt(1 + 4s**2)) / 2. 19 | lmb = (-1 + np.sqrt(1 + 4 * hot_pixel_std**2)) / 2 20 | y = np.random.poisson(lam=lmb, size=num_hot_pixels) 21 | sign = 2 * np.random.randint(0, 2, size=num_hot_pixels) - 1 22 | val = y * sign 23 | else: 24 | val = np.random.randn(num_hot_pixels) 25 | val *= hot_pixel_std 26 | noise = np.zeros((H, W)) 27 | np.add.at(noise, (y, x), val) 28 | noise = noise[np.newaxis, np.newaxis, ...] 29 | voxels += noise 30 | return voxels 31 | 32 | 33 | def add_noise_to_voxel(voxel, noise_std=1.0, noise_fraction=0.1, integer_noise=False): 34 | if integer_noise: 35 | # lambda-poisson * (50% +1, 50% -1) 36 | lmb = (-1 + np.sqrt(1 + 4 * noise_std**2)) / 2 37 | y = np.random.poisson(lam=lmb, size=voxel.shape) 38 | sign = 2 * np.random.randint(0, 2, size=voxel.shape) - 1 39 | noise = y * sign 40 | else: 41 | noise = noise_std * np.random.randn(*voxel.shape) # mean = 0, std = noise_std 42 | 43 | if noise_fraction < 1.0: 44 | mask = np.random.rand(*voxel.shape) >= noise_fraction 45 | noise = np.where(mask, 0, noise) 46 | return voxel + noise 47 | 48 | 49 | class ESIMH5Dataset(torch.utils.data.Dataset): 50 | # The original codebase did not provide HDF5Dataset, which is used in the training configuration. 51 | # This dataset is similar to DynamicH5Dataset, except that it caches the voxels in h5 file. 52 | """ 53 | Dataloader for events saved in the Monash University HDF5 events format 54 | (see https://github.com/TimoStoff/event_utils for code to convert datasets) 55 | """ 56 | 57 | def __init__(self, h5_path, configs): 58 | self.h5_path = h5_path 59 | self.sequence_length = configs.get('sequence_length', 40) 60 | self.step_size = configs.get('step_size', self.sequence_length) 61 | self.proba_pause_when_running = configs.get('proba_pause_when_running', 0.05) 62 | self.proba_pause_when_paused = configs.get('proba_pause_when_paused', 0.9) 63 | self.noise_std = configs.get('noise_std', 0.1) 64 | self.noise_fraction = configs.get('noise_fraction', 1.0) 65 | self.hot_pixel_std = configs.get('hot_pixel_std', 0.1) 66 | self.max_hot_pixel_fraction = configs.get('max_hot_pixel_fraction', 0.001) 67 | self.random_crop_size = configs.get('random_crop_size', 112) 68 | self.random_flip = configs.get('random_flip', True) 69 | self.integer_noise = configs.get('integer_noise', False) 70 | 71 | self.h5_file = h5py.File(h5_path, 'r') 72 | self.sensor_resolution = self.h5_file.attrs['sensor_resolution'][0:2] 73 | self.num_frames = self.h5_file['frames'].shape[0] 74 | self.data_source_name = "esim" 75 | self.data_source_idx = data_sources.index(self.data_source_name) 76 | 77 | self.samples = [] 78 | for i in range(0, self.num_frames - self.sequence_length, self.step_size): 79 | self.samples.append((i, i + self.sequence_length)) 80 | 81 | def __len__(self): 82 | return len(self.samples) 83 | 84 | def __getitem__(self, index): 85 | begin_i, end_i = self.samples[index] 86 | 87 | all_frame = self.h5_file["frames"][begin_i:end_i] 88 | all_frame = all_frame # in [0, 1] 89 | all_flow = self.h5_file["flow"][begin_i:end_i] 90 | all_voxel = self.h5_file["events"][begin_i:end_i] 91 | 92 | # Random crop 93 | T, one, H, W = all_frame.shape 94 | 95 | if self.random_crop_size is not None: 96 | # Random crop[] 97 | th, tw = self.random_crop_size, self.random_crop_size 98 | i = random.randint(0, H - th) 99 | j = random.randint(0, W - tw) 100 | all_frame = all_frame[:, :, i:i+th, j:j+tw] 101 | all_flow = all_flow[:, :, i:i+th, j:j+tw] 102 | all_voxel = all_voxel[:, :, i:i+th, j:j+tw] 103 | 104 | # Random flip 105 | if self.random_flip and random.random() > 0.5: 106 | all_frame = np.flip(all_frame, axis=3) 107 | all_flow = np.flip(all_flow, axis=3) 108 | all_voxel = np.flip(all_voxel, axis=3) 109 | 110 | # Random pause 111 | frame = np.zeros_like(all_frame) 112 | flow = np.zeros_like(all_flow) 113 | voxel = np.zeros_like(all_voxel) 114 | timestamp = [] 115 | 116 | paused = False 117 | k = 0 118 | for t_idx in range(self.sequence_length): 119 | # decide whether we should make a "pause" at this step 120 | # the probability of "pause" is conditioned on the previous state (to encourage long sequences) 121 | u = np.random.rand() 122 | if paused: 123 | probability_pause = self.proba_pause_when_paused 124 | else: 125 | probability_pause = self.proba_pause_when_running 126 | paused = (u < probability_pause) 127 | if t_idx > 0 and paused: # Cannot pause at the first frame 128 | # add a tensor filled with zeros, paired with the last frame 129 | # do not increase the counter 130 | frame[t_idx] = frame[t_idx - 1] 131 | # Leave the flow and voxel as zeros 132 | 133 | else: 134 | # normal case: append the next item to the list 135 | frame[t_idx] = all_frame[k] 136 | flow[t_idx] = all_flow[k] 137 | voxel[t_idx] = all_voxel[k] 138 | k += 1 139 | 140 | # add noise 141 | voxel[t_idx] = add_noise_to_voxel(voxel[t_idx], self.noise_std, self.noise_fraction, integer_noise=self.integer_noise) 142 | 143 | voxel = add_hot_pixels_to_voxels(voxel, self.hot_pixel_std, self.max_hot_pixel_fraction, integer_noise=self.integer_noise) 144 | 145 | 146 | item = { 147 | 'frame': torch.Tensor(frame), 148 | 'flow': torch.Tensor(flow), 149 | 'events': torch.Tensor(voxel), 150 | 'data_source_idx': torch.tensor(self.data_source_idx), 151 | } 152 | 153 | return item -------------------------------------------------------------------------------- /utils/training_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from torchvision import utils 5 | # local modules 6 | from utils.myutil import quick_norm 7 | 8 | 9 | def make_tc_vis(tc_output): 10 | imgs0 = [tc_output['image1'], tc_output['image1'], tc_output['visibility_mask']] 11 | imgs1 = [tc_output['image0'], tc_output['image0_warped_to1'], tc_output['visibility_mask']] 12 | frames = [] 13 | for imgs in [imgs0, imgs1]: 14 | imgs = [i[0, ...].expand(3, -1, -1) for i in imgs] 15 | frames.append(utils.make_grid(imgs, nrow=3)) 16 | return torch.stack(frames, dim=0).unsqueeze(0) 17 | 18 | def make_vw_vis(tc_output): 19 | event_preview = torch.sum(tc_output['voxel_grid'], dim=1, keepdim=True)[0, ...].expand(3, -1, -1) 20 | events_warped = tc_output['voxel_grid_warped'][0, ...].expand(3, -1, -1) 21 | frames = [] 22 | frames.append(utils.make_grid([event_preview, events_warped], nrow=2)) 23 | frames.append(utils.make_grid([events_warped, event_preview], nrow=2)) 24 | return torch.stack(frames, dim=0).unsqueeze(0) 25 | 26 | def make_flow_movie(event_previews, predicted_frames, groundtruth_frames, predicted_flows, groundtruth_flows): 27 | # event_previews: a list of [1 x 1 x H x W] event previews 28 | # predicted_frames: a list of [1 x 1 x H x W] predicted frames 29 | # flows: a list of [1 x 2 x H x W] predicted frames 30 | # for movie, we need to pass [1 x T x 1 x H x W] where T is the time dimension max_magnitude = 40 31 | if groundtruth_flows is None: 32 | groundtruth_flows = [] 33 | max_magnitude = None 34 | movie_frames = [] 35 | for i, flow in enumerate(predicted_flows): 36 | voxel = quick_norm(event_previews[i][0, ...]).expand(3, -1, -1) 37 | pred_frame = quick_norm(predicted_frames[i][0, ...]).expand(3, -1, -1) 38 | gt_frame = groundtruth_frames[i][0, ...].expand(3, -1, -1) 39 | #cv2.imwrite("temp/gt_frame.png", gt_frame.cpu().numpy().squeeze()) 40 | #pred_flow_rgb = flow2rgb(flow[0, 0, :, :], flow[0, 1, :, :], max_magnitude) 41 | #blank = torch.zeros_like(gt_frame) 42 | #imgs = [voxel, pred_frame, gt_frame, blank, pred_flow_rgb.float()] 43 | #if groundtruth_flows: 44 | # gt_flow = groundtruth_flows[i] 45 | # gt_flow_rgb = flow2rgb(gt_flow[0, 0, :, :], gt_flow[0, 1, :, :], max_magnitude) 46 | # imgs.append(gt_flow_rgb.float()) 47 | imgs = [-voxel, -pred_frame, -gt_frame] 48 | movie_frame = utils.make_grid(imgs, nrow=3) 49 | movie_frames.append(movie_frame) 50 | return torch.stack(movie_frames, dim=0).unsqueeze(0) 51 | 52 | 53 | def flush(summary_writer): 54 | for writer in summary_writer.all_writers.values(): 55 | writer.flush() 56 | 57 | 58 | def count_parameters(model): 59 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 60 | 61 | 62 | def select_evenly_spaced_elements(num_elements, sequence_length): 63 | return [i * sequence_length // num_elements + sequence_length // (2 * num_elements) for i in range(num_elements)] 64 | 65 | 66 | def flow2bgr_np(disp_x, disp_y, max_magnitude=None): 67 | """ 68 | Convert an optic flow tensor to an RGB color map for visualization 69 | Code adapted from: https://github.com/ClementPinard/FlowNetPytorch/blob/master/main.py#L339 70 | 71 | :param disp_x: a [H x W] NumPy array containing the X displacement 72 | :param disp_x: a [H x W] NumPy array containing the Y displacement 73 | :returns bgr: a [H x W x 3] NumPy array containing a color-coded representation of the flow [0, 255] 74 | """ 75 | assert(disp_x.shape == disp_y.shape) 76 | H, W = disp_x.shape 77 | 78 | # X, Y = np.meshgrid(np.linspace(-1, 1, H), np.linspace(-1, 1, W)) 79 | 80 | # flow_x = (X - disp_x) * float(W) / 2 81 | # flow_y = (Y - disp_y) * float(H) / 2 82 | # magnitude, angle = cv2.cartToPolar(flow_x, flow_y) 83 | # magnitude, angle = cv2.cartToPolar(disp_x, disp_y) 84 | 85 | # follow alex zhu color convention https://github.com/daniilidis-group/EV-FlowNet 86 | 87 | flows = np.stack((disp_x, disp_y), axis=2) 88 | magnitude = np.linalg.norm(flows, axis=2) 89 | 90 | angle = np.arctan2(disp_y, disp_x) 91 | angle += np.pi 92 | angle *= 180. / np.pi / 2. 93 | angle = angle.astype(np.uint8) 94 | 95 | if max_magnitude is None: 96 | v = np.zeros(magnitude.shape, dtype=np.uint8) 97 | cv2.normalize(src=magnitude, dst=v, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U) 98 | else: 99 | v = np.clip(255.0 * magnitude / max_magnitude, 0, 255) 100 | v = v.astype(np.uint8) 101 | 102 | hsv = np.zeros((H, W, 3), dtype=np.uint8) 103 | hsv[..., 1] = 255 104 | hsv[..., 0] = angle 105 | hsv[..., 2] = v 106 | bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) 107 | 108 | return bgr 109 | 110 | 111 | def flow2rgb(disp_x, disp_y, max_magnitude=None): 112 | return flow2bgr(disp_x, disp_y, max_magnitude)[[2, 1, 0], ...] 113 | 114 | 115 | def flow2bgr(disp_x, disp_y, max_magnitude=None): 116 | device = disp_x.device 117 | bgr = flow2bgr_np(disp_x.cpu().numpy(), disp_y.cpu().numpy(), max_magnitude) 118 | bgr = bgr.astype(float) / 255 119 | return torch.tensor(bgr).permute(2, 0, 1).to(device) # 3 x H x W 120 | 121 | 122 | def make_movie(event_previews, predicted_frames, groundtruth_frames): 123 | # event_previews: a list of [1 x 1 x H x W] event previews 124 | # predicted_frames: a list of [1 x 1 x H x W] predicted frames 125 | # for movie, we need to pass [1 x T x 1 x H x W] where T is the time dimension 126 | 127 | video_tensor = None 128 | for i in torch.arange(len(event_previews)): 129 | voxel = quick_norm(event_previews[i]) 130 | predicted_frame = quick_norm(predicted_frames[i]) 131 | movie_frame = torch.cat([voxel, 132 | predicted_frame, 133 | groundtruth_frames[i]], 134 | dim=-1) 135 | movie_frame.unsqueeze_(dim=0) 136 | video_tensor = movie_frame if video_tensor is None else \ 137 | torch.cat((video_tensor, movie_frame), dim=1) 138 | return video_tensor 139 | -------------------------------------------------------------------------------- /PerceptualSimilarity/models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import numpy as np 7 | from skimage.metrics import structural_similarity as compare_ssim 8 | # from skimage.metrics import structural_similarity 9 | import torch 10 | from torch.autograd import Variable 11 | 12 | # import sys, os 13 | # sys.path.append(os.getcwd()) 14 | from PerceptualSimilarity.models import dist_model 15 | # import .dist_model 16 | # import dist_model 17 | 18 | 19 | class PerceptualLoss(torch.nn.Module): 20 | def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric) 21 | # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss 22 | super(PerceptualLoss, self).__init__() 23 | print('Setting up Perceptual loss...') 24 | self.use_gpu = use_gpu 25 | self.spatial = spatial 26 | self.gpu_ids = gpu_ids 27 | self.model = dist_model.DistModel() 28 | self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids) 29 | print('...[%s] initialized'%self.model.name()) 30 | print('...Done') 31 | 32 | def forward(self, pred, target, normalize=False): 33 | """ 34 | Pred and target are Variables. 35 | If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] 36 | If normalize is False, assumes the images are already between [-1,+1] 37 | 38 | Inputs pred and target are Nx3xHxW 39 | Output pytorch Variable N long 40 | """ 41 | 42 | if normalize: 43 | target = 2 * target - 1 44 | pred = 2 * pred - 1 45 | 46 | return self.model.forward(target, pred) 47 | 48 | def normalize_tensor(in_feat,eps=1e-10): 49 | norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True)) 50 | return in_feat/(norm_factor+eps) 51 | 52 | def l2(p0, p1, range=255.): 53 | return .5*np.mean((p0 / range - p1 / range)**2) 54 | 55 | def psnr(p0, p1, peak=255.): 56 | return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2)) 57 | 58 | def dssim(p0, p1, range=255.): 59 | return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2. 60 | 61 | def rgb2lab(in_img,mean_cent=False): 62 | from skimage import color 63 | img_lab = color.rgb2lab(in_img) 64 | if(mean_cent): 65 | img_lab[:,:,0] = img_lab[:,:,0]-50 66 | return img_lab 67 | 68 | def tensor2np(tensor_obj): 69 | # change dimension of a tensor object into a numpy array 70 | return tensor_obj[0].cpu().float().numpy().transpose((1,2,0)) 71 | 72 | def np2tensor(np_obj): 73 | # change dimenion of np array into tensor array 74 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 75 | 76 | def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False): 77 | # image tensor to lab tensor 78 | from skimage import color 79 | 80 | img = tensor2im(image_tensor) 81 | img_lab = color.rgb2lab(img) 82 | if(mc_only): 83 | img_lab[:,:,0] = img_lab[:,:,0]-50 84 | if(to_norm and not mc_only): 85 | img_lab[:,:,0] = img_lab[:,:,0]-50 86 | img_lab = img_lab/100. 87 | 88 | return np2tensor(img_lab) 89 | 90 | def tensorlab2tensor(lab_tensor,return_inbnd=False): 91 | from skimage import color 92 | import warnings 93 | warnings.filterwarnings("ignore") 94 | 95 | lab = tensor2np(lab_tensor)*100. 96 | lab[:,:,0] = lab[:,:,0]+50 97 | 98 | rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1) 99 | if(return_inbnd): 100 | # convert back to lab, see if we match 101 | lab_back = color.rgb2lab(rgb_back.astype('uint8')) 102 | mask = 1.*np.isclose(lab_back,lab,atol=2.) 103 | mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis]) 104 | return (im2tensor(rgb_back),mask) 105 | else: 106 | return im2tensor(rgb_back) 107 | 108 | def rgb2lab(input): 109 | from skimage import color 110 | return color.rgb2lab(input / 255.) 111 | 112 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 113 | image_numpy = image_tensor[0].cpu().float().numpy() 114 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 115 | return image_numpy.astype(imtype) 116 | 117 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 118 | return torch.Tensor((image / factor - cent) 119 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 120 | 121 | def tensor2vec(vector_tensor): 122 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0] 123 | 124 | def voc_ap(rec, prec, use_07_metric=False): 125 | """ ap = voc_ap(rec, prec, [use_07_metric]) 126 | Compute VOC AP given precision and recall. 127 | If use_07_metric is true, uses the 128 | VOC 07 11 point method (default:False). 129 | """ 130 | if use_07_metric: 131 | # 11 point metric 132 | ap = 0. 133 | for t in np.arange(0., 1.1, 0.1): 134 | if np.sum(rec >= t) == 0: 135 | p = 0 136 | else: 137 | p = np.max(prec[rec >= t]) 138 | ap = ap + p / 11. 139 | else: 140 | # correct AP calculation 141 | # first append sentinel values at the end 142 | mrec = np.concatenate(([0.], rec, [1.])) 143 | mpre = np.concatenate(([0.], prec, [0.])) 144 | 145 | # compute the precision envelope 146 | for i in range(mpre.size - 1, 0, -1): 147 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 148 | 149 | # to calculate area under PR curve, look for points 150 | # where X axis (recall) changes value 151 | i = np.where(mrec[1:] != mrec[:-1])[0] 152 | 153 | # and sum (\Delta recall) * prec 154 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 155 | return ap 156 | 157 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 158 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): 159 | image_numpy = image_tensor[0].cpu().float().numpy() 160 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 161 | return image_numpy.astype(imtype) 162 | 163 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 164 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): 165 | return torch.Tensor((image / factor - cent) 166 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 167 | -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | # local modules 4 | from utils import loss 5 | from torchmetrics.image import StructuralSimilarityIndexMeasure as SSIM 6 | from PerceptualSimilarity.models import PerceptualLoss 7 | 8 | class combined_perceptual_loss(): 9 | def __init__(self, weight=1.0, use_gpu=True): 10 | """ 11 | Flow wrapper for perceptual_loss 12 | """ 13 | self.loss = perceptual_loss(weight=1.0, use_gpu=use_gpu) 14 | self.weight = weight 15 | 16 | def __call__(self, pred_img, pred_flow, target_img, target_flow): 17 | """ 18 | image is tensor of N x 2 x H x W, flow of N x 2 x H x W 19 | These are concatenated, as perceptualLoss expects N x 3 x H x W. 20 | """ 21 | pred = torch.cat([pred_img, pred_flow], dim=1) 22 | target = torch.cat([target_img, target_flow], dim=1) 23 | dist = self.loss(pred, target, normalize=False) 24 | return dist * self.weight 25 | 26 | 27 | class warping_flow_loss(): 28 | def __init__(self, weight=1.0, L0=1): 29 | assert L0 > 0 30 | self.loss = loss.warping_flow_loss 31 | self.weight = weight 32 | self.L0 = L0 33 | self.default_return = None 34 | 35 | def __call__(self, i, image1, flow): 36 | """ 37 | flow is from image0 to image1 (reversed when passed to 38 | warping_flow_loss function) 39 | """ 40 | loss = self.default_return if i < self.L0 else self.weight * self.loss( 41 | self.image0, image1, -flow) 42 | self.image0 = image1 43 | return loss 44 | 45 | 46 | class voxel_warp_flow_loss(): 47 | def __init__(self, weight=1.0): 48 | self.loss = loss.voxel_warping_flow_loss 49 | self.weight = weight 50 | 51 | def __call__(self, voxel, displacement, output_images=False): 52 | """ 53 | Warp the voxel grid by the displacement map. Variance 54 | of resulting image is loss 55 | """ 56 | loss = self.loss(voxel, displacement, output_images) 57 | if output_images: 58 | loss = (self.weight * loss[0], loss[1]) 59 | else: 60 | loss *= self.weight 61 | return loss 62 | 63 | 64 | class flow_perceptual_loss(): 65 | def __init__(self, weight=1.0, use_gpu=True): 66 | """ 67 | Flow wrapper for perceptual_loss 68 | """ 69 | self.loss = perceptual_loss(weight=1.0, use_gpu=use_gpu) 70 | self.weight = weight 71 | 72 | def __call__(self, pred, target): 73 | """ 74 | pred and target are Tensors with shape N x 2 x H x W 75 | PerceptualLoss expects N x 3 x H x W. 76 | """ 77 | dist_x = self.loss(pred[:, 0:1, :, :], target[:, 0:1, :, :], normalize=False) 78 | dist_y = self.loss(pred[:, 1:2, :, :], target[:, 1:2, :, :], normalize=False) 79 | return (dist_x + dist_y) / 2 * self.weight 80 | 81 | 82 | class flow_l1_loss(): 83 | def __init__(self, weight=1.0): 84 | self.loss = F.l1_loss 85 | self.weight = weight 86 | 87 | def __call__(self, pred, target): 88 | return self.weight * self.loss(pred, target) 89 | 90 | 91 | # keep for compatibility 92 | flow_loss = flow_l1_loss 93 | 94 | 95 | class perceptual_loss(): 96 | def __init__(self, weight=1.0, net='alex', use_gpu=True): 97 | """ 98 | Wrapper for PerceptualSimilarity.models.PerceptualLoss 99 | """ 100 | self.model = PerceptualLoss(net=net, use_gpu=use_gpu) 101 | self.weight = weight 102 | 103 | def __call__(self, pred, target, normalize=True, reduce_batch=True): 104 | """ 105 | pred and target are Tensors with shape N x C x H x W (C {1, 3}) 106 | normalize scales images from [0, 1] to [-1, 1] (default: True) 107 | PerceptualLoss expects N x 3 x H x W. 108 | """ 109 | if pred.shape[1] == 1: 110 | pred = torch.cat([pred, pred, pred], dim=1) 111 | if target.shape[1] == 1: 112 | target = torch.cat([target, target, target], dim=1) 113 | dist = self.model.forward(pred, target, normalize=normalize) 114 | if reduce_batch: 115 | return self.weight * dist.mean() 116 | B, _, _, _ = dist.shape 117 | dist = dist.reshape((B, -1)) 118 | dist = dist.mean(axis=1) 119 | return self.weight * dist 120 | 121 | class l2_loss(): 122 | def __init__(self, weight=1.0): 123 | self.weight = weight 124 | 125 | def __call__(self, pred, target, reduce_batch=True): 126 | loss = (pred - target)**2 127 | B = pred.shape[0] 128 | if reduce_batch: 129 | loss = torch.mean(loss) 130 | else: 131 | loss = loss.reshape((B, -1)) 132 | loss = torch.mean(loss, dim=1) 133 | return self.weight * loss 134 | 135 | class l1_loss(): 136 | def __init__(self, weight=1.0): 137 | self.weight = weight 138 | 139 | def __call__(self, pred, target, reduce_batch=True): 140 | loss = torch.abs(pred - target) 141 | B = pred.shape[0] 142 | if reduce_batch: 143 | loss = torch.mean(loss) 144 | else: 145 | loss = loss.reshape((B, -1)) 146 | loss = torch.mean(loss, dim=1) 147 | return self.weight * loss 148 | 149 | class ssim_loss(): 150 | def __init__(self, weight=1.0, model=None): 151 | assert model is not None 152 | self.loss = model 153 | self.weight = weight 154 | 155 | def __call__(self, pred, target): 156 | # SSIM: larger is better 157 | assert False, "This function causes multi-GPU issues." 158 | 159 | class temporal_consistency_loss(): 160 | def __init__(self, weight=1.0, L0=1): 161 | assert L0 > 0 162 | self.loss = loss.temporal_consistency_loss 163 | self.weight = weight 164 | self.L0 = L0 165 | 166 | def __call__(self, i, image1, processed1, flow, output_images=False, reduce_batch=True): 167 | """ 168 | flow is from image0 to image1 (reversed when passed to 169 | temporal_consistency_loss function) 170 | """ 171 | if i >= self.L0: 172 | loss = self.loss(self.image0, image1, self.processed0, processed1, 173 | -flow, output_images=output_images, reduce_batch=reduce_batch) 174 | if output_images: 175 | loss = (self.weight * loss[0], loss[1]) 176 | else: 177 | loss *= self.weight 178 | else: 179 | loss = 0 180 | self.image0 = image1 181 | self.processed0 = processed1 182 | return loss 183 | -------------------------------------------------------------------------------- /test_flow.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import yaml 4 | import torch 5 | import tqdm 6 | import cv2 7 | import numpy as np 8 | from collections import defaultdict 9 | from utils.data import data_sources 10 | 11 | from model.train_flow_utils import FlowModelInterface, flow2rgb_np 12 | from train import convert_to_compiled 13 | from test_e2vid import create_test_dataloader 14 | 15 | metrics = ["dense_EPE", "dense_1PE", "dense_3PE", "sparse_EPE", "sparse_1PE", "sparse_3PE"] 16 | sequences = { 17 | #"IJRR": ["boxes_6dof", "calibration", "dynamic_6dof", "office_zigzag", "poster_6dof", "shapes_6dof", "slider_depth"], 18 | "MVSEC": ["indoor_flying1", "indoor_flying2", "indoor_flying3", "outdoor_day1", "outdoor_day2"], 19 | #"HQF": ["bike_bay_hdr", "boxes", "desk", "desk_fast", "desk_hand_only", "desk_slow", "engineering_posters", "high_texture_plants", "poster_pillar_1", "poster_pillar_2", "reflective_materials", "slow_and_fast_desk", "slow_hand", "still_life"], 20 | #"EVAID": ["ball", "bear", "box", "building", "outdoor", "playball", "room1", "sculpture", "toy", "traffic", "wall"] 21 | } 22 | all_metric_names = [] 23 | for k, v in sequences.items(): 24 | for seqname in v: 25 | for m in metrics: 26 | all_metric_names.append(f"{k}/{seqname}/{m}") 27 | 28 | def run_test(model_interface, dataloader, device, configs): 29 | 30 | output_dir = configs["test_output_dir"] 31 | save_npy = configs.get("save_npy", False) 32 | save_png = configs.get("save_png", True) 33 | 34 | # Actually a flow model, but variable name is still e2vid_model 35 | model_interface.e2vid_model.eval() 36 | 37 | previous_test_sequence = None 38 | all_metrics = defaultdict(list) 39 | 40 | with torch.no_grad(): 41 | 42 | for batch_idx, batch in enumerate(tqdm.tqdm(dataloader)): 43 | sequence_name = batch["sequence_name"][0][0] 44 | 45 | if previous_test_sequence is None or previous_test_sequence != sequence_name: 46 | model_interface.e2vid_model.reset_states() 47 | output_img_idx = 0 48 | if output_dir is not None: 49 | data_source_idx = batch["data_source_idx"][0] 50 | data_source = data_sources[data_source_idx].upper() 51 | seq_output_dir = os.path.join(output_dir, data_source, sequence_name) 52 | #print("seq_output_dir:", seq_output_dir) 53 | os.makedirs(seq_output_dir, exist_ok=True) 54 | 55 | for k, v in batch.items(): 56 | if torch.is_tensor(v): 57 | batch[k] = v.to(device) 58 | 59 | pred = model_interface.forward_sequence(batch, reset_states=False, test=True, val=True) # Reset manually according to sequence name 60 | # pred: (B, T, 2, H, W) optical flow 61 | 62 | #pred = torch.zeros_like(pred) 63 | metrics = model_interface.compute_metrics(pred, batch) 64 | for k, v in metrics.items(): 65 | if k in all_metric_names: 66 | all_metrics[k] += v # v is also a list 67 | 68 | 69 | if output_dir is not None: 70 | one, T, C, H, W = pred.shape 71 | for t in range(T): 72 | flow = pred[0, t, :].detach().cpu().numpy() 73 | # gt_flow = batch["flow"][0, t, :].detach().cpu().numpy() 74 | # gt_flow = np.where(np.isnan(gt_flow), 0, gt_flow) 75 | # flow = gt_flow 76 | #cat_flow = np.concatenate([flow, gt_flow], axis=2) 77 | if save_npy: 78 | np.save(os.path.join(seq_output_dir, f"{output_img_idx:06d}.npy"), flow) 79 | if save_png: 80 | flow_vis = flow2rgb_np(flow[0, :, :], flow[1, :, :]) 81 | #flow_vis = flow2rgb_np(cat_flow[0, :, :], cat_flow[1, :, :]) 82 | cv2.imwrite(os.path.join(seq_output_dir, f"{output_img_idx:06d}_flow.png"), flow_vis) 83 | 84 | output_img_idx += 1 85 | 86 | previous_test_sequence = sequence_name 87 | 88 | output_metric_txt = os.path.join("tensorboard_logs", configs["experiment_name"], "test_metrics.txt") 89 | with open(output_metric_txt, "w") as f: 90 | for k, v in all_metrics.items(): 91 | all_metrics[k] = np.mean(v) 92 | print(f"{k}: {all_metrics[k]}") 93 | f.write(f"{k}: {all_metrics[k]}\n") 94 | 95 | return all_metrics 96 | 97 | 98 | def main(): 99 | # Add two arguments. 100 | # Argument 1: config_path 101 | # Argument 2 (optional): test_all_pths (default=False) 102 | if len(sys.argv) > 1: 103 | config_path = sys.argv[1] 104 | else: 105 | config_path = "configs/template.yaml" 106 | 107 | if len(sys.argv) > 2: 108 | test_all_pths = True 109 | else: 110 | test_all_pths = False 111 | 112 | with open(config_path) as f: 113 | config = yaml.load(f, Loader=yaml.Loader) 114 | 115 | assert config.get("task", "e2vid") == "flow", "e2vid should be tested with test_torch.py" 116 | 117 | ckpt_paths_file = f"ckpt_paths/{config['experiment_name']}.txt" 118 | output_csv = os.path.join("tensorboard_logs", config['experiment_name'], f"all_test_results_new.csv") 119 | os.makedirs(os.path.dirname(output_csv), exist_ok=True) 120 | done_checkpoints = [] 121 | if os.path.exists(output_csv): 122 | with open(output_csv, "r", encoding="utf-8") as f: 123 | lines = f.readlines() 124 | for line in lines[1:]: 125 | ckpt_path = line.split(",")[0] 126 | done_checkpoints.append(ckpt_path) 127 | 128 | # First row: all the metric names 129 | # Each row: subpath, metric1, metric2, .... 130 | if not os.path.exists(output_csv): 131 | with open(output_csv, "w", encoding="UTF-8") as f: 132 | f.write("Checkpoint_path,") 133 | for key in all_metric_names: 134 | f.write(f"{key},") 135 | f.write("\n") 136 | 137 | all_results = [] 138 | if os.path.exists(ckpt_paths_file) and os.path.getsize(ckpt_paths_file) > 0: 139 | with open(ckpt_paths_file, "r") as f: 140 | paths = [p.strip() for p in f.readlines() if p.strip()] 141 | assert len(paths) > 0, "No checkpoint paths found in the file." 142 | if not test_all_pths: 143 | paths = paths[-1:] 144 | 145 | for path in paths: 146 | subpath = path.split("/")[-1] 147 | # If I only request testing the last line, don't skip, it is probably retesting 148 | if not test_all_pths or subpath not in done_checkpoints: 149 | result = run_single_test(path, config) 150 | all_results.append((result, subpath)) 151 | 152 | with open(output_csv, "a", encoding="UTF-8") as f: 153 | f.write(f"{subpath},") 154 | for key in all_metric_names: 155 | f.write(f"{result[key]},") 156 | f.write("\n") 157 | f.flush() 158 | 159 | else: 160 | print("No checkpoint paths file found or it is empty.") 161 | 162 | def run_single_test(checkpoint_path, config): 163 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 164 | model_interface = FlowModelInterface(config["module"], device=device, local_rank=None) 165 | 166 | if checkpoint_path is not None: 167 | saved = torch.load(checkpoint_path, map_location=device, weights_only=False) 168 | state_dict = saved["state_dict"] 169 | 170 | # Don't use torch.compile, because the test is fast enough. 171 | new_state_dict = convert_to_compiled(state_dict=state_dict, local_rank=None, use_compile=False) 172 | 173 | model_interface.e2vid_model.load_state_dict(new_state_dict, strict=False) 174 | print("Loaded checkpoint:", checkpoint_path) 175 | 176 | model_interface.e2vid_model.to(device) 177 | 178 | test_dataloader = create_test_dataloader(config["test_stage"]) 179 | return run_test(model_interface, test_dataloader, device, config) 180 | 181 | if __name__ == "__main__": 182 | main() 183 | --------------------------------------------------------------------------------