├── datasets └── .gitkeep ├── requirements.txt ├── resources ├── bdd100k.jpg └── cover.png ├── unpaired_cycler2r ├── models │ ├── translation_models │ │ ├── __init__.py │ │ └── unpaired_cycler2r.py │ ├── architectures │ │ ├── __init__.py │ │ └── unpaired_cycler2r │ │ │ ├── __init__.py │ │ │ ├── modules.py │ │ │ └── generator_discriminator.py │ └── __init__.py ├── datasets │ ├── pipelines │ │ ├── __init__.py │ │ └── raw.py │ ├── __init__.py │ └── unpaired_cycler2r_dataset.py ├── version.py └── __init__.py ├── configs ├── unpaired_cycler2r │ ├── unpaired_cycler2r_in_bdd100k_rgb2asi_raw_20k.py │ ├── unpaired_cycler2r_in_bdd100k_rgb2huawei_raw_20k.py │ ├── unpaired_cycler2r_in_bdd100k_rgb2iphone_raw_20k.py │ ├── unpaired_cycler2r_in_flicker2w_rgb2asi_raw_20k.py │ ├── unpaired_cycler2r_in_bdd100k_rgb2oneplus_raw_20k.py │ ├── unpaired_cycler2r_in_cityscapes_rgb2iphone_raw_20k.py │ ├── unpaired_cycler2r_in_flicker2w_rgb2huawei_raw_20k.py │ ├── unpaired_cycler2r_in_flicker2w_rgb2iphone_raw_20k.py │ ├── unpaired_cycler2r_in_flicker2w_rgb2oneplus_raw_20k.py │ └── runtime.py └── _base_ │ ├── default_metrics.py │ ├── default_runtime.py │ ├── models │ └── unpaired_cycler2r │ │ └── unpaired_cycler2r.py │ └── datasets │ ├── bdd100k_rgb2asi_raw_512x512.py │ ├── bdd100k_rgb2oneplus_raw_512x512.py │ ├── flicker2w_rgb2asi_raw_512x512.py │ ├── flicker2w_rgb2oneplus_raw_512x512.py │ ├── bdd100k_rgb2huawei_raw_512x512.py │ ├── bdd100k_rgb2iphone_raw_512x512.py │ ├── flicker2w_rgb2huawei_raw_512x512.py │ ├── flicker2w_rgb2iphone_raw_512x512.py │ └── cityscapes_rgb2iphone_raw_512x512.py ├── LICENSE ├── .gitignore ├── inference.py ├── README.md └── train.py /datasets/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | rawpy 2 | mmcv-full==1.3.8 3 | mmgen==0.5.0 4 | -------------------------------------------------------------------------------- /resources/bdd100k.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJUVISION/rho-vision/HEAD/resources/bdd100k.jpg -------------------------------------------------------------------------------- /resources/cover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJUVISION/rho-vision/HEAD/resources/cover.png -------------------------------------------------------------------------------- /unpaired_cycler2r/models/translation_models/__init__.py: -------------------------------------------------------------------------------- 1 | from .unpaired_cycler2r import UnpairedCycleR2R 2 | 3 | __all__ = ['UnpairedCycleR2R'] 4 | -------------------------------------------------------------------------------- /unpaired_cycler2r/models/architectures/__init__.py: -------------------------------------------------------------------------------- 1 | from .unpaired_cycler2r import inverseISP, HistAwareDiscriminator 2 | 3 | __all__ = [ 4 | 'inverseISP', 'HistAwareDiscriminator' 5 | ] 6 | -------------------------------------------------------------------------------- /unpaired_cycler2r/datasets/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from .raw import LoadRAWFromFile, Demosaic, RAWNormalize 2 | 3 | __all__ = [ 4 | 'LoadRAWFromFile', 5 | 'Demosaic', 6 | 'RAWNormalize' 7 | ] 8 | -------------------------------------------------------------------------------- /unpaired_cycler2r/models/architectures/unpaired_cycler2r/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .generator_discriminator import inverseISP, HistAwareDiscriminator 3 | __all__ = ['inverseISP', 'HistAwareDiscriminator'] 4 | -------------------------------------------------------------------------------- /unpaired_cycler2r/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .architectures import inverseISP, HistAwareDiscriminator # noqa: F401, F403 2 | from .translation_models import UnpairedCycleR2R # noqa: F401, F403 3 | 4 | __all__ = ['inverseISP', 'HistAwareDiscriminator', 'UnpairedCycleR2R'] 5 | -------------------------------------------------------------------------------- /unpaired_cycler2r/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .pipelines import LoadRAWFromFile, Demosaic, RAWNormalize 3 | from .unpaired_cycler2r_dataset import UnpairedCycleR2RDataset 4 | 5 | __all__ = [ 6 | 'LoadRAWFromFile', 'Demosaic', 'RAWNormalize', 'UnpairedCycleR2RDataset' 7 | ] 8 | -------------------------------------------------------------------------------- /configs/unpaired_cycler2r/unpaired_cycler2r_in_bdd100k_rgb2asi_raw_20k.py: -------------------------------------------------------------------------------- 1 | DATASET = 'bdd100k' 2 | CAMERA = 'asi' 3 | _base_ = [ 4 | './runtime.py', 5 | '../_base_/datasets/{}_rgb2{}_raw_512x512.py'.format(DATASET, CAMERA), 6 | ] 7 | 8 | exp_name = 'unpaired_cycler2r_{}_rgb2{}_raw'.format(DATASET, CAMERA) 9 | work_dir = f'./work_dirs/experiments/{exp_name}' 10 | -------------------------------------------------------------------------------- /configs/unpaired_cycler2r/unpaired_cycler2r_in_bdd100k_rgb2huawei_raw_20k.py: -------------------------------------------------------------------------------- 1 | DATASET = 'bdd100k' 2 | CAMERA = 'huawei' 3 | _base_ = [ 4 | './runtime.py', 5 | '../_base_/datasets/{}_rgb2{}_raw_512x512.py'.format(DATASET, CAMERA), 6 | ] 7 | 8 | exp_name = 'unpaired_cycler2r_{}_rgb2{}_raw'.format(DATASET, CAMERA) 9 | work_dir = f'./work_dirs/experiments/{exp_name}' 10 | -------------------------------------------------------------------------------- /configs/unpaired_cycler2r/unpaired_cycler2r_in_bdd100k_rgb2iphone_raw_20k.py: -------------------------------------------------------------------------------- 1 | DATASET = 'bdd100k' 2 | CAMERA = 'iphone' 3 | _base_ = [ 4 | './runtime.py', 5 | '../_base_/datasets/{}_rgb2{}_raw_512x512.py'.format(DATASET, CAMERA), 6 | ] 7 | 8 | exp_name = 'unpaired_cycler2r_{}_rgb2{}_raw'.format(DATASET, CAMERA) 9 | work_dir = f'./work_dirs/experiments/{exp_name}' 10 | -------------------------------------------------------------------------------- /configs/unpaired_cycler2r/unpaired_cycler2r_in_flicker2w_rgb2asi_raw_20k.py: -------------------------------------------------------------------------------- 1 | DATASET = 'flicker2w' 2 | CAMERA = 'asi' 3 | _base_ = [ 4 | './runtime.py', 5 | '../_base_/datasets/{}_rgb2{}_raw_512x512.py'.format(DATASET, CAMERA), 6 | ] 7 | 8 | exp_name = 'unpaired_cycler2r_{}_rgb2{}_raw'.format(DATASET, CAMERA) 9 | work_dir = f'./work_dirs/experiments/{exp_name}' 10 | -------------------------------------------------------------------------------- /configs/unpaired_cycler2r/unpaired_cycler2r_in_bdd100k_rgb2oneplus_raw_20k.py: -------------------------------------------------------------------------------- 1 | DATASET = 'bdd100k' 2 | CAMERA = 'oneplus' 3 | _base_ = [ 4 | './runtime.py', 5 | '../_base_/datasets/{}_rgb2{}_raw_512x512.py'.format(DATASET, CAMERA), 6 | ] 7 | 8 | exp_name = 'unpaired_cycler2r_{}_rgb2{}_raw'.format(DATASET, CAMERA) 9 | work_dir = f'./work_dirs/experiments/{exp_name}' 10 | -------------------------------------------------------------------------------- /configs/unpaired_cycler2r/unpaired_cycler2r_in_cityscapes_rgb2iphone_raw_20k.py: -------------------------------------------------------------------------------- 1 | DATASET = 'cityscapes' 2 | CAMERA = 'iphone' 3 | _base_ = [ 4 | './runtime.py', 5 | '../_base_/datasets/{}_rgb2{}_raw_512x512.py'.format(DATASET, CAMERA), 6 | ] 7 | 8 | exp_name = 'unpaired_cycler2r_{}_rgb2{}_raw'.format(DATASET, CAMERA) 9 | work_dir = f'./work_dirs/experiments/{exp_name}' 10 | -------------------------------------------------------------------------------- /configs/unpaired_cycler2r/unpaired_cycler2r_in_flicker2w_rgb2huawei_raw_20k.py: -------------------------------------------------------------------------------- 1 | DATASET = 'flicker2w' 2 | CAMERA = 'huawei' 3 | _base_ = [ 4 | './runtime.py', 5 | '../_base_/datasets/{}_rgb2{}_raw_512x512.py'.format(DATASET, CAMERA), 6 | ] 7 | 8 | exp_name = 'unpaired_cycler2r_{}_rgb2{}_raw'.format(DATASET, CAMERA) 9 | work_dir = f'./work_dirs/experiments/{exp_name}' 10 | -------------------------------------------------------------------------------- /configs/unpaired_cycler2r/unpaired_cycler2r_in_flicker2w_rgb2iphone_raw_20k.py: -------------------------------------------------------------------------------- 1 | DATASET = 'flicker2w' 2 | CAMERA = 'iphone' 3 | _base_ = [ 4 | './runtime.py', 5 | '../_base_/datasets/{}_rgb2{}_raw_512x512.py'.format(DATASET, CAMERA), 6 | ] 7 | 8 | exp_name = 'unpaired_cycler2r_{}_rgb2{}_raw'.format(DATASET, CAMERA) 9 | work_dir = f'./work_dirs/experiments/{exp_name}' 10 | -------------------------------------------------------------------------------- /configs/unpaired_cycler2r/unpaired_cycler2r_in_flicker2w_rgb2oneplus_raw_20k.py: -------------------------------------------------------------------------------- 1 | DATASET = 'flicker2w' 2 | CAMERA = 'oneplus' 3 | _base_ = [ 4 | './runtime.py', 5 | '../_base_/datasets/{}_rgb2{}_raw_512x512.py'.format(DATASET, CAMERA), 6 | ] 7 | 8 | exp_name = 'unpaired_cycler2r_{}_rgb2{}_raw'.format(DATASET, CAMERA) 9 | work_dir = f'./work_dirs/experiments/{exp_name}' 10 | -------------------------------------------------------------------------------- /configs/_base_/default_metrics.py: -------------------------------------------------------------------------------- 1 | metrics = dict( 2 | fid50k=dict(type='FID', num_images=50000), 3 | pr50k3=dict(type='PR', num_images=50000, k=3), 4 | is50k=dict(type='IS', num_images=50000), 5 | ppl_zfull=dict(type='PPL', space='Z', sampling='full', num_images=50000), 6 | ppl_wfull=dict(type='PPL', space='W', sampling='full', num_images=50000), 7 | ppl_zend=dict(type='PPL', space='Z', sampling='end', num_images=50000), 8 | ppl_wend=dict(type='PPL', space='W', sampling='end', num_images=50000), 9 | ms_ssim10k=dict(type='MS_SSIM', num_images=10000), 10 | swd16k=dict(type='SWD', num_images=16384)) 11 | -------------------------------------------------------------------------------- /unpaired_cycler2r/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | __version__ = '0.5.0' 3 | 4 | 5 | def parse_version_info(version_str): 6 | """Parse version information. 7 | 8 | Args: 9 | version_str (str): Version string. 10 | 11 | Returns: 12 | tuple: Version information in tuple. 13 | """ 14 | version_info = [] 15 | for x in version_str.split('.'): 16 | if x.isdigit(): 17 | version_info.append(int(x)) 18 | elif x.find('rc') != -1: 19 | patch_version = x.split('rc') 20 | version_info.append(int(patch_version[0])) 21 | version_info.append(f'rc{patch_version[1]}') 22 | return tuple(version_info) 23 | 24 | 25 | version_info = parse_version_info(__version__) 26 | -------------------------------------------------------------------------------- /configs/_base_/default_runtime.py: -------------------------------------------------------------------------------- 1 | checkpoint_config = dict(interval=10000, by_epoch=False) 2 | # yapf:disable 3 | log_config = dict( 4 | interval=20, 5 | hooks=[ 6 | dict(type='TextLoggerHook'), 7 | # dict(type='TensorboardLoggerHook'), 8 | ]) 9 | # yapf:enable 10 | 11 | custom_hooks = [ 12 | dict( 13 | type='VisualizeUnconditionalSamples', 14 | output_dir='training_samples', 15 | interval=1000) 16 | ] 17 | 18 | # use dynamic runner 19 | runner = dict( 20 | type='DynamicIterBasedRunner', 21 | is_dynamic_ddp=True, 22 | pass_training_status=True) 23 | 24 | dist_params = dict(backend='nccl') 25 | log_level = 'INFO' 26 | load_from = None 27 | resume_from = None 28 | workflow = [('train', 10000)] 29 | find_unused_parameters = True 30 | cudnn_benchmark = True 31 | -------------------------------------------------------------------------------- /unpaired_cycler2r/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import mmcv 3 | 4 | from .version import __version__, parse_version_info, version_info 5 | 6 | 7 | def digit_version(version_str): 8 | digit_version = [] 9 | for x in version_str.split('.'): 10 | if x.isdigit(): 11 | digit_version.append(int(x)) 12 | elif x.find('rc') != -1: 13 | patch_version = x.split('rc') 14 | digit_version.append(int(patch_version[0]) - 1) 15 | digit_version.append(int(patch_version[1])) 16 | return digit_version 17 | 18 | 19 | mmcv_minimum_version = '1.3.0' 20 | mmcv_maximum_version = '1.7.0' 21 | mmcv_version = digit_version(mmcv.__version__) 22 | 23 | 24 | assert (mmcv_version >= digit_version(mmcv_minimum_version) 25 | and mmcv_version <= digit_version(mmcv_maximum_version)), \ 26 | f'MMCV=={mmcv.__version__} is used but incompatible. ' \ 27 | f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.' 28 | 29 | __all__ = ['__version__', 'version_info', 'parse_version_info'] 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 NJUVISION 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/unpaired_cycler2r/runtime.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/unpaired_cycler2r/unpaired_cycler2r.py', 3 | ] 4 | # log 5 | # yapf:disable 6 | log_config = dict( 7 | interval=20, 8 | hooks=[ 9 | dict(type='TextLoggerHook'), 10 | # dict(type='TensorboardLoggerHook'), 11 | ]) 12 | # yapf:enable 13 | 14 | # dist 15 | # use dynamic runner 16 | runner = None 17 | dist_params = dict(backend='nccl') 18 | log_level = 'INFO' 19 | load_from = None 20 | resume_from = None 21 | find_unused_parameters = False 22 | cudnn_benchmark = True 23 | use_ddp_wrapper = True 24 | total_iters = 20000 25 | workflow = [('train', 1)] 26 | 27 | # learning policy 28 | optimizer = dict( 29 | generator=dict(type='AdamW', lr=5e-4), 30 | discriminators=dict(type='Adam', lr=5e-5, betas=(0.5, 0.999))) 31 | 32 | lr_config = dict( 33 | policy='Linear', by_epoch=False, target_lr=0, start=10000, interval=400) 34 | 35 | # evalutation 36 | num_images = 20 37 | evaluation = dict( 38 | type='TranslationEvalHook', 39 | target_domain='raw', 40 | interval=10000, 41 | metrics=[ 42 | dict(type='FID', num_images=num_images, bgr2rgb=False), 43 | dict( 44 | type='IS', 45 | num_images=num_images, 46 | image_shape=(3, 256, 256), 47 | inception_args=dict(type='pytorch')) 48 | ], 49 | best_metric=['fid', 'is']) 50 | 51 | checkpoint_config = dict(interval=10000, save_optimizer=True, by_epoch=False) 52 | -------------------------------------------------------------------------------- /configs/_base_/models/unpaired_cycler2r/unpaired_cycler2r.py: -------------------------------------------------------------------------------- 1 | _domain_a = 'raw' # set by user 2 | _domain_b = 'rgb' # set by user 3 | model = dict( 4 | type='UnpairedCycleR2R', 5 | lambda_bright_diversity=30.0, 6 | lambda_color_diversity=3.0, 7 | generator=dict( 8 | type='inverseISP'), 9 | discriminator=dict( 10 | type='HistAwareDiscriminator'), 11 | gan_loss=dict( 12 | type='GANLoss', 13 | gan_type='lsgan', 14 | real_label_val=1.0, 15 | fake_label_val=0.0, 16 | loss_weight=1.0), 17 | default_domain=_domain_b, 18 | reachable_domains=[_domain_a, _domain_b], 19 | related_domains=[_domain_a, _domain_b], 20 | gen_auxiliary_loss=[ 21 | dict( 22 | type='L1Loss', 23 | loss_weight=10.0, 24 | loss_name='cycle_loss', 25 | data_info=dict( 26 | pred=f'cycle_{_domain_b}', 27 | target=f'real_{_domain_b}', 28 | ), 29 | reduction='mean') 30 | ]) 31 | train_cfg = dict(buffer_size=50) 32 | test_cfg = None 33 | custom_hooks = [ 34 | dict( 35 | type='MMGenVisualizationHook', 36 | output_dir='training_samples', 37 | res_name_list=[f'real_{_domain_a}', f'fake_{_domain_b}', f'cycle_{_domain_a}', f'real_{_domain_b}', 38 | f'fake_{_domain_a}', f'fake_{_domain_a}_diversity', f'cycle_{_domain_b}'], 39 | rerange=False, 40 | bgr2rgb=False, 41 | interval=200) 42 | ] 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # 132 | work_dirs 133 | visual 134 | *pth 135 | simulated_preview.png -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | from collections import OrderedDict 4 | from imageio import imread, imwrite 5 | 6 | import torch 7 | from torch import nn 8 | import torch.distributions as D 9 | 10 | from mmgen.models import build_module 11 | from unpaired_cycler2r.models import * 12 | 13 | 14 | class DemoModel(nn.Module): 15 | 16 | def __init__(self, ckpt_path) -> None: 17 | super().__init__() 18 | # load checkpoint 19 | ckpt = torch.load(ckpt_path, map_location='cpu') 20 | state_dict = OrderedDict() 21 | for k, v in ckpt['state_dict'].items(): 22 | if k.startswith('generator.'): 23 | state_dict[k[len('generator.'):]] = v 24 | 25 | # get invISP model 26 | self.model = build_module(dict(type='inverseISP')) 27 | self.model.load_state_dict(state_dict, strict=False) 28 | self.model.eval() 29 | 30 | def _get_illumination_condition(self, img): 31 | mean_var = self.model.color_condition_gen(img) 32 | m = D.Normal(mean_var[:, 0], 33 | torch.clamp_min(torch.abs(mean_var[:, 1]), 1e-6)) 34 | color_condition = m.sample() 35 | mean_var = self.model.bright_condition_gen(img) 36 | m = D.Normal(mean_var[:, 0], 37 | torch.clamp_min(torch.abs(mean_var[:, 1]), 1e-6)) 38 | bright_condition = m.sample() 39 | condition = torch.cat( 40 | [color_condition[:, None], bright_condition[:, None]], 1) 41 | return condition 42 | 43 | def _mosaic(self, x): 44 | h, w = x.shape[2:] 45 | _x = torch.zeros(x.shape[0], 4, h // 2, w // 2, device=x.device) 46 | _x[:, 0] = x[:, 0, 0::2, 0::2] 47 | _x[:, 1] = x[:, 0, 0::2, 1::2] 48 | _x[:, 2] = x[:, 0, 1::2, 0::2] 49 | _x[:, 3] = x[:, 0, 1::2, 1::2] 50 | return _x 51 | 52 | def forward(self, rgb, mosaic=False): 53 | with torch.no_grad(): 54 | # get illumination condition 55 | condition = self._get_illumination_condition(rgb) 56 | # get simulated RAW image 57 | raw = self.model(rgb, condition, rev=False) 58 | raw = torch.clamp(raw, 0, 1) 59 | if mosaic: 60 | raw = self._mosaic(raw) 61 | return raw 62 | 63 | 64 | if __name__ == '__main__': 65 | args = argparse.ArgumentParser() 66 | args.add_argument('--ckpt', type=str) 67 | args.add_argument('--rgb', type=str) 68 | args = args.parse_args() 69 | 70 | model = DemoModel(args.ckpt) 71 | img = imread(args.rgb).astype(np.float32) / 255 72 | img = torch.from_numpy(img).permute(2, 0, 1)[None] 73 | 74 | model = model.cuda() 75 | img = img.cuda() 76 | 77 | x = model(img, mosaic=False) 78 | x = x[0].permute(1, 2, 0).cpu().numpy() 79 | x = (x * 255).astype(np.uint8) 80 | imwrite('simulated_preview.png', x) 81 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [T-PAMI 2024] Efficient Visual Computing with Camera RAW Snapshots 2 | 3 | ![](resources/cover.png) 4 | 5 | We proposes a novel **$\bf{\rho}$-Vision** to directly perform **high-level semantic understanding** and **low-level compression** using **RAW images**. The framework is demonstrated to provide better detection accuracy and compression than RGB-domain counterparts and is shown to be able to generalize across different camera sensors and task-specific models. Additionally, it has the potential to **reduce ISP computation and processing time**. 6 | 7 | In this repo, we release our **Unpaired CycleR2R** code and pretrained models. With Unpaired CycleR2R, you could **train your RAW model with diversity and realistic simulated RAW images** and then **deploy them in the real-world directly**. 8 | 9 | 10 | ## Requirments 11 | ```bash 12 | pip install -r requirements.txt 13 | ``` 14 | 15 | ## Datasets 16 | (*required*) Download the [MulitRAW](https://box.nju.edu.cn/d/0f4b5206cf734bd889aa/) [LUCID subset (passwd: mwdp)](https://pan.baidu.com/s/1x7kGOMEVhPpZYVjlkPIxEw). 17 | 18 | (*optional*) Download the [BDD100K](https://www.bdd100k.com/). 19 | 20 | (*optional*) Download the [Cityscapes](https://www.cityscapes-dataset.com/). 21 | 22 | (*optional*) Download the [Flicker2W](https://github.com/liujiaheng/CompressionData). 23 | 24 | The datasets folder will be like: 25 | ```bash 26 | datasets 27 | ├─multiRAW 28 | │ ├─iphone_xsmax 29 | │ ├─huawei_p30pro 30 | │ ├─asi_294mcpro 31 | │ └─oneplus_5t 32 | ├─(optional) bdd100k 33 | ├─(optional) cityscapes 34 | └─(optional) flicker 35 | ``` 36 | 37 | 38 | ## Pretrained Models 39 | | Source RGB | Target RAW | Model | 40 | | :---: | :----: | :---: | 41 | | BDD100K | iPhone XSmax | [link](https://box.nju.edu.cn/f/d1e199cd5ada49b88ad9/) | 42 | | BDD100K | Huawei P30pro | [link](https://box.nju.edu.cn/f/9eddaab558ce4b1e953e/) | 43 | | BDD100K | asi 294mcpro | [link](https://box.nju.edu.cn/f/c813bc63ddc14932aba3/) | 44 | | BDD100K | Oneplus 5t | [link](https://box.nju.edu.cn/f/ac70a11515834c43a1ac/) | 45 | | Cityscapes | iPhone XSmax | [link](https://box.nju.edu.cn/f/ec6fc364d90d482f9001/) | 46 | | Flicker2W | iPhone XSmax | [link](https://box.nju.edu.cn/f/acbdf1dded594587aac7/) | 47 | | Flicker2W | Huawei P30pro | [link](https://box.nju.edu.cn/f/647e979dc59d4f8c845c/) | 48 | | Flicker2W | asi 294mcpro | [link](https://box.nju.edu.cn/f/c26c0e5a84da49f884b7/) | 49 | | Flicker2W | Oneplus 5t | [link](https://box.nju.edu.cn/f/3a223fd40c3d4b548ca8/) | 50 | 51 | 52 | ## Training 53 | ```bash 54 | python train.py configs/unpaired_cycler2r/unpaired_cycler2r_in_bdd100k_rgb2iphone_raw_20k.py 55 | ``` 56 | 57 | ## Inference 58 | Please download the [pretrained model](https://box.nju.edu.cn/f/d1e199cd5ada49b88ad9/) first. 59 | 60 | You coud inference using command, 61 | ```bash 62 | python inference.py --ckpt bdd100k_rgb_to_iphone_raw.pth --rgb resources/bdd100k.jpg 63 | ``` 64 | 65 | or in your code 66 | ```python 67 | from inference import DemoModel 68 | ckpt_path = 'bdd100k_rgb_to_iphone_raw.pth' 69 | rgb_path = 'resources/bdd100k.jpg' 70 | 71 | model = DemoModel(ckpt_path) 72 | rgb = imread(rgb_path).astype(np.float32) / 255 73 | rgb = torch.from_numpy(rgb).permute(2, 0, 1)[None] 74 | 75 | model = model.cuda() 76 | rgb = rgb.cuda() 77 | 78 | raw = model(rgb, mosaic=False) 79 | ``` 80 | 81 | 82 | ## Citation 83 | If your find our dataset and work are helpful for your research, please cite our paper. 84 | ``` 85 | @article{li2022efficient, 86 | title={Efficient Visual Computing with Camera RAW Snapshots}, 87 | author={Zhihao Li, Ming Lu, Xu Zhang, Xin Feng, M. Salman Asif, and Zhan Ma}, 88 | journal={arxiv}, 89 | url={https://arxiv.org/pdf/2212.07778.pdf}, 90 | year={2022}, 91 | } 92 | ``` 93 | -------------------------------------------------------------------------------- /unpaired_cycler2r/models/architectures/unpaired_cycler2r/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class TinyEncoder(nn.Module): 6 | def __init__(self, in_channels=3, out_channels=3, channnels=64, input_size=128): 7 | super().__init__() 8 | self.input_size = input_size 9 | self.net = torch.nn.Sequential( 10 | nn.Conv2d(in_channels, 32, 5, 1, 2), 11 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 12 | nn.AvgPool2d(2, 2), 13 | 14 | nn.Conv2d(32, channnels, 3, 1, 1), 15 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 16 | nn.AvgPool2d(2, 2), 17 | 18 | nn.Conv2d(channnels, channnels, 3, 1, 1), 19 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 20 | nn.AdaptiveAvgPool2d(1), 21 | ) 22 | 23 | self.linear = nn.Sequential( 24 | nn.Linear(channnels, 256), 25 | nn.Linear(256, out_channels) 26 | ) 27 | 28 | def forward(self, x): 29 | x = self.net(x) 30 | x = x.view(x.size(0), -1) 31 | x = self.linear(x) 32 | return x 33 | 34 | 35 | def differentiable_histogram(x, bins=255, min=0.0, max=1.0): 36 | if len(x.shape) == 4: 37 | n_samples, n_chns, _, _ = x.shape 38 | elif len(x.shape) == 2: 39 | n_samples, n_chns = 1, 1 40 | else: 41 | raise AssertionError('The dimension of input tensor should be 2 or 4.') 42 | 43 | hist_torch = torch.zeros(n_samples, n_chns, bins).to(x.device) 44 | delta = (max - min) / bins 45 | 46 | BIN_Table = torch.arange(start=0, end=bins + 1, step=1) * delta 47 | 48 | for dim in range(1, bins - 1, 1): 49 | h_r = BIN_Table[dim].item() # h_r 50 | h_r_sub_1 = BIN_Table[dim - 1].item() # h_(r-1) 51 | h_r_plus_1 = BIN_Table[dim + 1].item() # h_(r+1) 52 | 53 | mask_sub = ((h_r > x) & (x >= h_r_sub_1)).float() 54 | mask_plus = ((h_r_plus_1 > x) & (x >= h_r)).float() 55 | 56 | hist_torch[:, :, dim] += torch.sum(((x - h_r_sub_1) * mask_sub).view(n_samples, n_chns, -1), dim=-1) 57 | hist_torch[:, :, dim] += torch.sum(((h_r_plus_1 - x) * mask_plus).view(n_samples, n_chns, -1), dim=-1) 58 | 59 | return hist_torch / delta 60 | 61 | 62 | def differentiable_uv_histogram(x, bins=32): 63 | b, c, _, _ = x.shape 64 | assert c == 3 65 | hist_uv_map = torch.zeros(b, 1, bins, bins, device=x.device) 66 | x = torch.clamp(x, 1e-6, 1) 67 | u, v = torch.log(x[:, 1] / x[:, 0]), torch.log(x[:, 1] / x[:, 2]) # [b, h, w] 68 | y = torch.sqrt(torch.pow(x, 2).sum(1)) # [b, h, w] 69 | y, u, v = y.view([b, -1]), u.view([b, -1]), v.view([b, -1]) 70 | eta = 0.025 * 256 / bins 71 | 72 | for u_c in range(1, bins + 1): 73 | for v_c in range(1, bins + 1): 74 | u_sub = (u_c - 0.5) * eta 75 | u_plus = (u_c + 0.5) * eta 76 | v_sub = (v_c - 0.5) * eta 77 | v_plus = (v_c + 0.5) * eta 78 | u_mask_sub = ((u_sub <= u) & (u < u_c)).float().detach() 79 | v_mask_sub = ((v_sub <= v) & (v < v_c)).float().detach() 80 | u_mask_plus = ((u_c <= u) & (u < u_plus)).float().detach() 81 | v_mask_plus = ((v_c <= v) & (v < v_plus)).float().detach() 82 | hist_uv_map[:, 0, u_c - 1, v_c - 1] += torch.sum(y * (u - u_sub) * u_mask_sub, -1) 83 | hist_uv_map[:, 0, u_c - 1, v_c - 1] += torch.sum(y * (v - v_sub) * v_mask_sub, -1) 84 | hist_uv_map[:, 0, u_c - 1, v_c - 1] += torch.sum(y * (u - u_plus) * u_mask_plus, -1) 85 | hist_uv_map[:, 0, u_c - 1, v_c - 1] += torch.sum(y * (v - v_plus) * v_mask_plus, -1) 86 | hist_uv_map = hist_uv_map/hist_uv_map.view([b, -1]).sum(-1).view([b, 1, 1, 1]) 87 | hist_uv_map = torch.sqrt(hist_uv_map) 88 | return hist_uv_map 89 | -------------------------------------------------------------------------------- /configs/_base_/datasets/bdd100k_rgb2asi_raw_512x512.py: -------------------------------------------------------------------------------- 1 | BLC = 0 2 | SATURATION = 16484 3 | CAMERA = 'asi_294mcpro' 4 | H, W = 960, 1280 5 | 6 | dataset_type = 'UnpairedCycleR2RDataset' 7 | domain_a = 'raw' # set by user 8 | domain_b = 'rgb' # set by user 9 | # dataset a setting 10 | dataroot_a = 'datasets/multiRAW/{}/raw'.format(CAMERA) 11 | train_split_a = 'datasets/multiRAW/{}/train.txt'.format(CAMERA) 12 | test_split_a = 'datasets/multiRAW/{}/train.txt'.format(CAMERA) 13 | 14 | # dataset b setting 15 | train_dataroot_b = 'datasets/bdd100k/images/100k/train' 16 | test_dataroot_b = 'datasets/bdd100k/images/100k/test' 17 | split_b = None 18 | 19 | train_pipeline = [ 20 | dict( 21 | type='LoadRAWFromFile', 22 | key=f'img_{domain_a}'), 23 | dict( 24 | type='Demosaic', 25 | key=f'img_{domain_a}'), 26 | dict( 27 | type='RAWNormalize', 28 | blc=BLC, 29 | saturate=SATURATION, 30 | key=f'img_{domain_a}'), 31 | dict( 32 | type='LoadImageFromFile', 33 | io_backend='disk', 34 | key=f'img_{domain_b}', 35 | flag='color'), 36 | dict( 37 | type='Resize', 38 | keys=[f'img_{domain_a}'], 39 | scale=(W, H), 40 | interpolation='bicubic'), 41 | dict( 42 | type='Crop', 43 | keys=[f'img_{domain_a}', f'img_{domain_b}'], 44 | crop_size=(512, 512), 45 | random_crop=True), 46 | dict(type='Flip', keys=[f'img_{domain_a}'], direction='horizontal'), 47 | dict(type='Flip', keys=[f'img_{domain_b}'], direction='horizontal'), 48 | dict(type='RescaleToZeroOne', keys=[f'img_{domain_b}']), 49 | dict( 50 | type='Normalize', 51 | keys=[f'img_{domain_b}'], 52 | to_rgb=True, 53 | mean=[0, 0, 0], 54 | std=[1, 1, 1]), 55 | dict(type='ImageToTensor', keys=[f'img_{domain_a}', f'img_{domain_b}']), 56 | dict( 57 | type='Collect', 58 | keys=[f'img_{domain_a}', f'img_{domain_b}'], 59 | meta_keys=[f'img_{domain_a}_path', f'img_{domain_b}_path']) 60 | ] 61 | 62 | test_pipeline = [ 63 | dict( 64 | type='LoadRAWFromFile', 65 | key=f'img_{domain_a}'), 66 | dict( 67 | type='Demosaic', 68 | key=f'img_{domain_a}'), 69 | dict( 70 | type='RAWNormalize', 71 | blc=BLC, 72 | saturate=SATURATION, 73 | key=f'img_{domain_a}'), 74 | dict( 75 | type='LoadImageFromFile', 76 | io_backend='disk', 77 | key=f'img_{domain_b}', 78 | flag='color'), 79 | dict( 80 | type='Resize', 81 | keys=[f'img_{domain_a}', f'img_{domain_b}'], 82 | scale=(W, H), 83 | interpolation='bicubic'), 84 | dict(type='RescaleToZeroOne', keys=[f'img_{domain_b}']), 85 | dict( 86 | type='Normalize', 87 | keys=[f'img_{domain_b}'], 88 | to_rgb=True, 89 | mean=[0, 0, 0], 90 | std=[1, 1, 1]), 91 | dict(type='ImageToTensor', keys=[f'img_{domain_a}', f'img_{domain_b}']), 92 | dict( 93 | type='Collect', 94 | keys=[f'img_{domain_a}', f'img_{domain_b}'], 95 | meta_keys=[f'img_{domain_a}_path', f'img_{domain_b}_path']) 96 | ] 97 | 98 | data = dict( 99 | samples_per_gpu=4, 100 | workers_per_gpu=8, 101 | drop_last=True, 102 | val_samples_per_gpu=1, 103 | val_workers_per_gpu=0, 104 | train=dict( 105 | type=dataset_type, 106 | dataroot_a=dataroot_a, 107 | dataroot_b=train_dataroot_b, 108 | split_a=train_split_a, 109 | split_b=split_b, 110 | pipeline=train_pipeline, 111 | domain_a=domain_a, 112 | domain_b=domain_b), 113 | val=dict( 114 | type=dataset_type, 115 | dataroot_a=dataroot_a, 116 | dataroot_b=test_dataroot_b, 117 | split_a=test_split_a, 118 | split_b=split_b, 119 | domain_a=domain_a, 120 | domain_b=domain_b, 121 | pipeline=test_pipeline), 122 | test=dict( 123 | type=dataset_type, 124 | dataroot_a=dataroot_a, 125 | dataroot_b=test_dataroot_b, 126 | split_a=test_split_a, 127 | split_b=split_b, 128 | domain_a=domain_a, 129 | domain_b=domain_b, 130 | pipeline=test_pipeline)) 131 | -------------------------------------------------------------------------------- /configs/_base_/datasets/bdd100k_rgb2oneplus_raw_512x512.py: -------------------------------------------------------------------------------- 1 | BLC = 0 2 | SATURATION = 1024 3 | CAMERA = 'oneplus_5t' 4 | H, W = 960, 1280 5 | 6 | dataset_type = 'UnpairedCycleR2RDataset' 7 | domain_a = 'raw' # set by user 8 | domain_b = 'rgb' # set by user 9 | # dataset a setting 10 | dataroot_a = 'datasets/multiRAW/{}/raw'.format(CAMERA) 11 | train_split_a = 'datasets/multiRAW/{}/test.txt'.format(CAMERA) 12 | test_split_a = 'datasets/multiRAW/{}/test.txt'.format(CAMERA) 13 | 14 | # dataset b setting 15 | train_dataroot_b = 'datasets/bdd100k/images/100k/train' 16 | test_dataroot_b = 'datasets/bdd100k/images/100k/test' 17 | split_b = None 18 | 19 | train_pipeline = [ 20 | dict( 21 | type='LoadRAWFromFile', 22 | key=f'img_{domain_a}'), 23 | dict( 24 | type='Demosaic', 25 | key=f'img_{domain_a}'), 26 | dict( 27 | type='RAWNormalize', 28 | blc=BLC, 29 | saturate=SATURATION, 30 | key=f'img_{domain_a}'), 31 | dict( 32 | type='LoadImageFromFile', 33 | io_backend='disk', 34 | key=f'img_{domain_b}', 35 | flag='color'), 36 | dict( 37 | type='Resize', 38 | keys=[f'img_{domain_a}'], 39 | scale=(W, H), 40 | interpolation='bicubic'), 41 | dict( 42 | type='Crop', 43 | keys=[f'img_{domain_a}', f'img_{domain_b}'], 44 | crop_size=(512, 512), 45 | random_crop=True), 46 | dict(type='Flip', keys=[f'img_{domain_a}'], direction='horizontal'), 47 | dict(type='Flip', keys=[f'img_{domain_b}'], direction='horizontal'), 48 | dict(type='RescaleToZeroOne', keys=[f'img_{domain_b}']), 49 | dict( 50 | type='Normalize', 51 | keys=[f'img_{domain_b}'], 52 | to_rgb=True, 53 | mean=[0, 0, 0], 54 | std=[1, 1, 1]), 55 | dict(type='ImageToTensor', keys=[f'img_{domain_a}', f'img_{domain_b}']), 56 | dict( 57 | type='Collect', 58 | keys=[f'img_{domain_a}', f'img_{domain_b}'], 59 | meta_keys=[f'img_{domain_a}_path', f'img_{domain_b}_path']) 60 | ] 61 | 62 | test_pipeline = [ 63 | dict( 64 | type='LoadRAWFromFile', 65 | key=f'img_{domain_a}'), 66 | dict( 67 | type='Demosaic', 68 | key=f'img_{domain_a}'), 69 | dict( 70 | type='RAWNormalize', 71 | blc=BLC, 72 | saturate=SATURATION, 73 | key=f'img_{domain_a}'), 74 | dict( 75 | type='LoadImageFromFile', 76 | io_backend='disk', 77 | key=f'img_{domain_b}', 78 | flag='color'), 79 | dict( 80 | type='Resize', 81 | keys=[f'img_{domain_a}', f'img_{domain_b}'], 82 | scale=(W, H), 83 | interpolation='bicubic'), 84 | dict(type='RescaleToZeroOne', keys=[f'img_{domain_b}']), 85 | dict( 86 | type='Normalize', 87 | keys=[f'img_{domain_b}'], 88 | to_rgb=True, 89 | mean=[0, 0, 0], 90 | std=[1, 1, 1]), 91 | dict(type='ImageToTensor', keys=[f'img_{domain_a}', f'img_{domain_b}']), 92 | dict( 93 | type='Collect', 94 | keys=[f'img_{domain_a}', f'img_{domain_b}'], 95 | meta_keys=[f'img_{domain_a}_path', f'img_{domain_b}_path']) 96 | ] 97 | 98 | data = dict( 99 | samples_per_gpu=4, 100 | workers_per_gpu=8, 101 | drop_last=True, 102 | val_samples_per_gpu=1, 103 | val_workers_per_gpu=0, 104 | train=dict( 105 | type=dataset_type, 106 | dataroot_a=dataroot_a, 107 | dataroot_b=train_dataroot_b, 108 | split_a=train_split_a, 109 | split_b=split_b, 110 | pipeline=train_pipeline, 111 | domain_a=domain_a, 112 | domain_b=domain_b), 113 | val=dict( 114 | type=dataset_type, 115 | dataroot_a=dataroot_a, 116 | dataroot_b=test_dataroot_b, 117 | split_a=test_split_a, 118 | split_b=split_b, 119 | domain_a=domain_a, 120 | domain_b=domain_b, 121 | pipeline=test_pipeline), 122 | test=dict( 123 | type=dataset_type, 124 | dataroot_a=dataroot_a, 125 | dataroot_b=test_dataroot_b, 126 | split_a=test_split_a, 127 | split_b=split_b, 128 | domain_a=domain_a, 129 | domain_b=domain_b, 130 | pipeline=test_pipeline)) 131 | -------------------------------------------------------------------------------- /configs/_base_/datasets/flicker2w_rgb2asi_raw_512x512.py: -------------------------------------------------------------------------------- 1 | BLC = 0 2 | SATURATION = 16484 3 | CAMERA = 'asi_294mcpro' 4 | H, W = 960, 1280 5 | 6 | dataset_type = 'UnpairedCycleR2RDataset' 7 | domain_a = 'raw' # set by user 8 | domain_b = 'rgb' # set by user 9 | # dataset a setting 10 | dataroot_a = 'datasets/multiRAW/{}/raw'.format(CAMERA) 11 | train_split_a = 'datasets/multiRAW/{}/train.txt'.format(CAMERA) 12 | test_split_a = 'datasets/multiRAW/{}/train.txt'.format(CAMERA) 13 | 14 | # dataset b setting 15 | train_dataroot_b = 'datasets/flicker/train' 16 | test_dataroot_b = 'datasets/flicker/train' 17 | split_b = None 18 | 19 | train_pipeline = [ 20 | dict( 21 | type='LoadRAWFromFile', 22 | key=f'img_{domain_a}'), 23 | dict( 24 | type='Demosaic', 25 | key=f'img_{domain_a}'), 26 | dict( 27 | type='RAWNormalize', 28 | blc=BLC, 29 | saturate=SATURATION, 30 | key=f'img_{domain_a}'), 31 | dict( 32 | type='LoadImageFromFile', 33 | io_backend='disk', 34 | key=f'img_{domain_b}', 35 | flag='color'), 36 | dict( 37 | type='Resize', 38 | keys=[f'img_{domain_a}', f'img_{domain_b}'], 39 | scale=(W, H), 40 | interpolation='bicubic'), 41 | dict( 42 | type='Crop', 43 | keys=[f'img_{domain_a}', f'img_{domain_b}'], 44 | crop_size=(512, 512), 45 | random_crop=True), 46 | dict(type='Flip', keys=[f'img_{domain_a}'], direction='horizontal'), 47 | dict(type='Flip', keys=[f'img_{domain_b}'], direction='horizontal'), 48 | dict(type='RescaleToZeroOne', keys=[f'img_{domain_b}']), 49 | dict( 50 | type='Normalize', 51 | keys=[f'img_{domain_b}'], 52 | to_rgb=True, 53 | mean=[0, 0, 0], 54 | std=[1, 1, 1]), 55 | dict(type='ImageToTensor', keys=[f'img_{domain_a}', f'img_{domain_b}']), 56 | dict( 57 | type='Collect', 58 | keys=[f'img_{domain_a}', f'img_{domain_b}'], 59 | meta_keys=[f'img_{domain_a}_path', f'img_{domain_b}_path']) 60 | ] 61 | 62 | test_pipeline = [ 63 | dict( 64 | type='LoadRAWFromFile', 65 | key=f'img_{domain_a}'), 66 | dict( 67 | type='Demosaic', 68 | key=f'img_{domain_a}'), 69 | dict( 70 | type='RAWNormalize', 71 | blc=BLC, 72 | saturate=SATURATION, 73 | key=f'img_{domain_a}'), 74 | dict( 75 | type='LoadImageFromFile', 76 | io_backend='disk', 77 | key=f'img_{domain_b}', 78 | flag='color'), 79 | dict( 80 | type='Resize', 81 | keys=[f'img_{domain_a}', f'img_{domain_b}'], 82 | scale=(W, H), 83 | interpolation='bicubic'), 84 | dict(type='RescaleToZeroOne', keys=[f'img_{domain_b}']), 85 | dict( 86 | type='Normalize', 87 | keys=[f'img_{domain_b}'], 88 | to_rgb=True, 89 | mean=[0, 0, 0], 90 | std=[1, 1, 1]), 91 | dict(type='ImageToTensor', keys=[f'img_{domain_a}', f'img_{domain_b}']), 92 | dict( 93 | type='Collect', 94 | keys=[f'img_{domain_a}', f'img_{domain_b}'], 95 | meta_keys=[f'img_{domain_a}_path', f'img_{domain_b}_path']) 96 | ] 97 | 98 | data = dict( 99 | samples_per_gpu=4, 100 | workers_per_gpu=8, 101 | drop_last=True, 102 | val_samples_per_gpu=1, 103 | val_workers_per_gpu=0, 104 | train=dict( 105 | type=dataset_type, 106 | dataroot_a=dataroot_a, 107 | dataroot_b=train_dataroot_b, 108 | split_a=train_split_a, 109 | split_b=split_b, 110 | pipeline=train_pipeline, 111 | domain_a=domain_a, 112 | domain_b=domain_b), 113 | val=dict( 114 | type=dataset_type, 115 | dataroot_a=dataroot_a, 116 | dataroot_b=test_dataroot_b, 117 | split_a=test_split_a, 118 | split_b=split_b, 119 | domain_a=domain_a, 120 | domain_b=domain_b, 121 | pipeline=test_pipeline), 122 | test=dict( 123 | type=dataset_type, 124 | dataroot_a=dataroot_a, 125 | dataroot_b=test_dataroot_b, 126 | split_a=test_split_a, 127 | split_b=split_b, 128 | domain_a=domain_a, 129 | domain_b=domain_b, 130 | pipeline=test_pipeline)) 131 | -------------------------------------------------------------------------------- /configs/_base_/datasets/flicker2w_rgb2oneplus_raw_512x512.py: -------------------------------------------------------------------------------- 1 | BLC = 0 2 | SATURATION = 1024 3 | CAMERA = 'oneplus_5t' 4 | H, W = 960, 1280 5 | 6 | dataset_type = 'UnpairedCycleR2RDataset' 7 | domain_a = 'raw' # set by user 8 | domain_b = 'rgb' # set by user 9 | # dataset a setting 10 | dataroot_a = 'datasets/multiRAW/{}/raw'.format(CAMERA) 11 | train_split_a = 'datasets/multiRAW/{}/test.txt'.format(CAMERA) 12 | test_split_a = 'datasets/multiRAW/{}/test.txt'.format(CAMERA) 13 | 14 | # dataset b setting 15 | train_dataroot_b = 'datasets/flicker/train' 16 | test_dataroot_b = 'datasets/flicker/train' 17 | split_b = None 18 | 19 | train_pipeline = [ 20 | dict( 21 | type='LoadRAWFromFile', 22 | key=f'img_{domain_a}'), 23 | dict( 24 | type='Demosaic', 25 | key=f'img_{domain_a}'), 26 | dict( 27 | type='RAWNormalize', 28 | blc=BLC, 29 | saturate=SATURATION, 30 | key=f'img_{domain_a}'), 31 | dict( 32 | type='LoadImageFromFile', 33 | io_backend='disk', 34 | key=f'img_{domain_b}', 35 | flag='color'), 36 | dict( 37 | type='Resize', 38 | keys=[f'img_{domain_a}', f'img_{domain_b}'], 39 | scale=(W, H), 40 | interpolation='bicubic'), 41 | dict( 42 | type='Crop', 43 | keys=[f'img_{domain_a}', f'img_{domain_b}'], 44 | crop_size=(512, 512), 45 | random_crop=True), 46 | dict(type='Flip', keys=[f'img_{domain_a}'], direction='horizontal'), 47 | dict(type='Flip', keys=[f'img_{domain_b}'], direction='horizontal'), 48 | dict(type='RescaleToZeroOne', keys=[f'img_{domain_b}']), 49 | dict( 50 | type='Normalize', 51 | keys=[f'img_{domain_b}'], 52 | to_rgb=True, 53 | mean=[0, 0, 0], 54 | std=[1, 1, 1]), 55 | dict(type='ImageToTensor', keys=[f'img_{domain_a}', f'img_{domain_b}']), 56 | dict( 57 | type='Collect', 58 | keys=[f'img_{domain_a}', f'img_{domain_b}'], 59 | meta_keys=[f'img_{domain_a}_path', f'img_{domain_b}_path']) 60 | ] 61 | 62 | test_pipeline = [ 63 | dict( 64 | type='LoadRAWFromFile', 65 | key=f'img_{domain_a}'), 66 | dict( 67 | type='Demosaic', 68 | key=f'img_{domain_a}'), 69 | dict( 70 | type='RAWNormalize', 71 | blc=BLC, 72 | saturate=SATURATION, 73 | key=f'img_{domain_a}'), 74 | dict( 75 | type='LoadImageFromFile', 76 | io_backend='disk', 77 | key=f'img_{domain_b}', 78 | flag='color'), 79 | dict( 80 | type='Resize', 81 | keys=[f'img_{domain_a}', f'img_{domain_b}'], 82 | scale=(W, H), 83 | interpolation='bicubic'), 84 | dict(type='RescaleToZeroOne', keys=[f'img_{domain_b}']), 85 | dict( 86 | type='Normalize', 87 | keys=[f'img_{domain_b}'], 88 | to_rgb=True, 89 | mean=[0, 0, 0], 90 | std=[1, 1, 1]), 91 | dict(type='ImageToTensor', keys=[f'img_{domain_a}', f'img_{domain_b}']), 92 | dict( 93 | type='Collect', 94 | keys=[f'img_{domain_a}', f'img_{domain_b}'], 95 | meta_keys=[f'img_{domain_a}_path', f'img_{domain_b}_path']) 96 | ] 97 | 98 | data = dict( 99 | samples_per_gpu=4, 100 | workers_per_gpu=8, 101 | drop_last=True, 102 | val_samples_per_gpu=1, 103 | val_workers_per_gpu=0, 104 | train=dict( 105 | type=dataset_type, 106 | dataroot_a=dataroot_a, 107 | dataroot_b=train_dataroot_b, 108 | split_a=train_split_a, 109 | split_b=split_b, 110 | pipeline=train_pipeline, 111 | domain_a=domain_a, 112 | domain_b=domain_b), 113 | val=dict( 114 | type=dataset_type, 115 | dataroot_a=dataroot_a, 116 | dataroot_b=test_dataroot_b, 117 | split_a=test_split_a, 118 | split_b=split_b, 119 | domain_a=domain_a, 120 | domain_b=domain_b, 121 | pipeline=test_pipeline), 122 | test=dict( 123 | type=dataset_type, 124 | dataroot_a=dataroot_a, 125 | dataroot_b=test_dataroot_b, 126 | split_a=test_split_a, 127 | split_b=split_b, 128 | domain_a=domain_a, 129 | domain_b=domain_b, 130 | pipeline=test_pipeline)) 131 | -------------------------------------------------------------------------------- /configs/_base_/datasets/bdd100k_rgb2huawei_raw_512x512.py: -------------------------------------------------------------------------------- 1 | BLC = 256 2 | SATURATION = 4095 3 | CAMERA = 'huawei_p30pro' 4 | H, W = 960, 1280 5 | 6 | dataset_type = 'UnpairedCycleR2RDataset' 7 | domain_a = 'raw' # set by user 8 | domain_b = 'rgb' # set by user 9 | # dataset a setting 10 | dataroot_a = 'datasets/multiRAW/{}/raw'.format(CAMERA) 11 | train_split_a = 'datasets/multiRAW/{}/train.txt'.format(CAMERA) 12 | test_split_a = 'datasets/multiRAW/{}/train.txt'.format(CAMERA) 13 | 14 | # dataset b setting 15 | train_dataroot_b = 'datasets/bdd100k/images/100k/train' 16 | test_dataroot_b = 'datasets/bdd100k/images/100k/test' 17 | split_b = None 18 | 19 | train_pipeline = [ 20 | dict( 21 | type='LoadRAWFromFile', 22 | key=f'img_{domain_a}'), 23 | dict( 24 | type='Demosaic', 25 | key=f'img_{domain_a}'), 26 | dict( 27 | type='RAWNormalize', 28 | blc=BLC, 29 | saturate=SATURATION, 30 | key=f'img_{domain_a}'), 31 | dict( 32 | type='LoadImageFromFile', 33 | io_backend='disk', 34 | key=f'img_{domain_b}', 35 | flag='color'), 36 | dict( 37 | type='Resize', 38 | keys=[f'img_{domain_a}'], 39 | scale=(W, H), 40 | interpolation='bicubic'), 41 | dict( 42 | type='Crop', 43 | keys=[f'img_{domain_a}', f'img_{domain_b}'], 44 | crop_size=(512, 512), 45 | random_crop=True), 46 | dict(type='Flip', keys=[f'img_{domain_a}'], direction='horizontal'), 47 | dict(type='Flip', keys=[f'img_{domain_b}'], direction='horizontal'), 48 | dict(type='RescaleToZeroOne', keys=[f'img_{domain_b}']), 49 | dict( 50 | type='Normalize', 51 | keys=[f'img_{domain_b}'], 52 | to_rgb=True, 53 | mean=[0, 0, 0], 54 | std=[1, 1, 1]), 55 | dict(type='ImageToTensor', keys=[f'img_{domain_a}', f'img_{domain_b}']), 56 | dict( 57 | type='Collect', 58 | keys=[f'img_{domain_a}', f'img_{domain_b}'], 59 | meta_keys=[f'img_{domain_a}_path', f'img_{domain_b}_path']) 60 | ] 61 | 62 | test_pipeline = [ 63 | dict( 64 | type='LoadRAWFromFile', 65 | key=f'img_{domain_a}'), 66 | dict( 67 | type='Demosaic', 68 | key=f'img_{domain_a}'), 69 | dict( 70 | type='RAWNormalize', 71 | blc=BLC, 72 | saturate=SATURATION, 73 | key=f'img_{domain_a}'), 74 | dict( 75 | type='LoadImageFromFile', 76 | io_backend='disk', 77 | key=f'img_{domain_b}', 78 | flag='color'), 79 | dict( 80 | type='Resize', 81 | keys=[f'img_{domain_a}', f'img_{domain_b}'], 82 | scale=(W, H), 83 | interpolation='bicubic'), 84 | dict(type='RescaleToZeroOne', keys=[f'img_{domain_b}']), 85 | dict( 86 | type='Normalize', 87 | keys=[f'img_{domain_b}'], 88 | to_rgb=True, 89 | mean=[0, 0, 0], 90 | std=[1, 1, 1]), 91 | dict(type='ImageToTensor', keys=[f'img_{domain_a}', f'img_{domain_b}']), 92 | dict( 93 | type='Collect', 94 | keys=[f'img_{domain_a}', f'img_{domain_b}'], 95 | meta_keys=[f'img_{domain_a}_path', f'img_{domain_b}_path']) 96 | ] 97 | 98 | data = dict( 99 | samples_per_gpu=4, 100 | workers_per_gpu=8, 101 | drop_last=True, 102 | val_samples_per_gpu=1, 103 | val_workers_per_gpu=0, 104 | train=dict( 105 | type=dataset_type, 106 | dataroot_a=dataroot_a, 107 | dataroot_b=train_dataroot_b, 108 | split_a=train_split_a, 109 | split_b=split_b, 110 | pipeline=train_pipeline, 111 | domain_a=domain_a, 112 | domain_b=domain_b), 113 | val=dict( 114 | type=dataset_type, 115 | dataroot_a=dataroot_a, 116 | dataroot_b=test_dataroot_b, 117 | split_a=test_split_a, 118 | split_b=split_b, 119 | domain_a=domain_a, 120 | domain_b=domain_b, 121 | pipeline=test_pipeline), 122 | test=dict( 123 | type=dataset_type, 124 | dataroot_a=dataroot_a, 125 | dataroot_b=test_dataroot_b, 126 | split_a=test_split_a, 127 | split_b=split_b, 128 | domain_a=domain_a, 129 | domain_b=domain_b, 130 | pipeline=test_pipeline)) 131 | -------------------------------------------------------------------------------- /configs/_base_/datasets/bdd100k_rgb2iphone_raw_512x512.py: -------------------------------------------------------------------------------- 1 | BLC = 528 2 | SATURATION = 4095 3 | CAMERA = 'iphone_xsmax' 4 | H, W = 960, 1280 5 | 6 | dataset_type = 'UnpairedCycleR2RDataset' 7 | domain_a = 'raw' # set by user 8 | domain_b = 'rgb' # set by user 9 | # dataset a setting 10 | dataroot_a = 'datasets/multiRAW/{}/raw'.format(CAMERA) 11 | train_split_a = 'datasets/multiRAW/{}/train.txt'.format(CAMERA) 12 | test_split_a = 'datasets/multiRAW/{}/train.txt'.format(CAMERA) 13 | 14 | # dataset b setting 15 | train_dataroot_b = 'datasets/bdd100k/images/100k/train' 16 | test_dataroot_b = 'datasets/bdd100k/images/100k/test' 17 | split_b = None 18 | 19 | train_pipeline = [ 20 | dict( 21 | type='LoadRAWFromFile', 22 | key=f'img_{domain_a}'), 23 | dict( 24 | type='Demosaic', 25 | key=f'img_{domain_a}'), 26 | dict( 27 | type='RAWNormalize', 28 | blc=BLC, 29 | saturate=SATURATION, 30 | key=f'img_{domain_a}'), 31 | dict( 32 | type='LoadImageFromFile', 33 | io_backend='disk', 34 | key=f'img_{domain_b}', 35 | flag='color'), 36 | dict( 37 | type='Resize', 38 | keys=[f'img_{domain_a}'], 39 | scale=(W, H), 40 | interpolation='bicubic'), 41 | dict( 42 | type='Crop', 43 | keys=[f'img_{domain_a}', f'img_{domain_b}'], 44 | crop_size=(512, 512), 45 | random_crop=True), 46 | dict(type='Flip', keys=[f'img_{domain_a}'], direction='horizontal'), 47 | dict(type='Flip', keys=[f'img_{domain_b}'], direction='horizontal'), 48 | dict(type='RescaleToZeroOne', keys=[f'img_{domain_b}']), 49 | dict( 50 | type='Normalize', 51 | keys=[f'img_{domain_b}'], 52 | to_rgb=True, 53 | mean=[0, 0, 0], 54 | std=[1, 1, 1]), 55 | dict(type='ImageToTensor', keys=[f'img_{domain_a}', f'img_{domain_b}']), 56 | dict( 57 | type='Collect', 58 | keys=[f'img_{domain_a}', f'img_{domain_b}'], 59 | meta_keys=[f'img_{domain_a}_path', f'img_{domain_b}_path']) 60 | ] 61 | 62 | test_pipeline = [ 63 | dict( 64 | type='LoadRAWFromFile', 65 | key=f'img_{domain_a}'), 66 | dict( 67 | type='Demosaic', 68 | key=f'img_{domain_a}'), 69 | dict( 70 | type='RAWNormalize', 71 | blc=BLC, 72 | saturate=SATURATION, 73 | key=f'img_{domain_a}'), 74 | dict( 75 | type='LoadImageFromFile', 76 | io_backend='disk', 77 | key=f'img_{domain_b}', 78 | flag='color'), 79 | dict( 80 | type='Resize', 81 | keys=[f'img_{domain_a}', f'img_{domain_b}'], 82 | scale=(W, H), 83 | interpolation='bicubic'), 84 | dict(type='RescaleToZeroOne', keys=[f'img_{domain_b}']), 85 | dict( 86 | type='Normalize', 87 | keys=[f'img_{domain_b}'], 88 | to_rgb=True, 89 | mean=[0, 0, 0], 90 | std=[1, 1, 1]), 91 | dict(type='ImageToTensor', keys=[f'img_{domain_a}', f'img_{domain_b}']), 92 | dict( 93 | type='Collect', 94 | keys=[f'img_{domain_a}', f'img_{domain_b}'], 95 | meta_keys=[f'img_{domain_a}_path', f'img_{domain_b}_path']) 96 | ] 97 | 98 | data = dict( 99 | samples_per_gpu=4, 100 | workers_per_gpu=8, 101 | drop_last=True, 102 | val_samples_per_gpu=1, 103 | val_workers_per_gpu=0, 104 | train=dict( 105 | type=dataset_type, 106 | dataroot_a=dataroot_a, 107 | dataroot_b=train_dataroot_b, 108 | split_a=train_split_a, 109 | split_b=split_b, 110 | pipeline=train_pipeline, 111 | domain_a=domain_a, 112 | domain_b=domain_b), 113 | val=dict( 114 | type=dataset_type, 115 | dataroot_a=dataroot_a, 116 | dataroot_b=test_dataroot_b, 117 | split_a=test_split_a, 118 | split_b=split_b, 119 | domain_a=domain_a, 120 | domain_b=domain_b, 121 | pipeline=test_pipeline), 122 | test=dict( 123 | type=dataset_type, 124 | dataroot_a=dataroot_a, 125 | dataroot_b=test_dataroot_b, 126 | split_a=test_split_a, 127 | split_b=split_b, 128 | domain_a=domain_a, 129 | domain_b=domain_b, 130 | pipeline=test_pipeline)) 131 | -------------------------------------------------------------------------------- /configs/_base_/datasets/flicker2w_rgb2huawei_raw_512x512.py: -------------------------------------------------------------------------------- 1 | BLC = 256 2 | SATURATION = 4095 3 | CAMERA = 'huawei_p30pro' 4 | H, W = 960, 1280 5 | 6 | dataset_type = 'UnpairedCycleR2RDataset' 7 | domain_a = 'raw' # set by user 8 | domain_b = 'rgb' # set by user 9 | # dataset a setting 10 | dataroot_a = 'datasets/multiRAW/{}/raw'.format(CAMERA) 11 | train_split_a = 'datasets/multiRAW/{}/train.txt'.format(CAMERA) 12 | test_split_a = 'datasets/multiRAW/{}/train.txt'.format(CAMERA) 13 | 14 | # dataset b setting 15 | train_dataroot_b = 'datasets/flicker/train' 16 | test_dataroot_b = 'datasets/flicker/train' 17 | split_b = None 18 | 19 | train_pipeline = [ 20 | dict( 21 | type='LoadRAWFromFile', 22 | key=f'img_{domain_a}'), 23 | dict( 24 | type='Demosaic', 25 | key=f'img_{domain_a}'), 26 | dict( 27 | type='RAWNormalize', 28 | blc=BLC, 29 | saturate=SATURATION, 30 | key=f'img_{domain_a}'), 31 | dict( 32 | type='LoadImageFromFile', 33 | io_backend='disk', 34 | key=f'img_{domain_b}', 35 | flag='color'), 36 | dict( 37 | type='Resize', 38 | keys=[f'img_{domain_a}', f'img_{domain_b}'], 39 | scale=(W, H), 40 | interpolation='bicubic'), 41 | dict( 42 | type='Crop', 43 | keys=[f'img_{domain_a}', f'img_{domain_b}'], 44 | crop_size=(512, 512), 45 | random_crop=True), 46 | dict(type='Flip', keys=[f'img_{domain_a}'], direction='horizontal'), 47 | dict(type='Flip', keys=[f'img_{domain_b}'], direction='horizontal'), 48 | dict(type='RescaleToZeroOne', keys=[f'img_{domain_b}']), 49 | dict( 50 | type='Normalize', 51 | keys=[f'img_{domain_b}'], 52 | to_rgb=True, 53 | mean=[0, 0, 0], 54 | std=[1, 1, 1]), 55 | dict(type='ImageToTensor', keys=[f'img_{domain_a}', f'img_{domain_b}']), 56 | dict( 57 | type='Collect', 58 | keys=[f'img_{domain_a}', f'img_{domain_b}'], 59 | meta_keys=[f'img_{domain_a}_path', f'img_{domain_b}_path']) 60 | ] 61 | 62 | test_pipeline = [ 63 | dict( 64 | type='LoadRAWFromFile', 65 | key=f'img_{domain_a}'), 66 | dict( 67 | type='Demosaic', 68 | key=f'img_{domain_a}'), 69 | dict( 70 | type='RAWNormalize', 71 | blc=BLC, 72 | saturate=SATURATION, 73 | key=f'img_{domain_a}'), 74 | dict( 75 | type='LoadImageFromFile', 76 | io_backend='disk', 77 | key=f'img_{domain_b}', 78 | flag='color'), 79 | dict( 80 | type='Resize', 81 | keys=[f'img_{domain_a}', f'img_{domain_b}'], 82 | scale=(W, H), 83 | interpolation='bicubic'), 84 | dict(type='RescaleToZeroOne', keys=[f'img_{domain_b}']), 85 | dict( 86 | type='Normalize', 87 | keys=[f'img_{domain_b}'], 88 | to_rgb=True, 89 | mean=[0, 0, 0], 90 | std=[1, 1, 1]), 91 | dict(type='ImageToTensor', keys=[f'img_{domain_a}', f'img_{domain_b}']), 92 | dict( 93 | type='Collect', 94 | keys=[f'img_{domain_a}', f'img_{domain_b}'], 95 | meta_keys=[f'img_{domain_a}_path', f'img_{domain_b}_path']) 96 | ] 97 | 98 | data = dict( 99 | samples_per_gpu=4, 100 | workers_per_gpu=8, 101 | drop_last=True, 102 | val_samples_per_gpu=1, 103 | val_workers_per_gpu=0, 104 | train=dict( 105 | type=dataset_type, 106 | dataroot_a=dataroot_a, 107 | dataroot_b=train_dataroot_b, 108 | split_a=train_split_a, 109 | split_b=split_b, 110 | pipeline=train_pipeline, 111 | domain_a=domain_a, 112 | domain_b=domain_b), 113 | val=dict( 114 | type=dataset_type, 115 | dataroot_a=dataroot_a, 116 | dataroot_b=test_dataroot_b, 117 | split_a=test_split_a, 118 | split_b=split_b, 119 | domain_a=domain_a, 120 | domain_b=domain_b, 121 | pipeline=test_pipeline), 122 | test=dict( 123 | type=dataset_type, 124 | dataroot_a=dataroot_a, 125 | dataroot_b=test_dataroot_b, 126 | split_a=test_split_a, 127 | split_b=split_b, 128 | domain_a=domain_a, 129 | domain_b=domain_b, 130 | pipeline=test_pipeline)) 131 | -------------------------------------------------------------------------------- /configs/_base_/datasets/flicker2w_rgb2iphone_raw_512x512.py: -------------------------------------------------------------------------------- 1 | BLC = 528 2 | SATURATION = 4095 3 | CAMERA = 'iphone_xsmax' 4 | H, W = 960, 1280 5 | 6 | dataset_type = 'UnpairedCycleR2RDataset' 7 | domain_a = 'raw' # set by user 8 | domain_b = 'rgb' # set by user 9 | # dataset a setting 10 | dataroot_a = 'datasets/multiRAW/{}/raw'.format(CAMERA) 11 | train_split_a = 'datasets/multiRAW/{}/train.txt'.format(CAMERA) 12 | test_split_a = 'datasets/multiRAW/{}/train.txt'.format(CAMERA) 13 | 14 | # dataset b setting 15 | train_dataroot_b = 'datasets/flicker/train' 16 | test_dataroot_b = 'datasets/flicker/train' 17 | split_b = None 18 | 19 | train_pipeline = [ 20 | dict( 21 | type='LoadRAWFromFile', 22 | key=f'img_{domain_a}'), 23 | dict( 24 | type='Demosaic', 25 | key=f'img_{domain_a}'), 26 | dict( 27 | type='RAWNormalize', 28 | blc=BLC, 29 | saturate=SATURATION, 30 | key=f'img_{domain_a}'), 31 | dict( 32 | type='LoadImageFromFile', 33 | io_backend='disk', 34 | key=f'img_{domain_b}', 35 | flag='color'), 36 | dict( 37 | type='Resize', 38 | keys=[f'img_{domain_a}', f'img_{domain_b}'], 39 | scale=(W, H), 40 | interpolation='bicubic'), 41 | dict( 42 | type='Crop', 43 | keys=[f'img_{domain_a}', f'img_{domain_b}'], 44 | crop_size=(512, 512), 45 | random_crop=True), 46 | dict(type='Flip', keys=[f'img_{domain_a}'], direction='horizontal'), 47 | dict(type='Flip', keys=[f'img_{domain_b}'], direction='horizontal'), 48 | dict(type='RescaleToZeroOne', keys=[f'img_{domain_b}']), 49 | dict( 50 | type='Normalize', 51 | keys=[f'img_{domain_b}'], 52 | to_rgb=True, 53 | mean=[0, 0, 0], 54 | std=[1, 1, 1]), 55 | dict(type='ImageToTensor', keys=[f'img_{domain_a}', f'img_{domain_b}']), 56 | dict( 57 | type='Collect', 58 | keys=[f'img_{domain_a}', f'img_{domain_b}'], 59 | meta_keys=[f'img_{domain_a}_path', f'img_{domain_b}_path']) 60 | ] 61 | 62 | test_pipeline = [ 63 | dict( 64 | type='LoadRAWFromFile', 65 | key=f'img_{domain_a}'), 66 | dict( 67 | type='Demosaic', 68 | key=f'img_{domain_a}'), 69 | dict( 70 | type='RAWNormalize', 71 | blc=BLC, 72 | saturate=SATURATION, 73 | key=f'img_{domain_a}'), 74 | dict( 75 | type='LoadImageFromFile', 76 | io_backend='disk', 77 | key=f'img_{domain_b}', 78 | flag='color'), 79 | dict( 80 | type='Resize', 81 | keys=[f'img_{domain_a}', f'img_{domain_b}'], 82 | scale=(W, H), 83 | interpolation='bicubic'), 84 | dict(type='RescaleToZeroOne', keys=[f'img_{domain_b}']), 85 | dict( 86 | type='Normalize', 87 | keys=[f'img_{domain_b}'], 88 | to_rgb=True, 89 | mean=[0, 0, 0], 90 | std=[1, 1, 1]), 91 | dict(type='ImageToTensor', keys=[f'img_{domain_a}', f'img_{domain_b}']), 92 | dict( 93 | type='Collect', 94 | keys=[f'img_{domain_a}', f'img_{domain_b}'], 95 | meta_keys=[f'img_{domain_a}_path', f'img_{domain_b}_path']) 96 | ] 97 | 98 | data = dict( 99 | samples_per_gpu=4, 100 | workers_per_gpu=8, 101 | drop_last=True, 102 | val_samples_per_gpu=1, 103 | val_workers_per_gpu=0, 104 | train=dict( 105 | type=dataset_type, 106 | dataroot_a=dataroot_a, 107 | dataroot_b=train_dataroot_b, 108 | split_a=train_split_a, 109 | split_b=split_b, 110 | pipeline=train_pipeline, 111 | domain_a=domain_a, 112 | domain_b=domain_b), 113 | val=dict( 114 | type=dataset_type, 115 | dataroot_a=dataroot_a, 116 | dataroot_b=test_dataroot_b, 117 | split_a=test_split_a, 118 | split_b=split_b, 119 | domain_a=domain_a, 120 | domain_b=domain_b, 121 | pipeline=test_pipeline), 122 | test=dict( 123 | type=dataset_type, 124 | dataroot_a=dataroot_a, 125 | dataroot_b=test_dataroot_b, 126 | split_a=test_split_a, 127 | split_b=split_b, 128 | domain_a=domain_a, 129 | domain_b=domain_b, 130 | pipeline=test_pipeline)) 131 | -------------------------------------------------------------------------------- /configs/_base_/datasets/cityscapes_rgb2iphone_raw_512x512.py: -------------------------------------------------------------------------------- 1 | BLC = 528 2 | SATURATION = 4095 3 | CAMERA = 'iphone_xsmax' 4 | H, W = 1024, 2048 5 | 6 | dataset_type = 'UnpairedCycleR2RDataset' 7 | domain_a = 'raw' # set by user 8 | domain_b = 'rgb' # set by user 9 | # dataset a setting 10 | dataroot_a = 'datasets/multiRAW/{}/raw'.format(CAMERA) 11 | train_split_a = 'datasets/multiRAW/{}/train.txt'.format(CAMERA) 12 | test_split_a = 'datasets/multiRAW/{}/train.txt'.format(CAMERA) 13 | 14 | # dataset b setting 15 | train_dataroot_b = 'datasets/cityscapes/leftImg8bit/train' 16 | test_dataroot_b = 'datasets/cityscapes/leftImg8bit/test' 17 | split_b = None 18 | 19 | train_pipeline = [ 20 | dict( 21 | type='LoadRAWFromFile', 22 | key=f'img_{domain_a}'), 23 | dict( 24 | type='Demosaic', 25 | key=f'img_{domain_a}'), 26 | dict( 27 | type='RAWNormalize', 28 | blc=BLC, 29 | saturate=SATURATION, 30 | key=f'img_{domain_a}'), 31 | dict( 32 | type='LoadImageFromFile', 33 | io_backend='disk', 34 | key=f'img_{domain_b}', 35 | flag='color'), 36 | dict( 37 | type='Resize', 38 | keys=[f'img_{domain_a}'], 39 | scale=(W, H), 40 | interpolation='bicubic'), 41 | dict( 42 | type='Crop', 43 | keys=[f'img_{domain_a}', f'img_{domain_b}'], 44 | crop_size=(512, 512), 45 | random_crop=True), 46 | dict(type='Flip', keys=[f'img_{domain_a}'], direction='horizontal'), 47 | dict(type='Flip', keys=[f'img_{domain_b}'], direction='horizontal'), 48 | dict(type='RescaleToZeroOne', keys=[f'img_{domain_b}']), 49 | dict( 50 | type='Normalize', 51 | keys=[f'img_{domain_b}'], 52 | to_rgb=True, 53 | mean=[0, 0, 0], 54 | std=[1, 1, 1]), 55 | dict(type='ImageToTensor', keys=[f'img_{domain_a}', f'img_{domain_b}']), 56 | dict( 57 | type='Collect', 58 | keys=[f'img_{domain_a}', f'img_{domain_b}'], 59 | meta_keys=[f'img_{domain_a}_path', f'img_{domain_b}_path']) 60 | ] 61 | 62 | test_pipeline = [ 63 | dict( 64 | type='LoadRAWFromFile', 65 | key=f'img_{domain_a}'), 66 | dict( 67 | type='Demosaic', 68 | key=f'img_{domain_a}'), 69 | dict( 70 | type='RAWNormalize', 71 | blc=BLC, 72 | saturate=SATURATION, 73 | key=f'img_{domain_a}'), 74 | dict( 75 | type='LoadImageFromFile', 76 | io_backend='disk', 77 | key=f'img_{domain_b}', 78 | flag='color'), 79 | dict( 80 | type='Resize', 81 | keys=[f'img_{domain_a}', f'img_{domain_b}'], 82 | scale=(W, H), 83 | interpolation='bicubic'), 84 | dict(type='RescaleToZeroOne', keys=[f'img_{domain_b}']), 85 | dict( 86 | type='Normalize', 87 | keys=[f'img_{domain_b}'], 88 | to_rgb=True, 89 | mean=[0, 0, 0], 90 | std=[1, 1, 1]), 91 | dict(type='ImageToTensor', keys=[f'img_{domain_a}', f'img_{domain_b}']), 92 | dict( 93 | type='Collect', 94 | keys=[f'img_{domain_a}', f'img_{domain_b}'], 95 | meta_keys=[f'img_{domain_a}_path', f'img_{domain_b}_path']) 96 | ] 97 | 98 | data = dict( 99 | samples_per_gpu=4, 100 | workers_per_gpu=8, 101 | drop_last=True, 102 | val_samples_per_gpu=1, 103 | val_workers_per_gpu=0, 104 | train=dict( 105 | type=dataset_type, 106 | dataroot_a=dataroot_a, 107 | dataroot_b=train_dataroot_b, 108 | split_a=train_split_a, 109 | split_b=split_b, 110 | pipeline=train_pipeline, 111 | domain_a=domain_a, 112 | domain_b=domain_b), 113 | val=dict( 114 | type=dataset_type, 115 | dataroot_a=dataroot_a, 116 | dataroot_b=test_dataroot_b, 117 | split_a=test_split_a, 118 | split_b=split_b, 119 | domain_a=domain_a, 120 | domain_b=domain_b, 121 | pipeline=test_pipeline), 122 | test=dict( 123 | type=dataset_type, 124 | dataroot_a=dataroot_a, 125 | dataroot_b=test_dataroot_b, 126 | split_a=test_split_a, 127 | split_b=split_b, 128 | domain_a=domain_a, 129 | domain_b=domain_b, 130 | pipeline=test_pipeline)) 131 | -------------------------------------------------------------------------------- /unpaired_cycler2r/datasets/pipelines/raw.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from imageio import imread 3 | 4 | import numpy as np 5 | import rawpy 6 | 7 | from mmgen.datasets.builder import PIPELINES 8 | 9 | 10 | 11 | @PIPELINES.register_module() 12 | class LoadRAWFromFile: 13 | """Load image from file. 14 | 15 | Args: 16 | key (str): Keys in results to find corresponding path. Default: 'gt'. 17 | kwargs (dict): Args for file client. 18 | """ 19 | 20 | def __init__(self, 21 | key='gt'): 22 | self.key = key 23 | 24 | def __call__(self, results): 25 | """Call function. 26 | 27 | Args: 28 | results (dict): A dict containing the necessary information and 29 | data for augmentation. 30 | 31 | Returns: 32 | dict: A dict containing the processed data and information. 33 | """ 34 | filepath = str(results[f'{self.key}_path']) 35 | if filepath.endswith('TIF'): 36 | img = imread(filepath).astype(np.float32) 37 | else: 38 | with rawpy.imread(filepath) as f: 39 | img = f.raw_image_visible.copy().astype(np.float32) 40 | 41 | results[self.key] = img 42 | results[f'{self.key}_path'] = filepath 43 | results[f'{self.key}_ori_shape'] = img.shape 44 | return results 45 | 46 | def __repr__(self): 47 | repr_str = self.__class__.__name__ 48 | repr_str += f'(key={self.key})' 49 | return repr_str 50 | 51 | 52 | @PIPELINES.register_module() 53 | class Demosaic: 54 | """ Demosaic RAW image. 55 | """ 56 | 57 | def __init__(self, 58 | key='gt'): 59 | self.key = key 60 | 61 | def __call__(self, results): 62 | """Call functions to load image and get image meta information. 63 | 64 | Args: 65 | results (dict): Result dict from :obj:`mmdet.CustomDataset`. 66 | 67 | Returns: 68 | dict: The dict contains loaded image and meta information. 69 | """ 70 | img = results[self.key] 71 | if len(img.shape) == 3: 72 | assert img.shape[2] == 1 73 | img = img[..., 0] 74 | h, w = img.shape 75 | rgb = np.zeros((h, w, 3), dtype=img.dtype) 76 | 77 | img = np.pad(img, ((2, 2), (2, 2)), mode='reflect') 78 | r, gb, gr, b = img[0::2, 0::2], img[0::2, 1::2], img[1::2, 0::2], img[1::2, 1::2] 79 | 80 | rgb[0::2, 0::2, 0] = r[1:-1, 1:-1] 81 | rgb[0::2, 0::2, 1] = (gr[1:-1, 1:-1] + gr[:-2, 1:-1] + gb[1:-1, 1:-1] + gb[1:-1, :-2]) / 4 82 | rgb[0::2, 0::2, 2] = (b[1:-1, 1:-1] + b[:-2, :-2] + b[1:-1, :-2] + b[:-2, 1:-1]) / 4 83 | 84 | rgb[1::2, 0::2, 0] = (r[1:-1, 1:-1] + r[2:, 1:-1]) / 2 85 | rgb[1::2, 0::2, 1] = gr[1:-1, 1:-1] 86 | rgb[1::2, 0::2, 2] = (b[1:-1, 1:-1] + b[1:-1, :-2]) / 2 87 | 88 | rgb[0::2, 1::2, 0] = (r[1:-1, 1:-1] + r[1:-1, 2:]) / 2 89 | rgb[0::2, 1::2, 1] = gb[1:-1, 1:-1] 90 | rgb[0::2, 1::2, 2] = (b[1:-1, 1:-1] + b[:-2, 1:-1]) / 2 91 | 92 | rgb[1::2, 1::2, 0] = (r[1:-1, 1:-1] + r[2:, 2:] + r[1:-1, 2:] + r[2:, 1:-1]) / 4 93 | rgb[1::2, 1::2, 1] = (gr[1:-1, 1:-1] + gr[1:-1, 2:] + gb[1:-1, 1:-1] + gb[2:, 1:-1]) / 4 94 | rgb[1::2, 1::2, 2] = b[1:-1, 1:-1] 95 | 96 | results[self.key] = rgb 97 | results[f'{self.key}_ori_shape'] = rgb.shape 98 | return results 99 | 100 | def __repr__(self): 101 | repr_str = self.__class__.__name__ 102 | repr_str += f'(key={self.key})' 103 | return repr_str 104 | 105 | 106 | @PIPELINES.register_module() 107 | class RAWNormalize: 108 | """Rearrange RAW to four channels. 109 | """ 110 | 111 | def __init__(self, 112 | key='gt', 113 | blc=528, 114 | saturate=4095): 115 | self.key = key 116 | self.blc = blc 117 | self.saturate = saturate 118 | 119 | def __call__(self, results): 120 | """Call functions to load image and get image meta information. 121 | 122 | Args: 123 | results (dict): Result dict from :obj:`mmdet.CustomDataset`. 124 | 125 | Returns: 126 | dict: The dict contains loaded image and meta information. 127 | """ 128 | img = results[self.key] 129 | img = (img-self.blc) / (self.saturate-self.blc) 130 | img = np.clip(img, 0, 1) 131 | results[self.key] = img 132 | return results 133 | 134 | def __repr__(self): 135 | repr_str = self.__class__.__name__ 136 | repr_str += f'(key={self.key}, ' \ 137 | f'(blc={self.blc}, ' \ 138 | f'(saturate={self.saturate})' 139 | return repr_str 140 | -------------------------------------------------------------------------------- /unpaired_cycler2r/datasets/unpaired_cycler2r_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os.path as osp 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | from mmcv import scandir 7 | from torch.utils.data import Dataset 8 | 9 | from pathlib import Path 10 | from mmgen.datasets.builder import DATASETS 11 | from mmgen.datasets.pipelines import Compose 12 | 13 | IMG_EXTENSIONS = ('.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', 14 | '.PPM', '.bmp', '.BMP', '.tif', '.raw.TIF', '.tiff', '.TIFF', 15 | '.DNG', '.dng') 16 | 17 | 18 | @DATASETS.register_module() 19 | class UnpairedCycleR2RDataset(Dataset): 20 | """General unpaired image folder dataset for image generation. 21 | 22 | It assumes that the training directory of images from domain A is 23 | '/path/to/data/trainA', and that from domain B is '/path/to/data/trainB', 24 | respectively. '/path/to/data' can be initialized by args 'dataroot'. 25 | During test time, the directory is '/path/to/data/testA' and 26 | '/path/to/data/testB', respectively. 27 | 28 | Args: 29 | dataroot (str | :obj:`Path`): Path to the folder root of unpaired 30 | images. 31 | pipeline (List[dict | callable]): A sequence of data transformations. 32 | test_mode (bool): Store `True` when building test dataset. 33 | Default: `False`. 34 | domain_a (str, optional): Domain of images in trainA / testA. 35 | Defaults to None. 36 | domain_b (str, optional): Domain of images in trainB / testB. 37 | Defaults to None. 38 | """ 39 | 40 | def __init__(self, 41 | dataroot_a, 42 | dataroot_b, 43 | split_a, 44 | split_b, 45 | pipeline, 46 | test_mode=False, 47 | domain_a=None, 48 | domain_b=None): 49 | super().__init__() 50 | self.data_infos_a = self.load_annotations(dataroot_a, split_a) 51 | self.data_infos_b = self.load_annotations(dataroot_b, split_b) 52 | self.len_a = len(self.data_infos_a) 53 | self.len_b = len(self.data_infos_b) 54 | self.test_mode = test_mode 55 | self.pipeline = Compose(pipeline) 56 | assert isinstance(domain_a, str) 57 | assert isinstance(domain_b, str) 58 | self.domain_a = domain_a 59 | self.domain_b = domain_b 60 | 61 | def load_annotations(self, dataroot, split): 62 | """Load unpaired image paths of one domain. 63 | 64 | Args: 65 | dataroot (str): Path to the folder root for unpaired images of 66 | one domain. 67 | 68 | Returns: 69 | list[dict]: List that contains unpaired image paths of one domain. 70 | """ 71 | data_infos = [] 72 | paths = sorted(self.scan_folder(dataroot)) 73 | if len(paths)==0: 74 | for suffix in IMG_EXTENSIONS: 75 | paths += [str(s) for s in Path(dataroot).glob(f'*/*{suffix}')] 76 | paths = sorted(paths) 77 | if split is not None: 78 | with open(split, 'r') as f: 79 | split_names = [line.strip() for line in f.readlines()] 80 | paths = [path for path in paths if osp.basename(path).split('.')[0] in split_names] 81 | 82 | for path in paths: 83 | data_infos.append(dict(path=path)) 84 | return data_infos 85 | 86 | def prepare_train_data(self, idx): 87 | """Prepare unpaired training data. 88 | 89 | Args: 90 | idx (int): Index of current batch. 91 | 92 | Returns: 93 | dict: Prepared training data batch. 94 | """ 95 | img_a_path = self.data_infos_a[idx % self.len_a]['path'] 96 | idx_b = np.random.randint(0, self.len_b) 97 | img_b_path = self.data_infos_b[idx_b]['path'] 98 | results = dict() 99 | results[f'img_{self.domain_a}_path'] = img_a_path 100 | results[f'img_{self.domain_b}_path'] = img_b_path 101 | return self.pipeline(results) 102 | 103 | def prepare_test_data(self, idx): 104 | """Prepare unpaired test data. 105 | 106 | Args: 107 | idx (int): Index of current batch. 108 | 109 | Returns: 110 | list[dict]: Prepared test data batch. 111 | """ 112 | img_a_path = self.data_infos_a[idx % self.len_a]['path'] 113 | img_b_path = self.data_infos_b[idx % self.len_b]['path'] 114 | results = dict() 115 | results[f'img_{self.domain_a}_path'] = img_a_path 116 | results[f'img_{self.domain_b}_path'] = img_b_path 117 | return self.pipeline(results) 118 | 119 | def __len__(self): 120 | return max(self.len_a, self.len_b) 121 | 122 | @staticmethod 123 | def scan_folder(path): 124 | """Obtain image path list (including sub-folders) from a given folder. 125 | 126 | Args: 127 | path (str | :obj:`Path`): Folder path. 128 | 129 | Returns: 130 | list[str]: Image list obtained from the given folder. 131 | """ 132 | 133 | if isinstance(path, (str, Path)): 134 | path = str(path) 135 | else: 136 | raise TypeError("'path' must be a str or a Path object, " 137 | f'but received {type(path)}.') 138 | 139 | images = scandir(path, suffix=IMG_EXTENSIONS, recursive=True) 140 | images = [osp.join(path, v) for v in images] 141 | assert images, f'{path} has no valid image file.' 142 | return images 143 | 144 | def __getitem__(self, idx): 145 | """Get item at each call. 146 | 147 | Args: 148 | idx (int): Index for getting each item. 149 | """ 150 | if not self.test_mode: 151 | return self.prepare_train_data(idx) 152 | 153 | return self.prepare_test_data(idx) 154 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import os 4 | import os.path as osp 5 | import time 6 | 7 | import cv2 8 | import mmcv 9 | import torch 10 | from mmcv import Config, DictAction 11 | from mmcv.runner import get_dist_info, init_dist 12 | from mmcv.utils import get_git_hash 13 | 14 | from mmgen import __version__ 15 | from mmgen.apis import set_random_seed, train_model 16 | from mmgen.datasets import build_dataset 17 | from mmgen.models import build_model 18 | from mmgen.utils import collect_env, get_root_logger 19 | 20 | from unpaired_cycler2r.datasets import * 21 | from unpaired_cycler2r.models import * 22 | 23 | cv2.setNumThreads(0) 24 | 25 | 26 | def parse_args(): 27 | parser = argparse.ArgumentParser(description='Train a GAN model') 28 | parser.add_argument('config', help='train config file path') 29 | parser.add_argument('--work-dir', help='the dir to save logs and models') 30 | parser.add_argument( 31 | '--resume-from', help='the checkpoint file to resume from') 32 | parser.add_argument( 33 | '--no-validate', 34 | action='store_true', 35 | help='whether not to evaluate the checkpoint during training') 36 | group_gpus = parser.add_mutually_exclusive_group() 37 | group_gpus.add_argument( 38 | '--gpus', 39 | type=int, 40 | help='number of gpus to use ' 41 | '(only applicable to non-distributed training)') 42 | group_gpus.add_argument( 43 | '--gpu-ids', 44 | type=int, 45 | nargs='+', 46 | help='ids of gpus to use ' 47 | '(only applicable to non-distributed training)') 48 | parser.add_argument('--seed', type=int, default=2021, help='random seed') 49 | parser.add_argument( 50 | '--deterministic', 51 | action='store_true', 52 | help='whether to set deterministic options for CUDNN backend.') 53 | parser.add_argument( 54 | '--cfg-options', 55 | nargs='+', 56 | action=DictAction, 57 | help='override some settings in the used config, the key-value pair ' 58 | 'in xxx=yyy format will be merged into config file.') 59 | parser.add_argument( 60 | '--launcher', 61 | choices=['none', 'pytorch', 'slurm', 'mpi'], 62 | default='none', 63 | help='job launcher') 64 | parser.add_argument('--local_rank', type=int, default=0) 65 | args = parser.parse_args() 66 | if 'LOCAL_RANK' not in os.environ: 67 | os.environ['LOCAL_RANK'] = str(args.local_rank) 68 | 69 | return args 70 | 71 | 72 | def main(): 73 | args = parse_args() 74 | 75 | cfg = Config.fromfile(args.config) 76 | if args.cfg_options is not None: 77 | cfg.merge_from_dict(args.cfg_options) 78 | # import modules from string list. 79 | if cfg.get('custom_imports', None): 80 | from mmcv.utils import import_modules_from_strings 81 | import_modules_from_strings(**cfg['custom_imports']) 82 | # set cudnn_benchmark 83 | if cfg.get('cudnn_benchmark', False): 84 | torch.backends.cudnn.benchmark = True 85 | 86 | # work_dir is determined in this priority: CLI > segment in file > filename 87 | if args.work_dir is not None: 88 | # update configs according to CLI args if args.work_dir is not None 89 | cfg.work_dir = args.work_dir 90 | elif cfg.get('work_dir', None) is None: 91 | # use config filename as default work_dir if cfg.work_dir is None 92 | cfg.work_dir = osp.join('./work_dirs', 93 | osp.splitext(osp.basename(args.config))[0]) 94 | if args.resume_from is not None: 95 | cfg.resume_from = args.resume_from 96 | if args.gpu_ids is not None: 97 | cfg.gpu_ids = args.gpu_ids 98 | else: 99 | cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus) 100 | 101 | # init distributed env first, since logger depends on the dist info. 102 | if args.launcher == 'none': 103 | distributed = False 104 | else: 105 | distributed = True 106 | init_dist(args.launcher, **cfg.dist_params) 107 | # re-set gpu_ids with distributed training mode 108 | _, world_size = get_dist_info() 109 | cfg.gpu_ids = range(world_size) 110 | 111 | # create work_dir 112 | mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) 113 | # dump config 114 | cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) 115 | # init the logger before other steps 116 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 117 | log_file = osp.join(cfg.work_dir, f'{timestamp}.log') 118 | logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) 119 | 120 | # init the meta dict to record some important information such as 121 | # environment info and seed, which will be logged 122 | meta = dict() 123 | # log env info 124 | env_info_dict = collect_env() 125 | env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()]) 126 | dash_line = '-' * 60 + '\n' 127 | logger.info('Environment info:\n' + dash_line + env_info + '\n' + 128 | dash_line) 129 | meta['env_info'] = env_info 130 | meta['config'] = cfg.pretty_text 131 | # log some basic info 132 | logger.info(f'Distributed training: {distributed}') 133 | logger.info(f'Config:\n{cfg.pretty_text}') 134 | 135 | # set random seeds 136 | if args.seed is not None: 137 | logger.info(f'Set random seed to {args.seed}, ' 138 | f'deterministic: {args.deterministic}') 139 | set_random_seed(args.seed, deterministic=args.deterministic) 140 | cfg.seed = args.seed 141 | meta['seed'] = args.seed 142 | meta['exp_name'] = osp.basename(args.config) 143 | 144 | model = build_model( 145 | cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) 146 | 147 | datasets = [build_dataset(cfg.data.train)] 148 | if len(cfg.workflow) == 2: 149 | val_dataset = copy.deepcopy(cfg.data.val) 150 | val_dataset.pipeline = cfg.data.train.pipeline 151 | datasets.append(build_dataset(val_dataset)) 152 | if cfg.checkpoint_config is not None: 153 | # save mmgen version, config file content and class names in 154 | # checkpoints as meta data 155 | cfg.checkpoint_config.meta = dict(mmgen_version=__version__ + 156 | get_git_hash()[:7]) 157 | 158 | train_model( 159 | model, 160 | datasets, 161 | cfg, 162 | distributed=distributed, 163 | validate=(not args.no_validate), 164 | timestamp=timestamp, 165 | meta=meta) 166 | 167 | 168 | if __name__ == '__main__': 169 | main() 170 | -------------------------------------------------------------------------------- /unpaired_cycler2r/models/architectures/unpaired_cycler2r/generator_discriminator.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import torch.distributed as dist 5 | import torch.nn as nn 6 | import torch.nn.init as init 7 | from mmcv.runner import load_checkpoint 8 | 9 | from mmgen.models.builder import MODULES 10 | from mmgen.utils import get_root_logger 11 | from .modules import TinyEncoder, differentiable_histogram 12 | 13 | 14 | @MODULES.register_module() 15 | class inverseISP(nn.Module): 16 | def __init__(self, wb=None, ccm=None, global_size=128): 17 | super().__init__() 18 | wb = [ 19 | [2.0931, 1.6701], 20 | [2.1932, 1.7702], 21 | [2.2933, 1.8703], 22 | [2.3934, 1.9704], 23 | [2.4935, 1.9705] 24 | ] if wb is None else wb 25 | self.wb = nn.Parameter(torch.FloatTensor(wb), requires_grad=True) 26 | 27 | ccm = [ 28 | [[1.67557, -0.52636, -0.04920], 29 | [-0.16799, 1.32824, -0.36024], 30 | [0.03188, -0.22302, 1.59114]], 31 | [[1.57558, -0.52637, -0.04921], 32 | [-0.16798, 1.52823, -0.36023], 33 | [0.031885, -0.42303, 1.39115]] 34 | ] if ccm is None else ccm 35 | self.ccm = nn.Parameter(torch.FloatTensor(ccm), requires_grad=True) 36 | 37 | self.resize_fn = partial(nn.functional.interpolate, size=global_size) 38 | self.global_size = global_size 39 | 40 | self.color_condition_gen = TinyEncoder(3, 2) 41 | self.bright_condition_gen = TinyEncoder(3, 2) 42 | 43 | # inverse ISP 44 | self.ccm_estimator = TinyEncoder(7, 1) 45 | self.bright_estimator = TinyEncoder(3, 1) 46 | self.wb_estimator = TinyEncoder(7, 1) 47 | 48 | # ISP 49 | self.wb_evaluator = TinyEncoder(6, 1) 50 | self.bright_evaluator = TinyEncoder(2, 1) 51 | self.ccm_evaluator = TinyEncoder(6, 1) 52 | 53 | self.counter = 0 54 | 55 | def initialize(self): 56 | pass 57 | 58 | def init_weights(self, pretrained=None, strict=True): 59 | if isinstance(pretrained, str): 60 | logger = get_root_logger() 61 | load_checkpoint(self, pretrained, strict=strict, logger=logger) 62 | elif pretrained is None: 63 | pass 64 | else: 65 | raise TypeError("'pretrained' must be a str or None. " 66 | f'But received {type(pretrained)}.') 67 | 68 | def safe_inverse_gain(self, x, gain): 69 | gray = x.mean(1, keepdim=True) 70 | inflection = 0.9 71 | mask = torch.maximum(gray - inflection, torch.zeros_like(gray)) / (1 - inflection) 72 | mask = mask ** 2 73 | mask = torch.clamp(mask, 0, 1) 74 | safe_gain = torch.maximum(mask + (1 - mask) * gain, gain) 75 | x = x * safe_gain 76 | return x 77 | 78 | def rgb2raw(self, x, condition): 79 | b = x.shape[0] 80 | gs = self.global_size 81 | ccm_condition = torch.ones([b, self.ccm.shape[0], 1, gs, gs], device=x.device) 82 | ccm_condition *= condition[:, 0].view([b, 1, 1, 1, 1]) 83 | ccm_condition = ccm_condition.view([-1, 1, gs, gs]) 84 | 85 | b_condition = torch.ones([b, 1, gs, gs], device=x.device) 86 | b_condition *= condition[:, 1].view([b, 1, 1, 1]) 87 | 88 | wb_condition = torch.ones([b, self.wb.shape[0], 1, gs, gs], device=x.device) 89 | wb_condition *= condition[:, 0].view([b, 1, 1, 1, 1]) 90 | wb_condition = wb_condition.view([-1, 1, gs, gs]) 91 | 92 | ## inverse gamma 93 | x = torch.maximum(x, 1e-8 * torch.ones_like(x)) 94 | x = torch.where(x > 0.04045, ((x + 0.055) / 1.055) ** 2.4, x / 12.92) 95 | 96 | ## inverse ccm 97 | x_resize = self.resize_fn(x) 98 | ccm = self.ccm / self.ccm.reshape([-1, 3]).sum(1).reshape([-1, 3, 1]) 99 | inv_ccm = torch.linalg.pinv(ccm.transpose(-2, -1)) 100 | ccm_preview = torch.einsum('bchw,ncj->bnjhw', x_resize, inv_ccm) 101 | ccm_preview = ccm_preview.reshape([-1, 3, gs, gs]) 102 | ccm_preview = torch.cat([ccm_preview, ccm_preview**2, ccm_condition], 1) 103 | ccm_prob = self.ccm_estimator(ccm_preview).view([b, -1]) 104 | ccm_prob = torch.softmax(ccm_prob, 1) 105 | ccm = ccm[None] * ccm_prob.view([b, -1, 1, 1]) # [B, N, 3, 3] 106 | ccm = ccm.sum(1) # [B, 3, 3] 107 | inv_ccm = torch.linalg.pinv(ccm.transpose(-2, -1)) # [B, 3, 3] 108 | x = torch.einsum('bchw,bcj->bjhw', [x, inv_ccm]) # [B, 3, H, W] 109 | 110 | ## inverse brightness adjustment 111 | x_resize = self.resize_fn(x) 112 | x_resize = torch.mean(x_resize, 1, keepdim=True) 113 | x_resize = torch.cat([x_resize, x_resize**2, b_condition], 1) 114 | bright = torch.tanh(self.bright_estimator(x_resize)) 115 | bright_adjust = bright * 0.8 + 0.8 116 | x = self.safe_inverse_gain(x, bright_adjust.view([b, 1, 1, 1])) # [B, 3, H, W] 117 | 118 | ## inverse awb 119 | x_resize = self.resize_fn(x) 120 | gain = torch.ones([b, self.wb.shape[0], 3]).to(x.device) 121 | gain[..., (0, 2)] = 1 / self.wb[None] # [B, N, 3] 122 | wb_preview = x_resize[:, None].repeat([1, self.wb.shape[0], 1, 1, 1]) # [B, N, 3, 128, 128] 123 | wb_preview = wb_preview.reshape([-1, 3, gs, gs]) # [B*N, 3, 128, 128] 124 | gain = gain.reshape([-1, 3, 1, 1]) # [B*N, 3, 1, 1] 125 | wb_preview = self.safe_inverse_gain(wb_preview, gain) 126 | wb_preview = torch.cat([wb_preview, wb_preview**2, wb_condition], 1) 127 | 128 | wb_prob = self.wb_estimator(wb_preview).view([b, -1, 1, 1, 1]) 129 | wb_prob = torch.softmax(wb_prob, 1) 130 | gain = gain.view([b, -1, 3, 1, 1]) * wb_prob # [B, N, 3, 1, 1] 131 | gain = gain.sum(1) # [B, 3, 1, 1] 132 | x = self.safe_inverse_gain(x, gain) # [B, 3, H, W] 133 | 134 | # mosaic 135 | # x = self.mosaic(x) 136 | return x 137 | 138 | def raw2rgb(self, x): 139 | # global 140 | b = x.shape[0] 141 | gs = self.global_size 142 | 143 | ## demosaic 144 | # x = self.demosaic(x) 145 | 146 | ## white balance 147 | x_resize = self.resize_fn(x) 148 | x_resize = x_resize[:, None].repeat([1, self.wb.shape[0], 1, 1, 1]) # [B, N, 3, 128, 128] 149 | x_resize[:, :, (0, 2)] = x_resize[:, :, (0, 2)] * self.wb.view([1, -1, 2, 1, 1]) # [B, N, 3, 128, 128] 150 | x_resize = x_resize.reshape([-1, 3, gs, gs]) # [B*N, 3, 128, 128] 151 | x_resize = torch.cat([x_resize, x_resize**2], 1) 152 | wb_prob = self.wb_evaluator(x_resize).reshape([b, self.wb.shape[0]]) # [B, N] 153 | wb_prob = torch.softmax(wb_prob, 1) 154 | wb_prob = wb_prob.view([b, -1, 1, 1, 1]) # [B, N, 1, 1, 1] 155 | wb = self.wb.view([1, -1, 2, 1, 1]) # [B, N, 2, 1, 1] 156 | wb = (wb * wb_prob).sum(1) # [B, 2, 1, 1] 157 | x[:, (0, 2)] = x[:, (0, 2)] * wb 158 | 159 | ## brightness adjustment 160 | x_resize = self.resize_fn(x) 161 | x_resize = torch.mean(x_resize, 1, keepdim=True) 162 | x_resize = torch.cat([x_resize, x_resize**2], 1) 163 | bright = torch.tanh(self.bright_evaluator(x_resize)) 164 | bright_adjust = 0.8 + bright * 0.8 165 | bright_adjust = torch.abs(bright_adjust) 166 | x = x / bright_adjust[:, :, None, None] 167 | 168 | ## ccm 169 | x_resize = self.resize_fn(x) 170 | ccm = self.ccm / self.ccm.sum(2, keepdims=True) 171 | ccm = ccm.transpose(1, 2) 172 | ccm_preview = torch.einsum('bchw,ncj->bnjhw', x_resize, ccm) # [B, N, 3, 128, 128] 173 | ccm_preview = ccm_preview.reshape([-1, 3, gs, gs]) # [B*N, 3, 128, 128] 174 | ccm_preview = torch.cat([ccm_preview, ccm_preview**2], 1) 175 | ccm_prob = self.ccm_evaluator(ccm_preview).reshape([b, self.ccm.shape[0]]) # [B, N] 176 | ccm_prob = torch.softmax(ccm_prob, 1) 177 | ccm_prob = ccm_prob.view([b, -1, 1, 1]) # [B, N, 1, 1] 178 | ccm = ccm[None] * ccm_prob # [B, N, 3, 3] 179 | ccm = ccm.sum(1) # [B, 3, 3] 180 | x = torch.einsum('bchw,bcj->bjhw', x, ccm) # [B, 3, H, W] 181 | 182 | ## gamma correction 183 | x = torch.maximum(x, 1e-8 * torch.ones_like(x)) 184 | x = torch.where(x <= 0.0031308, 12.92 * x, 1.055 * torch.pow(x, 1 / 2.4) - 0.055) 185 | return x 186 | 187 | def forward(self, x, condition=None, rev=False): 188 | x = x.clone() 189 | # inverse ISP 190 | if not rev: 191 | x = self.rgb2raw(x, condition) 192 | log_dict = {} 193 | log_dict['wb'] = self.wb.detach().cpu() 194 | log_dict['ccm'] = self.ccm.detach().cpu() 195 | # log 196 | if self.training and (not torch.distributed.is_initialized() or dist.get_rank() == 0): 197 | self.counter += 1 198 | if self.counter == 200: 199 | self.counter = 0 200 | for k, v in log_dict.items(): 201 | print(k, ':', '\n', v) 202 | # ISP 203 | else: 204 | x = self.raw2rgb(x) 205 | 206 | return x 207 | 208 | 209 | class FCDiscriminator(nn.Module): 210 | 211 | def __init__(self, in_channels=3, ndf=64): 212 | super(FCDiscriminator, self).__init__() 213 | self.conv1 = nn.Conv2d(in_channels, ndf, kernel_size=4, stride=2, padding=1) 214 | self.conv2 = nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1) 215 | self.conv3 = nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1) 216 | self.conv4 = nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1) 217 | self.classifier = nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=2, padding=1) 218 | 219 | self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=False) 220 | 221 | def forward(self, x): 222 | x = self.conv1(x) 223 | x = self.leaky_relu(x) 224 | x = self.conv2(x) 225 | x = self.leaky_relu(x) 226 | x = self.conv3(x) 227 | x = self.leaky_relu(x) 228 | x = self.conv4(x) 229 | x = self.leaky_relu(x) 230 | x = self.classifier(x) 231 | x = x.reshape([x.shape[0], -1]).mean(1) 232 | return x 233 | 234 | 235 | @MODULES.register_module() 236 | class HistAwareDiscriminator(nn.Module): 237 | 238 | def __init__(self, in_channels=3, ndf=64, bins=255, global_size=128): 239 | super(HistAwareDiscriminator, self).__init__() 240 | self.bins = bins 241 | self.bright_fcd = nn.Sequential( 242 | nn.Linear(bins * in_channels, 1024), 243 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 244 | nn.Linear(1024, 1024), 245 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 246 | nn.Linear(1024, 256), 247 | nn.Linear(256, 1) 248 | ) 249 | self.local_fcd = FCDiscriminator(in_channels=in_channels, ndf=ndf) 250 | self.resize_fn = partial(nn.functional.interpolate, size=(global_size, global_size)) 251 | self.uv_fcd = FCDiscriminator(in_channels=2, ndf=ndf) 252 | 253 | def init_weights(self, pretrained=None, strict=True): 254 | if isinstance(pretrained, str): 255 | logger = get_root_logger() 256 | load_checkpoint(self, pretrained, strict=strict, logger=logger) 257 | elif pretrained is None: 258 | for m in self.modules(): 259 | if isinstance(m, nn.Conv2d): 260 | init.xavier_normal_(m.weight) 261 | m.weight.data *= 1. # for residual block 262 | if m.bias is not None: 263 | m.bias.data.zero_() 264 | elif isinstance(m, nn.Linear): 265 | init.xavier_normal_(m.weight) 266 | m.weight.data *= 1. 267 | if m.bias is not None: 268 | m.bias.data.zero_() 269 | elif isinstance(m, nn.BatchNorm2d): 270 | init.constant_(m.weight, 1) 271 | init.constant_(m.bias.data, 0.0) 272 | else: 273 | raise TypeError("'pretrained' must be a str or None. " 274 | f'But received {type(pretrained)}.') 275 | 276 | def forward(self, x): 277 | local_judge = self.local_fcd(x.clone()) 278 | 279 | hist = differentiable_histogram(x, self.bins) # [B, C, 256] 280 | hist /= x.shape[2] * x.shape[3] 281 | bright_judge = self.bright_fcd(hist.reshape([hist.shape[0], -1])) 282 | 283 | x = self.resize_fn(torch.clamp(x.clone(), 1e-6, 1)) 284 | u, v = torch.log(x[:, 1] / x[:, 0]), torch.log(x[:, 1] / x[:, 2]) # [b, h, w] 285 | uv_judge = self.uv_fcd(torch.cat([u[:, None], v[:, None]], 1)) 286 | 287 | combine_judge = local_judge + bright_judge + uv_judge 288 | return combine_judge 289 | -------------------------------------------------------------------------------- /unpaired_cycler2r/models/translation_models/unpaired_cycler2r.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from collections import OrderedDict 3 | 4 | import torch 5 | import torch.distributions as D 6 | from torch.nn.parallel.distributed import _find_tensors 7 | 8 | from mmgen.models.common import set_requires_grad 9 | from mmgen.models.builder import MODELS, build_module 10 | from mmgen.models.translation_models.cyclegan import CycleGAN 11 | 12 | 13 | @MODELS.register_module() 14 | class UnpairedCycleR2R(CycleGAN): 15 | """CycleGAN model for unpaired image-to-image translation. 16 | 17 | Ref: 18 | Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial 19 | Networks 20 | """ 21 | 22 | def __init__(self, *args, **kwargs): 23 | self.lambda_bright_diversity = kwargs.pop('lambda_bright_diversity', None) 24 | self.lambda_color_diversity = kwargs.pop('lambda_color_diversity', None) 25 | self.c_condition = 2 26 | 27 | pretrained = kwargs.pop('pretrained', None) 28 | kwargs['pretrained'] = None 29 | super().__init__(*args, **kwargs) 30 | self.l1 = torch.nn.L1Loss() 31 | del self.generators 32 | torch.cuda.empty_cache() 33 | self.generator = build_module(kwargs['generator']) 34 | if 'pretrained' in kwargs['generator'].keys(): 35 | self.generator.init_weights(kwargs['generator']['pretrained']) 36 | 37 | if pretrained is not None: 38 | self.init_weights(pretrained) 39 | 40 | def init_weights(self, pretrained): 41 | """Placeholder for init weights""" 42 | if pretrained is not None: 43 | state_dict = torch.load(pretrained, map_location='cpu') 44 | if 'state_dict' in state_dict.keys(): 45 | state_dict = state_dict['state_dict'] 46 | print('Direct loading from {}'.format(pretrained)) 47 | try: 48 | self.load_state_dict(state_dict, strict=True) 49 | except: 50 | pass 51 | updated_state_dict = OrderedDict() 52 | cache_init_keys = [] 53 | non_same = {} 54 | for k in self.state_dict(): 55 | if k not in state_dict.keys(): 56 | non_same[k] = (None, self.state_dict()[k].shape) 57 | elif self.state_dict()[k].shape == state_dict[k].shape: 58 | if 'init' not in k: 59 | updated_state_dict[k] = state_dict[k] 60 | else: 61 | cache_init_keys.append(k) 62 | else: 63 | non_same[k] = (state_dict[k].shape, self.state_dict()[k].shape) 64 | if len(non_same.keys()) == 0: 65 | for k in cache_init_keys: 66 | updated_state_dict[k] = state_dict[k] 67 | else: 68 | print('Not all state dict shape same') 69 | for k, v in non_same.items(): 70 | print('{}: {} -> {}'.format(k, v[0], v[1])) 71 | self.load_state_dict(updated_state_dict, strict=False) 72 | else: 73 | # todo add default init weights 74 | pass 75 | 76 | def translation(self, image, condition, target_domain=None, **kwargs): 77 | """Translation Image to target style. 78 | 79 | Args: 80 | image (tensor): Image tensor with a shape of (N, C, H, W). 81 | target_domain (str, optional): Target domain of output image. 82 | Default to None. 83 | 84 | Returns: 85 | dict: Image tensor of target style. 86 | """ 87 | if target_domain is None: 88 | target_domain = self._default_domain 89 | if target_domain == 'raw': 90 | outputs = self.generator(image, condition=condition, rev=False) 91 | else: 92 | outputs = self.generator(image, condition=condition, rev=True) 93 | return outputs 94 | 95 | def _get_target_generator(self, domain): 96 | raise NotImplementedError 97 | 98 | def _get_condition(self, img): 99 | # condition = torch.randn([img.shape[0], self.c_condition]).to(img.device).detach() 100 | mean_var = self.generator.color_condition_gen(img) 101 | m = D.Normal(mean_var[:, 0], torch.clamp_min(torch.abs(mean_var[:, 1]), 1e-6)) 102 | color_condition = m.sample() 103 | mean_var = self.generator.bright_condition_gen(img) 104 | m = D.Normal(mean_var[:, 0], torch.clamp_min(torch.abs(mean_var[:, 1]), 1e-6)) 105 | bright_condition = m.sample() 106 | condition = torch.cat([color_condition[:, None], bright_condition[:, None]], 1) 107 | return condition 108 | 109 | def forward_train(self, img, target_domain, **kwargs): 110 | condition = self._get_condition(img) 111 | target= self.translation(img, condition, target_domain=target_domain, **kwargs) 112 | results = dict(source=img, target=target, condition=condition) 113 | return results 114 | 115 | def forward_test(self, img, target_domain, **kwargs): 116 | # This is a trick for CycleGAN 117 | # ref: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/e1bdf46198662b0f4d0b318e24568205ec4d7aee/test.py#L54 # noqa 118 | self.train() 119 | condition = self._get_condition(img).detach() 120 | target = self.translation(img, condition, target_domain=target_domain, **kwargs) 121 | results = dict(source=img.cpu(), target=target.cpu(), condition=condition.cpu()) 122 | return results 123 | 124 | def train_step(self, 125 | data_batch, 126 | optimizer, 127 | ddp_reducer=None, 128 | running_status=None): 129 | """Training step function. 130 | 131 | Args: 132 | data_batch (dict): Dict of the input data batch. 133 | optimizer (dict[torch.optim.Optimizer]): Dict of optimizers for 134 | the generators and discriminators. 135 | ddp_reducer (:obj:`Reducer` | None, optional): Reducer from ddp. 136 | It is used to prepare for ``backward()`` in ddp. Defaults to 137 | None. 138 | running_status (dict | None, optional): Contains necessary basic 139 | information for training, e.g., iteration number. Defaults to 140 | None. 141 | 142 | Returns: 143 | dict: Dict of loss, information for logger, the number of samples\ 144 | and results for visualization. 145 | """ 146 | # get running status 147 | if running_status is not None: 148 | curr_iter = running_status['iteration'] 149 | else: 150 | # dirty walkround for not providing running status 151 | if not hasattr(self, 'iteration'): 152 | self.iteration = 0 153 | curr_iter = self.iteration 154 | 155 | # forward generators 156 | outputs = dict() 157 | for target_domain in self._reachable_domains: 158 | # fetch data by domain 159 | source_domain = self.get_other_domains(target_domain)[0] 160 | img = data_batch[f'img_{source_domain}'] 161 | # translation process 162 | results = self(img, test_mode=False, target_domain=target_domain) 163 | outputs[f'real_{source_domain}'] = results['source'] 164 | outputs[f'fake_{target_domain}'] = results['target'] 165 | outputs[f'fake_{target_domain}_condition'] = results['condition'] 166 | # cycle process 167 | results = self( 168 | results['target'], 169 | test_mode=False, 170 | target_domain=source_domain) 171 | outputs[f'cycle_{source_domain}'] = results['target'] 172 | results = self(data_batch['img_rgb'], test_mode=False, target_domain='raw') 173 | outputs['fake_raw_diversity'] = results['target'] 174 | outputs['fake_raw_diversity_condition'] = results['condition'] 175 | 176 | log_vars = dict() 177 | 178 | # discriminators 179 | set_requires_grad(self.discriminators, True) 180 | # optimize 181 | optimizer['discriminators'].zero_grad() 182 | loss_d, log_vars_d = self._get_disc_loss(outputs) 183 | log_vars.update(log_vars_d) 184 | if ddp_reducer is not None: 185 | ddp_reducer.prepare_for_backward(_find_tensors(loss_d)) 186 | loss_d.backward(retain_graph=True) 187 | optimizer['discriminators'].step() 188 | 189 | # generators, no updates to discriminator parameters. 190 | if (curr_iter % self.disc_steps == 0 191 | and curr_iter >= self.disc_init_steps): 192 | set_requires_grad(self.discriminators, False) 193 | # optimize 194 | optimizer['generator'].zero_grad() 195 | loss_g, log_vars_g = self._get_gen_loss(outputs) 196 | log_vars.update(log_vars_g) 197 | if ddp_reducer is not None: 198 | ddp_reducer.prepare_for_backward(_find_tensors(loss_g)) 199 | loss_g.backward(retain_graph=True) 200 | optimizer['generator'].step() 201 | 202 | if hasattr(self, 'iteration'): 203 | self.iteration += 1 204 | 205 | image_results = dict() 206 | for domain in self._reachable_domains: 207 | image_results[f'real_{domain}'] = outputs[f'real_{domain}'].cpu() 208 | image_results[f'fake_{domain}'] = outputs[f'fake_{domain}'].cpu() 209 | image_results[f'cycle_{domain}'] = outputs[f'cycle_{domain}'].cpu() 210 | image_results['fake_raw_diversity'] = outputs['fake_raw_diversity'].cpu() 211 | results = dict( 212 | log_vars=log_vars, 213 | num_samples=len(outputs[f'real_{domain}']), 214 | results=image_results) 215 | 216 | return results 217 | 218 | def _get_gen_loss(self, outputs): 219 | """Backward function for the generators. 220 | 221 | Args: 222 | outputs (dict): Dict of forward results. 223 | 224 | Returns: 225 | dict: Generators' loss and loss dict. 226 | """ 227 | discriminators = self.get_module(self.discriminators) 228 | 229 | losses = dict() 230 | for domain in self._reachable_domains: 231 | # Identity reconstruction for generators 232 | # rev = True if domain == 'rgb' else False 233 | # outputs[f'identity_{domain}'] = self.generator(outputs[f'real_{domain}'], rev=rev) 234 | # GAN loss for generators 235 | fake_pred = discriminators[domain](outputs[f'fake_{domain}']) 236 | losses[f'loss_gan_g_{domain}'] = self.gan_loss( 237 | fake_pred, target_is_real=True, is_disc=False) 238 | 239 | # gen auxiliary loss 240 | if self.with_gen_auxiliary_loss: 241 | for loss_module in self.gen_auxiliary_losses: 242 | loss_ = loss_module(outputs) 243 | if loss_ is None: 244 | continue 245 | # the `loss_name()` function return name as 'loss_xxx' 246 | if loss_module.loss_name() in losses: 247 | losses[loss_module.loss_name( 248 | )] = losses[loss_module.loss_name()] + loss_ 249 | else: 250 | losses[loss_module.loss_name()] = loss_ 251 | 252 | if self.lambda_color_diversity is not None: 253 | b = outputs['fake_raw'].shape[0] 254 | 255 | im1, im2 = outputs['fake_raw'], outputs['fake_raw_diversity'] 256 | l1_bright = torch.abs(im1[:, 1]-im2[:, 1]).reshape([b, -1]).mean(1) 257 | bright_condition = torch.abs( 258 | outputs['fake_raw_condition'][:, 1] - outputs['fake_raw_diversity_condition'][:, 1]) 259 | _loss = torch.abs((l1_bright/bright_condition)-0.2).mean() 260 | losses['loss_bright_diversity'] = self.lambda_bright_diversity * _loss 261 | 262 | # im1, im2 = im1.mean(-1).mean(-1), im2.mean(-1).mean(-1) 263 | im1, im2 = torch.clamp_min(im1, 1e-6), torch.clamp(im2, 1e-6) 264 | u1, v1 = im1[:, 1] / im1[:, 0], im1[:, 1] / im1[:, 2] # [b, h, w] 265 | u2, v2 = im2[:, 1] / im2[:, 0], im2[:, 1] / im2[:, 2] # [b, h, w] 266 | u1, v1 = torch.log(u1), torch.log(v1) 267 | u2, v2 = torch.log(u2), torch.log(v2) 268 | l1_uv = torch.abs(u1-u2) + torch.abs(v1-v2) 269 | l1_uv = l1_uv.view([b, -1])/2 270 | color_condition = torch.abs( 271 | outputs['fake_raw_condition'][:, 0] - outputs['fake_raw_diversity_condition'][:, 0])[:, None] 272 | _loss = torch.abs((l1_uv/color_condition)-2).mean() 273 | losses['loss_color_diversity'] = self.lambda_color_diversity * _loss 274 | 275 | loss_g, log_vars_g = self._parse_losses(losses) 276 | log_vars_g['diversity'] = float(self.l1(outputs['fake_raw'], outputs['fake_raw_diversity'])) 277 | return loss_g, log_vars_g 278 | --------------------------------------------------------------------------------