├── .gitignore ├── 202102_008_221_35_3200_108_00_Image_input.png ├── README.md ├── data ├── __init__.py ├── anno_ISR │ ├── test.txt │ ├── test_qualitative_pairs.txt │ ├── test_quantitative_pairs_10x.txt │ ├── train.txt │ ├── val.txt │ └── val_pairs.txt ├── anno_RSR │ ├── AnyLighttest_pairs.txt │ ├── AnyLighttest_pairs_qualitative.txt │ ├── AnyLightval_pairs.txt │ ├── test.txt │ ├── train.txt │ └── val.txt ├── anno_VIDIT │ └── any2any │ │ ├── AnyLight_test_pairs.txt │ │ ├── AnyLight_test_pairs_qualitative.txt │ │ ├── AnyLight_val_pairs.txt │ │ ├── test.txt │ │ ├── train.txt │ │ └── val.txt ├── base_dataset.py ├── multi_illumination │ ├── crop_resize.py │ ├── test.txt │ ├── test_qualitative.txt │ ├── train.txt │ └── val.txt ├── relighting_dataset_single_image.py ├── relighting_dataset_single_image_multilum.py ├── relighting_dataset_single_image_rsr.py ├── relighting_dataset_single_image_test.py └── relighting_dataset_single_image_vidit.py ├── models ├── __init__.py ├── base_model.py ├── drn_model.py ├── models.py ├── networks.py ├── networks_custom_func.py ├── networks_discriminator.py ├── networks_intrinsic.py ├── networks_one_to_one_rep.py └── two_stage_model.py ├── options ├── __init__.py ├── base_options.py ├── test_qualitative_options.py ├── test_quantitative_options.py ├── train_options_isr.py ├── train_options_isr_drn.py ├── train_options_isr_pix2pix.py ├── train_options_multilum_ours_f.py ├── train_options_rsr_ours_f.py └── train_options_vidit_ours_f.py ├── requirements.txt ├── test_qualitative.py ├── test_qualitative.txt ├── test_qualitative_animation.py ├── test_quantitative.py ├── train.py └── util ├── html.py ├── k_to_rgb.py ├── metric.py ├── pan_tilt_dial.png ├── util.py └── visualizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | __pycache__/ 3 | -------------------------------------------------------------------------------- /202102_008_221_35_3200_108_00_Image_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVC-CIC/DeepIntrinsicRelighting/b38a04ff9caecddd4e3551950cf6aaf5b17d960e/202102_008_221_35_3200_108_00_Image_input.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Relighting from a Single Image: Datasets and Deep Intrinsic-based Architecture 2 | Yixiong Yang, Hassan Ahmed Sial, Ramon Baldrich, Maria Vanrell 3 | 4 | [Arxiv](https://arxiv.org/abs/2409.18770) 5 | 6 | ## Description 7 | This repo contains the official code, data, and video results for the paper. 8 | 9 | 10 | https://github.com/liulisixin/save_videos/assets/49985369/c375b269-032e-4e2c-9685-18fa4f5e7a11 11 | 12 | ## Setup 13 | The code was tested on PyTorch 1.10.1, but it is not strict; other versions should also work. 14 | ``` 15 | conda create --name DIR python=3.7 16 | conda activate DIR 17 | pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 torchaudio==0.10.1 -f https://download.pytorch.org/whl/cu111/torch_stable.html 18 | pip install -r requirements.txt 19 | ``` 20 | 21 | 22 | ## Datasets 23 | ### ISR: Intrinsic Scene Relighting Dataset 24 | [Reflectance](https://cvcuab-my.sharepoint.com/:u:/g/personal/yixiong_cvc_uab_cat/Ed4PMW9cxJBKipa2GNlSvSQBD7try__Gz6Sk76Qbwcx7nA?e=lHvddG) 25 | [Shading](https://cvcuab-my.sharepoint.com/:u:/g/personal/yixiong_cvc_uab_cat/EdBYF2GUO35Hpsm_PBpEMsQBjstij4hOOn2HlxQ4ekDwqw?e=aQukUd) 26 | [Image](https://cvcuab-my.sharepoint.com/:u:/g/personal/yixiong_cvc_uab_cat/EfeCiWYw_P9BnVqub6ii8FYBjtRVyMwko3E-av8WZTTo1Q?e=EQdM6E) 27 | 28 | You can download only the Reflectance and Shading, as the image can be calculated from them in the code. 29 | 30 | **Note**: The name of the picture is: `{Index of scene (part 1)}_{Index of scene (part 2)}_{pan}_{tilt}_{light temperature}_{not use}`. The first two numbers represent the scene ID, while pan, tilt, and light temperature represent the lighting parameters. 31 | 32 | ### RSR: Real Scene Relighting Dataset 33 | [Download](https://cvcuab-my.sharepoint.com/:u:/g/personal/yixiong_cvc_uab_cat/ETWcj5yBKgJLqUZDsT9Q39QBJ8GUJYEQuzNWpV5FS2lPRg?e=jEYI5t) 34 | 35 | The RSR dataset was created in our lab environment. The dataset above is in 256×256 resolution, as used in our paper. If you need high-resolution or raw images, please let us know. 36 | 37 | **Updated on Nov 23, 2024**: Here is the dataset with the original resolution [Download](https://cvcuab-my.sharepoint.com/:u:/g/personal/yixiong_cvc_uab_cat/ET7MLf3u27dMktC52VaMGL4Bu203mKfAcniBaPKhGktVIw?e=ZYPAjX). 38 | 39 | **Note**: 40 | 1. The name of the picture is: `{index of picture}_{index of group (different scene or different view)}_{pan}_{tilt}_{R}_{G}_{B}_{index of scene}_{not use}_{index of view}_{index of light position}`. The quantities that need attention are pan, tilt, and color (RGB), which represent the parameters of the light. 41 | 2. The order of the lights are as follow: 42 | 43 | | 5 | 4 | 3 | 44 | | --- | --- | --- | 45 | | 6 | 1 | 2 | 46 | | 7 | 8 | 9 | 47 | 48 | 49 | 50 | ## Train 51 | In the code, the path of datasets can be modified in the self.server_root of options/base_options.py. 52 | 53 | Train from the scratch on the ISR dataset: 54 | ```angular2html 55 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 7777 train.py isr # For ISR dataset 56 | ``` 57 | Continue training on other datasets (place the pre-trained model in checkpoints/{exp} with the name base_{}): 58 | ```angular2html 59 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 7777 train.py rsr_ours_f # For RSR dataset 60 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 7777 train.py vidit_ours_f # For VIDIT dataset 61 | python -m torch.distributed.launch --nproc_per_node=1 --master_port 7777 train.py multilum_ours_f # For Multi-illumination dataset 62 | ``` 63 | **Note**: 64 | 1. When using the VIDIT dataset, place all images into a single folder to create a complete version, named VIDIT_full. 65 | 2. When using the Multi-illumination dataset, crop and resize the images to a resolution of 256x256. The code for this process can be found in ./data/multi_illumination/crop_resize.py. 66 | 67 | 68 | ## Test 69 | You can download the checkpoints from this link [Checkpoints](https://cvcuab-my.sharepoint.com/:u:/g/personal/yixiong_cvc_uab_cat/EZGyWdlf5nhIpzMeCDpp2DwBsnLfmExrl9NeO_KY8w60Ow?e=540jgE), and do the tests quantitatively or qualitatively. 70 | ```angular2html 71 | python test_quantitative.py --name {$exp} 72 | ``` 73 | 74 | ```angular2html 75 | python test_qualitative.py --name {$exp} 76 | ``` 77 | The name can be 'exp_isr', 'exp_rsr_ours_f', 'exp_vidit_ours_f' and 'exp_multilum_ours_f'. 78 | 79 | The video results can be generated by: 80 | ```angular2html 81 | python test_qualitative_animation.py --name exp_isr 82 | ``` 83 | 84 | 85 | ## More video results 86 | 87 | https://github.com/liulisixin/save_videos/assets/49985369/6dc22788-87ee-4836-8677-7d4f4d8c3b7c 88 | 89 | https://github.com/liulisixin/save_videos/assets/49985369/57a1efd7-7f8b-460c-84a0-9f3dd596da56 90 | 91 | 92 | # Citation 93 | Please cite this repository as follows if you find it helpful for your project :): 94 | ``` 95 | @article{yang2024relighting, 96 | title={Relighting from a Single Image: Datasets and Deep Intrinsic-based Architecture}, 97 | author={Yang, Yixiong and Sial, Hassan Ahmed and Baldrich, Ramon and Vanrell, Maria}, 98 | journal={arXiv preprint arXiv:2409.18770}, 99 | year={2024} 100 | } 101 | ``` 102 | 103 | ## Acknowledgments 104 | Some codes are borrowed from [pix2pix](https://github.com/phillipi/pix2pix). Thanks for their great works. 105 | 106 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes all the modules related to data loading and preprocessing 2 | 3 | To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. 4 | You need to implement four functions: 5 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 6 | -- <__len__>: return the size of dataset. 7 | -- <__getitem__>: get a data point from data loader. 8 | -- : (optionally) add dataset-specific options and set default options. 9 | 10 | Now you can use the dataset class by specifying flag '--dataset_mode dummy'. 11 | See our template dataset class 'template_dataset.py' for more details. 12 | """ 13 | import importlib 14 | import torch.utils.data 15 | from data.base_dataset import BaseDataset 16 | # from data.relighting_collate import default_collate as relighting_collate 17 | 18 | 19 | def find_dataset_using_name(dataset_name): 20 | """Import the module "data/[dataset_name]_dataset.py". 21 | 22 | In the file, the class called DatasetNameDataset() will 23 | be instantiated. It has to be a subclass of BaseDataset, 24 | and it is case-insensitive. 25 | """ 26 | dataset_filename = "data." + dataset_name + "_dataset" 27 | datasetlib = importlib.import_module(dataset_filename) 28 | 29 | dataset = None 30 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 31 | for name, cls in datasetlib.__dict__.items(): 32 | if name.lower() == target_dataset_name.lower() \ 33 | and issubclass(cls, BaseDataset): 34 | dataset = cls 35 | 36 | if dataset is None: 37 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) 38 | 39 | return dataset 40 | 41 | 42 | def get_option_setter(dataset_name): 43 | """Return the static method of the dataset class.""" 44 | dataset_class = find_dataset_using_name(dataset_name) 45 | return dataset_class.modify_commandline_options 46 | 47 | 48 | def create_dataset(opt, validation=False): 49 | """Create a dataset given the option. 50 | 51 | This function wraps the class CustomDatasetDataLoader. 52 | This is the main interface between this package and 'train.py'/'test.py' 53 | 54 | Example: 55 | >>> from data import create_dataset 56 | >>> dataset = create_dataset(opt) 57 | """ 58 | data_loader = CustomDatasetDataLoader(opt, validation) 59 | dataset = data_loader.load_data() 60 | return dataset 61 | 62 | 63 | class CustomDatasetDataLoader(): 64 | """Wrapper class of Dataset class that performs multi-threaded data loading""" 65 | 66 | def __init__(self, opt, validation=False): 67 | """Initialize this class 68 | 69 | Step 1: create a dataset instance given the name [dataset_mode] 70 | Step 2: create a multi-threaded data loader. 71 | """ 72 | self.opt = opt 73 | 74 | if opt.dataset_mode == 'relighting_single_image': 75 | if validation: 76 | # relighting_single_image_test 77 | from data.relighting_dataset_single_image_test import RelightingDatasetSingleImageTest 78 | self.dataset = RelightingDatasetSingleImageTest(opt, validation=validation) 79 | if_shuffle = False 80 | else: 81 | from data.relighting_dataset_single_image import RelightingDatasetSingleImage 82 | self.dataset = RelightingDatasetSingleImage(opt) 83 | if_shuffle = not opt.serial_batches 84 | elif opt.dataset_mode == 'relighting_single_image_test': 85 | from data.relighting_dataset_single_image_test import RelightingDatasetSingleImageTest 86 | self.dataset = RelightingDatasetSingleImageTest(opt) 87 | if_shuffle = False 88 | elif opt.dataset_mode == 'relighting_single_image_multilum': 89 | from data.relighting_dataset_single_image_multilum import RelightingDatasetSingleImageMultilum 90 | self.dataset = RelightingDatasetSingleImageMultilum(opt, validation=validation) 91 | if_shuffle = not opt.serial_batches and not validation 92 | elif opt.dataset_mode == 'relighting_single_image_rsr': 93 | from data.relighting_dataset_single_image_rsr import RelightingDatasetSingleImageRSR 94 | self.dataset = RelightingDatasetSingleImageRSR(opt, validation=validation) 95 | if_shuffle = not opt.serial_batches and not validation 96 | elif opt.dataset_mode == 'relighting_single_image_vidit': 97 | if validation: 98 | # relighting_single_image_test 99 | from data.relighting_dataset_single_image_vidit import RelightingDatasetSingleImageVidit 100 | self.dataset = RelightingDatasetSingleImageVidit(opt, validation=validation) 101 | if_shuffle = False 102 | else: 103 | from data.relighting_dataset_single_image_vidit import RelightingDatasetSingleImageVidit 104 | self.dataset = RelightingDatasetSingleImageVidit(opt) 105 | if_shuffle = not opt.serial_batches 106 | else: 107 | raise Exception("Can not find dataset!") 108 | 109 | # drop_last 110 | if opt.isTrain and not validation and hasattr(opt, 'dataset_drop_last'): 111 | if_drop_last = opt.dataset_drop_last 112 | else: 113 | if_drop_last = False 114 | 115 | print("dataset [%s] was created (shuffle=%s, drop_last=%s)" % ( 116 | type(self.dataset).__name__, if_shuffle, if_drop_last)) 117 | 118 | if opt.parallel_method == 'DataParallel': 119 | self.dataloader = torch.utils.data.DataLoader( 120 | self.dataset, 121 | batch_size=opt.batch_size, 122 | shuffle=if_shuffle, 123 | num_workers=int(opt.num_threads), 124 | drop_last=if_drop_last) 125 | elif opt.parallel_method == 'DistributedDataParallel': 126 | distributed_sampler = torch.utils.data.distributed.DistributedSampler( 127 | self.dataset, 128 | num_replicas=opt.world_size, 129 | rank=opt.gpu_ids, 130 | shuffle=if_shuffle 131 | ) 132 | self.dataloader = torch.utils.data.DataLoader( 133 | dataset=self.dataset, 134 | batch_size=opt.batch_size, 135 | # shuffle=if_shuffle, 136 | num_workers=int(opt.num_threads), 137 | drop_last=if_drop_last, 138 | sampler=distributed_sampler) 139 | 140 | else: 141 | raise Exception("Parallel method type error!") 142 | 143 | 144 | def load_data(self): 145 | return self 146 | 147 | def __len__(self): 148 | """Return the number of data in the dataset""" 149 | return min(len(self.dataset), self.opt.max_dataset_size) 150 | 151 | def __iter__(self): 152 | """Return a batch of data""" 153 | for i, data in enumerate(self.dataloader): 154 | if i * self.opt.batch_size >= self.opt.max_dataset_size: 155 | break 156 | yield data 157 | 158 | -------------------------------------------------------------------------------- /data/anno_ISR/test_qualitative_pairs.txt: -------------------------------------------------------------------------------- 1 | 200023_086_35_40_3700_186.png 200023_086_169_25_3700_186.png 2 | 200132_000_154_25_2400_154.png 200132_000_71_0_2400_154.png 3 | 200135_037_318_20_3400_182.png 200135_037_257_35_3400_182.png 4 | 200199_041_215_25_3600_124.png 200199_041_359_0_3600_124.png 5 | 200251_086_97_45_5600_106.png 200251_086_267_5_5600_106.png 6 | 200266_039_293_20_4000_104.png 200266_039_210_10_4000_104.png 7 | 200578_091_71_35_3900_111.png 200578_091_251_5_3900_111.png 8 | 200773_079_282_30_5600_170.png 200773_079_169_25_5600_170.png 9 | 200774_017_272_10_2600_116.png 200774_017_236_0_2600_116.png 10 | 200796_045_210_15_5600_115.png 200796_045_25_15_5600_115.png 11 | 200878_079_287_10_5000_144.png 200878_079_61_45_5000_144.png 12 | 200946_081_87_15_4400_102.png 200946_081_51_30_4400_102.png 13 | 201020_041_226_30_3900_159.png 201020_041_154_10_3900_159.png 14 | 201078_024_133_10_6300_142.png 201078_024_36_40_6300_142.png 15 | 201147_060_293_15_5300_172.png 201147_060_71_35_5300_172.png 16 | 201199_053_349_45_4400_121.png 201199_053_15_35_4400_121.png 17 | 201306_017_169_5_4500_117.png 201306_017_108_15_4500_117.png 18 | 201380_003_71_30_4100_158.png 201380_003_72_50_4100_158.png 19 | 201413_001_221_10_3600_126.png 201413_001_174_5_3600_126.png 20 | 201490_071_354_50_3000_160.png 201490_071_97_35_3000_160.png 21 | 201731_096_41_15_6400_194.png 201731_096_82_20_6400_194.png 22 | 201856_053_66_5_5800_136.png 201856_053_287_20_5800_136.png 23 | 202102_008_221_35_3200_108.png 202102_008_185_15_3200_108.png 24 | 202195_063_123_30_2500_180.png 202195_063_323_15_2500_180.png 25 | 202216_094_5_25_4500_146.png 202216_094_107_35_4500_146.png 26 | 202406_041_231_0_5200_146.png 202406_041_46_35_5200_146.png 27 | 202592_043_5_20_4400_170.png 202592_043_133_15_4400_170.png 28 | 202721_040_303_5_3000_194.png 202721_040_195_25_3000_194.png 29 | 202746_016_5_5_5600_132.png 202746_016_236_0_5600_132.png 30 | 203418_047_205_25_5800_125.png 203418_047_226_40_5800_125.png 31 | 203467_044_56_40_4000_138.png 203467_044_241_25_4000_138.png 32 | 203664_054_41_10_2600_157.png 203664_054_10_50_2600_157.png 33 | 203717_099_46_45_2400_175.png 203717_099_87_30_2400_175.png 34 | 203836_034_41_40_3100_199.png 203836_034_354_30_3100_199.png 35 | 204114_036_138_35_3600_140.png 204114_036_51_40_3600_140.png 36 | 204158_009_72_5_2800_172.png 204158_009_179_20_2800_172.png 37 | 204412_038_308_50_4600_142.png 204412_038_138_25_4600_142.png 38 | 204449_009_236_15_3400_134.png 204449_009_30_50_3400_134.png 39 | 204500_017_61_35_2600_160.png 204500_017_221_15_2600_160.png 40 | 204562_030_169_25_4800_167.png 204562_030_324_5_4800_167.png 41 | 204893_091_128_40_6400_198.png 204893_091_221_40_6400_198.png 42 | 205008_075_216_40_6300_121.png 205008_075_123_50_6300_121.png 43 | 205050_019_133_10_4200_119.png 205050_019_180_20_4200_119.png 44 | 205066_023_143_45_2300_159.png 205066_023_293_20_2300_159.png 45 | 205225_000_25_5_4400_106.png 205225_000_87_20_4400_106.png 46 | 205251_075_267_50_3800_125.png 205251_075_77_30_3800_125.png 47 | 205340_064_180_25_4100_122.png 205340_064_128_5_4100_122.png 48 | 205376_015_288_50_3500_155.png 205376_015_349_5_3500_155.png 49 | 205491_006_113_0_5700_146.png 205491_006_318_20_5700_146.png 50 | 205608_021_288_45_6300_129.png 205608_021_252_30_6300_129.png 51 | 205743_019_359_50_4100_157.png 205743_019_185_15_4100_157.png 52 | 205810_045_215_30_3600_157.png 205810_045_77_0_3600_157.png 53 | 205880_067_20_45_2900_116.png 205880_067_252_5_2900_116.png 54 | 205892_071_113_45_6000_166.png 205892_071_144_25_6000_166.png 55 | 205920_029_313_0_4100_185.png 205920_029_164_15_4100_185.png 56 | 206001_076_359_10_3500_167.png 206001_076_5_35_3500_167.png 57 | 206064_026_180_5_5700_182.png 206064_026_267_50_5700_182.png 58 | 206080_009_174_5_6100_104.png 206080_009_282_20_6100_104.png 59 | 206387_048_339_50_2500_157.png 206387_048_174_15_2500_157.png 60 | 206616_028_92_30_4600_107.png 206616_028_339_35_4600_107.png 61 | 206784_075_344_35_4000_116.png 206784_075_221_40_4000_116.png 62 | 206860_059_195_5_3800_144.png 206860_059_277_20_3800_144.png 63 | 206947_033_143_10_5500_116.png 206947_033_174_10_5500_116.png 64 | 206960_008_339_50_4200_181.png 206960_008_30_40_4200_181.png 65 | 207032_081_329_50_6400_116.png 207032_081_298_35_6400_116.png 66 | 207049_066_282_45_4000_138.png 207049_066_56_50_4000_138.png 67 | 207265_036_46_10_4000_163.png 207265_036_185_10_4000_163.png 68 | 207358_031_51_5_4100_170.png 207358_031_277_45_4100_170.png 69 | 207383_032_349_20_5300_173.png 207383_032_133_10_5300_173.png 70 | 207748_050_128_15_5500_106.png 207748_050_329_10_5500_106.png 71 | 207891_059_236_20_4300_108.png 207891_059_190_15_4300_108.png 72 | 208114_212_344_50_5900_128.png 208114_212_236_45_5900_128.png 73 | 208259_232_56_20_5100_127.png 208259_232_30_10_5100_127.png 74 | 208323_248_257_10_2400_179.png 208323_248_226_5_2400_179.png 75 | 208376_237_0_20_4900_181.png 208376_237_133_30_4900_181.png 76 | 208385_208_313_50_5600_114.png 208385_208_77_25_5600_114.png 77 | 208523_212_92_50_5700_120.png 208523_212_144_45_5700_120.png 78 | 208663_221_123_10_4200_178.png 208663_221_313_40_4200_178.png 79 | 208683_242_118_20_2700_110.png 208683_242_277_5_2700_110.png 80 | 208787_219_323_40_5200_111.png 208787_219_107_45_5200_111.png 81 | 208846_234_231_20_5200_170.png 208846_234_252_10_5200_170.png 82 | 208934_201_97_30_5300_124.png 208934_201_30_30_5300_124.png 83 | 208987_218_56_25_3000_142.png 208987_218_257_30_3000_142.png 84 | 209077_238_216_40_4800_159.png 209077_238_143_40_4800_159.png 85 | 209308_229_354_25_3100_162.png 209308_229_56_15_3100_162.png 86 | 209366_222_231_50_2400_111.png 209366_222_154_10_2400_111.png 87 | 209401_210_149_10_6300_135.png 209401_210_252_45_6300_135.png 88 | 209456_249_30_35_3000_124.png 209456_249_339_5_3000_124.png 89 | 209526_246_10_45_5400_143.png 209526_246_138_40_5400_143.png 90 | 209740_215_226_10_6100_107.png 209740_215_252_35_6100_107.png 91 | 209774_216_87_30_5800_130.png 209774_216_308_10_5800_130.png 92 | 209841_250_252_15_5200_147.png 209841_250_102_35_5200_147.png 93 | 209909_209_92_30_5000_197.png 209909_209_159_5_5000_197.png 94 | 209916_213_159_20_5500_138.png 209916_213_318_35_5500_138.png 95 | 209990_219_71_5_6300_167.png 209990_219_10_15_6300_167.png 96 | 209991_219_287_45_5900_122.png 209991_219_216_50_5900_122.png 97 | 210029_206_257_30_5200_138.png 210029_206_246_5_5200_138.png 98 | 210106_234_236_5_5900_140.png 210106_234_35_40_5900_140.png 99 | 210117_250_97_50_4400_109.png 210117_250_41_30_4400_109.png 100 | 210163_222_66_25_5700_119.png 210163_222_97_50_5700_119.png 101 | -------------------------------------------------------------------------------- /data/anno_RSR/AnyLighttest_pairs_qualitative.txt: -------------------------------------------------------------------------------- 1 | scene_05 01193_0034_225_35_255_255_255_5_225_1_5.jpg 01205_0034_0_27_255_212_47_5_225_2_8.jpg 2 | scene_05 01204_0034_315_35_255_212_47_5_225_2_7.jpg 01212_0034_270_27_255_159_255_5_225_3_6.jpg 3 | scene_05 01244_0035_90_27_255_159_255_5_270_3_2.jpg 01254_0035_135_35_255_0_255_5_270_4_3.jpg 4 | scene_05 01282_0036_180_27_255_159_255_5_315_3_4.jpg 01267_0036_315_35_255_255_255_5_315_1_7.jpg 5 | scene_05 01327_0037_180_27_255_0_255_5_0_4_4.jpg 01311_0037_270_27_255_212_47_5_0_2_6.jpg 6 | scene_05 01344_0038_135_35_255_212_47_5_45_2_3.jpg 01359_0038_45_35_255_159_255_5_45_3_9.jpg 7 | scene_05 01353_0038_135_35_255_159_255_5_45_3_3.jpg 01346_0038_225_35_255_212_47_5_45_2_5.jpg 8 | scene_05 01397_0039_90_27_255_0_255_5_90_4_2.jpg 01389_0039_135_35_255_159_255_5_90_3_3.jpg 9 | scene_05 01410_0040_270_27_255_255_255_5_135_1_6.jpg 01437_0040_270_27_255_0_255_5_135_4_6.jpg 10 | scene_05 01411_0040_315_35_255_255_255_5_135_1_7.jpg 01436_0040_225_35_255_0_255_5_135_4_5.jpg 11 | scene_06 01519_0043_315_35_255_255_255_6_270_1_7.jpg 01546_0043_315_35_255_166_255_6_270_4_7.jpg 12 | scene_06 01525_0043_180_27_255_239_153_6_270_2_4.jpg 01534_0043_180_27_255_122_170_6_270_3_4.jpg 13 | scene_06 01533_0043_135_35_255_122_170_6_270_3_3.jpg 01544_0043_225_35_255_166_255_6_270_4_5.jpg 14 | scene_06 01543_0043_180_27_255_166_255_6_270_4_4.jpg 01536_0043_270_27_255_122_170_6_270_3_6.jpg 15 | scene_06 01547_0043_0_27_255_166_255_6_270_4_8.jpg 01513_0043_0_0_255_255_255_6_270_1_1.jpg 16 | scene_06 01550_0044_90_27_255_255_255_6_315_1_2.jpg 01570_0044_180_27_255_122_170_6_315_3_4.jpg 17 | scene_06 01640_0046_90_27_255_122_170_6_45_3_2.jpg 01636_0046_315_35_255_239_153_6_45_2_7.jpg 18 | scene_06 01652_0046_225_35_255_166_255_6_45_4_5.jpg 01623_0046_135_35_255_255_255_6_45_1_3.jpg 19 | scene_06 01655_0046_0_27_255_166_255_6_45_4_8.jpg 01630_0046_0_0_255_239_153_6_45_2_1.jpg 20 | scene_06 01662_0047_270_27_255_255_255_6_90_1_6.jpg 01664_0047_0_27_255_255_255_6_90_1_8.jpg 21 | scene_06 01663_0047_315_35_255_255_255_6_90_1_7.jpg 01683_0047_45_35_255_122_170_6_90_3_9.jpg 22 | scene_06 01701_0048_45_35_255_255_255_6_135_1_9.jpg 01717_0048_315_35_255_122_170_6_135_3_7.jpg 23 | scene_11 02919_0082_135_35_255_255_255_11_225_1_3.jpg 02936_0082_90_27_255_250_88_11_225_3_2.jpg 24 | scene_11 02936_0082_90_27_255_250_88_11_225_3_2.jpg 02930_0082_225_35_255_242_170_11_225_2_5.jpg 25 | scene_11 02958_0083_270_27_255_255_255_11_270_1_6.jpg 02974_0083_180_27_255_250_88_11_270_3_4.jpg 26 | scene_11 03019_0084_180_27_255_129_255_11_315_4_4.jpg 02989_0084_0_0_255_255_255_11_315_1_1.jpg 27 | scene_11 03030_0085_270_27_255_255_255_11_0_1_6.jpg 03045_0085_135_35_255_250_88_11_0_3_3.jpg 28 | scene_11 03033_0085_45_35_255_255_255_11_0_1_9.jpg 03026_0085_90_27_255_255_255_11_0_1_2.jpg 29 | scene_11 03038_0085_225_35_255_242_170_11_0_2_5.jpg 03041_0085_0_27_255_242_170_11_0_2_8.jpg 30 | scene_11 03101_0087_225_35_255_255_255_11_90_1_5.jpg 03115_0087_0_0_255_250_88_11_90_3_1.jpg 31 | scene_11 03140_0088_0_27_255_255_255_11_135_1_8.jpg 03133_0088_0_0_255_255_255_11_135_1_1.jpg 32 | scene_11 03168_0088_45_35_255_129_255_11_135_4_9.jpg 03150_0088_45_35_255_242_170_11_135_2_9.jpg 33 | scene_24 06694_0186_315_35_255_204_255_24_225_4_7.jpg 06680_0186_90_27_255_255_127_24_225_3_2.jpg 34 | scene_24 06712_0187_315_35_255_251_226_24_270_2_7.jpg 06724_0187_0_0_255_204_255_24_270_4_1.jpg 35 | scene_24 06719_0187_225_35_255_255_127_24_270_3_5.jpg 06703_0187_315_35_255_255_255_24_270_1_7.jpg 36 | scene_24 06725_0187_90_27_255_204_255_24_270_4_2.jpg 06724_0187_0_0_255_204_255_24_270_4_1.jpg 37 | scene_24 06732_0187_45_35_255_204_255_24_270_4_9.jpg 06698_0187_90_27_255_255_255_24_270_1_2.jpg 38 | scene_24 06749_0188_0_27_255_251_226_24_315_2_8.jpg 06755_0188_225_35_255_255_127_24_315_3_5.jpg 39 | scene_24 06874_0191_315_35_255_204_255_24_90_4_7.jpg 06861_0191_135_35_255_255_127_24_90_3_3.jpg 40 | scene_24 06877_0192_0_0_255_255_255_24_135_1_1.jpg 06911_0192_0_27_255_204_255_24_135_4_8.jpg 41 | scene_31 08676_0241_45_35_255_0_255_31_180_4_9.jpg 08661_0241_135_35_255_209_255_31_180_3_3.jpg 42 | scene_31 08718_0243_270_27_255_255_255_31_270_1_6.jpg 08723_0243_90_27_255_236_137_31_270_2_2.jpg 43 | scene_31 08719_0243_315_35_255_255_255_31_270_1_7.jpg 08717_0243_225_35_255_255_255_31_270_1_5.jpg 44 | scene_31 08741_0243_90_27_255_0_255_31_270_4_2.jpg 08739_0243_45_35_255_209_255_31_270_3_9.jpg 45 | scene_31 08748_0243_45_35_255_0_255_31_270_4_9.jpg 08719_0243_315_35_255_255_255_31_270_1_7.jpg 46 | scene_31 08762_0244_225_35_255_236_137_31_315_2_5.jpg 08759_0244_90_27_255_236_137_31_315_2_2.jpg 47 | scene_31 08776_0244_0_0_255_0_255_31_315_4_1.jpg 08751_0244_135_35_255_255_255_31_315_1_3.jpg 48 | scene_31 08811_0245_45_35_255_209_255_31_0_3_9.jpg 08818_0245_315_35_255_0_255_31_0_4_7.jpg 49 | scene_31 08832_0246_135_35_255_236_137_31_45_2_3.jpg 08842_0246_180_27_255_209_255_31_45_3_4.jpg 50 | scene_31 08834_0246_225_35_255_236_137_31_45_2_5.jpg 08850_0246_135_35_255_0_255_31_45_4_3.jpg 51 | scene_31 08895_0248_135_35_255_255_255_31_135_1_3.jpg 08902_0248_0_0_255_236_137_31_135_2_1.jpg 52 | scene_46 12969_0361_45_35_255_255_255_46_180_1_9.jpg 12961_0361_0_0_255_255_255_46_180_1_1.jpg 53 | scene_46 13004_0362_0_27_255_255_255_46_225_1_8.jpg 13000_0362_180_27_255_255_255_46_225_1_4.jpg 54 | scene_46 13016_0362_90_27_255_105_246_46_225_3_2.jpg 12997_0362_0_0_255_255_255_46_225_1_1.jpg 55 | scene_46 13061_0363_90_27_255_222_255_46_270_4_2.jpg 13067_0363_0_27_255_222_255_46_270_4_8.jpg 56 | scene_46 13150_0366_0_0_255_248_203_46_45_2_1.jpg 13162_0366_180_27_255_105_246_46_45_3_4.jpg 57 | scene_46 13172_0366_225_35_255_222_255_46_45_4_5.jpg 13149_0366_45_35_255_255_255_46_45_1_9.jpg 58 | scene_46 13181_0367_225_35_255_255_255_46_90_1_5.jpg 13200_0367_270_27_255_105_246_46_90_3_6.jpg 59 | scene_46 13183_0367_315_35_255_255_255_46_90_1_7.jpg 13178_0367_90_27_255_255_255_46_90_1_2.jpg 60 | scene_46 13212_0367_45_35_255_222_255_46_90_4_9.jpg 13186_0367_0_0_255_248_203_46_90_2_1.jpg 61 | scene_46 13225_0368_180_27_255_248_203_46_135_2_4.jpg 13241_0368_90_27_255_222_255_46_135_4_2.jpg 62 | scene_46 13226_0368_225_35_255_248_203_46_135_2_5.jpg 13240_0368_0_0_255_222_255_46_135_4_1.jpg 63 | scene_51 14415_0401_270_27_255_211_41_51_180_2_6.jpg 14420_0401_90_27_255_198_158_51_180_3_2.jpg 64 | scene_51 14481_0403_45_35_255_255_255_51_270_1_9.jpg 14480_0403_0_27_255_255_255_51_270_1_8.jpg 65 | scene_51 14508_0403_45_35_255_192_206_51_270_4_9.jpg 14477_0403_225_35_255_255_255_51_270_1_5.jpg 66 | scene_51 14560_0405_315_35_255_211_41_51_0_2_7.jpg 14567_0405_225_35_255_198_158_51_0_3_5.jpg 67 | scene_51 14606_0406_0_27_255_198_158_51_45_3_8.jpg 14611_0406_180_27_255_192_206_51_45_4_4.jpg 68 | scene_51 14619_0407_135_35_255_255_255_51_90_1_3.jpg 14628_0407_135_35_255_211_41_51_90_2_3.jpg 69 | scene_51 14688_0408_45_35_255_192_206_51_135_4_9.jpg 14671_0408_0_0_255_198_158_51_135_3_1.jpg 70 | scene_58 16449_0457_270_27_255_161_255_58_180_4_6.jpg 16436_0457_90_27_255_118_186_58_180_3_2.jpg 71 | scene_58 16521_0459_270_27_255_161_255_58_270_4_6.jpg 16491_0459_135_35_255_255_255_58_270_1_3.jpg 72 | scene_58 16522_0459_315_35_255_161_255_58_270_4_7.jpg 16508_0459_90_27_255_118_186_58_270_3_2.jpg 73 | scene_58 16527_0460_135_35_255_255_255_58_315_1_3.jpg 16558_0460_315_35_255_161_255_58_315_4_7.jpg 74 | scene_58 16550_0460_0_27_255_118_186_58_315_3_8.jpg 16547_0460_225_35_255_118_186_58_315_3_5.jpg 75 | scene_58 16595_0461_0_27_255_161_255_58_0_4_8.jpg 16574_0461_225_35_255_248_205_58_0_2_5.jpg 76 | scene_58 16627_0462_180_27_255_161_255_58_45_4_4.jpg 16600_0462_180_27_255_255_255_58_45_1_4.jpg 77 | scene_58 16636_0463_180_27_255_255_255_58_90_1_4.jpg 16639_0463_315_35_255_255_255_58_90_1_7.jpg 78 | scene_58 16675_0464_315_35_255_255_255_58_135_1_7.jpg 16680_0464_135_35_255_248_205_58_135_2_3.jpg 79 | scene_58 16679_0464_90_27_255_248_205_58_135_2_2.jpg 16694_0464_0_27_255_118_186_58_135_3_8.jpg 80 | scene_58 16683_0464_270_27_255_248_205_58_135_2_6.jpg 16684_0464_315_35_255_248_205_58_135_2_7.jpg 81 | scene_70 19880_0553_0_27_255_255_255_70_180_1_8.jpg 19878_0553_270_27_255_255_255_70_180_1_6.jpg 82 | scene_70 19906_0553_315_35_255_189_132_70_180_4_7.jpg 19891_0553_0_0_255_255_90_70_180_3_1.jpg 83 | scene_70 19927_0554_0_0_255_255_90_70_225_3_1.jpg 19942_0554_315_35_255_189_132_70_225_4_7.jpg 84 | scene_70 19984_0556_180_27_255_255_255_70_315_1_4.jpg 19986_0556_270_27_255_255_255_70_315_1_6.jpg 85 | scene_70 19985_0556_225_35_255_255_255_70_315_1_5.jpg 20000_0556_90_27_255_255_90_70_315_3_2.jpg 86 | scene_70 20091_0559_135_35_255_255_255_70_90_1_3.jpg 20123_0559_0_27_255_189_132_70_90_4_8.jpg 87 | scene_70 20133_0560_45_35_255_255_255_70_135_1_9.jpg 20130_0560_270_27_255_255_255_70_135_1_6.jpg 88 | scene_74 21048_0585_270_27_253_255_153_74_180_3_6.jpg 21059_0585_0_27_255_154_255_74_180_4_8.jpg 89 | scene_74 21060_0585_45_35_255_154_255_74_180_4_9.jpg 21027_0585_135_35_255_255_255_74_180_1_3.jpg 90 | scene_74 21062_0586_90_27_255_255_255_74_225_1_2.jpg 21075_0586_270_27_255_240_155_74_225_2_6.jpg 91 | scene_74 21098_0587_90_27_255_255_255_74_270_1_2.jpg 21121_0587_315_35_253_255_153_74_270_3_7.jpg 92 | scene_74 21125_0587_90_27_255_154_255_74_270_4_2.jpg 21097_0587_0_0_255_255_255_74_270_1_1.jpg 93 | scene_74 21174_0589_270_27_255_255_255_74_0_1_6.jpg 21196_0589_0_0_255_154_255_74_0_4_1.jpg 94 | scene_74 21177_0589_45_35_255_255_255_74_0_1_9.jpg 21172_0589_180_27_255_255_255_74_0_1_4.jpg 95 | scene_74 21188_0589_90_27_253_255_153_74_0_3_2.jpg 21196_0589_0_0_255_154_255_74_0_4_1.jpg 96 | scene_74 21211_0590_315_35_255_255_255_74_45_1_7.jpg 21240_0590_45_35_255_154_255_74_45_4_9.jpg 97 | scene_74 21241_0591_0_0_255_255_255_74_90_1_1.jpg 21262_0591_180_27_253_255_153_74_90_3_4.jpg 98 | scene_74 21243_0591_135_35_255_255_255_74_90_1_3.jpg 21266_0591_0_27_253_255_153_74_90_3_8.jpg 99 | scene_74 21253_0591_180_27_255_240_155_74_90_2_4.jpg 21244_0591_180_27_255_255_255_74_90_1_4.jpg 100 | scene_74 21261_0591_135_35_253_255_153_74_90_3_3.jpg 21258_0591_45_35_255_240_155_74_90_2_9.jpg 101 | -------------------------------------------------------------------------------- /data/anno_RSR/test.txt: -------------------------------------------------------------------------------- 1 | scene_05 2 | scene_06 3 | scene_11 4 | scene_24 5 | scene_31 6 | scene_46 7 | scene_51 8 | scene_58 9 | scene_70 10 | scene_74 11 | -------------------------------------------------------------------------------- /data/anno_RSR/train.txt: -------------------------------------------------------------------------------- 1 | scene_01 2 | scene_03 3 | scene_07 4 | scene_08 5 | scene_10 6 | scene_13 7 | scene_14 8 | scene_15 9 | scene_16 10 | scene_17 11 | scene_18 12 | scene_20 13 | scene_21 14 | scene_22 15 | scene_23 16 | scene_26 17 | scene_27 18 | scene_29 19 | scene_30 20 | scene_32 21 | scene_33 22 | scene_34 23 | scene_35 24 | scene_36 25 | scene_37 26 | scene_38 27 | scene_39 28 | scene_40 29 | scene_41 30 | scene_43 31 | scene_44 32 | scene_45 33 | scene_47 34 | scene_48 35 | scene_50 36 | scene_52 37 | scene_53 38 | scene_54 39 | scene_55 40 | scene_56 41 | scene_57 42 | scene_59 43 | scene_60 44 | scene_61 45 | scene_62 46 | scene_63 47 | scene_64 48 | scene_65 49 | scene_66 50 | scene_67 51 | scene_68 52 | scene_69 53 | scene_71 54 | scene_72 55 | scene_73 56 | scene_75 57 | scene_76 58 | scene_78 59 | scene_79 60 | -------------------------------------------------------------------------------- /data/anno_RSR/val.txt: -------------------------------------------------------------------------------- 1 | scene_12 2 | scene_42 3 | scene_80 4 | -------------------------------------------------------------------------------- /data/anno_VIDIT/any2any/AnyLight_test_pairs_qualitative.txt: -------------------------------------------------------------------------------- 1 | Image003_2500_E.png Image003_3500_E.png 2 | Image003_2500_N.png Image003_4500_N.png 3 | Image003_2500_W.png Image003_5500_N.png 4 | Image003_3500_NE.png Image003_6500_S.png 5 | Image008_2500_W.png Image008_3500_S.png 6 | Image008_4500_S.png Image008_6500_SW.png 7 | Image008_5500_NE.png Image008_5500_S.png 8 | Image008_6500_NE.png Image008_6500_SW.png 9 | Image014_2500_NE.png Image014_4500_W.png 10 | Image014_2500_SW.png Image014_2500_N.png 11 | Image014_3500_S.png Image014_4500_S.png 12 | Image014_5500_S.png Image014_3500_NW.png 13 | Image044_2500_W.png Image044_5500_NW.png 14 | Image044_3500_W.png Image044_2500_E.png 15 | Image044_5500_SW.png Image044_6500_N.png 16 | Image044_6500_S.png Image044_5500_S.png 17 | Image062_3500_W.png Image062_3500_SE.png 18 | Image062_4500_S.png Image062_2500_NW.png 19 | Image062_5500_SE.png Image062_2500_NW.png 20 | Image062_5500_W.png Image062_6500_W.png 21 | Image063_3500_NE.png Image063_4500_N.png 22 | Image063_5500_NE.png Image063_3500_SE.png 23 | Image063_6500_E.png Image063_4500_E.png 24 | Image063_6500_SE.png Image063_5500_SE.png 25 | Image078_2500_E.png Image078_3500_N.png 26 | Image078_2500_N.png Image078_2500_NW.png 27 | Image078_4500_S.png Image078_6500_SW.png 28 | Image078_6500_SE.png Image078_5500_N.png 29 | Image080_2500_SE.png Image080_4500_SW.png 30 | Image080_3500_N.png Image080_3500_E.png 31 | Image080_5500_N.png Image080_3500_SE.png 32 | Image080_6500_N.png Image080_3500_NE.png 33 | Image081_2500_S.png Image081_3500_W.png 34 | Image081_4500_NW.png Image081_3500_NW.png 35 | Image081_5500_E.png Image081_3500_N.png 36 | Image081_6500_NE.png Image081_4500_SW.png 37 | Image097_2500_NE.png Image097_3500_S.png 38 | Image097_5500_S.png Image097_2500_NE.png 39 | Image097_5500_SE.png Image097_3500_NW.png 40 | Image097_5500_W.png Image097_6500_N.png 41 | Image098_2500_SE.png Image098_2500_S.png 42 | Image098_4500_W.png Image098_6500_E.png 43 | Image098_5500_SE.png Image098_3500_NW.png 44 | Image099_2500_SE.png Image099_4500_SW.png 45 | Image099_5500_N.png Image099_2500_W.png 46 | Image099_6500_NW.png Image099_5500_SW.png 47 | Image109_2500_W.png Image109_2500_SW.png 48 | Image109_5500_E.png Image109_3500_SE.png 49 | Image109_5500_W.png Image109_6500_SW.png 50 | Image110_3500_NE.png Image110_2500_SE.png 51 | Image110_5500_E.png Image110_3500_N.png 52 | Image110_5500_NE.png Image110_4500_NW.png 53 | Image123_2500_NE.png Image123_4500_NW.png 54 | Image123_2500_NW.png Image123_3500_NE.png 55 | Image123_6500_W.png Image123_2500_SW.png 56 | Image134_4500_E.png Image134_2500_S.png 57 | Image134_4500_S.png Image134_2500_NW.png 58 | Image134_4500_SW.png Image134_5500_N.png 59 | Image140_2500_NE.png Image140_3500_NW.png 60 | Image140_3500_W.png Image140_4500_S.png 61 | Image140_5500_W.png Image140_4500_S.png 62 | Image141_2500_NW.png Image141_2500_SW.png 63 | Image141_4500_E.png Image141_6500_SW.png 64 | Image141_6500_SE.png Image141_3500_SW.png 65 | Image146_2500_S.png Image146_4500_N.png 66 | Image146_5500_SW.png Image146_5500_N.png 67 | Image146_6500_NW.png Image146_2500_E.png 68 | Image149_3500_E.png Image149_5500_NE.png 69 | Image149_4500_W.png Image149_3500_SE.png 70 | Image149_6500_E.png Image149_4500_SW.png 71 | Image150_2500_NE.png Image150_6500_SW.png 72 | Image150_2500_W.png Image150_3500_NE.png 73 | Image150_3500_SW.png Image150_5500_W.png 74 | Image203_2500_SW.png Image203_2500_NE.png 75 | Image203_2500_W.png Image203_2500_NW.png 76 | Image203_4500_SW.png Image203_6500_NW.png 77 | Image205_2500_E.png Image205_3500_N.png 78 | Image205_3500_SE.png Image205_4500_E.png 79 | Image205_6500_SW.png Image205_2500_E.png 80 | Image206_2500_NW.png Image206_3500_S.png 81 | Image206_3500_NW.png Image206_6500_NE.png 82 | Image206_5500_S.png Image206_3500_S.png 83 | Image209_5500_NW.png Image209_6500_SW.png 84 | Image209_6500_NE.png Image209_6500_SW.png 85 | Image209_6500_SW.png Image209_3500_NE.png 86 | Image217_2500_E.png Image217_6500_SW.png 87 | Image217_2500_SE.png Image217_6500_S.png 88 | Image217_4500_N.png Image217_4500_SE.png 89 | Image244_5500_S.png Image244_4500_E.png 90 | Image244_6500_S.png Image244_6500_W.png 91 | Image244_6500_SE.png Image244_3500_E.png 92 | Image248_2500_S.png Image248_2500_SW.png 93 | Image248_4500_E.png Image248_3500_NW.png 94 | Image248_5500_S.png Image248_6500_SE.png 95 | Image274_2500_SW.png Image274_5500_S.png 96 | Image274_5500_W.png Image274_4500_E.png 97 | Image274_6500_N.png Image274_3500_S.png 98 | Image286_2500_N.png Image286_3500_NE.png 99 | Image286_4500_S.png Image286_6500_NE.png 100 | Image286_6500_S.png Image286_5500_S.png 101 | Image003_6500_N.png Image003_4500_E.png 102 | Image008_6500_N.png Image008_4500_E.png 103 | Image014_6500_N.png Image014_4500_E.png 104 | Image044_6500_N.png Image044_4500_E.png 105 | Image062_6500_N.png Image062_4500_E.png 106 | Image063_6500_N.png Image063_4500_E.png 107 | Image078_6500_N.png Image078_4500_E.png 108 | Image080_6500_N.png Image080_4500_E.png 109 | Image081_6500_N.png Image081_4500_E.png 110 | Image097_6500_N.png Image097_4500_E.png 111 | Image098_6500_N.png Image098_4500_E.png 112 | Image099_6500_N.png Image099_4500_E.png 113 | Image109_6500_N.png Image109_4500_E.png 114 | Image110_6500_N.png Image110_4500_E.png 115 | Image123_6500_N.png Image123_4500_E.png 116 | Image134_6500_N.png Image134_4500_E.png 117 | Image140_6500_N.png Image140_4500_E.png 118 | Image141_6500_N.png Image141_4500_E.png 119 | Image146_6500_N.png Image146_4500_E.png 120 | Image149_6500_N.png Image149_4500_E.png 121 | Image150_6500_N.png Image150_4500_E.png 122 | Image203_6500_N.png Image203_4500_E.png 123 | Image205_6500_N.png Image205_4500_E.png 124 | Image206_6500_N.png Image206_4500_E.png 125 | Image209_6500_N.png Image209_4500_E.png 126 | Image217_6500_N.png Image217_4500_E.png 127 | Image244_6500_N.png Image244_4500_E.png 128 | Image248_6500_N.png Image248_4500_E.png 129 | Image274_6500_N.png Image274_4500_E.png 130 | Image286_6500_N.png Image286_4500_E.png 131 | -------------------------------------------------------------------------------- /data/anno_VIDIT/any2any/test.txt: -------------------------------------------------------------------------------- 1 | Image003 2 | Image008 3 | Image014 4 | Image044 5 | Image062 6 | Image063 7 | Image078 8 | Image080 9 | Image081 10 | Image097 11 | Image098 12 | Image099 13 | Image109 14 | Image110 15 | Image123 16 | Image134 17 | Image140 18 | Image141 19 | Image146 20 | Image149 21 | Image150 22 | Image203 23 | Image205 24 | Image206 25 | Image209 26 | Image217 27 | Image244 28 | Image248 29 | Image274 30 | Image286 31 | -------------------------------------------------------------------------------- /data/anno_VIDIT/any2any/train.txt: -------------------------------------------------------------------------------- 1 | Image000 2 | Image001 3 | Image002 4 | Image004 5 | Image005 6 | Image006 7 | Image007 8 | Image009 9 | Image010 10 | Image011 11 | Image012 12 | Image013 13 | Image015 14 | Image016 15 | Image017 16 | Image018 17 | Image019 18 | Image020 19 | Image021 20 | Image022 21 | Image023 22 | Image024 23 | Image025 24 | Image026 25 | Image027 26 | Image028 27 | Image029 28 | Image030 29 | Image031 30 | Image033 31 | Image034 32 | Image035 33 | Image036 34 | Image037 35 | Image039 36 | Image040 37 | Image041 38 | Image042 39 | Image043 40 | Image045 41 | Image046 42 | Image047 43 | Image048 44 | Image049 45 | Image050 46 | Image051 47 | Image052 48 | Image053 49 | Image054 50 | Image055 51 | Image056 52 | Image057 53 | Image058 54 | Image059 55 | Image060 56 | Image061 57 | Image064 58 | Image065 59 | Image066 60 | Image067 61 | Image068 62 | Image069 63 | Image070 64 | Image071 65 | Image072 66 | Image073 67 | Image074 68 | Image075 69 | Image076 70 | Image077 71 | Image079 72 | Image083 73 | Image084 74 | Image085 75 | Image086 76 | Image087 77 | Image088 78 | Image089 79 | Image091 80 | Image092 81 | Image093 82 | Image094 83 | Image095 84 | Image096 85 | Image100 86 | Image101 87 | Image102 88 | Image103 89 | Image104 90 | Image105 91 | Image106 92 | Image107 93 | Image108 94 | Image112 95 | Image113 96 | Image114 97 | Image115 98 | Image116 99 | Image117 100 | Image119 101 | Image120 102 | Image121 103 | Image122 104 | Image124 105 | Image125 106 | Image126 107 | Image127 108 | Image128 109 | Image129 110 | Image130 111 | Image131 112 | Image132 113 | Image133 114 | Image135 115 | Image136 116 | Image137 117 | Image138 118 | Image139 119 | Image142 120 | Image143 121 | Image145 122 | Image148 123 | Image151 124 | Image152 125 | Image153 126 | Image154 127 | Image155 128 | Image156 129 | Image157 130 | Image158 131 | Image159 132 | Image160 133 | Image161 134 | Image162 135 | Image163 136 | Image164 137 | Image165 138 | Image166 139 | Image168 140 | Image169 141 | Image170 142 | Image171 143 | Image172 144 | Image173 145 | Image174 146 | Image175 147 | Image176 148 | Image177 149 | Image179 150 | Image181 151 | Image182 152 | Image183 153 | Image184 154 | Image185 155 | Image186 156 | Image187 157 | Image188 158 | Image189 159 | Image190 160 | Image191 161 | Image192 162 | Image193 163 | Image194 164 | Image195 165 | Image196 166 | Image197 167 | Image198 168 | Image199 169 | Image200 170 | Image201 171 | Image202 172 | Image204 173 | Image207 174 | Image208 175 | Image210 176 | Image211 177 | Image212 178 | Image213 179 | Image214 180 | Image215 181 | Image216 182 | Image218 183 | Image219 184 | Image220 185 | Image221 186 | Image222 187 | Image223 188 | Image224 189 | Image225 190 | Image226 191 | Image227 192 | Image228 193 | Image229 194 | Image230 195 | Image231 196 | Image232 197 | Image233 198 | Image234 199 | Image235 200 | Image236 201 | Image237 202 | Image238 203 | Image239 204 | Image240 205 | Image241 206 | Image242 207 | Image243 208 | Image245 209 | Image246 210 | Image247 211 | Image249 212 | Image250 213 | Image251 214 | Image252 215 | Image253 216 | Image254 217 | Image255 218 | Image256 219 | Image257 220 | Image258 221 | Image259 222 | Image260 223 | Image261 224 | Image262 225 | Image263 226 | Image265 227 | Image266 228 | Image268 229 | Image269 230 | Image270 231 | Image271 232 | Image272 233 | Image273 234 | Image275 235 | Image278 236 | Image279 237 | Image280 238 | Image281 239 | Image282 240 | Image283 241 | Image284 242 | Image285 243 | Image287 244 | Image288 245 | Image289 246 | Image290 247 | Image291 248 | Image292 249 | Image293 250 | Image294 251 | Image295 252 | Image296 253 | Image297 254 | Image298 255 | Image299 256 | -------------------------------------------------------------------------------- /data/anno_VIDIT/any2any/val.txt: -------------------------------------------------------------------------------- 1 | Image032 2 | Image038 3 | Image082 4 | Image090 5 | Image111 6 | Image118 7 | Image144 8 | Image147 9 | Image167 10 | Image178 11 | Image180 12 | Image264 13 | Image267 14 | Image276 15 | Image277 16 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets. 2 | 3 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. 4 | """ 5 | import random 6 | import numpy as np 7 | import torch.utils.data as data 8 | from PIL import Image 9 | import torchvision.transforms as transforms 10 | from abc import ABC, abstractmethod 11 | 12 | 13 | class BaseDataset(data.Dataset, ABC): 14 | """This class is an abstract base class (ABC) for datasets. 15 | 16 | To create a subclass, you need to implement the following four functions: 17 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 18 | -- <__len__>: return the size of dataset. 19 | -- <__getitem__>: get a data point. 20 | -- : (optionally) add dataset-specific options and set default options. 21 | """ 22 | 23 | def __init__(self, opt): 24 | """Initialize the class; save the options in the class 25 | 26 | Parameters: 27 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 28 | """ 29 | self.opt = opt 30 | self.dataroot = opt.dataroot 31 | 32 | @staticmethod 33 | def modify_commandline_options(parser, is_train): 34 | """Add new dataset-specific options, and rewrite default values for existing options. 35 | 36 | Parameters: 37 | parser -- original option parser 38 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 39 | 40 | Returns: 41 | the modified parser. 42 | """ 43 | return parser 44 | 45 | @abstractmethod 46 | def __len__(self): 47 | """Return the total number of images in the dataset.""" 48 | return 0 49 | 50 | @abstractmethod 51 | def __getitem__(self, index): 52 | """Return a data point and its metadata information. 53 | 54 | Parameters: 55 | index - - a random integer for data indexing 56 | 57 | Returns: 58 | a dictionary of data with their names. It ususally contains the data itself and its metadata information. 59 | """ 60 | pass 61 | 62 | 63 | def get_params(opt, size): 64 | w, h = size 65 | new_h = h 66 | new_w = w 67 | if opt.preprocess == 'resize_and_crop': 68 | new_h = new_w = opt.load_size 69 | elif opt.preprocess == 'scale_width_and_crop': 70 | new_w = opt.load_size 71 | new_h = opt.load_size * h // w 72 | 73 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) 74 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) 75 | 76 | flip = random.random() > 0.5 77 | 78 | return {'crop_pos': (x, y), 'flip': flip} 79 | 80 | 81 | def get_transform(opt, params=None, grayscale=False, method=transforms.InterpolationMode.BICUBIC, convert=True): 82 | transform_list = [] 83 | if grayscale: 84 | transform_list.append(transforms.Grayscale(1)) 85 | if 'resize' in opt.preprocess: 86 | osize = [opt.load_size, opt.load_size] 87 | transform_list.append(transforms.Resize(osize, method)) 88 | elif 'scale_width' in opt.preprocess: 89 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method))) 90 | 91 | if 'crop' in opt.preprocess: 92 | if params is None: 93 | transform_list.append(transforms.RandomCrop(opt.crop_size)) 94 | else: 95 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) 96 | 97 | if opt.preprocess == 'none': 98 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method))) 99 | 100 | if not opt.no_flip: 101 | if params is None: 102 | transform_list.append(transforms.RandomHorizontalFlip()) 103 | elif params['flip']: 104 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 105 | 106 | if convert: 107 | transform_list += [transforms.ToTensor()] 108 | if opt.normalization_type == '[-1, 1]': 109 | if grayscale: 110 | transform_list += [transforms.Normalize((0.5,), (0.5,))] 111 | else: 112 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 113 | elif opt.normalization_type == '[0, 1]': 114 | pass 115 | else: 116 | raise Exception("normalization_type error") 117 | return transforms.Compose(transform_list) 118 | 119 | 120 | def __make_power_2(img, base, method=Image.BICUBIC): 121 | ow, oh = img.size 122 | h = int(round(oh / base) * base) 123 | w = int(round(ow / base) * base) 124 | if h == oh and w == ow: 125 | return img 126 | 127 | __print_size_warning(ow, oh, w, h) 128 | return img.resize((w, h), method) 129 | 130 | 131 | def __scale_width(img, target_size, crop_size, method=Image.BICUBIC): 132 | ow, oh = img.size 133 | if ow == target_size and oh >= crop_size: 134 | return img 135 | w = target_size 136 | h = int(max(target_size * oh / ow, crop_size)) 137 | return img.resize((w, h), method) 138 | 139 | 140 | def __crop(img, pos, size): 141 | ow, oh = img.size 142 | x1, y1 = pos 143 | tw = th = size 144 | if (ow > tw or oh > th): 145 | return img.crop((x1, y1, x1 + tw, y1 + th)) 146 | return img 147 | 148 | 149 | def __flip(img, flip): 150 | if flip: 151 | return img.transpose(Image.FLIP_LEFT_RIGHT) 152 | return img 153 | 154 | 155 | def __print_size_warning(ow, oh, w, h): 156 | """Print warning information about image size(only print once)""" 157 | if not hasattr(__print_size_warning, 'has_printed'): 158 | print("The image size needs to be a multiple of 4. " 159 | "The loaded image size was (%d, %d), so it was adjusted to " 160 | "(%d, %d). This adjustment will be done to all images " 161 | "whose sizes are not multiples of 4" % (ow, oh, w, h)) 162 | __print_size_warning.has_printed = True 163 | -------------------------------------------------------------------------------- /data/multi_illumination/crop_resize.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script is to crop and resize the images of Multi-illumination dataset 3 | """ 4 | import os 5 | import cv2 6 | from tqdm import tqdm 7 | 8 | def transfer_img(ori_name, new_name): 9 | ori_image = cv2.imread(ori_name, cv2.IMREAD_UNCHANGED) 10 | shape = ori_image.shape 11 | height = shape[0] 12 | width = shape[1] 13 | left = int(0.5 * width - 0.5 * height) 14 | right = left + height 15 | crop_image = ori_image[:, left:right, :] 16 | resized_image = cv2.resize(crop_image, (256, 256), interpolation=cv2.INTER_CUBIC) 17 | cv2.imwrite(new_name, resized_image) 18 | 19 | 20 | def transfer_probes(ori_name, new_name): 21 | for suffix in ["chrome", "gray"]: 22 | name1 = ori_name + suffix + "256.jpg" 23 | name2 = new_name + suffix + ".jpg" 24 | ori_image = cv2.imread(name1, cv2.IMREAD_UNCHANGED) 25 | resized_image = cv2.resize(ori_image, (64, 64), interpolation=cv2.INTER_CUBIC) 26 | cv2.imwrite(name2, resized_image) 27 | 28 | 29 | dataset_path_ori = '/home/yyang/dataset/Multi_Illumination/' 30 | dataset_path_new = '/home/yyang/dataset/Multi_Illumination_small/' 31 | num_img_scene = 25 32 | 33 | for subset in os.listdir(dataset_path_ori): 34 | # subset: train, test 35 | if not os.path.exists(os.path.join(dataset_path_new, subset)): 36 | os.mkdir(os.path.join(dataset_path_new, subset)) 37 | for scene in tqdm(os.listdir(os.path.join(dataset_path_ori, subset))): 38 | # scenes 39 | if not os.path.exists(os.path.join(dataset_path_new, subset, scene)): 40 | os.mkdir(os.path.join(dataset_path_new, subset, scene)) 41 | if not os.path.exists(os.path.join(dataset_path_new, subset, scene, 'probes')): 42 | os.mkdir(os.path.join(dataset_path_new, subset, scene, 'probes')) 43 | current_path_ori = os.path.join(dataset_path_ori, subset, scene) 44 | current_path_new = os.path.join(dataset_path_new, subset, scene) 45 | for img_id in range(num_img_scene): 46 | img_name = "dir_{}_mip2.jpg".format(str(img_id)) 47 | transfer_img(os.path.join(current_path_ori, img_name), os.path.join(current_path_new, img_name)) 48 | probes_name = "probes/dir_{}_".format(str(img_id)) 49 | transfer_probes(os.path.join(current_path_ori, probes_name), os.path.join(current_path_new, probes_name)) 50 | 51 | 52 | -------------------------------------------------------------------------------- /data/multi_illumination/test_qualitative.txt: -------------------------------------------------------------------------------- 1 | everett_dining1 1 7 2 | everett_dining1 17 0 3 | everett_dining1 22 14 4 | everett_dining1 3 15 5 | everett_dining2 0 8 6 | everett_dining2 1 12 7 | everett_dining2 6 24 8 | everett_kitchen12 13 6 9 | everett_kitchen12 16 2 10 | everett_kitchen12 9 16 11 | everett_kitchen14 17 19 12 | everett_kitchen14 19 6 13 | everett_kitchen14 24 18 14 | everett_kitchen14 8 0 15 | everett_kitchen17 11 23 16 | everett_kitchen17 15 17 17 | everett_kitchen17 7 12 18 | everett_kitchen18 12 1 19 | everett_kitchen18 24 9 20 | everett_kitchen2 19 13 21 | everett_kitchen2 23 2 22 | everett_kitchen2 3 10 23 | everett_kitchen2 8 1 24 | everett_kitchen4 17 13 25 | everett_kitchen4 21 4 26 | everett_kitchen5 0 8 27 | everett_kitchen5 10 1 28 | everett_kitchen5 3 15 29 | everett_kitchen6 0 2 30 | everett_kitchen6 1 10 31 | everett_kitchen6 5 24 32 | everett_kitchen7 0 11 33 | everett_kitchen7 13 16 34 | everett_kitchen8 1 20 35 | everett_kitchen8 8 13 36 | everett_kitchen9 0 12 37 | everett_kitchen9 11 15 38 | everett_kitchen9 19 8 39 | everett_kitchen9 2 9 40 | everett_kitchen9 20 21 41 | everett_kitchen9 21 8 42 | everett_living2 10 21 43 | everett_living2 11 9 44 | everett_living2 21 8 45 | everett_living4 14 13 46 | everett_living4 18 23 47 | everett_living4 2 23 48 | everett_living4 5 22 49 | everett_lobby1 13 1 50 | everett_lobby1 14 0 51 | everett_lobby1 22 12 52 | everett_lobby1 5 1 53 | everett_lobby11 17 6 54 | everett_lobby11 2 13 55 | everett_lobby11 21 22 56 | everett_lobby12 16 2 57 | everett_lobby12 17 9 58 | everett_lobby12 19 17 59 | everett_lobby12 4 12 60 | everett_lobby13 0 5 61 | everett_lobby13 17 19 62 | everett_lobby13 18 24 63 | everett_lobby13 23 0 64 | everett_lobby13 4 17 65 | everett_lobby14 1 17 66 | everett_lobby14 12 10 67 | everett_lobby14 19 12 68 | everett_lobby14 5 4 69 | everett_lobby14 7 22 70 | everett_lobby15 10 3 71 | everett_lobby15 23 6 72 | everett_lobby15 5 4 73 | everett_lobby16 16 13 74 | everett_lobby16 17 16 75 | everett_lobby16 23 13 76 | everett_lobby16 3 6 77 | everett_lobby16 9 22 78 | everett_lobby17 11 14 79 | everett_lobby17 18 14 80 | everett_lobby17 2 24 81 | everett_lobby17 7 23 82 | everett_lobby17 8 14 83 | everett_lobby19 11 20 84 | everett_lobby19 16 0 85 | everett_lobby19 2 23 86 | everett_lobby19 8 14 87 | everett_lobby20 3 17 88 | everett_lobby20 7 3 89 | everett_lobby20 8 19 90 | everett_lobby3 15 14 91 | everett_lobby3 16 10 92 | everett_lobby3 23 17 93 | everett_lobby3 6 17 94 | everett_lobby4 16 18 95 | everett_lobby4 8 10 96 | everett_lobby6 11 14 97 | everett_lobby6 18 23 98 | everett_lobby6 19 1 99 | everett_lobby6 23 3 100 | everett_lobby6 9 22 101 | -------------------------------------------------------------------------------- /data/relighting_dataset_single_image.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from data.base_dataset import BaseDataset, get_params, get_transform 3 | from PIL import Image 4 | import torch 5 | import random 6 | import math 7 | from util.k_to_rgb import convert_K_to_RGB 8 | from util.util import PARA_NOR 9 | LIGHT_INDEX_IN_NAME = 2 10 | 11 | 12 | def read_anno_single_image(anno_filename): 13 | """Read the name of images from the anno file yielded by prepare_dataset.py. 14 | For each image, we must know which scene it belongs to. 15 | """ 16 | # read lines from anno 17 | f = open(anno_filename, 'r') 18 | file_names = [] 19 | for line in f.readlines(): 20 | line = line.strip('\n') 21 | file_names.append(line) 22 | # Don't sort because the test and val sets need the determined order. 23 | scene_all = [] 24 | scene_index = [] # record the scene index of each image 25 | last_scene = " " 26 | scene = [] 27 | for x in file_names: 28 | x_scene = "{}_{}".format(x.split('_')[0], x.split('_')[1]) 29 | if x_scene != last_scene: 30 | if len(scene) > 0: 31 | scene_all.append(scene) 32 | scene = [] 33 | scene.append(x) 34 | scene_index.append(len(scene_all)) 35 | last_scene = x_scene 36 | if len(scene) > 0: 37 | scene_all.append(scene) 38 | 39 | return file_names, scene_index, scene_all 40 | 41 | 42 | def image_name2light_condition(img_name): 43 | factor_deg2rad = math.pi / 180.0 44 | names = os.path.splitext(img_name)[0].split('_') 45 | pan = float(names[LIGHT_INDEX_IN_NAME]) * factor_deg2rad 46 | tilt = float(names[LIGHT_INDEX_IN_NAME+1]) * factor_deg2rad 47 | color_temp = int(names[LIGHT_INDEX_IN_NAME+2]) 48 | # transform light position to cos and sin 49 | light_position = [math.cos(pan), math.sin(pan), math.cos(tilt), math.sin(tilt)] 50 | # normalize the light position to [0, 1] 51 | light_position[:2] = [x * PARA_NOR['pan_a'] + PARA_NOR['pan_b'] for x in light_position[:2]] 52 | light_position[2:] = [x * PARA_NOR['tilt_a'] + PARA_NOR['tilt_b'] for x in light_position[2:]] 53 | # transform light temperature to RGB, and normalize it. 54 | light_color = list(map(lambda x: x / 255.0, convert_K_to_RGB(color_temp))) 55 | light_position_color = light_position + light_color 56 | return torch.tensor(light_position_color) 57 | 58 | 59 | def read_component(dataroot, component, file_name, img_transform, r_pil=False): 60 | component_path = "{}{}/{}".format(dataroot, component, file_name) 61 | if not os.path.exists(component_path): 62 | raise Exception("RelightingDataset __getitem__ error") 63 | 64 | img_component = Image.open(component_path).convert('RGB') 65 | img_tensor = img_transform(img_component) 66 | if r_pil: 67 | return img_tensor, img_component 68 | return img_tensor 69 | 70 | 71 | def get_data(file_name_input, file_name_output, dataroot, img_transform, multiple_replace_image): 72 | data = {} # output dictionary 73 | 74 | data['scene_label'] = file_name_input 75 | data['light_position_color_original'] = image_name2light_condition(file_name_input) 76 | data['light_position_color_new'] = image_name2light_condition(file_name_output) 77 | 78 | # Reflectance_output 79 | data['Reflectance_output'] = read_component(dataroot, 'Reflectance', file_name_input, img_transform) 80 | data['Shading_ori'], s_ori = read_component(dataroot, 'Shading', file_name_input, img_transform, r_pil=True) 81 | data['Shading_output'], s_output = read_component(dataroot, 'Shading', file_name_output, img_transform, r_pil=True) 82 | if multiple_replace_image: 83 | data['Image_input'] = torch.mul(data['Reflectance_output'], data['Shading_ori']) 84 | data['Image_relighted'] = torch.mul(data['Reflectance_output'], data['Shading_output']) 85 | else: 86 | data['Image_input'] = read_component(dataroot, 'Image', file_name_input, img_transform) 87 | data['Image_relighted'] = read_component(dataroot, 'Image', file_name_output, img_transform) 88 | 89 | return data 90 | 91 | class RelightingDatasetSingleImage(BaseDataset): 92 | """A dataset class for relighting dataset. 93 | This dataset read data image by image. 94 | """ 95 | 96 | def __init__(self, opt, validation=False): 97 | """Initialize this dataset class. 98 | 99 | Parameters: 100 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 101 | """ 102 | BaseDataset.__init__(self, opt) 103 | if validation: 104 | anno_file = opt.anno_validation 105 | else: 106 | anno_file = opt.anno 107 | self.file_names, self.scene_index, self.scenes_list = read_anno_single_image(anno_file) 108 | 109 | def __getitem__(self, index): 110 | """Return a data point and its metadata information. 111 | 112 | Parameters: 113 | index - - a random integer for data indexing 114 | 115 | Returns a dictionary that contains 116 | 'Image_input': , 117 | 'light_position_color_new': , 118 | 'light_position_color_original': , 119 | 'Reflectance_output': , 120 | 'Shading_output': , 121 | 'Shading_ori': , 122 | 'Image_relighted': , 123 | 'scene_label': , 124 | """ 125 | # get parameters 126 | dataroot = self.dataroot 127 | img_size = self.opt.img_size 128 | multiple_replace_image = self.opt.multiple_replace_image 129 | 130 | index_file_names = index 131 | 132 | # get one image 133 | file_name_input = self.file_names[index_file_names] 134 | # get the scene_index 135 | scene_id = self.scene_index[index_file_names] 136 | scene = self.scenes_list[scene_id].copy() 137 | # remove the input image. 138 | scene.remove(file_name_input) 139 | id_in_scene = random.randrange(0, len(scene)) 140 | file_name_output = scene[id_in_scene] 141 | 142 | # get the parameters of data augmentation 143 | transform_params = get_params(self.opt, img_size) 144 | img_transform = get_transform(self.opt, transform_params) 145 | 146 | data = get_data(file_name_input, file_name_output, dataroot, img_transform, multiple_replace_image) 147 | 148 | return data 149 | 150 | def __len__(self): 151 | """Return the total number of images in the dataset.""" 152 | return len(self.file_names) 153 | 154 | 155 | -------------------------------------------------------------------------------- /data/relighting_dataset_single_image_multilum.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from data.base_dataset import BaseDataset, get_params, get_transform, __make_power_2 4 | import torchvision.transforms as transforms 5 | import random 6 | import os 7 | from PIL import Image 8 | 9 | 10 | def read_anno(anno_filename): 11 | """Read the name of images from the anno file yielded by prepare_dataset.py. 12 | For each image, we must know which scene it belongs to. 13 | """ 14 | # read lines from anno 15 | f = open(anno_filename, 'r') 16 | annos = [] 17 | for line in f.readlines(): 18 | line = line.strip('\n') 19 | annos.append(line) 20 | return annos 21 | 22 | 23 | class RelightingDatasetSingleImageMultilum(BaseDataset): 24 | """A dataset class for relighting dataset. 25 | This dataset read data image by image. 26 | """ 27 | def __init__(self, opt, validation=False): 28 | """Initialize this dataset class. 29 | Parameters: 30 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 31 | """ 32 | BaseDataset.__init__(self, opt) 33 | self.dataroot = self.opt.dataroot_multilum 34 | if validation: 35 | anno_file = opt.anno_validation 36 | else: 37 | anno_file = opt.anno 38 | 39 | self.fix_pairs = not (self.opt.isTrain and opt.dataset_setting == 'ALL' and not validation) 40 | 41 | if self.fix_pairs: 42 | self.pairs_list = read_anno(anno_file) 43 | self.length = self.pairs_list.__len__() 44 | else: 45 | self.scenes_list = read_anno(anno_file) 46 | self.image_per_scene = 25 47 | self.length = self.image_per_scene * self.scenes_list.__len__() 48 | # kingston_bigbathroom2/dir_3_mip2.jpg is wrong 49 | self.scenes_list.remove("kingston_bigbathroom2") 50 | self.scenes_list.append("kingston_bigbathroom2") 51 | self.length = self.length - 1 52 | self.special_index_kingston_bigbathroom2 = list(range(self.image_per_scene)) 53 | self.special_index_kingston_bigbathroom2.remove(3) 54 | 55 | def __getitem__(self, index): 56 | """Return a data point and its metadata information. 57 | 58 | Parameters: 59 | index - - a random integer for data indexing 60 | 61 | Returns a dictionary that contains 62 | 'Image_input': , 63 | 'light_position_color_new': , 64 | 'light_position_color_original': , 65 | 'Reflectance_output': , 66 | 'Shading_output': , 67 | 'Shading_ori': , 68 | 'Image_relighted': , 69 | 'scene_label': , 70 | """ 71 | # get parameters 72 | img_size = self.opt.img_size 73 | 74 | if self.fix_pairs: 75 | pair = self.pairs_list[index] 76 | pair_elements = pair.split() 77 | scene = pair_elements[0] 78 | input_id = pair_elements[1] 79 | target_id = pair_elements[2] 80 | else: 81 | div, mod = divmod(index, self.image_per_scene) 82 | scene = self.scenes_list[div] 83 | # random pairs 84 | if scene != "kingston_bigbathroom2": 85 | input_id = mod 86 | id_list = list(range(self.image_per_scene)) 87 | id_list.remove(mod) 88 | target_id = random.choice(id_list) 89 | else: 90 | input_id = self.special_index_kingston_bigbathroom2[mod] 91 | filtered_list = [x for x in self.special_index_kingston_bigbathroom2 if x != input_id] 92 | target_id = random.choice(filtered_list) 93 | 94 | # get the parameters of data augmentation 95 | transform_params = get_params(self.opt, img_size) 96 | self.img_transform = get_transform(self.opt, transform_params) 97 | self.probes_transform = self.get_transform_probes() 98 | 99 | data = {} 100 | data['scene_label'] = scene + '_' + str(input_id) + '_' + str(target_id) 101 | data['Image_input'], data['light_position_color_original'] = self.get_image(scene, input_id) 102 | data['Image_relighted'], data['light_position_color_new'] = self.get_image(scene, target_id) 103 | return data 104 | 105 | def get_image(self, scene, img_id): 106 | # image 107 | image_filename = os.path.join(self.dataroot, scene, "dir_{}_mip2.jpg".format(str(img_id))) 108 | if not os.path.exists(image_filename): 109 | raise Exception("RelightingDataset __getitem__ error") 110 | img = Image.open(image_filename).convert('RGB') 111 | img_tensor = self.img_transform(img) 112 | # the corresponding probes. 113 | probe1_name = os.path.join(self.dataroot, scene, "probes/dir_{}_chrome.jpg".format(str(img_id))) 114 | probe1 = Image.open(probe1_name).convert('RGB') 115 | probe1_tensor = self.probes_transform(probe1) 116 | probe2_name = os.path.join(self.dataroot, scene, "probes/dir_{}_gray.jpg".format(str(img_id))) 117 | probe2 = Image.open(probe2_name).convert('RGB') 118 | probe2_tensor = self.probes_transform(probe2) 119 | probes_tensor = torch.cat([probe1_tensor, probe2_tensor], dim=0) 120 | return img_tensor, probes_tensor 121 | 122 | def get_transform_probes(self, grayscale=False): 123 | transform_list = [] 124 | 125 | transform_list += [transforms.ToTensor()] 126 | if self.opt.normalization_type == '[-1, 1]': 127 | if grayscale: 128 | transform_list += [transforms.Normalize((0.5,), (0.5,))] 129 | else: 130 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 131 | elif self.opt.normalization_type == '[0, 1]': 132 | pass 133 | else: 134 | raise Exception("normalization_type error") 135 | return transforms.Compose(transform_list) 136 | 137 | def __len__(self): 138 | """Return the total number of images in the dataset.""" 139 | return self.length 140 | 141 | 142 | -------------------------------------------------------------------------------- /data/relighting_dataset_single_image_rsr.py: -------------------------------------------------------------------------------- 1 | from data.base_dataset import BaseDataset, get_params, get_transform 2 | import random 3 | import os 4 | from PIL import Image 5 | import math 6 | from util.util import PARA_NOR 7 | import torch 8 | 9 | 10 | def read_anno_pairs(anno_filename): 11 | # read lines from anno 12 | with open(anno_filename, 'r') as f: 13 | annos = [] 14 | for line in f.readlines(): 15 | line = line.strip('\n') 16 | this_pair = line.split(' ') 17 | annos.append(this_pair) 18 | return annos 19 | 20 | 21 | def read_anno_group(anno_filename, data_root, Type_select): 22 | scene_list = [] 23 | with open(anno_filename, 'r') as f: 24 | for line in f.readlines(): 25 | line = line.strip('\n') 26 | scene_list.append(line) 27 | 28 | groups = [] 29 | for scene_id in scene_list: 30 | scene_path = os.path.join(data_root, scene_id) 31 | file_list = os.listdir(scene_path) 32 | file_list.sort() 33 | if len(file_list) != 288: 34 | raise Exception("The number of files are wrong!") 35 | 36 | if Type_select == "LightColorOnly": 37 | for i in range(8): 38 | start_point = i * 9 * 4 39 | for j in range(9): 40 | one_part = [file_list[start_point + j + 9 * k] for k in range(4)] 41 | groups.append([scene_id, one_part]) 42 | elif Type_select == "LightPositionOnly": 43 | slice = 9 44 | for i in range(32): 45 | one_part = file_list[i * slice:(i + 1) * slice] 46 | groups.append([scene_id, one_part]) 47 | elif Type_select == "AnyLight": 48 | slice = 36 49 | for i in range(8): 50 | one_part = file_list[i * slice:(i + 1) * slice] 51 | groups.append([scene_id, one_part]) 52 | else: 53 | raise Exception("No Type_select!") 54 | return groups, len(groups) * len(groups[0][1]) 55 | 56 | 57 | 58 | class RelightingDatasetSingleImageRSR(BaseDataset): 59 | """A dataset class for relighting dataset. 60 | This dataset read data image by image. 61 | """ 62 | 63 | def __init__(self, opt, validation=False): 64 | """Initialize this dataset class. 65 | 66 | Parameters: 67 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 68 | """ 69 | BaseDataset.__init__(self, opt) 70 | self.dataroot = self.opt.dataroot_RSR 71 | if validation: 72 | anno_file = opt.anno_validation 73 | else: 74 | anno_file = opt.anno 75 | self.fix_pair = validation or not opt.isTrain 76 | if self.fix_pair: 77 | self.pairs_list = read_anno_pairs(anno_file) 78 | self.length = self.pairs_list.__len__() 79 | else: 80 | self.group_type = self.opt.dataset_rsr_type 81 | self.groups_list, self.length = read_anno_group(anno_file, self.dataroot, self.group_type) 82 | self.image_per_group = len(self.groups_list[0][1]) 83 | 84 | def __getitem__(self, index): 85 | """Return a data point and its metadata information. 86 | 87 | Parameters: 88 | index - - a random integer for data indexing 89 | 90 | Returns a dictionary that contains 91 | 'Image_input': , 92 | 'light_position_color_new': , 93 | 'light_position_color_original': , 94 | 'Reflectance_output': , 95 | 'Shading_output': , 96 | 'Shading_ori': , 97 | 'Image_relighted': , 98 | 'scene_label': , 99 | """ 100 | # get parameters 101 | img_size = self.opt.img_size 102 | 103 | if self.fix_pair: 104 | pair = self.pairs_list[index] 105 | folder_id = pair[0] 106 | img_input = pair[1] 107 | img_target = pair[2] 108 | else: 109 | div, mod = divmod(index, self.image_per_group) 110 | group = self.groups_list[div] 111 | folder_id = group[0] 112 | img_input = group[1][mod] 113 | id_list = list(range(self.image_per_group)) 114 | id_list.remove(mod) 115 | id_target = random.choice(id_list) 116 | img_target = group[1][id_target] 117 | 118 | # get the parameters of data augmentation 119 | transform_params = get_params(self.opt, img_size) 120 | self.img_transform = get_transform(self.opt, transform_params) 121 | 122 | data = {} 123 | data['scene_label'] = img_input 124 | data['light_position_color_original'] = self.get_light_condition(img_input) 125 | data['light_position_color_new'] = self.get_light_condition(img_target) 126 | 127 | data['Image_input'] = self.get_image(folder_id, img_input) 128 | data['Image_relighted'] = self.get_image(folder_id, img_target) 129 | 130 | return data 131 | 132 | def get_light_condition(self, img_name): 133 | """ 134 | Name of image: 135 | {Image index}_{Relighting scene index}_{Pan}_{tilt}_{R}_{G}_{B}_{Object scene index}_{rotation of platform}_ 136 | {light color index}_{light position index}.jpg 137 | """ 138 | factor_deg2rad = math.pi / 180.0 139 | names = os.path.splitext(img_name)[0].split('_') 140 | pan = float(names[2]) * factor_deg2rad 141 | tilt = float(names[3]) * factor_deg2rad 142 | # transform light position to cos and sin 143 | light_position = [math.cos(pan), math.sin(pan), math.cos(tilt), math.sin(tilt)] 144 | # normalize the light position to [0, 1] 145 | light_position[:2] = [x * PARA_NOR['pan_a'] + PARA_NOR['pan_b'] for x in light_position[:2]] 146 | light_position[2:] = [x * PARA_NOR['tilt_a'] + PARA_NOR['tilt_b'] for x in light_position[2:]] 147 | # transform light temperature to RGB, and normalize it. 148 | light_color = list(map(lambda x: int(x) / 255.0, names[4:7])) 149 | light_position_color = light_position + light_color 150 | return torch.tensor(light_position_color) 151 | 152 | def get_image(self, folder_name, img_name): 153 | image_filename = os.path.join(self.dataroot, folder_name, img_name) 154 | if not os.path.exists(image_filename): 155 | raise Exception("RelightingDataset __getitem__ error") 156 | 157 | img = Image.open(image_filename).convert('RGB') 158 | img_tensor = self.img_transform(img) 159 | return img_tensor 160 | 161 | def __len__(self): 162 | """Return the total number of images in the dataset.""" 163 | return self.length 164 | 165 | 166 | -------------------------------------------------------------------------------- /data/relighting_dataset_single_image_test.py: -------------------------------------------------------------------------------- 1 | from data.base_dataset import BaseDataset, get_params, get_transform 2 | from data.relighting_dataset_single_image import get_data 3 | 4 | 5 | def read_anno(file_name): 6 | anno_list = [] 7 | with open(file_name, 'r') as f: 8 | for x in f.readlines(): 9 | x = x.strip('\n') 10 | anno_list.append(x) 11 | return anno_list 12 | 13 | 14 | class RelightingDatasetSingleImageTest(BaseDataset): 15 | """A dataset class for relighting dataset. 16 | This dataset read data image by image (for test). 17 | Read test pairs from anno files. Each scene only has 2 images, one for input and one for relighting. 18 | """ 19 | 20 | def __init__(self, opt, validation=False): 21 | """Initialize this dataset class. 22 | 23 | Parameters: 24 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 25 | """ 26 | BaseDataset.__init__(self, opt) 27 | if validation: 28 | anno_file = opt.anno_validation 29 | else: 30 | anno_file = opt.anno 31 | self.pairs_list = read_anno(anno_file) 32 | 33 | def __getitem__(self, index): 34 | """Return a data point and its metadata information. 35 | 36 | Parameters: 37 | index - - a random integer for data indexing 38 | 39 | Returns a dictionary that contains 40 | 'Image_input': , 41 | 'light_position_color_new': , 42 | 'light_position_color_original': , 43 | 'Reflectance_output': , 44 | 'Shading_output': , 45 | 'Shading_ori': , 46 | 'Image_relighted': , 47 | 'scene_label': , 48 | """ 49 | # get parameters 50 | dataroot = self.dataroot 51 | img_size = self.opt.img_size 52 | multiple_replace_image = self.opt.multiple_replace_image 53 | # get one pair 54 | pair = self.pairs_list[index].split() 55 | file_name_input = pair[0] 56 | file_name_output = pair[1] 57 | 58 | # get the parameters of data augmentation 59 | transform_params = get_params(self.opt, img_size) 60 | img_transform = get_transform(self.opt, transform_params) 61 | 62 | data = get_data(file_name_input, file_name_output, dataroot, img_transform, multiple_replace_image) 63 | 64 | return data 65 | 66 | def __len__(self): 67 | """Return the total number of images in the dataset.""" 68 | return len(self.pairs_list) 69 | 70 | 71 | -------------------------------------------------------------------------------- /data/relighting_dataset_single_image_vidit.py: -------------------------------------------------------------------------------- 1 | from data.base_dataset import BaseDataset, get_params, get_transform 2 | from util.util import PARA_NOR 3 | import math 4 | import torch 5 | from util.k_to_rgb import convert_K_to_RGB 6 | from PIL import Image 7 | import os.path 8 | import random 9 | DIRECTION = ['E', 'N', 'NE', 'NW', 'S', 'SE', 'SW', 'W'] 10 | TEMPERATURE_LIST = [2500, 3500, 4500, 5500, 6500] 11 | 12 | 13 | def read_anno_pairs(anno_filename): 14 | # read lines from anno 15 | with open(anno_filename, 'r') as f: 16 | annos = [] 17 | for line in f.readlines(): 18 | line = line.strip('\n') 19 | this_pair = line.split(' ') 20 | annos.append(this_pair) 21 | return annos 22 | 23 | 24 | def read_train_pairs(anno_filename): 25 | # read lines from anno 26 | with open(anno_filename, 'r') as f: 27 | annos = [] 28 | for line in f.readlines(): 29 | line = line.strip('\n') 30 | annos.append(["{}_6500_N.png".format(line), "{}_4500_E.png".format(line)]) 31 | return annos 32 | 33 | 34 | def read_anno_group(anno_filename, Type_select): 35 | scene_list = [] 36 | with open(anno_filename, 'r') as f: 37 | for line in f.readlines(): 38 | line = line.strip('\n') 39 | scene_list.append(line) 40 | 41 | groups = [] 42 | for scene_id in scene_list: 43 | if Type_select == "LightColorOnly": 44 | for direction in DIRECTION: 45 | groups.append( 46 | ['_'.join([scene_id, str(temperature), direction]) + '.png' for temperature in TEMPERATURE_LIST]) 47 | elif Type_select == "LightPositionOnly": 48 | for temperature in TEMPERATURE_LIST: 49 | groups.append( 50 | ['_'.join([scene_id, str(temperature), direction]) + '.png' for direction in DIRECTION]) 51 | elif Type_select == "AnyLight": 52 | groups.append( 53 | ['_'.join([scene_id, str(temperature), direction]) + '.png' for temperature in TEMPERATURE_LIST for 54 | direction in DIRECTION]) 55 | else: 56 | raise Exception("No Type_select!") 57 | return groups, len(groups) * len(groups[0]), len(groups[0]) 58 | 59 | 60 | class RelightingDatasetSingleImageVidit(BaseDataset): 61 | """A dataset class for relighting dataset. 62 | This dataset read data image by image. 63 | """ 64 | 65 | def __init__(self, opt, validation=False): 66 | """Initialize this dataset class. 67 | 68 | Parameters: 69 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 70 | """ 71 | BaseDataset.__init__(self, opt) 72 | # for vidit dataset 73 | self.dataroot = self.opt.dataroot_vidit 74 | if validation: 75 | anno_file = opt.anno_validation 76 | else: 77 | anno_file = opt.anno 78 | self.fix_pair = validation or not opt.isTrain 79 | if self.fix_pair: 80 | self.pairs_list = read_anno_pairs(anno_file) 81 | self.length = self.pairs_list.__len__() 82 | elif self.opt.dataset_assignment_type == "6500_N_4500_E": 83 | self.pairs_list = read_train_pairs(anno_file) 84 | self.length = self.pairs_list.__len__() 85 | self.fix_pair = True 86 | else: 87 | self.group_type = self.opt.dataset_assignment_type.split('_')[0] 88 | # DCDP means different color and different position 89 | self.DCDP = 'DCDP' in self.opt.dataset_assignment_type 90 | self.groups_list, self.length, self.image_per_group = read_anno_group(anno_file, self.group_type) 91 | 92 | def __getitem__(self, index): 93 | # get parameters 94 | img_size = self.opt.img_size 95 | 96 | if self.fix_pair: 97 | pair = self.pairs_list[index] 98 | img_input = pair[0] 99 | img_target = pair[1] 100 | else: 101 | div, mod = divmod(index, self.image_per_group) 102 | group = self.groups_list[div] 103 | img_input = group[mod] 104 | if self.DCDP: 105 | color_input = img_input.split('.')[0].split('_')[1] 106 | position_input = img_input.split('.')[0].split('_')[2] 107 | filtered_group = [] 108 | for img in group: 109 | color, position = tuple(img.split('.')[0].split('_')[1:3]) 110 | if color != color_input and position != position_input: 111 | filtered_group.append(img) 112 | img_target = random.choice(filtered_group) 113 | else: 114 | id_list = list(range(self.image_per_group)) 115 | id_list.remove(mod) 116 | id_target = random.choice(id_list) 117 | img_target = group[id_target] 118 | 119 | # get the parameters of data augmentation 120 | transform_params = get_params(self.opt, img_size) 121 | self.img_transform = get_transform(self.opt, transform_params) 122 | 123 | data = {} 124 | data['scene_label'] = img_input 125 | data['light_position_color_original'] = self.get_light(img_input) 126 | data['light_position_color_new'] = self.get_light(img_target) 127 | 128 | data['Image_input'] = self.get_image(img_input) 129 | data['Image_relighted'] = self.get_image(img_target) 130 | 131 | return data 132 | 133 | def get_light(self, img_name): 134 | str2pan = {'E': 90, 135 | 'N': 180, 136 | 'NE': 135, 137 | 'NW': 225, 138 | 'S': 0, 139 | 'SE': 45, 140 | 'SW': 315, 141 | 'W': 270} 142 | img_name_split = img_name.split('.')[0].split('_') 143 | pan = str2pan[img_name_split[-1]] 144 | tilt = 45.0 145 | color_temp = int(img_name_split[-2]) 146 | # transform light position to cos and sin 147 | light_position = [math.cos(pan), math.sin(pan), math.cos(tilt), math.sin(tilt)] 148 | # normalize the light position to [0, 1] 149 | light_position[:2] = [x * PARA_NOR['pan_a'] + PARA_NOR['pan_b'] for x in light_position[:2]] 150 | light_position[2:] = [x * PARA_NOR['tilt_a'] + PARA_NOR['tilt_b'] for x in light_position[2:]] 151 | # transform light temperature to RGB, and normalize it. 152 | light_color = list(map(lambda x: x / 255.0, convert_K_to_RGB(color_temp))) 153 | light_position_color = light_position + light_color 154 | return torch.tensor(light_position_color) 155 | 156 | def get_image(self, img_name): 157 | image_filename = os.path.join(self.dataroot, img_name) 158 | if not os.path.exists(image_filename): 159 | raise Exception("RelightingDataset __getitem__ error") 160 | 161 | img = Image.open(image_filename).convert('RGB') 162 | img_tensor = self.img_transform(img) 163 | return img_tensor 164 | 165 | def __len__(self): 166 | """Return the total number of images in the dataset.""" 167 | return self.length 168 | 169 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVC-CIC/DeepIntrinsicRelighting/b38a04ff9caecddd4e3551950cf6aaf5b17d960e/models/__init__.py -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | def create_model(opt): 2 | 3 | if opt.model_name == 'relighting_two_stage': 4 | from .two_stage_model import TwoStageModel 5 | model = TwoStageModel(opt) 6 | elif opt.model_name == 'relighting_two_stage_rs': 7 | from .two_stage_rs_model import TwoStageRSModel 8 | model = TwoStageRSModel(opt) 9 | elif opt.model_name == 'relighting_one_decoder': 10 | from .one_decoder_model import OneDecoderModel 11 | model = OneDecoderModel(opt) 12 | elif opt.model_name == 'drn': 13 | from .drn_model import DRNModel 14 | model = DRNModel(opt) 15 | else: 16 | raise Exception("Can not find model!") 17 | 18 | print("model [%s] was created" % (model.model_names)) 19 | return model 20 | -------------------------------------------------------------------------------- /models/networks_custom_func.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SigLinear(nn.Module): 6 | """ 7 | SigLinear is a custom PyTorch activation module that combines the Sigmoid function for the 8 | negative half-axis and a linear function for the positive half-axis. 9 | 10 | Specifically: 11 | - For input values x < 0, the output is given by the Sigmoid function. 12 | - For input values x >= 0, the output is given by the linear function y = 0.5 + 0.25 * x. 13 | 14 | The activation function is continuous and differentiable at x = 0. 15 | """ 16 | def __init__(self): 17 | super(SigLinear, self).__init__() 18 | 19 | def forward(self, x): 20 | sigmoid_mask = x < 0 21 | linear_mask = ~sigmoid_mask 22 | 23 | sigmoid_part = torch.sigmoid(x[sigmoid_mask]) 24 | linear_part = 0.5 + 0.25 * x[linear_mask] 25 | 26 | output = torch.empty_like(x) 27 | output[sigmoid_mask] = sigmoid_part 28 | output[linear_mask] = linear_part 29 | 30 | return output 31 | 32 | 33 | """ 34 | 35 | """ 36 | class ClampWithGradient(torch.autograd.Function): 37 | @staticmethod 38 | def forward(ctx, x, min_value, max_value): 39 | ctx.save_for_backward(x) 40 | ctx.min_value = min_value 41 | ctx.max_value = max_value 42 | return x.clamp(min_value, max_value) 43 | 44 | @staticmethod 45 | def backward(ctx, grad_output): 46 | x, = ctx.saved_tensors 47 | grad_input = grad_output.clone() 48 | grad_input[(x < ctx.min_value) | (x > ctx.max_value)] = 1 49 | return grad_input, None, None 50 | 51 | class CustomClamp(nn.Module): 52 | """ 53 | CustomClamp is a custom module that applies element-wise clamping to the input tensor. 54 | It ensures that the output values are within a specified range [min_value, max_value], 55 | while maintaining gradients outside the clamping range for learning purposes. 56 | Attributes: 57 | min_value (float): The lower bound of the clamping range. 58 | max_value (float): The upper bound of the clamping range. 59 | """ 60 | def __init__(self, min_value, max_value): 61 | super(CustomClamp, self).__init__() 62 | self.min_value = min_value 63 | self.max_value = max_value 64 | 65 | def forward(self, x): 66 | return ClampWithGradient.apply(x, self.min_value, self.max_value) 67 | 68 | -------------------------------------------------------------------------------- /models/networks_discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import functools 4 | from models.networks import get_norm_layer 5 | 6 | 7 | def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch'): 8 | """Create a discriminator 9 | 10 | Parameters: 11 | input_nc (int) -- the number of channels in input images 12 | ndf (int) -- the number of filters in the first conv layer 13 | netD (str) -- the architecture's name: basic | n_layers | pixel 14 | n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers' 15 | norm (str) -- the type of normalization layers used in the network. 16 | init_type (str) -- the name of the initialization method. 17 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 18 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 19 | 20 | Returns a discriminator 21 | 22 | Our current implementation provides three types of discriminators: 23 | [basic]: 'PatchGAN' classifier described in the original pix2pix paper. 24 | It can classify whether 70×70 overlapping patches are real or fake. 25 | Such a patch-level discriminator architecture has fewer parameters 26 | than a full-image discriminator and can work on arbitrarily-sized images 27 | in a fully convolutional fashion. 28 | 29 | [n_layers]: With this mode, you can specify the number of conv layers in the discriminator 30 | with the parameter (default=3 as used in [basic] (PatchGAN).) 31 | 32 | [pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not. 33 | It encourages greater color diversity but has no effect on spatial statistics. 34 | 35 | The discriminator has been initialized by . It uses Leakly RELU for non-linearity. 36 | """ 37 | net = None 38 | norm_layer = get_norm_layer(norm_type=norm) 39 | 40 | if netD == 'basic': # default PatchGAN classifier 41 | net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer) 42 | elif netD == 'n_layers': # more options 43 | net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer) 44 | elif netD == 'pixel': # classify if each pixel is real or fake 45 | net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer) 46 | elif netD == 'embedded_light': 47 | net = EmbeddedDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer) 48 | else: 49 | raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD) 50 | # return init_net(net, init_type, init_gain, gpu_ids, parallel_method) 51 | return net 52 | 53 | 54 | ############################################################################## 55 | # Classes 56 | ############################################################################## 57 | class GANLoss(nn.Module): 58 | """Define different GAN objectives. 59 | 60 | The GANLoss class abstracts away the need to create the target label tensor 61 | that has the same size as the input. 62 | """ 63 | 64 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): 65 | """ Initialize the GANLoss class. 66 | 67 | Parameters: 68 | gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. 69 | target_real_label (bool) - - label for a real image 70 | target_fake_label (bool) - - label of a fake image 71 | 72 | Note: Do not use sigmoid as the last layer of Discriminator. 73 | LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. 74 | """ 75 | super(GANLoss, self).__init__() 76 | self.register_buffer('real_label', torch.tensor(target_real_label)) 77 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 78 | self.gan_mode = gan_mode 79 | if gan_mode == 'lsgan': 80 | self.loss = nn.MSELoss() 81 | elif gan_mode == 'vanilla': 82 | self.loss = nn.BCEWithLogitsLoss() 83 | elif gan_mode in ['wgangp']: 84 | self.loss = None 85 | else: 86 | raise NotImplementedError('gan mode %s not implemented' % gan_mode) 87 | 88 | def get_target_tensor(self, prediction, target_is_real): 89 | """Create label tensors with the same size as the input. 90 | 91 | Parameters: 92 | prediction (tensor) - - tpyically the prediction from a discriminator 93 | target_is_real (bool) - - if the ground truth label is for real images or fake images 94 | 95 | Returns: 96 | A label tensor filled with ground truth label, and with the size of the input 97 | """ 98 | 99 | if target_is_real: 100 | target_tensor = self.real_label 101 | else: 102 | target_tensor = self.fake_label 103 | return target_tensor.expand_as(prediction) 104 | 105 | def __call__(self, prediction, target_is_real): 106 | """Calculate loss given Discriminator's output and grount truth labels. 107 | 108 | Parameters: 109 | prediction (tensor) - - tpyically the prediction output from a discriminator 110 | target_is_real (bool) - - if the ground truth label is for real images or fake images 111 | 112 | Returns: 113 | the calculated loss. 114 | """ 115 | if self.gan_mode in ['lsgan', 'vanilla']: 116 | target_tensor = self.get_target_tensor(prediction, target_is_real) 117 | loss = self.loss(prediction, target_tensor) 118 | elif self.gan_mode == 'wgangp': 119 | if target_is_real: 120 | loss = -prediction.mean() 121 | else: 122 | loss = prediction.mean() 123 | return loss 124 | 125 | 126 | 127 | class NLayerDiscriminator(nn.Module): 128 | """Defines a PatchGAN discriminator""" 129 | 130 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): 131 | """Construct a PatchGAN discriminator 132 | 133 | Parameters: 134 | input_nc (int) -- the number of channels in input images 135 | ndf (int) -- the number of filters in the last conv layer 136 | n_layers (int) -- the number of conv layers in the discriminator 137 | norm_layer -- normalization layer 138 | """ 139 | super(NLayerDiscriminator, self).__init__() 140 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 141 | use_bias = norm_layer.func == nn.InstanceNorm2d 142 | else: 143 | use_bias = norm_layer == nn.InstanceNorm2d 144 | 145 | kw = 4 146 | padw = 1 147 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 148 | nf_mult = 1 149 | nf_mult_prev = 1 150 | for n in range(1, n_layers): # gradually increase the number of filters 151 | nf_mult_prev = nf_mult 152 | nf_mult = min(2 ** n, 8) 153 | sequence += [ 154 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 155 | norm_layer(ndf * nf_mult), 156 | nn.LeakyReLU(0.2, True) 157 | ] 158 | 159 | nf_mult_prev = nf_mult 160 | nf_mult = min(2 ** n_layers, 8) 161 | sequence += [ 162 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 163 | norm_layer(ndf * nf_mult), 164 | nn.LeakyReLU(0.2, True) 165 | ] 166 | 167 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 168 | self.model = nn.Sequential(*sequence) 169 | 170 | def forward(self, input): 171 | """Standard forward.""" 172 | return self.model(input) 173 | 174 | 175 | class PixelDiscriminator(nn.Module): 176 | """Defines a 1x1 PatchGAN discriminator (pixelGAN)""" 177 | 178 | def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d): 179 | """Construct a 1x1 PatchGAN discriminator 180 | 181 | Parameters: 182 | input_nc (int) -- the number of channels in input images 183 | ndf (int) -- the number of filters in the last conv layer 184 | norm_layer -- normalization layer 185 | """ 186 | super(PixelDiscriminator, self).__init__() 187 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 188 | use_bias = norm_layer.func == nn.InstanceNorm2d 189 | else: 190 | use_bias = norm_layer == nn.InstanceNorm2d 191 | 192 | self.net = [ 193 | nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), 194 | nn.LeakyReLU(0.2, True), 195 | nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), 196 | norm_layer(ndf * 2), 197 | nn.LeakyReLU(0.2, True), 198 | nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)] 199 | 200 | self.net = nn.Sequential(*self.net) 201 | 202 | def forward(self, input): 203 | """Standard forward.""" 204 | return self.net(input) 205 | 206 | 207 | class EmbeddedDiscriminator(nn.Module): 208 | """Defines a PatchGAN discriminator with more conditions""" 209 | 210 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): 211 | """Construct a PatchGAN discriminator 212 | 213 | Parameters: 214 | input_nc (int) -- the number of channels in input images 215 | ndf (int) -- the number of filters in the last conv layer 216 | n_layers (int) -- the number of conv layers in the discriminator 217 | norm_layer -- normalization layer 218 | """ 219 | super(EmbeddedDiscriminator, self).__init__() 220 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 221 | use_bias = norm_layer.func == nn.InstanceNorm2d 222 | else: 223 | use_bias = norm_layer == nn.InstanceNorm2d 224 | 225 | kw = 4 226 | padw = 1 227 | sequence1 = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 228 | nf_mult = 1 229 | 230 | for n in range(1, n_layers): # gradually increase the number of filters 231 | nf_mult_prev = nf_mult 232 | nf_mult = min(2 ** n, 8) 233 | sequence1 += [ 234 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 235 | norm_layer(ndf * nf_mult), 236 | nn.LeakyReLU(0.2, True) 237 | ] 238 | 239 | nf_mult_prev = nf_mult 240 | nf_mult = min(2 ** n_layers, 8) 241 | sequence1 += [ 242 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 243 | norm_layer(ndf * nf_mult), 244 | nn.LeakyReLU(0.2, True) 245 | ] 246 | 247 | new_channel = 128 248 | light_add = LightAdd(new_channel) 249 | 250 | sequence2 = [] 251 | sequence2 += [nn.Conv2d(new_channel + ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 252 | self.model1 = nn.Sequential(*sequence1) 253 | self.light_embedded = light_add 254 | self.model2 = nn.Sequential(*sequence2) 255 | 256 | def forward(self, input, light_condition): 257 | """Standard forward.""" 258 | out = self.model1(input) 259 | out = self.light_embedded(out, light_condition) 260 | out = self.model2(out) 261 | return out 262 | 263 | 264 | class LightAdd(nn.Module): 265 | def __init__(self, channel): 266 | super(LightAdd, self).__init__() 267 | self.mlp = nn.Sequential( 268 | nn.Linear(7 * 2, 32), 269 | nn.LeakyReLU(0.2, True), 270 | nn.Linear(32, channel), 271 | nn.LeakyReLU(0.2, True) 272 | ) 273 | def forward(self, in_tensor, light_condition): 274 | light_vector = self.mlp(light_condition) 275 | light_vector = light_vector.unsqueeze(2).unsqueeze(3) 276 | light_vector = light_vector.expand(-1, -1, in_tensor.size()[2], in_tensor.size()[3]) 277 | 278 | return torch.cat((in_tensor, light_vector), 1) 279 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVC-CIC/DeepIntrinsicRelighting/b38a04ff9caecddd4e3551950cf6aaf5b17d960e/options/__init__.py -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import os 2 | from util import util 3 | 4 | class BaseOptions(): 5 | def __init__(self): 6 | self.server_root = '/ghome/yyang/dataset/' 7 | self.dataroot = self.server_root + 'ISR/' # path of the dataset 8 | self.dataroot_vidit = self.server_root + 'VIDIT_full/' 9 | self.dataroot_RSR = self.server_root + 'RSR_256/' 10 | self.dataroot_multilum = self.server_root + 'Multi_Illumination_small/train/' 11 | self.checkpoints_dir = './checkpoints/' # models are saved here 12 | self.max_dataset_size = float("inf") #float("inf") # Maximum number of samples allowed per dataset. If the dataset 13 | # directory contains more than max_dataset_size, only a subset is loaded. 14 | self.img_size = (256, 256) # size of the image 15 | self.input_nc = 3 # number of input image channels 16 | self.output_nc = 3 # number of output image channels 17 | self.ngf = 64 # number of gen filters in first conv layer 18 | self.init_type = 'normal' # network initialization [normal | xavier | kaiming | orthogonal] 19 | self.init_gain = 0.02 # scaling factor for normal, xavier and orthogonal. 20 | self.verbose = False # if specified, print more debugging information 21 | 22 | self.normalization_type = '[0, 1]' # '[0, 1]' or '[-1, 1]' if this is changed, the inverse normalization in 23 | # visualizer should also be changed manually. 24 | self.multiple_replace_image = True # if specified, the Image type of the dataset will not be read from the 25 | # dataset, it will be replaced by the multiple of reflectance and shading. 26 | self.pre_read_data = False # if specified, the dataset will be stored in memory before training. 27 | 28 | self.display_winsize = 256 # display window size for both visdom and HTML 29 | 30 | def parse(self, verbose=True): 31 | args = vars(self) 32 | 33 | if verbose: 34 | print('------------ Options -------------') 35 | for k, v in sorted(args.items()): 36 | print('%s: %s' % (str(k), str(v))) 37 | print('-------------- End ----------------') 38 | 39 | # save to the disk 40 | expr_dir = os.path.join(self.checkpoints_dir, self.name) 41 | util.mkdirs(expr_dir) 42 | if self.isTrain: 43 | options_name = 'options_train.txt' 44 | else: 45 | options_name = 'options_test.txt' 46 | file_name = os.path.join(expr_dir, options_name) 47 | with open(file_name, 'wt') as opt_file: 48 | opt_file.write('------------ Options -------------\n') 49 | for k, v in sorted(args.items()): 50 | opt_file.write('%s: %s\n' % (str(k), str(v))) 51 | opt_file.write('-------------- End ----------------\n') 52 | return self 53 | -------------------------------------------------------------------------------- /options/test_qualitative_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | import argparse 3 | 4 | class TestOptions(BaseOptions): 5 | def __init__(self): 6 | super(TestOptions, self).__init__() 7 | 8 | # get para 9 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 10 | parser.add_argument('--name', type=str) 11 | parser_value = parser.parse_args() 12 | self.name = parser_value.name 13 | 14 | # self.name = 'exp_isr' # name of the experiment. It decides where to store samples and models 15 | # self.name = 'exp_rsr_ours_f' # name of the experiment. It decides where to store samples and models 16 | # self.name = 'exp_vidit_ours_f' 17 | # self.name = 'exp_multilum_ours_f' 18 | 19 | self.special_test = False 20 | using_dataset = self.name.split('_')[1] 21 | if using_dataset == "isr": 22 | self.dataset_mode = 'relighting_single_image_test' # name of the dataset 23 | self.anno = 'data/anno_ISR/test_qualitative_pairs.txt' # the anno file from prepare_dataset.py 24 | self.preprocess = 'none' 25 | self.show_gt_intrinsic = True 26 | self.light_type = "pan_tilt_color" 27 | elif using_dataset == "rsr": 28 | self.dataset_mode = 'relighting_single_image_rsr' # name of the dataset 29 | self.anno = 'data/anno_RSR/AnyLighttest_pairs_qualitative.txt' # the anno file from prepare_dataset.py 30 | self.preprocess = 'none' 31 | self.show_gt_intrinsic = False 32 | self.light_type = "pan_tilt_color" 33 | elif using_dataset == "vidit": 34 | self.dataset_mode = 'relighting_single_image_vidit' # name of the dataset 35 | self.anno = 'data/anno_VIDIT/any2any/AnyLight_test_pairs_qualitative.txt' # the anno file from prepare_dataset.py 36 | self.preprocess = 'resize' 37 | self.show_gt_intrinsic = False 38 | self.light_type = "pan_tilt_color" 39 | elif using_dataset == "multilum": 40 | self.dataset_mode = 'relighting_single_image_multilum' 41 | self.dataroot_multilum = self.server_root + 'Multi_Illumination_small/test/' 42 | self.anno = 'data/multi_illumination/test_qualitative.txt' 43 | self.preprocess = 'none' 44 | self.show_gt_intrinsic = False 45 | self.light_type = "probes" 46 | elif using_dataset == "special": 47 | self.dataset_mode = 'relighting_single_image_special_test' # name of the dataset 48 | self.preprocess = 'resize' 49 | self.show_gt_intrinsic = False 50 | self.light_type = "pan_tilt_color" 51 | self.special_test = True # special test for pictures from other datasets. 52 | 53 | self.phase = 'test_' + using_dataset # str, default='test', help='train, val, test, etc') 54 | ##### 55 | if len(self.name.split('_')) > 2: 56 | using_model = self.name.split('_')[2] 57 | else: 58 | using_model = "ours" 59 | if using_model == "ours": 60 | self.model_name = 'relighting_two_stage' # ['relighting_two_stage' | 'relighting_one_decoder'] 61 | self.two_stage = True 62 | if using_dataset == "multilum": 63 | self.light_prediction = False 64 | else: 65 | self.light_prediction = True 66 | if self.two_stage and self.show_gt_intrinsic: 67 | self.metric_list = ['Relighted', 'Reflectance', 'Shading_ori', 'Shading_new', 68 | 'Reconstruct'] 69 | else: 70 | self.metric_list = ['Relighted'] 71 | if self.light_prediction: 72 | self.metric_list.append('light_position_color') 73 | self.netG = "resnet9_nonlocal" 74 | self.net_intrinsic = "resnet9" 75 | # self.netG = "unet" 76 | # self.net_intrinsic = "unet" 77 | self.infinite_range_sha = True 78 | if self.infinite_range_sha: 79 | self.net_intrinsic = self.net_intrinsic + "_InfRange" 80 | self.netG = self.netG + "_InfRange" 81 | self.introduce_ref_G_2 = False 82 | self.no_dropout = False # old option: no dropout for the model 83 | self.norm = 'batch' # instance normalization or batch normalization [instance | batch | none] 84 | elif using_model == "pix2pix": 85 | self.model_name = 'relighting_two_stage' # ['relighting' | 'intrinsic_decomposition'] 86 | self.two_stage = False 87 | if using_dataset == "multilum": 88 | self.light_prediction = False 89 | else: 90 | self.light_prediction = True 91 | self.no_dropout = False # old option: no dropout for the model 92 | self.norm = 'batch' # instance normalization or batch normalization [instance | batch | none] 93 | self.netG = "unet" 94 | self.infinite_range_sha = False 95 | self.metric_list = ['Relighted'] 96 | if self.light_prediction: 97 | self.metric_list.append('light_position_color') 98 | elif using_model == "drn": 99 | self.model_name = 'drn' # ['relighting' | 'intrinsic_decomposition'] 100 | self.light_prediction = False 101 | self.netG = 'global' 102 | self.n_downsample_global = 3 103 | self.n_blocks_global = 9 104 | self.n_local_enhancers = 1 105 | self.n_blocks_local = 3 106 | self.norm = 'instance' 107 | self.n_layers_D = 1 108 | self.use_sigmoid = False 109 | self.num_D = 1 110 | self.no_ganFeat_loss = False 111 | self.no_lsgan = False 112 | self.pool_size = 0 113 | self.ndf = 4 114 | self.lambda_feat = 10.0 115 | self.metric_list = ['Relighted'] 116 | elif using_model == "ian": 117 | self.model_name = 'IAN' 118 | self.metric_list = ['Relighted'] 119 | else: 120 | raise Exception("Using_model not exist. ") 121 | 122 | # train parameter 123 | self.batch_size = 1 # batch size # 6 124 | # select which model to load, set continue_train = True to load the weight 125 | self.continue_train = True # continue training: load the latest model 126 | self.epoch = 'save_best' # default='latest', which epoch to load? set to latest to use latest cached model 127 | self.load_iter = 0 # default='0', which iteration to load? 128 | 129 | self.parallel_method = "DataParallel" 130 | self.use_amp = False 131 | self.use_discriminator = False 132 | self.cross_model = False 133 | 134 | self.model_modify_layer = [] 135 | self.modify_layer = len(self.model_modify_layer) != 0 136 | self.constrain_intrinsic = False 137 | 138 | self.aspect_ratio = 1.0 # float, default=1.0, help='aspect ratio of result images') 139 | self.results_dir = './results/' # str, default='./results/', help='saves results here.') 140 | self.isTrain = False 141 | self.gpu_ids = [0] 142 | 143 | # Dropout and Batchnorm has different behavioir during training and test. 144 | self.eval = True # use eval mode during test time. 145 | self.num_test = 130 # how many test images to run 146 | # dataloader 147 | self.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed. 148 | self.num_threads = 1 # test code only supports num_threads = 1 149 | # data augmentation 150 | self.crop_size = 256 # then crop to this size 151 | self.load_size = self.crop_size # scale images to this size 152 | self.no_flip = True # if specified, do not flip the images for data augmentation 153 | 154 | self.display_id = -1 # no visdom display; the test code saves the results to a HTML file. 155 | 156 | -------------------------------------------------------------------------------- /options/test_quantitative_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | import argparse 3 | 4 | class TestQuantitiveOptions(BaseOptions): 5 | def __init__(self): 6 | super(TestQuantitiveOptions, self).__init__() 7 | 8 | # get para 9 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 10 | parser.add_argument('--name', type=str) 11 | parser_value = parser.parse_args() 12 | self.name = parser_value.name 13 | 14 | # self.name = 'exp_isr' # name of the experiment. It decides where to store samples and models 15 | # self.name = 'exp_rsr_ours_f' # name of the experiment. It decides where to store samples and models 16 | # self.name = 'exp_vidit_ours_f' 17 | # self.name = 'exp_multilum_ours_f' 18 | 19 | 20 | using_dataset = self.name.split('_')[1] 21 | if using_dataset == "isr": 22 | self.dataset_mode = 'relighting_single_image_test' # name of the dataset 23 | self.anno = 'data/anno_ISR/test_quantitative_pairs_10x.txt' # the anno file from prepare_dataset.py 24 | self.preprocess = 'none' # 'resize_and_crop' 25 | self.show_gt_intrinsic = True 26 | self.light_type = "pan_tilt_color" 27 | elif using_dataset == "rsr": 28 | self.dataset_mode = 'relighting_single_image_rsr' # name of the dataset 29 | self.anno = 'data/anno_RSR/AnyLighttest_pairs.txt' # the anno file from prepare_dataset.py 30 | self.preprocess = 'none' # 'resize_and_crop' 31 | self.show_gt_intrinsic = False 32 | self.light_type = "pan_tilt_color" 33 | elif using_dataset == "vidit": 34 | self.dataset_mode = 'relighting_single_image_vidit' # name of the dataset 35 | self.anno = 'data/anno_VIDIT/any2any/AnyLight_test_pairs.txt' # for table any-to-any 36 | self.preprocess = 'resize' 37 | self.show_gt_intrinsic = False 38 | self.light_type = "pan_tilt_color" 39 | elif using_dataset == "multilum": 40 | self.dataset_mode = 'relighting_single_image_multilum' 41 | self.dataroot_multilum = self.server_root + 'Multi_Illumination_small/test/' 42 | self.anno = 'data/multi_illumination/test.txt' 43 | self.preprocess = 'none' 44 | self.show_gt_intrinsic = False 45 | self.light_type = "probes" 46 | 47 | if len(self.name.split('_')) > 2: 48 | using_model = self.name.split('_')[2] 49 | else: 50 | using_model = "ours" 51 | if using_model == "ours": 52 | self.model_name = 'relighting_two_stage' # ['relighting_two_stage' | 'relighting_one_decoder'] 53 | self.two_stage = True 54 | if using_dataset == "multilum": 55 | self.light_prediction = False 56 | else: 57 | self.light_prediction = True 58 | if self.two_stage and self.show_gt_intrinsic: 59 | self.metric_list = ['Relighted', 'Reflectance', 'Shading_ori', 'Shading_new', 60 | 'Reconstruct'] 61 | else: 62 | self.metric_list = ['Relighted'] 63 | if self.light_prediction: 64 | self.metric_list.append('light_position_color') 65 | self.netG = "resnet9_nonlocal" 66 | self.net_intrinsic = "resnet9" 67 | # self.netG = "unet" 68 | # self.net_intrinsic = "unet" 69 | self.infinite_range_sha = True 70 | if self.infinite_range_sha: 71 | self.net_intrinsic = self.net_intrinsic + "_InfRange" 72 | self.netG = self.netG + "_InfRange" 73 | self.introduce_ref_G_2 = False 74 | self.no_dropout = False # old option: no dropout for the model 75 | self.norm = 'batch' # instance normalization or batch normalization [instance | batch | none] 76 | elif using_model == "pix2pix": 77 | self.model_name = 'relighting_two_stage' # ['relighting' | 'intrinsic_decomposition'] 78 | self.two_stage = False 79 | if using_dataset == "multilum": 80 | self.light_prediction = False 81 | else: 82 | self.light_prediction = True 83 | self.no_dropout = False # old option: no dropout for the model 84 | self.norm = 'batch' # instance normalization or batch normalization [instance | batch | none] 85 | self.netG = "unet" 86 | self.infinite_range_sha = False 87 | self.metric_list = ['Relighted'] 88 | if self.light_prediction: 89 | self.metric_list.append('light_position_color') 90 | elif using_model == "drn": 91 | self.model_name = 'drn' # ['relighting' | 'intrinsic_decomposition'] 92 | self.light_prediction = False 93 | self.netG = 'global' 94 | self.n_downsample_global = 3 95 | self.n_blocks_global = 9 96 | self.n_local_enhancers = 1 97 | self.n_blocks_local = 3 98 | self.norm = 'instance' 99 | self.n_layers_D = 1 100 | self.use_sigmoid = False 101 | self.num_D = 1 102 | self.no_ganFeat_loss = False 103 | self.no_lsgan = False 104 | self.pool_size = 0 105 | self.ndf = 4 106 | self.lambda_feat = 10.0 107 | self.metric_list = ['Relighted'] 108 | elif using_model == "ian": 109 | self.model_name = 'IAN' 110 | self.metric_list = ['Relighted'] 111 | else: 112 | raise Exception("Using_model not exist. ") 113 | 114 | # train parameter 115 | self.batch_size = 1 # batch size # 6 116 | # select which model to load, set continue_train = True to load the weight 117 | self.continue_train = True # continue training: load the latest model 118 | self.epoch = 'save_best' # default='latest', which epoch to load? set to latest to use latest cached model 119 | self.load_iter = 0 # default='0', which iteration to load? 120 | 121 | self.parallel_method = "DataParallel" 122 | self.use_amp = False 123 | self.use_discriminator = False 124 | self.cross_model = False 125 | 126 | self.model_modify_layer = [] 127 | self.modify_layer = len(self.model_modify_layer) != 0 128 | self.constrain_intrinsic = False 129 | 130 | self.results_dir = './results/' # str, default='./results/', help='saves results here.') 131 | self.isTrain = False 132 | self.gpu_ids = [0] 133 | 134 | # Dropout and Batchnorm has different behavioir during training and test. 135 | self.eval = True # use eval mode during test time. 136 | self.num_test = float("inf") # 100 # how many test images to run 137 | # dataloader 138 | self.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed. 139 | self.num_threads = 1 # test code only supports num_threads = 1 140 | # data augmentation 141 | self.crop_size = 256 # then crop to this size 142 | self.load_size = self.crop_size # scale images to this size 143 | self.no_flip = True # if specified, do not flip the images for data augmentation 144 | 145 | self.display_id = -1 # no visdom display; the test code saves the results to a HTML file. 146 | 147 | self.special_test = False # special test for pictures from other datasets. 148 | 149 | # # low resolution setting 150 | # self.img_size = (resolution, resolution) 151 | # self.preprocess = 'resize' 152 | # self.load_size = resolution # scale images to this size 153 | # self.crop_size = resolution # then crop to this size 154 | 155 | -------------------------------------------------------------------------------- /options/train_options_isr.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | def __init__(self, which_experiment): 6 | super(TrainOptions, self).__init__() 7 | 8 | self.which_experiment = which_experiment 9 | self.isTrain = True 10 | 11 | # Setting for dataset 12 | self.anno = 'data/anno_ISR/train.txt' # the anno file from prepare_dataset.py 13 | self.anno_validation = 'data/anno_ISR/val_pairs.txt' 14 | # Setting for GPU 15 | self.parallel_method = "DistributedDataParallel" # "DataParallel" "DistributedDataParallel" 16 | if self.parallel_method == "DataParallel": 17 | number_gpus = 1 18 | self.gpu_ids = [i for i in range(number_gpus)] 19 | elif self.parallel_method == "DistributedDataParallel": 20 | self.world_size = None 21 | self.gpu_ids = None 22 | self.use_amp = False 23 | # parameters for batch 24 | self.batch_size = 6 25 | 26 | # Setting for the optimizer 27 | self.lr_policy = 'step' # learning rate policy. [linear | step | plateau | cosine] 28 | self.lr = 0.0002 # initial learning rate for adam 29 | self.lr_d = self.lr 30 | self.lr_decay_ratio = 0.5 # decay ratio in step scheduler. 31 | self.n_epochs = 150 #100 number of epochs with the initial learning rate 32 | self.n_epochs_decay = 0 #100 when using 'linear', number of epochs to linearly decay learning rate to zero 33 | self.lr_decay_iters = 100 # when using 'step', multiply by a gamma every lr_decay_iters iterations 34 | self.optimizer_type = 'Adam' # 'Adam', 'SGD' 35 | self.beta1 = 0.5 # momentum term of adam 36 | self.adam_eps = 1e-8 37 | 38 | # Setting for continuing the training. 39 | self.continue_train = False # continue training: load the latest model 40 | self.epoch = '75' # default='latest', which epoch to load? set to latest to use latest cached model 41 | self.load_iter = 0 # default='0', which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch] 42 | if self.continue_train: 43 | try: 44 | self.epoch_count = int(self.epoch) 45 | except: 46 | self.epoch_count = 1 # the starting epoch count, we save the model by , +, ...') 47 | else: 48 | self.epoch_count = 1 49 | 50 | self.model_modify_layer = [] 51 | self.modify_layer = len(self.model_modify_layer) != 0 52 | 53 | # Setting for the model 54 | self.dataset_mode = 'relighting_single_image' # name of the dataset 55 | self.name = self.which_experiment # name of the experiment. It decides where to store samples and models 56 | 57 | self.model_name = 'relighting_two_stage' # ['relighting' | 'intrinsic_decomposition'] 58 | self.two_stage = True 59 | self.no_dropout = False # old option: no dropout for the model 60 | self.norm = 'batch' # instance normalization or batch normalization [instance | batch | none] 61 | self.light_type = "pan_tilt_color" # ["pan_tilt_color" | "Spherical_harmonic"] 62 | self.light_prediction = True 63 | self.netG = "resnet9_nonlocal" 64 | self.net_intrinsic = "resnet9" 65 | self.introduce_ref_G_2 = False 66 | self.cross_model = True 67 | # range of sha 68 | self.infinite_range_sha = True 69 | if self.infinite_range_sha: 70 | self.net_intrinsic = self.net_intrinsic + "_InfRange" 71 | self.netG = self.netG + "_InfRange" 72 | # reflectance consistency 73 | self.flag_ref_consistency = False 74 | self.loss_weight_ref_consistency = 1.0 75 | self.flag_sha_consistency = False 76 | self.loss_weight_sha_consistency = 1.0 77 | # Regularizing chromaticity 78 | self.flag_sha_chromaticity_smooth = False 79 | # self.method_sha_chromaticity_smooth = "OPP" # "OPP", "LAB" 80 | # self.loss_weight_sha_chromaticity_smooth = 75.0 81 | # self.loss_weight_sha_overall_smooth = 0.5 82 | self.flag_sha_ref_regression = False 83 | # self.method_sha_ref_regression = 'm1' 84 | # self.loss_weight_sha_ref_regression_1 = 1.0 # for chromaticity 85 | # self.loss_weight_sha_ref_regression_2 = 0.1 # for all channels 86 | # self.sha_ref_regression_mean = [0.43, 0.61] 87 | self.method_sha_ref_regression = 'm2' 88 | self.para_sha_ref_regression = { 89 | 'R_I_c': 1.2119, 90 | 'R_I_a': 1.1603, 91 | 'S_I_c': 0.5254, 92 | 'S_I_a': 0.7089, 93 | 'elu_alpha': 0.0, 94 | 'elu_shift': 0.0, 95 | 'w_R_I_c': 0.0, 96 | 'w_R_I_a': 0.0, 97 | 'w_S_I_c': 1.0, 98 | 'w_S_I_a': 0.1, 99 | } 100 | # Regularizing init_ref 101 | self.flag_init_ref = False 102 | self.para_init_ref = { 103 | 'cross_ij': True, 104 | 'decay': False, 105 | 'method': "ORI" # "OPP", "ORI" 106 | } 107 | self.loss_weight_init_ref = 1.0 108 | 109 | # # discriminator 110 | self.use_discriminator = True 111 | # self.epoch_start_train_discriminator = -1 112 | self.netD = 'n_layers' # specify discriminator architecture [basic | n_layers | pixel]. 113 | # The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator' 114 | self.n_layers_D = 3 # only used if netD==n_layers 115 | self.gan_mode = 'lsgan' # 'the type of GAN objective. [vanilla| lsgan | wgangp]. 116 | # vanilla GAN loss is the cross-entropy objective used in the original GAN paper.' 117 | self.ndf = 64 # number of discrim filters in the first conv layer 118 | self.loss_weight_GAN = 0.05 119 | 120 | # parameters for loss functions 121 | self.constrain_intrinsic = True 122 | self.show_gt_intrinsic = True 123 | self.main_loss_function = 'L1_DSSIM_LPIPS' # choose using L2 or L1 during the training 124 | self.flag_L1_DSSIM_LPIPS = [True, True, True] 125 | if self.cross_model: 126 | self.unbalanced = False 127 | self.unbalanced_para = None 128 | # Weights of losses 129 | self.loss_weight_angular = 1.0 130 | self.loss_weight_color = 1.0 131 | self.loss_weight_reflectance = 1.0 132 | self.loss_weight_shading_ori = 1.0 133 | self.loss_weight_reconstruct = 1.0 134 | self.loss_weight_shading_new = 1.0 135 | self.loss_weight_relighted = 5.0 136 | 137 | # data augmentation 138 | self.preprocess = 'none' # 'resize_and_crop' # scaling and cropping of images at load time 139 | # [resize_and_crop | crop | scale_width | scale_width_and_crop | none] 140 | self.load_size = 256 # scale images to this size 141 | self.crop_size = 256 # then crop to this size 142 | self.no_flip = True # if specified, do not flip the images for data augmentation 143 | 144 | # dataloader 145 | self.serial_batches = False # if true, takes images in order to make batches, otherwise takes them randomly 146 | self.num_threads = 10 # threads for loading data 147 | 148 | # save model and output images 149 | self.save_epoch_freq = 50 # frequency of saving checkpoints at the end of epochs 150 | self.save_latest = False 151 | self.save_optimizer = True 152 | self.load_optimizer = True 153 | self.load_scaler = False 154 | 155 | # visdom and HTML visualization parameters 156 | self.display_env = self.name 157 | self.save_and_show_by_epoch = True 158 | self.display_freq = 4000 # frequency of showing training results on screen') 159 | self.display_ncols = 5 # if positive, display all images in a single visdom web panel with certain number of images per row.') 160 | self.display_id = 1 # window id of the web display') 161 | self.display_server = "http://localhost" # visdom server of the web display') 162 | 163 | self.display_port = 8097 # visdom port of the web display') 164 | self.update_html_freq = 4000 # frequency of saving training results to html') 165 | self.print_freq = 4000 # frequency of showing training results on console') 166 | self.no_html = False # do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 167 | 168 | -------------------------------------------------------------------------------- /options/train_options_isr_drn.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | def __init__(self, which_experiment): 6 | super(TrainOptions, self).__init__() 7 | 8 | self.which_experiment = which_experiment 9 | self.isTrain = True 10 | 11 | # Setting for dataset 12 | self.dataset_mode = 'relighting_single_image' # name of the dataset 13 | self.anno = 'data/check_dataset/SID2_new_train.txt' # the anno file from prepare_dataset.py 14 | self.anno_validation = 'data/check_dataset/SID2_new_val_pairs.txt' 15 | # Setting for GPU 16 | self.parallel_method = "DistributedDataParallel" # "DataParallel" "DistributedDataParallel" 17 | if self.parallel_method == "DataParallel": 18 | number_gpus = 1 19 | self.gpu_ids = [i for i in range(number_gpus)] 20 | elif self.parallel_method == "DistributedDataParallel": 21 | self.world_size = None 22 | self.gpu_ids = None 23 | self.use_amp = False 24 | # parameters for batch 25 | self.batch_size = 6 26 | 27 | # Setting for the optimizer 28 | self.lr_policy = 'step' # learning rate policy. [linear | step | plateau | cosine] 29 | self.lr = 0.0002 # initial learning rate for adam 30 | self.lr_d = self.lr 31 | self.lr_decay_ratio = 0.5 # decay ratio in step scheduler. 32 | self.n_epochs = 150 #100 number of epochs with the initial learning rate 33 | self.n_epochs_decay = 0 #100 when using 'linear', number of epochs to linearly decay learning rate to zero 34 | self.lr_decay_iters = 100 # when using 'step', multiply by a gamma every lr_decay_iters iterations 35 | self.optimizer_type = 'Adam' # 'Adam', 'SGD' 36 | self.beta1 = 0.5 # momentum term of adam 37 | self.adam_eps = 1e-8 38 | 39 | # Setting for continuing the training. 40 | self.continue_train = False # continue training: load the latest model 41 | self.epoch = '75' # default='latest', which epoch to load? set to latest to use latest cached model 42 | self.load_iter = 0 # default='0', which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch] 43 | if self.continue_train: 44 | try: 45 | self.epoch_count = int(self.epoch) 46 | except: 47 | self.epoch_count = 1 # the starting epoch count, we save the model by , +, ...') 48 | else: 49 | self.epoch_count = 1 50 | 51 | self.model_modify_layer = [] 52 | self.modify_layer = len(self.model_modify_layer) != 0 53 | 54 | # Setting for the model 55 | self.name = self.which_experiment # name of the experiment. It decides where to store samples and models 56 | 57 | self.model_name = 'drn' # ['relighting' | 'intrinsic_decomposition'] 58 | self.light_type = "pan_tilt_color" # ["pan_tilt_color" | "Spherical_harmonic"] 59 | self.light_prediction = False 60 | self.netG = 'global' 61 | self.use_discriminator = False 62 | self.n_downsample_global = 3 63 | self.n_blocks_global = 9 64 | self.n_local_enhancers = 1 65 | self.n_blocks_local = 3 66 | self.norm = 'instance' 67 | self.n_layers_D = 1 68 | self.use_sigmoid = False 69 | self.num_D = 1 70 | self.no_ganFeat_loss = False 71 | self.no_lsgan = False 72 | self.pool_size = 0 73 | self.ndf = 4 74 | self.lambda_feat = 10.0 75 | 76 | self.loss_weight_GAN = 0.05 77 | 78 | # data augmentation 79 | self.preprocess = 'none' # 'resize_and_crop' # scaling and cropping of images at load time 80 | # [resize_and_crop | crop | scale_width | scale_width_and_crop | none] 81 | self.load_size = 256 # scale images to this size 82 | self.crop_size = 256 # then crop to this size 83 | self.no_flip = True # if specified, do not flip the images for data augmentation 84 | 85 | # dataloader 86 | self.serial_batches = False # if true, takes images in order to make batches, otherwise takes them randomly 87 | self.num_threads = 10 # threads for loading data 88 | 89 | # save model and output images 90 | self.save_epoch_freq = 50 # frequency of saving checkpoints at the end of epochs 91 | self.save_latest = False 92 | self.save_optimizer = True 93 | self.load_optimizer = True 94 | self.load_scaler = False 95 | 96 | # visdom and HTML visualization parameters 97 | self.display_env = self.name 98 | self.save_and_show_by_epoch = True 99 | self.display_freq = 4000 # frequency of showing training results on screen') 100 | self.display_ncols = 5 # if positive, display all images in a single visdom web panel with certain number of images per row.') 101 | self.display_id = 1 # window id of the web display') 102 | self.display_server = "http://localhost" # visdom server of the web display') 103 | 104 | self.display_port = 8097 # visdom port of the web display') 105 | self.update_html_freq = 4000 # frequency of saving training results to html') 106 | self.print_freq = 4000 # frequency of showing training results on console') 107 | self.no_html = False # do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 108 | 109 | -------------------------------------------------------------------------------- /options/train_options_isr_pix2pix.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | def __init__(self, which_experiment): 6 | super(TrainOptions, self).__init__() 7 | 8 | self.which_experiment = which_experiment 9 | self.isTrain = True 10 | 11 | # Setting for dataset 12 | self.anno = 'data/check_dataset/SID2_new_train.txt' # the anno file from prepare_dataset.py 13 | self.anno_validation = 'data/check_dataset/SID2_new_val_pairs.txt' 14 | # Setting for GPU 15 | self.parallel_method = "DistributedDataParallel" # "DataParallel" "DistributedDataParallel" 16 | if self.parallel_method == "DataParallel": 17 | number_gpus = 1 18 | self.gpu_ids = [i for i in range(number_gpus)] 19 | elif self.parallel_method == "DistributedDataParallel": 20 | self.world_size = None 21 | self.gpu_ids = None 22 | self.use_amp = False 23 | # parameters for batch 24 | self.batch_size = 18 25 | 26 | # Setting for the optimizer 27 | self.lr_policy = 'step' # learning rate policy. [linear | step | plateau | cosine] 28 | self.lr = 0.0002 # initial learning rate for adam 29 | self.lr_d = self.lr 30 | self.lr_decay_ratio = 0.5 # decay ratio in step scheduler. 31 | self.n_epochs = 150 #100 number of epochs with the initial learning rate 32 | self.n_epochs_decay = 0 #100 when using 'linear', number of epochs to linearly decay learning rate to zero 33 | self.lr_decay_iters = 100 # when using 'step', multiply by a gamma every lr_decay_iters iterations 34 | self.optimizer_type = 'Adam' # 'Adam', 'SGD' 35 | self.beta1 = 0.5 # momentum term of adam 36 | self.adam_eps = 1e-8 37 | 38 | # Setting for continuing the training. 39 | self.continue_train = False # continue training: load the latest model 40 | self.epoch = '75' # default='latest', which epoch to load? set to latest to use latest cached model 41 | self.load_iter = 0 # default='0', which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch] 42 | if self.continue_train: 43 | try: 44 | self.epoch_count = int(self.epoch) 45 | except: 46 | self.epoch_count = 1 # the starting epoch count, we save the model by , +, ...') 47 | else: 48 | self.epoch_count = 1 49 | 50 | self.model_modify_layer = [] 51 | self.modify_layer = len(self.model_modify_layer) != 0 52 | 53 | # Setting for the model 54 | self.dataset_mode = 'relighting_single_image' # name of the dataset 55 | self.name = self.which_experiment # name of the experiment. It decides where to store samples and models 56 | 57 | self.model_name = 'relighting_two_stage' # ['relighting' | 'intrinsic_decomposition'] 58 | self.two_stage = False 59 | self.no_dropout = False # old option: no dropout for the model 60 | self.norm = 'batch' # instance normalization or batch normalization [instance | batch | none] 61 | self.light_type = "pan_tilt_color" # ["pan_tilt_color" | "Spherical_harmonic"] 62 | self.light_prediction = True 63 | self.netG = "unet" 64 | self.infinite_range_sha = False 65 | 66 | self.cross_model = False 67 | 68 | # reflectance consistency 69 | self.flag_ref_consistency = False 70 | # Regularizing chromaticity 71 | self.flag_sha_chromaticity_smooth = False 72 | self.flag_sha_ref_regression = False 73 | # Regularizing init_ref 74 | self.flag_init_ref = False 75 | 76 | # # discriminator 77 | self.use_discriminator = True 78 | # self.epoch_start_train_discriminator = -1 79 | self.netD = 'n_layers' # specify discriminator architecture [basic | n_layers | pixel]. 80 | # The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator' 81 | self.n_layers_D = 3 # only used if netD==n_layers 82 | self.gan_mode = 'lsgan' # 'the type of GAN objective. [vanilla| lsgan | wgangp]. 83 | # vanilla GAN loss is the cross-entropy objective used in the original GAN paper.' 84 | self.ndf = 64 # number of discrim filters in the first conv layer 85 | self.loss_weight_GAN = 0.05 86 | 87 | # parameters for loss functions 88 | self.constrain_intrinsic = False 89 | self.show_gt_intrinsic = True 90 | self.main_loss_function = 'L1_DSSIM_LPIPS' # choose using L2 or L1 during the training 91 | self.flag_L1_DSSIM_LPIPS = [True, False, False] 92 | # Weights of losses 93 | self.loss_weight_angular = 1.0 94 | self.loss_weight_color = 1.0 95 | self.loss_weight_relighted = 5.0 96 | 97 | # data augmentation 98 | self.preprocess = 'none' # 'resize_and_crop' # scaling and cropping of images at load time 99 | # [resize_and_crop | crop | scale_width | scale_width_and_crop | none] 100 | self.load_size = 256 # scale images to this size 101 | self.crop_size = 256 # then crop to this size 102 | self.no_flip = True # if specified, do not flip the images for data augmentation 103 | 104 | # dataloader 105 | self.serial_batches = False # if true, takes images in order to make batches, otherwise takes them randomly 106 | self.num_threads = 10 # threads for loading data 107 | 108 | # save model and output images 109 | self.save_epoch_freq = 50 # frequency of saving checkpoints at the end of epochs 110 | self.save_latest = False 111 | self.save_optimizer = True 112 | self.load_optimizer = True 113 | self.load_scaler = False 114 | 115 | # visdom and HTML visualization parameters 116 | self.display_env = self.name 117 | self.save_and_show_by_epoch = True 118 | self.display_freq = 4000 # frequency of showing training results on screen') 119 | self.display_ncols = 5 # if positive, display all images in a single visdom web panel with certain number of images per row.') 120 | self.display_id = 1 # window id of the web display') 121 | self.display_server = "http://localhost" # visdom server of the web display') 122 | 123 | self.display_port = 8097 # visdom port of the web display') 124 | self.update_html_freq = 4000 # frequency of saving training results to html') 125 | self.print_freq = 4000 # frequency of showing training results on console') 126 | self.no_html = False # do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 127 | 128 | -------------------------------------------------------------------------------- /options/train_options_multilum_ours_f.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | def __init__(self, which_experiment): 6 | super(TrainOptions, self).__init__() 7 | 8 | self.which_experiment = which_experiment 9 | self.isTrain = True 10 | 11 | # Setting for dataset 12 | self.dataset_mode = 'relighting_single_image_multilum' # name of the dataset 13 | self.dataset_setting = 'ALL' # 'IAN' 'SILT' 'ALL' 14 | # 'SILT' setting: SILT : Self-supervised Lighting Transfer Using Implicit Image Decomposition. 15 | # For this paper we picked nine illumination conditions, consistent across all scenes of the 16 | # dataset, and chose one unpaired image setting to act as target lighting domain. 17 | # 'IAN' setting: Designing an Illumination-Aware Network for Deep Image Relighting. 18 | # In our experiment, we take dir_0 as input illumination settings and dir_17 as output 19 | # illumination settings in dataset to train all methods for evaluation. 20 | if self.dataset_setting == 'ALL': 21 | self.anno = 'data/multi_illumination/train.txt' # the anno file from prepare_dataset.py 22 | self.anno_validation = 'data/multi_illumination/val.txt' 23 | elif self.dataset_setting == 'SILT': 24 | self.anno = 'data/multi_illumination/train_setting1.txt' # the anno file from prepare_dataset.py 25 | self.anno_validation = 'data/multi_illumination/val_setting1.txt' 26 | elif self.dataset_setting == 'IAN': 27 | self.anno = 'data/multi_illumination/train_setting2.txt' # the anno file from prepare_dataset.py 28 | self.anno_validation = 'data/multi_illumination/val_setting2.txt' 29 | else: 30 | raise Exception("Error: dataset_setting") 31 | # Setting for GPU 32 | self.parallel_method = "DistributedDataParallel" # "DataParallel" "DistributedDataParallel" 33 | if self.parallel_method == "DataParallel": 34 | number_gpus = 1 35 | self.gpu_ids = [i for i in range(number_gpus)] 36 | elif self.parallel_method == "DistributedDataParallel": 37 | self.world_size = None 38 | self.gpu_ids = None 39 | self.use_amp = False 40 | # parameters for batch 41 | self.batch_size = 6 42 | self.dataset_drop_last = True 43 | 44 | # Setting for the optimizer 45 | self.lr_policy = 'step' # learning rate policy. [linear | step | plateau | cosine] 46 | self.lr = 0.0001 # initial learning rate for adam 47 | self.lr_d = self.lr 48 | self.lr_decay_ratio = 0.5 # decay ratio in step scheduler. 49 | self.n_epochs = 150 #100 number of epochs with the initial learning rate 50 | self.n_epochs_decay = 0 #100 when using 'linear', number of epochs to linearly decay learning rate to zero 51 | self.lr_decay_iters = 100 # when using 'step', multiply by a gamma every lr_decay_iters iterations 52 | self.optimizer_type = 'Adam' # 'Adam', 'SGD' 53 | self.beta1 = 0.5 # momentum term of adam 54 | self.adam_eps = 1e-8 55 | 56 | # Setting for continuing the training. 57 | self.continue_train = True # continue training: load the latest model 58 | self.epoch = 'base_isr' # default='latest', which epoch to load? set to latest to use latest cached model 59 | self.load_iter = 0 # default='0', which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch] 60 | if self.continue_train: 61 | try: 62 | self.epoch_count = int(self.epoch) 63 | except: 64 | self.epoch_count = 1 # the starting epoch count, we save the model by , +, ...') 65 | else: 66 | self.epoch_count = 1 67 | 68 | if self.continue_train: 69 | self.model_modify_layer = ['G_2'] 70 | self.modify_layer = len(self.model_modify_layer) != 0 71 | 72 | # Setting for the model 73 | self.name = self.which_experiment # name of the experiment. It decides where to store samples and models 74 | 75 | self.model_name = 'relighting_two_stage' # ['relighting' | 'intrinsic_decomposition'] 76 | self.two_stage = True 77 | self.no_dropout = False # old option: no dropout for the model 78 | self.norm = 'batch' # instance normalization or batch normalization [instance | batch | none] 79 | self.light_type = "probes" # ["pan_tilt_color" | "Spherical_harmonic"] 80 | self.light_prediction = False 81 | self.netG = "resnet9_nonlocal" 82 | self.net_intrinsic = "resnet9" 83 | self.introduce_ref_G_2 = False 84 | self.cross_model = True 85 | # range of sha 86 | self.infinite_range_sha = True 87 | if self.infinite_range_sha: 88 | self.net_intrinsic = self.net_intrinsic + "_InfRange" 89 | self.netG = self.netG + "_InfRange" 90 | # reflectance consistency 91 | self.flag_ref_consistency = True 92 | self.loss_weight_ref_consistency = 1.0 93 | self.flag_sha_consistency = True 94 | self.loss_weight_sha_consistency = 1.0 95 | # Regularizing chromaticity 96 | self.flag_sha_chromaticity_smooth = False 97 | # self.method_sha_chromaticity_smooth = "OPP" # "OPP", "LAB" 98 | # self.loss_weight_sha_chromaticity_smooth = 75.0 99 | # self.loss_weight_sha_overall_smooth = 0.5 100 | self.flag_sha_ref_regression = True 101 | # self.method_sha_ref_regression = 'm1' 102 | # self.loss_weight_sha_ref_regression_1 = 1.0 # for chromaticity 103 | # self.loss_weight_sha_ref_regression_2 = 0.1 # for all channels 104 | # self.sha_ref_regression_mean = [0.43, 0.61] 105 | self.method_sha_ref_regression = 'm2' 106 | self.para_sha_ref_regression = { 107 | # 'R_I_c': 1.2119, 108 | # 'R_I_a': 1.1603, 109 | 'S_I_c': 0.5254, 110 | 'S_I_a': 0.7089, 111 | # 'S_R_c': 0.4336, 112 | # 'S_R_a': 0.6109, 113 | 'elu_alpha': 0.1, 114 | 'elu_shift': 0.0, 115 | # 'w_R_I_c': 0.0, 116 | # 'w_R_I_a': 0.0, 117 | 'w_S_I_c': 2.0, 118 | 'w_S_I_a': 0.1, 119 | # 'w_S_R_c': 0.0, 120 | # 'w_S_R_a': 0.0, 121 | } 122 | # Regularizing init_ref 123 | self.flag_init_ref = True 124 | self.para_init_ref = { 125 | 'cross_ij': True, 126 | 'decay': True, 127 | 'method': "ORI" # "OPP", "ORI" 128 | } 129 | self.loss_weight_init_ref = 1.0 130 | 131 | # # discriminator 132 | self.use_discriminator = True 133 | # self.epoch_start_train_discriminator = -1 134 | self.netD = 'n_layers' # specify discriminator architecture [basic | n_layers | pixel]. 135 | # The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator' 136 | self.n_layers_D = 3 # only used if netD==n_layers 137 | self.gan_mode = 'lsgan' # 'the type of GAN objective. [vanilla| lsgan | wgangp]. 138 | # vanilla GAN loss is the cross-entropy objective used in the original GAN paper.' 139 | self.ndf = 64 # number of discrim filters in the first conv layer 140 | self.loss_weight_GAN = 0.05 141 | 142 | # parameters for loss functions 143 | self.constrain_intrinsic = False 144 | self.show_gt_intrinsic = False 145 | self.main_loss_function = 'L1_DSSIM_LPIPS' # choose using L2 or L1 during the training 146 | self.flag_L1_DSSIM_LPIPS = [True, True, True] 147 | if self.cross_model: 148 | self.unbalanced = False 149 | self.unbalanced_para = None 150 | # Weights of losses 151 | self.loss_weight_angular = 1.0 152 | self.loss_weight_color = 1.0 153 | self.loss_weight_reflectance = 1.0 154 | self.loss_weight_shading_ori = 1.0 155 | self.loss_weight_reconstruct = 1.0 156 | self.loss_weight_shading_new = 1.0 157 | self.loss_weight_relighted = 5.0 158 | # SH light 159 | self.loss_weight_SHlight = 1.0 160 | 161 | # data augmentation 162 | self.preprocess = 'none' # 'resize_and_crop' # scaling and cropping of images at load time 163 | # [resize_and_crop | crop | scale_width | scale_width_and_crop | none] 164 | self.load_size = 256 # scale images to this size 165 | self.crop_size = 256 # then crop to this size 166 | self.no_flip = True # if specified, do not flip the images for data augmentation 167 | 168 | # dataloader 169 | self.serial_batches = False # if true, takes images in order to make batches, otherwise takes them randomly 170 | self.num_threads = 10 # threads for loading data 171 | 172 | # save model and output images 173 | self.save_epoch_freq = 50 # frequency of saving checkpoints at the end of epochs 174 | self.save_latest = False 175 | self.save_optimizer = True 176 | self.load_optimizer = False 177 | self.load_scaler = False 178 | 179 | # visdom and HTML visualization parameters 180 | self.display_env = self.name 181 | self.save_and_show_by_epoch = True 182 | self.display_freq = 4000 # frequency of showing training results on screen') 183 | self.display_ncols = -1 # if positive, display all images in a single visdom web panel with certain number of images per row.') 184 | self.display_id = 1 # window id of the web display') 185 | self.display_server = "http://localhost" # visdom server of the web display') 186 | 187 | self.display_port = 8097 # visdom port of the web display') 188 | self.update_html_freq = 4000 # frequency of saving training results to html') 189 | self.print_freq = 4000 # frequency of showing training results on console') 190 | self.no_html = False # do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 191 | 192 | -------------------------------------------------------------------------------- /options/train_options_rsr_ours_f.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | def __init__(self, which_experiment): 6 | super(TrainOptions, self).__init__() 7 | 8 | self.which_experiment = which_experiment 9 | self.isTrain = True 10 | 11 | # Setting for dataset 12 | self.dataset_mode = 'relighting_single_image_rsr' # name of the dataset 13 | self.dataset_rsr_type = "AnyLight" 14 | self.anno = 'data/anno_RSR/train.txt' # the anno file from prepare_dataset.py 15 | self.anno_validation = 'data/anno_RSR/AnyLightval_pairs.txt' 16 | # Setting for GPU 17 | self.parallel_method = "DistributedDataParallel" # "DataParallel" "DistributedDataParallel" 18 | if self.parallel_method == "DataParallel": 19 | number_gpus = 1 20 | self.gpu_ids = [i for i in range(number_gpus)] 21 | elif self.parallel_method == "DistributedDataParallel": 22 | self.world_size = None 23 | self.gpu_ids = None 24 | self.use_amp = False 25 | # parameters for batch 26 | self.batch_size = 6 27 | 28 | # Setting for the optimizer 29 | self.lr_policy = 'step' # learning rate policy. [linear | step | plateau | cosine] 30 | self.lr = 0.0001 # initial learning rate for adam 31 | self.lr_d = self.lr 32 | self.lr_decay_ratio = 0.5 # decay ratio in step scheduler. 33 | self.n_epochs = 150 #100 number of epochs with the initial learning rate 34 | self.n_epochs_decay = 0 #100 when using 'linear', number of epochs to linearly decay learning rate to zero 35 | self.lr_decay_iters = 100 # when using 'step', multiply by a gamma every lr_decay_iters iterations 36 | self.optimizer_type = 'Adam' # 'Adam', 'SGD' 37 | self.beta1 = 0.5 # momentum term of adam 38 | self.adam_eps = 1e-8 39 | 40 | # Setting for continuing the training. 41 | self.continue_train = True # continue training: load the latest model 42 | self.epoch = 'base_isr' # default='latest', which epoch to load? set to latest to use latest cached model 43 | self.load_iter = 0 # default='0', which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch] 44 | if self.continue_train: 45 | try: 46 | self.epoch_count = int(self.epoch) 47 | except: 48 | self.epoch_count = 1 # the starting epoch count, we save the model by , +, ...') 49 | else: 50 | self.epoch_count = 1 51 | 52 | self.model_modify_layer = [] 53 | self.modify_layer = len(self.model_modify_layer) != 0 54 | 55 | # Setting for the model 56 | self.name = self.which_experiment # name of the experiment. It decides where to store samples and models 57 | 58 | self.model_name = 'relighting_two_stage' # ['relighting' | 'intrinsic_decomposition'] 59 | self.two_stage = True 60 | self.no_dropout = False # old option: no dropout for the model 61 | self.norm = 'batch' # instance normalization or batch normalization [instance | batch | none] 62 | self.light_type = "pan_tilt_color" # ["pan_tilt_color" | "Spherical_harmonic"] 63 | self.light_prediction = True 64 | self.netG = "resnet9_nonlocal" 65 | self.net_intrinsic = "resnet9" 66 | self.introduce_ref_G_2 = False 67 | self.cross_model = True 68 | # range of sha 69 | self.infinite_range_sha = True 70 | if self.infinite_range_sha: 71 | self.net_intrinsic = self.net_intrinsic + "_InfRange" 72 | self.netG = self.netG + "_InfRange" 73 | # reflectance consistency 74 | self.flag_ref_consistency = True 75 | self.loss_weight_ref_consistency = 1.0 76 | self.flag_sha_consistency = True 77 | self.loss_weight_sha_consistency = 1.0 78 | # Regularizing chromaticity 79 | self.flag_sha_chromaticity_smooth = False 80 | # self.method_sha_chromaticity_smooth = "OPP" # "OPP", "LAB" 81 | # self.loss_weight_sha_chromaticity_smooth = 75.0 82 | # self.loss_weight_sha_overall_smooth = 0.5 83 | self.flag_sha_ref_regression = True 84 | # self.method_sha_ref_regression = 'm1' 85 | # self.loss_weight_sha_ref_regression_1 = 1.0 # for chromaticity 86 | # self.loss_weight_sha_ref_regression_2 = 0.1 # for all channels 87 | # self.sha_ref_regression_mean = [0.43, 0.61] 88 | self.method_sha_ref_regression = 'm2' 89 | self.para_sha_ref_regression = { 90 | # 'R_I_c': 1.2119, 91 | # 'R_I_a': 1.1603, 92 | 'S_I_c': 0.5254, 93 | 'S_I_a': 0.7089, 94 | # 'S_R_c': 0.4336, 95 | # 'S_R_a': 0.6109, 96 | 'elu_alpha': 0.1, 97 | 'elu_shift': 0.0, 98 | # 'w_R_I_c': 0.0, 99 | # 'w_R_I_a': 0.0, 100 | 'w_S_I_c': 2.0, 101 | 'w_S_I_a': 0.1, 102 | # 'w_S_R_c': 0.0, 103 | # 'w_S_R_a': 0.0, 104 | } 105 | # Regularizing init_ref 106 | self.flag_init_ref = True 107 | self.para_init_ref = { 108 | 'cross_ij': True, 109 | 'decay': True, 110 | 'method': "ORI" # "OPP", "ORI" 111 | } 112 | self.loss_weight_init_ref = 1.0 113 | 114 | # # discriminator 115 | self.use_discriminator = True 116 | # self.epoch_start_train_discriminator = -1 117 | self.netD = 'n_layers' # specify discriminator architecture [basic | n_layers | pixel]. 118 | # The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator' 119 | self.n_layers_D = 3 # only used if netD==n_layers 120 | self.gan_mode = 'lsgan' # 'the type of GAN objective. [vanilla| lsgan | wgangp]. 121 | # vanilla GAN loss is the cross-entropy objective used in the original GAN paper.' 122 | self.ndf = 64 # number of discrim filters in the first conv layer 123 | self.loss_weight_GAN = 0.05 124 | 125 | # parameters for loss functions 126 | self.constrain_intrinsic = False 127 | self.show_gt_intrinsic = False 128 | self.main_loss_function = 'L1_DSSIM_LPIPS' # choose using L2 or L1 during the training 129 | self.flag_L1_DSSIM_LPIPS = [True, True, True] 130 | if self.cross_model: 131 | self.unbalanced = False 132 | self.unbalanced_para = None 133 | # Weights of losses 134 | self.loss_weight_angular = 1.0 135 | self.loss_weight_color = 1.0 136 | self.loss_weight_reflectance = 1.0 137 | self.loss_weight_shading_ori = 1.0 138 | self.loss_weight_reconstruct = 1.0 139 | self.loss_weight_shading_new = 1.0 140 | self.loss_weight_relighted = 5.0 141 | 142 | # data augmentation 143 | self.preprocess = 'none' # 'resize_and_crop' # scaling and cropping of images at load time 144 | # [resize_and_crop | crop | scale_width | scale_width_and_crop | none] 145 | self.load_size = 256 # scale images to this size 146 | self.crop_size = 256 # then crop to this size 147 | self.no_flip = True # if specified, do not flip the images for data augmentation 148 | 149 | # dataloader 150 | self.serial_batches = False # if true, takes images in order to make batches, otherwise takes them randomly 151 | self.num_threads = 10 # threads for loading data 152 | 153 | # save model and output images 154 | self.save_epoch_freq = 50 # frequency of saving checkpoints at the end of epochs 155 | self.save_latest = False 156 | self.save_optimizer = True 157 | self.load_optimizer = False 158 | self.load_scaler = False 159 | 160 | # visdom and HTML visualization parameters 161 | self.display_env = self.name 162 | self.save_and_show_by_epoch = True 163 | self.display_freq = 4000 # frequency of showing training results on screen') 164 | self.display_ncols = 5 # if positive, display all images in a single visdom web panel with certain number of images per row.') 165 | self.display_id = 1 # window id of the web display') 166 | self.display_server = "http://localhost" # visdom server of the web display') 167 | 168 | self.display_port = 8097 # visdom port of the web display') 169 | self.update_html_freq = 4000 # frequency of saving training results to html') 170 | self.print_freq = 4000 # frequency of showing training results on console') 171 | self.no_html = False # do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 172 | 173 | -------------------------------------------------------------------------------- /options/train_options_vidit_ours_f.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | def __init__(self, which_experiment): 6 | super(TrainOptions, self).__init__() 7 | 8 | self.which_experiment = which_experiment 9 | self.isTrain = True 10 | 11 | # Setting for dataset 12 | self.dataset_mode = 'relighting_single_image_vidit' # name of the dataset 13 | self.dataset_assignment_type = "AnyLight" 14 | self.anno = 'data/anno_VIDIT/any2any/train.txt' # the anno file from prepare_dataset.py 15 | self.anno_validation = 'data/anno_VIDIT/any2any/AnyLight_val_pairs.txt' 16 | # Setting for GPU 17 | self.parallel_method = "DistributedDataParallel" # "DataParallel" "DistributedDataParallel" 18 | if self.parallel_method == "DataParallel": 19 | number_gpus = 1 20 | self.gpu_ids = [i for i in range(number_gpus)] 21 | elif self.parallel_method == "DistributedDataParallel": 22 | self.world_size = None 23 | self.gpu_ids = None 24 | self.use_amp = False 25 | # parameters for batch 26 | self.batch_size = 6 27 | 28 | # Setting for the optimizer 29 | self.lr_policy = 'step' # learning rate policy. [linear | step | plateau | cosine] 30 | self.lr = 0.0001 # initial learning rate for adam 31 | self.lr_d = self.lr 32 | self.lr_decay_ratio = 0.5 # decay ratio in step scheduler. 33 | self.n_epochs = 150 #100 number of epochs with the initial learning rate 34 | self.n_epochs_decay = 0 #100 when using 'linear', number of epochs to linearly decay learning rate to zero 35 | self.lr_decay_iters = 100 # when using 'step', multiply by a gamma every lr_decay_iters iterations 36 | self.optimizer_type = 'Adam' # 'Adam', 'SGD' 37 | self.beta1 = 0.5 # momentum term of adam 38 | self.adam_eps = 1e-8 39 | 40 | # Setting for continuing the training. 41 | self.continue_train = True # continue training: load the latest model 42 | self.epoch = 'base_isr' # default='latest', which epoch to load? set to latest to use latest cached model 43 | self.load_iter = 0 # default='0', which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch] 44 | if self.continue_train: 45 | try: 46 | self.epoch_count = int(self.epoch) 47 | except: 48 | self.epoch_count = 1 # the starting epoch count, we save the model by , +, ...') 49 | else: 50 | self.epoch_count = 1 51 | 52 | self.model_modify_layer = [] 53 | self.modify_layer = len(self.model_modify_layer) != 0 54 | 55 | # Setting for the model 56 | self.name = self.which_experiment # name of the experiment. It decides where to store samples and models 57 | 58 | self.model_name = 'relighting_two_stage' # ['relighting' | 'intrinsic_decomposition'] 59 | self.two_stage = True 60 | self.no_dropout = False # old option: no dropout for the model 61 | self.norm = 'batch' # instance normalization or batch normalization [instance | batch | none] 62 | self.light_type = "pan_tilt_color" # ["pan_tilt_color" | "Spherical_harmonic"] 63 | self.light_prediction = True 64 | self.netG = "resnet9_nonlocal" 65 | self.net_intrinsic = "resnet9" 66 | self.introduce_ref_G_2 = False 67 | self.cross_model = True 68 | # range of sha 69 | self.infinite_range_sha = True 70 | if self.infinite_range_sha: 71 | self.net_intrinsic = self.net_intrinsic + "_InfRange" 72 | self.netG = self.netG + "_InfRange" 73 | # reflectance consistency 74 | self.flag_ref_consistency = True 75 | self.loss_weight_ref_consistency = 1.0 76 | self.flag_sha_consistency = True 77 | self.loss_weight_sha_consistency = 1.0 78 | # Regularizing chromaticity 79 | self.flag_sha_chromaticity_smooth = False 80 | # self.method_sha_chromaticity_smooth = "OPP" # "OPP", "LAB" 81 | # self.loss_weight_sha_chromaticity_smooth = 75.0 82 | # self.loss_weight_sha_overall_smooth = 0.5 83 | self.flag_sha_ref_regression = True 84 | # self.method_sha_ref_regression = 'm1' 85 | # self.loss_weight_sha_ref_regression_1 = 1.0 # for chromaticity 86 | # self.loss_weight_sha_ref_regression_2 = 0.1 # for all channels 87 | # self.sha_ref_regression_mean = [0.43, 0.61] 88 | self.method_sha_ref_regression = 'm2' 89 | self.para_sha_ref_regression = { 90 | # 'R_I_c': 1.2119, 91 | # 'R_I_a': 1.1603, 92 | 'S_I_c': 0.5254, 93 | 'S_I_a': 0.7089, 94 | # 'S_R_c': 0.4336, 95 | # 'S_R_a': 0.6109, 96 | 'elu_alpha': 0.1, 97 | 'elu_shift': 0.0, 98 | # 'w_R_I_c': 0.0, 99 | # 'w_R_I_a': 0.0, 100 | 'w_S_I_c': 2.0, 101 | 'w_S_I_a': 0.1, 102 | # 'w_S_R_c': 0.0, 103 | # 'w_S_R_a': 0.0, 104 | } 105 | # Regularizing init_ref 106 | self.flag_init_ref = True 107 | self.para_init_ref = { 108 | 'cross_ij': True, 109 | 'decay': True, 110 | 'method': "ORI" # "OPP", "ORI" 111 | } 112 | self.loss_weight_init_ref = 1.0 113 | 114 | # # discriminator 115 | self.use_discriminator = True 116 | # self.epoch_start_train_discriminator = -1 117 | self.netD = 'n_layers' # specify discriminator architecture [basic | n_layers | pixel]. 118 | # The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator' 119 | self.n_layers_D = 3 # only used if netD==n_layers 120 | self.gan_mode = 'lsgan' # 'the type of GAN objective. [vanilla| lsgan | wgangp]. 121 | # vanilla GAN loss is the cross-entropy objective used in the original GAN paper.' 122 | self.ndf = 64 # number of discrim filters in the first conv layer 123 | self.loss_weight_GAN = 0.05 124 | 125 | # parameters for loss functions 126 | self.constrain_intrinsic = False 127 | self.show_gt_intrinsic = False 128 | self.main_loss_function = 'L1_DSSIM_LPIPS' # choose using L2 or L1 during the training 129 | self.flag_L1_DSSIM_LPIPS = [True, True, True] 130 | if self.cross_model: 131 | self.unbalanced = False 132 | self.unbalanced_para = None 133 | # Weights of losses 134 | self.loss_weight_angular = 1.0 135 | self.loss_weight_color = 1.0 136 | self.loss_weight_reflectance = 1.0 137 | self.loss_weight_shading_ori = 1.0 138 | self.loss_weight_reconstruct = 1.0 139 | self.loss_weight_shading_new = 1.0 140 | self.loss_weight_relighted = 5.0 141 | 142 | # data augmentation 143 | self.preprocess = 'resize' # 'resize_and_crop' # scaling and cropping of images at load time 144 | # [resize_and_crop | crop | scale_width | scale_width_and_crop | none] 145 | self.load_size = 256 # scale images to this size 146 | self.crop_size = 256 # then crop to this size 147 | self.no_flip = True # if specified, do not flip the images for data augmentation 148 | 149 | # dataloader 150 | self.serial_batches = False # if true, takes images in order to make batches, otherwise takes them randomly 151 | self.num_threads = 10 # threads for loading data 152 | 153 | # save model and output images 154 | self.save_epoch_freq = 150 # frequency of saving checkpoints at the end of epochs 155 | self.save_latest = False 156 | self.save_optimizer = True 157 | self.load_optimizer = False 158 | self.load_scaler = False 159 | 160 | # visdom and HTML visualization parameters 161 | self.display_env = self.name 162 | self.save_and_show_by_epoch = True 163 | self.display_freq = 4000 # frequency of showing training results on screen') 164 | self.display_ncols = 5 # if positive, display all images in a single visdom web panel with certain number of images per row.') 165 | self.display_id = 1 # window id of the web display') 166 | self.display_server = "http://localhost" # visdom server of the web display') 167 | 168 | self.display_port = 8097 # visdom port of the web display') 169 | self.update_html_freq = 4000 # frequency of saving training results to html') 170 | self.print_freq = 4000 # frequency of showing training results on console') 171 | self.no_html = False # do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 172 | 173 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | visdom==0.1.8.9 2 | tqdm==4.62.3 3 | pytorch-msssim==0.2.1 4 | lpips==0.1.4 5 | kornia==0.6.5 6 | dominate==2.6.0 7 | thop==0.1.1.post2209072238 8 | opencv-python==4.5.5.64 9 | pillow==8.2.0 10 | imageio==2.9.0 11 | imageio-ffmpeg==0.5.1 12 | -------------------------------------------------------------------------------- /test_qualitative.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | os.environ["CUDA_VISIBLE_DEVICES"] = '5' 4 | from options.test_qualitative_options import TestOptions 5 | from data import create_dataset 6 | from models.models import create_model 7 | from util.visualizer import save_images_one_batch 8 | from util import html 9 | 10 | 11 | if __name__ == '__main__': 12 | opt = TestOptions().parse() # get test options 13 | 14 | dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options 15 | model = create_model(opt) # create a model given opt.model and other options 16 | model.setup(opt) # regular setup: load and print networks; create schedulers 17 | # create a website 18 | web_dir = os.path.join(opt.results_dir, opt.name, '{}_{}'.format(opt.phase, opt.epoch)) # define the website directory 19 | if opt.load_iter > 0: # load_iter is 0 by default 20 | web_dir = '{:s}_iter{:d}'.format(web_dir, opt.load_iter) 21 | print('creating web directory', web_dir) 22 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch)) 23 | # test with eval mode. This only affects layers like batchnorm and dropout. 24 | # For [pix2pix]: we use batchnorm and dropout in the original pix2pix. You can experiment it with and without eval() mode. 25 | # For [CycleGAN]: It should not affect CycleGAN as CycleGAN uses instancenorm without dropout. 26 | if opt.eval: 27 | model.eval() 28 | for i, data in enumerate(dataset): 29 | if i >= opt.num_test: # only apply our model to opt.num_test images. 30 | break 31 | model.set_input(data) # unpack data from data loader 32 | model.test() # run inference 33 | visuals = model.get_current_visuals() # get image results 34 | # img_path should provide the image name of one batch in list format. 35 | img_path = model.get_image_paths() # get image paths 36 | # img_path = [x for x in img_path[0]] 37 | if i % 5 == 0: # save images to an HTML file 38 | print('processing (%04d)-th image... %s' % (i, img_path)) 39 | save_images_one_batch(webpage, visuals, img_path, opt.normalization_type, aspect_ratio=opt.aspect_ratio, 40 | width=opt.display_winsize) 41 | webpage.save() # save the HTML 42 | -------------------------------------------------------------------------------- /test_qualitative.txt: -------------------------------------------------------------------------------- 1 | ------------ Options ------------- 2 | anno: data/anno_RSR/AnyLighttest_pairs_qualitative.txt 3 | aspect_ratio: 1.0 4 | batch_size: 1 5 | checkpoints_dir: ./checkpoints/ 6 | constrain_intrinsic: False 7 | continue_train: True 8 | crop_size: 256 9 | cross_model: False 10 | dataroot: /ghome/yyang/dataset/ISR/ 11 | dataroot_RSR: /ghome/yyang/dataset/RSR_256/ 12 | dataroot_multilum: /ghome/yyang/dataset/Multi_Illumination_small/train/ 13 | dataroot_vidit: /ghome/yyang/dataset/VIDIT_full/ 14 | dataset_mode: relighting_single_image_rsr 15 | display_id: -1 16 | display_winsize: 256 17 | epoch: save_best 18 | eval: True 19 | gpu_ids: [0] 20 | img_size: (256, 256) 21 | infinite_range_sha: True 22 | init_gain: 0.02 23 | init_type: normal 24 | input_nc: 3 25 | introduce_ref_G_2: False 26 | isTrain: False 27 | light_prediction: True 28 | light_type: pan_tilt_color 29 | load_iter: 0 30 | load_size: 256 31 | max_dataset_size: inf 32 | metric_list: ['Relighted', 'light_position_color'] 33 | model_modify_layer: [] 34 | model_name: relighting_two_stage 35 | modify_layer: False 36 | multiple_replace_image: True 37 | name: exp_rsr_ours_f, 38 | netG: resnet9_nonlocal_InfRange 39 | net_intrinsic: resnet9_InfRange 40 | ngf: 64 41 | no_dropout: False 42 | no_flip: True 43 | norm: batch 44 | normalization_type: [0, 1] 45 | num_test: 130 46 | num_threads: 1 47 | output_nc: 3 48 | parallel_method: DataParallel 49 | phase: test_rsr 50 | pre_read_data: False 51 | preprocess: none 52 | results_dir: ./results/ 53 | serial_batches: True 54 | server_root: /ghome/yyang/dataset/ 55 | show_gt_intrinsic: False 56 | special_test: False 57 | two_stage: True 58 | use_amp: False 59 | use_discriminator: False 60 | verbose: False 61 | -------------- End ---------------- 62 | dataset [RelightingDatasetSingleImageRSR] was created (shuffle=False, drop_last=False) 63 | initialize network with normal 64 | ------------ Options ------------- 65 | anno: data/anno_VIDIT/any2any/AnyLight_test_pairs_qualitative.txt 66 | aspect_ratio: 1.0 67 | batch_size: 1 68 | checkpoints_dir: ./checkpoints/ 69 | constrain_intrinsic: False 70 | continue_train: True 71 | crop_size: 256 72 | cross_model: False 73 | dataroot: /ghome/yyang/dataset/ISR/ 74 | dataroot_RSR: /ghome/yyang/dataset/RSR_256/ 75 | dataroot_multilum: /ghome/yyang/dataset/Multi_Illumination_small/train/ 76 | dataroot_vidit: /ghome/yyang/dataset/VIDIT_full/ 77 | dataset_mode: relighting_single_image_vidit 78 | display_id: -1 79 | display_winsize: 256 80 | epoch: save_best 81 | eval: True 82 | gpu_ids: [0] 83 | img_size: (256, 256) 84 | infinite_range_sha: True 85 | init_gain: 0.02 86 | init_type: normal 87 | input_nc: 3 88 | introduce_ref_G_2: False 89 | isTrain: False 90 | light_prediction: True 91 | light_type: pan_tilt_color 92 | load_iter: 0 93 | load_size: 256 94 | max_dataset_size: inf 95 | metric_list: ['Relighted', 'light_position_color'] 96 | model_modify_layer: [] 97 | model_name: relighting_two_stage 98 | modify_layer: False 99 | multiple_replace_image: True 100 | name: exp_vidit_ours_f, 101 | netG: resnet9_nonlocal_InfRange 102 | net_intrinsic: resnet9_InfRange 103 | ngf: 64 104 | no_dropout: False 105 | no_flip: True 106 | norm: batch 107 | normalization_type: [0, 1] 108 | num_test: 130 109 | num_threads: 1 110 | output_nc: 3 111 | parallel_method: DataParallel 112 | phase: test_vidit 113 | pre_read_data: False 114 | preprocess: resize 115 | results_dir: ./results/ 116 | serial_batches: True 117 | server_root: /ghome/yyang/dataset/ 118 | show_gt_intrinsic: False 119 | special_test: False 120 | two_stage: True 121 | use_amp: False 122 | use_discriminator: False 123 | verbose: False 124 | -------------- End ---------------- 125 | dataset [RelightingDatasetSingleImageVidit] was created (shuffle=False, drop_last=False) 126 | initialize network with normal 127 | ------------ Options ------------- 128 | anno: data/multi_illumination/test_qualitative.txt 129 | aspect_ratio: 1.0 130 | batch_size: 1 131 | checkpoints_dir: ./checkpoints/ 132 | constrain_intrinsic: False 133 | continue_train: True 134 | crop_size: 256 135 | cross_model: False 136 | dataroot: /ghome/yyang/dataset/ISR/ 137 | dataroot_RSR: /ghome/yyang/dataset/RSR_256/ 138 | dataroot_multilum: /ghome/yyang/dataset/Multi_Illumination_small/test/ 139 | dataroot_vidit: /ghome/yyang/dataset/VIDIT_full/ 140 | dataset_mode: relighting_single_image_multilum 141 | display_id: -1 142 | display_winsize: 256 143 | epoch: save_best 144 | eval: True 145 | gpu_ids: [0] 146 | img_size: (256, 256) 147 | infinite_range_sha: True 148 | init_gain: 0.02 149 | init_type: normal 150 | input_nc: 3 151 | introduce_ref_G_2: False 152 | isTrain: False 153 | light_prediction: False 154 | light_type: probes 155 | load_iter: 0 156 | load_size: 256 157 | max_dataset_size: inf 158 | metric_list: ['Relighted'] 159 | model_modify_layer: [] 160 | model_name: relighting_two_stage 161 | modify_layer: False 162 | multiple_replace_image: True 163 | name: exp_multilum_ours_f 164 | netG: resnet9_nonlocal_InfRange 165 | net_intrinsic: resnet9_InfRange 166 | ngf: 64 167 | no_dropout: False 168 | no_flip: True 169 | norm: batch 170 | normalization_type: [0, 1] 171 | num_test: 130 172 | num_threads: 1 173 | output_nc: 3 174 | parallel_method: DataParallel 175 | phase: test_multilum 176 | pre_read_data: False 177 | preprocess: none 178 | results_dir: ./results/ 179 | serial_batches: True 180 | server_root: /ghome/yyang/dataset/ 181 | show_gt_intrinsic: False 182 | special_test: False 183 | two_stage: True 184 | use_amp: False 185 | use_discriminator: False 186 | verbose: False 187 | -------------- End ---------------- 188 | dataset [RelightingDatasetSingleImageMultilum] was created (shuffle=False, drop_last=False) 189 | initialize network with normal 190 | -------------------------------------------------------------------------------- /test_qualitative_animation.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | os.environ["CUDA_VISIBLE_DEVICES"] = '5' 4 | from options.test_qualitative_options import TestOptions 5 | from models.models import create_model 6 | from util.util import PARA_NOR 7 | from util.util import tensor2im 8 | from util.k_to_rgb import convert_K_to_RGB 9 | import cv2 10 | import numpy as np 11 | import torch 12 | from tqdm import tqdm 13 | from data.base_dataset import get_params, get_transform 14 | from PIL import Image 15 | import imageio 16 | 17 | 18 | def light_condition2tensor(pan_deg, tilt_deg, color, color_type = "temperature"): 19 | """ 20 | transform pan, tilt, color into the tensors for input. 21 | :param pan: in deg 22 | :param tilt: in deg 23 | :param color: in temperature 24 | :return: tensor size(7) 25 | """ 26 | factor_deg2rad = math.pi / 180.0 27 | pan = float(pan_deg) * factor_deg2rad 28 | tilt = float(tilt_deg) * factor_deg2rad 29 | 30 | # transform light position to cos and sin 31 | light_position = [math.cos(pan), math.sin(pan), math.cos(tilt), math.sin(tilt)] 32 | # normalize the light position to [0, 1] 33 | light_position[:2] = [x * PARA_NOR['pan_a'] + PARA_NOR['pan_b'] for x in light_position[:2]] 34 | light_position[2:] = [x * PARA_NOR['tilt_a'] + PARA_NOR['tilt_b'] for x in light_position[2:]] 35 | # transform light temperature to RGB, and normalize it. 36 | if color_type == "temperature": 37 | color_temp = int(color) 38 | light_color = list(map(lambda x: x / 255.0, convert_K_to_RGB(color_temp))) 39 | else: 40 | light_color = [x/255 for x in color] 41 | light_position_color = light_position + light_color 42 | return torch.tensor(light_position_color) 43 | 44 | 45 | 46 | def read_image(img_name, opt): 47 | transform_params = get_params(opt, opt.img_size) 48 | img_transform = get_transform(opt, transform_params) 49 | if not os.path.exists(img_name): 50 | raise Exception("RelightingDataset __getitem__ error") 51 | img_component = Image.open(img_name).convert('RGB') 52 | aspect = img_component.size[1] / img_component.size[0] 53 | img_component = img_transform(img_component) 54 | return img_component.unsqueeze(0), aspect 55 | 56 | 57 | class ImageDial(): 58 | def __init__(self, dial_img_name): 59 | dial_img = Image.open(dial_img_name).convert('RGB') 60 | dial_img = np.array(dial_img) 61 | scale_ratio = 256 / 880 62 | self.dial_img = cv2.resize(dial_img, None, fx=scale_ratio, fy=scale_ratio, 63 | interpolation=cv2.INTER_CUBIC) 64 | self.dial_center = [int(523 * scale_ratio), int(1320 * scale_ratio)] 65 | self.radius = 400 * scale_ratio 66 | # original size is (281, 768, 3), we need to fill to (288, 768, 3) to satisfy macro_block_size=16 in imageio 67 | self.h_pad, self.w_pad = tuple([x if x % 16 == 0 else x + 16 - x % 16 for x in self.dial_img.shape[:2]]) 68 | 69 | def insert_img(self, img_input, img_relit, pan, tilt): 70 | merged_img = np.copy(self.dial_img) 71 | # merged_img[-256:, :256, :] = cv2.cvtColor(img_input, cv2.COLOR_RGB2BGR) 72 | # merged_img[-256:, -256:, :] = cv2.cvtColor(img_relit, cv2.COLOR_RGB2BGR) 73 | merged_img[-256:, :256, :] = img_input 74 | merged_img[-256:, -256:, :] = img_relit 75 | # plot the point of pan and tilt. 76 | length = math.sin(tilt/180*math.pi) / math.sin(50/180*math.pi) * self.radius 77 | position = (int(self.dial_center[1] + math.sin(pan/180*math.pi) * length), 78 | int(self.dial_center[0] + math.cos(pan/180*math.pi) * length)) 79 | # cv2.circle(merged_img, position, 3, (0, 0, 255), -1) 80 | cv2.circle(merged_img, position, 3, (255, 0, 0), -1) 81 | # add padding 82 | white_padding = np.full((self.h_pad, self.w_pad, 3), 255, dtype=np.uint8) 83 | white_padding[:merged_img.shape[0], :merged_img.shape[1]] = merged_img 84 | return white_padding 85 | 86 | 87 | def generate_path(points, steps, length): 88 | loop_path = [] 89 | for i in range(len(points)-1): 90 | point_a = points[i] 91 | point_b = points[i+1] 92 | this_step = [step if point_b[k] > point_a[k] else -step for k, step in enumerate(steps)] 93 | lists = [np.arange(point_a[k], point_b[k], this_step[k]) for k in range(len(point_a))] 94 | for j in range(max([len(lst) for lst in lists])): 95 | loop_path.append([lists[k][j] if j < len(lists[k]) else point_b[k] for k in range(len(lists))]) 96 | path = [loop_path[i % len(loop_path)] for i in range(length)] 97 | 98 | return path 99 | 100 | 101 | def create_pan_tilt_temperature_seq(length, seq_type): 102 | # pan, tilt, temperature 103 | # default_steps = [2, 1, 100] 104 | default_start = [90, 30, 4100] 105 | if seq_type == "cycle_tilt": 106 | points = [[default_start[0], 40.0, default_start[2]], 107 | [default_start[0], 0, default_start[2]], 108 | [-default_start[0], 40, default_start[2]], 109 | [-default_start[0], 0, default_start[2]], 110 | [default_start[0], 40.0, default_start[2]]] 111 | steps = [float('inf'), 1, float('inf')] 112 | elif seq_type == "cycle_pan": 113 | points = [[0, default_start[1], default_start[2]], 114 | [360, default_start[1], default_start[2]], ] 115 | steps = [2, float('inf'), float('inf')] 116 | elif seq_type == "cycle_temperature": 117 | points = [[default_start[0], default_start[1], 2300], 118 | [default_start[0], default_start[1], 6400],] 119 | steps = [float('inf'), float('inf'), 100] 120 | else: 121 | raise Exception("seq_type wrong!") 122 | 123 | sequence = generate_path(points, steps, length) 124 | return sequence 125 | 126 | 127 | if __name__ == '__main__': 128 | opt = TestOptions().parse() # get test options 129 | opt.special_test = True 130 | 131 | dial_img_name = "./util/pan_tilt_dial.png" 132 | img_dial = ImageDial(dial_img_name) 133 | 134 | data = {} 135 | img_name = "./202102_008_221_35_3200_108_00_Image_input.png" 136 | data['scene_label'] = img_name.split('/')[-1] 137 | data['Image_input'], _ = read_image(img_name, opt) 138 | 139 | model = create_model(opt) # create a model given opt.model and other options 140 | model.setup(opt) # regular setup: load and print networks; create schedulers 141 | if opt.eval: 142 | model.eval() 143 | 144 | frame_number = 320 145 | seq_type = "cycle_pan" 146 | seq_light = create_pan_tilt_temperature_seq(frame_number, seq_type=seq_type) 147 | suffix = '_' + seq_type 148 | 149 | out_dir = os.path.join(opt.results_dir, opt.name, opt.epoch) # define the website directory 150 | input_name = os.path.splitext(data['scene_label'])[0] 151 | fix_tilt = True 152 | video_reso = (768, 281) 153 | video_name = '{}_{}'.format(out_dir, input_name)+suffix 154 | fps = 25 155 | # out = cv2.VideoWriter(video_name + '.avi', cv2.VideoWriter_fourcc('I', '4', '2', '0'), fps, video_reso) 156 | # Use MPEG-4 encoding 157 | # fourcc = cv2.VideoWriter_fourcc(*'avc1') 158 | # out = cv2.VideoWriter(video_name + '.mp4', fourcc, fps, video_reso) 159 | 160 | writer = imageio.get_writer(video_name + '.mp4', fps=fps, codec='libx264') 161 | 162 | print("Create video at {}".format(video_name + '.mp4')) 163 | 164 | for frame in tqdm(range(frame_number)): 165 | pan, tilt, temperature = tuple(seq_light[frame]) 166 | 167 | data['light_position_color_new'] = light_condition2tensor(pan, tilt, temperature, color_type="temperature").unsqueeze(0) 168 | model.set_input(data) # unpack data from data loader 169 | model.test() # run inference 170 | visuals = model.get_current_visuals() # get image results 171 | 172 | im_input = tensor2im(visuals['Image_input'][0].unsqueeze(0), opt.normalization_type) 173 | im_relit = tensor2im(visuals['Relighted_predict'][0].unsqueeze(0), opt.normalization_type) 174 | im = img_dial.insert_img(im_input, im_relit, pan, tilt) 175 | 176 | # out.write(im) 177 | writer.append_data(im) 178 | # out.release() 179 | writer.close() 180 | 181 | 182 | -------------------------------------------------------------------------------- /test_quantitative.py: -------------------------------------------------------------------------------- 1 | """ 2 | test script which is used to get quantitive results 3 | """ 4 | import os 5 | # os.environ["CUDA_VISIBLE_DEVICES"] = '4' 6 | from options.test_quantitative_options import TestQuantitiveOptions 7 | from data import create_dataset 8 | from models.models import create_model 9 | import torch 10 | from tqdm import tqdm 11 | from util.metric import calculate_all_metrics 12 | 13 | 14 | if __name__ == '__main__': 15 | opt = TestQuantitiveOptions().parse() # get test options 16 | print(opt.name) 17 | opt.pre_read_data = False 18 | 19 | dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options 20 | model = create_model(opt) # create a model given opt.model and other options 21 | model.setup(opt) # regular setup: load and print networks; create schedulers 22 | 23 | # test with eval mode. This only affects layers like batchnorm and dropout. 24 | # For [pix2pix]: we use batchnorm and dropout in the original pix2pix. You can experiment it with and without eval() mode. 25 | # For [CycleGAN]: It should not affect CycleGAN as CycleGAN uses instancenorm without dropout. 26 | if opt.eval: 27 | model.eval() 28 | metric_results = {} 29 | metric_function = calculate_all_metrics() 30 | for key in opt.metric_list: 31 | metric_results[key] = [] 32 | for i, data in tqdm(enumerate(dataset)): 33 | if i >= opt.num_test: # only apply our model to opt.num_test images. 34 | break 35 | model.set_input(data) # unpack data from data loader 36 | model.test() # run inference 37 | visuals = model.get_current_visuals() # get image results 38 | all_results = metric_function.run(visuals, metric_results.keys()) 39 | for key in metric_results.keys(): 40 | metric_results[key].append(all_results[key]) 41 | for key in metric_results.keys(): 42 | results = torch.tensor(metric_results[key]) 43 | results_mean = torch.mean(results, 0) 44 | results_std = torch.std(results, 0) 45 | if key in ['Reflectance', 'Shading_ori', 'Shading_new', 'Relighted', 'Reconstruct', 'Input_and_relighted_gt']: 46 | print("{}: MPS, SSIM, LPIPS, PSNR, MSE" 47 | " = ,{:.4f}, {:.4f}, {:.4f}, {:.4f}, {:.4f}, " 48 | "".format(key, results_mean[4], results_mean[1], results_mean[3], results_mean[2], results_mean[0])) 49 | print("{}(std): MPS, SSIM, LPIPS, PSNR, MSE" 50 | " = ,{:.4f}, {:.4f}, {:.4f}, {:.4f}, {:.4f}, " 51 | "".format(key, results_std[4], results_std[1], results_std[3], results_std[2], results_std[0])) 52 | elif key == 'light_position_color': 53 | print("{}: light_position_angle_error, pan_error, tilt_error, " 54 | "light_color_angle_error = ,{:.4f}, {:.4f}, {:.4f}, {:.4f}, ".format(key, results_mean[0], 55 | results_mean[1], 56 | results_mean[2], 57 | results_mean[3])) 58 | print("{}(std): light_position_angle_error, pan_error, tilt_error, " 59 | "light_color_angle_error = ,{:.4f}, {:.4f}, {:.4f}, {:.4f}, ".format(key, results_std[0], 60 | results_std[1], 61 | results_std[2], 62 | results_std[3])) 63 | else: 64 | raise Exception("key error") 65 | 66 | 67 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | # os.environ["CUDA_VISIBLE_DEVICES"] = '4,5,6,7' 4 | 5 | import time 6 | 7 | from data import create_dataset 8 | from models.models import create_model 9 | from util.visualizer import Visualizer 10 | import torch 11 | import torch.distributed as dist 12 | from util.metric import calculate_all_metrics 13 | 14 | from tqdm import tqdm 15 | import numpy as np 16 | import shutil 17 | import argparse 18 | import importlib 19 | 20 | if __name__ == '__main__': 21 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 22 | parser.add_argument('--local_rank') 23 | parser.add_argument('option_name', type=str) 24 | parser_value = parser.parse_args() 25 | option_name = parser_value.option_name 26 | # import package 27 | TrainOptions = getattr(importlib.import_module('options.train_options_{}'.format(option_name)), "TrainOptions") 28 | 29 | opt = TrainOptions('exp_'+option_name).parse() # set CUDA_VISIBLE_DEVICES before import torch 30 | # print(opt.which_experiment) 31 | 32 | if opt.parallel_method == "DistributedDataParallel": 33 | rank = int(os.environ["RANK"]) 34 | print(rank) 35 | world_size = int(os.environ['WORLD_SIZE']) 36 | opt.gpu_ids = rank 37 | opt.world_size = world_size 38 | torch.cuda.set_device(rank) 39 | dist.init_process_group(backend='nccl', init_method='env://', 40 | world_size=world_size, rank=rank) 41 | 42 | # flag_master controls to show result and save checkpoint. 43 | flag_master = opt.parallel_method != "DistributedDataParallel" or ( 44 | opt.parallel_method == "DistributedDataParallel" and rank == 0) 45 | 46 | model = create_model(opt) 47 | 48 | dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options 49 | dataset_validation = create_dataset(opt, validation=True) 50 | dataset_size = len(dataset) # get the number of images in the dataset. 51 | dataset_size_validation = len(dataset_validation) # get the number of images in the dataset. 52 | if flag_master: 53 | print('The number of training images = %d' % dataset_size) 54 | print('The number of validation images = %d' % dataset_size_validation) 55 | 56 | model.setup(opt) # regular setup: load and print networks; create schedulers 57 | # model.plot_model() # plot the model 58 | # model.get_macs() # plot the model 59 | if flag_master: 60 | visualizer = Visualizer(opt) # create a visualizer that display/save images and plots 61 | total_iters = 0 # the total number of training iterations 62 | best_loss_relit_validation = float('inf') 63 | # define the metric for monitor during training 64 | metric_function = calculate_all_metrics() 65 | if hasattr(opt, 'val_keys'): 66 | val_keys = opt.val_keys 67 | else: 68 | val_keys = ['Relighted'] 69 | 70 | # outer loop for different epochs; we save the model by , + 71 | 72 | for epoch in range(opt.epoch_count, 73 | opt.n_epochs + opt.n_epochs_decay + 1): 74 | epoch_start_time = time.time() # timer for entire epoch 75 | iter_data_time = time.time() # timer for data loading per iteration 76 | epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch 77 | if flag_master: 78 | visualizer.reset() # reset the visualizer: make sure it saves the results to HTML at least once every epoch 79 | 80 | print("Begin training") 81 | model.train() 82 | if opt.parallel_method == "DistributedDataParallel": 83 | dataset.dataloader.sampler.set_epoch(epoch) 84 | loss_epoch = [] 85 | for i, data in tqdm(enumerate(dataset)): # inner loop within one epoch 86 | # if i > 10: 87 | # break 88 | iter_start_time = time.time() # timer for computation per iteration 89 | if total_iters % opt.print_freq == 0: 90 | t_data = iter_start_time - iter_data_time 91 | total_iters += 1 # opt.batch_size 92 | epoch_iter += 1 # opt.batch_size 93 | model.set_input(data) # unpack data from dataset and apply preprocessing 94 | model.optimize_parameters(epoch, i) # calculate loss functions, get gradients, update network weights 95 | iter_data_time = time.time() 96 | # print training losses and save logging information to the disk 97 | current_loss = model.get_current_losses() 98 | current_loss['weighted_total'] = float(model.loss_weighted_total) 99 | loss_epoch.append(current_loss) 100 | losses = {} 101 | for key in current_loss: 102 | losses[key] = np.mean([x[key] for x in loss_epoch]) 103 | 104 | # display images on visdom and save images to a HTML file 105 | save_result = epoch_iter % opt.update_html_freq == 0 106 | model.compute_visuals() 107 | if flag_master: 108 | visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) 109 | 110 | # print("light_position_color_predict[0] = ", model.light_position_color_predict[0]) 111 | # print("light_position_color_original[0] = ", model.light_position_color_original[0]) 112 | 113 | print("Begin validation") 114 | loss_val = [] 115 | metric_val = [] 116 | model.eval() 117 | with torch.no_grad(): 118 | # test the last batch of training 119 | model.test() 120 | visuals = model.get_current_visuals() # get image results 121 | metric_train = metric_function.run(visuals, ['Relighted'])['Relighted'] 122 | # test the validation dataset 123 | for i, data in tqdm(enumerate(dataset_validation)): # inner loop within one epoch 124 | # if i > 10: 125 | # break 126 | model.set_input(data) 127 | model.test() # run inference 128 | visuals = model.get_current_visuals() # get image results 129 | val_results = [] 130 | for val_key in val_keys: 131 | val_results.extend(metric_function.run(visuals, [val_key])[val_key]) 132 | metric_val.append(torch.stack(val_results)) 133 | loss_val.append(model.calculate_val_loss()) 134 | # last batch of train 135 | metric_train = torch.stack(metric_train).unsqueeze(0) 136 | # vallidation 137 | metric_val = torch.stack(metric_val) 138 | loss_relit_validation = torch.stack(loss_val) 139 | if opt.parallel_method == "DistributedDataParallel": 140 | # print(losses) 141 | for key in losses: 142 | value = torch.tensor(losses[key]).cuda() 143 | dist.all_reduce(value) 144 | losses[key] = float(value) / float(world_size) 145 | # print(losses) 146 | # print(metric_train) 147 | dist.all_reduce(metric_train) 148 | # print("metric_train:", metric_train) 149 | dist.all_reduce(metric_val) 150 | dist.all_reduce(loss_relit_validation) 151 | torch.distributed.barrier() 152 | # all_reduce collects the sum of all GPUs results, so it needs to be averaged. 153 | metric_train = metric_train / float(world_size) 154 | metric_val = metric_val / float(world_size) 155 | loss_relit_validation = loss_relit_validation / float(world_size) 156 | # move to the cpu, otherwise Visdom cannot work. 157 | metric_train = torch.mean(metric_train, 0).cpu() 158 | metric_val_mean = torch.mean(metric_val, 0).cpu() 159 | loss_relit_validation = float(torch.mean(loss_relit_validation, 0)) 160 | 161 | # add loss 162 | losses['relit_validation'] = loss_relit_validation 163 | # add metric 164 | for metric_index, metric_key in enumerate(['MSE', 'SSIM', 'PSNR', 'LPIPS', 'MPS']): 165 | losses['_'.join(['train', metric_key])] = metric_train[metric_index] 166 | val_count = 0 167 | for val_key in val_keys: 168 | for metric_key in ['MSE', 'SSIM', 'PSNR', 'LPIPS', 'MPS']: 169 | losses['_'.join(['val', val_key, metric_key])] = metric_val_mean[val_count] 170 | val_count = val_count + 1 171 | 172 | t_comp = (time.time() - iter_start_time) / opt.batch_size 173 | 174 | if flag_master: 175 | visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data, 176 | model.optimizers[0].param_groups[0]['lr']) 177 | if opt.display_id > 0: 178 | visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses) 179 | 180 | # cache our latest model every iterations 181 | if opt.save_latest: 182 | print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters)) 183 | save_suffix = 'latest' 184 | model.save_networks(save_suffix) 185 | 186 | if epoch % opt.save_epoch_freq == 0: # cache our model every epochs 187 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters)) 188 | model.save_networks(epoch) 189 | if loss_relit_validation < best_loss_relit_validation: 190 | print('saving the best model at the end of epoch %d, iters %d' % (epoch, total_iters)) 191 | model.save_networks('best') 192 | best_loss_relit_validation = loss_relit_validation 193 | 194 | print('End of epoch %d / %d \t Time Taken: %d sec' % ( 195 | epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time)) 196 | model.update_learning_rate() # update learning rates at the end of every epoch. 197 | 198 | if flag_master: 199 | # rename the best_* to save_best_* 200 | save_dir = os.path.join(opt.checkpoints_dir, opt.name) 201 | for filename in os.listdir(save_dir): 202 | if "best" in filename: 203 | old_name = os.path.join(save_dir, filename) 204 | new_name = os.path.join(save_dir, "save_"+filename) 205 | shutil.copyfile(old_name, new_name) 206 | pass 207 | 208 | -------------------------------------------------------------------------------- /util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br 3 | import os 4 | 5 | 6 | class HTML: 7 | """This HTML class allows us to save images and write texts into a single HTML file. 8 | 9 | It consists of functions such as (add a text header to the HTML file), 10 | (add a row of images to the HTML file), and (save the HTML to the disk). 11 | It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. 12 | """ 13 | 14 | def __init__(self, web_dir, title, refresh=0): 15 | """Initialize the HTML classes 16 | 17 | Parameters: 18 | web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: 32 | with self.doc.head: 33 | meta(http_equiv="refresh", content=str(refresh)) 34 | 35 | def get_image_dir(self): 36 | """Return the directory that stores images""" 37 | return self.img_dir 38 | 39 | def add_header(self, text): 40 | """Insert a header to the HTML file 41 | 42 | Parameters: 43 | text (str) -- the header text 44 | """ 45 | with self.doc: 46 | h3(text) 47 | 48 | def add_images(self, ims, txts, links, width=400): 49 | """add images to the HTML file 50 | 51 | Parameters: 52 | ims (str list) -- a list of image paths 53 | txts (str list) -- a list of image names shown on the website 54 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page 55 | """ 56 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table 57 | self.doc.add(self.t) 58 | with self.t: 59 | with tr(): 60 | for im, txt, link in zip(ims, txts, links): 61 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 62 | with p(): 63 | with a(href=os.path.join('images', link)): 64 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 65 | br() 66 | p(txt) 67 | 68 | def add_number(self, label, number): 69 | with self.doc: 70 | p(label) 71 | p(str(number)) 72 | 73 | def save(self): 74 | """save the current content to the HMTL file""" 75 | html_file = '%s/index.html' % self.web_dir 76 | f = open(html_file, 'wt') 77 | f.write(self.doc.render()) 78 | f.close() 79 | 80 | 81 | if __name__ == '__main__': # we show an example usage here. 82 | html = HTML('web/', 'test_html') 83 | html.add_header('hello world') 84 | 85 | ims, txts, links = [], [], [] 86 | for n in range(4): 87 | ims.append('image_%d.png' % n) 88 | txts.append('text_%d' % n) 89 | links.append('image_%d.png' % n) 90 | html.add_images(ims, txts, links) 91 | html.save() 92 | -------------------------------------------------------------------------------- /util/k_to_rgb.py: -------------------------------------------------------------------------------- 1 | """ 2 | From: https://gist.github.com/petrklus/b1f427accdf7438606a6#file-rgb_to_kelvin-py 3 | Based on: http://www.tannerhelland.com/4435/convert-temperature-rgb-algorithm-code/ 4 | Comments resceived: https://gist.github.com/petrklus/b1f427accdf7438606a6 5 | Original pseudo code: 6 | 7 | Set Temperature = Temperature \ 100 8 | 9 | Calculate Red: 10 | 11 | If Temperature <= 66 Then 12 | Red = 255 13 | Else 14 | Red = Temperature - 60 15 | Red = 329.698727446 * (Red ^ -0.1332047592) 16 | If Red < 0 Then Red = 0 17 | If Red > 255 Then Red = 255 18 | End If 19 | 20 | Calculate Green: 21 | 22 | If Temperature <= 66 Then 23 | Green = Temperature 24 | Green = 99.4708025861 * Ln(Green) - 161.1195681661 25 | If Green < 0 Then Green = 0 26 | If Green > 255 Then Green = 255 27 | Else 28 | Green = Temperature - 60 29 | Green = 288.1221695283 * (Green ^ -0.0755148492) 30 | If Green < 0 Then Green = 0 31 | If Green > 255 Then Green = 255 32 | End If 33 | 34 | Calculate Blue: 35 | 36 | If Temperature >= 66 Then 37 | Blue = 255 38 | Else 39 | 40 | If Temperature <= 19 Then 41 | Blue = 0 42 | Else 43 | Blue = Temperature - 10 44 | Blue = 138.5177312231 * Ln(Blue) - 305.0447927307 45 | If Blue < 0 Then Blue = 0 46 | If Blue > 255 Then Blue = 255 47 | End If 48 | 49 | End If 50 | """ 51 | 52 | import math 53 | 54 | 55 | def convert_K_to_RGB(colour_temperature): 56 | """ 57 | Converts from K to RGB, algorithm courtesy of 58 | http://www.tannerhelland.com/4435/convert-temperature-rgb-algorithm-code/ 59 | """ 60 | # range check 61 | if colour_temperature < 1000: 62 | colour_temperature = 1000 63 | elif colour_temperature > 40000: 64 | colour_temperature = 40000 65 | 66 | tmp_internal = colour_temperature / 100.0 67 | 68 | # red 69 | if tmp_internal <= 66: 70 | red = 255 71 | else: 72 | tmp_red = 329.698727446 * math.pow(tmp_internal - 60, -0.1332047592) 73 | if tmp_red < 0: 74 | red = 0 75 | elif tmp_red > 255: 76 | red = 255 77 | else: 78 | red = tmp_red 79 | 80 | # green 81 | if tmp_internal <= 66: 82 | tmp_green = 99.4708025861 * math.log(tmp_internal) - 161.1195681661 83 | if tmp_green < 0: 84 | green = 0 85 | elif tmp_green > 255: 86 | green = 255 87 | else: 88 | green = tmp_green 89 | else: 90 | tmp_green = 288.1221695283 * math.pow(tmp_internal - 60, -0.0755148492) 91 | if tmp_green < 0: 92 | green = 0 93 | elif tmp_green > 255: 94 | green = 255 95 | else: 96 | green = tmp_green 97 | 98 | # blue 99 | if tmp_internal >= 66: 100 | blue = 255 101 | elif tmp_internal <= 19: 102 | blue = 0 103 | else: 104 | tmp_blue = 138.5177312231 * math.log(tmp_internal - 10) - 305.0447927307 105 | if tmp_blue < 0: 106 | blue = 0 107 | elif tmp_blue > 255: 108 | blue = 255 109 | else: 110 | blue = tmp_blue 111 | 112 | return red, green, blue 113 | 114 | 115 | if __name__ == "__main__": 116 | print("Preview requires matplotlib") 117 | from matplotlib import pyplot as plt 118 | 119 | step_size = 100 120 | for i in range(0, 15000, step_size): 121 | color = list(map(lambda div: div / 255.0, convert_K_to_RGB(i))) + [1] 122 | print(color) 123 | plt.plot((i, i), (0, 1), linewidth=step_size / 2.0, linestyle="-", color=color) 124 | 125 | plt.show() 126 | 127 | 128 | -------------------------------------------------------------------------------- /util/metric.py: -------------------------------------------------------------------------------- 1 | """ 2 | Metric used in test_quantitive.py 3 | 4 | """ 5 | from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM 6 | import lpips 7 | import torch 8 | import math 9 | from torch.nn import functional as torchF 10 | from util.util import PARA_NOR 11 | 12 | class calculate_all_metrics(): 13 | def __init__(self): 14 | self.loss_lpips = lpips.LPIPS(net='alex').cuda() 15 | def calculate_lpips(self, X, Y): 16 | """ 17 | Calculate lpips, 18 | Variables X, Y is a PyTorch Tensor/Variable with shape Nx3xHxW 19 | (N patches of size HxW, RGB images scaled in [-1,+1]). 20 | This returns d, a length N Tensor/Variable. 21 | """ 22 | L = self.loss_lpips.forward(X*2-1, Y*2-1) # change [0,+1] to [-1,+1] 23 | L = L.detach().mean() 24 | return L 25 | def run(self, visuals, metric_labels): 26 | all_results = {} 27 | for key in metric_labels: 28 | if key in ['Reflectance', 'Shading_ori', 'Shading_new', 'Relighted']: 29 | results = self.image_metric(visuals, key + '_gt', key + '_predict') 30 | elif key == 'Reconstruct': 31 | results = self.image_metric(visuals, 'Image_input', 'Reconstruct') 32 | elif key == 'Input_and_relighted_gt': 33 | results = self.image_metric(visuals, 'Image_input', 'Relighted_gt') 34 | elif key == 'light_position_color': 35 | pan_tilt_color_gt = inverse_normalize_pan_tilt_color(visuals['light_position_color_original']) 36 | pan_tilt_color_predict = inverse_normalize_pan_tilt_color(visuals['light_position_color_predict']) 37 | light_position_angle_error, pan_error, tilt_error = pan_tilt2angle(pan_tilt_color_predict[:, :2], 38 | pan_tilt_color_gt[:, :2]) 39 | light_color_angle_error = angular_distance(pan_tilt_color_predict[:, 2:5], pan_tilt_color_gt[:, 2:5]) 40 | results = [light_position_angle_error, pan_error, tilt_error, light_color_angle_error] 41 | else: 42 | raise Exception("key error") 43 | all_results[key] = results 44 | return all_results 45 | 46 | def image_metric(self, tensors, gt_label, predict_label): 47 | # X: (N,3,H,W) a batch of non-negative RGB images (0~1) 48 | # Y: (N,3,H,W) 49 | ssim_value = ssim(tensors[gt_label], tensors[predict_label], data_range=1.0, size_average=True) # return (N,) 50 | mse_value = torchF.mse_loss(tensors[gt_label], tensors[predict_label]) 51 | psnr_value = 10 * torch.log10(1.0 / mse_value) 52 | lpips_value = self.calculate_lpips(tensors[gt_label], tensors[predict_label]) 53 | mps_value = (1 - lpips_value + ssim_value) / 2.0 54 | return [mse_value, ssim_value, psnr_value, lpips_value, mps_value] 55 | 56 | 57 | def inverse_normalize_pan_tilt_color(data): 58 | """ 59 | This function transforms the normalized light condition into the real value. 60 | The normalized data is [cos(pan)/2+0.5, sin(pan)/2+0.5, cos(tilt)/2+0.5, sin(tilt)/2+0.5, R/255, G/255, B/255] 61 | :param data: normalized light condition 62 | :return: pan(deg), tilt(deg), rgb(0~255) 63 | """ 64 | # This function transforms the normalized light condition into the real value. 65 | # The normalized data is [cos(pan)/2+0.5, sin(pan)/2+0.5, cos(tilt), sin(tilt), R/255, G/255, B/255] 66 | data_angle = data[:, :4] 67 | data_angle[:, :2] = (data_angle[:, :2] - PARA_NOR['pan_b']) / PARA_NOR['pan_a'] 68 | data_angle[:, 2:] = (data_angle[:, 2:] - PARA_NOR['tilt_b']) / PARA_NOR['tilt_a'] 69 | data_rgb = data[:, 4:] * 255.0 70 | pan = torch.atan2(data_angle[:, 1], data_angle[:, 0]) * 180 / math.pi 71 | pan = pan.unsqueeze(1) 72 | tilt = torch.atan2(data_angle[:, 3], data_angle[:, 2]) * 180 / math.pi 73 | tilt = tilt.unsqueeze(1) 74 | pan_tilt_color = torch.cat((pan, tilt, data_rgb), dim=1) 75 | 76 | return pan_tilt_color 77 | 78 | 79 | def pan_tilt2angle(pan_tilt_pred, pan_tilt_target): 80 | """ 81 | :param pan_tilt_pred: prediction 82 | :param pan_tilt_target: target 83 | :return: 84 | """ 85 | # inputs and targets should be pan and tilt. 86 | def pan_tilt_to_vector(pan, tilt): 87 | pan = pan / 180 * math.pi 88 | tilt = tilt / 180 * math.pi 89 | vector = torch.zeros(pan.size()[0], 3).cuda() 90 | vector[:, 0] = torch.mul(torch.sin(tilt), torch.cos(pan)) 91 | vector[:, 1] = torch.mul(torch.sin(tilt), torch.sin(pan)) 92 | vector[:, 2] = torch.cos(tilt) 93 | return vector 94 | 95 | vector_pred = pan_tilt_to_vector(pan_tilt_pred[:, 0], pan_tilt_pred[:, 1]) 96 | vector_target = pan_tilt_to_vector(pan_tilt_target[:, 0], pan_tilt_target[:, 1]) 97 | 98 | result = torch.sum(torch.mul(vector_pred, vector_target), dim=1) 99 | # pytorch acos occurs nan error. 100 | eps = 1e-6 101 | result = torch.clamp(result, min=-1 + eps, max=1 - eps) 102 | result = torch.mean(torch.acos(result)) 103 | angle_result = result * 180 / math.pi 104 | # also return error of pan and tilt. 105 | pan_error = distance_angle(pan_tilt_pred[:, 0], pan_tilt_target[:, 0]) 106 | tilt_error = distance_angle(pan_tilt_pred[:, 1], pan_tilt_target[:, 1]) 107 | return angle_result, pan_error, tilt_error 108 | 109 | 110 | def distance_angle(predict, target): 111 | """ 112 | :param predict: angle in deg 113 | :param target: angle in deg 114 | :return: the angle between predict and target in deg 115 | """ 116 | predict_rad = predict / (180 / math.pi) 117 | target_rad = target / (180 / math.pi) 118 | angle_cos = torch.cos(predict_rad) * torch.cos(target_rad) + torch.sin(predict_rad) * torch.sin(target_rad) 119 | # pytorch acos occurs nan error. 120 | eps = 1e-6 121 | angle_cos = torch.clamp(angle_cos, min=-1 + eps, max=1 - eps) 122 | result_rad = torch.mean(torch.acos(angle_cos)) 123 | result_deg = result_rad * 180 / math.pi 124 | return result_deg 125 | 126 | 127 | def angular_distance(color_predict, color_gt): 128 | result = torch.cosine_similarity(color_predict, color_gt, dim=1) 129 | #torch.sum(torch.mul(color_predict, color_gt), dim=1) 130 | # pytorch acos occurs nan error. 131 | eps = 1e-6 132 | result = torch.clamp(result, min=-1 + eps, max=1 - eps) 133 | result = torch.mean(torch.acos(result)) 134 | angle_result = result * 180 / math.pi 135 | return angle_result 136 | 137 | -------------------------------------------------------------------------------- /util/pan_tilt_dial.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CVC-CIC/DeepIntrinsicRelighting/b38a04ff9caecddd4e3551950cf6aaf5b17d960e/util/pan_tilt_dial.png -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | """This module contains simple helper functions """ 2 | from __future__ import print_function 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | import os 7 | 8 | 9 | """ 10 | Normalize the trigonometric function value of pan and tilt to [0, 1] 11 | pan can be -180 ~ 180 deg, so cos(pan), sin(pan) are [-1, 1], which should be adapted. 12 | tilt can be 0 ~ 90 deg, so cos(pan), sin(pan) are [0, 1], which should not be adapted. 13 | Normalization: y = ax + b; 14 | denormalization: x = (y-b)/a 15 | """ 16 | PARA_NOR = { 17 | 'pan_a': 0.5, 18 | 'pan_b': 0.5, 19 | 'tilt_a': 1.0, 20 | 'tilt_b': 0.0, 21 | } 22 | 23 | 24 | def tensor2im(input_image, normalization_type, imtype=np.uint8): 25 | """"Converts a Tensor array into a numpy image array. 26 | 27 | Parameters: 28 | input_image (tensor) -- the input image tensor array 29 | imtype (type) -- the desired type of the converted numpy array 30 | """ 31 | if not isinstance(input_image, np.ndarray): 32 | if isinstance(input_image, torch.Tensor): # get the data from a variable 33 | image_tensor = input_image.data 34 | else: 35 | return input_image 36 | image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array 37 | if image_numpy.shape[0] == 1: # grayscale to RGB 38 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 39 | if normalization_type == '[-1, 1]': 40 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling 41 | elif normalization_type == '[0, 1]': 42 | if np.max(image_numpy) > 1.0: 43 | image_numpy = np.clip(image_numpy, a_min=0.0, a_max=1.0) 44 | raise Exception("A warning for possible exceeding. ") 45 | # image_numpy = image_numpy / np.max(image_numpy) 46 | image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 # post-processing: tranpose and scaling 47 | else: 48 | raise Exception("rewrite inverse normalization here.") 49 | else: # if it is a numpy array, do nothing 50 | image_numpy = input_image 51 | return image_numpy.astype(imtype) 52 | 53 | 54 | # def tensor2pan_tilt_color(pan_tilt_color, num_type=np.int16): 55 | # pan_tilt_color_numpy = pan_tilt_color.cpu().detach().numpy() 56 | # pan_tilt_color_numpy = pan_tilt_color_numpy / [math.pi / 180.0, math.pi / 180.0, 0.001] 57 | # pan_tilt_color_numpy = np.around(pan_tilt_color_numpy) 58 | # return pan_tilt_color_numpy.astype(num_type) 59 | 60 | 61 | def tensor2pan_tilt_color(data, num_type=np.int16): 62 | # This function transforms the normalized light condition into the real value. 63 | data = data.cpu().detach().numpy() 64 | data[:2] = (data[:2] - PARA_NOR['pan_b']) / PARA_NOR['pan_a'] 65 | data[2:4] = (data[2:4] - PARA_NOR['tilt_b']) / PARA_NOR['tilt_a'] 66 | pan = np.rad2deg(np.arctan2(data[1], data[0])) 67 | tilt = np.rad2deg(np.arctan2(data[3], data[2])) 68 | RGB = data[4:] * 255.0 69 | around_data = np.around(np.array([pan, tilt, RGB[0], RGB[1], RGB[2]])) 70 | return around_data.astype(num_type) 71 | 72 | 73 | def diagnose_network(net, name='network'): 74 | """Calculate and print the mean of average absolute(gradients) 75 | 76 | Parameters: 77 | net (torch network) -- Torch network 78 | name (str) -- the name of the network 79 | """ 80 | mean = 0.0 81 | count = 0 82 | for param in net.parameters(): 83 | if param.grad is not None: 84 | mean += torch.mean(torch.abs(param.grad.data)) 85 | count += 1 86 | if count > 0: 87 | mean = mean / count 88 | print(name) 89 | print(mean) 90 | 91 | 92 | def save_image(image_numpy, image_path, aspect_ratio=1.0): 93 | """Save a numpy image to the disk 94 | 95 | Parameters: 96 | image_numpy (numpy array) -- input numpy array 97 | image_path (str) -- the path of the image 98 | """ 99 | 100 | image_pil = Image.fromarray(image_numpy) 101 | h, w, _ = image_numpy.shape 102 | 103 | if aspect_ratio > 1.0: 104 | image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC) 105 | if aspect_ratio < 1.0: 106 | image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC) 107 | image_pil.save(image_path) 108 | 109 | 110 | def print_numpy(x, val=True, shp=False): 111 | """Print the mean, min, max, median, std, and size of a numpy array 112 | 113 | Parameters: 114 | val (bool) -- if print the values of the numpy array 115 | shp (bool) -- if print the shape of the numpy array 116 | """ 117 | x = x.astype(np.float64) 118 | if shp: 119 | print('shape,', x.shape) 120 | if val: 121 | x = x.flatten() 122 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 123 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 124 | 125 | 126 | def mkdirs(paths): 127 | """create empty directories if they don't exist 128 | 129 | Parameters: 130 | paths (str list) -- a list of directory paths 131 | """ 132 | if isinstance(paths, list) and not isinstance(paths, str): 133 | for path in paths: 134 | mkdir(path) 135 | else: 136 | mkdir(paths) 137 | 138 | 139 | def mkdir(path): 140 | """create a single empty directory if it didn't exist 141 | 142 | Parameters: 143 | path (str) -- a single directory path 144 | """ 145 | if not os.path.exists(path): 146 | os.makedirs(path) 147 | -------------------------------------------------------------------------------- /util/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sys 4 | import ntpath 5 | import time 6 | from . import util, html 7 | from subprocess import Popen, PIPE 8 | 9 | 10 | if sys.version_info[0] == 2: 11 | VisdomExceptionBase = Exception 12 | else: 13 | VisdomExceptionBase = ConnectionError 14 | 15 | 16 | def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): 17 | """Save images to the disk. 18 | 19 | Parameters: 20 | webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details) 21 | visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs 22 | image_path (str) -- the string is used to create image paths 23 | aspect_ratio (float) -- the aspect ratio of saved images 24 | width (int) -- the images will be resized to width x width 25 | 26 | This function will save images stored in 'visuals' to the HTML file specified by 'webpage'. 27 | """ 28 | image_dir = webpage.get_image_dir() 29 | short_path = ntpath.basename(image_path[0]) 30 | name = os.path.splitext(short_path)[0] 31 | 32 | webpage.add_header(name) 33 | ims, txts, links = [], [], [] 34 | 35 | for label, im_data in visuals.items(): 36 | im = util.tensor2im(im_data) 37 | image_name = '%s_%s.png' % (name, label) 38 | save_path = os.path.join(image_dir, image_name) 39 | util.save_image(im, save_path, aspect_ratio=aspect_ratio) 40 | ims.append(image_name) 41 | txts.append(label) 42 | links.append(image_name) 43 | webpage.add_images(ims, txts, links, width=width) 44 | 45 | 46 | def save_images_one_batch(webpage, visuals, image_path, normalization_type, aspect_ratio=1.0, width=256): 47 | """Save images to the disk. 48 | 49 | Parameters: 50 | webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details) 51 | visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs 52 | image_path (str) -- the string is used to create image paths 53 | aspect_ratio (float) -- the aspect ratio of saved images 54 | width (int) -- the images will be resized to width x width 55 | 56 | This function will save images stored in 'visuals' to the HTML file specified by 'webpage'. 57 | """ 58 | image_dir = webpage.get_image_dir() 59 | 60 | for i in range(len(image_path)): 61 | short_path = ntpath.basename(image_path[i]) 62 | name = os.path.splitext(short_path)[0] 63 | 64 | webpage.add_header(name) 65 | ims, txts, links = [], [], [] 66 | 67 | for label, im_data in visuals.items(): 68 | if im_data.size().__len__() < 3: 69 | if im_data.size()[-1] == 7: 70 | # im_data is pan tilt 71 | webpage.add_number(label, util.tensor2pan_tilt_color(im_data[i])) 72 | else: 73 | # im_data is SH light. 74 | webpage.add_number(label, im_data[i]) 75 | continue 76 | im = util.tensor2im(im_data[i].unsqueeze(0), normalization_type) 77 | image_name = '%s_%s.png' % (name, label) 78 | save_path = os.path.join(image_dir, image_name) 79 | util.save_image(im, save_path, aspect_ratio=aspect_ratio) 80 | ims.append(image_name) 81 | txts.append(label) 82 | links.append(image_name) 83 | webpage.add_images(ims, txts, links, width=width) 84 | 85 | 86 | class Visualizer(): 87 | """This class includes several functions that can display/save images and print/save logging information. 88 | 89 | It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images. 90 | """ 91 | 92 | def __init__(self, opt): 93 | """Initialize the Visualizer class 94 | 95 | Parameters: 96 | opt -- stores all the experiment flags; needs to be a subclass of BaseOptions 97 | Step 1: Cache the training/test options 98 | Step 2: connect to a visdom server 99 | Step 3: create an HTML object for saveing HTML filters 100 | Step 4: create a logging file to store training losses 101 | """ 102 | self.opt = opt # cache the option 103 | self.display_id = opt.display_id 104 | self.use_html = opt.isTrain and not opt.no_html 105 | self.win_size = opt.display_winsize 106 | self.name = opt.name 107 | self.port = opt.display_port 108 | self.saved = False 109 | self.normalization_type = opt.normalization_type 110 | if self.display_id > 0: # connect to a visdom server given and 111 | import visdom 112 | self.ncols = opt.display_ncols 113 | self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env) 114 | if not self.vis.check_connection(): 115 | self.create_visdom_connections() 116 | 117 | if self.use_html: # create an HTML object at /web/; images will be saved under /web/images/ 118 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 119 | self.img_dir = os.path.join(self.web_dir, 'images') 120 | print('create web directory %s...' % self.web_dir) 121 | util.mkdirs([self.web_dir, self.img_dir]) 122 | # create a logging file to store training losses 123 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 124 | with open(self.log_name, "a") as log_file: 125 | now = time.strftime("%c") 126 | log_file.write('================ Training Loss (%s) ================\n' % now) 127 | 128 | def reset(self): 129 | """Reset the self.saved status""" 130 | self.saved = False 131 | 132 | def create_visdom_connections(self): 133 | """If the program could not connect to Visdom server, this function will start a new server at port < self.port > """ 134 | cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port 135 | print('\n\nCould not connect to Visdom server. \n Trying to start a server....') 136 | print('Command: %s' % cmd) 137 | Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE) 138 | 139 | def display_current_results(self, visuals, epoch, save_result): 140 | """Display current results on visdom; save current results to an HTML file. 141 | 142 | Parameters: 143 | visuals (OrderedDict) - - dictionary of images to display or save 144 | epoch (int) - - the current epoch 145 | save_result (bool) - - if save the current results to an HTML file 146 | """ 147 | if self.display_id > 0: # show images in the browser using visdom 148 | ncols = self.ncols 149 | if ncols > 0: # show all the images in one visdom panel 150 | ncols = min(ncols, len(visuals)) 151 | h, w = next(iter(visuals.values())).shape[:2] 152 | table_css = """""" % (w, h) # create a table css 156 | # create a table of images. 157 | title = self.name 158 | label_html = '' 159 | label_html_row = '' 160 | images = [] 161 | idx = 0 162 | for label, image in visuals.items(): 163 | image_numpy = util.tensor2im(image, self.normalization_type) 164 | label_html_row += '%s' % label 165 | images.append(image_numpy.transpose([2, 0, 1])) 166 | idx += 1 167 | if idx % ncols == 0: 168 | label_html += '%s' % label_html_row 169 | label_html_row = '' 170 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255 171 | while idx % ncols != 0: 172 | images.append(white_image) 173 | label_html_row += '' 174 | idx += 1 175 | if label_html_row != '': 176 | label_html += '%s' % label_html_row 177 | try: 178 | self.vis.images(images, nrow=ncols, win=self.display_id + 1, 179 | padding=2, opts=dict(title=title + ' images')) 180 | label_html = '%s
' % label_html 181 | self.vis.text(table_css + label_html, win=self.display_id + 2, 182 | opts=dict(title=title + ' labels')) 183 | except VisdomExceptionBase: 184 | self.create_visdom_connections() 185 | 186 | else: # show each image in a separate visdom panel; 187 | idx = 1 188 | try: 189 | for label, image in visuals.items(): 190 | image_numpy = util.tensor2im(image, self.normalization_type) 191 | self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), 192 | win=self.display_id + idx) 193 | idx += 1 194 | except VisdomExceptionBase: 195 | self.create_visdom_connections() 196 | 197 | if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved. 198 | self.saved = True 199 | # save images to the disk 200 | for label, image in visuals.items(): 201 | image_numpy = util.tensor2im(image, self.normalization_type) 202 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) 203 | util.save_image(image_numpy, img_path) 204 | 205 | # update website 206 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=0) 207 | for n in range(epoch, 0, -1): 208 | webpage.add_header('epoch [%d]' % n) 209 | ims, txts, links = [], [], [] 210 | 211 | for label, image_numpy in visuals.items(): 212 | image_numpy = util.tensor2im(image, self.normalization_type) 213 | img_path = 'epoch%.3d_%s.png' % (n, label) 214 | ims.append(img_path) 215 | txts.append(label) 216 | links.append(img_path) 217 | webpage.add_images(ims, txts, links, width=self.win_size) 218 | webpage.save() 219 | 220 | def plot_current_losses(self, epoch, counter_ratio, losses): 221 | """display the current losses on visdom display: dictionary of error labels and values 222 | 223 | Parameters: 224 | epoch (int) -- current epoch 225 | counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1 226 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs 227 | """ 228 | if not hasattr(self, 'plot_data'): 229 | self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())} 230 | self.plot_data['X'].append(epoch + counter_ratio) 231 | self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']]) 232 | try: 233 | self.vis.line( 234 | X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), 235 | Y=np.array(self.plot_data['Y']), 236 | opts={ 237 | 'title': self.name + ' loss over time', 238 | 'legend': self.plot_data['legend'], 239 | 'xlabel': 'epoch', 240 | 'ylabel': 'loss'}, 241 | win=self.display_id) 242 | except VisdomExceptionBase: 243 | self.create_visdom_connections() 244 | 245 | # losses: same format as |losses| of plot_current_losses 246 | def print_current_losses(self, epoch, iters, losses, t_comp, t_data, learning_rate): 247 | """print current losses on console; also save the losses to the disk 248 | 249 | Parameters: 250 | epoch (int) -- current epoch 251 | iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) 252 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs 253 | t_comp (float) -- computational time per data point (normalized by batch_size) 254 | t_data (float) -- data loading time per data point (normalized by batch_size) 255 | """ 256 | message = '(epoch: %d, iters: %d, time: %.3f, t_data: %.3f, learning_rate: %.8f) ' % (epoch, iters, t_comp, 257 | t_data, learning_rate) 258 | for k, v in losses.items(): 259 | message += '%s: %.3f ' % (k, v) 260 | 261 | print(message) # print the message 262 | with open(self.log_name, "a") as log_file: 263 | log_file.write('%s\n' % message) # save the message 264 | --------------------------------------------------------------------------------