├── .gitignore ├── GamutNet_CIC2021_pretrained_model ├── checkpoints │ └── epoch=5.ckpt ├── events.out.tfevents.1603703529.228b4f259e16.2057.0 └── hparams.yaml ├── README.md ├── datasets ├── __init__.py ├── data_modules.py ├── datasets.py ├── frontend_dataset.py ├── iterable_patches_dataset.py └── patch_utils.py ├── figures └── overview_gamutnet.png ├── gen_split_dataset.py ├── icc_profiles ├── ProPhoto.icm └── __init__.py ├── inference ├── __init__.py └── inference_agent.py ├── metrics ├── __init__.py ├── calc_deltaE.py ├── calc_mae.py ├── calc_mse.py ├── calc_psnr.py └── calc_rmse.py ├── models ├── __init__.py ├── networks │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── deep_wb_blocks.cpython-38.pyc │ │ ├── wide_gamut_blocks.cpython-38.pyc │ │ └── wide_gamut_net.cpython-38.pyc │ ├── deep_wb_blocks.py │ ├── wide_gamut_blocks.py │ └── wide_gamut_net.py └── wide_gamut_net_pl.py ├── piecewise ├── __init__.py ├── piecewise.py ├── pw_mapping.npy └── utils.py ├── requirements.txt ├── run_inference.py ├── trainer_main.py └── utils ├── color.py ├── mask.py └── metric.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Models output 2 | output_images/ 3 | output_figures/ 4 | tensorboard_logs/ 5 | 6 | # Byte-compiled / optimized / DLL files 7 | **/__pycache__/ 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | 137 | .vscode 138 | .idea 139 | 140 | # PyTorch Lightning Logs 141 | lightning_logs 142 | -------------------------------------------------------------------------------- /GamutNet_CIC2021_pretrained_model/checkpoints/epoch=5.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gamut-mapping/GamutNet/68fe72858b7963f82b7b324edffb5a534394f46c/GamutNet_CIC2021_pretrained_model/checkpoints/epoch=5.ckpt -------------------------------------------------------------------------------- /GamutNet_CIC2021_pretrained_model/events.out.tfevents.1603703529.228b4f259e16.2057.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gamut-mapping/GamutNet/68fe72858b7963f82b7b324edffb5a534394f46c/GamutNet_CIC2021_pretrained_model/events.out.tfevents.1603703529.228b4f259e16.2057.0 -------------------------------------------------------------------------------- /GamutNet_CIC2021_pretrained_model/hparams.yaml: -------------------------------------------------------------------------------- 1 | accelerator: null 2 | accumulate_grad_batches: 1 3 | amp_backend: native 4 | amp_level: O2 5 | auto_lr_find: false 6 | auto_scale_batch_size: false 7 | auto_select_gpus: false 8 | automatic_optimization: true 9 | batch_size: 64 10 | benchmark: false 11 | check_val_every_n_epoch: 1 12 | checkpoint_callback: true 13 | default_root_dir: null 14 | deterministic: false 15 | distributed_backend: null 16 | fast_dev_run: false 17 | flush_logs_every_n_steps: 100 18 | gpus: 1 19 | gradient_clip_val: 0 20 | hint_mode: o2o_all 21 | input_channels: 4 22 | learning_rate: 0.0001 23 | limit_test_batches: 1.0 24 | limit_train_batches: 1.0 25 | limit_val_batches: 1.0 26 | limiting_output_range: true 27 | log_every_n_steps: 50 28 | log_gpu_memory: null 29 | logger: true 30 | max_epochs: 1000 31 | max_patches_per_image: 32000 32 | max_steps: null 33 | min_epochs: 1 34 | min_steps: null 35 | model_size: small 36 | num_nodes: 1 37 | num_processes: 1 38 | num_sanity_val_steps: 2 39 | overfit_batches: 0.0 40 | patch_size: 41 | - 64 42 | - 64 43 | precision: 16 44 | prepare_data_per_node: true 45 | process_position: 0 46 | profiler: null 47 | progress_bar_refresh_rate: 1 48 | reload_dataloaders_every_epoch: false 49 | replace_sampler_ddp: true 50 | resume_from_checkpoint: null 51 | split_path: /root/share/split 52 | sync_batchnorm: false 53 | terminate_on_nan: false 54 | test_input: test-input.txt 55 | test_num_workers: 32 56 | test_target: test-target.txt 57 | tpu_cores: !!python/name:pytorch_lightning.utilities.argparse_utils._gpus_arg_default '' 58 | track_grad_norm: -1 59 | train_input: train-input.txt 60 | train_num_workers: 32 61 | train_target: train-target.txt 62 | truncated_bptt_steps: null 63 | using_residual: true 64 | val_check_interval: 1.0 65 | val_input: val-input.txt 66 | val_num_workers: 32 67 | val_target: val-target.txt 68 | weights_save_path: null 69 | weights_summary: top 70 | hparams: 71 | accelerator: null 72 | accumulate_grad_batches: 1 73 | amp_backend: native 74 | amp_level: O2 75 | auto_lr_find: false 76 | auto_scale_batch_size: false 77 | auto_select_gpus: false 78 | automatic_optimization: true 79 | batch_size: 64 80 | benchmark: false 81 | check_val_every_n_epoch: 1 82 | checkpoint_callback: true 83 | default_root_dir: null 84 | deterministic: false 85 | distributed_backend: null 86 | fast_dev_run: false 87 | flush_logs_every_n_steps: 100 88 | gpus: 1 89 | gradient_clip_val: 0 90 | hint_mode: o2o_all 91 | input_channels: 4 92 | learning_rate: 0.0001 93 | limit_test_batches: 1.0 94 | limit_train_batches: 1.0 95 | limit_val_batches: 1.0 96 | limiting_output_range: true 97 | log_every_n_steps: 50 98 | log_gpu_memory: null 99 | logger: true 100 | max_epochs: 1000 101 | max_patches_per_image: 32000 102 | max_steps: null 103 | min_epochs: 1 104 | min_steps: null 105 | model_size: small 106 | num_nodes: 1 107 | num_processes: 1 108 | num_sanity_val_steps: 2 109 | overfit_batches: 0.0 110 | patch_size: 111 | - 64 112 | - 64 113 | precision: 16 114 | prepare_data_per_node: true 115 | process_position: 0 116 | profiler: null 117 | progress_bar_refresh_rate: 1 118 | reload_dataloaders_every_epoch: false 119 | replace_sampler_ddp: true 120 | resume_from_checkpoint: null 121 | split_path: /root/share/split 122 | sync_batchnorm: false 123 | terminate_on_nan: false 124 | test_input: test-input.txt 125 | test_num_workers: 32 126 | test_target: test-target.txt 127 | tpu_cores: !!python/name:pytorch_lightning.utilities.argparse_utils._gpus_arg_default '' 128 | track_grad_norm: -1 129 | train_input: train-input.txt 130 | train_num_workers: 32 131 | train_target: train-target.txt 132 | truncated_bptt_steps: null 133 | using_residual: true 134 | val_check_interval: 1.0 135 | val_input: val-input.txt 136 | val_num_workers: 32 137 | val_target: val-target.txt 138 | weights_save_path: null 139 | weights_summary: top 140 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GamutNet: Restoring-Wide-Gamut-Colors-for-Camera-Captured-Images 2 | _[Hoang M. Le](https://www.linkedin.com/in/hminle/)_, _[Taehong Jeong](https://github.com/taehongjeong)_, _[Abdelrahman Abdelhamed](https://abdokamel.github.io/)_, _[Hyun Joon Shin](https://www.linkedin.com/in/hyun-joon-shin-35aa604b/)_ and _[Michael S. Brown](http://www.cse.yorku.ca/~mbrown/)_ 3 | 4 | Repo for the paper: 5 | [GamutNet: Restoring Wide-Gamut Colors for Camera-Captured Images](https://library.imaging.org/admin/apis/public/api/ist/website/downloadArticle/cic/29/1/art00003) 6 | 7 | ## Overview 8 | ![overview_figure](./figures/overview_gamutnet.png) 9 | 10 | ## Dataset: 11 | 12 | 1. Use these links to download our dataset: 13 | - ProPhoto Images: [prop-8bpc.zip 106.4GB](https://ln5.sync.com/dl/5b776c6e0/bkyjim85-3vtf5qdu-x2v7ymse-6jkjbr9i) 14 | - sRGB Images: [srgb-8bpc-1.rar 57.1GB](https://ln5.sync.com/dl/879f8b3f0/fj47dzcm-ewjnc9f9-zeh8paxa-9s4bdmbf) [srgb-8bpc-2.rar 55.9GB](https://ln5.sync.com/dl/748574d10/fjs2kb24-rv7ftvix-nt8amuan-vna52ah2) 15 | - split files: [split.zip 65KB](https://ln5.sync.com/dl/1ea117750/iyvpkiix-qiudb344-gy5ye2nc-x8ishes9) 16 | 17 | 2. You then need to uncompress ProPhoto to the folder `prop-8bpc` and all two sRGB parts into the folder `srgb-8bpc`. Each of them has 5000 images. 18 | 3. You need to modify the path of the above folders in `split txt files` 19 | 20 | ## How to start training 21 | 22 | 1. Some basic running scripts: 23 | 24 | ``` 25 | python trainer_main.py --split_path SPLIT_PATH 26 | ``` 27 | 28 | ``` 29 | python trainer_main.py --split_path ./split_output --model_size small --patch_size 64 64 --batch_size 64 --gpus "1" --train_num_workers 32 --val_num_workers 32 --max_patches_per_image 1000 --max_epochs 20 30 | ``` 31 | 32 | 2. Specifying `hint_mode`: 33 | 34 | ``` 35 | ... --hint_mode HINT_MODE --input_channels INPUT_CHANNELS 36 | ``` 37 | 38 | - `--hint_mode none --input_channels 3`: no hint will be provided. 39 | - `--hint_mode o2o_all --input_channels 4`: one-to-one mask (white, in-gamut, and black) will be provided. 40 | - `--hint_mode o2o_rgb --input_channels 6`: per-channel one-to-one mask will be provided. 41 | - `--hint_mode pw_27 --input_channels 4`: the pw_mask_27 normalized by 26 will be provided. 42 | 43 | 44 | 3. Specifying `patch_size': 45 | ``` 46 | ... --patch_size PATCH_HEIGHT PATCH_WIDTH 47 | ``` 48 | 49 | - e.g. `--patch_size 32 32` 50 | - e.g. `--patch_size 64 64` 51 | - e.g. `--patch_size 128 128` 52 | 53 | 54 | 4. Specifying `max_patches_per_image`: 55 | ``` 56 | ... --max_patches_per_image MAX_PATCHES_PER_IMAGE 57 | ``` 58 | 59 | - e.g. given `batch_size=32`, `--max_patches_per_image 32000` will yield 1000 iterations per image 60 | - e.g. given `batch_size=128`, `--max_patches_per_image 12800`: will yield 100 iterations per image 61 | 62 | 63 | 5. Specifying `num_workers` for `DataLoader`: 64 | 65 | - `--train_num_workers`: default is 2 66 | - `--val_num_workers`: default is 1 67 | - `--test_num_workers`: default is 1 68 | 69 | ## Citation 70 | 71 | ```bibtex 72 | @inproceedings{le2021gamutnet, 73 | title={GamutNet: Restoring Wide-Gamut Colors for Camera-Captured Images}, 74 | author={Le, Hoang and Jeong, Taehong and Abdelhamed, Abdelrahman and Shin, Hyun Joon and Brown, Michael S}, 75 | booktitle={Color and Imaging Conference}, 76 | volume={2021}, 77 | number={29}, 78 | pages={7--12}, 79 | year={2021}, 80 | organization={Society for Imaging Science and Technology} 81 | } 82 | ``` 83 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gamut-mapping/GamutNet/68fe72858b7963f82b7b324edffb5a534394f46c/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/data_modules.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from argparse import ArgumentParser 3 | from pathlib import Path 4 | from abc import ABC 5 | 6 | from pytorch_lightning import LightningDataModule 7 | from torch.utils.data import DataLoader, Subset, get_worker_info 8 | from torchvision import transforms 9 | from typing import Optional 10 | 11 | from .datasets import Images, PairedDataset 12 | from .frontend_dataset import FrontendDataset 13 | from .iterable_patches_dataset import IterablePatchesDataset 14 | 15 | def _is_valid_file(filename): 16 | return any([filename.endswith(ext) for ext in ['.png', '.PNG', '.tif', '.TIF']]) 17 | 18 | 19 | def _read_split(split_path, filename): 20 | return sorted(list(filter(len, Path(split_path, filename).read_text().split(sep='\n')))) 21 | 22 | class BaseGamutNetDataModule(LightningDataModule, ABC): 23 | 24 | def __init__(self, hparams): 25 | super().__init__() 26 | self.save_hyperparameters(hparams) 27 | self.train_dataset = None 28 | self.val_dataset = None 29 | self.test_dataset = None 30 | self.transform = None 31 | 32 | def _create_dataloader(self, dataset, num_workers, pin_memory=False): 33 | return DataLoader(dataset, 34 | batch_size=self.hparams.batch_size, 35 | num_workers=num_workers, 36 | worker_init_fn=self.worker_init_fn, 37 | pin_memory=pin_memory) 38 | 39 | def train_dataloader(self): 40 | return self._create_dataloader(self._create_patches(self.train_dataset), 41 | num_workers=self.hparams.train_num_workers, pin_memory=True) 42 | 43 | def val_dataloader(self): 44 | return self._create_dataloader(self._create_patches(self.val_dataset), 45 | num_workers=self.hparams.val_num_workers) 46 | 47 | def test_dataloader(self): 48 | return self._create_dataloader(self._create_patches(self.test_dataset), 49 | num_workers=self.hparams.test_num_workers) 50 | 51 | @staticmethod 52 | def worker_init_fn(_): 53 | worker_info = get_worker_info() 54 | worker_id = worker_info.id 55 | patch_dataset = worker_info.dataset 56 | frontend_dataset = patch_dataset.dataset 57 | num_samples = len(frontend_dataset) 58 | num_workers = worker_info.num_workers 59 | assert num_samples >= num_workers, '\'num_samples\' must be greater than or equals to \'num_workers\'' 60 | split_size = num_samples // num_workers 61 | # using Subset instead of slicing 62 | indices = list(range(worker_id * split_size, (worker_id + 1) * split_size)) 63 | patch_dataset.dataset = Subset(frontend_dataset, indices) 64 | 65 | 66 | class WideGamutNetDataModule(BaseGamutNetDataModule): 67 | 68 | def _create_patches(self, dataset): 69 | return IterablePatchesDataset(dataset=dataset, 70 | hint_mode=self.hparams.hint_mode, 71 | patch_size=self.hparams.patch_size, 72 | max_patches_per_image=self.hparams.max_patches_per_image, 73 | transform=self.transform, ) 74 | 75 | # download, split, etc... 76 | # only called on 1 GPU/TPU in distributed 77 | def setup(self, stage: Optional[str] = None): 78 | split_path = self.hparams.split_path 79 | 80 | if 'fit' == stage or stage is None: 81 | train_input = _read_split(split_path, self.hparams.train_input) 82 | train_target = _read_split(split_path, self.hparams.train_target) 83 | train_paired_images = PairedDataset(Images(train_input), Images(train_target)) 84 | self.train_dataset = FrontendDataset(train_paired_images) 85 | 86 | val_input = _read_split(split_path, self.hparams.val_input) 87 | val_target = _read_split(split_path, self.hparams.val_target) 88 | val_paired_images = PairedDataset(Images(val_input), Images(val_target)) 89 | self.val_dataset = FrontendDataset(val_paired_images) 90 | 91 | if 'test' == stage or stage is None: 92 | test_input = _read_split(split_path, self.hparams.test_input) 93 | test_target = _read_split(split_path, self.hparams.test_target) 94 | test_paired_images = PairedDataset(Images(test_input), Images(test_target)) 95 | self.test_dataset = FrontendDataset(test_paired_images) 96 | 97 | self.transform = transforms.Compose([transforms.ToTensor()]) 98 | 99 | @staticmethod 100 | def add_datamodule_specific_args(parent_parser): 101 | model_parser = ArgumentParser(parents=[parent_parser], add_help=False) 102 | model_parser.add_argument('--split_path', type=str, required=True) 103 | model_parser.add_argument('--train_input', type=str, default='train-input.txt') 104 | model_parser.add_argument('--train_target', type=str, default='train-target.txt') 105 | model_parser.add_argument('--val_input', type=str, default='val-input.txt') 106 | model_parser.add_argument('--val_target', type=str, default='val-target.txt') 107 | model_parser.add_argument('--test_input', type=str, default='test-input.txt') 108 | model_parser.add_argument('--test_target', type=str, default='test-target.txt') 109 | 110 | model_parser.add_argument('--hint_mode', choices=['none', 'o2o_all', 'o2o_rgb', ], default='o2o_all') 111 | model_parser.add_argument('--patch_size', type=int, nargs=2, default=(32, 32)) 112 | model_parser.add_argument('--max_patches_per_image', type=int, default=32000) 113 | 114 | model_parser.add_argument('--batch_size', type=int, default=32) 115 | 116 | model_parser.add_argument('--train_num_workers', type=int, default=2) 117 | model_parser.add_argument('--val_num_workers', type=int, default=1) 118 | model_parser.add_argument('--test_num_workers', type=int, default=1) 119 | 120 | return model_parser -------------------------------------------------------------------------------- /datasets/datasets.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | import numpy as np 3 | from os import scandir 4 | from imageio import imread 5 | from torch.utils.data import Dataset, get_worker_info 6 | from pathlib import Path 7 | import time 8 | from PIL import Image, ImageFile 9 | ImageFile.LOAD_TRUNCATED_IMAGES = True 10 | 11 | import cv2 12 | 13 | def read_image(path: Union[str, Path]) -> np.ndarray: 14 | img = cv2.imread(str(path), cv2.IMREAD_UNCHANGED) 15 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 16 | return img 17 | 18 | def _get_filenames(path, is_valid_file): 19 | filenames = [] 20 | for entry in scandir(path): 21 | if not entry.name.startswith('.') and entry.is_file(): 22 | if is_valid_file is None: 23 | filenames.append(entry.name) 24 | else: 25 | if is_valid_file(entry.name): 26 | filenames.append(entry.name) 27 | return filenames 28 | 29 | 30 | class Filenames(Dataset): 31 | 32 | def __init__(self, root_path, filename_list=None, sort=True, is_valid_file=None): 33 | self.root_path = root_path 34 | if not isinstance(root_path, Path): 35 | self.root_path = Path(root_path) 36 | assert self.root_path.exists(), f'{self.root_path} does not exist' 37 | 38 | self.filename_list = filename_list 39 | if not isinstance(filename_list, list): 40 | self.filename_list = _get_filenames(self.root_path, is_valid_file) 41 | assert len(self.filename_list) > 0, f'{self.root_path} is empty' 42 | 43 | if sort: 44 | self.filename_list = sorted(self.filename_list) 45 | 46 | def __len__(self): 47 | return len(self.filename_list) 48 | 49 | def __getitem__(self, indices): 50 | if isinstance(indices, int): 51 | # return a filename at a specific index 52 | filename = self.root_path / self.filename_list[indices] 53 | return filename 54 | else: 55 | raise TypeError(f'{type(self)} indices must be integers, not {type(indices)}') 56 | 57 | 58 | class Images(Dataset): 59 | 60 | def __init__(self, filenames, loader=read_image): 61 | assert len(filenames) > 0 62 | self.filenames = filenames 63 | assert callable(loader) 64 | self.loader = loader 65 | 66 | def __len__(self): 67 | return len(self.filenames) 68 | 69 | def __getitem__(self, indices): 70 | if isinstance(indices, int): 71 | # return an image at a specific index 72 | filename = self.filenames[indices] 73 | filename = filename if isinstance(filename, Path) else Path(filename) 74 | assert filename.is_file(), f"{filename} must be an existing file." 75 | # To test multi-process loading, 76 | # 1. import get_worker_info from torch.utils.data 77 | # 2. uncomment following two lines: 78 | started_at = time.time() 79 | try: 80 | image = self.loader(filename) 81 | except ValueError or OSError: 82 | print(f"CANNOT OPEN THIS IMAGE: {filename}") 83 | worker_info = get_worker_info() 84 | # print(f'\nWorker {-1 if worker_info is None else worker_info.id}' 85 | # f' loaded {filename} in {time.time() - started_at:.2f}s (ind={indices}, len={len(self)}).') 86 | return image 87 | else: 88 | raise TypeError(f'{type(self)} indices must be integers, not {type(indices)}') 89 | 90 | 91 | class PairedDataset(Dataset): 92 | """ 93 | See: https://discuss.pytorch.org/t/train-simultaneously-on-two-datasets/649/2 94 | """ 95 | 96 | def __init__(self, *datasets): 97 | self.datasets = datasets 98 | 99 | def __len__(self): 100 | return min(len(d) for d in self.datasets) 101 | 102 | def __getitem__(self, indices): 103 | if isinstance(indices, int): 104 | # return a paired data at a specific index 105 | paired_data = tuple(dataset[indices] for dataset in self.datasets) 106 | return paired_data 107 | else: 108 | raise TypeError(f'{type(self)} indices must be integers, not {type(indices)}') 109 | -------------------------------------------------------------------------------- /datasets/frontend_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | 4 | from utils.color import to_single, decode_srgb, srgb_to_prop_cat02, decode_prop 5 | 6 | 7 | class FrontendDataset(Dataset): 8 | def __init__(self, paired_images): 9 | assert isinstance(paired_images, Dataset) 10 | self.paired_images = paired_images 11 | 12 | def __len__(self): 13 | return len(self.paired_images) 14 | 15 | def __getitem__(self, indices): 16 | if isinstance(indices, int): 17 | # grab a pair of input and target images 18 | input_img, target_img = self.paired_images[indices] 19 | 20 | # make the input and target image prepared 21 | prep_input_img = srgb_to_prop_cat02(decode_srgb(to_single(input_img))) 22 | prep_target_img = decode_prop(to_single(target_img)) 23 | 24 | # return the prepared input and target images along with the original input image 25 | # the input_img will be used for further processing 26 | return prep_input_img, prep_target_img, input_img 27 | else: 28 | raise TypeError(f'{type(self)} indices must be integers, not {type(indices)}') -------------------------------------------------------------------------------- /datasets/iterable_patches_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, IterableDataset, get_worker_info 2 | from .patch_utils import generate_patch 3 | import numpy as np 4 | import itertools 5 | from utils.mask import compute_masks 6 | 7 | 8 | class IterablePatchesDataset(IterableDataset): 9 | 10 | def __init__(self, dataset, hint_mode, patch_size=(128, 128), max_patches_per_image=1000, transform=None): 11 | assert isinstance(dataset, Dataset) 12 | self.dataset = dataset 13 | 14 | self.hint_mode = hint_mode 15 | 16 | assert patch_size[0] > 0 and patch_size[1] > 0 17 | self.patch_size = patch_size 18 | 19 | assert max_patches_per_image > 0 20 | self.max_patches_per_image = max_patches_per_image 21 | 22 | self.transform = transform 23 | 24 | worker_info = get_worker_info() 25 | if worker_info is not None: 26 | self.rng = np.random.default_rng(worker_info.seed) 27 | else: 28 | self.rng = np.random.default_rng(2021) 29 | 30 | def __iter__(self): 31 | return itertools.chain.from_iterable(map(self.process_data, self.dataset)) 32 | 33 | def process_data(self, data): 34 | prep_input_img, prep_target_img, input_img = data 35 | 36 | o2o_mask, m2o_mask, m_inner = compute_masks(input_img) 37 | 38 | # choose a hint 39 | hints = {'none': None, 'o2o_all': o2o_mask, 'o2o_rgb': m_inner} # all the hints here 40 | hint = hints[self.hint_mode] # choose a particular hint using hint_mode 41 | 42 | if hint is not None: # append hint to the input 43 | prep_hint_img = hint.astype(np.float32) # type-matching 44 | prep_input_img = np.dstack((prep_input_img, prep_hint_img)) 45 | 46 | patch_generator = generate_patch(m2o_mask, self.patch_size, self.rng) # generate patch using padded mask 47 | sliced_patch_generator = itertools.islice(patch_generator, self.max_patches_per_image) # length-limited 48 | 49 | for patch in sliced_patch_generator: 50 | _, (top, bottom, left, right) = patch 51 | input_patch = prep_input_img[top:bottom, left:right, :] 52 | target_patch = prep_target_img[top:bottom, left:right, :] 53 | # repeat mask along channel-axis to compute loss conveniently 54 | m2o_mask_patch = m2o_mask[top:bottom, left:right, None].repeat(3, axis=2) 55 | 56 | if self.transform: # apply transforms such as ToTensor() 57 | input_patch = self.transform(input_patch) 58 | target_patch = self.transform(target_patch) 59 | m2o_mask_patch = self.transform(m2o_mask_patch) 60 | 61 | yield input_patch, target_patch, m2o_mask_patch # yield a sample in a batch -------------------------------------------------------------------------------- /datasets/patch_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | 4 | 5 | def generate_patch(mask, size, random_generator=None): 6 | # patch dimensions 7 | rows, cols = size 8 | half_rows = int(rows / 2) 9 | half_cols = int(cols / 2) 10 | 11 | # crop the mask according to the given patch dimensions 12 | cropped_mask = mask[half_rows:-half_rows, half_cols:-half_cols] 13 | 14 | # find pixels where the mask value is True 15 | nonzero_indices = np.flatnonzero(cropped_mask) 16 | if random_generator is not None: 17 | assert isinstance(random_generator, np.random.Generator) 18 | # if a random number generator is provided, it shuffles the indices 19 | random_generator.shuffle(nonzero_indices) 20 | nonzero_multi_indices = np.unravel_index(nonzero_indices, cropped_mask.shape) 21 | 22 | # generate patches 23 | for i, j in zip(*nonzero_multi_indices): 24 | # The first 2-tuple is the center point (row_center, col_center); and 25 | # The second 4-tuple is the bounding box (top, bottom, left, right). 26 | # The center point is for indexing the groundtruth color. 27 | # e.g. y[row_center, col_center, :] 28 | # The bounding box is for indexing the input patch. 29 | # e.g. x[top:bottom, left:right, :] 30 | # The below line yields ((row_center, col_center), (top, bottom, left, right)). 31 | yield (i + half_rows, j + half_cols), (i, i + rows, j, j + cols) 32 | 33 | 34 | def get_pad_width(patch_size): 35 | # pad_width := ((before_axis_0, after_axis_0), (before_axis_1, after_axis_1), (before_axis_2, after_axis_2)) 36 | half_rows, half_cols = tuple(int(s / 2) for s in patch_size) 37 | pad_width = ((half_rows, half_rows), (half_cols, half_cols), (0, 0)) 38 | return pad_width 39 | -------------------------------------------------------------------------------- /figures/overview_gamutnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gamut-mapping/GamutNet/68fe72858b7963f82b7b324edffb5a534394f46c/figures/overview_gamutnet.png -------------------------------------------------------------------------------- /gen_split_dataset.py: -------------------------------------------------------------------------------- 1 | import time 2 | from argparse import ArgumentParser 3 | from pathlib import Path 4 | import random 5 | 6 | random.seed(2021) 7 | 8 | def main(args): 9 | split_output = Path('./split_output') 10 | split_output.mkdir(parents=True, exist_ok=True) 11 | 12 | input_dir = Path(args.input) 13 | target_dir = Path(args.target) 14 | img_names = list(target_dir.glob('*.png')) 15 | 16 | list_idx = list(range(len(img_names))) 17 | random.shuffle(list_idx) 18 | train_length = round(0.95*(len(img_names))) 19 | 20 | write_txt(img_names, args.input,list_idx[:train_length], split_output / "train-input.txt") 21 | write_txt(img_names, args.input,list_idx[train_length:], split_output / "val-input.txt") 22 | write_txt(img_names, args.target,list_idx[:train_length], split_output / "train-target.txt") 23 | write_txt(img_names, args.target,list_idx[train_length:], split_output / "val-target.txt") 24 | 25 | def write_txt(img_names, img_dir, list_idx, filename): 26 | with open(filename, 'w') as file: 27 | for idx in list_idx: 28 | file.write(img_dir + "/" + img_names[idx].name + "\n") 29 | 30 | if __name__ == '__main__': 31 | start_time = time.time() 32 | 33 | parser = ArgumentParser() 34 | parser.add_argument("-i", "--input", type=str, 35 | help="input sRGB dir") 36 | parser.add_argument("-t", "--target", type=str, 37 | help="target ProPhoto dir") 38 | main(parser.parse_args()) # parse args and start training 39 | 40 | end_time = time.time() 41 | duration = end_time - start_time 42 | duration = round(duration/3600, 2) 43 | print(f'---- FINISHED in {duration} hours ----') -------------------------------------------------------------------------------- /icc_profiles/ProPhoto.icm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gamut-mapping/GamutNet/68fe72858b7963f82b7b324edffb5a534394f46c/icc_profiles/ProPhoto.icm -------------------------------------------------------------------------------- /icc_profiles/__init__.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageCms 2 | from pathlib import Path 3 | 4 | path = (Path(__file__).parent / 'ProPhoto.icm') 5 | ICC_PROPHOTO_RGB_PROFILE_BYTES = ImageCms.getOpenProfile(str(path)).tobytes() 6 | -------------------------------------------------------------------------------- /inference/__init__.py: -------------------------------------------------------------------------------- 1 | from .inference_agent import NoPWInferenceAgent 2 | -------------------------------------------------------------------------------- /inference/inference_agent.py: -------------------------------------------------------------------------------- 1 | import time 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import torch 6 | from PIL import ImageCms 7 | from imageio import imread, imsave 8 | from torchvision.transforms.functional import to_tensor 9 | 10 | from models import WideGamutNetPL 11 | from utils.color import to_single, to_uint8, decode_srgb, srgb_to_prop_cat02, encode_prop 12 | from utils.mask import compute_masks 13 | 14 | # ICC_PROFILE_BYTES = ImageCms.getOpenProfile('icc_profiles/ProPhoto.icm').tobytes() 15 | 16 | 17 | class NoPWInferenceAgent: 18 | 19 | def __init__(self, 20 | version_path, 21 | ckpt_filename, 22 | ckpt_dirname='checkpoints', 23 | hparams_filename='hparams.yaml', 24 | device='cpu'): 25 | 26 | self.version_path = Path(version_path) 27 | assert self.version_path.is_dir(), 'version_path must be an existing directory' 28 | 29 | self.checkpoint_path = self.version_path / ckpt_dirname / ckpt_filename 30 | assert self.checkpoint_path.is_file(), 'checkpoint_path must be an existing file' 31 | 32 | self.hparams_file = self.version_path / hparams_filename 33 | assert self.hparams_file.is_file(), 'hparams_file must be an existing file' 34 | 35 | if torch.cuda.is_available() and device != 'cpu': 36 | map_location = lambda storage, loc: storage.cuda() 37 | else: 38 | map_location = 'cpu' 39 | 40 | self.model = WideGamutNetPL.load_from_checkpoint(checkpoint_path=str(self.checkpoint_path), 41 | hparams_file=str(self.hparams_file), 42 | map_location=map_location) 43 | self.device = device 44 | self.model.to(self.device) 45 | self.model.eval() 46 | 47 | def single_image_inference(self, img_path, output_img_path=None, icc_profile=None): 48 | started_at = time.time() 49 | input_img = imread(img_path)[:, :, :3] # drop alpha channel if it is provided 50 | print(f'[Reading Input Image] {time.time() - started_at:.2f} seconds') 51 | 52 | started_at = time.time() 53 | o2o_mask, _, m_inner = compute_masks(input_img) 54 | print(f'[Computing Masks] {time.time() - started_at:.2f} seconds') 55 | 56 | # make the input image prepared 57 | started_at = time.time() 58 | prep_input_img = srgb_to_prop_cat02(decode_srgb(to_single(input_img))) 59 | 60 | # choose a hint 61 | hints = {'none': None, 'o2o_all': o2o_mask, 'o2o_rgb': m_inner} # all the hints here 62 | hint = hints[self.hint_mode] # choose a particular hint using hint_mode 63 | 64 | if hint is not None: 65 | prep_hint_img = hint.astype(np.float32) # type-matching 66 | network_input = np.dstack((prep_input_img, prep_hint_img)) 67 | else: 68 | network_input = prep_input_img 69 | 70 | print(f'[Preparing Input Image] {time.time() - started_at:.2f} seconds') 71 | 72 | started_at = time.time() 73 | network_output = self.infer(network_input) 74 | print(f'[Inferring Output Image] {time.time() - started_at:.2f} seconds') 75 | 76 | started_at = time.time() 77 | network_output[o2o_mask] = prep_input_img[o2o_mask] # use 'prep_input_img' at 'one-to-one' pixels 78 | output = encode_prop(network_output) # encode ProPhoto RGB and clip from 0 to 1 79 | print(f'[Preparing Output Image] {time.time() - started_at:.2f} seconds') 80 | 81 | if output_img_path: 82 | started_at = time.time() 83 | imsave(output_img_path, to_uint8(output), optimize=False, compression=0, icc_profile=icc_profile) 84 | print(f'[Saving Output Image] {time.time() - started_at:.2f} seconds') 85 | 86 | return output 87 | 88 | def infer(self, network_input): 89 | network_input = to_tensor(network_input).to(self.device) # to Tensor 90 | network_input = torch.unsqueeze(network_input, 0) # add batch dimension 91 | with torch.no_grad(): 92 | network_output = self.model(network_input.float()).detach() 93 | network_output = network_output.squeeze() # remove batch dimension 94 | return network_output.transpose(0, 1).transpose(1, 2).cpu().numpy() # to NumPy 95 | 96 | def gamut_expansion(self, srgb_img): 97 | o2o_mask, _, m_inner = compute_masks(srgb_img) 98 | # make the input image prepared 99 | prep_input_img = srgb_to_prop_cat02(srgb_img) 100 | 101 | # choose a hint 102 | hints = {'none': None, 'o2o_all': o2o_mask, 'o2o_rgb': m_inner} # all the hints here 103 | hint = hints[self.hint_mode] # choose a particular hint using hint_mode 104 | 105 | if hint is not None: 106 | prep_hint_img = hint.astype(np.float32) # type-matching 107 | network_input = np.dstack((prep_input_img, prep_hint_img)) 108 | else: 109 | network_input = prep_input_img 110 | 111 | started_at = time.time() 112 | network_output = self.infer(network_input) 113 | 114 | started_at = time.time() 115 | network_output[o2o_mask] = prep_input_img[o2o_mask] # use 'prep_input_img' at 'one-to-one' pixels 116 | 117 | return network_output 118 | 119 | @property 120 | def hint_mode(self): 121 | return self.model.hparams.hint_mode 122 | -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .calc_mse import calc_mse 2 | from .calc_mae import calc_mae 3 | from .calc_deltaE import calc_deltaE 4 | from .calc_rmse import calc_rmse 5 | from .calc_psnr import calc_psnr 6 | -------------------------------------------------------------------------------- /metrics/calc_deltaE.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from colour import models 3 | import colour 4 | 5 | 6 | def calc_deltaE(source, target, color_space, method='CIE 2000', chromatic_adaptation_transform='CAT02'): 7 | # type: (numpy.ndarray, numpy.ndarray, str, str) -> float 8 | # method = 'CIE 2000' | 'CIE 1976' 9 | 10 | assert isinstance(color_space, str), "color_space should be string" 11 | COLORSPACE_DICT = {'ProPhotoRGB': models.RGB_COLOURSPACE_PROPHOTO_RGB, 12 | 'sRGB': models.RGB_COLOURSPACE_sRGB, 13 | 'AdobeRGB': models.RGB_COLOURSPACE_ADOBE_RGB1998, 14 | 'DisplayP3': models.RGB_COLOURSPACE_DISPLAY_P3, 15 | } 16 | assert COLORSPACE_DICT.get(color_space) != None, "color_space should be ProPhotoRGB, sRGB, AdobeRGB, DisplayP3" 17 | color_space = COLORSPACE_DICT.get(color_space) 18 | source = np.reshape(source, (-1,3)) 19 | target = np.reshape(target, (-1,3)) 20 | source_XYZ = models.RGB_to_XYZ( 21 | source, 22 | color_space.whitepoint, 23 | color_space.whitepoint, 24 | color_space.matrix_RGB_to_XYZ, 25 | chromatic_adaptation_transform=chromatic_adaptation_transform 26 | ) 27 | target_XYZ = models.RGB_to_XYZ( 28 | target, 29 | color_space.whitepoint, 30 | color_space.whitepoint, 31 | color_space.matrix_RGB_to_XYZ, 32 | chromatic_adaptation_transform=chromatic_adaptation_transform 33 | ) 34 | source_Lab = models.XYZ_to_Lab( 35 | source_XYZ, 36 | color_space.whitepoint, 37 | ) 38 | target_Lab = models.XYZ_to_Lab( 39 | target_XYZ, 40 | color_space.whitepoint, 41 | ) 42 | deltaE = colour.delta_E(source_Lab, target_Lab, method=method) 43 | # if source.shape[0] == 1: 44 | # return deltaE 45 | # deltaE = sum(deltaE)/deltaE.shape[0] 46 | return np.mean(deltaE) 47 | -------------------------------------------------------------------------------- /metrics/calc_mae.py: -------------------------------------------------------------------------------- 1 | from utils.metric import mae 2 | 3 | 4 | def calc_mae(source, target): 5 | return mae(source, target) 6 | -------------------------------------------------------------------------------- /metrics/calc_mse.py: -------------------------------------------------------------------------------- 1 | from utils.metric import mse 2 | 3 | 4 | def calc_mse(source, target): 5 | return mse(source, target) 6 | -------------------------------------------------------------------------------- /metrics/calc_psnr.py: -------------------------------------------------------------------------------- 1 | from utils.metric import psnr 2 | 3 | 4 | def calc_psnr(source, target): 5 | return psnr(source, target) 6 | -------------------------------------------------------------------------------- /metrics/calc_rmse.py: -------------------------------------------------------------------------------- 1 | from utils.metric import rmse 2 | 3 | 4 | def calc_rmse(source, target): 5 | return rmse(source, target) 6 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .wide_gamut_net_pl import WideGamutNetPL -------------------------------------------------------------------------------- /models/networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gamut-mapping/GamutNet/68fe72858b7963f82b7b324edffb5a534394f46c/models/networks/__init__.py -------------------------------------------------------------------------------- /models/networks/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gamut-mapping/GamutNet/68fe72858b7963f82b7b324edffb5a534394f46c/models/networks/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/deep_wb_blocks.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gamut-mapping/GamutNet/68fe72858b7963f82b7b324edffb5a534394f46c/models/networks/__pycache__/deep_wb_blocks.cpython-38.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/wide_gamut_blocks.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gamut-mapping/GamutNet/68fe72858b7963f82b7b324edffb5a534394f46c/models/networks/__pycache__/wide_gamut_blocks.cpython-38.pyc -------------------------------------------------------------------------------- /models/networks/__pycache__/wide_gamut_net.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gamut-mapping/GamutNet/68fe72858b7963f82b7b324edffb5a534394f46c/models/networks/__pycache__/wide_gamut_net.cpython-38.pyc -------------------------------------------------------------------------------- /models/networks/deep_wb_blocks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Main blocks of the network 3 | Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved 4 | If you use this code, please cite the following paper: 5 | Mahmoud Afifi and Michael S Brown. Deep White-Balance Editing. In CVPR, 2020. 6 | """ 7 | __author__ = "Mahmoud Afifi" 8 | __credits__ = ["Mahmoud Afifi"] 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | 14 | class DoubleConvBlock(nn.Module): 15 | """double conv layers block""" 16 | def __init__(self, in_channels, out_channels): 17 | super().__init__() 18 | self.double_conv = nn.Sequential( 19 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 20 | nn.ReLU(inplace=True), 21 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 22 | nn.ReLU(inplace=True) 23 | ) 24 | 25 | def forward(self, x): 26 | return self.double_conv(x) 27 | 28 | 29 | class DownBlock(nn.Module): 30 | """Downscale block: maxpool -> double conv block""" 31 | def __init__(self, in_channels, out_channels): 32 | super().__init__() 33 | self.maxpool_conv = nn.Sequential( 34 | nn.MaxPool2d(2), 35 | DoubleConvBlock(in_channels, out_channels) 36 | ) 37 | 38 | def forward(self, x): 39 | return self.maxpool_conv(x) 40 | 41 | 42 | class BridgeDown(nn.Module): 43 | """Downscale bottleneck block: maxpool -> conv""" 44 | def __init__(self, in_channels, out_channels): 45 | super().__init__() 46 | self.maxpool_conv = nn.Sequential( 47 | nn.MaxPool2d(2), 48 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 49 | nn.ReLU(inplace=True) 50 | ) 51 | 52 | def forward(self, x): 53 | return self.maxpool_conv(x) 54 | 55 | 56 | class BridgeUP(nn.Module): 57 | """Downscale bottleneck block: conv -> transpose conv""" 58 | def __init__(self, in_channels, out_channels): 59 | super().__init__() 60 | self.conv_up = nn.Sequential( 61 | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), 62 | nn.ReLU(inplace=True), 63 | nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2) 64 | ) 65 | 66 | def forward(self, x): 67 | return self.conv_up(x) 68 | 69 | 70 | 71 | class UpBlock(nn.Module): 72 | """Upscale block: double conv block -> transpose conv""" 73 | def __init__(self, in_channels, out_channels): 74 | super().__init__() 75 | self.conv = DoubleConvBlock(in_channels * 2, in_channels) 76 | self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2) 77 | 78 | 79 | 80 | def forward(self, x1, x2): 81 | x = torch.cat([x2, x1], dim=1) 82 | x = self.conv(x) 83 | return torch.relu(self.up(x)) 84 | 85 | 86 | class OutputBlock(nn.Module): 87 | """Output block: double conv block -> output conv""" 88 | def __init__(self, in_channels, out_channels): 89 | super().__init__() 90 | self.out_conv = nn.Sequential( 91 | DoubleConvBlock(in_channels * 2, in_channels), 92 | nn.Conv2d(in_channels, out_channels, kernel_size=1)) 93 | 94 | def forward(self, x1, x2): 95 | x = torch.cat([x2, x1], dim=1) 96 | return self.out_conv(x) 97 | 98 | 99 | -------------------------------------------------------------------------------- /models/networks/wide_gamut_blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .deep_wb_blocks import DoubleConvBlock 5 | 6 | 7 | def safe_depth_cat(x1, x2): 8 | """Safely concatenate given tensors along depth (i.e. channel) axis""" 9 | # https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_parts.py 10 | # input is (Batch, Channel, Height, Width) 11 | dy = x2.size()[2] - x1.size()[2] 12 | dx = x2.size()[3] - x1.size()[3] 13 | x1 = F.pad(x1, [dx // 2, dx - dx // 2, dy // 2, dy - dy // 2]) 14 | return torch.cat([x2, x1], dim=1) 15 | 16 | 17 | class SafeUpBlock(nn.Module): 18 | """Upscale block: double conv block -> transpose conv""" 19 | 20 | def __init__(self, in_channels, out_channels): 21 | super().__init__() 22 | self.conv = DoubleConvBlock(in_channels * 2, in_channels) 23 | self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2) 24 | 25 | def forward(self, x1, x2): 26 | return torch.relu(self.up(self.conv(safe_depth_cat(x1, x2)))) 27 | 28 | 29 | class SafeOutputBlock(nn.Module): 30 | """Output block: double conv block -> output conv""" 31 | 32 | def __init__(self, in_channels, out_channels): 33 | super().__init__() 34 | self.out_conv = nn.Sequential( 35 | DoubleConvBlock(in_channels * 2, in_channels), 36 | nn.Conv2d(in_channels, out_channels, kernel_size=1)) 37 | 38 | def forward(self, x1, x2): 39 | return self.out_conv(safe_depth_cat(x1, x2)) 40 | -------------------------------------------------------------------------------- /models/networks/wide_gamut_net.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from .deep_wb_blocks import (DoubleConvBlock, DownBlock, BridgeDown, BridgeUP) 7 | from .wide_gamut_blocks import SafeUpBlock, SafeOutputBlock # Concatenation-safe 8 | 9 | 10 | class BaseWideGamutNet(nn.Module, ABC): 11 | def __init__(self, n_in_channels, using_residual=True, limiting_output_range=False): 12 | super(BaseWideGamutNet, self).__init__() 13 | self.n_in_channels = n_in_channels 14 | self.n_out_channels = 3 # (R, G, B) 15 | self.using_residual = using_residual 16 | self.limiting_output_range = limiting_output_range 17 | 18 | @abstractmethod 19 | def _forward(self, x): 20 | raise NotImplementedError 21 | 22 | def forward(self, x): 23 | x0 = x[:, :3, :, :] # drop the mask channel 24 | if not self.using_residual: 25 | x0 = torch.zeros(x0.shape) 26 | out = self._forward(x) + x0 # residual learning 27 | if self.limiting_output_range: 28 | out = torch.sigmoid(out) 29 | return out 30 | 31 | 32 | class WideGamutNet(BaseWideGamutNet): 33 | def __init__(self, n_in_channels, using_residual=True, limiting_output_range=False): 34 | super(WideGamutNet, self).__init__(n_in_channels, using_residual, limiting_output_range) 35 | # Contracting 36 | self.encoder_inc = DoubleConvBlock(self.n_in_channels, 24) 37 | self.encoder_down1 = DownBlock(24, 48) 38 | self.encoder_down2 = DownBlock(48, 96) 39 | self.encoder_down3 = DownBlock(96, 192) 40 | self.encoder_bridge_down = BridgeDown(192, 384) 41 | # Expanding 42 | self.decoder_bridge_up = BridgeUP(384, 192) 43 | self.decoder_up1 = SafeUpBlock(192, 96) # Concatenation-safe 44 | self.decoder_up2 = SafeUpBlock(96, 48) # Concatenation-safe 45 | self.decoder_up3 = SafeUpBlock(48, 24) # Concatenation-safe 46 | self.decoder_out = SafeOutputBlock(24, self.n_out_channels) # Concatenation-safe 47 | 48 | def _forward(self, x): 49 | # Contracting 50 | x1 = self.encoder_inc(x) 51 | x2 = self.encoder_down1(x1) 52 | x3 = self.encoder_down2(x2) 53 | x4 = self.encoder_down3(x3) 54 | x5 = self.encoder_bridge_down(x4) 55 | # Expanding 56 | x = self.decoder_bridge_up(x5) 57 | x = self.decoder_up1(x, x4) 58 | x = self.decoder_up2(x, x3) 59 | x = self.decoder_up3(x, x2) 60 | x = self.decoder_out(x, x1) 61 | return x 62 | 63 | 64 | class SmallWideGamutNet(BaseWideGamutNet): 65 | def __init__(self, n_in_channels, using_residual=True, limiting_output_range=False): 66 | super(SmallWideGamutNet, self).__init__(n_in_channels, using_residual, limiting_output_range) 67 | # Contracting 68 | self.encoder_inc = DoubleConvBlock(self.n_in_channels, 24) 69 | self.encoder_down1 = DownBlock(24, 48) 70 | self.encoder_down2 = DownBlock(48, 96) 71 | self.encoder_bridge_down = BridgeDown(96, 192) 72 | # Expanding 73 | self.decoder_bridge_up = BridgeUP(192, 96) 74 | self.decoder_up1 = SafeUpBlock(96, 48) # Concatenation-safe 75 | self.decoder_up2 = SafeUpBlock(48, 24) # Concatenation-safe 76 | self.decoder_out = SafeOutputBlock(24, self.n_out_channels) # Concatenation-safe 77 | 78 | def _forward(self, x): 79 | # Contracting 80 | x1 = self.encoder_inc(x) 81 | x2 = self.encoder_down1(x1) 82 | x3 = self.encoder_down2(x2) 83 | x4 = self.encoder_bridge_down(x3) 84 | # Expanding 85 | x = self.decoder_bridge_up(x4) 86 | x = self.decoder_up1(x, x3) 87 | x = self.decoder_up2(x, x2) 88 | x = self.decoder_out(x, x1) 89 | return x 90 | 91 | 92 | class TinyWideGamutNet(BaseWideGamutNet): 93 | def __init__(self, n_in_channels, using_residual=True, limiting_output_range=False): 94 | super(TinyWideGamutNet, self).__init__(n_in_channels, using_residual, limiting_output_range) 95 | # Contracting 96 | self.encoder_inc = DoubleConvBlock(self.n_in_channels, 24) 97 | self.encoder_down1 = DownBlock(24, 48) 98 | self.encoder_bridge_down = BridgeDown(48, 96) 99 | # Expanding 100 | self.decoder_bridge_up = BridgeUP(96, 48) 101 | self.decoder_up1 = SafeUpBlock(48, 24) # Concatenation-safe 102 | self.decoder_out = SafeOutputBlock(24, self.n_out_channels) # Concatenation-safe 103 | 104 | def _forward(self, x): 105 | # Contracting 106 | x1 = self.encoder_inc(x) 107 | x2 = self.encoder_down1(x1) 108 | x3 = self.encoder_bridge_down(x2) 109 | # Expanding 110 | x = self.decoder_bridge_up(x3) 111 | x = self.decoder_up1(x, x2) 112 | x = self.decoder_out(x, x1) 113 | return x 114 | -------------------------------------------------------------------------------- /models/wide_gamut_net_pl.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import torch 3 | from torch.nn import functional as F 4 | from torch.optim import Adam 5 | from pytorch_lightning.core.lightning import LightningModule 6 | from models.networks.wide_gamut_net import WideGamutNet, SmallWideGamutNet, TinyWideGamutNet 7 | 8 | 9 | class WideGamutNetPL(LightningModule): 10 | 11 | def __init__(self, hparams, *args, **kwargs): 12 | super().__init__() 13 | self.save_hyperparameters(hparams) 14 | if self.hparams.model_size == 'default': # RF=82 15 | self.net = WideGamutNet(self.hparams.input_channels, 16 | using_residual=self.hparams.using_residual, 17 | limiting_output_range=self.hparams.limiting_output_range) 18 | elif self.hparams.model_size == 'small': # RF=68 19 | self.net = SmallWideGamutNet(self.hparams.input_channels, 20 | using_residual=self.hparams.using_residual, 21 | limiting_output_range=self.hparams.limiting_output_range) 22 | elif self.hparams.model_size == 'tiny': # RF=32 23 | self.net = TinyWideGamutNet(self.hparams.input_channels, 24 | using_residual=self.hparams.using_residual, 25 | limiting_output_range=self.hparams.limiting_output_range) 26 | else: 27 | raise Exception("Sorry, there is no model like that!") 28 | 29 | # the input array is in shape of (batch_size, input_channels, patch_height, patch_size) 30 | self.example_input_array = torch.rand(self.hparams.batch_size, 31 | self.hparams.input_channels, 32 | *self.hparams.patch_size, 33 | device=self.device) 34 | 35 | def _loss(self, batch): 36 | input_patches, target_patches, m2o_mask_patches = batch 37 | predicted_patches = self(input_patches.float()) # get predictions 38 | target_patches = target_patches.float() # get targets 39 | y_hat = predicted_patches[m2o_mask_patches] # consider only many-to-one pixels 40 | y = target_patches[m2o_mask_patches] # consider only many-to-one pixels 41 | return F.l1_loss(y_hat, y) 42 | 43 | def forward(self, x): 44 | return self.net(x) 45 | 46 | def training_step(self, batch, batch_idx): 47 | loss = self._loss(batch) 48 | self.log('train_loss', loss) 49 | return loss 50 | 51 | def validation_step(self, batch, batch_idx): 52 | loss = self._loss(batch) 53 | self.log('val_loss', loss) 54 | 55 | def test_step(self, batch, batch_idx): 56 | loss = self._loss(batch) 57 | self.log('test_loss', loss) 58 | 59 | def configure_optimizers(self): 60 | return Adam(self.parameters(), lr=self.hparams.learning_rate) 61 | 62 | @staticmethod 63 | def add_model_specific_args(parent_parser): 64 | model_parser = parent_parser.add_argument_group("WideGamutNetPL", "Arguments for WideGamutNetPL") 65 | 66 | # for model configuration 67 | model_parser.add_argument('--model_size', type=str, default='default', choices=['default', 'small', 'tiny']) 68 | model_parser.add_argument('--input_channels', type=int, default=4, choices=[3, 4, 6]) 69 | 70 | # for optimizer configuration 71 | model_parser.add_argument('--learning_rate', type=float, default=0.0001) 72 | 73 | model_parser.add_argument('--using_residual', type=bool, default=True) 74 | model_parser.add_argument('--limiting_output_range', type=bool, default=True) 75 | 76 | return parent_parser 77 | -------------------------------------------------------------------------------- /piecewise/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from .piecewise import (apply_mapping_piecewise, 3 | fit_mapping_piecewise, 4 | make_mask_piecewise) 5 | from .utils import (linear_kernel, 6 | load_pw_mapping) 7 | 8 | __author__ = "Taehong Jeong" 9 | __email__ = "enjoyjade43@ajou.ac.kr" 10 | __all__ = [] 11 | 12 | # piecewise 13 | __all__ += [ 14 | 'apply_mapping_piecewise', 15 | 'fit_mapping_piecewise', 16 | 'make_mask_piecewise' 17 | ] 18 | 19 | # utils 20 | __all__ += [ 21 | 'linear_kernel', 22 | 'load_pw_mapping' 23 | ] 24 | -------------------------------------------------------------------------------- /piecewise/piecewise.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | 4 | 5 | def make_mask_piecewise(x, bins=None): 6 | # the input must have its channel axis 7 | assert (3 == len(x.shape)) 8 | 9 | # we handle either 8bpc or 16bpc images 10 | assert x.dtype in (np.uint8, np.uint16) 11 | 12 | # set the default bins 13 | if bins is None: 14 | # either [0, 1); [1, 255); [255, infinity), 15 | # or [0, 1); [1, 65535); [65535, infinity). 16 | bins = [0, 1, np.iinfo(x.dtype).max] 17 | 18 | height, width, channels = x.shape 19 | dims = tuple([len(bins)] * channels) # the bin-dimensions 20 | multi_indices = np.digitize(x, bins, right=False) - 1 # the multi-dimensional indices (zero-based) 21 | indices = np.ravel_multi_index(multi_indices.reshape(-1, channels).T, dims) # the raveled indices 22 | pw_mask = indices.reshape(height, width).astype(np.int) 23 | return pw_mask 24 | 25 | 26 | def fit_mapping_piecewise(x_img, y_img, pw_mask, kernel): 27 | mappings = list() 28 | for mask_value in np.unique(pw_mask): 29 | mask = pw_mask == mask_value 30 | x, y = x_img[mask], y_img[mask] 31 | mapping = np.linalg.lstsq(kernel(x), y, rcond=None)[0] 32 | mappings.append(mapping) 33 | pw_mapping = np.transpose(np.dstack(mappings), (2, 0, 1)) 34 | return pw_mapping 35 | 36 | 37 | def apply_mapping_piecewise(x_img, pw_mapping, pw_mask, kernel, 38 | skip=(0, 13, 26)): # skip *in-gamut* cases of 27-pieces 39 | # copy the input as-is because some cases could be skipped. 40 | prediction = x_img.copy() 41 | for mask_value in np.unique(pw_mask): 42 | if (skip is not None) and (mask_value in skip): 43 | continue # skip a given set of mask_values 44 | mapping = pw_mapping[mask_value] 45 | mask = pw_mask == mask_value 46 | x = x_img[mask] 47 | y_hat = kernel(x).dot(mapping) 48 | prediction[mask] = np.clip(y_hat, 0, 1) 49 | return prediction 50 | -------------------------------------------------------------------------------- /piecewise/pw_mapping.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gamut-mapping/GamutNet/68fe72858b7963f82b7b324edffb5a534394f46c/piecewise/pw_mapping.npy -------------------------------------------------------------------------------- /piecewise/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | from os.path import realpath 4 | from pathlib import Path 5 | 6 | 7 | def linear_kernel(x): 8 | return np.c_[np.ones(len(x)), x] 9 | 10 | 11 | def load_pw_mapping(): 12 | return np.load(Path(realpath(__file__)).parent / 'pw_mapping.npy') 13 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.19.4 2 | pandas==1.0.5 3 | Pillow==9.0.0 4 | colour==0.1.5 5 | imageio==2.13.5 6 | pytorch_lightning==1.5.8 7 | torch==1.10.1 8 | torchvision==0.11.2 9 | tqdm==4.62.3 10 | typing==3.10.0.0 11 | -------------------------------------------------------------------------------- /run_inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from argparse import ArgumentParser 4 | from pathlib import Path 5 | 6 | from inference import NoPWInferenceAgent 7 | 8 | 9 | def main(input_path, output_path, version_path, ckpt_filename): 10 | input_path = Path(input_path) 11 | if input_path.is_dir(): 12 | input_files = [Path(entry.path) for entry in os.scandir(input_path)] 13 | elif input_path.is_file(): 14 | input_files = [Path(line) for line in input_path.read_text().splitlines()] 15 | else: 16 | raise RuntimeError('input_path must be an existing directory.') 17 | assert all(input_file.is_file() for input_file in input_files), 'all the input files must exist.' 18 | num_input_files = len(input_files) 19 | 20 | output_path = Path(output_path) 21 | output_path.mkdir(parents=True, exist_ok=True) 22 | 23 | inference_agent = NoPWInferenceAgent(version_path=version_path, ckpt_filename=ckpt_filename) 24 | for i, input_img_path in enumerate(input_files): 25 | print(f"Process {i + 1} / {num_input_files} image.") 26 | output_img_path = output_path / input_img_path.name 27 | inference_agent.single_image_inference_using_cnn(input_img_path, output_img_path) 28 | 29 | 30 | # for dev and debug 31 | if __name__ == '__main__': 32 | parser = ArgumentParser() 33 | parser.add_argument('--version_path', type=str, required=True) 34 | parser.add_argument('--ckpt_filename', type=str, required=True) 35 | parser.add_argument('-i', '--input_path', type=str, required=True, help='either a directory or a text file') 36 | parser.add_argument('-o', '--output_path', type=str, required=True) 37 | args = parser.parse_args() 38 | 39 | started_at = time.time() 40 | main(**vars(args)) 41 | print(f'---- FINISHED in {time.time() - started_at} seconds ----') 42 | -------------------------------------------------------------------------------- /trainer_main.py: -------------------------------------------------------------------------------- 1 | import time 2 | from argparse import ArgumentParser 3 | 4 | from pytorch_lightning import Trainer 5 | from pytorch_lightning.callbacks import ModelCheckpoint 6 | 7 | from datasets.data_modules import WideGamutNetDataModule 8 | from models import WideGamutNetPL 9 | 10 | 11 | def main(args): 12 | # checkpointing 13 | checkpoint_callback = ModelCheckpoint(save_top_k=-1, # always save all the checkpoints 14 | verbose=True, ) 15 | trainer = Trainer.from_argparse_args(args, callbacks=[checkpoint_callback]) 16 | model = WideGamutNetPL(hparams=args) # pick model 17 | datamodule = WideGamutNetDataModule(hparams=args) # pick datamodule 18 | trainer.fit(model, datamodule=datamodule) 19 | if args.run_test: 20 | trainer.test(ckpt_path=None) # use the latest weights (because we're saving all the checkpoints, not the best one) 21 | 22 | 23 | if __name__ == '__main__': 24 | start_time = time.time() 25 | 26 | parser = ArgumentParser() 27 | parser = Trainer.add_argparse_args(parser) 28 | parser = WideGamutNetPL.add_model_specific_args(parser) # model specific arguments 29 | parser = WideGamutNetDataModule.add_datamodule_specific_args(parser) # datamodule specific arguments 30 | parser.add_argument('--run_test', action='store_true', help='run test or not') 31 | main(parser.parse_args()) # parse args and start training 32 | 33 | end_time = time.time() 34 | duration = end_time - start_time 35 | duration = round(duration/3600, 2) 36 | print(f'---- FINISHED in {duration} hours ----') 37 | -------------------------------------------------------------------------------- /utils/color.py: -------------------------------------------------------------------------------- 1 | from warnings import warn 2 | import colour 3 | from colour.models import matrix_RGB_to_RGB, RGB_COLOURSPACE_sRGB, RGB_COLOURSPACE_PROPHOTO_RGB 4 | import numpy as np 5 | import colour 6 | 7 | ILLUMINANT_D50 = colour.CCS_ILLUMINANTS['CIE 1931 2 Degree Standard Observer']['D50'] 8 | ILLUMINANT_D65 = colour.CCS_ILLUMINANTS['CIE 1931 2 Degree Standard Observer']['D65'] 9 | 10 | def to_single(x): 11 | if x.dtype == np.uint8: 12 | return np.asarray(x, dtype=np.single) / 255.0 # we prefer float32 13 | elif x.dtype == np.uint16: 14 | return np.asarray(x, dtype=np.single) / 65535.0 15 | 16 | 17 | def to_uint8(x): 18 | return np.asarray(x * 255, dtype=np.uint8) 19 | 20 | 21 | def is_valid_type(x): 22 | return x.dtype in (np.float16, np.float32, np.float64) 23 | 24 | 25 | def is_valid_domain(x): 26 | return 0.0 <= np.min(x) and np.max(x) <= 1.0 27 | 28 | def encode_srgb(x): 29 | assert is_valid_type(x) 30 | if not is_valid_domain(x): 31 | warn('The given input is not in the valid domain. Please check the values.') 32 | x = np.clip(x, 0, 1) 33 | return colour.cctf_encoding(x, function='sRGB') 34 | 35 | def decode_srgb(x): 36 | assert is_valid_type(x) and is_valid_domain(x) 37 | 38 | # sRGB de-gamma (https://en.wikipedia.org/wiki/SRGB); It is faster then using colour's function. 39 | return np.clip(np.where(x <= 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4), 0, 1) 40 | 41 | 42 | def encode_prop(x): 43 | assert is_valid_type(x) 44 | 45 | if not is_valid_domain(x): 46 | warn('The given input is not in the valid domain. Please check the values.') 47 | x = np.clip(x, 0, 1) 48 | 49 | # colour.models.cctf_encoding_ROMMRGB/colour.models.cctf_encoding_ProPhotoRGB 50 | # It is faster then using colour's function. 51 | return np.clip(np.where(x < 0.001953125, x * 16, np.power(x, 1 / 1.8)), 0, 1) 52 | 53 | 54 | def decode_prop(x): 55 | assert is_valid_type(x) and is_valid_domain(x) 56 | 57 | # colour.models.cctf_decoding_ROMMRGB; It is faster then using colour's function. 58 | return np.clip(np.where(x < 0.03125, x / 16.0, np.power(x, 1.8)), 0, 1) 59 | 60 | def decode_displayp3(x): 61 | return colour.models.cctf_decoding(x, function='sRGB') 62 | 63 | 64 | def decode_adobergb(x): 65 | return colour.models.cctf_decoding(x, function='Gamma 2.2') 66 | 67 | 68 | def srgb_to_prop_cat02(x): 69 | # assert is_valid_type(x) and is_valid_domain(x) 70 | m_srgb_to_prop = matrix_RGB_to_RGB(RGB_COLOURSPACE_sRGB, RGB_COLOURSPACE_PROPHOTO_RGB, 71 | chromatic_adaptation_transform='CAT02') 72 | return np.einsum('...ij,...j->...i', m_srgb_to_prop, x) 73 | 74 | 75 | def srgb_to_prop_bradford(x): 76 | assert is_valid_type(x) and is_valid_domain(x) 77 | m_srgb_to_prop = matrix_RGB_to_RGB(RGB_COLOURSPACE_sRGB, RGB_COLOURSPACE_PROPHOTO_RGB, 78 | chromatic_adaptation_transform='Bradford') 79 | return np.einsum('...ij,...j->...i', m_srgb_to_prop, x) 80 | 81 | def prop_to_srgb_cat02(x): 82 | assert is_valid_type(x) and is_valid_domain(x) 83 | m_prop_to_srgb = matrix_RGB_to_RGB(RGB_COLOURSPACE_PROPHOTO_RGB, RGB_COLOURSPACE_sRGB, 84 | chromatic_adaptation_transform='CAT02') 85 | return np.einsum('...ij,...j->...i', m_prop_to_srgb, x) 86 | 87 | def srgb_to_Lab_D50(img_srgb): 88 | # img_srgb has dim = MxNx3 89 | # img_srgb scale: 0-1 and already decoded 90 | # Scale: 91 | ## L: 0-100 92 | ## a, b: -100 - 100 93 | ## need to normalize 94 | srgb_colourspace = colour.models.RGB_COLOURSPACE_sRGB.chromatically_adapt(ILLUMINANT_D50) 95 | 96 | XYZ = colour.models.RGB_to_XYZ(img_srgb.reshape(-1,3), 97 | srgb_colourspace.whitepoint, 98 | srgb_colourspace.whitepoint, 99 | srgb_colourspace.matrix_RGB_to_XYZ) 100 | Lab_srgb = colour.XYZ_to_Lab(XYZ, 101 | illuminant=ILLUMINANT_D50) 102 | 103 | return Lab_srgb 104 | 105 | def lab_to_srgb_D50(Lab): 106 | # Scale: 107 | ## L: 0-100 108 | ## a, b: -100 - 100 109 | srgb_colourspace = colour.models.RGB_COLOURSPACE_sRGB.chromatically_adapt(ILLUMINANT_D50) 110 | XYZ = colour.Lab_to_XYZ(Lab, 111 | illuminant=ILLUMINANT_D50) 112 | 113 | img = colour.models.XYZ_to_RGB(XYZ, 114 | srgb_colourspace.whitepoint, 115 | srgb_colourspace.whitepoint, 116 | srgb_colourspace.matrix_XYZ_to_RGB) 117 | # img has dim Nx3, and not encoded yet. 118 | return img 119 | 120 | 121 | def prop_to_Lab_D50(img_prop): 122 | # img_srgb has dim = MxNx3 123 | # img_srgb scale: 0-1 and already decoded 124 | prophotorgb_colourspace = colour.models.RGB_COLOURSPACE_PROPHOTO_RGB 125 | 126 | XYZ = colour.models.RGB_to_XYZ(img_prop.reshape(-1,3), 127 | prophotorgb_colourspace.whitepoint, 128 | prophotorgb_colourspace.whitepoint, 129 | prophotorgb_colourspace.matrix_RGB_to_XYZ) 130 | Lab_prop = colour.XYZ_to_Lab(XYZ, 131 | illuminant=ILLUMINANT_D50) 132 | # Scale: 133 | ## L: 0-100 134 | ## a, b: -100 - 100 135 | ## need to normalize 136 | return Lab_prop 137 | 138 | def lab_to_prop_D50(Lab): 139 | # Scale: 140 | ## L: 0-100 141 | ## a, b: -100 - 100 142 | prophotorgb_colourspace = colour.models.RGB_COLOURSPACE_PROPHOTO_RGB 143 | XYZ = colour.Lab_to_XYZ(Lab, 144 | illuminant=ILLUMINANT_D50) 145 | 146 | img = colour.models.XYZ_to_RGB(XYZ, 147 | prophotorgb_colourspace.whitepoint, 148 | prophotorgb_colourspace.whitepoint, 149 | prophotorgb_colourspace.matrix_XYZ_to_RGB) 150 | # img has dim Nx3, and not encoded yet. 151 | return img 152 | 153 | def srgb_to_Lab(img_srgb, illuminant=ILLUMINANT_D65): 154 | # img_srgb has dim = MxNx3 155 | # img_srgb scale: 0-1 and already decoded 156 | # Scale: 157 | ## L: 0-100 158 | ## a, b: -100 - 100 159 | ## need to normalize 160 | srgb_colourspace = colour.models.RGB_COLOURSPACE_sRGB 161 | 162 | XYZ = colour.models.RGB_to_XYZ(img_srgb.reshape(-1,3), 163 | srgb_colourspace.whitepoint, 164 | srgb_colourspace.whitepoint, 165 | srgb_colourspace.matrix_RGB_to_XYZ) 166 | Lab_srgb = colour.XYZ_to_Lab(XYZ, 167 | illuminant=illuminant) 168 | 169 | return Lab_srgb 170 | 171 | def lab_to_srgb(Lab, illuminant=ILLUMINANT_D65): 172 | # Scale: 173 | ## L: 0-100 174 | ## a, b: -100 - 100 175 | srgb_colourspace = colour.models.RGB_COLOURSPACE_sRGB 176 | XYZ = colour.Lab_to_XYZ(Lab, 177 | illuminant=illuminant) 178 | 179 | img = colour.models.XYZ_to_RGB(XYZ, 180 | srgb_colourspace.whitepoint, 181 | srgb_colourspace.whitepoint, 182 | srgb_colourspace.matrix_XYZ_to_RGB) 183 | # img has dim Nx3, and not encoded yet. 184 | return img 185 | 186 | def prop_to_Lab_D65(img_prop): 187 | # img_srgb has dim = MxNx3 188 | # img_srgb scale: 0-1 and already decoded 189 | prophotorgb_colourspace = colour.models.PROPHOTO_RGB_COLOURSPACE.chromatically_adapt(ILLUMINANT_D65) 190 | 191 | XYZ = colour.models.RGB_to_XYZ(img_prop.reshape(-1,3), 192 | prophotorgb_colourspace.whitepoint, 193 | prophotorgb_colourspace.whitepoint, 194 | prophotorgb_colourspace.matrix_RGB_to_XYZ) 195 | Lab_prop = colour.XYZ_to_Lab(XYZ, 196 | illuminant=ILLUMINANT_D65) 197 | # Scale: 198 | ## L: 0-100 199 | ## a, b: -100 - 100 200 | ## need to normalize 201 | return Lab_prop 202 | 203 | def lab_to_prop_D65(Lab): 204 | # Scale: 205 | ## L: 0-100 206 | ## a, b: -100 - 100 207 | prophotorgb_colourspace = colour.models.PROPHOTO_RGB_COLOURSPACE.chromatically_adapt(ILLUMINANT_D65) 208 | XYZ = colour.Lab_to_XYZ(Lab, 209 | illuminant=ILLUMINANT_D65) 210 | 211 | img = colour.models.XYZ_to_RGB(XYZ, 212 | prophotorgb_colourspace.whitepoint, 213 | prophotorgb_colourspace.whitepoint, 214 | prophotorgb_colourspace.matrix_XYZ_to_RGB) 215 | # img has dim Nx3, and not encoded yet. 216 | return img -------------------------------------------------------------------------------- /utils/mask.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | 4 | 5 | def saturated_mask_generator(pw_mask): 6 | for mask_value in np.unique(pw_mask): 7 | if mask_value != 13 and mask_value != 0 and mask_value != 26: 8 | mask = pw_mask == mask_value 9 | yield mask, mask_value 10 | 11 | 12 | def is_black(x: np.ndarray) -> np.ndarray: 13 | if x.dtype == np.uint8: 14 | return (x == 0).all(axis=2) 15 | elif x.dtype in (np.float16, np.float32, np.float64): 16 | return (x == 0.0).all(axis=2) 17 | else: 18 | raise Exception(f"Wrong numpy array type. Expect float or uint8 but got {x.dtype}") 19 | 20 | def is_white(x: np.ndarray) -> np.ndarray: 21 | if x.dtype == np.uint8: 22 | return (x == 255).all(axis=2) 23 | elif x.dtype in (np.float16, np.float32, np.float64): 24 | return (x == 1.0).all(axis=2) 25 | else: 26 | raise Exception(f"Wrong numpy array type. Expect float or uint8 but got {x.dtype}") 27 | 28 | 29 | def is_inner(x: np.ndarray) -> np.ndarray: 30 | return (x != 0) & (x != 255) 31 | 32 | 33 | def compute_masks(input_img): 34 | m_black = is_black(input_img) 35 | m_white = is_white(input_img) 36 | m_inner = is_inner(input_img) 37 | m_inner_all: np.ndarray = m_inner.all(axis=2) # all the R, G, and B are inner values 38 | o2o_mask: np.ndarray = m_black | m_white | m_inner_all # either black, white, or inner_all 39 | m2o_mask: np.ndarray = ~o2o_mask # neither black, white, nor inner_all 40 | return o2o_mask, m2o_mask, m_inner 41 | -------------------------------------------------------------------------------- /utils/metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def __error(a, b): 5 | return a.astype(np.float64) - b.astype(np.float64) 6 | 7 | 8 | def __absolute_error(a, b): 9 | return np.abs(__error(a, b)) 10 | 11 | 12 | def __squared_error(a, b): 13 | return np.power(__error(a, b), 2.0) 14 | 15 | 16 | def l2(a, b, axis=None): 17 | return np.sqrt(np.sum(__squared_error(a, b), axis=axis)) 18 | 19 | 20 | def mae(a, b, axis=None): 21 | return np.mean(__absolute_error(a, b, ), axis=axis) 22 | 23 | 24 | def mse(a, b, axis=None): 25 | return np.mean(__squared_error(a, b), axis=axis) 26 | 27 | 28 | def rmse(a, b, axis=None): 29 | return np.sqrt(mse(a, b, axis)) 30 | 31 | 32 | def psnr(a, b, axis=None): 33 | if a.dtype != b.dtype: 34 | raise Exception(f"Wrong numpy array type. 2 arrays should have the same dtype: {a.dtype} vs {b.dtype}") 35 | if a.dtype == np.uint8: 36 | max_value = 255 37 | elif a.dtype in (np.float16, np.float32, np.float64): 38 | max_value = 1 39 | else: 40 | raise Exception(f"Wrong numpy array type. Expect float or uint8 but got {a.dtype}") 41 | 42 | return 10 * np.log10(max_value ** 2 / mse(a, b, axis)) 43 | --------------------------------------------------------------------------------