├── README.md ├── data ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── base_dataset.cpython-36.pyc │ └── custom_dataset.cpython-36.pyc ├── base_dataset.py ├── custom_dataset.py └── iHarmony4_dataset.py ├── demo_test.py ├── demo_test_iHarmony4.sh ├── demo_test_single.sh ├── examples ├── f1510_1_2.jpg ├── f1510_1_2.png ├── f1601_1_1.jpg ├── f1601_1_1.png ├── f1679_1_2.jpg ├── f1679_1_2.png ├── f1721_1_2.jpg ├── f1721_1_2.png ├── f1806_1_2.jpg ├── f1806_1_2.png ├── f2062_1_1.jpg ├── f2062_1_1.png ├── f2313_1_2.jpg ├── f2313_1_2.png ├── f2374_1_1.jpg ├── f2374_1_1.png ├── f2699_1_2.jpg ├── f2699_1_2.png ├── f2865_1_1.jpg ├── f2865_1_1.png ├── f3017_1_1.jpg ├── f3017_1_1.png ├── f3040_1_1.jpg ├── f3040_1_1.png ├── f3110_1_2.jpg ├── f3110_1_2.png ├── f3309_1_1.jpg ├── f3309_1_1.png ├── f436_1_1.jpg ├── f436_1_1.png ├── f445_1_2.jpg ├── f445_1_2.png ├── f4516_1_1.jpg ├── f4516_1_1.png ├── f4521_1_2.jpg ├── f4521_1_2.png ├── f5067_1_1.jpg ├── f5067_1_1.png ├── f5107_1_1.jpg ├── f5107_1_1.png ├── f523_1_1.jpg ├── f523_1_1.png ├── f685_1_1.jpg ├── f685_1_1.png ├── f727_1_1.jpg ├── f727_1_1.png ├── f76_1_1.jpg ├── f76_1_1.png ├── f814_1_2.jpg └── f814_1_2.png ├── figures ├── augmentation_examples.jpg └── flowchart..jpg ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── base_model.cpython-36.pyc │ ├── laBaseLUTs_model.cpython-36.pyc │ ├── models_x.cpython-36.pyc │ ├── modules.cpython-36.pyc │ └── networks.cpython-36.pyc ├── base_model.py ├── laBaseLUTs_model.py └── modules.py ├── options ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── base_options.cpython-36.pyc │ └── test_options.cpython-36.pyc ├── base_options.py └── test_options.py ├── requirements.txt ├── trilinear_cpp ├── build │ ├── lib.linux-x86_64-3.6 │ │ └── trilinear.cpython-36m-x86_64-linux-gnu.so │ └── temp.linux-x86_64-3.6 │ │ └── src │ │ ├── trilinear_cuda.o │ │ └── trilinear_kernel.o ├── dist │ └── trilinear-0.0.0-py3.6-linux-x86_64.egg ├── setup.py ├── setup.sh ├── src │ ├── trilinear.cpp │ ├── trilinear.h │ ├── trilinear_cuda.cpp │ ├── trilinear_cuda.h │ ├── trilinear_kernel.cu │ └── trilinear_kernel.h └── trilinear.egg-info │ ├── PKG-INFO │ ├── SOURCES.txt │ ├── dependency_links.txt │ └── top_level.txt └── util ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc └── util.cpython-36.pyc └── util.py /README.md: -------------------------------------------------------------------------------- 1 | # SycoNet: Domain Adaptive Image Harmonization 2 | 3 | This is the official repository for the following paper: 4 | 5 | > **Deep Image Harmonization with Learnable Augmentation** [[arXiv]](https://arxiv.org/pdf/2308.00376.pdf)
6 | > 7 | > Li Niu, Junyan Cao, Wenyan Cong, Liqing Zhang
8 | > Accepted by **ICCV 2023**. 9 | > 10 | SycoNet can generate multiple plausible synthetic composite images based on a real image and a foreground mask, which is useful to construct pairs of synthetic composite images and real images for harmonization. We release the SycoNet inference code and model. **The released model is first trained on [iHarmony4](https://github.com/bcmi/Image-Harmonization-Dataset-iHarmony4) and then finetuned on [ccHarmony](https://github.com/bcmi/Image-Harmonization-Dataset-ccHarmony), because ccHarmony can reflect the illumination variation more faithfully.** 11 | 12 |
13 | SycoNet 14 |
15 | 16 | **The released model could be used to generate high-quality synthetic composite images for real images to augment the small-scale training set.** In the examples below, we show several real images and the generated synthetic composite images. 17 | 18 |
19 | SycoNet 20 |
21 | 22 | 23 | 24 | # Setup 25 | 26 | Clone the repository: 27 | ``` 28 | git clone git@github.com:bcmi/SycoNet-Adaptive-Image-Harmonization.git 29 | ``` 30 | Install Anaconda and create a virtual environment: 31 | ``` 32 | conda create -n syconet python=3.6 33 | conda activate syconet 34 | ``` 35 | Install PyTorch: 36 | ``` 37 | conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge 38 | ``` 39 | Install necessary packages: 40 | ``` 41 | pip install -r requirements.txt 42 | ``` 43 | Build Trilinear: 44 | ``` 45 | cd trilinear_cpp 46 | sh setup.sh 47 | ``` 48 | Modify `CUDA_HOME` as your own path in `setup.sh`. You can refer to [this repository](https://github.com/HuiZeng/Image-Adaptive-3DLUT) for more solutions. 49 | 50 | # Inference 51 | 52 | Download SycoNet model `pretrained_net_Er.pth` and 3D LUTs `pretrained_net_LUTs.pth` from [Baidu Cloud](https://pan.baidu.com/s/1wIWxb37yIVccxB0kM-FnnQ) (access code:o4rt) or [Dropbox](https://www.dropbox.com/scl/fo/zo5bbzotkc70psg3dlzz9/AGG6z69_qaRC5N3MydjgjjY?rlkey=39l5uixbym7xhf5sp6tg9515c&st=3o5qatdh&dl=0). Put them in the folder `checkpoints\syco`. 53 | 54 | ## Test on a single image 55 | Modify `real` and `mask` in `demo_test.sh` as your own real image path and foreground mask path respectively. Modify `augment_num` as your expected number of generated composite images per pair of real image and foreground mask. Then, run the following command: 56 | ``` 57 | sh demo_test_single.sh 58 | ``` 59 | Our SycoNet could generate composite images for the input real image and foreground mask in the folder `results\syco\test_pretrained`. 60 | 61 | ## Test on iHarmony4 dataset 62 | 63 | 64 | Download [iHarmony4](https://github.com/bcmi/Image-Harmonization-Dataset-iHarmony4) and modify `dataset_root`, `dataset_name` in `demo_test_iHarmony4.sh` as your own dataset path. Then, run the following command: 65 | 66 | ``` 67 | sh demo_test_iHarmony4.sh 68 | ``` 69 | 70 | Our SycoNet could generate composite images for the input real images and foreground masks in the specified dataset in the folder `results\syco\test_pretrained`. 71 | 72 | 73 | # Other Resources 74 | 75 | + [Image-Harmonization-Dataset-iHarmony4](https://github.com/bcmi/Image-Harmonization-Dataset-iHarmony4) 76 | + [Awesome-Image-Harmonization](https://github.com/bcmi/Awesome-Image-Harmonization) 77 | + [Awesome-Image-Composition](https://github.com/bcmi/Awesome-Object-Insertion) 78 | 79 | 80 | 81 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes all the modules related to data loading and preprocessing 2 | 3 | To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. 4 | You need to implement four functions: 5 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 6 | -- <__len__>: return the size of dataset. 7 | -- <__getitem__>: get a data point from data loader. 8 | -- : (optionally) add dataset-specific options and set default options. 9 | 10 | Now you can use the dataset class by specifying flag '--dataset_mode dummy'. 11 | See our template dataset class 'template_dataset.py' for more details. 12 | """ 13 | import importlib 14 | import torch.utils.data 15 | from data.base_dataset import BaseDataset 16 | 17 | 18 | def find_dataset_using_name(dataset_name): 19 | """Import the module "data/[dataset_name]_dataset.py". 20 | 21 | In the file, the class called DatasetNameDataset() will 22 | be instantiated. It has to be a subclass of BaseDataset, 23 | and it is case-insensitive. 24 | """ 25 | dataset_filename = "data." + dataset_name + "_dataset" 26 | datasetlib = importlib.import_module(dataset_filename) 27 | 28 | dataset = None 29 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 30 | for name, cls in datasetlib.__dict__.items(): 31 | if name.lower() == target_dataset_name.lower() \ 32 | and issubclass(cls, BaseDataset): 33 | dataset = cls 34 | 35 | if dataset is None: 36 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) 37 | 38 | return dataset 39 | 40 | 41 | def get_option_setter(dataset_name): 42 | """Return the static method of the dataset class.""" 43 | dataset_class = find_dataset_using_name(dataset_name) 44 | return dataset_class.modify_commandline_options 45 | 46 | 47 | def create_dataset(opt): 48 | """Create a dataset given the option. 49 | 50 | This function wraps the class CustomDatasetDataLoader. 51 | This is the main interface between this package and 'train.py'/'test.py' 52 | 53 | Example: 54 | >>> from data import create_dataset 55 | >>> dataset = create_dataset(opt) 56 | """ 57 | data_loader = CustomDatasetDataLoader(opt) 58 | dataset = data_loader.load_data() 59 | return dataset 60 | 61 | 62 | class CustomDatasetDataLoader(): 63 | """Wrapper class of Dataset class that performs multi-threaded data loading""" 64 | 65 | def __init__(self, opt): 66 | """Initialize this class 67 | 68 | Step 1: create a dataset instance given the name [dataset_mode] 69 | Step 2: create a multi-threaded data loader. 70 | """ 71 | self.opt = opt 72 | dataset_class = find_dataset_using_name(opt.dataset_mode) 73 | self.dataset = dataset_class(opt) 74 | print("dataset [%s] was created" % type(self.dataset).__name__) 75 | self.dataloader = torch.utils.data.DataLoader( 76 | self.dataset, 77 | batch_size=opt.batch_size, 78 | shuffle=not opt.serial_batches, 79 | drop_last=True, 80 | num_workers=int(opt.num_threads)) 81 | 82 | def load_data(self): 83 | return self 84 | 85 | def __len__(self): 86 | """Return the number of data in the dataset""" 87 | return min(len(self.dataset), self.opt.max_dataset_size) 88 | 89 | def __iter__(self): 90 | """Return a batch of data""" 91 | for i, data in enumerate(self.dataloader): 92 | if i * self.opt.batch_size >= self.opt.max_dataset_size: 93 | break 94 | yield data 95 | -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/base_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/data/__pycache__/base_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/custom_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/data/__pycache__/custom_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets. 2 | 3 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. 4 | """ 5 | import random 6 | import cv2 7 | import numpy as np 8 | import torch.utils.data as data 9 | import albumentations.augmentations.transforms as transforms 10 | from abc import ABC, abstractmethod 11 | from albumentations import HorizontalFlip, RandomResizedCrop, Compose, DualTransform 12 | 13 | 14 | class BaseDataset(data.Dataset, ABC): 15 | """This class is an abstract base class (ABC) for datasets. 16 | 17 | To create a subclass, you need to implement the following four functions: 18 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 19 | -- <__len__>: return the size of dataset. 20 | -- <__getitem__>: get a data point. 21 | -- : (optionally) add dataset-specific options and set default options. 22 | """ 23 | 24 | def __init__(self, opt): 25 | """Initialize the class; save the options in the class 26 | 27 | Parameters: 28 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 29 | """ 30 | self.opt = opt 31 | self.root = opt.dataset_root #mia 32 | 33 | @staticmethod 34 | def modify_commandline_options(parser, is_train): 35 | """Add new dataset-specific options, and rewrite default values for existing options. 36 | 37 | Parameters: 38 | parser -- original option parser 39 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 40 | 41 | Returns: 42 | the modified parser. 43 | """ 44 | return parser 45 | 46 | @abstractmethod 47 | def __len__(self): 48 | """Return the total number of images in the dataset.""" 49 | return 0 50 | 51 | @abstractmethod 52 | def __getitem__(self, index): 53 | """Return a data point and its metadata information. 54 | 55 | Parameters: 56 | index - - a random integer for data indexing 57 | 58 | Returns: 59 | a dictionary of data with their names. It ususally contains the data itself and its metadata information. 60 | """ 61 | pass 62 | 63 | class HCompose(Compose): 64 | def __init__(self, transforms, *args, additional_targets=None, no_nearest_for_masks=True, **kwargs): 65 | if additional_targets is None: 66 | additional_targets = { 67 | 'real': 'image', 68 | 'mask': 'mask' 69 | } 70 | self.additional_targets = additional_targets 71 | super().__init__(transforms, *args, additional_targets=additional_targets, **kwargs) 72 | if no_nearest_for_masks: 73 | for t in transforms: 74 | if isinstance(t, DualTransform): 75 | t._additional_targets['mask'] = 'image' 76 | 77 | 78 | def get_params(opt, size): 79 | w, h = size 80 | new_h = h 81 | new_w = w 82 | if opt.preprocess == 'resize_and_crop': 83 | new_h = new_w = opt.load_size 84 | elif opt.preprocess == 'scale_width_and_crop': 85 | new_w = opt.load_size 86 | new_h = opt.load_size * h // w 87 | 88 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) 89 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) 90 | 91 | flip = random.random() > 0.5 92 | 93 | return {'crop_pos': (x, y), 'flip': flip} 94 | 95 | 96 | def get_transform(opt, params=None, grayscale=False, convert=True): 97 | transform_list = [] 98 | if grayscale: 99 | transform_list.append(transforms.ToGray()) 100 | if opt.preprocess == 'resize_and_crop': 101 | if params is None: 102 | transform_list.append(RandomResizedCrop(256, 256, scale=(0.5, 1.0))) 103 | elif opt.preprocess == 'resize': 104 | transform_list.append(transforms.Resize(256, 256)) 105 | 106 | 107 | if not opt.no_flip: 108 | if params is None: 109 | transform_list.append(HorizontalFlip()) 110 | 111 | return HCompose(transform_list) 112 | 113 | 114 | def __make_power_2(img, base): 115 | ow, oh = img.size 116 | h = int(round(oh / base) * base) 117 | w = int(round(ow / base) * base) 118 | if (h == oh) and (w == ow): 119 | return img 120 | 121 | __print_size_warning(ow, oh, w, h) 122 | return cv2.resize(img, (w, h), interpolation = cv2.INTER_LINEAR) 123 | 124 | 125 | 126 | def __print_size_warning(ow, oh, w, h): 127 | """Print warning information about image size(only print once)""" 128 | if not hasattr(__print_size_warning, 'has_printed'): 129 | print("The image size needs to be a multiple of 4. " 130 | "The loaded image size was (%d, %d), so it was adjusted to " 131 | "(%d, %d). This adjustment will be done to all images " 132 | "whose sizes are not multiples of 4" % (ow, oh, w, h)) 133 | __print_size_warning.has_printed = True 134 | -------------------------------------------------------------------------------- /data/custom_dataset.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torchvision.transforms as transforms 4 | from data.base_dataset import BaseDataset, get_transform 5 | from copy import deepcopy 6 | 7 | class CUSTOMDataset(BaseDataset): 8 | """A template dataset class for you to implement custom datasets.""" 9 | 10 | @staticmethod 11 | def modify_commandline_options(parser, is_train): 12 | """Add new dataset-specific options, and rewrite default values for existing options. 13 | 14 | Parameters: 15 | parser -- original option parser 16 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 17 | 18 | Returns: 19 | the modified parser. 20 | """ 21 | parser.add_argument('--is_train', type=bool, default=True, help='whether in the training phase') 22 | parser.set_defaults(max_dataset_size=float("inf"), new_dataset_option=2.0) # specify dataset-specific default values 23 | return parser 24 | 25 | def __init__(self, opt): 26 | """Initialize this dataset class. 27 | 28 | Parameters: 29 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 30 | 31 | A few things can be done here. 32 | - save the options (have been done in BaseDataset) 33 | - get image paths and meta information of the dataset. 34 | - define the image transformation. 35 | """ 36 | # save the option and dataset root 37 | BaseDataset.__init__(self, opt) 38 | self.real_paths, self.mask_paths = '', '' 39 | self.isTrain = opt.isTrain 40 | if opt.isTrain==False: 41 | print('loading test file: ') 42 | self.real_path = opt.real 43 | self.mask_path = opt.mask 44 | else: 45 | raise NotImplementedError('Sorry, the training code has not been released.') 46 | 47 | self.transform = get_transform(opt) 48 | self.input_transform = transforms.Compose([ 49 | transforms.ToTensor(), 50 | ]) 51 | 52 | def __getitem__(self, index): 53 | sample = self.get_sample() 54 | self.check_sample_types(sample) 55 | sample_raw = deepcopy(sample) 56 | sample = self.augment_sample(sample) 57 | real = self.input_transform(sample['real']) 58 | mask = sample['mask'].astype(np.float32) 59 | 60 | real_raw = self.input_transform(sample_raw['real']) 61 | mask_raw = sample_raw['mask'].astype(np.float32) 62 | 63 | output = { 64 | 'mask': mask[np.newaxis, ...].astype(np.float32), 65 | 'real': real, 66 | 'mask_raw': mask_raw, 67 | 'real_raw': real_raw, 68 | 'img_path':sample['img_path'] 69 | } 70 | 71 | return output 72 | 73 | 74 | def check_sample_types(self, sample): 75 | assert sample['real'].dtype == 'uint8' 76 | 77 | 78 | def augment_sample(self, sample): 79 | if self.transform is None: 80 | return sample 81 | 82 | additional_targets = {target_name: sample[target_name] 83 | for target_name in self.transform.additional_targets.keys()} 84 | 85 | valid_augmentation = False 86 | while not valid_augmentation: 87 | aug_output = self.transform(image=sample['real'], **additional_targets) 88 | valid_augmentation = self.check_augmented_sample(aug_output) 89 | 90 | for target_name, transformed_target in aug_output.items(): 91 | sample[target_name] = transformed_target 92 | 93 | return sample 94 | 95 | def check_augmented_sample(self, aug_output): 96 | return aug_output['mask'].sum() > 1.0 97 | 98 | 99 | def get_sample(self): 100 | real = cv2.imread(self.real_path) 101 | real = cv2.cvtColor(real, cv2.COLOR_BGR2RGB) 102 | mask = cv2.imread(self.mask_path) 103 | mask = mask[:, :, 0].astype(np.float32) / 255. 104 | mask = mask.astype(np.uint8) 105 | 106 | return {'mask': mask, 'real': real, 'img_path': self.real_path} 107 | 108 | def __len__(self): 109 | """Return the total number of images.""" 110 | return 1 111 | -------------------------------------------------------------------------------- /data/iHarmony4_dataset.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import random 3 | import numpy as np 4 | import os.path 5 | import torchvision.transforms as transforms 6 | from data.base_dataset import BaseDataset, get_transform 7 | from copy import deepcopy 8 | 9 | class iHarmony4Dataset(BaseDataset): 10 | """A template dataset class for you to implement custom datasets.""" 11 | @staticmethod 12 | def modify_commandline_options(parser, is_train): 13 | """Add new dataset-specific options, and rewrite default values for existing options. 14 | 15 | Parameters: 16 | parser -- original option parser 17 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 18 | 19 | Returns: 20 | the modified parser. 21 | """ 22 | parser.add_argument('--is_train', type=bool, default=True, help='whether in the training phase') 23 | parser.set_defaults(max_dataset_size=float("inf"), new_dataset_option=2.0) # specify dataset-specific default values 24 | return parser 25 | 26 | def __init__(self, opt): 27 | """Initialize this dataset class. 28 | 29 | Parameters: 30 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 31 | 32 | A few things can be done here. 33 | - save the options (have been done in BaseDataset) 34 | - get image paths and meta information of the dataset. 35 | - define the image transformation. 36 | """ 37 | # save the option and dataset root 38 | BaseDataset.__init__(self, opt) 39 | self.image_paths = [] 40 | self.isTrain = opt.isTrain 41 | if opt.isTrain==False: 42 | print('loading test file: ') 43 | self.keep_background_prob = -1 44 | if opt.dataset_name in ['HAdobe5k', 'HCOCO', 'HFlickr', 'Hday2night']: 45 | self.trainfile = os.path.join(opt.dataset_root, opt.dataset_name, opt.dataset_name + '_test.txt') 46 | with open(self.trainfile,'r') as f: 47 | for line in f.readlines(): 48 | self.image_paths.append(os.path.join(opt.dataset_root, opt.dataset_name, 'composite_images', line.rstrip())) 49 | else: 50 | raise NotImplementedError('%s not implemented.' % (opt.dataset_name)) 51 | else: 52 | raise NotImplementedError('Sorry, the training code has not been released.') 53 | 54 | self.transform = get_transform(opt) 55 | self.input_transform = transforms.Compose([ 56 | transforms.ToTensor(), 57 | ]) 58 | 59 | def __getitem__(self, index): 60 | sample = self.get_sample(index) 61 | self.check_sample_types(sample) 62 | sample_raw = deepcopy(sample) 63 | sample = self.augment_sample(sample) 64 | real = self.input_transform(sample['real']) 65 | mask = sample['mask'].astype(np.float32) 66 | 67 | real_raw = self.input_transform(sample_raw['real']) 68 | mask_raw = sample_raw['mask'].astype(np.float32) 69 | 70 | output = { 71 | 'mask': mask[np.newaxis, ...].astype(np.float32), 72 | 'real': real, 73 | 'mask_raw': mask_raw, 74 | 'real_raw': real_raw, 75 | 'img_path':sample['img_path'] 76 | } 77 | 78 | return output 79 | 80 | 81 | def check_sample_types(self, sample): 82 | assert sample['real'].dtype == 'uint8' 83 | 84 | def augment_sample(self, sample): 85 | if self.transform is None: 86 | return sample 87 | 88 | additional_targets = {target_name: sample[target_name] 89 | for target_name in self.transform.additional_targets.keys()} 90 | 91 | valid_augmentation = False 92 | while not valid_augmentation: 93 | aug_output = self.transform(image=sample['real'], **additional_targets) 94 | valid_augmentation = self.check_augmented_sample(sample, aug_output) 95 | 96 | for target_name, transformed_target in aug_output.items(): 97 | sample[target_name] = transformed_target 98 | 99 | return sample 100 | 101 | def check_augmented_sample(self, sample, aug_output): 102 | if self.keep_background_prob < 0.0 or random.random() < self.keep_background_prob: 103 | return True 104 | 105 | return aug_output['mask'].sum() > 1.0 106 | 107 | def get_sample(self, index): 108 | path = self.image_paths[index] 109 | mask_path = self.image_paths[index].replace('composite_images','masks') 110 | mask_path = '_'.join(mask_path.split('_')[:-1]) + '.png' 111 | real_path = self.image_paths[index].replace('composite_images','real_images') 112 | real_path = '_'.join(real_path.split('_')[:-2]) + '.jpg' 113 | 114 | real = cv2.imread(real_path) 115 | real = cv2.cvtColor(real, cv2.COLOR_BGR2RGB) 116 | mask = cv2.imread(mask_path) 117 | mask = mask[:, :, 0].astype(np.float32) / 255. 118 | mask = mask.astype(np.uint8) 119 | 120 | return {'mask': mask, 'real': real, 'img_path': path} 121 | 122 | def __len__(self): 123 | """Return the total number of images.""" 124 | return len(self.image_paths) 125 | -------------------------------------------------------------------------------- /demo_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from util import util 3 | from options.test_options import TestOptions 4 | from data import create_dataset 5 | from models import create_model 6 | 7 | if __name__ == '__main__': 8 | opt = TestOptions().parse() # get test options 9 | # hard-code some parameters for test 10 | opt.num_threads = 0 # test code only supports num_threads = 1 11 | opt.batch_size = 1 # test code only supports batch_size = 1 12 | opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed. 13 | opt.no_flip = True # no flip; comment this line if results on flipped images are needed. 14 | opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file. 15 | dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options 16 | model = create_model(opt) # create a model given opt.model and other options 17 | model.setup(opt) # regular setup: load and print networks; create schedulers 18 | div_num = opt.augment_num 19 | 20 | print('total number of test images: %d' % len(dataset)) 21 | 22 | save_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.epoch)) # define the website directory 23 | os.makedirs(save_dir, exist_ok=True) 24 | if opt.eval: 25 | model.eval() 26 | 27 | for i, data in enumerate(dataset): 28 | if i >= opt.num_test: # only apply our model to opt.num_test images. 29 | break 30 | 31 | for j in range(div_num): 32 | model.set_input(data) # unpack data from data loader 33 | model.test() # run inference 34 | visuals = model.get_current_visuals() # get image results 35 | img_path = str(data['img_path'][0]) 36 | _, image_name = os.path.split(img_path) 37 | image_name, _ = os.path.splitext(image_name) 38 | 39 | for label, im_data in visuals.items(): 40 | if label=='transfer_img_c': 41 | save_path = os.path.join(save_dir, image_name + '_' + str(j) + '.jpg') 42 | output_c = util.tensor2im(im_data) 43 | util.save_image(output_c, save_path, aspect_ratio=opt.aspect_ratio) 44 | 45 | print(f'[{i}], {image_name}, z num: {j}') 46 | 47 | -------------------------------------------------------------------------------- /demo_test_iHarmony4.sh: -------------------------------------------------------------------------------- 1 | python demo_test.py \ 2 | --name syco \ 3 | --checkpoints_dir checkpoints \ 4 | --model laBaseLUTs \ 5 | --netEr Syco \ 6 | --epoch pretrained \ 7 | --dataset_mode iHarmony4 \ 8 | --gpu_ids 0 \ 9 | --is_train 0 \ 10 | --preprocess resize \ 11 | --norm batch \ 12 | --nz 32 \ 13 | --dataset_root ./iHarmony4_dataset \ 14 | --dataset_name Hday2night \ 15 | --results_dir results \ 16 | --augment_num 10 \ 17 | --keep_res \ 18 | --eval \ 19 | 20 | -------------------------------------------------------------------------------- /demo_test_single.sh: -------------------------------------------------------------------------------- 1 | python demo_test.py \ 2 | --name syco \ 3 | --checkpoints_dir checkpoints \ 4 | --model laBaseLUTs \ 5 | --netEr Syco \ 6 | --epoch pretrained \ 7 | --dataset_mode custom \ 8 | --gpu_ids 0 \ 9 | --is_train 0 \ 10 | --real examples/f436_1_1.jpg \ 11 | --mask examples/f436_1_1.png \ 12 | --results_dir results \ 13 | --augment_num 10 \ 14 | --keep_res \ 15 | --eval \ 16 | -------------------------------------------------------------------------------- /examples/f1510_1_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f1510_1_2.jpg -------------------------------------------------------------------------------- /examples/f1510_1_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f1510_1_2.png -------------------------------------------------------------------------------- /examples/f1601_1_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f1601_1_1.jpg -------------------------------------------------------------------------------- /examples/f1601_1_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f1601_1_1.png -------------------------------------------------------------------------------- /examples/f1679_1_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f1679_1_2.jpg -------------------------------------------------------------------------------- /examples/f1679_1_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f1679_1_2.png -------------------------------------------------------------------------------- /examples/f1721_1_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f1721_1_2.jpg -------------------------------------------------------------------------------- /examples/f1721_1_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f1721_1_2.png -------------------------------------------------------------------------------- /examples/f1806_1_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f1806_1_2.jpg -------------------------------------------------------------------------------- /examples/f1806_1_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f1806_1_2.png -------------------------------------------------------------------------------- /examples/f2062_1_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f2062_1_1.jpg -------------------------------------------------------------------------------- /examples/f2062_1_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f2062_1_1.png -------------------------------------------------------------------------------- /examples/f2313_1_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f2313_1_2.jpg -------------------------------------------------------------------------------- /examples/f2313_1_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f2313_1_2.png -------------------------------------------------------------------------------- /examples/f2374_1_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f2374_1_1.jpg -------------------------------------------------------------------------------- /examples/f2374_1_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f2374_1_1.png -------------------------------------------------------------------------------- /examples/f2699_1_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f2699_1_2.jpg -------------------------------------------------------------------------------- /examples/f2699_1_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f2699_1_2.png -------------------------------------------------------------------------------- /examples/f2865_1_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f2865_1_1.jpg -------------------------------------------------------------------------------- /examples/f2865_1_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f2865_1_1.png -------------------------------------------------------------------------------- /examples/f3017_1_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f3017_1_1.jpg -------------------------------------------------------------------------------- /examples/f3017_1_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f3017_1_1.png -------------------------------------------------------------------------------- /examples/f3040_1_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f3040_1_1.jpg -------------------------------------------------------------------------------- /examples/f3040_1_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f3040_1_1.png -------------------------------------------------------------------------------- /examples/f3110_1_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f3110_1_2.jpg -------------------------------------------------------------------------------- /examples/f3110_1_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f3110_1_2.png -------------------------------------------------------------------------------- /examples/f3309_1_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f3309_1_1.jpg -------------------------------------------------------------------------------- /examples/f3309_1_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f3309_1_1.png -------------------------------------------------------------------------------- /examples/f436_1_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f436_1_1.jpg -------------------------------------------------------------------------------- /examples/f436_1_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f436_1_1.png -------------------------------------------------------------------------------- /examples/f445_1_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f445_1_2.jpg -------------------------------------------------------------------------------- /examples/f445_1_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f445_1_2.png -------------------------------------------------------------------------------- /examples/f4516_1_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f4516_1_1.jpg -------------------------------------------------------------------------------- /examples/f4516_1_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f4516_1_1.png -------------------------------------------------------------------------------- /examples/f4521_1_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f4521_1_2.jpg -------------------------------------------------------------------------------- /examples/f4521_1_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f4521_1_2.png -------------------------------------------------------------------------------- /examples/f5067_1_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f5067_1_1.jpg -------------------------------------------------------------------------------- /examples/f5067_1_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f5067_1_1.png -------------------------------------------------------------------------------- /examples/f5107_1_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f5107_1_1.jpg -------------------------------------------------------------------------------- /examples/f5107_1_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f5107_1_1.png -------------------------------------------------------------------------------- /examples/f523_1_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f523_1_1.jpg -------------------------------------------------------------------------------- /examples/f523_1_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f523_1_1.png -------------------------------------------------------------------------------- /examples/f685_1_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f685_1_1.jpg -------------------------------------------------------------------------------- /examples/f685_1_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f685_1_1.png -------------------------------------------------------------------------------- /examples/f727_1_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f727_1_1.jpg -------------------------------------------------------------------------------- /examples/f727_1_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f727_1_1.png -------------------------------------------------------------------------------- /examples/f76_1_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f76_1_1.jpg -------------------------------------------------------------------------------- /examples/f76_1_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f76_1_1.png -------------------------------------------------------------------------------- /examples/f814_1_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f814_1_2.jpg -------------------------------------------------------------------------------- /examples/f814_1_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f814_1_2.png -------------------------------------------------------------------------------- /figures/augmentation_examples.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/figures/augmentation_examples.jpg -------------------------------------------------------------------------------- /figures/flowchart..jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/figures/flowchart..jpg -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | """This package contains modules related to objective functions, optimizations, and network architectures. 2 | 3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. 4 | You need to implement the following five functions: 5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 6 | -- : unpack data from dataset and apply preprocessing. 7 | -- : produce intermediate results. 8 | -- : calculate loss, gradients, and update network weights. 9 | -- : (optionally) add model-specific options and set default options. 10 | 11 | In the function <__init__>, you need to define four lists: 12 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 13 | -- self.model_names (str list): define networks used in our training. 14 | -- self.visual_names (str list): specify the images that you want to display and save. 15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. 16 | 17 | Now you can use the model class by specifying flag '--model dummy'. 18 | See our template model class 'template_model.py' for more details. 19 | """ 20 | 21 | import importlib 22 | from models.base_model import BaseModel 23 | 24 | 25 | def find_model_using_name(model_name): 26 | """Import the module "models/[model_name]_model.py". 27 | 28 | In the file, the class called DatasetNameModel() will 29 | be instantiated. It has to be a subclass of BaseModel, 30 | and it is case-insensitive. 31 | """ 32 | model_filename = "models." + model_name + "_model" 33 | modellib = importlib.import_module(model_filename) 34 | model = None 35 | target_model_name = model_name.replace('_', '') + 'model' 36 | for name, cls in modellib.__dict__.items(): 37 | if name.lower() == target_model_name.lower() \ 38 | and issubclass(cls, BaseModel): 39 | model = cls 40 | 41 | if model is None: 42 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 43 | exit(0) 44 | 45 | return model 46 | 47 | 48 | def get_option_setter(model_name): 49 | """Return the static method of the model class.""" 50 | model_class = find_model_using_name(model_name) 51 | return model_class.modify_commandline_options 52 | 53 | 54 | def create_model(opt): 55 | """Create a model given the option. 56 | 57 | This function warps the class CustomDatasetDataLoader. 58 | This is the main interface between this package and 'train.py'/'test.py' 59 | 60 | Example: 61 | >>> from models import create_model 62 | >>> model = create_model(opt) 63 | """ 64 | model = find_model_using_name(opt.model) 65 | instance = model(opt) 66 | print("model [%s] was created" % type(instance).__name__) 67 | return instance 68 | -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/base_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/models/__pycache__/base_model.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/laBaseLUTs_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/models/__pycache__/laBaseLUTs_model.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/models_x.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/models/__pycache__/models_x.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/modules.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/models/__pycache__/modules.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/networks.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/models/__pycache__/networks.cpython-36.pyc -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from abc import ABC, abstractmethod 5 | 6 | 7 | class BaseModel(ABC): 8 | """This class is an abstract base class (ABC) for models. 9 | To create a subclass, you need to implement the following five functions: 10 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 11 | -- : unpack data from dataset and apply preprocessing. 12 | -- : produce intermediate results. 13 | -- : (optionally) add model-specific options and set default options. 14 | """ 15 | 16 | def __init__(self, opt): 17 | """Initialize the BaseModel class. 18 | 19 | Parameters: 20 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 21 | 22 | When creating your custom class, you need to implement your own initialization. 23 | In this fucntion, you should first call 24 | Then, you need to define four lists: 25 | -- self.model_names (str list): define networks used in our training. 26 | -- self.visual_names (str list): specify the images that you want to display and save. 27 | """ 28 | self.opt = opt 29 | self.gpu_ids = opt.gpu_ids 30 | self.isTrain = opt.isTrain 31 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU 32 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir 33 | if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark. 34 | torch.backends.cudnn.benchmark = True 35 | self.model_names = [] 36 | self.visual_names = [] 37 | self.image_paths = [] 38 | self.metric = 0 # used for learning rate policy 'plateau' 39 | 40 | @staticmethod 41 | def modify_commandline_options(parser, is_train): 42 | """Add new model-specific options, and rewrite default values for existing options. 43 | 44 | Parameters: 45 | parser -- original option parser 46 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 47 | 48 | Returns: 49 | the modified parser. 50 | """ 51 | return parser 52 | 53 | @abstractmethod 54 | def set_input(self, input): 55 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 56 | 57 | Parameters: 58 | input (dict): includes the data itself and its metadata information. 59 | """ 60 | pass 61 | 62 | @abstractmethod 63 | def forward(self): 64 | """Run forward pass; called by both functions and .""" 65 | pass 66 | 67 | def setup(self, opt): 68 | """Load and print networks; create schedulers 69 | 70 | Parameters: 71 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 72 | """ 73 | 74 | if not self.isTrain: 75 | load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch 76 | self.load_networks(load_suffix) 77 | self.print_networks(opt.verbose) 78 | 79 | def eval(self): 80 | """Make models eval mode during test time""" 81 | for name in self.model_names: 82 | if isinstance(name, str): 83 | net = getattr(self, 'net' + name) 84 | net.eval() 85 | 86 | def test(self): 87 | """Forward function used in test time. 88 | 89 | This function wraps function in no_grad() so we don't save intermediate steps for backprop 90 | It also calls to produce additional visualization results 91 | """ 92 | with torch.no_grad(): 93 | self.forward() 94 | self.compute_visuals() 95 | 96 | def compute_visuals(self): 97 | """Calculate additional output images for visdom and HTML visualization""" 98 | pass 99 | 100 | def get_image_paths(self): 101 | """ Return image paths that are used to load current data""" 102 | return self.image_paths 103 | 104 | def get_current_visuals(self): 105 | """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" 106 | visual_ret = OrderedDict() 107 | for name in self.visual_names: 108 | if isinstance(name, str): 109 | visual_ret[name] = getattr(self, name) 110 | return visual_ret 111 | 112 | 113 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): 114 | """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" 115 | key = keys[i] 116 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer 117 | if module.__class__.__name__.startswith('InstanceNorm') and \ 118 | (key == 'running_mean' or key == 'running_var'): 119 | if getattr(module, key) is None: 120 | state_dict.pop('.'.join(keys)) 121 | if module.__class__.__name__.startswith('InstanceNorm') and \ 122 | (key == 'num_batches_tracked'): 123 | state_dict.pop('.'.join(keys)) 124 | else: 125 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) 126 | 127 | def load_networks(self, epoch): 128 | """Load all the networks from the disk. 129 | 130 | Parameters: 131 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 132 | """ 133 | 134 | load_filename = '%s_net_LUTs.pth' % epoch 135 | load_path = os.path.join(self.save_dir, load_filename) 136 | print('loading the model from %s' % load_path) 137 | LUTs_state_dict = torch.load(load_path, map_location=str(self.device)) 138 | LUTs = getattr(self, 'baseLUTs') 139 | for ii in range(len(LUTs)): 140 | LUTs[ii].load_state_dict(LUTs_state_dict[str(ii)]) 141 | 142 | for name in self.model_names: 143 | if isinstance(name, str): 144 | load_filename = '%s_net_%s.pth' % (epoch, name) 145 | load_path = os.path.join(self.save_dir, load_filename) 146 | net = getattr(self, 'net' + name) 147 | 148 | if isinstance(net, torch.nn.DataParallel): 149 | net = net.module 150 | print('loading the model from %s' % load_path) 151 | 152 | state_dict = torch.load(load_path, map_location=str(self.device)) 153 | if hasattr(state_dict, '_metadata'): 154 | del state_dict._metadata 155 | 156 | current_state_dict = net.state_dict() 157 | 158 | pretrained_dict = state_dict 159 | current_state_dict.update(pretrained_dict) 160 | net.load_state_dict(current_state_dict, strict=True) 161 | 162 | 163 | 164 | def print_networks(self, verbose): 165 | """Print the total number of parameters in the network and (if verbose) network architecture 166 | 167 | Parameters: 168 | verbose (bool) -- if verbose: print the network architecture 169 | """ 170 | print('---------- Networks initialized -------------') 171 | for name in self.model_names: 172 | if isinstance(name, str): 173 | net = getattr(self, 'net' + name) 174 | num_params = 0 175 | for param in net.parameters(): 176 | num_params += param.numel() 177 | if verbose: 178 | print(net) 179 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) 180 | print('-----------------------------------------------') 181 | 182 | def set_requires_grad(self, nets, requires_grad=False): 183 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 184 | Parameters: 185 | nets (network list) -- a list of networks 186 | requires_grad (bool) -- whether the networks require gradients or not 187 | """ 188 | if not isinstance(nets, list): 189 | nets = [nets] 190 | for net in nets: 191 | if net is not None: 192 | for param in net.parameters(): 193 | param.requires_grad = requires_grad 194 | -------------------------------------------------------------------------------- /models/laBaseLUTs_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_model import BaseModel 3 | from . import modules 4 | from torch import cuda 5 | 6 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 7 | 8 | 9 | class LABASELUTSModel(BaseModel): 10 | def __init__(self, opt): 11 | BaseModel.__init__(self, opt) 12 | self.visual_names = ['mask', 'real', 'transfer_img_c'] 13 | self.model_names = ['Er'] 14 | 15 | self.LUT_num = opt.LUT_num 16 | 17 | self.baseLUTs = [] 18 | for _ in range(self.LUT_num): 19 | self.baseLUTs.append(modules.Get3DLUT_identity(dim=17).to(self.device)) 20 | self.LUT_num = len(self.baseLUTs) 21 | print('total LUT numbers: %d' % self.LUT_num) 22 | 23 | self.netEr = modules.define_E(opt.input_nc, opt.nz, opt.nef, opt.nwf, opt.netEr, opt.norm, 24 | 'lrelu', self.gpu_ids, True, self.LUT_num) 25 | 26 | 27 | self.TV3 = modules.TV_3D().to(self.device) 28 | self.TV3.weight_r = self.TV3.weight_r.type(Tensor) 29 | self.TV3.weight_g = self.TV3.weight_g.type(Tensor) 30 | self.TV3.weight_b = self.TV3.weight_b.type(Tensor) 31 | self.trilinear_ = modules.TrilinearInterpolation() 32 | 33 | 34 | def set_input(self, tgt_input): 35 | self.mask = tgt_input['mask'].to(self.device) 36 | self.real = tgt_input['real'].to(self.device) 37 | self.inputs = torch.cat([self.real,self.mask], 1) 38 | 39 | self.real_raw = tgt_input['real_raw'].to(self.device) 40 | self.mask_raw = tgt_input['mask_raw'].to(self.device) 41 | 42 | 43 | def get_z_random(self, size, random_type='gauss'): 44 | if random_type == 'uni': 45 | z = torch.rand(size) * 2.0 - 1.0 46 | elif random_type == 'gauss': 47 | z = torch.randn(size) 48 | return z.detach().to(self.device) 49 | 50 | 51 | def generator_eval(self, pred, img, mask, LUTs): 52 | 53 | pred = pred.squeeze() 54 | 55 | for ii in range(self.LUT_num): 56 | self.baseLUTs[ii].eval() 57 | 58 | LUT = pred[0] * self.baseLUTs[0].LUT 59 | for idx in range(1, self.LUT_num): 60 | LUT += pred[idx] * self.baseLUTs[idx].LUT 61 | 62 | 63 | _, combine_A = self.trilinear_(LUT,img) 64 | 65 | combine_A = combine_A * mask + img * (1 - mask) 66 | combine_A = torch.clamp(combine_A, 0, 1) 67 | 68 | return combine_A 69 | 70 | 71 | def forward(self): 72 | self.z_size = [self.inputs.size(0), self.opt.nz, 1, 1] 73 | self.z_random = self.get_z_random(self.z_size) 74 | 75 | self.features_c, self.weightPred_c = self.netEr(self.inputs, self.z_random) 76 | if self.opt.keep_res: 77 | self.transfer_img_c = self.generator_eval(self.weightPred_c, self.real_raw, self.mask_raw, self.baseLUTs) 78 | else: 79 | self.transfer_img_c = self.generator_eval(self.weightPred_c, self.real, self.mask, self.baseLUTs) 80 | 81 | 82 | -------------------------------------------------------------------------------- /models/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import functools 3 | import trilinear 4 | import numpy as np 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | 10 | class Identity(nn.Module): 11 | def forward(self, x): 12 | return x 13 | 14 | 15 | def get_norm_layer(norm_type='instance'): 16 | """Return a normalization layer 17 | 18 | Parameters: 19 | norm_type (str) -- the name of the normalization layer: batch | instance | none 20 | 21 | For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). 22 | For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. 23 | """ 24 | if norm_type == 'batch': 25 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) 26 | elif norm_type == 'instance': 27 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) 28 | elif norm_type == 'none': 29 | norm_layer = lambda x: Identity() 30 | else: 31 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 32 | return norm_layer 33 | 34 | def get_non_linearity(layer_type='relu'): 35 | if layer_type == 'relu': 36 | nl_layer = functools.partial(nn.ReLU, inplace=True) 37 | elif layer_type == 'lrelu': 38 | nl_layer = functools.partial( 39 | nn.LeakyReLU, negative_slope=0.2, inplace=True) 40 | elif layer_type == 'elu': 41 | nl_layer = functools.partial(nn.ELU, inplace=True) 42 | else: 43 | raise NotImplementedError( 44 | 'nonlinearity activitation [%s] is not found' % layer_type) 45 | return nl_layer 46 | 47 | 48 | 49 | def init_net(net, gpu_ids=[]): 50 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights 51 | Parameters: 52 | net (network) -- the network to be initialized 53 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 54 | 55 | Return an initialized network. 56 | """ 57 | 58 | if len(gpu_ids) > 0: 59 | assert(torch.cuda.is_available()) 60 | net.to(gpu_ids[0]) 61 | #net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs 62 | return net 63 | 64 | 65 | 66 | def define_E(input_nc, output_nc, nef, nwf, netE, norm='batch', nl='lrelu', gpu_ids=[], linear=True, LUT_num=5): 67 | net = None 68 | norm_layer = get_norm_layer(norm_type=norm) 69 | nl = 'lrelu' # use leaky relu for E 70 | nl_layer = get_non_linearity(layer_type=nl) 71 | 72 | if netE == 'Syco': #Joy 73 | net = SycoNet(input_nc, output_nc, nef, nwf, n_blocks=5, norm_layer=norm_layer, nl_layer=nl_layer, linear=linear, LUT_num=LUT_num) 74 | else: 75 | raise NotImplementedError('Encoder model name [%s] is not recognized' % netE) 76 | 77 | return init_net(net, gpu_ids) 78 | 79 | 80 | 81 | def conv3x3(in_planes, out_planes): 82 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, 83 | padding=1, bias=True) 84 | 85 | 86 | def upsampleConv(inplanes, outplanes, kw, padw): 87 | sequence = [] 88 | sequence += [nn.Upsample(scale_factor=2, mode='nearest')] 89 | sequence += [nn.Conv2d(inplanes, outplanes, kernel_size=kw, 90 | stride=1, padding=padw, bias=True)] 91 | return nn.Sequential(*sequence) 92 | 93 | 94 | def meanpoolConv(inplanes, outplanes): 95 | sequence = [] 96 | sequence += [nn.AvgPool2d(kernel_size=2, stride=2)] 97 | sequence += [nn.Conv2d(inplanes, outplanes, 98 | kernel_size=1, stride=1, padding=0, bias=True)] 99 | return nn.Sequential(*sequence) 100 | 101 | 102 | def convMeanpool(inplanes, outplanes): 103 | sequence = [] 104 | sequence += [conv3x3(inplanes, outplanes)] 105 | sequence += [nn.AvgPool2d(kernel_size=2, stride=2)] 106 | return nn.Sequential(*sequence) 107 | 108 | class BasicBlock(nn.Module): 109 | def __init__(self, inplanes, outplanes, norm_layer=None, nl_layer=None): 110 | super(BasicBlock, self).__init__() 111 | layers = [] 112 | if norm_layer is not None: 113 | layers += [norm_layer(inplanes)] 114 | layers += [nl_layer()] 115 | layers += [conv3x3(inplanes, inplanes)] 116 | if norm_layer is not None: 117 | layers += [norm_layer(inplanes)] 118 | layers += [nl_layer()] 119 | layers += [convMeanpool(inplanes, outplanes)] 120 | self.conv = nn.Sequential(*layers) 121 | self.shortcut = meanpoolConv(inplanes, outplanes) 122 | 123 | def forward(self, x): 124 | out = self.conv(x) + self.shortcut(x) 125 | return out 126 | 127 | 128 | 129 | class SycoNet(nn.Module): 130 | def __init__(self, input_nc=3, nz=1, nef=64, nwf=128, n_blocks=4, norm_layer=None, nl_layer=None, linear=False, LUT_num=5): 131 | super(SycoNet, self).__init__() 132 | self.isLinear = linear 133 | max_nef = 4 134 | self.block0 = nn.Conv2d(input_nc+nz, nef, kernel_size=4, stride=2, padding=1, bias=True) 135 | 136 | self.block1 = BasicBlock(nef * min(max_nef, 1)+nz, nef * min(max_nef, 2), norm_layer, nl_layer) 137 | self.block2 = BasicBlock(nef * min(max_nef, 2)+nz, nef * min(max_nef, 3), norm_layer, nl_layer) 138 | self.block3 = BasicBlock(nef * min(max_nef, 3)+nz, nef * min(max_nef, 4), norm_layer, nl_layer) 139 | self.block4 = BasicBlock(nef * min(max_nef, 4)+nz, nef * min(max_nef, 4), norm_layer, nl_layer) 140 | 141 | self.nl = nn.Sequential(nl_layer(), nn.AvgPool2d(8)) 142 | 143 | self.weight_predictor = nn.Conv2d(nwf, LUT_num, 1, padding=0) 144 | 145 | def forward(self, img_input, random_z): 146 | z_img = random_z.expand(random_z.size(0), random_z.size(1), img_input.size(2), img_input.size(3)) 147 | inputs = torch.cat([img_input, z_img], 1) 148 | x0 = self.block0(inputs) 149 | z0 = random_z.expand(random_z.size(0), random_z.size(1), x0.size(2), x0.size(3)) 150 | x1 = torch.cat([x0, z0], 1) 151 | x1 = self.block1(x1) 152 | z1 = random_z.expand(random_z.size(0), random_z.size(1), x1.size(2), x1.size(3)) 153 | x2 = torch.cat([x1, z1], 1) 154 | x2 = self.block2(x2) 155 | z2 = random_z.expand(random_z.size(0), random_z.size(1), x2.size(2), x2.size(3)) 156 | x3 = torch.cat([x2, z2], 1) 157 | x3 = self.block3(x3) 158 | z3 = random_z.expand(random_z.size(0), random_z.size(1), x3.size(2), x3.size(3)) 159 | x4 = torch.cat([x3, z3], 1) 160 | x4 = self.block4(x4) 161 | features = self.nl(x4) 162 | outputs = self.weight_predictor(features) 163 | if self.isLinear: 164 | outputs = F.softmax(outputs, dim=1) 165 | return features, outputs 166 | 167 | class Get3DLUT_identity(nn.Module): 168 | def __init__(self, dim=17): 169 | super(Get3DLUT_identity, self).__init__() 170 | buffer = np.zeros((3,dim,dim,dim), dtype=np.float32) 171 | 172 | for i in range(0,dim): 173 | for j in range(0,dim): 174 | for k in range(0,dim): 175 | n = i * dim*dim + j * dim + k 176 | #x = lines[n].split() 177 | buffer[0,i,j,k] = 1.0/(dim-1)*k #float(x[0]) 178 | buffer[1,i,j,k] = 1.0/(dim-1)*j #float(x[1]) 179 | buffer[2,i,j,k] = 1.0/(dim-1)*i #float(x[2]) 180 | #print(i,j,k,":",buffer[0,i,j,k],buffer[1,i,j,k],buffer[2,i,j,k]) 181 | self.LUT = nn.Parameter(torch.from_numpy(buffer).requires_grad_(True)) 182 | self.TrilinearInterpolation = TrilinearInterpolation() 183 | 184 | def forward(self, x): 185 | _, output = self.TrilinearInterpolation(self.LUT, x) 186 | return output 187 | 188 | class TrilinearInterpolationFunction(torch.autograd.Function): 189 | @staticmethod 190 | def forward(ctx, lut, x): 191 | x = x.contiguous() 192 | 193 | output = x.new(x.size()) 194 | dim = lut.size()[-1] 195 | shift = dim ** 3 196 | binsize = 1.000001 / (dim-1) 197 | W = x.size(2) 198 | H = x.size(3) 199 | batch = x.size(0) 200 | 201 | assert 1 == trilinear.forward(lut, x, output, dim, shift, binsize, W, H, batch) 202 | 203 | int_package = torch.IntTensor([dim, shift, W, H, batch]) 204 | float_package = torch.FloatTensor([binsize]) 205 | variables = [lut, x, int_package, float_package] 206 | 207 | ctx.save_for_backward(*variables) 208 | 209 | return lut, output 210 | 211 | @staticmethod 212 | def backward(ctx, lut_grad, x_grad): 213 | 214 | lut, x, int_package, float_package = ctx.saved_variables 215 | dim, shift, W, H, batch = int_package 216 | dim, shift, W, H, batch = int(dim), int(shift), int(W), int(H), int(batch) 217 | binsize = float(float_package[0]) 218 | 219 | assert 1 == trilinear.backward(x, x_grad, lut_grad, dim, shift, binsize, W, H, batch) 220 | return lut_grad, x_grad 221 | 222 | 223 | 224 | class TrilinearInterpolation(torch.nn.Module): 225 | def __init__(self): 226 | super(TrilinearInterpolation, self).__init__() 227 | 228 | def forward(self, lut, x): 229 | return TrilinearInterpolationFunction.apply(lut, x) 230 | 231 | 232 | class TV_3D(nn.Module): 233 | def __init__(self, dim=33): 234 | super(TV_3D,self).__init__() 235 | 236 | self.weight_r = torch.ones(3,dim,dim,dim-1, dtype=torch.float) 237 | self.weight_r[:,:,:,(0,dim-2)] *= 2.0 238 | self.weight_g = torch.ones(3,dim,dim-1,dim, dtype=torch.float) 239 | self.weight_g[:,:,(0,dim-2),:] *= 2.0 240 | self.weight_b = torch.ones(3,dim-1,dim,dim, dtype=torch.float) 241 | self.weight_b[:,(0,dim-2),:,:] *= 2.0 242 | self.relu = torch.nn.ReLU() 243 | 244 | def forward(self, LUT): 245 | 246 | dif_r = LUT.LUT[:,:,:,:-1] - LUT.LUT[:,:,:,1:] 247 | dif_g = LUT.LUT[:,:,:-1,:] - LUT.LUT[:,:,1:,:] 248 | dif_b = LUT.LUT[:,:-1,:,:] - LUT.LUT[:,1:,:,:] 249 | tv = torch.mean(torch.mul((dif_r ** 2),self.weight_r)) + torch.mean(torch.mul((dif_g ** 2),self.weight_g)) + torch.mean(torch.mul((dif_b ** 2),self.weight_b)) 250 | 251 | mn = torch.mean(self.relu(dif_r)) + torch.mean(self.relu(dif_g)) + torch.mean(self.relu(dif_b)) 252 | 253 | return tv, mn -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- 1 | """This package options includes option modules: training options, test options, and basic options (used in both training and test).""" 2 | -------------------------------------------------------------------------------- /options/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/options/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /options/__pycache__/base_options.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/options/__pycache__/base_options.cpython-36.pyc -------------------------------------------------------------------------------- /options/__pycache__/test_options.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/options/__pycache__/test_options.cpython-36.pyc -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from util import util 4 | import torch 5 | import models 6 | import data 7 | 8 | 9 | class BaseOptions(): 10 | """This class defines options used during both training and test time. 11 | 12 | It also implements several helper functions such as parsing, printing, and saving the options. 13 | It also gathers additional options defined in functions in both dataset class and model class. 14 | """ 15 | 16 | def __init__(self): 17 | """Reset the class; indicates the class hasn't been initailized""" 18 | self.initialized = False 19 | 20 | def initialize(self, parser): 21 | """Define the common options that are used in both training and test.""" 22 | # basic parameters 23 | parser.add_argument('--dataset_root',type=str, default='/data/caojunyan/datasets/IHD/', help='path to iHarmony4 dataset') #mia 24 | parser.add_argument('--dataset_name',type=str, default='', help='which sub-dataset to load [Hday2night | HVIDIT]') #mia 25 | parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models') 26 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 27 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 28 | # model parameters 29 | parser.add_argument('--model', type=str, default='la', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]') 30 | parser.add_argument('--input_nc', type=int, default=4, help='# of input image channels: 4 for concated comp and mask') #mia 31 | parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale') 32 | parser.add_argument('--nwf', type=int, default=256, help='# of weight predictor filters in the first conv layer') 33 | parser.add_argument('--nef', type=int, default=64, help='# of encoder filters in the first conv layer') # Joy 34 | parser.add_argument('--LUT_num', type=int, default=20, help='# the number of LUTs. maximum: 20') #Joy 35 | parser.add_argument('--nz', type=int, default=32, help='#latent code dim') #Joy 36 | parser.add_argument('--netEr', type=str, default='lut_spacial', help='Encoder architecture. [lut_special]') #Joy 37 | parser.add_argument('--norm', type=str, default='batch', help='instance normalization or batch normalization [instance | batch | none]') 38 | # dataset parameters 39 | parser.add_argument('--dataset_mode', type=str, default='ihd', help='load iHarmony4 dataset') #mia 40 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 41 | parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data') 42 | parser.add_argument('--batch_size', type=int, default=1, help='input batch size') 43 | parser.add_argument('--load_size', type=int, default=286, help='scale images to this size') 44 | parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size') 45 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 46 | parser.add_argument('--preprocess', type=str, default='resize', help='scaling images at load time') 47 | parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') 48 | parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML') 49 | # additional parameters 50 | parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 51 | parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]') 52 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') 53 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') 54 | self.initialized = True 55 | return parser 56 | 57 | def gather_options(self): 58 | """Initialize our parser with basic options(only once). 59 | Add additional model-specific and dataset-specific options. 60 | These options are defined in the function 61 | in model and dataset classes. 62 | """ 63 | if not self.initialized: # check if it has been initialized 64 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 65 | parser = self.initialize(parser) 66 | 67 | # get the basic options 68 | opt, _ = parser.parse_known_args() 69 | 70 | # modify model-related parser options 71 | model_name = opt.model 72 | model_option_setter = models.get_option_setter(model_name) 73 | parser = model_option_setter(parser, self.isTrain) 74 | opt, _ = parser.parse_known_args() # parse again with new defaults 75 | 76 | # modify dataset-related parser options 77 | if opt.dataset_mode is not None: 78 | dataset_name = opt.dataset_mode 79 | dataset_option_setter = data.get_option_setter(dataset_name) 80 | parser = dataset_option_setter(parser, self.isTrain) 81 | else: 82 | dataset_name = 'ihd' 83 | dataset_option_setter = data.get_option_setter(dataset_name) 84 | parser = dataset_option_setter(parser, self.isTrain) 85 | 86 | # save and return the parser 87 | self.parser = parser 88 | return parser.parse_args() 89 | 90 | def print_options(self, opt): 91 | """Print and save options 92 | 93 | It will print both current options and default values(if different). 94 | It will save options into a text file / [checkpoints_dir] / opt.txt 95 | """ 96 | message = '' 97 | message += '----------------- Options ---------------\n' 98 | for k, v in sorted(vars(opt).items()): 99 | comment = '' 100 | default = self.parser.get_default(k) 101 | if v != default: 102 | comment = '\t[default: %s]' % str(default) 103 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 104 | message += '----------------- End -------------------' 105 | print(message) 106 | 107 | # save to the disk 108 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 109 | util.mkdirs(expr_dir) 110 | file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase)) 111 | with open(file_name, 'wt') as opt_file: 112 | opt_file.write(message) 113 | opt_file.write('\n') 114 | 115 | def parse(self): 116 | """Parse our options, create checkpoints directory suffix, and set up gpu device.""" 117 | opt = self.gather_options() 118 | opt.isTrain = self.isTrain # train or test 119 | 120 | # process opt.suffix 121 | if opt.suffix: 122 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' 123 | opt.name = opt.name + suffix 124 | 125 | self.print_options(opt) 126 | 127 | # set gpu ids 128 | str_ids = opt.gpu_ids.split(',') 129 | opt.gpu_ids = [] 130 | for str_id in str_ids: 131 | id = int(str_id) 132 | if id >= 0: 133 | opt.gpu_ids.append(id) 134 | if len(opt.gpu_ids) > 0: 135 | torch.cuda.set_device(opt.gpu_ids[0]) 136 | 137 | self.opt = opt 138 | return self.opt 139 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | """This class includes test options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) # define shared options 12 | parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') 13 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 14 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 15 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 16 | # Dropout and Batchnorm has different behavioir during training and test. 17 | parser.add_argument('--eval', action='store_true', help='use eval mode during test time.') 18 | parser.add_argument('--num_test', type=int, default=7404, help='how many test images to run, for iHarmony4, the number is 7404') #mia 19 | parser.add_argument('--augment_num', type=int, default=5, help='how many augmented reults to generate for each real image.') 20 | parser.add_argument('--keep_res', action='store_true', help='keep the input\'s resolution for augmented output.') 21 | parser.add_argument('--real', type=str, default='', help='the real image for learnable augmentation.') 22 | parser.add_argument('--mask', type=str, default='', help='the foreground mask of real image.') 23 | # rewrite devalue values 24 | parser.set_defaults(model='test') 25 | # To avoid cropping, the load_size should be the same as crop_size 26 | parser.set_defaults(load_size=parser.get_default('crop_size')) 27 | self.isTrain = False 28 | return parser 29 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv_python==4.1.1.26 2 | albumentations==0.5.2 --no-binary albumentations 3 | -------------------------------------------------------------------------------- /trilinear_cpp/build/lib.linux-x86_64-3.6/trilinear.cpython-36m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/trilinear_cpp/build/lib.linux-x86_64-3.6/trilinear.cpython-36m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /trilinear_cpp/build/temp.linux-x86_64-3.6/src/trilinear_cuda.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/trilinear_cpp/build/temp.linux-x86_64-3.6/src/trilinear_cuda.o -------------------------------------------------------------------------------- /trilinear_cpp/build/temp.linux-x86_64-3.6/src/trilinear_kernel.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/trilinear_cpp/build/temp.linux-x86_64-3.6/src/trilinear_kernel.o -------------------------------------------------------------------------------- /trilinear_cpp/dist/trilinear-0.0.0-py3.6-linux-x86_64.egg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/trilinear_cpp/dist/trilinear-0.0.0-py3.6-linux-x86_64.egg -------------------------------------------------------------------------------- /trilinear_cpp/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | import torch 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension 4 | 5 | if torch.cuda.is_available(): 6 | print('Including CUDA code.') 7 | setup( 8 | name='trilinear', 9 | ext_modules=[ 10 | CUDAExtension('trilinear', [ 11 | 'src/trilinear_cuda.cpp', 12 | 'src/trilinear_kernel.cu', 13 | ]) 14 | ], 15 | cmdclass={ 16 | 'build_ext': BuildExtension 17 | }) 18 | else: 19 | print('NO CUDA is found. Fall back to CPU.') 20 | setup(name='trilinear', 21 | ext_modules=[CppExtension('trilinear', ['src/trilinear.cpp'])], 22 | cmdclass={'build_ext': BuildExtension}) 23 | -------------------------------------------------------------------------------- /trilinear_cpp/setup.sh: -------------------------------------------------------------------------------- 1 | export CUDA_HOME=/usr/local/cuda-11.3 && python3 setup.py install 2 | -------------------------------------------------------------------------------- /trilinear_cpp/src/trilinear.cpp: -------------------------------------------------------------------------------- 1 | #include "trilinear.h" 2 | 3 | 4 | void TriLinearForwardCpu(const float* lut, const float* image, float* output, const int dim, const int shift, const float binsize, const int width, const int height, const int channels); 5 | 6 | void TriLinearBackwardCpu(const float* image, const float* image_grad, float* lut_grad, const int dim, const int shift, const float binsize, const int width, const int height, const int channels); 7 | 8 | int trilinear_forward(torch::Tensor lut, torch::Tensor image, torch::Tensor output, 9 | int lut_dim, int shift, float binsize, int width, int height, int batch) 10 | { 11 | // Grab the input tensor 12 | float * lut_flat = lut.data(); 13 | float * image_flat = image.data(); 14 | float * output_flat = output.data(); 15 | 16 | // whether color image 17 | auto image_size = image.sizes(); 18 | int channels = image_size[1]; 19 | if (channels != 3) 20 | { 21 | return 0; 22 | } 23 | 24 | TriLinearForwardCpu(lut_flat, image_flat, output_flat, lut_dim, shift, binsize, width, height, channels); 25 | 26 | return 1; 27 | } 28 | 29 | int trilinear_backward(torch::Tensor image, torch::Tensor image_grad, torch::Tensor lut_grad, 30 | int lut_dim, int shift, float binsize, int width, int height, int batch) 31 | { 32 | // Grab the input tensor 33 | float * image_grad_flat = image_grad.data(); 34 | float * image_flat = image.data(); 35 | float * lut_grad_flat = lut_grad.data(); 36 | 37 | // whether color image 38 | auto image_size = image.sizes(); 39 | int channels = image_size[1]; 40 | if (channels != 3) 41 | { 42 | return 0; 43 | } 44 | 45 | TriLinearBackwardCpu(image_flat, image_grad_flat, lut_grad_flat, lut_dim, shift, binsize, width, height, channels); 46 | 47 | return 1; 48 | } 49 | 50 | void TriLinearForwardCpu(const float* lut, const float* image, float* output, const int dim, const int shift, const float binsize, const int width, const int height, const int channels) 51 | { 52 | const int output_size = height * width;; 53 | 54 | int index = 0; 55 | for (index = 0; index < output_size; ++index) 56 | { 57 | float r = image[index]; 58 | float g = image[index + width * height]; 59 | float b = image[index + width * height * 2]; 60 | 61 | int r_id = floor(r / binsize); 62 | int g_id = floor(g / binsize); 63 | int b_id = floor(b / binsize); 64 | 65 | float r_d = fmod(r,binsize) / binsize; 66 | float g_d = fmod(g,binsize) / binsize; 67 | float b_d = fmod(b,binsize) / binsize; 68 | 69 | int id000 = r_id + g_id * dim + b_id * dim * dim; 70 | int id100 = r_id + 1 + g_id * dim + b_id * dim * dim; 71 | int id010 = r_id + (g_id + 1) * dim + b_id * dim * dim; 72 | int id110 = r_id + 1 + (g_id + 1) * dim + b_id * dim * dim; 73 | int id001 = r_id + g_id * dim + (b_id + 1) * dim * dim; 74 | int id101 = r_id + 1 + g_id * dim + (b_id + 1) * dim * dim; 75 | int id011 = r_id + (g_id + 1) * dim + (b_id + 1) * dim * dim; 76 | int id111 = r_id + 1 + (g_id + 1) * dim + (b_id + 1) * dim * dim; 77 | 78 | float w000 = (1-r_d)*(1-g_d)*(1-b_d); 79 | float w100 = r_d*(1-g_d)*(1-b_d); 80 | float w010 = (1-r_d)*g_d*(1-b_d); 81 | float w110 = r_d*g_d*(1-b_d); 82 | float w001 = (1-r_d)*(1-g_d)*b_d; 83 | float w101 = r_d*(1-g_d)*b_d; 84 | float w011 = (1-r_d)*g_d*b_d; 85 | float w111 = r_d*g_d*b_d; 86 | 87 | output[index] = w000 * lut[id000] + w100 * lut[id100] + 88 | w010 * lut[id010] + w110 * lut[id110] + 89 | w001 * lut[id001] + w101 * lut[id101] + 90 | w011 * lut[id011] + w111 * lut[id111]; 91 | 92 | output[index + width * height] = w000 * lut[id000 + shift] + w100 * lut[id100 + shift] + 93 | w010 * lut[id010 + shift] + w110 * lut[id110 + shift] + 94 | w001 * lut[id001 + shift] + w101 * lut[id101 + shift] + 95 | w011 * lut[id011 + shift] + w111 * lut[id111 + shift]; 96 | 97 | output[index + width * height * 2] = w000 * lut[id000 + shift * 2] + w100 * lut[id100 + shift * 2] + 98 | w010 * lut[id010 + shift * 2] + w110 * lut[id110 + shift * 2] + 99 | w001 * lut[id001 + shift * 2] + w101 * lut[id101 + shift * 2] + 100 | w011 * lut[id011 + shift * 2] + w111 * lut[id111 + shift * 2]; 101 | } 102 | } 103 | 104 | void TriLinearBackwardCpu(const float* image, const float* image_grad, float* lut_grad, const int dim, const int shift, const float binsize, const int width, const int height, const int channels) 105 | { 106 | const int output_size = height * width; 107 | 108 | int index = 0; 109 | for (index = 0; index < output_size; ++index) 110 | { 111 | float r = image[index]; 112 | float g = image[index + width * height]; 113 | float b = image[index + width * height * 2]; 114 | 115 | int r_id = floor(r / binsize); 116 | int g_id = floor(g / binsize); 117 | int b_id = floor(b / binsize); 118 | 119 | float r_d = fmod(r,binsize) / binsize; 120 | float g_d = fmod(g,binsize) / binsize; 121 | float b_d = fmod(b,binsize) / binsize; 122 | 123 | int id000 = r_id + g_id * dim + b_id * dim * dim; 124 | int id100 = r_id + 1 + g_id * dim + b_id * dim * dim; 125 | int id010 = r_id + (g_id + 1) * dim + b_id * dim * dim; 126 | int id110 = r_id + 1 + (g_id + 1) * dim + b_id * dim * dim; 127 | int id001 = r_id + g_id * dim + (b_id + 1) * dim * dim; 128 | int id101 = r_id + 1 + g_id * dim + (b_id + 1) * dim * dim; 129 | int id011 = r_id + (g_id + 1) * dim + (b_id + 1) * dim * dim; 130 | int id111 = r_id + 1 + (g_id + 1) * dim + (b_id + 1) * dim * dim; 131 | 132 | float w000 = (1-r_d)*(1-g_d)*(1-b_d); 133 | float w100 = r_d*(1-g_d)*(1-b_d); 134 | float w010 = (1-r_d)*g_d*(1-b_d); 135 | float w110 = r_d*g_d*(1-b_d); 136 | float w001 = (1-r_d)*(1-g_d)*b_d; 137 | float w101 = r_d*(1-g_d)*b_d; 138 | float w011 = (1-r_d)*g_d*b_d; 139 | float w111 = r_d*g_d*b_d; 140 | 141 | lut_grad[id000] += w000 * image_grad[index]; 142 | lut_grad[id100] += w100 * image_grad[index]; 143 | lut_grad[id010] += w010 * image_grad[index]; 144 | lut_grad[id110] += w110 * image_grad[index]; 145 | lut_grad[id001] += w001 * image_grad[index]; 146 | lut_grad[id101] += w101 * image_grad[index]; 147 | lut_grad[id011] += w011 * image_grad[index]; 148 | lut_grad[id111] += w111 * image_grad[index]; 149 | 150 | lut_grad[id000 + shift] += w000 * image_grad[index + width * height]; 151 | lut_grad[id100 + shift] += w100 * image_grad[index + width * height]; 152 | lut_grad[id010 + shift] += w010 * image_grad[index + width * height]; 153 | lut_grad[id110 + shift] += w110 * image_grad[index + width * height]; 154 | lut_grad[id001 + shift] += w001 * image_grad[index + width * height]; 155 | lut_grad[id101 + shift] += w101 * image_grad[index + width * height]; 156 | lut_grad[id011 + shift] += w011 * image_grad[index + width * height]; 157 | lut_grad[id111 + shift] += w111 * image_grad[index + width * height]; 158 | 159 | lut_grad[id000 + shift* 2] += w000 * image_grad[index + width * height * 2]; 160 | lut_grad[id100 + shift* 2] += w100 * image_grad[index + width * height * 2]; 161 | lut_grad[id010 + shift* 2] += w010 * image_grad[index + width * height * 2]; 162 | lut_grad[id110 + shift* 2] += w110 * image_grad[index + width * height * 2]; 163 | lut_grad[id001 + shift* 2] += w001 * image_grad[index + width * height * 2]; 164 | lut_grad[id101 + shift* 2] += w101 * image_grad[index + width * height * 2]; 165 | lut_grad[id011 + shift* 2] += w011 * image_grad[index + width * height * 2]; 166 | lut_grad[id111 + shift* 2] += w111 * image_grad[index + width * height * 2]; 167 | } 168 | } 169 | 170 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 171 | m.def("forward", &trilinear_forward, "Trilinear forward"); 172 | m.def("backward", &trilinear_backward, "Trilinear backward"); 173 | } 174 | -------------------------------------------------------------------------------- /trilinear_cpp/src/trilinear.h: -------------------------------------------------------------------------------- 1 | #ifndef TRILINEAR_H 2 | #define TRILINEAR_H 3 | 4 | #include 5 | 6 | int trilinear_forward(torch::Tensor lut, torch::Tensor image, torch::Tensor output, 7 | int lut_dim, int shift, float binsize, int width, int height, int batch); 8 | 9 | int trilinear_backward(torch::Tensor image, torch::Tensor image_grad, torch::Tensor lut_grad, 10 | int lut_dim, int shift, float binsize, int width, int height, int batch); 11 | 12 | #endif 13 | -------------------------------------------------------------------------------- /trilinear_cpp/src/trilinear_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include "trilinear_kernel.h" 2 | #include 3 | #include 4 | 5 | int trilinear_forward_cuda(torch::Tensor lut, torch::Tensor image, torch::Tensor output, 6 | int lut_dim, int shift, float binsize, int width, int height, int batch) 7 | { 8 | // Grab the input tensor 9 | float * lut_flat = lut.data(); 10 | float * image_flat = image.data(); 11 | float * output_flat = output.data(); 12 | 13 | TriLinearForwardLaucher(lut_flat, image_flat, output_flat, lut_dim, shift, binsize, width, height, batch, at::cuda::getCurrentCUDAStream()); 14 | 15 | return 1; 16 | } 17 | 18 | int trilinear_backward_cuda(torch::Tensor image, torch::Tensor image_grad, torch::Tensor lut_grad, 19 | int lut_dim, int shift, float binsize, int width, int height, int batch) 20 | { 21 | // Grab the input tensor 22 | float * image_grad_flat = image_grad.data(); 23 | float * image_flat = image.data(); 24 | float * lut_grad_flat = lut_grad.data(); 25 | 26 | TriLinearBackwardLaucher(image_flat, image_grad_flat, lut_grad_flat, lut_dim, shift, binsize, width, height, batch, at::cuda::getCurrentCUDAStream()); 27 | 28 | return 1; 29 | } 30 | 31 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 32 | m.def("forward", &trilinear_forward_cuda, "Trilinear forward"); 33 | m.def("backward", &trilinear_backward_cuda, "Trilinear backward"); 34 | } 35 | 36 | -------------------------------------------------------------------------------- /trilinear_cpp/src/trilinear_cuda.h: -------------------------------------------------------------------------------- 1 | #ifndef TRILINEAR_CUDA_H 2 | #define TRILINEAR_CUDA_H 3 | 4 | #import 5 | 6 | int trilinear_forward_cuda(torch::Tensor lut, torch::Tensor image, torch::Tensor output, 7 | int lut_dim, int shift, float binsize, int width, int height, int batch); 8 | 9 | int trilinear_backward_cuda(torch::Tensor image, torch::Tensor image_grad, torch::Tensor lut_grad, 10 | int lut_dim, int shift, float binsize, int width, int height, int batch); 11 | 12 | #endif 13 | -------------------------------------------------------------------------------- /trilinear_cpp/src/trilinear_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "trilinear_kernel.h" 4 | 5 | #define CUDA_1D_KERNEL_LOOP(i, n) \ 6 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ 7 | i += blockDim.x * gridDim.x) 8 | 9 | 10 | __global__ void TriLinearForward(const int nthreads, const float* lut, const float* image, float* output, const int dim, const int shift, const float binsize, const int width, const int height, const int batch) { 11 | CUDA_1D_KERNEL_LOOP(index, nthreads) { 12 | 13 | float r = image[index]; 14 | float g = image[index + width * height * batch]; 15 | float b = image[index + width * height * batch * 2]; 16 | 17 | int r_id = floor(r / binsize); 18 | int g_id = floor(g / binsize); 19 | int b_id = floor(b / binsize); 20 | 21 | float r_d = fmod(r,binsize) / binsize; 22 | float g_d = fmod(g,binsize) / binsize; 23 | float b_d = fmod(b,binsize) / binsize; 24 | 25 | int id000 = r_id + g_id * dim + b_id * dim * dim; 26 | int id100 = r_id + 1 + g_id * dim + b_id * dim * dim; 27 | int id010 = r_id + (g_id + 1) * dim + b_id * dim * dim; 28 | int id110 = r_id + 1 + (g_id + 1) * dim + b_id * dim * dim; 29 | int id001 = r_id + g_id * dim + (b_id + 1) * dim * dim; 30 | int id101 = r_id + 1 + g_id * dim + (b_id + 1) * dim * dim; 31 | int id011 = r_id + (g_id + 1) * dim + (b_id + 1) * dim * dim; 32 | int id111 = r_id + 1 + (g_id + 1) * dim + (b_id + 1) * dim * dim; 33 | 34 | float w000 = (1-r_d)*(1-g_d)*(1-b_d); 35 | float w100 = r_d*(1-g_d)*(1-b_d); 36 | float w010 = (1-r_d)*g_d*(1-b_d); 37 | float w110 = r_d*g_d*(1-b_d); 38 | float w001 = (1-r_d)*(1-g_d)*b_d; 39 | float w101 = r_d*(1-g_d)*b_d; 40 | float w011 = (1-r_d)*g_d*b_d; 41 | float w111 = r_d*g_d*b_d; 42 | 43 | output[index] = w000 * lut[id000] + w100 * lut[id100] + 44 | w010 * lut[id010] + w110 * lut[id110] + 45 | w001 * lut[id001] + w101 * lut[id101] + 46 | w011 * lut[id011] + w111 * lut[id111]; 47 | 48 | output[index + width * height * batch] = w000 * lut[id000 + shift] + w100 * lut[id100 + shift] + 49 | w010 * lut[id010 + shift] + w110 * lut[id110 + shift] + 50 | w001 * lut[id001 + shift] + w101 * lut[id101 + shift] + 51 | w011 * lut[id011 + shift] + w111 * lut[id111 + shift]; 52 | 53 | output[index + width * height * batch * 2] = w000 * lut[id000 + shift * 2] + w100 * lut[id100 + shift * 2] + 54 | w010 * lut[id010 + shift * 2] + w110 * lut[id110 + shift * 2] + 55 | w001 * lut[id001 + shift * 2] + w101 * lut[id101 + shift * 2] + 56 | w011 * lut[id011 + shift * 2] + w111 * lut[id111 + shift * 2]; 57 | 58 | } 59 | } 60 | 61 | 62 | int TriLinearForwardLaucher(const float* lut, const float* image, float* output, const int lut_dim, const int shift, const float binsize, const int width, const int height, const int batch, cudaStream_t stream) { 63 | const int kThreadsPerBlock = 1024; 64 | const int output_size = height * width * batch; 65 | cudaError_t err; 66 | 67 | 68 | TriLinearForward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(output_size, lut, image, output, lut_dim, shift, binsize, width, height, batch); 69 | 70 | err = cudaGetLastError(); 71 | if(cudaSuccess != err) { 72 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); 73 | exit( -1 ); 74 | } 75 | 76 | return 1; 77 | } 78 | 79 | 80 | __global__ void TriLinearBackward(const int nthreads, const float* image, const float* image_grad, float* lut_grad, const int dim, const int shift, const float binsize, const int width, const int height, const int batch) { 81 | CUDA_1D_KERNEL_LOOP(index, nthreads) { 82 | 83 | float r = image[index]; 84 | float g = image[index + width * height * batch]; 85 | float b = image[index + width * height * batch * 2]; 86 | 87 | int r_id = floor(r / binsize); 88 | int g_id = floor(g / binsize); 89 | int b_id = floor(b / binsize); 90 | 91 | float r_d = fmod(r,binsize) / binsize; 92 | float g_d = fmod(g,binsize) / binsize; 93 | float b_d = fmod(b,binsize) / binsize; 94 | 95 | int id000 = r_id + g_id * dim + b_id * dim * dim; 96 | int id100 = r_id + 1 + g_id * dim + b_id * dim * dim; 97 | int id010 = r_id + (g_id + 1) * dim + b_id * dim * dim; 98 | int id110 = r_id + 1 + (g_id + 1) * dim + b_id * dim * dim; 99 | int id001 = r_id + g_id * dim + (b_id + 1) * dim * dim; 100 | int id101 = r_id + 1 + g_id * dim + (b_id + 1) * dim * dim; 101 | int id011 = r_id + (g_id + 1) * dim + (b_id + 1) * dim * dim; 102 | int id111 = r_id + 1 + (g_id + 1) * dim + (b_id + 1) * dim * dim; 103 | 104 | float w000 = (1-r_d)*(1-g_d)*(1-b_d); 105 | float w100 = r_d*(1-g_d)*(1-b_d); 106 | float w010 = (1-r_d)*g_d*(1-b_d); 107 | float w110 = r_d*g_d*(1-b_d); 108 | float w001 = (1-r_d)*(1-g_d)*b_d; 109 | float w101 = r_d*(1-g_d)*b_d; 110 | float w011 = (1-r_d)*g_d*b_d; 111 | float w111 = r_d*g_d*b_d; 112 | 113 | atomicAdd(lut_grad + id000, image_grad[index] * w000); 114 | atomicAdd(lut_grad + id100, image_grad[index] * w100); 115 | atomicAdd(lut_grad + id010, image_grad[index] * w010); 116 | atomicAdd(lut_grad + id110, image_grad[index] * w110); 117 | atomicAdd(lut_grad + id001, image_grad[index] * w001); 118 | atomicAdd(lut_grad + id101, image_grad[index] * w101); 119 | atomicAdd(lut_grad + id011, image_grad[index] * w011); 120 | atomicAdd(lut_grad + id111, image_grad[index] * w111); 121 | 122 | atomicAdd(lut_grad + id000 + shift, image_grad[index + width * height * batch] * w000); 123 | atomicAdd(lut_grad + id100 + shift, image_grad[index + width * height * batch] * w100); 124 | atomicAdd(lut_grad + id010 + shift, image_grad[index + width * height * batch] * w010); 125 | atomicAdd(lut_grad + id110 + shift, image_grad[index + width * height * batch] * w110); 126 | atomicAdd(lut_grad + id001 + shift, image_grad[index + width * height * batch] * w001); 127 | atomicAdd(lut_grad + id101 + shift, image_grad[index + width * height * batch] * w101); 128 | atomicAdd(lut_grad + id011 + shift, image_grad[index + width * height * batch] * w011); 129 | atomicAdd(lut_grad + id111 + shift, image_grad[index + width * height * batch] * w111); 130 | 131 | atomicAdd(lut_grad + id000 + shift * 2, image_grad[index + width * height * batch * 2] * w000); 132 | atomicAdd(lut_grad + id100 + shift * 2, image_grad[index + width * height * batch * 2] * w100); 133 | atomicAdd(lut_grad + id010 + shift * 2, image_grad[index + width * height * batch * 2] * w010); 134 | atomicAdd(lut_grad + id110 + shift * 2, image_grad[index + width * height * batch * 2] * w110); 135 | atomicAdd(lut_grad + id001 + shift * 2, image_grad[index + width * height * batch * 2] * w001); 136 | atomicAdd(lut_grad + id101 + shift * 2, image_grad[index + width * height * batch * 2] * w101); 137 | atomicAdd(lut_grad + id011 + shift * 2, image_grad[index + width * height * batch * 2] * w011); 138 | atomicAdd(lut_grad + id111 + shift * 2, image_grad[index + width * height * batch * 2] * w111); 139 | } 140 | } 141 | 142 | int TriLinearBackwardLaucher(const float* image, const float* image_grad, float* lut_grad, const int lut_dim, const int shift, const float binsize, const int width, const int height, const int batch, cudaStream_t stream) { 143 | const int kThreadsPerBlock = 1024; 144 | const int output_size = height * width * batch; 145 | cudaError_t err; 146 | 147 | TriLinearBackward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(output_size, image, image_grad, lut_grad, lut_dim, shift, binsize, width, height, batch); 148 | 149 | err = cudaGetLastError(); 150 | if(cudaSuccess != err) { 151 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); 152 | exit( -1 ); 153 | } 154 | 155 | return 1; 156 | } 157 | -------------------------------------------------------------------------------- /trilinear_cpp/src/trilinear_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _TRILINEAR_KERNEL 2 | #define _TRILINEAR_KERNEL 3 | 4 | #include 5 | 6 | __global__ void TriLinearForward(const int nthreads, const float* lut, const float* image, float* output, const int dim, const int shift, const float binsize, const int width, const int height, const int batch); 7 | 8 | int TriLinearForwardLaucher(const float* lut, const float* image, float* output, const int lut_dim, const int shift, const float binsize, const int width, const int height, const int batch, cudaStream_t stream); 9 | 10 | __global__ void TriLinearBackward(const int nthreads, const float* image, const float* image_grad, float* lut_grad, const int dim, const int shift, const float binsize, const int width, const int height, const int batch); 11 | 12 | int TriLinearBackwardLaucher(const float* image, const float* image_grad, float* lut_grad, const int lut_dim, const int shift, const float binsize, const int width, const int height, const int batch, cudaStream_t stream); 13 | 14 | 15 | #endif 16 | 17 | -------------------------------------------------------------------------------- /trilinear_cpp/trilinear.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: trilinear 3 | Version: 0.0.0 4 | Summary: UNKNOWN 5 | Home-page: UNKNOWN 6 | License: UNKNOWN 7 | Platform: UNKNOWN 8 | 9 | UNKNOWN 10 | 11 | -------------------------------------------------------------------------------- /trilinear_cpp/trilinear.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | setup.py 2 | src/trilinear_cuda.cpp 3 | src/trilinear_kernel.cu 4 | trilinear.egg-info/PKG-INFO 5 | trilinear.egg-info/SOURCES.txt 6 | trilinear.egg-info/dependency_links.txt 7 | trilinear.egg-info/top_level.txt -------------------------------------------------------------------------------- /trilinear_cpp/trilinear.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /trilinear_cpp/trilinear.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | trilinear 2 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes a miscellaneous collection of useful helper functions.""" 2 | -------------------------------------------------------------------------------- /util/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/util/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/util/__pycache__/util.cpython-36.pyc -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | """This module contains simple helper functions """ 2 | 3 | import os 4 | import torch 5 | import numpy as np 6 | from PIL import Image 7 | 8 | 9 | def tensor2im(input_image, imtype=np.uint8): 10 | """"Converts a Tensor array into a numpy image array. 11 | 12 | Parameters: 13 | input_image (tensor) -- the input image tensor array 14 | imtype (type) -- the desired type of the converted numpy array 15 | """ 16 | 17 | if not isinstance(input_image, np.ndarray): 18 | if isinstance(input_image, torch.Tensor): # get the data from a variable 19 | image_tensor = input_image.data 20 | else: 21 | return input_image 22 | 23 | image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array 24 | if image_numpy.shape[0] == 1: # grayscale to RGB 25 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 26 | 27 | image_numpy = (np.transpose(image_numpy, (1, 2, 0))) * 255.0 28 | 29 | else: # if it is a numpy array, do nothing 30 | image_numpy = input_image 31 | image_numpy = np.clip(image_numpy, 0, 255) 32 | return image_numpy.astype(imtype) 33 | 34 | 35 | def save_image(image_numpy, image_path, aspect_ratio=1.0): 36 | """Save a numpy image to the disk 37 | 38 | Parameters: 39 | image_numpy (numpy array) -- input numpy array 40 | image_path (str) -- the path of the image 41 | """ 42 | 43 | image_pil = Image.fromarray(image_numpy) 44 | h, w, _ = image_numpy.shape 45 | 46 | if aspect_ratio > 1.0: 47 | image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC) 48 | if aspect_ratio < 1.0: 49 | image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC) 50 | image_pil.save(image_path,quality=100) #added by Mia (quality) 51 | 52 | 53 | def mkdirs(paths): 54 | """create empty directories if they don't exist 55 | 56 | Parameters: 57 | paths (str list) -- a list of directory paths 58 | """ 59 | if isinstance(paths, list) and not isinstance(paths, str): 60 | for path in paths: 61 | mkdir(path) 62 | else: 63 | mkdir(paths) 64 | 65 | 66 | def mkdir(path): 67 | """create a single empty directory if it didn't exist 68 | 69 | Parameters: 70 | path (str) -- a single directory path 71 | """ 72 | if not os.path.exists(path): 73 | os.makedirs(path) 74 | 75 | --------------------------------------------------------------------------------