├── README.md ├── __pycache__ ├── option.cpython-36.pyc ├── option.cpython-37.pyc ├── template.cpython-36.pyc ├── template.cpython-37.pyc ├── train.cpython-36.pyc ├── train.cpython-37.pyc ├── utils.cpython-36.pyc ├── utils.cpython-37.pyc └── utils.cpython-38.pyc ├── data ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── common.cpython-36.pyc │ ├── common.cpython-37.pyc │ ├── dataset.cpython-36.pyc │ ├── dataset.cpython-37.pyc │ ├── dataset_indoor.cpython-36.pyc │ ├── dataset_indoor.cpython-37.pyc │ ├── demo.cpython-36.pyc │ ├── eventslicer.cpython-36.pyc │ ├── eventslicer.cpython-37.pyc │ ├── gopro_large.cpython-36.pyc │ ├── gopro_large.cpython-37.pyc │ ├── indoor_flying_1.cpython-36.pyc │ ├── indoor_flying_1.cpython-37.pyc │ ├── provider.cpython-36.pyc │ ├── provider.cpython-37.pyc │ ├── representations.cpython-36.pyc │ ├── representations.cpython-37.pyc │ ├── sampler.cpython-36.pyc │ ├── sampler.cpython-37.pyc │ ├── sequence.cpython-36.pyc │ ├── sequence.cpython-37.pyc │ ├── synthetic.cpython-36.pyc │ └── synthetic_event.cpython-36.pyc ├── common.py ├── dataset_indoor.py ├── demo.py ├── dsec.py ├── eventslicer.py ├── provider.py ├── representations.py ├── sampler.py ├── sequence.py └── visualization.py ├── disp_loss ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── disp_loss.cpython-36.pyc │ ├── disp_loss.cpython-37.pyc │ ├── disp_loss.cpython-38.pyc │ ├── metric.cpython-36.pyc │ ├── metric.cpython-37.pyc │ └── metric.cpython-38.pyc ├── adversarial.py ├── disp_loss.py └── metric.py ├── launch.py ├── loss ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── metric.cpython-36.pyc │ └── metric.cpython-37.pyc ├── adversarial.py └── metric.py ├── main.py ├── model ├── Decoder.py ├── ResNet.py ├── __init__.py ├── __pycache__ │ ├── Decoder.cpython-36.pyc │ ├── Decoder.cpython-37.pyc │ ├── ImageDepthNet.cpython-36.pyc │ ├── ImageDepthNet.cpython-37.pyc │ ├── MSResNet.cpython-36.pyc │ ├── MSResNet.cpython-37.pyc │ ├── PASM_modules.cpython-36.pyc │ ├── PASM_modules.cpython-37.pyc │ ├── PASMnet.cpython-36.pyc │ ├── PASMnet.cpython-37.pyc │ ├── ResNet.cpython-36.pyc │ ├── ResNet.cpython-37.pyc │ ├── Transformer.cpython-36.pyc │ ├── Transformer.cpython-37.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── aanet.cpython-37.pyc │ ├── affinity_module.cpython-36.pyc │ ├── affinity_module.cpython-37.pyc │ ├── common.cpython-36.pyc │ ├── common.cpython-37.pyc │ ├── discriminator.cpython-36.pyc │ ├── discriminator.cpython-37.pyc │ ├── gwc_event.cpython-37.pyc │ ├── gwc_image.cpython-37.pyc │ ├── gwc_pertu.cpython-37.pyc │ ├── gwc_pertu_noise.cpython-37.pyc │ ├── gwc_pertu_noise_KD.cpython-37.pyc │ ├── gwc_pertu_noise_affinity.cpython-37.pyc │ ├── gwc_pertu_noise_deform.cpython-37.pyc │ ├── gwc_pertu_noise_with_affinity.cpython-36.pyc │ ├── gwc_pertu_noise_with_affinity.cpython-37.pyc │ ├── gwcnet.cpython-37.pyc │ ├── image_recon.cpython-37.pyc │ ├── intensity_MSResNet.cpython-36.pyc │ ├── intensity_MSResNet.cpython-37.pyc │ ├── pasm_pertu_noise.cpython-36.pyc │ ├── pasm_pertu_noise.cpython-37.pyc │ ├── pertu_select_recon.cpython-36.pyc │ ├── pertu_select_recon.cpython-37.pyc │ ├── perturbations.cpython-36.pyc │ ├── perturbations.cpython-37.pyc │ ├── submodule.cpython-36.pyc │ ├── submodule.cpython-37.pyc │ ├── t2t_vit.cpython-36.pyc │ ├── t2t_vit.cpython-37.pyc │ ├── token_performer.cpython-36.pyc │ ├── token_performer.cpython-37.pyc │ ├── token_transformer.cpython-36.pyc │ ├── token_transformer.cpython-37.pyc │ ├── transformer_block.cpython-36.pyc │ └── transformer_block.cpython-37.pyc ├── affinity_module.py ├── common.py ├── gwc_pertu_noise_with_affinity.py ├── image_recon.py ├── intensity_MSResNet.py ├── pertu_select_recon.py ├── perturbations.py ├── structure.py └── submodule.py ├── optim ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── __init__.cpython-37.pyc └── warm_multi_step_lr.py ├── option.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Selection and Cross Similarity for Event-Image Deep Stereo (SCSNet) ECCV 2022 2 | 3 | This code is an official code of our ECCV paper "Selection and Cross Similarity for Event-Image Deep Stereo" 4 | 5 | # Dataset 6 | DSEC 7 | https://dsec.ifi.uzh.ch/uzh/disparity-benchmark/ 8 | 9 | # Train 10 | python main.py --n_GPUs 4 --batch_size 8 --dataset indoor_flying_1 --split 1 --data_root ../../DSEC_data --save_dir max_disp_120_homo_batch_8 --model pertu_select_recon --loss 1*L1+1*LPIPS --lr 1e-4 --test_every 200 --save_every 1 --disp_model gwc_pertu_noise_with_affinity --end_epoch 160 --validate_every 10 11 | 12 | # Inference for benchmark 13 | CUDA_VISIBLE_DEVICES=3 python main.py --n_GPUs 1 --batch_size 1 --split 1 --data_root ../../DSEC_data --save_dir max_disp_120_homo_batch_8 --model pertu_select_recon --loss 1*L1+1*LPIPS --lr 1e-4 --test_every 100 --save_every 1 --disp_model gwc_pertu_noise_with_affinity --end_epoch 99 --validate_every 1 --load_epoch 77 14 | 15 | 16 | # Paper Reference 17 | @inproceedings{cho2022selection, 18 | title={Selection and Cross Similarity for Event-Image Deep Stereo}, 19 | author={Cho, Hoonhee and Yoon, Kuk-Jin}, 20 | booktitle={Computer Vision--ECCV 2022: 17th European Conference, Tel Aviv, Israel, October 23--27, 2022, Proceedings, Part XXXII}, 21 | pages={470--486}, 22 | year={2022}, 23 | organization={Springer} 24 | } 25 | 26 | we borrow the works from three repositories. Thanks for the excellent codes! 27 | - Nah, Seungjun, Tae Hyun Kim, and Kyoung Mu Lee. "Deep multi-scale convolutional neural network for dynamic scene deblurring." Proceedings of the IEEE conference on computer vision and pattern recognition. 2017. 28 | https://github.com/SeungjunNah/DeepDeblur_release 29 | - Berthet, Quentin, et al. "Learning with differentiable pertubed optimizers." Advances in neural information processing systems 33 (2020): 9508-9519. 30 | https://github.com/tuero/perturbations-differential-pytorch 31 | 32 | # TBD 33 | - Pretrained model 34 | - MVSEC dataloader 35 | - There may be minor code errors due to accidental deletion of parts. But, performance was confirmed to be reproducible. 36 | -------------------------------------------------------------------------------- /__pycache__/option.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/__pycache__/option.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/option.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/__pycache__/option.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/template.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/__pycache__/template.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/template.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/__pycache__/template.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/train.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/__pycache__/train.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/train.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/__pycache__/train.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | """Generic dataset loader""" 2 | 3 | from importlib import import_module 4 | 5 | from torch.utils.data import DataLoader 6 | from torch.utils.data import SequentialSampler, RandomSampler 7 | from torch.utils.data.distributed import DistributedSampler 8 | from .sampler import DistributedEvalSampler 9 | from .provider import DatasetProvider 10 | from pathlib import Path 11 | class Data(): 12 | def __init__(self, args): 13 | 14 | self.modes = ['train', 'val', 'test', 'demo'] 15 | 16 | self.action = { 17 | 'train': args.do_train, 18 | 'val': args.do_validate, 19 | 'test': args.do_test, 20 | 'demo': args.demo 21 | } 22 | 23 | self.dataset_name = { 24 | 'train': args.data_train, 25 | 'val': args.data_val, 26 | 'test': args.data_test, 27 | 'demo': 'Demo' 28 | } 29 | 30 | self.args = args 31 | 32 | def _get_data_loader(mode='train'): 33 | dataset_name = self.dataset_name[mode] 34 | # dataset = import_module('data.' + dataset_name.lower()) 35 | # dataset = getattr(dataset, dataset_name)(args, mode) 36 | dataset_provider = DatasetProvider(Path(args.data_root)) 37 | if mode == 'train': 38 | dataset = dataset_provider.get_train_dataset() 39 | elif mode == 'val': 40 | dataset = dataset_provider.get_val_dataset() 41 | else: 42 | dataset = dataset_provider.get_test_dataset() 43 | dataset[0] 44 | 45 | 46 | if mode == 'train': 47 | if args.distributed: 48 | batch_size = int(args.batch_size / args.n_GPUs) # batch size per GPU (single-node training) 49 | sampler = DistributedSampler(dataset, shuffle=True, num_replicas=args.world_size, rank=args.rank) 50 | num_workers = int((args.num_workers + args.n_GPUs - 1) / args.n_GPUs) # num_workers per GPU (single-node training) 51 | else: 52 | batch_size = args.batch_size 53 | sampler = RandomSampler(dataset, replacement=False) 54 | num_workers = args.num_workers 55 | drop_last = True 56 | 57 | elif mode in ('val', 'test', 'demo'): 58 | if args.distributed: 59 | batch_size = 1 # 1 image per GPU 60 | sampler = DistributedEvalSampler(dataset, shuffle=False, num_replicas=args.world_size, rank=args.rank) 61 | num_workers = int((args.num_workers + args.n_GPUs - 1) / args.n_GPUs) # num_workers per GPU (single-node training) 62 | else: 63 | batch_size = args.n_GPUs # 1 image per GPU 64 | sampler = SequentialSampler(dataset) 65 | num_workers = args.num_workers 66 | drop_last = False 67 | 68 | loader = DataLoader( 69 | dataset=dataset, 70 | batch_size=batch_size, 71 | shuffle=False, 72 | sampler=sampler, 73 | num_workers=num_workers, 74 | pin_memory=True, 75 | drop_last=drop_last, 76 | ) 77 | 78 | return loader 79 | 80 | self.loaders = {} 81 | for mode in self.modes: 82 | if self.action[mode]: 83 | self.loaders[mode] = _get_data_loader(mode) 84 | print('===> Loading {} dataset: {}'.format(mode, self.dataset_name[mode])) 85 | else: 86 | self.loaders[mode] = None 87 | 88 | def get_loader(self): 89 | return self.loaders 90 | -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/common.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/data/__pycache__/common.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/common.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/data/__pycache__/common.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/data/__pycache__/dataset.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/data/__pycache__/dataset.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/dataset_indoor.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/data/__pycache__/dataset_indoor.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/dataset_indoor.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/data/__pycache__/dataset_indoor.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/demo.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/data/__pycache__/demo.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/eventslicer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/data/__pycache__/eventslicer.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/eventslicer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/data/__pycache__/eventslicer.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/gopro_large.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/data/__pycache__/gopro_large.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/gopro_large.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/data/__pycache__/gopro_large.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/indoor_flying_1.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/data/__pycache__/indoor_flying_1.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/indoor_flying_1.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/data/__pycache__/indoor_flying_1.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/provider.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/data/__pycache__/provider.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/provider.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/data/__pycache__/provider.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/representations.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/data/__pycache__/representations.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/representations.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/data/__pycache__/representations.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/sampler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/data/__pycache__/sampler.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/sampler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/data/__pycache__/sampler.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/sequence.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/data/__pycache__/sequence.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/sequence.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/data/__pycache__/sequence.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/synthetic.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/data/__pycache__/synthetic.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/synthetic_event.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/data/__pycache__/synthetic_event.cpython-36.pyc -------------------------------------------------------------------------------- /data/common.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from skimage.color import rgb2hsv, hsv2rgb 4 | from skimage.transform import pyramid_gaussian 5 | from skimage.measure import block_reduce 6 | 7 | import torch 8 | 9 | def _apply(func, x): 10 | 11 | if isinstance(x, (list, tuple)): 12 | return [_apply(func, x_i) for x_i in x] 13 | elif isinstance(x, dict): 14 | y = {} 15 | for key, value in x.items(): 16 | y[key] = _apply(func, value) 17 | return y 18 | else: 19 | return func(x) 20 | 21 | def crop(args, ps, py, px): # patch_size 22 | # args = [input, target] 23 | def _get_shape(*args): 24 | if isinstance(args[0], (list, tuple)): 25 | return _get_shape(args[0][0]) 26 | elif isinstance(args[0], dict): 27 | return _get_shape(list(args[0].values())[0]) 28 | else: 29 | return args[0].shape 30 | 31 | h, w, _ = _get_shape(args) 32 | # print(_get_shape(args)) 33 | # print(ps[1]) 34 | # import pdb; pdb.set_trace() 35 | 36 | # py = random.randrange(0, h-ps+1) 37 | # px = random.randrange(0, w-ps+1) 38 | 39 | def _crop(img): 40 | if img.ndim == 2: 41 | return img[py:py+ps[1], px:px+ps[0], np.newaxis] 42 | else: 43 | return img[py:py+ps[1], px:px+ps[0], :] 44 | 45 | return _apply(_crop, args) 46 | 47 | def crop_with_event(*args, left_event, right_event, ps=256): # patch_size 48 | # args = [input, target] 49 | def _get_shape(*args): 50 | if isinstance(args[0], (list, tuple)): 51 | return _get_shape(args[0][0]) 52 | elif isinstance(args[0], dict): 53 | return _get_shape(list(args[0].values())[0]) 54 | else: 55 | return args[0].shape 56 | 57 | h, w, _ = _get_shape(args) 58 | 59 | py = random.randrange(0, h-ps+1) 60 | px = random.randrange(0, w-ps+1) 61 | 62 | def _crop(img): 63 | if img.ndim == 2: 64 | return img[py:py+ps, px:px+ps, np.newaxis] 65 | else: 66 | return img[py:py+ps, px:px+ps, :] 67 | 68 | def _event_crop(event): 69 | return event[:, py:py+ps, px:px+ps] 70 | 71 | return _apply(_crop, args), _apply(_event_crop, left_event), _apply(_event_crop, right_event) 72 | 73 | def crop_event(args, ps, py, px): # patch_size 74 | # args = [input, target] 75 | def _get_shape(*args): 76 | if isinstance(args[0], (list, tuple)): 77 | return _get_shape(args[0][0]) 78 | elif isinstance(args[0], dict): 79 | return _get_shape(list(args[0].values())[0]) 80 | else: 81 | return args[0].shape 82 | 83 | _, h, w = _get_shape(args) 84 | 85 | # py = random.randrange(0, h-ps+1) 86 | # px = random.randrange(0, w-ps+1) 87 | 88 | def _crop(img): 89 | if img.ndim == 2: 90 | return img[py:py+ps, px:px+ps, np.newaxis] 91 | else: 92 | return img[py:py+ps, px:px+ps, :] 93 | 94 | def _event_crop(event): 95 | return event[:, py:py+ps[1], px:px+ps[0]] 96 | 97 | return _apply(_event_crop, args) 98 | 99 | 100 | def crop_disp(args, ps, py, px): # patch_size 101 | # args = [input, target] 102 | def _get_shape(*args): 103 | if isinstance(args[0], (list, tuple)): 104 | return _get_shape(args[0][0]) 105 | elif isinstance(args[0], dict): 106 | return _get_shape(list(args[0].values())[0]) 107 | else: 108 | return args[0].shape 109 | 110 | h, w = _get_shape(args) 111 | 112 | # py = random.randrange(0, h-ps+1) 113 | # px = random.randrange(0, w-ps+1) 114 | 115 | 116 | def _crop(img): 117 | if img.ndim == 2: 118 | return img[py:py+ps[1], px:px+ps[0], np.newaxis] 119 | else: 120 | return img[py:py+ps[1], px:px+ps[0], :] 121 | 122 | def _disp_crop(event): 123 | return event[py:py+ps[1], px:px+ps[0]] 124 | 125 | return _apply(_disp_crop, args) 126 | 127 | def add_noise(*args, sigma_sigma=2, rgb_range=255): 128 | 129 | if len(args) == 1: # usually there is only a single input 130 | args = args[0] 131 | 132 | sigma = np.random.normal() * sigma_sigma * rgb_range/255 133 | 134 | def _add_noise(img): 135 | noise = np.random.randn(*img.shape).astype(np.float32) * sigma 136 | return (img + noise).clip(0, rgb_range) 137 | 138 | return _apply(_add_noise, args) 139 | 140 | def augment(*args, hflip=True, rot=True, shuffle=True, change_saturation=True, rgb_range=255): 141 | """augmentation consistent to input and target""" 142 | 143 | choices = (False, True) 144 | 145 | hflip = hflip and random.choice(choices) 146 | vflip = rot and random.choice(choices) 147 | rot90 = rot and random.choice(choices) 148 | # shuffle = shuffle 149 | 150 | if shuffle: 151 | rgb_order = list(range(3)) 152 | random.shuffle(rgb_order) 153 | if rgb_order == list(range(3)): 154 | shuffle = False 155 | 156 | if change_saturation: 157 | amp_factor = np.random.uniform(0.5, 1.5) 158 | 159 | def _augment(img): 160 | if hflip: img = img[:, ::-1, :] 161 | if vflip: img = img[::-1, :, :] 162 | if rot90: img = img.transpose(1, 0, 2) 163 | if shuffle and img.ndim > 2: 164 | if img.shape[-1] == 3: # RGB image only 165 | img = img[..., rgb_order] 166 | 167 | if change_saturation: 168 | hsv_img = rgb2hsv(img) 169 | hsv_img[..., 1] *= amp_factor 170 | 171 | img = hsv2rgb(hsv_img).clip(0, 1) * rgb_range 172 | 173 | return img.astype(np.float32) 174 | 175 | return _apply(_augment, args) 176 | 177 | def pad(img, divisor=4, pad_width=None, negative=False): 178 | 179 | def _pad_numpy(img, divisor=4, pad_width=None, negative=False): 180 | if pad_width is None: 181 | (h, w, _) = img.shape 182 | pad_h = -h % divisor 183 | pad_w = -w % divisor 184 | pad_width = ((0, pad_h), (0, pad_w), (0, 0)) 185 | 186 | img = np.pad(img, pad_width, mode='edge') 187 | 188 | return img, pad_width 189 | 190 | def _pad_tensor(img, divisor=4, pad_width=None, negative=False): 191 | 192 | n, c, h, w = img.shape 193 | if pad_width is None: 194 | pad_h = -h % divisor 195 | pad_w = -w % divisor 196 | pad_width = (0, pad_w, 0, pad_h) 197 | else: 198 | try: 199 | pad_h = pad_width[0][1] 200 | pad_w = pad_width[1][1] 201 | if isinstance(pad_h, torch.Tensor): 202 | pad_h = pad_h.item() 203 | if isinstance(pad_w, torch.Tensor): 204 | pad_w = pad_w.item() 205 | 206 | pad_width = (0, pad_w, 0, pad_h) 207 | except: 208 | pass 209 | 210 | if negative: 211 | pad_width = [-val for val in pad_width] 212 | 213 | img = torch.nn.functional.pad(img, pad_width, 'reflect') 214 | 215 | return img, pad_width 216 | 217 | if isinstance(img, np.ndarray): 218 | return _pad_numpy(img, divisor, pad_width, negative) 219 | else: # torch.Tensor 220 | return _pad_tensor(img, divisor, pad_width, negative) 221 | 222 | 223 | def disp_pad(img, divisor=4, pad_width=None, negative=False): 224 | 225 | def _pad_numpy(img, divisor=4, pad_width=None, negative=False): 226 | if pad_width is None: 227 | (h, w, _) = img.shape 228 | pad_h = -h % divisor 229 | pad_w = -w % divisor 230 | pad_width = ((0, pad_h), (0, pad_w), (0, 0)) 231 | 232 | img = np.pad(img, pad_width, mode='edge') 233 | 234 | return img, pad_width 235 | 236 | def _pad_tensor(img, divisor=4, pad_width=None, negative=False): 237 | 238 | if isinstance(img, list): 239 | img = img[0].unsqueeze(1) 240 | n, c, h, w = img.shape 241 | else: 242 | img = img.unsqueeze(1) 243 | n, c, h, w = img.shape 244 | if pad_width is None: 245 | pad_h = -h % divisor 246 | pad_w = -w % divisor 247 | pad_width = (0, pad_w, 0, pad_h) 248 | 249 | 250 | else: 251 | try: 252 | # import pdb; pdb.set_trace() 253 | pad_h = pad_width[0][1][0] 254 | pad_w = pad_width[1][1][0] 255 | if isinstance(pad_h, torch.Tensor): 256 | pad_h = pad_h.item() 257 | if isinstance(pad_w, torch.Tensor): 258 | pad_w = pad_w.item() 259 | 260 | pad_width = (0, pad_w, 0, pad_h) 261 | except: 262 | pass 263 | 264 | if negative: 265 | pad_width = [-val for val in pad_width] 266 | 267 | 268 | img = torch.nn.functional.pad(img, pad_width, 'reflect') 269 | 270 | 271 | return img 272 | 273 | if isinstance(img, np.ndarray): 274 | return _pad_numpy(img, divisor, pad_width, negative) 275 | else: # torch.Tensor 276 | return _pad_tensor(img, divisor, pad_width, negative) 277 | 278 | def event_pad(img, divisor=4, pad_width=None, negative=False): 279 | 280 | def _pad_numpy(img, divisor=4, pad_width=None, negative=False): 281 | 282 | if pad_width is None: 283 | (_, h, w) = img.shape 284 | pad_h = -h % divisor 285 | pad_w = -w % divisor 286 | pad_width = ((0, 0), (0, pad_h), (0, pad_w)) 287 | 288 | img = np.pad(img, pad_width, mode='edge') 289 | 290 | return img, pad_width 291 | 292 | def _pad_tensor(img, divisor=4, pad_width=None, negative=False): 293 | 294 | n, c, h, w = img.shape 295 | if pad_width is None: 296 | pad_h = -h % divisor 297 | pad_w = -w % divisor 298 | pad_width = (0, pad_w, 0, pad_h) 299 | else: 300 | try: 301 | pad_h = pad_width[0][1] 302 | pad_w = pad_width[1][1] 303 | if isinstance(pad_h, torch.Tensor): 304 | pad_h = pad_h.item() 305 | if isinstance(pad_w, torch.Tensor): 306 | pad_w = pad_w.item() 307 | 308 | pad_width = (0, pad_w, 0, pad_h) 309 | except: 310 | pass 311 | 312 | if negative: 313 | pad_width = [-val for val in pad_width] 314 | 315 | img = torch.nn.functional.pad(img, pad_width, 'reflect') 316 | 317 | return img, pad_width 318 | 319 | if isinstance(img, np.ndarray): 320 | return _pad_numpy(img, divisor, pad_width, negative) 321 | else: # torch.Tensor 322 | return _pad_tensor(img, divisor, pad_width, negative) 323 | 324 | def generate_pyramid(*args, n_scales): 325 | 326 | def _generate_pyramid(img): 327 | if img.dtype != np.float32: 328 | img = img.astype(np.float32) 329 | pyramid = list(pyramid_gaussian(img, n_scales-1, multichannel=True)) 330 | 331 | return pyramid 332 | 333 | return _apply(_generate_pyramid, args) 334 | 335 | def generate_event_pyramid(*args, n_scales): 336 | def _generate_event_pyramid(event): 337 | event = np.array(event) 338 | 339 | if event.dtype != np.float32: 340 | event = event.astype(np.float32) 341 | 342 | # import pdb 343 | # pdb.set_trace() 344 | # print("len event") 345 | # print(event) 346 | # print([len(a) for a in event]) 347 | 348 | # print(event.shape) 349 | pyramid = [] 350 | for i in range(n_scales): 351 | w, h, c = event.shape 352 | # scale_event = block_reduce(event, (w//(2**i) , h//(2**i), c), np.max) 353 | scale_event = block_reduce(event, (1, 2**i, 2**i), np.max) 354 | # print(scale_event) 355 | # print(event) 356 | pyramid.append(scale_event) 357 | # if i == 2: 358 | # print(scale_event.shape) 359 | # import pdb 360 | # pdb.set_trace() 361 | 362 | 363 | # pyramid = list(pyramid_gaussian(img, n_scales-1, multichannel=True)) 364 | return pyramid 365 | return _generate_event_pyramid(args) 366 | 367 | 368 | def np2tensor(*args): 369 | def _np2tensor(x): 370 | np_transpose = np.ascontiguousarray(x.transpose(2, 0, 1)) 371 | tensor = torch.from_numpy(np_transpose) 372 | 373 | return tensor 374 | 375 | return _apply(_np2tensor, args) 376 | 377 | def image2tensor(x): 378 | np_transpose = np.ascontiguousarray(x.transpose(2, 0, 1)) 379 | tensor = torch.from_numpy(np_transpose) 380 | 381 | return tensor 382 | 383 | def event2tensor(*args): 384 | def _np2tensor(x): 385 | # np_transpose = np.ascontiguousarray(x.transpose(2, 0, 1)) 386 | tensor = torch.from_numpy(x) 387 | return tensor 388 | 389 | 390 | return _apply(_np2tensor, args) 391 | 392 | def to(*args, device=None, dtype=torch.float): 393 | 394 | def _to(x): 395 | return x.to(device=device, dtype=dtype, non_blocking=True, copy=False) 396 | 397 | return _apply(_to, args) 398 | -------------------------------------------------------------------------------- /data/dataset_indoor.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pdb import Pdb 3 | import random 4 | import imageio 5 | import numpy as np 6 | import torch.utils.data as data 7 | 8 | from data import common 9 | 10 | from utils import interact 11 | import PIL.Image 12 | 13 | 14 | class Dataset(data.Dataset): 15 | """Basic dataloader class 16 | """ 17 | def __init__(self, args, mode='train'): 18 | super(Dataset, self).__init__() 19 | self.args = args 20 | self.mode = mode 21 | 22 | self.modes = () 23 | self.set_modes() 24 | self._check_mode() 25 | 26 | self.set_keys() 27 | 28 | if self.mode == 'train': 29 | dataset = args.data_train 30 | elif self.mode == 'val': 31 | dataset = args.data_val 32 | elif self.mode == 'test': 33 | dataset = args.data_test 34 | elif self.mode == 'demo': 35 | pass 36 | else: 37 | raise NotImplementedError('not implemented for this mode: {}!'.format(self.mode)) 38 | 39 | if self.mode == 'demo': 40 | self.subset_root = args.demo_input_dir 41 | else: 42 | # self.subset_root = os.path.join(args.data_root, dataset, self.mode) 43 | if args.split == 1: 44 | if self.mode == 'train': 45 | dataset = ['indoor_flying_2', 'indoor_flying_3'] 46 | else: 47 | dataset = 'indoor_flying_1' 48 | elif args.split == 3: 49 | if self.mode == 'train': 50 | dataset = ['indoor_flying_1', 'indoor_flying_2'] 51 | else: 52 | dataset = 'indoor_flying_3' 53 | 54 | if self.mode == 'train': 55 | self.subset_root = [] 56 | for set in dataset: 57 | self.subset_root.append(os.path.join(args.data_root, set)) 58 | else: 59 | self.subset_root = os.path.join(args.data_root, dataset) 60 | 61 | 62 | # import pdb 63 | # pdb.set_trace() 64 | 65 | # FRAMES_FILTER_FOR_TEST = { 66 | # 'indoor_flying': { 67 | # 1: list(range(140, 1201)), 68 | # 2: list(range(120, 1421)), 69 | # 3: list(range(73, 1616)), 70 | # 4: list(range(190, 290)) 71 | # } 72 | # } 73 | 74 | # FRAMES_FILTER_FOR_TRAINING = { 75 | # 'indoor_flying': { 76 | # 1: list(range(80, 1260)), 77 | # 2: list(range(160, 1580)), 78 | # 3: list(range(125, 1815)), 79 | # 4: list(range(190, 290)) 80 | # } 81 | # } 82 | 83 | self.left_image_list = [] 84 | self.right_image_list = [] 85 | self.left_event_list = [] 86 | self.right_event_list = [] 87 | self.disp_list = [] 88 | 89 | 90 | 91 | self._scan() 92 | 93 | 94 | def set_modes(self): 95 | self.modes = ('train', 'val', 'test', 'demo') 96 | 97 | def _check_mode(self): 98 | """Should be called in the child class __init__() after super 99 | """ 100 | 101 | if self.mode not in self.modes: 102 | raise NotImplementedError('mode error: not for {}'.format(self.mode)) 103 | 104 | return 105 | 106 | def set_keys(self): 107 | 108 | self.left_image_key = 'image0' 109 | self.right_image_key = 'image1' 110 | # self.left_event_key = 'select0' 111 | # self.right_event_key = 'select1' 112 | # self.left_event_key = 'numvox0' 113 | # self.right_event_key = 'numvox1' 114 | self.left_event_key = 'voxel0_orig' 115 | self.right_event_key = 'voxel1_orig' 116 | self.disp_key = 'disparity_image' 117 | 118 | 119 | self.non_left_image_keys = [] 120 | self.non_right_image_keys = [] 121 | self.non_left_event_keys = [] 122 | self.non_right_event_keys = [] 123 | self.non_disp_keys = [] 124 | 125 | return 126 | 127 | def _scan(self, root=None): 128 | """Should be called in the child class __init__() after super 129 | """ 130 | if root is None: 131 | root = self.subset_root 132 | 133 | # if self.blur_key in self.non_blur_keys: 134 | # self.non_blur_keys.remove(self.blur_key) 135 | # if self.sharp_key in self.non_sharp_keys: 136 | # self.non_sharp_keys.remove(self.sharp_key) 137 | # if self.event_key in self.non_event_keys: 138 | # self.non_event_keys.remove(self.event_key) 139 | 140 | if self.left_image_key in self.non_left_image_keys: 141 | self.non_left_image_keys.remove(self.left_image_key) 142 | if self.right_image_key in self.non_right_image_keys: 143 | self.non_right_image_keys.remove(self.right_image_key) 144 | if self.left_event_key in self.non_left_event_keys: 145 | self.non_left_event_keys.remove(self.left_event_key) 146 | if self.right_event_key in self.non_right_event_keys: 147 | self.non_right_event_keys.remove(self.right_event_key) 148 | if self.disp_key in self.non_disp_keys: 149 | self.non_disp_keys.remove(self.disp_key) 150 | 151 | 152 | 153 | def _key_check(path, true_key, false_keys): 154 | path = os.path.join(path, '') 155 | if path.find(true_key) >= 0: 156 | for false_key in false_keys: 157 | if path.find(false_key) >= 0: 158 | return False 159 | 160 | return True 161 | else: 162 | return False 163 | # FRAMES_FILTER_FOR_TEST = { 164 | # 'indoor_flying': { 165 | # 1: list(range(140, 1201)), 166 | # 2: list(range(120, 1421)), 167 | # 3: list(range(73, 1616)), 168 | # 4: list(range(190, 290)) 169 | # } 170 | # } 171 | 172 | # FRAMES_FILTER_FOR_TRAINING = { 173 | # 'indoor_flying': { 174 | # 1: list(range(80, 1260)), 175 | # 2: list(range(160, 1580)), 176 | # 3: list(range(125, 1815)), 177 | # 4: list(range(190, 290)) 178 | # } 179 | # } 180 | 181 | # original 182 | FILTER_TEST = { 183 | # 'indoor_flying_1': list(range(140, 1201)), 184 | # 'indoor_flying_2': list(range(120, 1421)), 185 | # 'indoor_flying_3': list(range(73, 1616)) 186 | # 'indoor_flying_1': list(range(140, 1001)), 187 | 'indoor_flying_1': list(range(140, 1001)), 188 | 'indoor_flying_2': list(range(120, 1221)), 189 | 'indoor_flying_3': list(range(273, 1616)) 190 | } 191 | 192 | # test 193 | # FILTER_TEST = { 194 | # # 'indoor_flying_1': list(range(140, 1201)), 195 | # # 'indoor_flying_2': list(range(120, 1421)), 196 | # # 'indoor_flying_3': list(range(73, 1616)) 197 | # 'indoor_flying_1': list(range(140, 161)), 198 | # 'indoor_flying_2': list(range(120, 141)), 199 | # 'indoor_flying_3': list(range(73, 93)) 200 | # } 201 | 202 | # original 203 | FILTER_TRAIN = { 204 | 'indoor_flying_1': list(range(80, 1260)), 205 | 'indoor_flying_2': list(range(160, 1580)), 206 | 'indoor_flying_3': list(range(125, 1815)) 207 | } 208 | 209 | # test 210 | # FILTER_TRAIN = { 211 | # 'indoor_flying_1': list(range(80, 90)), 212 | # 'indoor_flying_2': list(range(160, 180)), 213 | # 'indoor_flying_3': list(range(125, 145)) 214 | # } 215 | 216 | def _get_list_by_key(root, true_key, false_keys): 217 | data_list = [] 218 | if isinstance(root, (list, tuple)): 219 | for rt in root: 220 | for sub, dirs, files in os.walk(rt): 221 | if not dirs: 222 | file_list = [os.path.join(sub, f) for f in files if int(f.split('.')[0]) in FILTER_TRAIN[rt.split('/')[-1]]] 223 | if _key_check(sub, true_key, false_keys): 224 | data_list += file_list 225 | else: 226 | for sub, dirs, files in os.walk(root): 227 | if not dirs: 228 | file_list = [os.path.join(sub, f) for f in files if int(f.split('.')[0]) in FILTER_TEST[root.split('/')[-1]]] 229 | if _key_check(sub, true_key, false_keys): 230 | data_list += file_list 231 | 232 | 233 | data_list.sort() 234 | return data_list 235 | 236 | def _rectify_keys(): 237 | 238 | self.left_image_key = os.path.join(self.left_image_key, '') 239 | self.non_left_image_keys = [os.path.join(non_left_image_key, '') for non_left_image_key in self.non_left_image_keys] 240 | self.left_event_key = os.path.join(self.left_event_key, '') 241 | self.non_left_event_keys = [os.path.join(non_left_event_key, '') for non_left_event_key in self.non_left_event_keys] 242 | self.right_image_key = os.path.join(self.right_image_key, '') 243 | self.non_right_image_keys = [os.path.join(non_right_image_key, '') for non_right_image_key in self.non_right_image_keys] 244 | self.right_event_key = os.path.join(self.right_event_key, '') 245 | self.non_right_event_keys = [os.path.join(non_right_event_key, '') for non_right_event_key in self.non_right_event_keys] 246 | self.disp_key = os.path.join(self.disp_key, '') 247 | self.non_disp_keys = [os.path.join(non_disp_key, '') for non_disp_key in self.non_disp_keys] 248 | 249 | 250 | _rectify_keys() 251 | 252 | self.left_image_list = _get_list_by_key(root, self.left_image_key, self.non_left_image_keys) 253 | self.left_event_list = _get_list_by_key(root, self.left_event_key, self.non_left_event_keys) 254 | self.right_image_list = _get_list_by_key(root, self.right_image_key, self.non_right_image_keys) 255 | self.right_event_list = _get_list_by_key(root, self.right_event_key, self.non_right_event_keys) 256 | self.disp_list = _get_list_by_key(root, self.disp_key, self.non_disp_keys) 257 | 258 | 259 | if len(self.left_image_list) > 0: 260 | assert(len(self.left_image_list) == len(self.left_event_list)) 261 | if len(self.right_image_list) > 0: 262 | assert(len(self.right_image_list) == len(self.right_event_list)) 263 | if len(self.left_image_list) > 0: 264 | assert(len(self.left_image_list) == len(self.right_image_list)) 265 | if len(self.disp_list) > 0: 266 | assert(len(self.disp_list) == len(self.left_image_list)) 267 | 268 | return 269 | 270 | def __getitem__(self, idx): 271 | 272 | left_image = imageio.imread(self.left_image_list[idx], pilmode='RGB') 273 | right_image = imageio.imread(self.right_image_list[idx], pilmode='RGB') 274 | imgs = [left_image, right_image] 275 | 276 | left_event = np.load(self.left_event_list[idx]) 277 | right_event = np.load(self.right_event_list[idx]) 278 | # print(left_event.shape) 279 | 280 | disp = np.array(PIL.Image.open(self.disp_list[idx])).astype(np.uint8) 281 | invalid_disparity = (disp == 255.0) 282 | disparity_image = (disp / 7.0) 283 | disparity_image[invalid_disparity] = float('inf') 284 | 285 | 286 | 287 | pad_width = 0 # dummy value 288 | # if self.mode == 'train': 289 | # # imgs, left_event, right_event = common.crop_with_event(*imgs, left_event = left_event, right_event = right_event, ps=self.args.patch_size) 290 | # imgs[0], pad_width = common.pad(imgs[0], divisor=64) 291 | # elif self.mode == 'demo': 292 | # imgs[0], pad_width = common.pad(imgs[0], divisor=2**(self.args.n_scales-1)) # pad in case of non-divisible size 293 | # else: 294 | # # imgs[0], pad_width = common.pad(imgs[0], divisor=2**(self.args.n_scales-1)) 295 | # # event, pad_width = common.pad(event, divisor=2**(self.args.n_scales-1)) 296 | # pass # deliver test image as is. 297 | 298 | ## padding 299 | # imgs[0], pad_width = common.pad(imgs[0], divisor=64) 300 | # imgs[1], pad_width = common.pad(imgs[1], divisor=64) 301 | # left_event, _ = common.event_pad(left_event, divisor=64) 302 | # right_event, _ = common.event_pad(right_event, divisor=64) 303 | 304 | # print(imgs[0].shape) 305 | # print(imgs[1].shape) 306 | # print(left_event.shape) 307 | # print(right_event.shape) 308 | # import pdb 309 | # pdb.set_trace() 310 | 311 | noise_imgs = [imgs[0], imgs[1]] 312 | noise_imgs[0] = common.add_noise(imgs[0], sigma_sigma=2, rgb_range=self.args.rgb_range) 313 | noise_imgs[1] = common.add_noise(imgs[1], sigma_sigma=2, rgb_range=self.args.rgb_range) 314 | 315 | 316 | # print(event.shape) 317 | 318 | if self.args.gaussian_pyramid: 319 | if self.mode == ('train' or 'demo'): 320 | 321 | imgs = common.generate_pyramid(*imgs, n_scales=self.args.n_scales) 322 | left_event = common.generate_event_pyramid(left_event, n_scales=self.args.n_scales) 323 | right_event = common.generate_event_pyramid(right_event, n_scales=self.args.n_scales) 324 | else: 325 | # left_event, pad_width = common.event_pad(left_event, divisor=2**(self.args.n_scales-1)) 326 | # right_event, pad_width = common.event_pad(right_event, divisor=2**(self.args.n_scales-1)) 327 | # imgs[0], pad_width = common.pad(imgs[0], divisor=2**(self.args.n_scales-1)) 328 | # imgs[1], pad_width = common.pad(imgs[1], divisor=2**(self.args.n_scales-1)) 329 | imgs = common.generate_pyramid(*imgs, n_scales=self.args.n_scales) 330 | left_event = common.generate_event_pyramid(*left_event, n_scales=self.args.n_scales) 331 | right_event = common.generate_event_pyramid(*right_event, n_scales=self.args.n_scales) 332 | 333 | 334 | imgs = common.np2tensor(*imgs) 335 | noise_imgs = common.np2tensor(*noise_imgs) 336 | 337 | left_event = common.event2tensor(left_event)[0] 338 | right_event = common.event2tensor(right_event)[0] 339 | 340 | if self.mode == 'train': 341 | relpath = os.path.relpath(self.left_image_list[idx], self.subset_root[0]) 342 | else: 343 | relpath = os.path.relpath(self.left_image_list[idx], self.subset_root) 344 | 345 | 346 | # blur = imgs[0] 347 | 348 | left_img = imgs[0] 349 | right_img = imgs[1] 350 | # sharp = imgs[1] if len(imgs) > 1 else False 351 | left_noise = noise_imgs[0] 352 | right_noise = noise_imgs[1] 353 | 354 | 355 | return left_img, right_img, pad_width, idx, relpath, left_event, right_event, left_noise, right_noise, disparity_image 356 | 357 | def __len__(self): 358 | return len(self.left_image_list) 359 | # return 32 360 | 361 | 362 | 363 | 364 | 365 | -------------------------------------------------------------------------------- /data/demo.py: -------------------------------------------------------------------------------- 1 | from data.dataset import Dataset 2 | 3 | from utils import interact 4 | 5 | class Demo(Dataset): 6 | """Demo train, test subset class 7 | """ 8 | def __init__(self, args, mode='demo'): 9 | super(Demo, self).__init__(args, mode) 10 | 11 | def set_modes(self): 12 | self.modes = ('demo') 13 | 14 | def set_keys(self): 15 | super(Demo, self).set_keys() 16 | self.blur_key = '' # all the files 17 | self.non_sharp_keys = [''] # no files 18 | 19 | def __getitem__(self, idx): 20 | blur, sharp, pad_width, idx, relpath, event = super(Demo, self).__getitem__(idx) 21 | 22 | return blur, sharp, pad_width, idx, relpath, event 23 | -------------------------------------------------------------------------------- /data/dsec.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pdb import Pdb 3 | import random 4 | import imageio 5 | import numpy as np 6 | import torch.utils.data as data 7 | 8 | from data import common 9 | 10 | from utils import interact 11 | import PIL.Image 12 | 13 | 14 | class Dataset(data.Dataset): 15 | """Basic dataloader class 16 | """ 17 | def __init__(self, args, mode='train'): 18 | super(Dataset, self).__init__() 19 | self.args = args 20 | self.mode = mode 21 | 22 | self.modes = () 23 | self.set_modes() 24 | self._check_mode() 25 | 26 | self.set_keys() 27 | 28 | if self.mode == 'train': 29 | dataset = args.data_train 30 | elif self.mode == 'val': 31 | dataset = args.data_val 32 | elif self.mode == 'test': 33 | dataset = args.data_test 34 | elif self.mode == 'demo': 35 | pass 36 | else: 37 | raise NotImplementedError('not implemented for this mode: {}!'.format(self.mode)) 38 | 39 | if self.mode == 'demo': 40 | self.subset_root = args.demo_input_dir 41 | else: 42 | # self.subset_root = os.path.join(args.data_root, dataset, self.mode) 43 | if args.split == 1: 44 | if self.mode == 'train': 45 | dataset = ['indoor_flying_2', 'indoor_flying_3'] 46 | else: 47 | dataset = 'indoor_flying_1' 48 | elif args.split == 3: 49 | if self.mode == 'train': 50 | dataset = ['indoor_flying_1', 'indoor_flying_2'] 51 | else: 52 | dataset = 'indoor_flying_3' 53 | 54 | if self.mode == 'train': 55 | self.subset_root = [] 56 | for set in dataset: 57 | self.subset_root.append(os.path.join(args.data_root, set)) 58 | else: 59 | self.subset_root = os.path.join(args.data_root, dataset) 60 | 61 | 62 | # import pdb 63 | # pdb.set_trace() 64 | 65 | # FRAMES_FILTER_FOR_TEST = { 66 | # 'indoor_flying': { 67 | # 1: list(range(140, 1201)), 68 | # 2: list(range(120, 1421)), 69 | # 3: list(range(73, 1616)), 70 | # 4: list(range(190, 290)) 71 | # } 72 | # } 73 | 74 | # FRAMES_FILTER_FOR_TRAINING = { 75 | # 'indoor_flying': { 76 | # 1: list(range(80, 1260)), 77 | # 2: list(range(160, 1580)), 78 | # 3: list(range(125, 1815)), 79 | # 4: list(range(190, 290)) 80 | # } 81 | # } 82 | 83 | self.left_image_list = [] 84 | self.right_image_list = [] 85 | self.left_event_list = [] 86 | self.right_event_list = [] 87 | self.disp_list = [] 88 | 89 | 90 | 91 | self._scan() 92 | 93 | 94 | def set_modes(self): 95 | self.modes = ('train', 'val', 'test', 'demo') 96 | 97 | def _check_mode(self): 98 | """Should be called in the child class __init__() after super 99 | """ 100 | 101 | if self.mode not in self.modes: 102 | raise NotImplementedError('mode error: not for {}'.format(self.mode)) 103 | 104 | return 105 | 106 | def set_keys(self): 107 | 108 | self.left_image_key = 'image0' 109 | self.right_image_key = 'image1' 110 | # self.left_event_key = 'select0' 111 | # self.right_event_key = 'select1' 112 | # self.left_event_key = 'numvox0' 113 | # self.right_event_key = 'numvox1' 114 | self.left_event_key = 'voxel0_orig' 115 | self.right_event_key = 'voxel1_orig' 116 | self.disp_key = 'disparity_image' 117 | 118 | 119 | self.non_left_image_keys = [] 120 | self.non_right_image_keys = [] 121 | self.non_left_event_keys = [] 122 | self.non_right_event_keys = [] 123 | self.non_disp_keys = [] 124 | 125 | return 126 | 127 | def _scan(self, root=None): 128 | """Should be called in the child class __init__() after super 129 | """ 130 | if root is None: 131 | root = self.subset_root 132 | 133 | # if self.blur_key in self.non_blur_keys: 134 | # self.non_blur_keys.remove(self.blur_key) 135 | # if self.sharp_key in self.non_sharp_keys: 136 | # self.non_sharp_keys.remove(self.sharp_key) 137 | # if self.event_key in self.non_event_keys: 138 | # self.non_event_keys.remove(self.event_key) 139 | 140 | if self.left_image_key in self.non_left_image_keys: 141 | self.non_left_image_keys.remove(self.left_image_key) 142 | if self.right_image_key in self.non_right_image_keys: 143 | self.non_right_image_keys.remove(self.right_image_key) 144 | if self.left_event_key in self.non_left_event_keys: 145 | self.non_left_event_keys.remove(self.left_event_key) 146 | if self.right_event_key in self.non_right_event_keys: 147 | self.non_right_event_keys.remove(self.right_event_key) 148 | if self.disp_key in self.non_disp_keys: 149 | self.non_disp_keys.remove(self.disp_key) 150 | 151 | 152 | 153 | def _key_check(path, true_key, false_keys): 154 | path = os.path.join(path, '') 155 | if path.find(true_key) >= 0: 156 | for false_key in false_keys: 157 | if path.find(false_key) >= 0: 158 | return False 159 | 160 | return True 161 | else: 162 | return False 163 | # FRAMES_FILTER_FOR_TEST = { 164 | # 'indoor_flying': { 165 | # 1: list(range(140, 1201)), 166 | # 2: list(range(120, 1421)), 167 | # 3: list(range(73, 1616)), 168 | # 4: list(range(190, 290)) 169 | # } 170 | # } 171 | 172 | # FRAMES_FILTER_FOR_TRAINING = { 173 | # 'indoor_flying': { 174 | # 1: list(range(80, 1260)), 175 | # 2: list(range(160, 1580)), 176 | # 3: list(range(125, 1815)), 177 | # 4: list(range(190, 290)) 178 | # } 179 | # } 180 | 181 | # original 182 | FILTER_TEST = { 183 | # 'indoor_flying_1': list(range(140, 1201)), 184 | # 'indoor_flying_2': list(range(120, 1421)), 185 | # 'indoor_flying_3': list(range(73, 1616)) 186 | # 'indoor_flying_1': list(range(140, 1001)), 187 | 'indoor_flying_1': list(range(140, 1001)), 188 | 'indoor_flying_2': list(range(120, 1221)), 189 | 'indoor_flying_3': list(range(273, 1616)) 190 | } 191 | 192 | # test 193 | # FILTER_TEST = { 194 | # # 'indoor_flying_1': list(range(140, 1201)), 195 | # # 'indoor_flying_2': list(range(120, 1421)), 196 | # # 'indoor_flying_3': list(range(73, 1616)) 197 | # 'indoor_flying_1': list(range(140, 161)), 198 | # 'indoor_flying_2': list(range(120, 141)), 199 | # 'indoor_flying_3': list(range(73, 93)) 200 | # } 201 | 202 | # original 203 | FILTER_TRAIN = { 204 | 'indoor_flying_1': list(range(80, 1260)), 205 | 'indoor_flying_2': list(range(160, 1580)), 206 | 'indoor_flying_3': list(range(125, 1815)) 207 | } 208 | 209 | # test 210 | # FILTER_TRAIN = { 211 | # 'indoor_flying_1': list(range(80, 90)), 212 | # 'indoor_flying_2': list(range(160, 180)), 213 | # 'indoor_flying_3': list(range(125, 145)) 214 | # } 215 | 216 | def _get_list_by_key(root, true_key, false_keys): 217 | data_list = [] 218 | if isinstance(root, (list, tuple)): 219 | for rt in root: 220 | for sub, dirs, files in os.walk(rt): 221 | if not dirs: 222 | file_list = [os.path.join(sub, f) for f in files if int(f.split('.')[0]) in FILTER_TRAIN[rt.split('/')[-1]]] 223 | if _key_check(sub, true_key, false_keys): 224 | data_list += file_list 225 | else: 226 | for sub, dirs, files in os.walk(root): 227 | if not dirs: 228 | file_list = [os.path.join(sub, f) for f in files if int(f.split('.')[0]) in FILTER_TEST[root.split('/')[-1]]] 229 | if _key_check(sub, true_key, false_keys): 230 | data_list += file_list 231 | 232 | 233 | data_list.sort() 234 | return data_list 235 | 236 | def _rectify_keys(): 237 | 238 | self.left_image_key = os.path.join(self.left_image_key, '') 239 | self.non_left_image_keys = [os.path.join(non_left_image_key, '') for non_left_image_key in self.non_left_image_keys] 240 | self.left_event_key = os.path.join(self.left_event_key, '') 241 | self.non_left_event_keys = [os.path.join(non_left_event_key, '') for non_left_event_key in self.non_left_event_keys] 242 | self.right_image_key = os.path.join(self.right_image_key, '') 243 | self.non_right_image_keys = [os.path.join(non_right_image_key, '') for non_right_image_key in self.non_right_image_keys] 244 | self.right_event_key = os.path.join(self.right_event_key, '') 245 | self.non_right_event_keys = [os.path.join(non_right_event_key, '') for non_right_event_key in self.non_right_event_keys] 246 | self.disp_key = os.path.join(self.disp_key, '') 247 | self.non_disp_keys = [os.path.join(non_disp_key, '') for non_disp_key in self.non_disp_keys] 248 | 249 | 250 | _rectify_keys() 251 | 252 | self.left_image_list = _get_list_by_key(root, self.left_image_key, self.non_left_image_keys) 253 | self.left_event_list = _get_list_by_key(root, self.left_event_key, self.non_left_event_keys) 254 | self.right_image_list = _get_list_by_key(root, self.right_image_key, self.non_right_image_keys) 255 | self.right_event_list = _get_list_by_key(root, self.right_event_key, self.non_right_event_keys) 256 | self.disp_list = _get_list_by_key(root, self.disp_key, self.non_disp_keys) 257 | 258 | 259 | if len(self.left_image_list) > 0: 260 | assert(len(self.left_image_list) == len(self.left_event_list)) 261 | if len(self.right_image_list) > 0: 262 | assert(len(self.right_image_list) == len(self.right_event_list)) 263 | if len(self.left_image_list) > 0: 264 | assert(len(self.left_image_list) == len(self.right_image_list)) 265 | if len(self.disp_list) > 0: 266 | assert(len(self.disp_list) == len(self.left_image_list)) 267 | 268 | return 269 | 270 | def __getitem__(self, idx): 271 | 272 | left_image = imageio.imread(self.left_image_list[idx], pilmode='RGB') 273 | right_image = imageio.imread(self.right_image_list[idx], pilmode='RGB') 274 | imgs = [left_image, right_image] 275 | 276 | left_event = np.load(self.left_event_list[idx]) 277 | right_event = np.load(self.right_event_list[idx]) 278 | # print(left_event.shape) 279 | 280 | disp = np.array(PIL.Image.open(self.disp_list[idx])).astype(np.uint8) 281 | invalid_disparity = (disp == 255.0) 282 | disparity_image = (disp / 7.0) 283 | disparity_image[invalid_disparity] = float('inf') 284 | 285 | 286 | 287 | pad_width = 0 # dummy value 288 | 289 | noise_imgs = [imgs[0], imgs[1]] 290 | noise_imgs[0] = common.add_noise(imgs[0], sigma_sigma=2, rgb_range=self.args.rgb_range) 291 | noise_imgs[1] = common.add_noise(imgs[1], sigma_sigma=2, rgb_range=self.args.rgb_range) 292 | 293 | 294 | # print(event.shape) 295 | 296 | if self.args.gaussian_pyramid: 297 | if self.mode == ('train' or 'demo'): 298 | 299 | imgs = common.generate_pyramid(*imgs, n_scales=self.args.n_scales) 300 | left_event = common.generate_event_pyramid(left_event, n_scales=self.args.n_scales) 301 | right_event = common.generate_event_pyramid(right_event, n_scales=self.args.n_scales) 302 | else: 303 | # left_event, pad_width = common.event_pad(left_event, divisor=2**(self.args.n_scales-1)) 304 | # right_event, pad_width = common.event_pad(right_event, divisor=2**(self.args.n_scales-1)) 305 | # imgs[0], pad_width = common.pad(imgs[0], divisor=2**(self.args.n_scales-1)) 306 | # imgs[1], pad_width = common.pad(imgs[1], divisor=2**(self.args.n_scales-1)) 307 | imgs = common.generate_pyramid(*imgs, n_scales=self.args.n_scales) 308 | left_event = common.generate_event_pyramid(*left_event, n_scales=self.args.n_scales) 309 | right_event = common.generate_event_pyramid(*right_event, n_scales=self.args.n_scales) 310 | 311 | 312 | imgs = common.np2tensor(*imgs) 313 | noise_imgs = common.np2tensor(*noise_imgs) 314 | 315 | left_event = common.event2tensor(left_event)[0] 316 | right_event = common.event2tensor(right_event)[0] 317 | 318 | if self.mode == 'train': 319 | relpath = os.path.relpath(self.left_image_list[idx], self.subset_root[0]) 320 | else: 321 | relpath = os.path.relpath(self.left_image_list[idx], self.subset_root) 322 | 323 | 324 | # blur = imgs[0] 325 | 326 | left_img = imgs[0] 327 | right_img = imgs[1] 328 | # sharp = imgs[1] if len(imgs) > 1 else False 329 | left_noise = noise_imgs[0] 330 | right_noise = noise_imgs[1] 331 | 332 | 333 | return left_img, right_img, pad_width, idx, relpath, left_event, right_event, left_noise, right_noise, disparity_image 334 | 335 | def __len__(self): 336 | return len(self.left_image_list) 337 | # return 32 338 | 339 | 340 | 341 | 342 | 343 | -------------------------------------------------------------------------------- /data/eventslicer.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Dict, Tuple 3 | 4 | import h5py 5 | from numba import jit 6 | import numpy as np 7 | 8 | 9 | class EventSlicer: 10 | def __init__(self, h5f: h5py.File): 11 | self.h5f = h5f 12 | 13 | self.events = dict() 14 | for dset_str in ['p', 'x', 'y', 't']: 15 | self.events[dset_str] = self.h5f['events/{}'.format(dset_str)] 16 | 17 | # This is the mapping from milliseconds to event index: 18 | # It is defined such that 19 | # (1) t[ms_to_idx[ms]] >= ms*1000 20 | # (2) t[ms_to_idx[ms] - 1] < ms*1000 21 | # ,where 'ms' is the time in milliseconds and 't' the event timestamps in microseconds. 22 | # 23 | # As an example, given 't' and 'ms': 24 | # t: 0 500 2100 5000 5000 7100 7200 7200 8100 9000 25 | # ms: 0 1 2 3 4 5 6 7 8 9 26 | # 27 | # we get 28 | # 29 | # ms_to_idx: 30 | # 0 2 2 3 3 3 5 5 8 9 31 | self.ms_to_idx = np.asarray(self.h5f['ms_to_idx'], dtype='int64') 32 | 33 | if "t_offset" in list(h5f.keys()): 34 | self.t_offset = int(h5f['t_offset'][()]) 35 | else: 36 | self.t_offset = 0 37 | self.t_final = int(self.events['t'][-1]) + self.t_offset 38 | 39 | def get_start_time_us(self): 40 | return self.t_offset 41 | 42 | def get_final_time_us(self): 43 | return self.t_final 44 | 45 | def get_events(self, t_start_us: int, t_end_us: int) -> Dict[str, np.ndarray]: 46 | """Get events (p, x, y, t) within the specified time window 47 | Parameters 48 | ---------- 49 | t_start_us: start time in microseconds 50 | t_end_us: end time in microseconds 51 | Returns 52 | ------- 53 | events: dictionary of (p, x, y, t) or None if the time window cannot be retrieved 54 | """ 55 | assert t_start_us < t_end_us 56 | 57 | # We assume that the times are top-off-day, hence subtract offset: 58 | t_start_us -= self.t_offset 59 | t_end_us -= self.t_offset 60 | 61 | t_start_ms, t_end_ms = self.get_conservative_window_ms(t_start_us, t_end_us) 62 | t_start_ms_idx = self.ms2idx(t_start_ms) 63 | t_end_ms_idx = self.ms2idx(t_end_ms) 64 | 65 | if t_start_ms_idx is None or t_end_ms_idx is None: 66 | # Cannot guarantee window size anymore 67 | return None 68 | 69 | events = dict() 70 | time_array_conservative = np.asarray(self.events['t'][t_start_ms_idx:t_end_ms_idx]) 71 | idx_start_offset, idx_end_offset = self.get_time_indices_offsets(time_array_conservative, t_start_us, t_end_us) 72 | t_start_us_idx = t_start_ms_idx + idx_start_offset 73 | t_end_us_idx = t_start_ms_idx + idx_end_offset 74 | # Again add t_offset to get gps time 75 | events['t'] = time_array_conservative[idx_start_offset:idx_end_offset] + self.t_offset 76 | for dset_str in ['p', 'x', 'y']: 77 | events[dset_str] = np.asarray(self.events[dset_str][t_start_us_idx:t_end_us_idx]) 78 | assert events[dset_str].size == events['t'].size 79 | return events 80 | 81 | 82 | @staticmethod 83 | def get_conservative_window_ms(ts_start_us: int, ts_end_us) -> Tuple[int, int]: 84 | """Compute a conservative time window of time with millisecond resolution. 85 | We have a time to index mapping for each millisecond. Hence, we need 86 | to compute the lower and upper millisecond to retrieve events. 87 | Parameters 88 | ---------- 89 | ts_start_us: start time in microseconds 90 | ts_end_us: end time in microseconds 91 | Returns 92 | ------- 93 | window_start_ms: conservative start time in milliseconds 94 | window_end_ms: conservative end time in milliseconds 95 | """ 96 | assert ts_end_us > ts_start_us 97 | window_start_ms = math.floor(ts_start_us/1000) 98 | window_end_ms = math.ceil(ts_end_us/1000) 99 | return window_start_ms, window_end_ms 100 | 101 | @staticmethod 102 | @jit(nopython=True) 103 | def get_time_indices_offsets( 104 | time_array: np.ndarray, 105 | time_start_us: int, 106 | time_end_us: int) -> Tuple[int, int]: 107 | """Compute index offset of start and end timestamps in microseconds 108 | Parameters 109 | ---------- 110 | time_array: timestamps (in us) of the events 111 | time_start_us: start timestamp (in us) 112 | time_end_us: end timestamp (in us) 113 | Returns 114 | ------- 115 | idx_start: Index within this array corresponding to time_start_us 116 | idx_end: Index within this array corresponding to time_end_us 117 | such that (in non-edge cases) 118 | time_array[idx_start] >= time_start_us 119 | time_array[idx_end] >= time_end_us 120 | time_array[idx_start - 1] < time_start_us 121 | time_array[idx_end - 1] < time_end_us 122 | this means that 123 | time_start_us <= time_array[idx_start:idx_end] < time_end_us 124 | """ 125 | 126 | assert time_array.ndim == 1 127 | 128 | idx_start = -1 129 | if time_array[-1] < time_start_us: 130 | # This can happen in extreme corner cases. E.g. 131 | # time_array[0] = 1016 132 | # time_array[-1] = 1984 133 | # time_start_us = 1990 134 | # time_end_us = 2000 135 | 136 | # Return same index twice: array[x:x] is empty. 137 | return time_array.size, time_array.size 138 | else: 139 | for idx_from_start in range(0, time_array.size, 1): 140 | if time_array[idx_from_start] >= time_start_us: 141 | idx_start = idx_from_start 142 | break 143 | assert idx_start >= 0 144 | 145 | idx_end = time_array.size 146 | for idx_from_end in range(time_array.size - 1, -1, -1): 147 | if time_array[idx_from_end] >= time_end_us: 148 | idx_end = idx_from_end 149 | else: 150 | break 151 | 152 | assert time_array[idx_start] >= time_start_us 153 | if idx_end < time_array.size: 154 | assert time_array[idx_end] >= time_end_us 155 | if idx_start > 0: 156 | assert time_array[idx_start - 1] < time_start_us 157 | if idx_end > 0: 158 | assert time_array[idx_end - 1] < time_end_us 159 | return idx_start, idx_end 160 | 161 | def ms2idx(self, time_ms: int) -> int: 162 | assert time_ms >= 0 163 | if time_ms >= self.ms_to_idx.size: 164 | return None 165 | return self.ms_to_idx[time_ms] 166 | -------------------------------------------------------------------------------- /data/provider.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | 5 | from .sequence import Sequence 6 | 7 | class DatasetProvider: 8 | def __init__(self, dataset_path: Path, delta_t_ms: int=50, num_bins=5): 9 | 10 | train_path = dataset_path / 'train' 11 | 12 | assert dataset_path.is_dir(), str(dataset_path) 13 | assert train_path.is_dir(), str(train_path) 14 | 15 | train_sequences = list() 16 | for child in train_path.iterdir(): 17 | 18 | train_sequences.append(Sequence(child, 'train', delta_t_ms, num_bins)) 19 | 20 | self.train_dataset = torch.utils.data.ConcatDataset(train_sequences) 21 | 22 | val_path = dataset_path / 'val' 23 | 24 | assert val_path.is_dir(), str(val_path) 25 | 26 | val_sequences = list() 27 | for child in val_path.iterdir(): 28 | 29 | val_sequences.append(Sequence(child, 'val', delta_t_ms, num_bins)) 30 | 31 | 32 | self.val_dataset = torch.utils.data.ConcatDataset(val_sequences) 33 | 34 | test_path = dataset_path / 'test' 35 | 36 | assert test_path.is_dir(), str(test_path) 37 | 38 | test_sequences = list() 39 | for child in test_path.iterdir(): 40 | 41 | test_sequences.append(Sequence(child, 'test', delta_t_ms, num_bins)) 42 | 43 | 44 | self.test_dataset = torch.utils.data.ConcatDataset(test_sequences) 45 | 46 | def get_train_dataset(self): 47 | return self.train_dataset 48 | 49 | def get_val_dataset(self): 50 | # Implement this according to your needs. 51 | return self.val_dataset 52 | # raise NotImplementedError 53 | 54 | def get_test_dataset(self): 55 | # Implement this according to your needs. 56 | # raise NotImplementedError 57 | return self.test_dataset 58 | -------------------------------------------------------------------------------- /data/representations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class EventRepresentation: 5 | def convert(self, x: torch.Tensor, y: torch.Tensor, pol: torch.Tensor, time: torch.Tensor): 6 | raise NotImplementedError 7 | 8 | 9 | class VoxelGrid(EventRepresentation): 10 | def __init__(self, channels: int, height: int, width: int, normalize: bool): 11 | self.voxel_grid = torch.zeros((channels, height, width), dtype=torch.float, requires_grad=False) 12 | self.nb_channels = channels 13 | self.normalize = normalize 14 | 15 | def convert(self, x: torch.Tensor, y: torch.Tensor, pol: torch.Tensor, time: torch.Tensor): 16 | assert x.shape == y.shape == pol.shape == time.shape 17 | assert x.ndim == 1 18 | 19 | C, H, W = self.voxel_grid.shape 20 | with torch.no_grad(): 21 | 22 | self.voxel_grid = self.voxel_grid.to(pol.device) 23 | voxel_grid = self.voxel_grid.clone() 24 | 25 | t_norm = time 26 | t_norm = (C - 1) * (t_norm-t_norm[0]) / (t_norm[-1]-t_norm[0]) 27 | 28 | x0 = x.int() 29 | y0 = y.int() 30 | t0 = t_norm.int() 31 | 32 | value = 2*pol-1 33 | 34 | for xlim in [x0,x0+1]: 35 | for ylim in [y0,y0+1]: 36 | for tlim in [t0,t0+1]: 37 | 38 | mask = (xlim < W) & (xlim >= 0) & (ylim < H) & (ylim >= 0) & (tlim >= 0) & (tlim < self.nb_channels) 39 | interp_weights = value * (1 - (xlim-x).abs()) * (1 - (ylim-y).abs()) * (1 - (tlim - t_norm).abs()) 40 | 41 | index = H * W * tlim.long() + \ 42 | W * ylim.long() + \ 43 | xlim.long() 44 | 45 | voxel_grid.put_(index[mask], interp_weights[mask], accumulate=True) 46 | 47 | if self.normalize: 48 | mask = torch.nonzero(voxel_grid, as_tuple=True) 49 | if mask[0].size()[0] > 0: 50 | mean = voxel_grid[mask].mean() 51 | std = voxel_grid[mask].std() 52 | if std > 0: 53 | voxel_grid[mask] = (voxel_grid[mask] - mean) / std 54 | else: 55 | voxel_grid[mask] = voxel_grid[mask] - mean 56 | 57 | return voxel_grid 58 | -------------------------------------------------------------------------------- /data/sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data import Sampler 4 | import torch.distributed as dist 5 | 6 | 7 | class DistributedEvalSampler(Sampler): 8 | r""" 9 | DistributedEvalSampler is different from DistributedSampler. 10 | It does NOT add extra samples to make it evenly divisible. 11 | DistributedEvalSampler should NOT be used for training. The distributed processes could hang forever. 12 | See this issue for details: https://github.com/pytorch/pytorch/issues/22584 13 | shuffle is disabled by default 14 | 15 | DistributedEvalSampler is for evaluation purpose where synchronization does not happen every epoch. 16 | Synchronization should be done outside the dataloader loop. 17 | 18 | Sampler that restricts data loading to a subset of the dataset. 19 | 20 | It is especially useful in conjunction with 21 | :class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each 22 | process can pass a :class`~torch.utils.data.DistributedSampler` instance as a 23 | :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the 24 | original dataset that is exclusive to it. 25 | 26 | .. note:: 27 | Dataset is assumed to be of constant size. 28 | 29 | Arguments: 30 | dataset: Dataset used for sampling. 31 | num_replicas (int, optional): Number of processes participating in 32 | distributed training. By default, :attr:`rank` is retrieved from the 33 | current distributed group. 34 | rank (int, optional): Rank of the current process within :attr:`num_replicas`. 35 | By default, :attr:`rank` is retrieved from the current distributed 36 | group. 37 | shuffle (bool, optional): If ``True`` (default), sampler will shuffle the 38 | indices. 39 | seed (int, optional): random seed used to shuffle the sampler if 40 | :attr:`shuffle=True`. This number should be identical across all 41 | processes in the distributed group. Default: ``0``. 42 | 43 | .. warning:: 44 | In distributed mode, calling the :meth`set_epoch(epoch) ` method at 45 | the beginning of each epoch **before** creating the :class:`DataLoader` iterator 46 | is necessary to make shuffling work properly across multiple epochs. Otherwise, 47 | the same ordering will be always used. 48 | 49 | Example:: 50 | 51 | >>> sampler = DistributedSampler(dataset) if is_distributed else None 52 | >>> loader = DataLoader(dataset, shuffle=(sampler is None), 53 | ... sampler=sampler) 54 | >>> for epoch in range(start_epoch, n_epochs): 55 | ... if is_distributed: 56 | ... sampler.set_epoch(epoch) 57 | ... train(loader) 58 | """ 59 | 60 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=False, seed=0): 61 | if num_replicas is None: 62 | if not dist.is_available(): 63 | raise RuntimeError("Requires distributed package to be available") 64 | num_replicas = dist.get_world_size() 65 | if rank is None: 66 | if not dist.is_available(): 67 | raise RuntimeError("Requires distributed package to be available") 68 | rank = dist.get_rank() 69 | self.dataset = dataset 70 | self.num_replicas = num_replicas 71 | self.rank = rank 72 | self.epoch = 0 73 | # self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 74 | # self.total_size = self.num_samples * self.num_replicas 75 | self.total_size = len(self.dataset) # true value without extra samples 76 | indices = list(range(self.total_size)) 77 | indices = indices[self.rank:self.total_size:self.num_replicas] 78 | self.num_samples = len(indices) # true value without extra samples 79 | 80 | self.shuffle = shuffle 81 | self.seed = seed 82 | 83 | def __iter__(self): 84 | if self.shuffle: 85 | # deterministically shuffle based on epoch and seed 86 | g = torch.Generator() 87 | g.manual_seed(self.seed + self.epoch) 88 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 89 | else: 90 | indices = list(range(len(self.dataset))) 91 | 92 | 93 | # # add extra samples to make it evenly divisible 94 | # indices += indices[:(self.total_size - len(indices))] 95 | # assert len(indices) == self.total_size 96 | 97 | # subsample 98 | indices = indices[self.rank:self.total_size:self.num_replicas] 99 | assert len(indices) == self.num_samples 100 | 101 | return iter(indices) 102 | 103 | def __len__(self): 104 | return self.num_samples 105 | 106 | def set_epoch(self, epoch): 107 | r""" 108 | Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas 109 | use a different random ordering for each epoch. Otherwise, the next iteration of this 110 | sampler will yield the same ordering. 111 | 112 | Arguments: 113 | epoch (int): _epoch number. 114 | """ 115 | self.epoch = epoch 116 | -------------------------------------------------------------------------------- /data/visualization.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import matplotlib as mpl 3 | import matplotlib.cm as cm 4 | import numpy as np 5 | 6 | 7 | def disp_img_to_rgb_img(disp_array: np.ndarray): 8 | disp_pixels = np.argwhere(disp_array > 0) 9 | u_indices = disp_pixels[:, 1] 10 | v_indices = disp_pixels[:, 0] 11 | disp = disp_array[v_indices, u_indices] 12 | max_disp = 80 13 | 14 | norm = mpl.colors.Normalize(vmin=0, vmax=max_disp, clip=True) 15 | mapper = cm.ScalarMappable(norm=norm, cmap='inferno') 16 | 17 | disp_color = mapper.to_rgba(disp)[..., :3] 18 | output_image = np.zeros((disp_array.shape[0], disp_array.shape[1], 3)) 19 | output_image[v_indices, u_indices, :] = disp_color 20 | output_image = (255 * output_image).astype("uint8") 21 | output_image = cv2.cvtColor(output_image, cv2.COLOR_RGB2BGR) 22 | return output_image 23 | 24 | def show_image(image): 25 | cv2.namedWindow('viz', cv2.WND_PROP_FULLSCREEN) 26 | cv2.imshow('viz', image) 27 | cv2.waitKey(0) 28 | 29 | def show_multi_image(overlay, image, event): 30 | cv2.namedWindow('viz', cv2.WND_PROP_FULLSCREEN) 31 | # import pdb; pdb.set_trace() 32 | image = np.repeat(image[..., np.newaxis], 3, axis=2) 33 | image = np.hstack((overlay, image)) 34 | image = np.hstack((image, event)) 35 | cv2.imshow('viz', image) 36 | cv2.waitKey(0) 37 | 38 | def get_disp_overlay(image_1c, disp_rgb_image, height, width): 39 | image = np.repeat(image_1c[..., np.newaxis], 3, axis=2) 40 | overlay = cv2.addWeighted(image, 0.1, disp_rgb_image, 0.9, 0) 41 | return overlay 42 | 43 | def show_disp_overlay(image_1c, disp_rgb_image, height, width): 44 | overlay = get_disp_overlay(image_1c, disp_rgb_image, height, width) 45 | show_image(overlay) 46 | 47 | def show_multi_overlay(image_1c, disp_rgb_image, height, width): 48 | overlay = get_disp_overlay(image_1c, disp_rgb_image, height, width) 49 | show_multi_image(overlay,image_1c, disp_rgb_image) 50 | -------------------------------------------------------------------------------- /disp_loss/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/disp_loss/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /disp_loss/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/disp_loss/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /disp_loss/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/disp_loss/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /disp_loss/__pycache__/disp_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/disp_loss/__pycache__/disp_loss.cpython-36.pyc -------------------------------------------------------------------------------- /disp_loss/__pycache__/disp_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/disp_loss/__pycache__/disp_loss.cpython-37.pyc -------------------------------------------------------------------------------- /disp_loss/__pycache__/disp_loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/disp_loss/__pycache__/disp_loss.cpython-38.pyc -------------------------------------------------------------------------------- /disp_loss/__pycache__/metric.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/disp_loss/__pycache__/metric.cpython-36.pyc -------------------------------------------------------------------------------- /disp_loss/__pycache__/metric.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/disp_loss/__pycache__/metric.cpython-37.pyc -------------------------------------------------------------------------------- /disp_loss/__pycache__/metric.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/disp_loss/__pycache__/metric.cpython-38.pyc -------------------------------------------------------------------------------- /disp_loss/adversarial.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from utils import interact 5 | 6 | import torch.cuda.amp as amp 7 | 8 | class Adversarial(nn.modules.loss._Loss): 9 | # pure loss function without saving & loading option 10 | # but trains deiscriminator 11 | def __init__(self, args, model, optimizer): 12 | super(Adversarial, self).__init__() 13 | self.args = args 14 | self.model = model.model 15 | self.optimizer = optimizer 16 | self.scaler = amp.GradScaler( 17 | init_scale=self.args.init_scale, 18 | enabled=self.args.amp 19 | ) 20 | 21 | self.gan_k = 1 22 | 23 | self.BCELoss = nn.BCEWithLogitsLoss() 24 | 25 | def forward(self, fake, real, training=False): 26 | if training: 27 | # update discriminator 28 | fake_detach = fake.detach() 29 | for _ in range(self.gan_k): 30 | self.optimizer.D.zero_grad() 31 | # d: B x 1 tensor 32 | with amp.autocast(self.args.amp): 33 | d_fake = self.model.D(fake_detach) 34 | d_real = self.model.D(real) 35 | 36 | label_fake = torch.zeros_like(d_fake) 37 | label_real = torch.ones_like(d_real) 38 | 39 | loss_d = self.BCELoss(d_fake, label_fake) + self.BCELoss(d_real, label_real) 40 | 41 | self.scaler.scale(loss_d).backward(retain_graph=False) 42 | self.scaler.step(self.optimizer.D) 43 | self.scaler.update() 44 | else: 45 | d_real = self.model.D(real) 46 | label_real = torch.ones_like(d_real) 47 | 48 | # update generator (outside here) 49 | d_fake_bp = self.model.D(fake) 50 | loss_g = self.BCELoss(d_fake_bp, label_real) 51 | 52 | return loss_g -------------------------------------------------------------------------------- /disp_loss/disp_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.modules.loss import SmoothL1Loss 5 | 6 | # 7 | # Disparity Loss 8 | # 9 | def EPE(output, target, maxdisp): 10 | mask = (target < maxdisp) & (0 < target) 11 | mask.detach() 12 | # d_diff = output[mask] - target[mask] 13 | # EPE_map = torch.abs(d_diff) 14 | # EPE_mean = torch.sum(EPE_map)/N 15 | 16 | output = output.squeeze(1) 17 | criterion = nn.MSELoss() 18 | 19 | # EPE_mean = F.smooth_l1_loss(output[mask], target[mask], size_average=True) 20 | EPE_mean = criterion(output[mask], target[mask]) 21 | # print(EPE_mean) 22 | return EPE_mean 23 | 24 | class multiscaleLoss(nn.modules.loss._Loss): 25 | def __init__(self, maxdisp): 26 | super(multiscaleLoss, self).__init__() 27 | 28 | self.maxdisp = maxdisp 29 | 30 | def forward(self, outputs, target): 31 | if type(outputs) not in [tuple, list]: 32 | outputs = [outputs] 33 | 34 | weights = [0.3, 0.3, 0.2, 0.1, 0.1] 35 | # weights = [0.5, 0.3, 0.2] 36 | loss = 0 37 | 38 | for output, weight in zip(outputs, weights): 39 | loss += weight * self.one_scale(output, target, self.maxdisp) 40 | return loss 41 | 42 | 43 | def one_scale(self, output, target, maxdisp): 44 | 45 | 46 | B, _, h, w = output.size() 47 | 48 | target_scaled = nn.functional.adaptive_max_pool2d(target, (h, w)) 49 | 50 | return EPE(output, target_scaled, maxdisp) 51 | 52 | def CEloss(disp_gt, max_disp, gt_distribute, pred_distribute): 53 | mask = (disp_gt > 0) & (disp_gt < max_disp) 54 | 55 | pred_distribute = torch.log(pred_distribute + 1e-8) 56 | 57 | ce_loss = torch.sum(-gt_distribute * pred_distribute, dim=1) 58 | ce_loss = torch.mean(ce_loss[mask]) 59 | return ce_loss 60 | 61 | def disp2distribute(disp_gt, max_disp, b=2): 62 | disp_gt = disp_gt.unsqueeze(1) 63 | disp_range = torch.arange(0, max_disp).view(1, -1, 1, 1).float().cuda() 64 | gt_distribute = torch.exp(-torch.abs(disp_range - disp_gt) / b) 65 | gt_distribute = gt_distribute / (torch.sum(gt_distribute, dim=1, keepdim=True) + 1e-8) 66 | return gt_distribute 67 | 68 | 69 | class smoothL1Loss(nn.modules.loss._Loss): 70 | def __init__(self): 71 | super(smoothL1Loss, self).__init__() 72 | 73 | def forward(self, outputs, target): 74 | # if type(outputs) not in [tuple, list]: 75 | # outputs = [outputs] 76 | 77 | 78 | if len(outputs) == 5: 79 | weights = [1 / 3, 2 / 3, 1.0, 1.0, 1.0] 80 | elif len(outputs) == 4: 81 | weights = [0.5, 0.5, 0.7, 1.0] 82 | elif len(outputs) == 3: 83 | weights = [0.5, 0.5, 1.0] 84 | elif len(outputs) == 8: 85 | weights = [0.5, 0.5, 1.0, 1.0] 86 | 87 | 88 | else: 89 | weights = 1 90 | # weights = [0.7, 1.0] 91 | loss = 0 92 | mask = (target < 80) & (0 < target) 93 | # mask = target != float('inf') 94 | 95 | 96 | # criterion = nn.MSELoss() 97 | if isinstance(weights, (list, tuple)): 98 | if len(outputs) == 8: 99 | outputs2 = outputs[4:] 100 | outputs1 = outputs[:4] 101 | for output, weight in zip(outputs1, weights): 102 | output = torch.squeeze(output, 1) 103 | loss += weight * F.smooth_l1_loss(output[mask], target[mask], size_average=True) 104 | 105 | target_distribute = disp2distribute(target, 48, b=2) 106 | for output, weight in zip(outputs2, weights): 107 | loss += weight * CEloss(target, 48, target_distribute, output) 108 | else: 109 | for output, weight in zip(outputs, weights): 110 | # import pdb; pdb.set_trace() 111 | output = torch.squeeze(output, 1) 112 | 113 | 114 | 115 | loss += weight * F.smooth_l1_loss(output[mask], target[mask], size_average=True) 116 | else: 117 | loss = F.smooth_l1_loss(outputs[0][mask], target[mask], size_average=True) 118 | # loss += weight * criterion(output[mask], target[mask], size_average=True) 119 | # print(loss) 120 | return loss 121 | 122 | class pasmLoss(nn.modules.loss._Loss): 123 | def __init__(self): 124 | super(pasmLoss, self).__init__() 125 | 126 | def forward(self, outputs, target): 127 | 128 | if len(outputs) == 1: 129 | loss = 0 130 | mask = target != float('inf') 131 | loss += F.smooth_l1_loss(outputs[mask], target[mask], size_average=True) 132 | else: 133 | loss = 0 134 | # mask = (target < 36) & (0 < target) 135 | mask = target != float('inf') 136 | 137 | 138 | output_disps, att, att_cycle, valid_mask = outputs 139 | loss += F.smooth_l1_loss(output_disps[mask], target[mask], size_average=True) 140 | 141 | loss_PAM_C = loss_pam_cycle(att_cycle, valid_mask).mean() 142 | loss_PAM_S = loss_pam_smoothness(att).mean() 143 | loss_PAM = loss_PAM_S + loss_PAM_C 144 | 145 | loss += loss_PAM 146 | return loss 147 | 148 | 149 | def loss_pam_cycle(att_cycle, valid_mask): 150 | weight = [0.2, 0.3, 0.5] 151 | loss = torch.zeros(1).to(att_cycle[0][0].device) 152 | 153 | for idx_scale in range(len(att_cycle)): 154 | b, c, h, w = valid_mask[idx_scale][0].shape 155 | I = torch.eye(w, w).repeat(b, h, 1, 1).to(att_cycle[0][0].device) 156 | 157 | att_left2right2left = att_cycle[idx_scale][0] 158 | att_right2left2right = att_cycle[idx_scale][1] 159 | valid_mask_left = valid_mask[idx_scale][0] 160 | valid_mask_right = valid_mask[idx_scale][1] 161 | 162 | loss_scale = L1Loss(att_left2right2left * valid_mask_left.permute(0, 2, 3, 1), I * valid_mask_left.permute(0, 2, 3, 1)) + \ 163 | L1Loss(att_right2left2right * valid_mask_right.permute(0, 2, 3, 1), I * valid_mask_right.permute(0, 2, 3, 1)) 164 | 165 | loss = loss + weight[idx_scale] * loss_scale 166 | 167 | return loss 168 | 169 | 170 | def loss_pam_smoothness(att): 171 | weight = [0.2, 0.3, 0.5] 172 | loss = torch.zeros(1).to(att[0][0].device) 173 | 174 | for idx_scale in range(len(att)): 175 | att_right2left = att[idx_scale][0] 176 | att_left2right = att[idx_scale][1] 177 | 178 | loss_scale = L1Loss(att_right2left[:, :-1, :, :], att_right2left[:, 1:, :, :]) + \ 179 | L1Loss(att_left2right[:, :-1, :, :], att_left2right[:, 1:, :, :]) + \ 180 | L1Loss(att_right2left[:, :, :-1, :-1], att_right2left[:, :, 1:, 1:]) + \ 181 | L1Loss(att_left2right[:, :, :-1, :-1], att_left2right[:, :, 1:, 1:]) 182 | 183 | loss = loss + weight[idx_scale] * loss_scale 184 | 185 | return loss 186 | 187 | 188 | def L1Loss(input, target): 189 | return (input - target).abs().mean() 190 | -------------------------------------------------------------------------------- /disp_loss/metric.py: -------------------------------------------------------------------------------- 1 | # from skimage.metrics import peak_signal_noise_ratio, structural_similarity 2 | 3 | from re import I 4 | import torch 5 | from torch import nn 6 | import copy 7 | import numpy as np 8 | 9 | 10 | 11 | class One_PA(nn.Module): 12 | def __init__(self, device_type='cpu', dtype=torch.float32): 13 | super(One_PA, self).__init__() 14 | 15 | self.device_type = device_type 16 | self.dtype = dtype # SSIM in half precision could be inaccurate 17 | 18 | def forward(self, input, target, maxdisp): 19 | """Implementation adopted from skimage.metrics.structural_similarity 20 | Default arguments set to multichannel=True, gaussian_weight=True, use_sample_covariance=False 21 | """ 22 | 23 | input = input.squeeze(1).to(self.device_type) 24 | target = target.to(self.device_type) 25 | 26 | true_disp = copy.deepcopy(target) 27 | index = np.argwhere((true_disp > 0) & (true_disp < maxdisp)) 28 | 29 | 30 | true_disp[index[0][:], index[1][:], index[2][:]] = np.abs( 31 | target[index[0][:], index[1][:], index[2][:]] - input[index[0][:], index[1][:], index[2][:]]) 32 | 33 | correct_1 = (true_disp[index[0][:], index[1][:], index[2][:]] < 1) 34 | 35 | return (1 - (float(torch.sum(correct_1)) / float(len(index[0])))) * 100 36 | 37 | class Two_PA(nn.Module): 38 | def __init__(self, device_type='cpu', dtype=torch.float32): 39 | super(Two_PA, self).__init__() 40 | 41 | self.device_type = device_type 42 | self.dtype = dtype # SSIM in half precision could be inaccurate 43 | 44 | def forward(self, input, target, maxdisp): 45 | """Implementation adopted from skimage.metrics.structural_similarity 46 | Default arguments set to multichannel=True, gaussian_weight=True, use_sample_covariance=False 47 | """ 48 | 49 | input = input.squeeze(1).to(self.device_type) 50 | target = target.to(self.device_type) 51 | 52 | true_disp = copy.deepcopy(target) 53 | index = np.argwhere((true_disp > 0) & (true_disp < maxdisp)) 54 | 55 | 56 | true_disp[index[0][:], index[1][:], index[2][:]] = np.abs( 57 | target[index[0][:], index[1][:], index[2][:]] - input[index[0][:], index[1][:], index[2][:]]) 58 | 59 | correct_1 = (true_disp[index[0][:], index[1][:], index[2][:]] < 2) 60 | 61 | return (1 - (float(torch.sum(correct_1)) / float(len(index[0])))) * 100 62 | 63 | class Three_PA(nn.Module): 64 | def __init__(self, device_type='cpu', dtype=torch.float32): 65 | super(Three_PA, self).__init__() 66 | 67 | self.device_type = device_type 68 | self.dtype = dtype # SSIM in half precision could be inaccurate 69 | 70 | def forward(self, input, target, maxdisp): 71 | """Implementation adopted from skimage.metrics.structural_similarity 72 | Default arguments set to multichannel=True, gaussian_weight=True, use_sample_covariance=False 73 | """ 74 | 75 | input = input.squeeze(1).to(self.device_type) 76 | target = target.to(self.device_type) 77 | 78 | true_disp = copy.deepcopy(target) 79 | index = np.argwhere((true_disp > 0) & (true_disp < maxdisp)) 80 | 81 | 82 | true_disp[index[0][:], index[1][:], index[2][:]] = np.abs( 83 | target[index[0][:], index[1][:], index[2][:]] - input[index[0][:], index[1][:], index[2][:]]) 84 | 85 | correct_1 = (true_disp[index[0][:], index[1][:], index[2][:]] < 3) 86 | return (1 - (float(torch.sum(correct_1)) / float(len(index[0])))) * 100 87 | 88 | class MAE(nn.Module): 89 | def __init__(self, device_type='cpu', dtype=torch.float32): 90 | super(MAE, self).__init__() 91 | 92 | self.device_type = device_type 93 | self.dtype = dtype # SSIM in half precision could be inaccurate 94 | 95 | def forward(self, input, target, maxdisp): 96 | """Implementation adopted from skimage.metrics.structural_similarity 97 | Default arguments set to multichannel=True, gaussian_weight=True, use_sample_covariance=False 98 | """ 99 | 100 | input = input.squeeze(1).to(self.device_type) 101 | target = target.to(self.device_type) 102 | 103 | true_disp = copy.deepcopy(target) 104 | index = np.argwhere((true_disp > 0) & (true_disp < maxdisp)) 105 | 106 | true_disp[index[0][:], index[1][:], index[2][:]] = np.abs( 107 | target[index[0][:], index[1][:], index[2][:]] - input[index[0][:], index[1][:], index[2][:]]) 108 | 109 | mae = torch.mean(true_disp[index[0][:], index[1][:], index[2][:]]) 110 | 111 | return mae 112 | 113 | class RMSE(nn.Module): 114 | def __init__(self, device_type='cpu', dtype=torch.float32): 115 | super(RMSE, self).__init__() 116 | 117 | self.device_type = device_type 118 | self.dtype = dtype # SSIM in half precision could be inaccurate 119 | 120 | def forward(self, input, target, maxdisp): 121 | """Implementation adopted from skimage.metrics.structural_similarity 122 | Default arguments set to multichannel=True, gaussian_weight=True, use_sample_covariance=False 123 | """ 124 | 125 | input = input.squeeze(1).to(self.device_type) 126 | target = target.to(self.device_type) 127 | 128 | true_disp = copy.deepcopy(target) 129 | index = np.argwhere((true_disp > 0) & (true_disp < maxdisp)) 130 | 131 | true_disp[index[0][:], index[1][:], index[2][:]] = np.abs( 132 | target[index[0][:], index[1][:], index[2][:]] - input[index[0][:], index[1][:], index[2][:]]) 133 | 134 | # rmse = torch.sqrt(torch.sum(torch.mul(true_disp[index[0][:], index[1][:], index[2][:]], \ 135 | # true_disp[index[0][:], index[1][:], index[2][:]]))/true_disp[index[0][:], index[1][:], index[2][:]].size()[0]) 136 | rmse = torch.sqrt(torch.sum(torch.mul(true_disp[index[0][:], index[1][:], index[2][:]], \ 137 | true_disp[index[0][:], index[1][:], index[2][:]])) / float(len(index[0]))) 138 | return rmse 139 | 140 | def disparity_to_depth(disparity_image): 141 | 142 | # unknown_disparity = disparity_image == float('inf') 143 | unknown_disparity = disparity_image == 0.0 144 | depth_image = \ 145 | 0.6 / (disparity_image + 1e-7) 146 | depth_image[unknown_disparity] = float('inf') 147 | return depth_image 148 | 149 | def compute_absolute_error(estimated_disparity, 150 | ground_truth_disparity, 151 | use_mean=True): 152 | 153 | absolute_difference = (estimated_disparity - ground_truth_disparity).abs() 154 | locations_without_ground_truth = torch.isinf(ground_truth_disparity) 155 | pixelwise_absolute_error = absolute_difference.clone() 156 | pixelwise_absolute_error[locations_without_ground_truth] = 0 157 | absolute_differece_with_ground_truth = absolute_difference[ 158 | ~locations_without_ground_truth] 159 | if absolute_differece_with_ground_truth.numel() == 0: 160 | average_absolute_error = 0.0 161 | else: 162 | if use_mean: 163 | average_absolute_error = absolute_differece_with_ground_truth.mean( 164 | ).item() 165 | else: 166 | average_absolute_error = absolute_differece_with_ground_truth.median( 167 | ).item() 168 | 169 | return pixelwise_absolute_error, average_absolute_error 170 | 171 | class Mean_Depth(nn.Module): 172 | def __init__(self, device_type = 'cpu', dtype = torch.float32): 173 | super(Mean_Depth, self).__init__() 174 | self.device_type = device_type 175 | self.dtype = dtype 176 | 177 | def forward(self, input, target, maxdisp): 178 | input = input.squeeze(1).to(self.device_type) 179 | target = target.to(self.device_type) 180 | 181 | input = disparity_to_depth(input) 182 | target = disparity_to_depth(target) 183 | 184 | error = compute_absolute_error(input, target)[1] * 100.0 185 | 186 | return error 187 | 188 | class Mean_Disp(nn.Module): 189 | def __init__(self, device_type = 'cpu', dtype = torch.float32): 190 | super(Mean_Disp, self).__init__() 191 | self.device_type = device_type 192 | self.dtype = dtype 193 | 194 | def forward(self, input, target, maxdisp): 195 | input = input.squeeze(1).to(self.device_type) 196 | target = target.to(self.device_type) 197 | 198 | unknown_disparity = target == 0.0 199 | target[unknown_disparity] = float('inf') 200 | error = compute_absolute_error(input, target)[1] 201 | 202 | return error 203 | -------------------------------------------------------------------------------- /launch.py: -------------------------------------------------------------------------------- 1 | """ distributed launcher adopted from torch.distributed.launch 2 | usage example: https://github.com/facebookresearch/maskrcnn-benchmark 3 | This enables using multiprocessing for each spawned process (as they are treated as main processes) 4 | """ 5 | import sys 6 | import subprocess 7 | from argparse import ArgumentParser, REMAINDER 8 | 9 | from utils import str2bool, int2str 10 | 11 | def parse_args(): 12 | parser = ArgumentParser(description="PyTorch distributed training launch " 13 | "helper utilty that will spawn up " 14 | "multiple distributed processes") 15 | 16 | 17 | parser.add_argument('--n_GPUs', type=int, default=1, help='the number of GPUs for training') 18 | 19 | # positional 20 | parser.add_argument("training_script", type=str, 21 | help="The full path to the single GPU training " 22 | "program/script to be launched in parallel, " 23 | "followed by all the arguments for the " 24 | "training script") 25 | 26 | # rest from the training program 27 | parser.add_argument('training_script_args', nargs=REMAINDER) 28 | return parser.parse_args() 29 | 30 | def main(): 31 | args = parse_args() 32 | 33 | processes = [] 34 | for rank in range(0, args.n_GPUs): 35 | cmd = [sys.executable] 36 | 37 | cmd.append(args.training_script) 38 | cmd.extend(args.training_script_args) 39 | 40 | cmd += ['--distributed', 'True'] 41 | cmd += ['--launched', 'True'] 42 | cmd += ['--n_GPUs', str(args.n_GPUs)] 43 | cmd += ['--rank', str(rank)] 44 | 45 | process = subprocess.Popen(cmd) 46 | processes.append(process) 47 | 48 | for process in processes: 49 | process.wait() 50 | if process.returncode != 0: 51 | raise subprocess.CalledProcessError(returncode=process.returncode, 52 | cmd=cmd) 53 | 54 | if __name__ == "__main__": 55 | main() 56 | -------------------------------------------------------------------------------- /loss/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/loss/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /loss/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/loss/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /loss/__pycache__/metric.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/loss/__pycache__/metric.cpython-36.pyc -------------------------------------------------------------------------------- /loss/__pycache__/metric.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/loss/__pycache__/metric.cpython-37.pyc -------------------------------------------------------------------------------- /loss/adversarial.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from utils import interact 5 | 6 | import torch.cuda.amp as amp 7 | 8 | class Adversarial(nn.modules.loss._Loss): 9 | # pure loss function without saving & loading option 10 | # but trains deiscriminator 11 | def __init__(self, args, model, optimizer): 12 | super(Adversarial, self).__init__() 13 | self.args = args 14 | self.model = model.model 15 | self.optimizer = optimizer 16 | self.scaler = amp.GradScaler( 17 | init_scale=self.args.init_scale, 18 | enabled=self.args.amp 19 | ) 20 | 21 | self.gan_k = 1 22 | 23 | self.BCELoss = nn.BCEWithLogitsLoss() 24 | 25 | def forward(self, fake, real, training=False): 26 | if training: 27 | # update discriminator 28 | fake_detach = fake.detach() 29 | for _ in range(self.gan_k): 30 | self.optimizer.D.zero_grad() 31 | # d: B x 1 tensor 32 | with amp.autocast(self.args.amp): 33 | d_fake = self.model.D(fake_detach) 34 | d_real = self.model.D(real) 35 | 36 | label_fake = torch.zeros_like(d_fake) 37 | label_real = torch.ones_like(d_real) 38 | 39 | loss_d = self.BCELoss(d_fake, label_fake) + self.BCELoss(d_real, label_real) 40 | 41 | self.scaler.scale(loss_d).backward(retain_graph=False) 42 | self.scaler.step(self.optimizer.D) 43 | self.scaler.update() 44 | else: 45 | d_real = self.model.D(real) 46 | label_real = torch.ones_like(d_real) 47 | 48 | # update generator (outside here) 49 | d_fake_bp = self.model.D(fake) 50 | loss_g = self.BCELoss(d_fake_bp, label_real) 51 | 52 | return loss_g -------------------------------------------------------------------------------- /loss/metric.py: -------------------------------------------------------------------------------- 1 | # from skimage.metrics import peak_signal_noise_ratio, structural_similarity 2 | 3 | import torch 4 | from torch import nn 5 | 6 | def _expand(img): 7 | if img.ndim < 4: 8 | img = img.expand([1] * (4-img.ndim) + list(img.shape)) 9 | 10 | return img 11 | 12 | class PSNR(nn.Module): 13 | def __init__(self): 14 | super(PSNR, self).__init__() 15 | 16 | def forward(self, im1, im2, data_range=None): 17 | # tensor input, constant output 18 | 19 | if data_range is None: 20 | data_range = 255 if im1.max() > 1 else 1 21 | 22 | se = (im1-im2)**2 23 | se = _expand(se) 24 | 25 | mse = se.mean(dim=list(range(1, se.ndim))) 26 | psnr = 10 * (data_range**2/mse).log10().mean() 27 | 28 | return psnr 29 | 30 | class SSIM(nn.Module): 31 | def __init__(self, device_type='cpu', dtype=torch.float32): 32 | super(SSIM, self).__init__() 33 | 34 | self.device_type = device_type 35 | self.dtype = dtype # SSIM in half precision could be inaccurate 36 | 37 | def _get_ssim_weight(): 38 | truncate = 3.5 39 | sigma = 1.5 40 | r = int(truncate * sigma + 0.5) # radius as in ndimage 41 | win_size = 2 * r + 1 42 | nch = 3 43 | 44 | weight = torch.Tensor([-(x - win_size//2)**2/float(2*sigma**2) for x in range(win_size)]).exp().unsqueeze(1) 45 | weight = weight.mm(weight.t()) 46 | weight /= weight.sum() 47 | weight = weight.repeat(nch, 1, 1, 1) 48 | 49 | return weight 50 | 51 | self.weight = _get_ssim_weight().to(self.device_type, dtype=self.dtype, non_blocking=True) 52 | 53 | def forward(self, im1, im2, data_range=None): 54 | """Implementation adopted from skimage.metrics.structural_similarity 55 | Default arguments set to multichannel=True, gaussian_weight=True, use_sample_covariance=False 56 | """ 57 | 58 | im1 = im1.to(self.device_type, dtype=self.dtype, non_blocking=True) 59 | im2 = im2.to(self.device_type, dtype=self.dtype, non_blocking=True) 60 | 61 | K1 = 0.01 62 | K2 = 0.03 63 | sigma = 1.5 64 | 65 | truncate = 3.5 66 | r = int(truncate * sigma + 0.5) # radius as in ndimage 67 | win_size = 2 * r + 1 68 | 69 | im1 = _expand(im1) 70 | im2 = _expand(im2) 71 | 72 | nch = im1.shape[1] 73 | 74 | if im1.shape[2] < win_size or im1.shape[3] < win_size: 75 | raise ValueError( 76 | "win_size exceeds image extent. If the input is a multichannel " 77 | "(color) image, set multichannel=True.") 78 | 79 | if data_range is None: 80 | data_range = 255 if im1.max() > 1 else 1 81 | 82 | def filter_func(img): # no padding 83 | return nn.functional.conv2d(img, self.weight, groups=nch).to(self.dtype) 84 | # return torch.conv2d(img, self.weight, groups=nch).to(self.dtype) 85 | 86 | # compute (weighted) means 87 | ux = filter_func(im1) 88 | uy = filter_func(im2) 89 | 90 | # compute (weighted) variances and covariances 91 | uxx = filter_func(im1 * im1) 92 | uyy = filter_func(im2 * im2) 93 | uxy = filter_func(im1 * im2) 94 | vx = (uxx - ux * ux) 95 | vy = (uyy - uy * uy) 96 | vxy = (uxy - ux * uy) 97 | 98 | R = data_range 99 | C1 = (K1 * R) ** 2 100 | C2 = (K2 * R) ** 2 101 | 102 | A1, A2, B1, B2 = ((2 * ux * uy + C1, 103 | 2 * vxy + C2, 104 | ux ** 2 + uy ** 2 + C1, 105 | vx + vy + C2)) 106 | D = B1 * B2 107 | S = (A1 * A2) / D 108 | 109 | # compute (weighted) mean of ssim 110 | mssim = S.mean() 111 | 112 | return mssim 113 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """main file that does everything""" 2 | from utils import interact 3 | 4 | from option import args, setup, cleanup 5 | from data import Data 6 | from model import Model 7 | from loss import Loss 8 | from optim import Optimizer 9 | from train import Trainer 10 | import torch 11 | from disp_loss import Loss as Disp_Loss 12 | 13 | 14 | ## train 15 | # python main.py --n_GPUs 4 --batch_size 8 --dataset indoor_flying_1 --split 1 --data_root ../../DSEC_data --save_dir max_disp_120_homo_batch_8 --model pertu_select_recon --gaussian_pyramid False --loss 1*L1+1*LPIPS --lr 1e-4 --test_every 200 --save_every 1 --disp_model gwc_pertu_noise_with_affinity --end_epoch 160 --validate_every 10 16 | 17 | ## test 18 | # CUDA_VISIBLE_DEVICES=3 python main.py --n_GPUs 1 --batch_size 1 --dataset indoor_flying_1 --split 1 --data_root ../../DSEC_data --save_dir max_disp_120_homo_batch_8 --model pertu_select_recon --gaussian_pyramid False --loss 1*L1+1*LPIPS --lr 1e-4 --test_every 100 --save_every 1 --disp_model gwc_pertu_noise_with_affinity --end_epoch 99 --validate_every 1 --load_epoch 76 19 | 20 | 21 | def main_worker(rank, args): 22 | args.rank = rank 23 | args = setup(args) 24 | 25 | loaders = Data(args).get_loader() 26 | 27 | 28 | model = Model(args) 29 | model.parallelize() 30 | if args.load_dir is not None: 31 | checkpoint = torch.load(args.load_dir) 32 | # import pdb 33 | # pdb.set_trace() 34 | if model.load_state_dict(args.load_dir): 35 | print('load the checkpoint {}'.format(args.load_dir)) 36 | 37 | optimizer = Optimizer(args, model) 38 | 39 | criterion = Loss(args, model=model, optimizer=optimizer) 40 | 41 | disp_criterion = Disp_Loss(args, model=model, optimizer=optimizer) 42 | 43 | trainer = Trainer(args, model, criterion, disp_criterion, optimizer, loaders) 44 | 45 | if args.stay: 46 | interact(local=locals()) 47 | exit() 48 | 49 | if args.demo: 50 | trainer.evaluate(epoch=args.start_epoch, mode='demo') 51 | exit() 52 | 53 | # for epoch in range(1, args.start_epoch): 54 | # if args.do_validate: 55 | # if epoch % args.validate_every == 0: 56 | # trainer.fill_evaluation(epoch, 'val') 57 | # if args.do_test: 58 | # if epoch % args.test_every == 0: 59 | # trainer.fill_evaluation(epoch, 'test') 60 | 61 | for epoch in range(args.start_epoch, args.end_epoch+1): 62 | 63 | # epoch = 24 64 | 65 | # if args.do_test: 66 | # # if epoch % args.test_every == 0: 67 | # # if trainer.epoch != epoch: 68 | # # trainer.load(epoch) 69 | # trainer.test(epoch) 70 | # import pdb; pdb.set_trace() 71 | 72 | 73 | # if args.do_validate: 74 | # if epoch % args.validate_every == 0: 75 | # if trainer.epoch != epoch: 76 | # trainer.load(epoch) 77 | # trainer.validate(epoch) 78 | # import pdb; pdb.set_trace() 79 | 80 | if args.do_train: 81 | trainer.train(epoch) 82 | # import pdb; pdb.set_trace() 83 | if args.do_validate: 84 | if epoch % args.validate_every == 0: 85 | if trainer.epoch != epoch: 86 | trainer.load(epoch) 87 | trainer.validate(epoch) 88 | 89 | if args.do_test: 90 | if epoch % args.test_every == 0: 91 | if trainer.epoch != epoch: 92 | trainer.load(epoch) 93 | trainer.test(epoch) 94 | 95 | if args.rank == 0 or not args.launched: 96 | print('') 97 | 98 | trainer.imsaver.join_background() 99 | 100 | cleanup(args) 101 | 102 | def main(): 103 | main_worker(args.rank, args) 104 | 105 | if __name__ == "__main__": 106 | main() -------------------------------------------------------------------------------- /model/ResNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from . import common 4 | 5 | def build_model(args): 6 | return ResNet(args) 7 | 8 | class ResNet(nn.Module): 9 | def __init__(self, args, in_channels=3, out_channels=3, n_feats=None, kernel_size=None, n_resblocks=None, mean_shift=True): 10 | super(ResNet, self).__init__() 11 | 12 | self.in_channels = in_channels 13 | self.out_channels = out_channels 14 | 15 | self.n_feats = args.n_feats if n_feats is None else n_feats 16 | self.kernel_size = args.kernel_size if kernel_size is None else kernel_size 17 | self.n_resblocks = args.n_resblocks if n_resblocks is None else n_resblocks 18 | 19 | self.mean_shift = mean_shift 20 | self.rgb_range = args.rgb_range 21 | self.mean = self.rgb_range / 2 22 | 23 | modules = [] 24 | modules.append(common.default_conv(self.in_channels, self.n_feats, self.kernel_size)) 25 | for _ in range(self.n_resblocks): 26 | modules.append(common.ResBlock(self.n_feats, self.kernel_size)) 27 | modules.append(common.default_conv(self.n_feats, self.out_channels, self.kernel_size)) 28 | 29 | self.body = nn.Sequential(*modules) 30 | 31 | def forward(self, input): 32 | if self.mean_shift: 33 | input = input - self.mean 34 | 35 | output = self.body(input) 36 | 37 | if self.mean_shift: 38 | output = output + self.mean 39 | 40 | return output 41 | 42 | def build_model(args): 43 | return ResNet(args) 44 | 45 | class ResNet_event(nn.Module): 46 | def __init__(self, args, in_channels=18, out_channels=3, n_feats=None, kernel_size=None, n_resblocks=None, mean_shift=True): 47 | super(ResNet_event, self).__init__() 48 | 49 | self.in_channels = in_channels 50 | self.out_channels = out_channels 51 | 52 | self.n_feats = args.n_feats if n_feats is None else n_feats 53 | # import pdb 54 | # pdb.set_trace() 55 | self.kernel_size = args.kernel_size if kernel_size is None else kernel_size 56 | self.n_resblocks = args.n_resblocks if n_resblocks is None else n_resblocks 57 | 58 | self.mean_shift = mean_shift 59 | self.rgb_range = args.rgb_range 60 | self.mean = self.rgb_range / 2 61 | 62 | modules = [] 63 | modules.append(common.default_conv(self.in_channels, self.n_feats, self.kernel_size)) 64 | for _ in range(self.n_resblocks): 65 | modules.append(common.ResBlock(self.n_feats, self.kernel_size)) 66 | modules.append(common.default_conv(self.n_feats, self.out_channels, self.kernel_size)) 67 | 68 | self.body = nn.Sequential(*modules) 69 | 70 | def forward(self, input): 71 | 72 | output = self.body(input) 73 | 74 | return output 75 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from importlib import import_module 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn.parallel import DataParallel, DistributedDataParallel 8 | 9 | import torch.distributed as dist 10 | from torch.nn.utils import parameters_to_vector, vector_to_parameters 11 | 12 | 13 | from utils import interact 14 | 15 | 16 | class Model(nn.Module): 17 | def __init__(self, args): 18 | super(Model, self).__init__() 19 | 20 | self.args = args 21 | self.device = args.device 22 | self.n_GPUs = args.n_GPUs 23 | self.save_dir = os.path.join(args.save_dir, 'models') 24 | os.makedirs(self.save_dir, exist_ok=True) 25 | 26 | self.model = nn.ModuleDict() 27 | disp_module = import_module('model.' + args.disp_model) 28 | self.model.disp = disp_module.Model() 29 | 30 | self.model.D = None 31 | self.to(args.device, dtype=args.dtype, non_blocking=True) 32 | self.load(args.load_epoch, path=args.pretrained) 33 | 34 | 35 | def parallelize(self): 36 | if self.args.device_type == 'cuda': 37 | if self.args.distributed: 38 | Parallel = DistributedDataParallel 39 | parallel_args = { 40 | "device_ids": [self.args.rank], 41 | "output_device": self.args.rank, 42 | } 43 | else: 44 | Parallel = DataParallel 45 | parallel_args = { 46 | 'device_ids': list(range(self.n_GPUs)), 47 | 'output_device': self.args.rank # always 0 48 | } 49 | 50 | for model_key in self.model: 51 | if self.model[model_key] is not None: 52 | self.model[model_key] = Parallel(self.model[model_key], **parallel_args) 53 | 54 | def forward(self, input): 55 | return self.model.G(input) 56 | 57 | def disp_forward(self, input): 58 | return self.model.disp(input) 59 | 60 | def _save_path(self, epoch): 61 | model_path = os.path.join(self.save_dir, 'model-{:d}.pt'.format(epoch)) 62 | return model_path 63 | 64 | def state_dict(self): 65 | state_dict = {} 66 | for model_key in self.model: 67 | if self.model[model_key] is not None: 68 | parallelized = isinstance(self.model[model_key], (DataParallel, DistributedDataParallel)) 69 | if parallelized: 70 | state_dict[model_key] = self.model[model_key].module.state_dict() 71 | else: 72 | state_dict[model_key] = self.model[model_key].state_dict() 73 | 74 | return state_dict 75 | 76 | def load_state_dict(self, state_dict, strict=True): 77 | for model_key in self.model: 78 | parallelized = isinstance(self.model[model_key], (DataParallel, DistributedDataParallel)) 79 | if model_key in state_dict: 80 | if parallelized: 81 | self.model[model_key].module.load_state_dict(state_dict[model_key], strict) 82 | else: 83 | self.model[model_key].load_state_dict(state_dict[model_key], strict) 84 | 85 | def save(self, epoch): 86 | torch.save(self.state_dict(), self._save_path(epoch)) 87 | 88 | def load(self, epoch=None, path=None): 89 | 90 | if path: 91 | model_name = path 92 | elif isinstance(epoch, int): 93 | if epoch < 0: 94 | epoch = self.get_last_epoch() 95 | if epoch == 0: # epoch 0 96 | # make sure model parameters are synchronized at initial 97 | # for multi-node training (not in current implementation) 98 | # self.synchronize() 99 | 100 | return # leave model as initialized 101 | 102 | model_name = self._save_path(epoch) 103 | else: 104 | raise Exception('no epoch number or model path specified!') 105 | 106 | print('Loading model from {}'.format(model_name)) 107 | 108 | state_dict = torch.load(model_name, map_location=self.args.device) 109 | self.load_state_dict(state_dict) 110 | 111 | return 112 | 113 | def synchronize(self): 114 | if self.args.distributed: 115 | # synchronize model parameters across nodes 116 | vector = parameters_to_vector(self.parameters()) 117 | 118 | dist.broadcast(vector, 0) # broadcast parameters to other processes 119 | if self.args.rank != 0: 120 | vector_to_parameters(vector, self.parameters()) 121 | 122 | del vector 123 | 124 | return 125 | 126 | def get_last_epoch(self): 127 | model_list = sorted(os.listdir(self.save_dir)) 128 | if len(model_list) == 0: 129 | epoch = 0 130 | else: 131 | epoch = int(re.findall('\\d+', model_list[-1])[0]) # model example name model-100.pt 132 | 133 | return epoch 134 | 135 | def print(self): 136 | print(self.model) 137 | 138 | return 139 | -------------------------------------------------------------------------------- /model/__pycache__/Decoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/Decoder.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/Decoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/Decoder.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/ImageDepthNet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/ImageDepthNet.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/ImageDepthNet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/ImageDepthNet.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/MSResNet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/MSResNet.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/MSResNet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/MSResNet.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/PASM_modules.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/PASM_modules.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/PASM_modules.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/PASM_modules.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/PASMnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/PASMnet.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/PASMnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/PASMnet.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/ResNet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/ResNet.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/ResNet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/ResNet.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/Transformer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/Transformer.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/Transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/Transformer.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/aanet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/aanet.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/affinity_module.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/affinity_module.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/affinity_module.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/affinity_module.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/common.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/common.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/common.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/common.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/discriminator.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/discriminator.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/discriminator.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/discriminator.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/gwc_event.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/gwc_event.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/gwc_image.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/gwc_image.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/gwc_pertu.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/gwc_pertu.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/gwc_pertu_noise.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/gwc_pertu_noise.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/gwc_pertu_noise_KD.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/gwc_pertu_noise_KD.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/gwc_pertu_noise_affinity.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/gwc_pertu_noise_affinity.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/gwc_pertu_noise_deform.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/gwc_pertu_noise_deform.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/gwc_pertu_noise_with_affinity.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/gwc_pertu_noise_with_affinity.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/gwc_pertu_noise_with_affinity.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/gwc_pertu_noise_with_affinity.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/gwcnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/gwcnet.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/image_recon.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/image_recon.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/intensity_MSResNet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/intensity_MSResNet.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/intensity_MSResNet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/intensity_MSResNet.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/pasm_pertu_noise.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/pasm_pertu_noise.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/pasm_pertu_noise.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/pasm_pertu_noise.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/pertu_select_recon.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/pertu_select_recon.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/pertu_select_recon.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/pertu_select_recon.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/perturbations.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/perturbations.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/perturbations.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/perturbations.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/submodule.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/submodule.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/submodule.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/submodule.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/t2t_vit.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/t2t_vit.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/t2t_vit.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/t2t_vit.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/token_performer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/token_performer.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/token_performer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/token_performer.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/token_transformer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/token_transformer.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/token_transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/token_transformer.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/transformer_block.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/transformer_block.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/transformer_block.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/model/__pycache__/transformer_block.cpython-37.pyc -------------------------------------------------------------------------------- /model/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | def default_conv(in_channels, out_channels, kernel_size, bias=True, groups=1): 7 | return nn.Conv2d( 8 | in_channels, out_channels, kernel_size, 9 | padding=(kernel_size // 2), bias=bias, groups=groups) 10 | 11 | def default_norm(n_feats): 12 | return nn.BatchNorm2d(n_feats) 13 | 14 | def default_act(): 15 | return nn.ReLU(True) 16 | 17 | import torch.nn as nn 18 | import torch 19 | 20 | 21 | class ConvLSTMCell(nn.Module): 22 | 23 | def __init__(self, input_dim, hidden_dim, kernel_size, bias): 24 | """ 25 | Initialize ConvLSTM cell. 26 | Parameters 27 | ---------- 28 | input_dim: int 29 | Number of channels of input tensor. 30 | hidden_dim: int 31 | Number of channels of hidden state. 32 | kernel_size: (int, int) 33 | Size of the convolutional kernel. 34 | bias: bool 35 | Whether or not to add the bias. 36 | """ 37 | 38 | super(ConvLSTMCell, self).__init__() 39 | 40 | self.input_dim = input_dim 41 | self.hidden_dim = hidden_dim 42 | 43 | self.kernel_size = kernel_size 44 | self.padding = kernel_size[0] // 2, kernel_size[1] // 2 45 | self.bias = bias 46 | 47 | self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim, 48 | out_channels=4 * self.hidden_dim, 49 | kernel_size=self.kernel_size, 50 | padding=self.padding, 51 | bias=self.bias) 52 | 53 | def forward(self, input_tensor, cur_state): 54 | h_cur, c_cur = cur_state 55 | 56 | combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis 57 | 58 | combined_conv = self.conv(combined) 59 | cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1) 60 | i = torch.sigmoid(cc_i) 61 | f = torch.sigmoid(cc_f) 62 | o = torch.sigmoid(cc_o) 63 | g = torch.tanh(cc_g) 64 | 65 | c_next = f * c_cur + i * g 66 | h_next = o * torch.tanh(c_next) 67 | 68 | return h_next, c_next 69 | 70 | def init_hidden(self, batch_size, image_size): 71 | height, width = image_size 72 | return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device), 73 | torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device)) 74 | 75 | 76 | class ConvLSTM(nn.Module): 77 | 78 | """ 79 | Parameters: 80 | input_dim: Number of channels in input 81 | hidden_dim: Number of hidden channels 82 | kernel_size: Size of kernel in convolutions 83 | num_layers: Number of LSTM layers stacked on each other 84 | batch_first: Whether or not dimension 0 is the batch or not 85 | bias: Bias or no bias in Convolution 86 | return_all_layers: Return the list of computations for all layers 87 | Note: Will do same padding. 88 | Input: 89 | A tensor of size B, T, C, H, W or T, B, C, H, W 90 | Output: 91 | A tuple of two lists of length num_layers (or length 1 if return_all_layers is False). 92 | 0 - layer_output_list is the list of lists of length T of each output 93 | 1 - last_state_list is the list of last states 94 | each element of the list is a tuple (h, c) for hidden state and memory 95 | Example: 96 | >> x = torch.rand((32, 10, 64, 128, 128)) 97 | >> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False) 98 | >> _, last_states = convlstm(x) 99 | >> h = last_states[0][0] # 0 for layer index, 0 for h index 100 | """ 101 | 102 | def __init__(self, input_dim, hidden_dim, kernel_size, num_layers, 103 | batch_first=False, bias=True, return_all_layers=False): 104 | super(ConvLSTM, self).__init__() 105 | 106 | self._check_kernel_size_consistency(kernel_size) 107 | 108 | # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers 109 | kernel_size = self._extend_for_multilayer(kernel_size, num_layers) 110 | hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers) 111 | if not len(kernel_size) == len(hidden_dim) == num_layers: 112 | raise ValueError('Inconsistent list length.') 113 | 114 | self.input_dim = input_dim 115 | self.hidden_dim = hidden_dim 116 | self.kernel_size = kernel_size 117 | self.num_layers = num_layers 118 | self.batch_first = batch_first 119 | self.bias = bias 120 | self.return_all_layers = return_all_layers 121 | 122 | cell_list = [] 123 | for i in range(0, self.num_layers): 124 | cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1] 125 | 126 | cell_list.append(ConvLSTMCell(input_dim=cur_input_dim, 127 | hidden_dim=self.hidden_dim[i], 128 | kernel_size=self.kernel_size[i], 129 | bias=self.bias)) 130 | 131 | self.cell_list = nn.ModuleList(cell_list) 132 | 133 | def forward(self, input_tensor, hidden_state=None): 134 | """ 135 | Parameters 136 | ---------- 137 | input_tensor: todo 138 | 5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w) 139 | hidden_state: todo 140 | None. todo implement stateful 141 | Returns 142 | ------- 143 | last_state_list, layer_output 144 | """ 145 | if not self.batch_first: 146 | # (t, b, c, h, w) -> (b, t, c, h, w) 147 | input_tensor = input_tensor.permute(1, 0, 2, 3, 4) 148 | 149 | b, _, _, h, w = input_tensor.size() 150 | 151 | # Implement stateful ConvLSTM 152 | if hidden_state is not None: 153 | raise NotImplementedError() 154 | else: 155 | # Since the init is done in forward. Can send image size here 156 | hidden_state = self._init_hidden(batch_size=b, 157 | image_size=(h, w)) 158 | 159 | layer_output_list = [] 160 | last_state_list = [] 161 | 162 | seq_len = input_tensor.size(1) 163 | cur_layer_input = input_tensor 164 | 165 | for layer_idx in range(self.num_layers): 166 | 167 | h, c = hidden_state[layer_idx] 168 | output_inner = [] 169 | for t in range(seq_len): 170 | h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :], 171 | cur_state=[h, c]) 172 | output_inner.append(h) 173 | 174 | layer_output = torch.stack(output_inner, dim=1) 175 | cur_layer_input = layer_output 176 | 177 | layer_output_list.append(layer_output) 178 | last_state_list.append([h, c]) 179 | 180 | if not self.return_all_layers: 181 | layer_output_list = layer_output_list[-1:] 182 | last_state_list = last_state_list[-1:] 183 | 184 | return layer_output_list, last_state_list 185 | 186 | def _init_hidden(self, batch_size, image_size): 187 | init_states = [] 188 | for i in range(self.num_layers): 189 | init_states.append(self.cell_list[i].init_hidden(batch_size, image_size)) 190 | return init_states 191 | 192 | @staticmethod 193 | def _check_kernel_size_consistency(kernel_size): 194 | if not (isinstance(kernel_size, tuple) or 195 | (isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))): 196 | raise ValueError('`kernel_size` must be tuple or list of tuples') 197 | 198 | @staticmethod 199 | def _extend_for_multilayer(param, num_layers): 200 | if not isinstance(param, list): 201 | param = [param] * num_layers 202 | return param 203 | 204 | def empty_h(x, n_feats): 205 | ''' 206 | create an empty hidden state 207 | 208 | input 209 | x: B x T x 3 x H x W 210 | 211 | output 212 | h: B x C x H/4 x W/4 213 | ''' 214 | b = x.size(0) 215 | h, w = x.size()[-2:] 216 | return x.new_zeros((b, n_feats, h//4, w//4)) 217 | 218 | class Normalization(nn.Conv2d): 219 | """Normalize input tensor value with convolutional layer""" 220 | def __init__(self, mean=(0, 0, 0), std=(1, 1, 1)): 221 | super(Normalization, self).__init__(3, 3, kernel_size=1) 222 | tensor_mean = torch.Tensor(mean) 223 | tensor_inv_std = torch.Tensor(std).reciprocal() 224 | 225 | self.weight.data = torch.eye(3).mul(tensor_inv_std).view(3, 3, 1, 1) 226 | self.bias.data = torch.Tensor(-tensor_mean.mul(tensor_inv_std)) 227 | 228 | for params in self.parameters(): 229 | params.requires_grad = False 230 | 231 | class BasicBlock(nn.Sequential): 232 | """Convolution layer + Activation layer""" 233 | def __init__( 234 | self, in_channels, out_channels, kernel_size, bias=True, 235 | conv=default_conv, norm=False, act=default_act): 236 | 237 | modules = [] 238 | modules.append( 239 | conv(in_channels, out_channels, kernel_size, bias=bias)) 240 | if norm: modules.append(norm(out_channels)) 241 | if act: modules.append(act()) 242 | 243 | super(BasicBlock, self).__init__(*modules) 244 | 245 | class ResBlock(nn.Module): 246 | def __init__( 247 | self, n_feats, kernel_size, bias=True, 248 | conv=default_conv, norm=False, act=default_act): 249 | 250 | super(ResBlock, self).__init__() 251 | 252 | modules = [] 253 | for i in range(2): 254 | modules.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 255 | if norm: modules.append(norm(n_feats)) 256 | if act and i == 0: modules.append(act()) 257 | 258 | self.body = nn.Sequential(*modules) 259 | 260 | def forward(self, x): 261 | res = self.body(x) 262 | res += x 263 | 264 | return res 265 | 266 | class ResBlock_mobile(nn.Module): 267 | def __init__( 268 | self, n_feats, kernel_size, bias=True, 269 | conv=default_conv, norm=False, act=default_act, dropout=False): 270 | 271 | super(ResBlock_mobile, self).__init__() 272 | 273 | modules = [] 274 | for i in range(2): 275 | modules.append(conv(n_feats, n_feats, kernel_size, bias=False, groups=n_feats)) 276 | modules.append(conv(n_feats, n_feats, 1, bias=False)) 277 | if dropout and i == 0: modules.append(nn.Dropout2d(dropout)) 278 | if norm: modules.append(norm(n_feats)) 279 | if act and i == 0: modules.append(act()) 280 | 281 | self.body = nn.Sequential(*modules) 282 | 283 | def forward(self, x): 284 | res = self.body(x) 285 | res += x 286 | 287 | return res 288 | 289 | class Upsampler(nn.Sequential): 290 | def __init__( 291 | self, scale, n_feats, bias=True, 292 | conv=default_conv, norm=False, act=False): 293 | 294 | modules = [] 295 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 296 | for _ in range(int(math.log(scale, 2))): 297 | modules.append(conv(n_feats, 4 * n_feats, 3, bias)) 298 | modules.append(nn.PixelShuffle(2)) 299 | if norm: modules.append(norm(n_feats)) 300 | if act: modules.append(act()) 301 | elif scale == 3: 302 | modules.append(conv(n_feats, 9 * n_feats, 3, bias)) 303 | modules.append(nn.PixelShuffle(3)) 304 | if norm: modules.append(norm(n_feats)) 305 | if act: modules.append(act()) 306 | else: 307 | raise NotImplementedError 308 | 309 | super(Upsampler, self).__init__(*modules) 310 | 311 | # Only support 1 / 2 312 | class PixelSort(nn.Module): 313 | """The inverse operation of PixelShuffle 314 | Reduces the spatial resolution, increasing the number of channels. 315 | Currently, scale 0.5 is supported only. 316 | Later, torch.nn.functional.pixel_sort may be implemented. 317 | Reference: 318 | http://pytorch.org/docs/0.3.0/_modules/torch/nn/modules/pixelshuffle.html#PixelShuffle 319 | http://pytorch.org/docs/0.3.0/_modules/torch/nn/functional.html#pixel_shuffle 320 | """ 321 | def __init__(self, upscale_factor=0.5): 322 | super(PixelSort, self).__init__() 323 | self.upscale_factor = upscale_factor 324 | 325 | def forward(self, x): 326 | b, c, h, w = x.size() 327 | x = x.view(b, c, 2, 2, h // 2, w // 2) 328 | x = x.permute(0, 1, 5, 3, 2, 4).contiguous() 329 | x = x.view(b, 4 * c, h // 2, w // 2) 330 | 331 | return x 332 | 333 | class Downsampler(nn.Sequential): 334 | def __init__( 335 | self, scale, n_feats, bias=True, 336 | conv=default_conv, norm=False, act=False): 337 | 338 | modules = [] 339 | if scale == 0.5: 340 | modules.append(PixelSort()) 341 | modules.append(conv(4 * n_feats, n_feats, 3, bias)) 342 | if norm: modules.append(norm(n_feats)) 343 | if act: modules.append(act()) 344 | else: 345 | raise NotImplementedError 346 | 347 | super(Downsampler, self).__init__(*modules) 348 | 349 | -------------------------------------------------------------------------------- /model/image_recon.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from . import common 5 | from .ResNet import ResNet 6 | from .ResNet import ResNet_event 7 | from data import common as cm 8 | 9 | import math 10 | 11 | class SizeAdapter(object): 12 | """Converts size of input to standard size. 13 | 14 | Practical deep network works only with input images 15 | which height and width are multiples of a minimum size. 16 | This class allows to pass to the network images of arbitrary 17 | size, by padding the input to the closest multiple 18 | and unpadding the network's output to the original size. 19 | """ 20 | 21 | def __init__(self, minimum_size=64): 22 | self._minimum_size = minimum_size 23 | self._pixels_pad_to_width = None 24 | self._pixels_pad_to_height = None 25 | 26 | def _closest_larger_multiple_of_minimum_size(self, size): 27 | return int(math.ceil(size / self._minimum_size) * self._minimum_size) 28 | 29 | def pad(self, network_input): 30 | """Returns "network_input" paded with zeros to the "standard" size. 31 | 32 | The "standard" size correspond to the height and width that 33 | are closest multiples of "minimum_size". The method pads 34 | height and width and and saves padded values. These 35 | values are then used by "unpad_output" method. 36 | """ 37 | height, width = network_input.size()[-2:] 38 | self._pixels_pad_to_height = ( 39 | self._closest_larger_multiple_of_minimum_size(height) - height) 40 | self._pixels_pad_to_width = ( 41 | self._closest_larger_multiple_of_minimum_size(width) - width) 42 | return nn.ZeroPad2d((self._pixels_pad_to_width, 0, 43 | self._pixels_pad_to_height, 0))(network_input) 44 | 45 | def unpad(self, network_output): 46 | """Returns "network_output" cropped to the original size. 47 | 48 | The cropping is performed using values save by the "pad_input" 49 | method. 50 | """ 51 | return network_output[..., self._pixels_pad_to_height:, self. 52 | _pixels_pad_to_width:] 53 | 54 | 55 | def build_model(args): 56 | return MSResNet_with_event(args) 57 | # return MSResNet(args) 58 | 59 | class conv_end(nn.Module): 60 | def __init__(self, in_channels=3, out_channels=3, kernel_size=5, ratio=2): 61 | super(conv_end, self).__init__() 62 | 63 | modules = [ 64 | common.default_conv(in_channels, out_channels, kernel_size), 65 | nn.PixelShuffle(ratio) 66 | ] 67 | 68 | self.uppath = nn.Sequential(*modules) 69 | 70 | def forward(self, x): 71 | return self.uppath(x) 72 | 73 | class MSResNet(nn.Module): 74 | def __init__(self, args): 75 | super(MSResNet, self).__init__() 76 | 77 | self.rgb_range = args.rgb_range 78 | self.mean = self.rgb_range / 2 79 | 80 | self.n_resblocks = args.n_resblocks 81 | self.n_feats = args.n_feats 82 | self.kernel_size = args.kernel_size 83 | 84 | self.n_scales = args.n_scales 85 | 86 | self.body_models = nn.ModuleList([ 87 | ResNet(args, 3, 3, mean_shift=False), 88 | ]) 89 | for _ in range(1, self.n_scales): 90 | self.body_models.insert(0, ResNet(args, 6, 3, mean_shift=False)) 91 | 92 | self.conv_end_models = nn.ModuleList([None]) 93 | for _ in range(1, self.n_scales): 94 | self.conv_end_models += [conv_end(3, 12)] 95 | 96 | 97 | 98 | def forward(self, input_pyramid): 99 | 100 | scales = range(self.n_scales-1, -1, -1) # 0: fine, 2: coarse 101 | 102 | for s in scales: 103 | input_pyramid[s] = input_pyramid[s] - self.mean 104 | 105 | output_pyramid = [None] * self.n_scales 106 | 107 | input_s = input_pyramid[-1] 108 | for s in scales: # [2, 1, 0] 109 | output_pyramid[s] = self.body_models[s](input_s) 110 | if s > 0: 111 | up_feat = self.conv_end_models[s](output_pyramid[s]) 112 | input_s = torch.cat((input_pyramid[s-1], up_feat), 1) 113 | 114 | for s in scales: 115 | output_pyramid[s] = output_pyramid[s] + self.mean 116 | 117 | return output_pyramid 118 | 119 | class MSResNet_with_event(nn.Module): 120 | def __init__(self, args): 121 | super(MSResNet_with_event, self).__init__() 122 | 123 | self.rgb_range = args.rgb_range 124 | self.mean = self.rgb_range / 2 125 | 126 | self.n_resblocks = args.n_resblocks 127 | self.n_feats = args.n_feats 128 | self.kernel_size = args.kernel_size 129 | 130 | self.n_scales = args.n_scales 131 | 132 | # self.body_models = nn.ModuleList([ 133 | # ResNet_event(args, 8, 3, mean_shift=False), 134 | # ]) 135 | self.body_models = nn.ModuleList([ 136 | ResNet_event(args, 5, 3, mean_shift=False), 137 | ]) 138 | # for _ in range(1, self.n_scales): 139 | # self.body_models.insert(0, ResNet_event(args, 11, 3, mean_shift=False)) 140 | for _ in range(1, self.n_scales): 141 | self.body_models.insert(0, ResNet_event(args, 8, 3, mean_shift=False)) 142 | 143 | self.conv_end_models = nn.ModuleList([None]) 144 | for _ in range(1, self.n_scales): 145 | self.conv_end_models += [conv_end(3, 12)] 146 | 147 | 148 | def forward(self, input_pyramid): 149 | 150 | image_pyramid, event_pyramid = input_pyramid[0], input_pyramid[1] 151 | scales = range(self.n_scales-1, -1, -1) # 0: fine, 2: coarse 152 | 153 | 154 | 155 | input_pyramid = [None] * self.n_scales 156 | output_pyramid = [None] * self.n_scales 157 | 158 | 159 | # for s in scales: 160 | # input_pyramid[s] = torch.cat((image_pyramid[s], event_pyramid[s]), 1) 161 | 162 | for s in scales: 163 | input_pyramid[s] = event_pyramid[s] 164 | 165 | 166 | input_s = input_pyramid[-1] 167 | for s in scales: # [2, 1, 0] 168 | output_pyramid[s] = self.body_models[s](input_s) 169 | import pdb 170 | pdb.set_trace() 171 | if s > 0: 172 | up_feat = self.conv_end_models[s](output_pyramid[s]) 173 | input_s = torch.cat((input_pyramid[s-1], up_feat), 1) 174 | 175 | 176 | 177 | return output_pyramid 178 | -------------------------------------------------------------------------------- /model/intensity_MSResNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from . import common 5 | from .ResNet import ResNet 6 | from .ResNet import ResNet_event 7 | from data import common as cm 8 | 9 | import math 10 | 11 | class SizeAdapter(object): 12 | """Converts size of input to standard size. 13 | 14 | Practical deep network works only with input images 15 | which height and width are multiples of a minimum size. 16 | This class allows to pass to the network images of arbitrary 17 | size, by padding the input to the closest multiple 18 | and unpadding the network's output to the original size. 19 | """ 20 | 21 | def __init__(self, minimum_size=64): 22 | self._minimum_size = minimum_size 23 | self._pixels_pad_to_width = None 24 | self._pixels_pad_to_height = None 25 | 26 | def _closest_larger_multiple_of_minimum_size(self, size): 27 | return int(math.ceil(size / self._minimum_size) * self._minimum_size) 28 | 29 | def pad(self, network_input): 30 | """Returns "network_input" paded with zeros to the "standard" size. 31 | 32 | The "standard" size correspond to the height and width that 33 | are closest multiples of "minimum_size". The method pads 34 | height and width and and saves padded values. These 35 | values are then used by "unpad_output" method. 36 | """ 37 | height, width = network_input.size()[-2:] 38 | self._pixels_pad_to_height = ( 39 | self._closest_larger_multiple_of_minimum_size(height) - height) 40 | self._pixels_pad_to_width = ( 41 | self._closest_larger_multiple_of_minimum_size(width) - width) 42 | return nn.ZeroPad2d((self._pixels_pad_to_width, 0, 43 | self._pixels_pad_to_height, 0))(network_input) 44 | 45 | def unpad(self, network_output): 46 | """Returns "network_output" cropped to the original size. 47 | 48 | The cropping is performed using values save by the "pad_input" 49 | method. 50 | """ 51 | return network_output[..., self._pixels_pad_to_height:, self. 52 | _pixels_pad_to_width:] 53 | 54 | 55 | def build_model(args): 56 | return MSResNet_with_event(args) 57 | # return MSResNet(args) 58 | 59 | class conv_end(nn.Module): 60 | def __init__(self, in_channels=3, out_channels=3, kernel_size=5, ratio=2): 61 | super(conv_end, self).__init__() 62 | 63 | modules = [ 64 | common.default_conv(in_channels, out_channels, kernel_size), 65 | nn.PixelShuffle(ratio) 66 | ] 67 | 68 | self.uppath = nn.Sequential(*modules) 69 | 70 | def forward(self, x): 71 | return self.uppath(x) 72 | 73 | class MSResNet(nn.Module): 74 | def __init__(self, args): 75 | super(MSResNet, self).__init__() 76 | 77 | self.rgb_range = args.rgb_range 78 | self.mean = self.rgb_range / 2 79 | 80 | self.n_resblocks = args.n_resblocks 81 | self.n_feats = args.n_feats 82 | self.kernel_size = args.kernel_size 83 | 84 | self.n_scales = args.n_scales 85 | 86 | self.body_models = nn.ModuleList([ 87 | ResNet(args, 3, 3, mean_shift=False), 88 | ]) 89 | for _ in range(1, self.n_scales): 90 | self.body_models.insert(0, ResNet(args, 6, 3, mean_shift=False)) 91 | 92 | self.conv_end_models = nn.ModuleList([None]) 93 | for _ in range(1, self.n_scales): 94 | self.conv_end_models += [conv_end(3, 12)] 95 | 96 | 97 | 98 | def forward(self, input_pyramid): 99 | 100 | scales = range(self.n_scales-1, -1, -1) # 0: fine, 2: coarse 101 | 102 | for s in scales: 103 | input_pyramid[s] = input_pyramid[s] - self.mean 104 | 105 | output_pyramid = [None] * self.n_scales 106 | 107 | input_s = input_pyramid[-1] 108 | for s in scales: # [2, 1, 0] 109 | output_pyramid[s] = self.body_models[s](input_s) 110 | if s > 0: 111 | up_feat = self.conv_end_models[s](output_pyramid[s]) 112 | input_s = torch.cat((input_pyramid[s-1], up_feat), 1) 113 | 114 | for s in scales: 115 | output_pyramid[s] = output_pyramid[s] + self.mean 116 | 117 | return output_pyramid 118 | 119 | class MSResNet_with_event(nn.Module): 120 | def __init__(self, args): 121 | super(MSResNet_with_event, self).__init__() 122 | 123 | self.rgb_range = args.rgb_range 124 | self.mean = self.rgb_range / 2 125 | 126 | self.n_resblocks = args.n_resblocks 127 | self.n_feats = args.n_feats 128 | self.kernel_size = args.kernel_size 129 | 130 | self.n_scales = args.n_scales 131 | 132 | self.body_models = nn.ModuleList([ 133 | ResNet_event(args, 8, 3, mean_shift=False), 134 | ]) 135 | for _ in range(1, self.n_scales): 136 | self.body_models.insert(0, ResNet_event(args, 11, 3, mean_shift=False)) 137 | 138 | self.conv_end_models = nn.ModuleList([None]) 139 | for _ in range(1, self.n_scales): 140 | self.conv_end_models += [conv_end(3, 12)] 141 | 142 | # self.size_adaptor = SizeAdapter() 143 | 144 | # def forward(self, event_pyramid, input_pyramid): 145 | def forward(self, input_pyramid): 146 | 147 | event_pyramid, input_pyramid, mode, n_scales = input_pyramid[1], input_pyramid[0], input_pyramid[2], input_pyramid[3] 148 | 149 | # import pdb 150 | # pdb.set_trace() 151 | scales = range(self.n_scales-1, -1, -1) # 0: fine, 2: coarse 152 | 153 | for s in scales: 154 | input_pyramid[s] = input_pyramid[s] - self.mean 155 | 156 | output_pyramid = [None] * self.n_scales 157 | 158 | for s in scales: 159 | input_pyramid[s] = torch.cat((input_pyramid[s], event_pyramid[s]), 1) 160 | 161 | input_s = input_pyramid[-1] 162 | for s in scales: # [2, 1, 0] 163 | # import pdb 164 | # pdb.set_trace() 165 | output_pyramid[s] = self.body_models[s](input_s) 166 | if s > 0: 167 | up_feat = self.conv_end_models[s](output_pyramid[s]) 168 | # import pdb 169 | # pdb.set_trace() 170 | input_s = torch.cat((input_pyramid[s-1], up_feat), 1) 171 | 172 | for s in scales: 173 | output_pyramid[s] = output_pyramid[s] + self.mean 174 | 175 | return output_pyramid 176 | -------------------------------------------------------------------------------- /model/pertu_select_recon.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from . import common 5 | from .ResNet import ResNet 6 | # from .ResNet import ResNet_event 7 | from data import common as cm 8 | from model import perturbations 9 | import math 10 | import torch.nn.functional as F 11 | 12 | class ResNet_event(nn.Module): 13 | def __init__(self, args, in_channels=18, out_channels=3, n_feats=None, kernel_size=None, n_resblocks=None, mean_shift=True): 14 | super(ResNet_event, self).__init__() 15 | 16 | self.in_channels = in_channels 17 | self.out_channels = out_channels 18 | 19 | self.n_feats = args.n_feats if n_feats is None else n_feats 20 | # import pdb 21 | # pdb.set_trace() 22 | self.kernel_size = args.kernel_size if kernel_size is None else kernel_size 23 | self.n_resblocks = args.n_resblocks if n_resblocks is None else n_resblocks 24 | 25 | self.mean_shift = mean_shift 26 | self.rgb_range = args.rgb_range 27 | self.mean = self.rgb_range / 2 28 | 29 | modules = [] 30 | modules.append(common.default_conv(self.in_channels, self.n_feats, self.kernel_size)) 31 | for _ in range(self.n_resblocks): 32 | modules.append(common.ResBlock(self.n_feats, self.kernel_size)) 33 | modules.append(common.default_conv(self.n_feats, self.out_channels, self.kernel_size)) 34 | 35 | self.body = nn.Sequential(*modules) 36 | 37 | def forward(self, input): 38 | 39 | output = self.body(input) 40 | 41 | return output 42 | 43 | def argtopk(x, axis=-1): 44 | _, index = torch.topk(x, k=3, dim=axis) 45 | 46 | # expand = torch.nn.functional.one_hot(index.squeeze()) 47 | # print(expand.shape) 48 | # output = expand.float() 49 | 50 | return F.one_hot(index, list(x.shape)[axis]).float() 51 | 52 | def argmax(x, axis=-1): 53 | return F.one_hot(torch.argmax(x, dim=axis), list(x.shape)[axis]).float() 54 | 55 | 56 | def build_model(args): 57 | return MSResNet_with_event(args) 58 | # return MSResNet(args) 59 | 60 | class conv_end(nn.Module): 61 | def __init__(self, in_channels=3, out_channels=3, kernel_size=5, ratio=2): 62 | super(conv_end, self).__init__() 63 | 64 | modules = [ 65 | common.default_conv(in_channels, out_channels, kernel_size), 66 | nn.PixelShuffle(ratio) 67 | ] 68 | 69 | self.uppath = nn.Sequential(*modules) 70 | 71 | def forward(self, x): 72 | return self.uppath(x) 73 | 74 | class MSResNet(nn.Module): 75 | def __init__(self, args): 76 | super(MSResNet, self).__init__() 77 | 78 | self.rgb_range = args.rgb_range 79 | self.mean = self.rgb_range / 2 80 | 81 | self.n_resblocks = args.n_resblocks 82 | self.n_feats = args.n_feats 83 | self.kernel_size = args.kernel_size 84 | 85 | self.n_scales = args.n_scales 86 | 87 | self.body_models = nn.ModuleList([ 88 | ResNet(args, 3, 3, mean_shift=False), 89 | ]) 90 | for _ in range(1, self.n_scales): 91 | self.body_models.insert(0, ResNet(args, 6, 3, mean_shift=False)) 92 | 93 | self.conv_end_models = nn.ModuleList([None]) 94 | for _ in range(1, self.n_scales): 95 | self.conv_end_models += [conv_end(3, 12)] 96 | 97 | 98 | 99 | def forward(self, input_pyramid): 100 | 101 | scales = range(self.n_scales-1, -1, -1) # 0: fine, 2: coarse 102 | 103 | for s in scales: 104 | input_pyramid[s] = input_pyramid[s] - self.mean 105 | 106 | output_pyramid = [None] * self.n_scales 107 | 108 | input_s = input_pyramid[-1] 109 | for s in scales: # [2, 1, 0] 110 | output_pyramid[s] = self.body_models[s](input_s) 111 | if s > 0: 112 | up_feat = self.conv_end_models[s](output_pyramid[s]) 113 | input_s = torch.cat((input_pyramid[s-1], up_feat), 1) 114 | 115 | for s in scales: 116 | output_pyramid[s] = output_pyramid[s] + self.mean 117 | 118 | return output_pyramid 119 | 120 | class MSResNet_with_event(nn.Module): 121 | def __init__(self, args): 122 | super(MSResNet_with_event, self).__init__() 123 | 124 | self.rgb_range = args.rgb_range 125 | self.mean = self.rgb_range / 2 126 | 127 | self.n_resblocks = args.n_resblocks 128 | self.n_feats = args.n_feats 129 | self.kernel_size = args.kernel_size 130 | 131 | self.n_scales = args.n_scales 132 | 133 | 134 | 135 | 136 | self.encoding_with_image = nn.ModuleList() 137 | for s in range(self.n_scales): 138 | out_dim = 32*(2**s) 139 | self.encoding_with_image.append(ResNet_event(args, 8, mean_shift=False, n_resblocks = 7, n_feats = 16*(2**s), out_channels = out_dim)) 140 | 141 | self.conv_score = [] 142 | self.conv_score.append(nn.Conv2d(224, 128, kernel_size=3, padding=(3 // 2), stride = 2, bias=True)) 143 | self.conv_score.append(nn.Conv2d(128, 32, kernel_size=3, padding=(3 // 2), stride = 2, bias=True)) 144 | self.conv_score.append(nn.Conv2d(32, 5, kernel_size=3, padding=(3 // 2), stride = 2, bias=True)) 145 | self.conv_score.append(nn.AvgPool2d(32)) 146 | self.conv_score = nn.Sequential(*self.conv_score) 147 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 148 | self.pert_argtopk = perturbations.perturbed(argtopk, 149 | num_samples=200, 150 | sigma=0.05, 151 | noise='gumbel', 152 | batched=True, 153 | device=self.device) 154 | 155 | self.body_models = [ResNet_event(args, 3, 3, mean_shift=False)] 156 | self.body_models.append(ResNet_event(args, 3, 3, mean_shift=False)) 157 | self.body_models = nn.Sequential(*self.body_models) 158 | 159 | 160 | 161 | def forward(self, input_pyramid): 162 | 163 | image_pyramid, event_pyramid = input_pyramid[0], input_pyramid[1] 164 | scales = range(self.n_scales-1, -1, -1) # 0: fine, 2: coarse 165 | 166 | concat_pyramid = [None] * self.n_scales 167 | encoded_pyramid = [None] * self.n_scales 168 | for s in scales: 169 | concat_pyramid[s] = torch.cat((image_pyramid[s], event_pyramid[s]), 1) 170 | 171 | for s in range(self.n_scales): # [2, 1, 0] 172 | input_s = concat_pyramid[s] 173 | encoded_pyramid[s] = self.encoding_with_image[s](input_s) 174 | if s > 0: 175 | upsample = nn.Upsample(scale_factor = 2**s, mode = 'bilinear') 176 | encoded_pyramid[s] = upsample(encoded_pyramid[s]) 177 | encoded_feature = torch.cat((encoded_feature, encoded_pyramid[s]), 1) 178 | else: 179 | encoded_feature = encoded_pyramid[s] 180 | b, c, h, w = event_pyramid[0].shape 181 | score = self.conv_score(encoded_feature).squeeze(-1).squeeze(-1) 182 | if self.training: 183 | one_hot = self.pert_argtopk(score) 184 | else: 185 | one_hot = argtopk(score) 186 | 187 | 188 | 189 | # print("score shape") 190 | # print(score.shape) 191 | # print(score) 192 | # print("one hot shape") 193 | # print(one_hot.shape) 194 | # print(one_hot) 195 | 196 | input_event = event_pyramid[0].view(b, c, -1) 197 | 198 | select_event = torch.bmm(one_hot, input_event).view(b, -1, h, w) 199 | 200 | output = self.body_models(select_event) 201 | 202 | return output 203 | -------------------------------------------------------------------------------- /model/perturbations.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 3 | # Modifications from original work 4 | # 29-03-2021 (tuero@ualberta.ca) : Convert Tensorflow code to PyTorch 5 | # 6 | # Copyright 2021 The Google Research Authors. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | # Lint as: python3 21 | """Introduces differentiation via perturbations. 22 | 23 | Example of usage: 24 | 25 | @perturbed 26 | def sign_or(x, axis=-1): 27 | s = ((torch.sign(x) + 1) / 2.0).type(torch.bool) 28 | result = torch.any(s, dim=-1) 29 | return result.type(torch.float) * 2.0 - 1 30 | 31 | 32 | Then sign_or is differentiable (unlike what it seems). 33 | 34 | It is possible to specify the parameters of the perturbations using: 35 | @perturbed(num_samples=1000, sigma=0.1, noise='gumbel') 36 | ... 37 | 38 | The decorator can also be used directly as a function, for example: 39 | soft_argsort = perturbed(torch.argsort, num_samples=200, sigma=0.01) 40 | """ 41 | 42 | import functools 43 | from typing import Tuple 44 | import torch 45 | from torch.distributions.gumbel import Gumbel 46 | from torch.distributions.normal import Normal 47 | 48 | _GUMBEL = 'gumbel' 49 | _NORMAL = 'normal' 50 | SUPPORTED_NOISES = (_GUMBEL, _NORMAL) 51 | 52 | 53 | def sample_noise_with_gradients(noise, shape): 54 | """Samples a noise tensor according to a distribution with its gradient. 55 | 56 | Args: 57 | noise: (str) a type of supported noise distribution. 58 | shape: torch.tensor, the shape of the tensor to sample. 59 | 60 | Returns: 61 | A tuple Tensor[shape], Tensor[shape] that corresponds to the 62 | sampled noise and the gradient of log the underlying probability 63 | distribution function. For instance, for a gaussian noise (normal), the 64 | gradient is equal to the noise itself. 65 | 66 | Raises: 67 | ValueError in case the requested noise distribution is not supported. 68 | See perturbations.SUPPORTED_NOISES for the list of supported distributions. 69 | """ 70 | if noise not in SUPPORTED_NOISES: 71 | raise ValueError('{} noise is not supported. Use one of [{}]'.format( 72 | noise, SUPPORTED_NOISES)) 73 | 74 | if noise == _GUMBEL: 75 | sampler = Gumbel(0.0, 1.0) 76 | samples = sampler.sample(shape) 77 | gradients = 1 - torch.exp(-samples) 78 | elif noise == _NORMAL: 79 | sampler = Normal(0.0, 1.0) 80 | samples = sampler.sample(shape) 81 | gradients = samples 82 | 83 | return samples, gradients 84 | 85 | 86 | def perturbed(func=None, 87 | num_samples = 1000, 88 | sigma = 0.05, 89 | noise = _NORMAL, 90 | batched = True, 91 | device=None): 92 | """Turns a function into a differentiable one via perturbations. 93 | 94 | The input function has to be the solution to a linear program for the trick 95 | to work. For instance the maximum function, the logical operators or the ranks 96 | can be expressed as solutions to some linear programs on some polytopes. 97 | If this condition is violated though, the result would not hold and there is 98 | no guarantee on the validity of the obtained gradients. 99 | 100 | This function can be used directly or as a decorator. 101 | 102 | Args: 103 | func: the function to be turned into a perturbed and differentiable one. 104 | Four I/O signatures for func are currently supported: 105 | If batched is True, 106 | (1) input [B, D1, ..., Dk], output [B, D1, ..., Dk], k >= 1 107 | (2) input [B, D1, ..., Dk], output [B], k >= 1 108 | If batched is False, 109 | (3) input [D1, ..., Dk], output [D1, ..., Dk], k >= 1 110 | (4) input [D1, ..., Dk], output [], k >= 1. 111 | num_samples: the number of samples to use for the expectation computation. 112 | sigma: the scale of the perturbation. 113 | noise: a string representing the noise distribution to be used to sample 114 | perturbations. 115 | batched: whether inputs to the perturbed function will have a leading batch 116 | dimension (True) or consist of a single example (False). Defaults to True. 117 | device: The device to create tensors on (cpu/gpu). If None given, it will 118 | default to gpu:0 if available, cpu otherwise. 119 | 120 | Returns: 121 | a function has the same signature as func but that can be back propagated. 122 | """ 123 | # If device not supplied, auto detect 124 | if device is None: 125 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 126 | # This is a trick to have the decorator work both with and without arguments. 127 | if func is None: 128 | return functools.partial( 129 | perturbed, num_samples=num_samples, sigma=sigma, noise=noise, 130 | batched=batched, device=device) 131 | 132 | @functools.wraps(func) 133 | def wrapper(input_tensor, *args): 134 | class PerturbedFunc(torch.autograd.Function): 135 | 136 | @staticmethod 137 | def forward(ctx, input_tensor, *args): 138 | original_input_shape = input_tensor.shape 139 | if batched: 140 | if not input_tensor.dim() >= 2: 141 | raise ValueError('Batched inputs must have at least rank two') 142 | else: # Adds dummy batch dimension internally. 143 | input_tensor = input_tensor.unsqueeze(0) 144 | input_shape = input_tensor.shape # [B, D1, ... Dk], k >= 1 145 | perturbed_input_shape = [num_samples] + list(input_shape) 146 | 147 | noises = sample_noise_with_gradients(noise, perturbed_input_shape) 148 | additive_noise, noise_gradient = tuple( 149 | [noise.type(input_tensor.dtype) for noise in noises]) 150 | additive_noise = additive_noise.to(device) 151 | noise_gradient = noise_gradient.to(device) 152 | perturbed_input = input_tensor.unsqueeze(0) + sigma * additive_noise 153 | 154 | # [N, B, D1, ..., Dk] -> [NB, D1, ..., Dk]. 155 | flat_batch_dim_shape = [-1] + list(input_shape)[1:] 156 | perturbed_input = torch.reshape(perturbed_input, flat_batch_dim_shape) 157 | # Calls user-defined function in a perturbation agnostic manner. 158 | perturbed_output = func(perturbed_input, *args) 159 | # [NB, D1, ..., Dk] -> [N, B, D1, ..., Dk]. 160 | perturbed_input = torch.reshape(perturbed_input, perturbed_input_shape) 161 | # Either 162 | # (Default case): [NB, D1, ..., Dk] -> [N, B, D1, ..., Dk] 163 | # or 164 | # (Full-reduce case) [NB] -> [N, B] 165 | perturbed_output_shape = [num_samples, -1] + list(perturbed_output.shape)[1:] 166 | perturbed_output = torch.reshape(perturbed_output, perturbed_output_shape) 167 | # import pdb 168 | # pdb.set_trace() 169 | forward_output = torch.mean(perturbed_output, dim=0) 170 | 171 | if not batched: # Removes dummy batch dimension. 172 | forward_output = forward_output[0] 173 | 174 | # Save context for backward pass 175 | ctx.save_for_backward(perturbed_input, perturbed_output, noise_gradient) 176 | ctx.original_input_shape = original_input_shape 177 | 178 | return forward_output 179 | 180 | @staticmethod 181 | def backward(ctx, dy): 182 | # Pull saved tensors 183 | original_input_shape = ctx.original_input_shape 184 | perturbed_input, perturbed_output, noise_gradient = ctx.saved_tensors 185 | output, noise_grad = perturbed_output, noise_gradient 186 | # Adds dummy feature/channel dimension internally. 187 | if perturbed_input.dim() > output.dim(): 188 | dy = dy.unsqueeze(-1) 189 | output = output.unsqueeze(-1) 190 | # Adds dummy batch dimension internally. 191 | if not batched: 192 | dy = dy.unsqueeze(0) 193 | # Flattens [D1, ..., Dk] to a single feat dim [D]. 194 | flatten = lambda t: torch.reshape(t, (list(t.shape)[0], list(t.shape)[1], -1)) 195 | dy = torch.reshape(dy, (list(dy.shape)[0], -1)) # (B, D) 196 | output = flatten(output) # (N, B, D) 197 | noise_grad = flatten(noise_grad) # (N, B, D) 198 | 199 | g = torch.einsum('nbd,nb->bd', noise_grad, torch.einsum('nbd,bd->nb', output, dy)) 200 | g /= sigma * num_samples 201 | return torch.reshape(g, original_input_shape) 202 | 203 | return PerturbedFunc.apply(input_tensor, *args) 204 | 205 | return wrapper 206 | -------------------------------------------------------------------------------- /model/structure.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .common import ResBlock, default_conv 4 | 5 | def encoder(in_channels, n_feats): 6 | """RGB / IR feature encoder 7 | """ 8 | 9 | # in_channels == 1 or 3 or 4 or .... 10 | # After 1st conv, B x n_feats x H x W 11 | # After 2nd conv, B x 2n_feats x H/2 x W/2 12 | # After 3rd conv, B x 3n_feats x H/4 x W/4 13 | return nn.Sequential( 14 | nn.Conv2d(in_channels, 1 * n_feats, 5, stride=1, padding=2), 15 | nn.Conv2d(1 * n_feats, 2 * n_feats, 5, stride=2, padding=2), 16 | nn.Conv2d(2 * n_feats, 3 * n_feats, 5, stride=2, padding=2), 17 | ) 18 | 19 | def decoder(out_channels, n_feats): 20 | """RGB / IR / Depth decoder 21 | """ 22 | # After 1st deconv, B x 2n_feats x H/2 x W/2 23 | # After 2nd deconv, B x n_feats x H x W 24 | # After 3rd conv, B x out_channels x H x W 25 | deconv_kargs = {'stride': 2, 'padding': 1, 'output_padding': 1} 26 | 27 | return nn.Sequential( 28 | nn.ConvTranspose2d(3 * n_feats, 2 * n_feats, 3, **deconv_kargs), 29 | nn.ConvTranspose2d(2 * n_feats, 1 * n_feats, 3, **deconv_kargs), 30 | nn.Conv2d(n_feats, out_channels, 5, stride=1, padding=2), 31 | ) 32 | 33 | # def ResNet(n_feats, in_channels=None, out_channels=None): 34 | def ResNet(n_feats, kernel_size, n_blocks, in_channels=None, out_channels=None): 35 | """sequential ResNet 36 | """ 37 | 38 | # if in_channels is None: 39 | # in_channels = n_feats 40 | # if out_channels is None: 41 | # out_channels = n_feats 42 | # # currently not implemented 43 | 44 | m = [] 45 | 46 | if in_channels is not None: 47 | m += [default_conv(in_channels, n_feats, kernel_size)] 48 | 49 | m += [ResBlock(n_feats, 3)] * n_blocks 50 | 51 | if out_channels is not None: 52 | m += [default_conv(n_feats, out_channels, kernel_size)] 53 | 54 | 55 | return nn.Sequential(*m) 56 | 57 | -------------------------------------------------------------------------------- /model/submodule.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.nn as nn 4 | import torch.utils.data 5 | from torch.autograd import Variable 6 | from torch.autograd.function import Function 7 | import torch.nn.functional as F 8 | import numpy as np 9 | 10 | 11 | def convbn(in_channels, out_channels, kernel_size, stride, pad, dilation): 12 | return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, 13 | padding=dilation if dilation > 1 else pad, dilation=dilation, bias=False), 14 | nn.BatchNorm2d(out_channels)) 15 | 16 | 17 | def convbn_3d(in_channels, out_channels, kernel_size, stride, pad): 18 | return nn.Sequential(nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, 19 | padding=pad, bias=False), 20 | nn.BatchNorm3d(out_channels)) 21 | 22 | def transposebn_3d(in_channels, out_channels, kernel_size, stride, pad): 23 | return nn.Sequential(nn.ConvTranspose3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, 24 | padding=pad, bias=False), 25 | nn.BatchNorm3d(out_channels)) 26 | 27 | 28 | def disparity_regression(x, maxdisp): 29 | assert len(x.shape) == 4 30 | disp_values = torch.arange(0, maxdisp, dtype=x.dtype, device=x.device) 31 | disp_values = disp_values.view(1, maxdisp, 1, 1) 32 | return torch.sum(x * disp_values, 1, keepdim=False) 33 | 34 | 35 | def build_concat_volume(refimg_fea, targetimg_fea, maxdisp): 36 | B, C, H, W = refimg_fea.shape 37 | volume = refimg_fea.new_zeros([B, 2 * C, maxdisp, H, W]) 38 | for i in range(maxdisp): 39 | if i > 0: 40 | volume[:, :C, i, :, i:] = refimg_fea[:, :, :, i:] 41 | volume[:, C:, i, :, i:] = targetimg_fea[:, :, :, :-i] 42 | else: 43 | volume[:, :C, i, :, :] = refimg_fea 44 | volume[:, C:, i, :, :] = targetimg_fea 45 | volume = volume.contiguous() 46 | return volume 47 | 48 | 49 | def groupwise_correlation(fea1, fea2, num_groups): 50 | B, C, H, W = fea1.shape 51 | assert C % num_groups == 0 52 | channels_per_group = C // num_groups 53 | cost = (fea1 * fea2).view([B, num_groups, channels_per_group, H, W]).mean(dim=2) 54 | assert cost.shape == (B, num_groups, H, W) 55 | return cost 56 | 57 | 58 | def build_gwc_volume(refimg_fea, targetimg_fea, maxdisp, num_groups): 59 | B, C, H, W = refimg_fea.shape 60 | volume = refimg_fea.new_zeros([B, num_groups, maxdisp, H, W]) 61 | for i in range(maxdisp): 62 | if i > 0: 63 | volume[:, :, i, :, i:] = groupwise_correlation(refimg_fea[:, :, :, i:], targetimg_fea[:, :, :, :-i], 64 | num_groups) 65 | else: 66 | volume[:, :, i, :, :] = groupwise_correlation(refimg_fea, targetimg_fea, num_groups) 67 | volume = volume.contiguous() 68 | return volume 69 | 70 | 71 | class BasicBlock(nn.Module): 72 | expansion = 1 73 | 74 | def __init__(self, inplanes, planes, stride, downsample, pad, dilation): 75 | super(BasicBlock, self).__init__() 76 | 77 | self.conv1 = nn.Sequential(convbn(inplanes, planes, 3, stride, pad, dilation), 78 | nn.ReLU(inplace=True)) 79 | 80 | self.conv2 = convbn(planes, planes, 3, 1, pad, dilation) 81 | 82 | self.downsample = downsample 83 | self.stride = stride 84 | 85 | def forward(self, x): 86 | out = self.conv1(x) 87 | out = self.conv2(out) 88 | 89 | if self.downsample is not None: 90 | x = self.downsample(x) 91 | 92 | out += x 93 | 94 | return out -------------------------------------------------------------------------------- /optim/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import torch.optim.lr_scheduler as lrs 4 | 5 | import os 6 | from collections import Counter 7 | 8 | from model import Model 9 | from utils import interact, Map 10 | 11 | class Optimizer(object): 12 | def __init__(self, args, model): 13 | self.args = args 14 | 15 | self.save_dir = os.path.join(self.args.save_dir, 'optim') 16 | os.makedirs(self.save_dir, exist_ok=True) 17 | 18 | if isinstance(model, Model): 19 | model = model.model 20 | 21 | # set base arguments 22 | kwargs_optimizer = { 23 | 'lr': args.lr, 24 | 'weight_decay': args.weight_decay 25 | } 26 | 27 | if args.optimizer == 'SGD': 28 | optimizer_class = optim.SGD 29 | kwargs_optimizer['momentum'] = args.momentum 30 | elif args.optimizer == 'ADAM': 31 | optimizer_class = optim.Adam 32 | kwargs_optimizer['betas'] = args.betas 33 | kwargs_optimizer['eps'] = args.epsilon 34 | elif args.optimizer == 'RMSPROP': 35 | optimizer_class = optim.RMSprop 36 | kwargs_optimizer['eps'] = args.epsilon 37 | 38 | # scheduler 39 | if args.scheduler == 'step': 40 | scheduler_class = lrs.MultiStepLR 41 | kwargs_scheduler = { 42 | 'milestones': args.milestones, 43 | 'gamma': args.gamma, 44 | } 45 | elif args.scheduler == 'plateau': 46 | scheduler_class = lrs.ReduceLROnPlateau 47 | kwargs_scheduler = { 48 | 'mode': 'min', 49 | 'factor': args.gamma, 50 | 'patience': 10, 51 | 'verbose': True, 52 | 'threshold': 0, 53 | 'threshold_mode': 'abs', 54 | 'cooldown': 10, 55 | } 56 | 57 | self.kwargs_optimizer = kwargs_optimizer 58 | self.scheduler_class = scheduler_class 59 | self.kwargs_scheduler = kwargs_scheduler 60 | 61 | def _get_optimizer(model): 62 | 63 | class _Optimizer(optimizer_class): 64 | def __init__(self, model, args, scheduler_class, kwargs_scheduler): 65 | trainable = filter(lambda x: x.requires_grad, model.parameters()) 66 | super(_Optimizer, self).__init__(trainable, **kwargs_optimizer) 67 | 68 | self.args = args 69 | 70 | self._register_scheduler(scheduler_class, kwargs_scheduler) 71 | 72 | def _register_scheduler(self, scheduler_class, kwargs_scheduler): 73 | self.scheduler = scheduler_class(self, **kwargs_scheduler) 74 | 75 | def schedule(self, metrics=None): 76 | if isinstance(self, lrs.ReduceLROnPlateau): 77 | self.scheduler.step(metrics) 78 | else: 79 | self.scheduler.step() 80 | 81 | def get_last_epoch(self): 82 | return self.scheduler.last_epoch 83 | 84 | def get_lr(self): 85 | return self.param_groups[0]['lr'] 86 | 87 | def get_last_lr(self): 88 | return self.scheduler.get_last_lr()[0] 89 | 90 | def state_dict(self): 91 | state_dict = super(_Optimizer, self).state_dict() # {'state': ..., 'param_groups': ...} 92 | state_dict['scheduler'] = self.scheduler.state_dict() 93 | 94 | return state_dict 95 | 96 | def load_state_dict(self, state_dict, epoch=None): 97 | # optimizer 98 | super(_Optimizer, self).load_state_dict(state_dict) # load 'state' and 'param_groups' only 99 | # scheduler 100 | self.scheduler.load_state_dict(state_dict['scheduler']) # should work for plateau or simple resuming 101 | 102 | reschedule = False 103 | if isinstance(self.scheduler, lrs.MultiStepLR): 104 | if self.args.milestones != list(self.scheduler.milestones) or self.args.gamma != self.scheduler.gamma: 105 | reschedule = True 106 | 107 | if reschedule: 108 | if epoch is None: 109 | if self.scheduler.last_epoch > 1: 110 | epoch = self.scheduler.last_epoch 111 | else: 112 | epoch = self.args.start_epoch - 1 113 | 114 | # if False: 115 | # # option 1. new scheduler 116 | # for i, group in enumerate(self.param_groups): 117 | # self.param_groups[i]['lr'] = group['initial_lr'] # reset optimizer learning rate to initial 118 | # # self.scheduler = None 119 | # self._register_scheduler(scheduler_class, kwargs_scheduler) 120 | 121 | # self.zero_grad() 122 | # self.step() 123 | # for _ in range(epoch): 124 | # self.scheduler.step() 125 | # self._step_count -= 1 126 | 127 | # else: 128 | # option 2. modify existing scheduler 129 | self.scheduler.milestones = Counter(self.args.milestones) 130 | self.scheduler.gamma = self.args.gamma 131 | for i, group in enumerate(self.param_groups): 132 | self.param_groups[i]['lr'] = group['initial_lr'] # reset optimizer learning rate to initial 133 | multiplier = 1 134 | for milestone in self.scheduler.milestones: 135 | if epoch >= milestone: 136 | multiplier *= self.scheduler.gamma 137 | 138 | self.param_groups[i]['lr'] *= multiplier 139 | 140 | return _Optimizer(model, args, scheduler_class, kwargs_scheduler) 141 | 142 | # self.G = _get_optimizer(model.G) 143 | self.G = _get_optimizer(model) 144 | if model.D is not None: 145 | self.D = _get_optimizer(model.D) 146 | else: 147 | self.D = None 148 | 149 | self.load(args.load_epoch) 150 | 151 | def zero_grad(self): 152 | self.G.zero_grad() 153 | 154 | def step(self): 155 | self.G.step() 156 | 157 | def schedule(self, metrics=None): 158 | self.G.schedule(metrics) 159 | if self.D is not None: 160 | self.D.schedule(metrics) 161 | 162 | def get_last_epoch(self): 163 | return self.G.get_last_epoch() 164 | 165 | def get_lr(self): 166 | return self.G.get_lr() 167 | 168 | def get_last_lr(self): 169 | return self.G.get_last_lr() 170 | 171 | def state_dict(self): 172 | state_dict = Map() 173 | state_dict.G = self.G.state_dict() 174 | if self.D is not None: 175 | state_dict.D = self.D.state_dict() 176 | 177 | return state_dict.toDict() 178 | 179 | def load_state_dict(self, state_dict, epoch=None): 180 | state_dict = Map(**state_dict) 181 | self.G.load_state_dict(state_dict.G, epoch) 182 | if self.D is not None: 183 | self.D.load_state_dict(state_dict.D, epoch) 184 | 185 | def _save_path(self, epoch=None): 186 | epoch = epoch if epoch is not None else self.get_last_epoch() 187 | save_path = os.path.join(self.save_dir, 'optim-{:d}.pt'.format(epoch)) 188 | 189 | return save_path 190 | 191 | def save(self, epoch=None): 192 | if epoch is None: 193 | epoch = self.G.scheduler.last_epoch 194 | torch.save(self.state_dict(), self._save_path(epoch)) 195 | 196 | def load(self, epoch): 197 | if epoch > 0: 198 | print('Loading optimizer from {}'.format(self._save_path(epoch))) 199 | self.load_state_dict(torch.load(self._save_path(epoch), map_location=self.args.device), epoch=epoch) 200 | 201 | elif epoch == 0: 202 | pass 203 | else: 204 | raise NotImplementedError 205 | 206 | return 207 | 208 | -------------------------------------------------------------------------------- /optim/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/optim/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /optim/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chohoonhee/SCSNet/5c6ba773af1513d246bcf0ec47f97f14fe48eb89/optim/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /optim/warm_multi_step_lr.py: -------------------------------------------------------------------------------- 1 | import math 2 | from bisect import bisect_right 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | # MultiStep learning rate scheduler with warm restart 6 | class WarmMultiStepLR(_LRScheduler): 7 | def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1, scale=1): 8 | if not list(milestones) == sorted(milestones): 9 | raise ValueError( 10 | 'Milestones should be a list of increasing integers. Got {}', 11 | milestones 12 | ) 13 | 14 | self.milestones = milestones 15 | self.gamma = gamma 16 | self.scale = scale 17 | 18 | self.warmup_epochs = 5 19 | self.gradual = (self.scale - 1) / self.warmup_epochs 20 | super(WarmMultiStepLR, self).__init__(optimizer, last_epoch) 21 | 22 | def get_lr(self): 23 | if self.last_epoch < self.warmup_epochs: 24 | return [ 25 | base_lr * (1 + self.last_epoch * self.gradual) / self.scale 26 | for base_lr in self.base_lrs 27 | ] 28 | else: 29 | return [ 30 | base_lr * self.gamma ** bisect_right(self.milestones, self.last_epoch) 31 | for base_lr in self.base_lrs 32 | ] 33 | -------------------------------------------------------------------------------- /option.py: -------------------------------------------------------------------------------- 1 | """optionional argument parsing""" 2 | # pylint: disable=C0103, C0301 3 | import argparse 4 | import datetime 5 | import os 6 | import re 7 | import shutil 8 | import time 9 | 10 | import torch 11 | import torch.distributed as dist 12 | import torch.backends.cudnn as cudnn 13 | 14 | from utils import interact 15 | from utils import str2bool, int2str 16 | 17 | 18 | # Training settings 19 | parser = argparse.ArgumentParser(description='Dynamic Scene Deblurring') 20 | 21 | # Device specifications 22 | group_device = parser.add_argument_group('Device specs') 23 | group_device.add_argument('--seed', type=int, default=123, help='random seed') 24 | group_device.add_argument('--num_workers', type=int, default=7, help='the number of dataloader workers') 25 | group_device.add_argument('--device_type', type=str, choices=('cpu', 'cuda'), default='cuda', help='device to run models') 26 | group_device.add_argument('--device_index', type=int, default=0, help='device id to run models') 27 | group_device.add_argument('--n_GPUs', type=int, default=1, help='the number of GPUs for training') 28 | group_device.add_argument('--distributed', type=str2bool, default=False, help='use DistributedDataParallel instead of DataParallel for better speed') 29 | group_device.add_argument('--launched', type=str2bool, default=False, help='identify if main.py was executed from launch.py. Do not set this to be true using main.py.') 30 | 31 | group_device.add_argument('--master_addr', type=str, default='127.0.0.1', help='master address for distributed') 32 | group_device.add_argument('--master_port', type=int2str, default='8023', help='master port for distributed') 33 | group_device.add_argument('--dist_backend', type=str, default='nccl', help='distributed backend') 34 | group_device.add_argument('--init_method', type=str, default='env://', help='distributed init method URL to discover peers') 35 | group_device.add_argument('--rank', type=int, default=0, help='rank of the distributed process (gpu id). 0 is the master process.') 36 | group_device.add_argument('--world_size', type=int, default=1, help='world_size for distributed training (number of GPUs)') 37 | 38 | # Data 39 | group_data = parser.add_argument_group('Data specs') 40 | group_data.add_argument('--data_root', type=str, default='~/Research/dataset', help='dataset root location') 41 | group_data.add_argument('--dataset', type=str, default=None, help='training/validation/test dataset name, has priority if not None') 42 | group_data.add_argument('--data_train', type=str, default='DSEC', help='training dataset name') 43 | group_data.add_argument('--data_val', type=str, default='DSEC', help='validation dataset name') 44 | group_data.add_argument('--data_test', type=str, default='DSEC', help='test dataset name') 45 | group_data.add_argument('--blur_key', type=str, default='blur_gamma', choices=('blur', 'blur_gamma'), help='blur type from camera response function for GOPRO_Large dataset') 46 | group_data.add_argument('--rgb_range', type=int, default=255, help='RGB pixel value ranging from 0') 47 | group_data.add_argument('--split', type=int, default=1, help='RGB pixel value ranging from 0') 48 | 49 | # Model 50 | group_model = parser.add_argument_group('Model specs') 51 | group_model.add_argument('--model', type=str, default='MSResNet', help='model architecture') 52 | group_model.add_argument('--disp_model', type=str, default='GWCNet', help='model architecture') 53 | group_model.add_argument('--pretrained', type=str, default='', help='pretrained model location') 54 | 55 | 56 | # amp 57 | group_amp = parser.add_argument_group('AMP specs') 58 | group_amp.add_argument('--amp', type=str2bool, default=False, help='use automatic mixed precision training') 59 | group_amp.add_argument('--init_scale', type=float, default=1024., help='initial loss scale') 60 | 61 | # Training 62 | group_train = parser.add_argument_group('Training specs') 63 | group_train.add_argument('--patch_size', type=int, default=256, help='training patch size') 64 | group_train.add_argument('--batch_size', type=int, default=16, help='input batch size for training') 65 | group_train.add_argument('--split_batch', type=int, default=1, help='split a minibatch into smaller chunks') 66 | group_train.add_argument('--augment', type=str2bool, default=True, help='train with data augmentation') 67 | 68 | # Testing 69 | group_test = parser.add_argument_group('Testing specs') 70 | group_test.add_argument('--validate_every', type=int, default=1, help='do validation at every N epochs') 71 | group_test.add_argument('--test_every', type=int, default=10, help='do test at every N epochs') 72 | # group_test.add_argument('--chop', type=str2bool, default=False, help='memory-efficient forward') 73 | # group_test.add_argument('--self_ensemble', type=str2bool, default=False, help='self-ensembled testing') 74 | 75 | # Action 76 | group_action = parser.add_argument_group('Source behavior') 77 | group_action.add_argument('--do_train', type=str2bool, default=True, help='do train the model') 78 | group_action.add_argument('--do_validate', type=str2bool, default=True, help='do validate the model') 79 | group_action.add_argument('--do_test', type=str2bool, default=True, help='do test the model') 80 | group_action.add_argument('--demo', type=str2bool, default=False, help='demo') 81 | group_action.add_argument('--demo_input_dir', type=str, default='', help='demo input directory') 82 | group_action.add_argument('--demo_output_dir', type=str, default='', help='demo output directory') 83 | 84 | # Optimization 85 | group_optim = parser.add_argument_group('Optimization specs') 86 | group_optim.add_argument('--lr', type=float, default=1e-3, help='learning rate') 87 | group_optim.add_argument('--milestones', type=int, nargs='+', default=[20, 40, 60, 80, 100, 120], help='learning rate decay per N epochs') 88 | group_optim.add_argument('--scheduler', default='step', choices=('step', 'plateau'), help='learning rate scheduler type') 89 | group_optim.add_argument('--gamma', type=float, default=0.5, help='learning rate decay factor for step decay') 90 | group_optim.add_argument('--optimizer', default='ADAM', choices=('SGD', 'ADAM', 'RMSprop'), help='optimizer to use (SGD | ADAM | RMSProp)') 91 | group_optim.add_argument('--momentum', type=float, default=0.9, help='SGD momentum') 92 | group_optim.add_argument('--betas', type=float, nargs=2, default=(0.9, 0.999), help='ADAM betas') 93 | group_optim.add_argument('--epsilon', type=float, default=1e-8, help='ADAM epsilon') 94 | group_optim.add_argument('--weight_decay', type=float, default=0.0001, help='weight decay') 95 | 96 | # Loss 97 | group_loss = parser.add_argument_group('Loss specs') 98 | group_loss.add_argument('--loss', type=str, default='1*L1', help='loss function configuration') 99 | group_loss.add_argument('--metric', type=str, default='PSNR,SSIM', help='metric function configuration. ex) None | PSNR | SSIM | PSNR,SSIM') 100 | 101 | # Logging 102 | group_log = parser.add_argument_group('Logging specs') 103 | group_log.add_argument('--save_dir', type=str, default='', help='subdirectory to save experiment logs') 104 | group_log.add_argument('--load_dir', type=str, default=None, help='subdirectory to load experiment logs') 105 | group_log.add_argument('--start_epoch', type=int, default=-1, help='(re)starting epoch number') 106 | group_log.add_argument('--end_epoch', type=int, default=1000, help='ending epoch number') 107 | group_log.add_argument('--load_epoch', type=int, default=-1, help='epoch number to load model (start_epoch-1 for training, start_epoch for testing)') 108 | group_log.add_argument('--save_every', type=int, default=10, help='save model/optimizer at every N epochs') 109 | group_log.add_argument('--save_results', type=str, default='part', choices=('none', 'part', 'all'), help='save none/part/all of result images') 110 | 111 | # Debugging 112 | group_debug = parser.add_argument_group('Debug specs') 113 | group_debug.add_argument('--stay', type=str2bool, default=False, help='stay at interactive console after trainer initialization') 114 | 115 | args = parser.parse_args() 116 | 117 | args.data_root = os.path.expanduser(args.data_root) # recognize home directory 118 | now = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') 119 | if args.save_dir == '': 120 | args.save_dir = now 121 | args.save_dir = os.path.join('experiment', args.save_dir) 122 | os.makedirs(args.save_dir, exist_ok=True) 123 | 124 | if args.start_epoch < 0: # start from scratch or continue from the last epoch 125 | # check if there are any models saved before 126 | model_dir = os.path.join(args.save_dir, 'models') 127 | model_prefix = 'model-' 128 | if os.path.exists(model_dir): 129 | model_list = [name for name in os.listdir(model_dir) if name.startswith(model_prefix)] 130 | last_epoch = 0 131 | for name in model_list: 132 | epochNumber = int(re.findall('\\d+', name)[0]) # model example name model-100.pt 133 | if last_epoch < epochNumber: 134 | last_epoch = epochNumber 135 | 136 | args.start_epoch = last_epoch + 1 137 | else: 138 | # train from scratch 139 | args.start_epoch = 1 140 | elif args.start_epoch == 0: 141 | # remove existing directory and start over 142 | if args.rank == 0: # maybe local rank 143 | shutil.rmtree(args.save_dir, ignore_errors=True) 144 | os.makedirs(args.save_dir, exist_ok=True) 145 | args.start_epoch = 1 146 | 147 | if args.load_epoch < 0: # load_epoch == start_epoch when doing a post-training test for a specific epoch 148 | args.load_epoch = args.start_epoch - 1 149 | 150 | if args.pretrained: 151 | if args.start_epoch <= 1: 152 | args.pretrained = os.path.join('experiment', args.pretrained) 153 | else: 154 | print('starting from epoch {}! ignoring pretrained model path..'.format(args.start_epoch)) 155 | args.pretrained = '' 156 | 157 | argname = os.path.join(args.save_dir, 'args.pt') 158 | argname_txt = os.path.join(args.save_dir, 'args.txt') 159 | 160 | if args.dataset is not None: 161 | 162 | args.data_train = args.dataset 163 | args.data_val = args.dataset if args.dataset not in ['GOPRO_Large', 'synthetic', 'synthetic_event', 'indoor_flying_1', 'indoor_flying_2', 'indoor_flying_3'] else None 164 | args.data_test = args.dataset 165 | 166 | 167 | 168 | if args.demo_input_dir: 169 | args.demo = True 170 | 171 | if args.demo: 172 | assert os.path.basename(args.save_dir) != now, 'You should specify pretrained directory by setting --save_dir SAVE_DIR' 173 | 174 | args.data_train = '' 175 | args.data_val = '' 176 | args.data_test = '' 177 | 178 | args.do_train = False 179 | args.do_validate = False 180 | args.do_test = False 181 | 182 | assert len(args.demo_input_dir) > 0, 'Please specify demo_input_dir!' 183 | args.demo_input_dir = os.path.expanduser(args.demo_input_dir) 184 | if args.demo_output_dir: 185 | args.demo_output_dir = os.path.expanduser(args.demo_output_dir) 186 | 187 | args.save_results = 'all' 188 | 189 | if args.amp: 190 | args.precision = 'single' # model parameters should stay in fp32 191 | 192 | if args.seed < 0: 193 | args.seed = int(time.time()) 194 | 195 | # save arguments 196 | if args.rank == 0: 197 | torch.save(args, argname) 198 | with open(argname_txt, 'a') as file: 199 | file.write('execution at {}\n'.format(now)) 200 | 201 | for key in args.__dict__: 202 | file.write(key + ': ' + str(args.__dict__[key]) + '\n') 203 | 204 | file.write('\n') 205 | 206 | # device and type 207 | if args.device_type == 'cuda' and not torch.cuda.is_available(): 208 | raise Exception("GPU not available!") 209 | 210 | if not args.distributed: 211 | args.rank = 0 212 | 213 | def setup(args): 214 | cudnn.benchmark = True 215 | 216 | if args.distributed: 217 | os.environ['MASTER_ADDR'] = args.master_addr 218 | os.environ['MASTER_PORT'] = args.master_port 219 | 220 | args.device_index = args.rank 221 | args.world_size = args.n_GPUs # consider single-node training 222 | 223 | # initialize the process group 224 | dist.init_process_group(args.dist_backend, init_method=args.init_method, rank=args.rank, world_size=args.world_size) 225 | 226 | args.device = torch.device(args.device_type, args.device_index) if args.device_type == 'cuda' else torch.device(args.device_type) 227 | args.dtype = torch.float32 228 | 229 | # set seed for processes (distributed: different seed for each process) 230 | # model parameters are synchronized explicitly at initial 231 | torch.manual_seed(args.seed) 232 | if args.device_type == 'cuda': 233 | torch.cuda.set_device(args.device) 234 | if args.rank == 0: 235 | torch.cuda.manual_seed_all(args.seed) 236 | 237 | return args 238 | 239 | def cleanup(args): 240 | if args.distributed: 241 | dist.destroy_process_group() 242 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import readline 2 | import rlcompleter 3 | readline.parse_and_bind("tab: complete") 4 | import code 5 | import pdb 6 | 7 | import time 8 | import argparse 9 | import os 10 | import imageio 11 | import torch 12 | import torch.multiprocessing as mp 13 | from matplotlib import pyplot as plt 14 | from mpl_toolkits import axes_grid1 15 | import numpy as np 16 | import cv2 17 | # debugging tools 18 | def interact(local=None): 19 | """interactive console with autocomplete function. Useful for debugging. 20 | interact(locals()) 21 | """ 22 | if local is None: 23 | local=dict(globals(), **locals()) 24 | 25 | readline.set_completer(rlcompleter.Completer(local).complete) 26 | code.interact(local=local) 27 | 28 | def set_trace(local=None): 29 | """debugging with pdb 30 | """ 31 | if local is None: 32 | local=dict(globals(), **locals()) 33 | 34 | pdb.Pdb.complete = rlcompleter.Completer(local).complete 35 | pdb.set_trace() 36 | 37 | # timer 38 | class Timer(): 39 | """Brought from https://github.com/thstkdgus35/EDSR-PyTorch 40 | """ 41 | def __init__(self): 42 | self.acc = 0 43 | self.tic() 44 | 45 | def tic(self): 46 | self.t0 = time.time() 47 | 48 | def toc(self): 49 | return time.time() - self.t0 50 | 51 | def hold(self): 52 | self.acc += self.toc() 53 | 54 | def release(self): 55 | ret = self.acc 56 | self.acc = 0 57 | 58 | return ret 59 | 60 | def reset(self): 61 | self.acc = 0 62 | 63 | 64 | # argument parser type casting functions 65 | def str2bool(val): 66 | """enable default constant true arguments""" 67 | # https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 68 | if isinstance(val, bool): 69 | return val 70 | elif val.lower() == 'true': 71 | return True 72 | elif val.lower() == 'false': 73 | return False 74 | else: 75 | raise argparse.ArgumentTypeError('Boolean value expected') 76 | 77 | def int2str(val): 78 | """convert int to str for environment variable related arguments""" 79 | if isinstance(val, int): 80 | return str(val) 81 | elif isinstance(val, str): 82 | return val 83 | else: 84 | raise argparse.ArgumentTypeError('number value expected') 85 | 86 | 87 | # image saver using multiprocessing queue 88 | class MultiSaver(): 89 | def __init__(self, result_dir=None): 90 | self.queue = None 91 | self.process = None 92 | self.result_dir = result_dir 93 | 94 | def begin_background(self): 95 | self.queue = mp.Queue() 96 | 97 | def t(queue): 98 | while True: 99 | if queue.empty(): 100 | continue 101 | img, name = queue.get() 102 | if name: 103 | try: 104 | basename, ext = os.path.splitext(name) 105 | if ext != '.png': 106 | name = '{}.png'.format(basename) 107 | imageio.imwrite(name, img) 108 | except Exception as e: 109 | print(e) 110 | else: 111 | return 112 | 113 | worker = lambda: mp.Process(target=t, args=(self.queue,), daemon=False) 114 | cpu_count = min(8, mp.cpu_count() - 1) 115 | self.process = [worker() for _ in range(cpu_count)] 116 | for p in self.process: 117 | p.start() 118 | 119 | def end_background(self): 120 | if self.queue is None: 121 | return 122 | 123 | for _ in self.process: 124 | self.queue.put((None, None)) 125 | 126 | def join_background(self): 127 | if self.queue is None: 128 | return 129 | 130 | while not self.queue.empty(): 131 | time.sleep(0.5) 132 | 133 | for p in self.process: 134 | p.join() 135 | 136 | self.queue = None 137 | 138 | def save_image(self, output, save_names, result_dir=None): 139 | result_dir = result_dir if self.result_dir is None else self.result_dir 140 | if result_dir is None: 141 | raise Exception('no result dir specified!') 142 | 143 | if self.queue is None: 144 | try: 145 | self.begin_background() 146 | except Exception as e: 147 | print(e) 148 | return 149 | 150 | # assume NCHW format 151 | if output.ndim == 2: 152 | output = output.expand([1, 1] + list(output.shape)) 153 | elif output.ndim == 3: 154 | output = output.expand([1] + list(output.shape)) 155 | 156 | for output_img, save_name in zip(output, save_names): 157 | # assume image range [0, 255] 158 | output_img = output_img.add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() 159 | 160 | save_name = os.path.join(result_dir, save_name) 161 | save_dir = os.path.dirname(save_name) 162 | os.makedirs(save_dir, exist_ok=True) 163 | 164 | self.queue.put((output_img, save_name)) 165 | 166 | return 167 | 168 | def save_disp(self, output, save_names, result_dir=None): 169 | result_dir = result_dir if self.result_dir is None else self.result_dir 170 | if result_dir is None: 171 | raise Exception('no result dir specified!') 172 | 173 | if self.queue is None: 174 | try: 175 | self.begin_background() 176 | except Exception as e: 177 | print(e) 178 | return 179 | 180 | # assume NCHW format 181 | if output.ndim == 2: 182 | output = output.expand([1, 1] + list(output.shape)) 183 | elif output.ndim == 3: 184 | output = output.expand([1] + list(output.shape)) 185 | 186 | for output_img, save_name in zip(output, save_names): 187 | # import pdb 188 | # pdb.set_trace() 189 | # assume image range [0, 255] 190 | # output_img = output_img.add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() 191 | output_img = output_img.squeeze(0).cpu() 192 | figure = plt.figure() 193 | noninf_mask = output_img != float('inf') 194 | 195 | minimum_value = np.quantile(output_img[noninf_mask], 0.001) 196 | maximum_value = np.quantile(output_img[noninf_mask], 0.999) 197 | 198 | plot = plt.imshow( 199 | output_img.numpy(), 'jet', vmin=minimum_value, vmax=maximum_value) 200 | save_name = os.path.join(result_dir, save_name) 201 | save_dir = os.path.dirname(save_name) 202 | os.makedirs(save_dir, exist_ok=True) 203 | # self.add_scaled_colorbar(plot) 204 | plot.axes.get_xaxis().set_visible(False) 205 | plot.axes.get_yaxis().set_visible(False) 206 | figure.savefig(save_name, bbox_inches='tight', dpi=200) 207 | plt.close() 208 | # self.queue.put((output_img, save_name)) 209 | 210 | return 211 | 212 | def save_disp_test(self, output, save_names, result_dir=None): 213 | result_dir = result_dir if self.result_dir is None else self.result_dir 214 | if result_dir is None: 215 | raise Exception('no result dir specified!') 216 | 217 | if self.queue is None: 218 | try: 219 | self.begin_background() 220 | except Exception as e: 221 | print(e) 222 | return 223 | # disp_16bit = cv2.imwrite(str(save_names), output_img * 256) 224 | # return disp_16bit.astype('float32')/256 225 | # assume NCHW format 226 | if output.ndim == 2: 227 | output = output.expand([1, 1] + list(output.shape)) 228 | elif output.ndim == 3: 229 | output = output.expand([1] + list(output.shape)) 230 | 231 | for output_img, save_name in zip(output, save_names): 232 | # import pdb 233 | # pdb.set_trace() 234 | # assume image range [0, 255] 235 | save_name = os.path.join(result_dir, save_name) 236 | save_dir = os.path.dirname(save_name) 237 | os.makedirs(save_dir, exist_ok=True) 238 | output_img = output_img.squeeze(0).cpu() 239 | output_img = np.array(output_img * 256, dtype = np.uint16) 240 | # imageio.imwrite(str(save_name), output_img) 241 | import pdb; pdb.set_trace() 242 | cv2.imwrite(str(save_name), output_img) 243 | 244 | # cv2.imwrite(str(save_names), output_img * 256) 245 | # # output_img = output_img.add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() 246 | # output_img = output_img.squeeze(0).cpu() 247 | # figure = plt.figure() 248 | # noninf_mask = output_img != float('inf') 249 | 250 | # minimum_value = np.quantile(output_img[noninf_mask], 0.001) 251 | # maximum_value = np.quantile(output_img[noninf_mask], 0.999) 252 | 253 | # plot = plt.imshow( 254 | # output_img.numpy(), 'gray', vmin=minimum_value, vmax=maximum_value) 255 | # save_name = os.path.join(result_dir, save_name) 256 | # save_dir = os.path.dirname(save_name) 257 | # os.makedirs(save_dir, exist_ok=True) 258 | # self.add_scaled_colorbar(plot) 259 | # plot.axes.get_xaxis().set_visible(False) 260 | # plot.axes.get_yaxis().set_visible(False) 261 | # figure.savefig(save_name, bbox_inches='tight', dpi=200) 262 | # plt.close() 263 | # # self.queue.put((output_img, save_name)) 264 | 265 | return 266 | 267 | def add_scaled_colorbar(self, plot, aspect=20, pad_fraction=0.5, **kwargs): 268 | """Adds scaled colorbar to existing plot.""" 269 | divider = axes_grid1.make_axes_locatable(plot.axes) 270 | width = axes_grid1.axes_size.AxesY(plot.axes, aspect=1. / aspect) 271 | pad = axes_grid1.axes_size.Fraction(pad_fraction, width) 272 | current_axis = plt.gca() 273 | cax = divider.append_axes("right", size=width, pad=pad) 274 | plt.sca(current_axis) 275 | return plot.axes.figure.colorbar(plot, cax=cax, **kwargs) 276 | 277 | class Map(dict): 278 | """ 279 | https://stackoverflow.com/questions/2352181/how-to-use-a-dot-to-access-members-of-dictionary 280 | Example: 281 | m = Map({'first_name': 'Eduardo'}, last_name='Pool', age=24, sports=['Soccer']) 282 | """ 283 | def __init__(self, *args, **kwargs): 284 | super(Map, self).__init__(*args, **kwargs) 285 | for arg in args: 286 | if isinstance(arg, dict): 287 | for k, v in arg.items(): 288 | self[k] = v 289 | 290 | if kwargs: 291 | for k, v in kwargs.items(): 292 | self[k] = v 293 | 294 | def __getattr__(self, attr): 295 | return self.get(attr) 296 | 297 | def __setattr__(self, key, value): 298 | self.__setitem__(key, value) 299 | 300 | def __setitem__(self, key, value): 301 | super(Map, self).__setitem__(key, value) 302 | self.__dict__.update({key: value}) 303 | 304 | def __delattr__(self, item): 305 | self.__delitem__(item) 306 | 307 | def __delitem__(self, key): 308 | super(Map, self).__delitem__(key) 309 | del self.__dict__[key] 310 | 311 | def toDict(self): 312 | return self.__dict__ --------------------------------------------------------------------------------