├── README.md
├── data
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── base_dataset.cpython-36.pyc
│ └── custom_dataset.cpython-36.pyc
├── base_dataset.py
├── custom_dataset.py
└── iHarmony4_dataset.py
├── demo_test.py
├── demo_test_iHarmony4.sh
├── demo_test_single.sh
├── examples
├── f1510_1_2.jpg
├── f1510_1_2.png
├── f1601_1_1.jpg
├── f1601_1_1.png
├── f1679_1_2.jpg
├── f1679_1_2.png
├── f1721_1_2.jpg
├── f1721_1_2.png
├── f1806_1_2.jpg
├── f1806_1_2.png
├── f2062_1_1.jpg
├── f2062_1_1.png
├── f2313_1_2.jpg
├── f2313_1_2.png
├── f2374_1_1.jpg
├── f2374_1_1.png
├── f2699_1_2.jpg
├── f2699_1_2.png
├── f2865_1_1.jpg
├── f2865_1_1.png
├── f3017_1_1.jpg
├── f3017_1_1.png
├── f3040_1_1.jpg
├── f3040_1_1.png
├── f3110_1_2.jpg
├── f3110_1_2.png
├── f3309_1_1.jpg
├── f3309_1_1.png
├── f436_1_1.jpg
├── f436_1_1.png
├── f445_1_2.jpg
├── f445_1_2.png
├── f4516_1_1.jpg
├── f4516_1_1.png
├── f4521_1_2.jpg
├── f4521_1_2.png
├── f5067_1_1.jpg
├── f5067_1_1.png
├── f5107_1_1.jpg
├── f5107_1_1.png
├── f523_1_1.jpg
├── f523_1_1.png
├── f685_1_1.jpg
├── f685_1_1.png
├── f727_1_1.jpg
├── f727_1_1.png
├── f76_1_1.jpg
├── f76_1_1.png
├── f814_1_2.jpg
└── f814_1_2.png
├── figures
├── augmentation_examples.jpg
└── flowchart..jpg
├── models
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── base_model.cpython-36.pyc
│ ├── laBaseLUTs_model.cpython-36.pyc
│ ├── models_x.cpython-36.pyc
│ ├── modules.cpython-36.pyc
│ └── networks.cpython-36.pyc
├── base_model.py
├── laBaseLUTs_model.py
└── modules.py
├── options
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── base_options.cpython-36.pyc
│ └── test_options.cpython-36.pyc
├── base_options.py
└── test_options.py
├── requirements.txt
├── trilinear_cpp
├── build
│ ├── lib.linux-x86_64-3.6
│ │ └── trilinear.cpython-36m-x86_64-linux-gnu.so
│ └── temp.linux-x86_64-3.6
│ │ └── src
│ │ ├── trilinear_cuda.o
│ │ └── trilinear_kernel.o
├── dist
│ └── trilinear-0.0.0-py3.6-linux-x86_64.egg
├── setup.py
├── setup.sh
├── src
│ ├── trilinear.cpp
│ ├── trilinear.h
│ ├── trilinear_cuda.cpp
│ ├── trilinear_cuda.h
│ ├── trilinear_kernel.cu
│ └── trilinear_kernel.h
└── trilinear.egg-info
│ ├── PKG-INFO
│ ├── SOURCES.txt
│ ├── dependency_links.txt
│ └── top_level.txt
└── util
├── __init__.py
├── __pycache__
├── __init__.cpython-36.pyc
└── util.cpython-36.pyc
└── util.py
/README.md:
--------------------------------------------------------------------------------
1 | # SycoNet: Domain Adaptive Image Harmonization
2 |
3 | This is the official repository for the following paper:
4 |
5 | > **Deep Image Harmonization with Learnable Augmentation** [[arXiv]](https://arxiv.org/pdf/2308.00376.pdf)
6 | >
7 | > Li Niu, Junyan Cao, Wenyan Cong, Liqing Zhang
8 | > Accepted by **ICCV 2023**.
9 | >
10 | SycoNet can generate multiple plausible synthetic composite images based on a real image and a foreground mask, which is useful to construct pairs of synthetic composite images and real images for harmonization. We release the SycoNet inference code and model. **The released model is first trained on [iHarmony4](https://github.com/bcmi/Image-Harmonization-Dataset-iHarmony4) and then finetuned on [ccHarmony](https://github.com/bcmi/Image-Harmonization-Dataset-ccHarmony), because ccHarmony can reflect the illumination variation more faithfully.**
11 |
12 |
13 |

14 |
15 |
16 | **The released model could be used to generate high-quality synthetic composite images for real images to augment the small-scale training set.** In the examples below, we show several real images and the generated synthetic composite images.
17 |
18 |
19 |

20 |
21 |
22 |
23 |
24 | # Setup
25 |
26 | Clone the repository:
27 | ```
28 | git clone git@github.com:bcmi/SycoNet-Adaptive-Image-Harmonization.git
29 | ```
30 | Install Anaconda and create a virtual environment:
31 | ```
32 | conda create -n syconet python=3.6
33 | conda activate syconet
34 | ```
35 | Install PyTorch:
36 | ```
37 | conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge
38 | ```
39 | Install necessary packages:
40 | ```
41 | pip install -r requirements.txt
42 | ```
43 | Build Trilinear:
44 | ```
45 | cd trilinear_cpp
46 | sh setup.sh
47 | ```
48 | Modify `CUDA_HOME` as your own path in `setup.sh`. You can refer to [this repository](https://github.com/HuiZeng/Image-Adaptive-3DLUT) for more solutions.
49 |
50 | # Inference
51 |
52 | Download SycoNet model `pretrained_net_Er.pth` and 3D LUTs `pretrained_net_LUTs.pth` from [Baidu Cloud](https://pan.baidu.com/s/1wIWxb37yIVccxB0kM-FnnQ) (access code:o4rt) or [Dropbox](https://www.dropbox.com/scl/fo/zo5bbzotkc70psg3dlzz9/AGG6z69_qaRC5N3MydjgjjY?rlkey=39l5uixbym7xhf5sp6tg9515c&st=3o5qatdh&dl=0). Put them in the folder `checkpoints\syco`.
53 |
54 | ## Test on a single image
55 | Modify `real` and `mask` in `demo_test.sh` as your own real image path and foreground mask path respectively. Modify `augment_num` as your expected number of generated composite images per pair of real image and foreground mask. Then, run the following command:
56 | ```
57 | sh demo_test_single.sh
58 | ```
59 | Our SycoNet could generate composite images for the input real image and foreground mask in the folder `results\syco\test_pretrained`.
60 |
61 | ## Test on iHarmony4 dataset
62 |
63 |
64 | Download [iHarmony4](https://github.com/bcmi/Image-Harmonization-Dataset-iHarmony4) and modify `dataset_root`, `dataset_name` in `demo_test_iHarmony4.sh` as your own dataset path. Then, run the following command:
65 |
66 | ```
67 | sh demo_test_iHarmony4.sh
68 | ```
69 |
70 | Our SycoNet could generate composite images for the input real images and foreground masks in the specified dataset in the folder `results\syco\test_pretrained`.
71 |
72 |
73 | # Other Resources
74 |
75 | + [Image-Harmonization-Dataset-iHarmony4](https://github.com/bcmi/Image-Harmonization-Dataset-iHarmony4)
76 | + [Awesome-Image-Harmonization](https://github.com/bcmi/Awesome-Image-Harmonization)
77 | + [Awesome-Image-Composition](https://github.com/bcmi/Awesome-Object-Insertion)
78 |
79 |
80 |
81 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | """This package includes all the modules related to data loading and preprocessing
2 |
3 | To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
4 | You need to implement four functions:
5 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
6 | -- <__len__>: return the size of dataset.
7 | -- <__getitem__>: get a data point from data loader.
8 | -- : (optionally) add dataset-specific options and set default options.
9 |
10 | Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
11 | See our template dataset class 'template_dataset.py' for more details.
12 | """
13 | import importlib
14 | import torch.utils.data
15 | from data.base_dataset import BaseDataset
16 |
17 |
18 | def find_dataset_using_name(dataset_name):
19 | """Import the module "data/[dataset_name]_dataset.py".
20 |
21 | In the file, the class called DatasetNameDataset() will
22 | be instantiated. It has to be a subclass of BaseDataset,
23 | and it is case-insensitive.
24 | """
25 | dataset_filename = "data." + dataset_name + "_dataset"
26 | datasetlib = importlib.import_module(dataset_filename)
27 |
28 | dataset = None
29 | target_dataset_name = dataset_name.replace('_', '') + 'dataset'
30 | for name, cls in datasetlib.__dict__.items():
31 | if name.lower() == target_dataset_name.lower() \
32 | and issubclass(cls, BaseDataset):
33 | dataset = cls
34 |
35 | if dataset is None:
36 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
37 |
38 | return dataset
39 |
40 |
41 | def get_option_setter(dataset_name):
42 | """Return the static method of the dataset class."""
43 | dataset_class = find_dataset_using_name(dataset_name)
44 | return dataset_class.modify_commandline_options
45 |
46 |
47 | def create_dataset(opt):
48 | """Create a dataset given the option.
49 |
50 | This function wraps the class CustomDatasetDataLoader.
51 | This is the main interface between this package and 'train.py'/'test.py'
52 |
53 | Example:
54 | >>> from data import create_dataset
55 | >>> dataset = create_dataset(opt)
56 | """
57 | data_loader = CustomDatasetDataLoader(opt)
58 | dataset = data_loader.load_data()
59 | return dataset
60 |
61 |
62 | class CustomDatasetDataLoader():
63 | """Wrapper class of Dataset class that performs multi-threaded data loading"""
64 |
65 | def __init__(self, opt):
66 | """Initialize this class
67 |
68 | Step 1: create a dataset instance given the name [dataset_mode]
69 | Step 2: create a multi-threaded data loader.
70 | """
71 | self.opt = opt
72 | dataset_class = find_dataset_using_name(opt.dataset_mode)
73 | self.dataset = dataset_class(opt)
74 | print("dataset [%s] was created" % type(self.dataset).__name__)
75 | self.dataloader = torch.utils.data.DataLoader(
76 | self.dataset,
77 | batch_size=opt.batch_size,
78 | shuffle=not opt.serial_batches,
79 | drop_last=True,
80 | num_workers=int(opt.num_threads))
81 |
82 | def load_data(self):
83 | return self
84 |
85 | def __len__(self):
86 | """Return the number of data in the dataset"""
87 | return min(len(self.dataset), self.opt.max_dataset_size)
88 |
89 | def __iter__(self):
90 | """Return a batch of data"""
91 | for i, data in enumerate(self.dataloader):
92 | if i * self.opt.batch_size >= self.opt.max_dataset_size:
93 | break
94 | yield data
95 |
--------------------------------------------------------------------------------
/data/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/data/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/base_dataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/data/__pycache__/base_dataset.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/custom_dataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/data/__pycache__/custom_dataset.cpython-36.pyc
--------------------------------------------------------------------------------
/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
2 |
3 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
4 | """
5 | import random
6 | import cv2
7 | import numpy as np
8 | import torch.utils.data as data
9 | import albumentations.augmentations.transforms as transforms
10 | from abc import ABC, abstractmethod
11 | from albumentations import HorizontalFlip, RandomResizedCrop, Compose, DualTransform
12 |
13 |
14 | class BaseDataset(data.Dataset, ABC):
15 | """This class is an abstract base class (ABC) for datasets.
16 |
17 | To create a subclass, you need to implement the following four functions:
18 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
19 | -- <__len__>: return the size of dataset.
20 | -- <__getitem__>: get a data point.
21 | -- : (optionally) add dataset-specific options and set default options.
22 | """
23 |
24 | def __init__(self, opt):
25 | """Initialize the class; save the options in the class
26 |
27 | Parameters:
28 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
29 | """
30 | self.opt = opt
31 | self.root = opt.dataset_root #mia
32 |
33 | @staticmethod
34 | def modify_commandline_options(parser, is_train):
35 | """Add new dataset-specific options, and rewrite default values for existing options.
36 |
37 | Parameters:
38 | parser -- original option parser
39 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
40 |
41 | Returns:
42 | the modified parser.
43 | """
44 | return parser
45 |
46 | @abstractmethod
47 | def __len__(self):
48 | """Return the total number of images in the dataset."""
49 | return 0
50 |
51 | @abstractmethod
52 | def __getitem__(self, index):
53 | """Return a data point and its metadata information.
54 |
55 | Parameters:
56 | index - - a random integer for data indexing
57 |
58 | Returns:
59 | a dictionary of data with their names. It ususally contains the data itself and its metadata information.
60 | """
61 | pass
62 |
63 | class HCompose(Compose):
64 | def __init__(self, transforms, *args, additional_targets=None, no_nearest_for_masks=True, **kwargs):
65 | if additional_targets is None:
66 | additional_targets = {
67 | 'real': 'image',
68 | 'mask': 'mask'
69 | }
70 | self.additional_targets = additional_targets
71 | super().__init__(transforms, *args, additional_targets=additional_targets, **kwargs)
72 | if no_nearest_for_masks:
73 | for t in transforms:
74 | if isinstance(t, DualTransform):
75 | t._additional_targets['mask'] = 'image'
76 |
77 |
78 | def get_params(opt, size):
79 | w, h = size
80 | new_h = h
81 | new_w = w
82 | if opt.preprocess == 'resize_and_crop':
83 | new_h = new_w = opt.load_size
84 | elif opt.preprocess == 'scale_width_and_crop':
85 | new_w = opt.load_size
86 | new_h = opt.load_size * h // w
87 |
88 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
89 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
90 |
91 | flip = random.random() > 0.5
92 |
93 | return {'crop_pos': (x, y), 'flip': flip}
94 |
95 |
96 | def get_transform(opt, params=None, grayscale=False, convert=True):
97 | transform_list = []
98 | if grayscale:
99 | transform_list.append(transforms.ToGray())
100 | if opt.preprocess == 'resize_and_crop':
101 | if params is None:
102 | transform_list.append(RandomResizedCrop(256, 256, scale=(0.5, 1.0)))
103 | elif opt.preprocess == 'resize':
104 | transform_list.append(transforms.Resize(256, 256))
105 |
106 |
107 | if not opt.no_flip:
108 | if params is None:
109 | transform_list.append(HorizontalFlip())
110 |
111 | return HCompose(transform_list)
112 |
113 |
114 | def __make_power_2(img, base):
115 | ow, oh = img.size
116 | h = int(round(oh / base) * base)
117 | w = int(round(ow / base) * base)
118 | if (h == oh) and (w == ow):
119 | return img
120 |
121 | __print_size_warning(ow, oh, w, h)
122 | return cv2.resize(img, (w, h), interpolation = cv2.INTER_LINEAR)
123 |
124 |
125 |
126 | def __print_size_warning(ow, oh, w, h):
127 | """Print warning information about image size(only print once)"""
128 | if not hasattr(__print_size_warning, 'has_printed'):
129 | print("The image size needs to be a multiple of 4. "
130 | "The loaded image size was (%d, %d), so it was adjusted to "
131 | "(%d, %d). This adjustment will be done to all images "
132 | "whose sizes are not multiples of 4" % (ow, oh, w, h))
133 | __print_size_warning.has_printed = True
134 |
--------------------------------------------------------------------------------
/data/custom_dataset.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torchvision.transforms as transforms
4 | from data.base_dataset import BaseDataset, get_transform
5 | from copy import deepcopy
6 |
7 | class CUSTOMDataset(BaseDataset):
8 | """A template dataset class for you to implement custom datasets."""
9 |
10 | @staticmethod
11 | def modify_commandline_options(parser, is_train):
12 | """Add new dataset-specific options, and rewrite default values for existing options.
13 |
14 | Parameters:
15 | parser -- original option parser
16 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
17 |
18 | Returns:
19 | the modified parser.
20 | """
21 | parser.add_argument('--is_train', type=bool, default=True, help='whether in the training phase')
22 | parser.set_defaults(max_dataset_size=float("inf"), new_dataset_option=2.0) # specify dataset-specific default values
23 | return parser
24 |
25 | def __init__(self, opt):
26 | """Initialize this dataset class.
27 |
28 | Parameters:
29 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
30 |
31 | A few things can be done here.
32 | - save the options (have been done in BaseDataset)
33 | - get image paths and meta information of the dataset.
34 | - define the image transformation.
35 | """
36 | # save the option and dataset root
37 | BaseDataset.__init__(self, opt)
38 | self.real_paths, self.mask_paths = '', ''
39 | self.isTrain = opt.isTrain
40 | if opt.isTrain==False:
41 | print('loading test file: ')
42 | self.real_path = opt.real
43 | self.mask_path = opt.mask
44 | else:
45 | raise NotImplementedError('Sorry, the training code has not been released.')
46 |
47 | self.transform = get_transform(opt)
48 | self.input_transform = transforms.Compose([
49 | transforms.ToTensor(),
50 | ])
51 |
52 | def __getitem__(self, index):
53 | sample = self.get_sample()
54 | self.check_sample_types(sample)
55 | sample_raw = deepcopy(sample)
56 | sample = self.augment_sample(sample)
57 | real = self.input_transform(sample['real'])
58 | mask = sample['mask'].astype(np.float32)
59 |
60 | real_raw = self.input_transform(sample_raw['real'])
61 | mask_raw = sample_raw['mask'].astype(np.float32)
62 |
63 | output = {
64 | 'mask': mask[np.newaxis, ...].astype(np.float32),
65 | 'real': real,
66 | 'mask_raw': mask_raw,
67 | 'real_raw': real_raw,
68 | 'img_path':sample['img_path']
69 | }
70 |
71 | return output
72 |
73 |
74 | def check_sample_types(self, sample):
75 | assert sample['real'].dtype == 'uint8'
76 |
77 |
78 | def augment_sample(self, sample):
79 | if self.transform is None:
80 | return sample
81 |
82 | additional_targets = {target_name: sample[target_name]
83 | for target_name in self.transform.additional_targets.keys()}
84 |
85 | valid_augmentation = False
86 | while not valid_augmentation:
87 | aug_output = self.transform(image=sample['real'], **additional_targets)
88 | valid_augmentation = self.check_augmented_sample(aug_output)
89 |
90 | for target_name, transformed_target in aug_output.items():
91 | sample[target_name] = transformed_target
92 |
93 | return sample
94 |
95 | def check_augmented_sample(self, aug_output):
96 | return aug_output['mask'].sum() > 1.0
97 |
98 |
99 | def get_sample(self):
100 | real = cv2.imread(self.real_path)
101 | real = cv2.cvtColor(real, cv2.COLOR_BGR2RGB)
102 | mask = cv2.imread(self.mask_path)
103 | mask = mask[:, :, 0].astype(np.float32) / 255.
104 | mask = mask.astype(np.uint8)
105 |
106 | return {'mask': mask, 'real': real, 'img_path': self.real_path}
107 |
108 | def __len__(self):
109 | """Return the total number of images."""
110 | return 1
111 |
--------------------------------------------------------------------------------
/data/iHarmony4_dataset.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import random
3 | import numpy as np
4 | import os.path
5 | import torchvision.transforms as transforms
6 | from data.base_dataset import BaseDataset, get_transform
7 | from copy import deepcopy
8 |
9 | class iHarmony4Dataset(BaseDataset):
10 | """A template dataset class for you to implement custom datasets."""
11 | @staticmethod
12 | def modify_commandline_options(parser, is_train):
13 | """Add new dataset-specific options, and rewrite default values for existing options.
14 |
15 | Parameters:
16 | parser -- original option parser
17 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
18 |
19 | Returns:
20 | the modified parser.
21 | """
22 | parser.add_argument('--is_train', type=bool, default=True, help='whether in the training phase')
23 | parser.set_defaults(max_dataset_size=float("inf"), new_dataset_option=2.0) # specify dataset-specific default values
24 | return parser
25 |
26 | def __init__(self, opt):
27 | """Initialize this dataset class.
28 |
29 | Parameters:
30 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
31 |
32 | A few things can be done here.
33 | - save the options (have been done in BaseDataset)
34 | - get image paths and meta information of the dataset.
35 | - define the image transformation.
36 | """
37 | # save the option and dataset root
38 | BaseDataset.__init__(self, opt)
39 | self.image_paths = []
40 | self.isTrain = opt.isTrain
41 | if opt.isTrain==False:
42 | print('loading test file: ')
43 | self.keep_background_prob = -1
44 | if opt.dataset_name in ['HAdobe5k', 'HCOCO', 'HFlickr', 'Hday2night']:
45 | self.trainfile = os.path.join(opt.dataset_root, opt.dataset_name, opt.dataset_name + '_test.txt')
46 | with open(self.trainfile,'r') as f:
47 | for line in f.readlines():
48 | self.image_paths.append(os.path.join(opt.dataset_root, opt.dataset_name, 'composite_images', line.rstrip()))
49 | else:
50 | raise NotImplementedError('%s not implemented.' % (opt.dataset_name))
51 | else:
52 | raise NotImplementedError('Sorry, the training code has not been released.')
53 |
54 | self.transform = get_transform(opt)
55 | self.input_transform = transforms.Compose([
56 | transforms.ToTensor(),
57 | ])
58 |
59 | def __getitem__(self, index):
60 | sample = self.get_sample(index)
61 | self.check_sample_types(sample)
62 | sample_raw = deepcopy(sample)
63 | sample = self.augment_sample(sample)
64 | real = self.input_transform(sample['real'])
65 | mask = sample['mask'].astype(np.float32)
66 |
67 | real_raw = self.input_transform(sample_raw['real'])
68 | mask_raw = sample_raw['mask'].astype(np.float32)
69 |
70 | output = {
71 | 'mask': mask[np.newaxis, ...].astype(np.float32),
72 | 'real': real,
73 | 'mask_raw': mask_raw,
74 | 'real_raw': real_raw,
75 | 'img_path':sample['img_path']
76 | }
77 |
78 | return output
79 |
80 |
81 | def check_sample_types(self, sample):
82 | assert sample['real'].dtype == 'uint8'
83 |
84 | def augment_sample(self, sample):
85 | if self.transform is None:
86 | return sample
87 |
88 | additional_targets = {target_name: sample[target_name]
89 | for target_name in self.transform.additional_targets.keys()}
90 |
91 | valid_augmentation = False
92 | while not valid_augmentation:
93 | aug_output = self.transform(image=sample['real'], **additional_targets)
94 | valid_augmentation = self.check_augmented_sample(sample, aug_output)
95 |
96 | for target_name, transformed_target in aug_output.items():
97 | sample[target_name] = transformed_target
98 |
99 | return sample
100 |
101 | def check_augmented_sample(self, sample, aug_output):
102 | if self.keep_background_prob < 0.0 or random.random() < self.keep_background_prob:
103 | return True
104 |
105 | return aug_output['mask'].sum() > 1.0
106 |
107 | def get_sample(self, index):
108 | path = self.image_paths[index]
109 | mask_path = self.image_paths[index].replace('composite_images','masks')
110 | mask_path = '_'.join(mask_path.split('_')[:-1]) + '.png'
111 | real_path = self.image_paths[index].replace('composite_images','real_images')
112 | real_path = '_'.join(real_path.split('_')[:-2]) + '.jpg'
113 |
114 | real = cv2.imread(real_path)
115 | real = cv2.cvtColor(real, cv2.COLOR_BGR2RGB)
116 | mask = cv2.imread(mask_path)
117 | mask = mask[:, :, 0].astype(np.float32) / 255.
118 | mask = mask.astype(np.uint8)
119 |
120 | return {'mask': mask, 'real': real, 'img_path': path}
121 |
122 | def __len__(self):
123 | """Return the total number of images."""
124 | return len(self.image_paths)
125 |
--------------------------------------------------------------------------------
/demo_test.py:
--------------------------------------------------------------------------------
1 | import os
2 | from util import util
3 | from options.test_options import TestOptions
4 | from data import create_dataset
5 | from models import create_model
6 |
7 | if __name__ == '__main__':
8 | opt = TestOptions().parse() # get test options
9 | # hard-code some parameters for test
10 | opt.num_threads = 0 # test code only supports num_threads = 1
11 | opt.batch_size = 1 # test code only supports batch_size = 1
12 | opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
13 | opt.no_flip = True # no flip; comment this line if results on flipped images are needed.
14 | opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file.
15 | dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
16 | model = create_model(opt) # create a model given opt.model and other options
17 | model.setup(opt) # regular setup: load and print networks; create schedulers
18 | div_num = opt.augment_num
19 |
20 | print('total number of test images: %d' % len(dataset))
21 |
22 | save_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.epoch)) # define the website directory
23 | os.makedirs(save_dir, exist_ok=True)
24 | if opt.eval:
25 | model.eval()
26 |
27 | for i, data in enumerate(dataset):
28 | if i >= opt.num_test: # only apply our model to opt.num_test images.
29 | break
30 |
31 | for j in range(div_num):
32 | model.set_input(data) # unpack data from data loader
33 | model.test() # run inference
34 | visuals = model.get_current_visuals() # get image results
35 | img_path = str(data['img_path'][0])
36 | _, image_name = os.path.split(img_path)
37 | image_name, _ = os.path.splitext(image_name)
38 |
39 | for label, im_data in visuals.items():
40 | if label=='transfer_img_c':
41 | save_path = os.path.join(save_dir, image_name + '_' + str(j) + '.jpg')
42 | output_c = util.tensor2im(im_data)
43 | util.save_image(output_c, save_path, aspect_ratio=opt.aspect_ratio)
44 |
45 | print(f'[{i}], {image_name}, z num: {j}')
46 |
47 |
--------------------------------------------------------------------------------
/demo_test_iHarmony4.sh:
--------------------------------------------------------------------------------
1 | python demo_test.py \
2 | --name syco \
3 | --checkpoints_dir checkpoints \
4 | --model laBaseLUTs \
5 | --netEr Syco \
6 | --epoch pretrained \
7 | --dataset_mode iHarmony4 \
8 | --gpu_ids 0 \
9 | --is_train 0 \
10 | --preprocess resize \
11 | --norm batch \
12 | --nz 32 \
13 | --dataset_root ./iHarmony4_dataset \
14 | --dataset_name Hday2night \
15 | --results_dir results \
16 | --augment_num 10 \
17 | --keep_res \
18 | --eval \
19 |
20 |
--------------------------------------------------------------------------------
/demo_test_single.sh:
--------------------------------------------------------------------------------
1 | python demo_test.py \
2 | --name syco \
3 | --checkpoints_dir checkpoints \
4 | --model laBaseLUTs \
5 | --netEr Syco \
6 | --epoch pretrained \
7 | --dataset_mode custom \
8 | --gpu_ids 0 \
9 | --is_train 0 \
10 | --real examples/f436_1_1.jpg \
11 | --mask examples/f436_1_1.png \
12 | --results_dir results \
13 | --augment_num 10 \
14 | --keep_res \
15 | --eval \
16 |
--------------------------------------------------------------------------------
/examples/f1510_1_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f1510_1_2.jpg
--------------------------------------------------------------------------------
/examples/f1510_1_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f1510_1_2.png
--------------------------------------------------------------------------------
/examples/f1601_1_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f1601_1_1.jpg
--------------------------------------------------------------------------------
/examples/f1601_1_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f1601_1_1.png
--------------------------------------------------------------------------------
/examples/f1679_1_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f1679_1_2.jpg
--------------------------------------------------------------------------------
/examples/f1679_1_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f1679_1_2.png
--------------------------------------------------------------------------------
/examples/f1721_1_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f1721_1_2.jpg
--------------------------------------------------------------------------------
/examples/f1721_1_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f1721_1_2.png
--------------------------------------------------------------------------------
/examples/f1806_1_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f1806_1_2.jpg
--------------------------------------------------------------------------------
/examples/f1806_1_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f1806_1_2.png
--------------------------------------------------------------------------------
/examples/f2062_1_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f2062_1_1.jpg
--------------------------------------------------------------------------------
/examples/f2062_1_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f2062_1_1.png
--------------------------------------------------------------------------------
/examples/f2313_1_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f2313_1_2.jpg
--------------------------------------------------------------------------------
/examples/f2313_1_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f2313_1_2.png
--------------------------------------------------------------------------------
/examples/f2374_1_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f2374_1_1.jpg
--------------------------------------------------------------------------------
/examples/f2374_1_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f2374_1_1.png
--------------------------------------------------------------------------------
/examples/f2699_1_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f2699_1_2.jpg
--------------------------------------------------------------------------------
/examples/f2699_1_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f2699_1_2.png
--------------------------------------------------------------------------------
/examples/f2865_1_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f2865_1_1.jpg
--------------------------------------------------------------------------------
/examples/f2865_1_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f2865_1_1.png
--------------------------------------------------------------------------------
/examples/f3017_1_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f3017_1_1.jpg
--------------------------------------------------------------------------------
/examples/f3017_1_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f3017_1_1.png
--------------------------------------------------------------------------------
/examples/f3040_1_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f3040_1_1.jpg
--------------------------------------------------------------------------------
/examples/f3040_1_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f3040_1_1.png
--------------------------------------------------------------------------------
/examples/f3110_1_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f3110_1_2.jpg
--------------------------------------------------------------------------------
/examples/f3110_1_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f3110_1_2.png
--------------------------------------------------------------------------------
/examples/f3309_1_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f3309_1_1.jpg
--------------------------------------------------------------------------------
/examples/f3309_1_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f3309_1_1.png
--------------------------------------------------------------------------------
/examples/f436_1_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f436_1_1.jpg
--------------------------------------------------------------------------------
/examples/f436_1_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f436_1_1.png
--------------------------------------------------------------------------------
/examples/f445_1_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f445_1_2.jpg
--------------------------------------------------------------------------------
/examples/f445_1_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f445_1_2.png
--------------------------------------------------------------------------------
/examples/f4516_1_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f4516_1_1.jpg
--------------------------------------------------------------------------------
/examples/f4516_1_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f4516_1_1.png
--------------------------------------------------------------------------------
/examples/f4521_1_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f4521_1_2.jpg
--------------------------------------------------------------------------------
/examples/f4521_1_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f4521_1_2.png
--------------------------------------------------------------------------------
/examples/f5067_1_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f5067_1_1.jpg
--------------------------------------------------------------------------------
/examples/f5067_1_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f5067_1_1.png
--------------------------------------------------------------------------------
/examples/f5107_1_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f5107_1_1.jpg
--------------------------------------------------------------------------------
/examples/f5107_1_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f5107_1_1.png
--------------------------------------------------------------------------------
/examples/f523_1_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f523_1_1.jpg
--------------------------------------------------------------------------------
/examples/f523_1_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f523_1_1.png
--------------------------------------------------------------------------------
/examples/f685_1_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f685_1_1.jpg
--------------------------------------------------------------------------------
/examples/f685_1_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f685_1_1.png
--------------------------------------------------------------------------------
/examples/f727_1_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f727_1_1.jpg
--------------------------------------------------------------------------------
/examples/f727_1_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f727_1_1.png
--------------------------------------------------------------------------------
/examples/f76_1_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f76_1_1.jpg
--------------------------------------------------------------------------------
/examples/f76_1_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f76_1_1.png
--------------------------------------------------------------------------------
/examples/f814_1_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f814_1_2.jpg
--------------------------------------------------------------------------------
/examples/f814_1_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/examples/f814_1_2.png
--------------------------------------------------------------------------------
/figures/augmentation_examples.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/figures/augmentation_examples.jpg
--------------------------------------------------------------------------------
/figures/flowchart..jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/figures/flowchart..jpg
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | """This package contains modules related to objective functions, optimizations, and network architectures.
2 |
3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4 | You need to implement the following five functions:
5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6 | -- : unpack data from dataset and apply preprocessing.
7 | -- : produce intermediate results.
8 | -- : calculate loss, gradients, and update network weights.
9 | -- : (optionally) add model-specific options and set default options.
10 |
11 | In the function <__init__>, you need to define four lists:
12 | -- self.loss_names (str list): specify the training losses that you want to plot and save.
13 | -- self.model_names (str list): define networks used in our training.
14 | -- self.visual_names (str list): specify the images that you want to display and save.
15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
16 |
17 | Now you can use the model class by specifying flag '--model dummy'.
18 | See our template model class 'template_model.py' for more details.
19 | """
20 |
21 | import importlib
22 | from models.base_model import BaseModel
23 |
24 |
25 | def find_model_using_name(model_name):
26 | """Import the module "models/[model_name]_model.py".
27 |
28 | In the file, the class called DatasetNameModel() will
29 | be instantiated. It has to be a subclass of BaseModel,
30 | and it is case-insensitive.
31 | """
32 | model_filename = "models." + model_name + "_model"
33 | modellib = importlib.import_module(model_filename)
34 | model = None
35 | target_model_name = model_name.replace('_', '') + 'model'
36 | for name, cls in modellib.__dict__.items():
37 | if name.lower() == target_model_name.lower() \
38 | and issubclass(cls, BaseModel):
39 | model = cls
40 |
41 | if model is None:
42 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
43 | exit(0)
44 |
45 | return model
46 |
47 |
48 | def get_option_setter(model_name):
49 | """Return the static method of the model class."""
50 | model_class = find_model_using_name(model_name)
51 | return model_class.modify_commandline_options
52 |
53 |
54 | def create_model(opt):
55 | """Create a model given the option.
56 |
57 | This function warps the class CustomDatasetDataLoader.
58 | This is the main interface between this package and 'train.py'/'test.py'
59 |
60 | Example:
61 | >>> from models import create_model
62 | >>> model = create_model(opt)
63 | """
64 | model = find_model_using_name(opt.model)
65 | instance = model(opt)
66 | print("model [%s] was created" % type(instance).__name__)
67 | return instance
68 |
--------------------------------------------------------------------------------
/models/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/models/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/base_model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/models/__pycache__/base_model.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/laBaseLUTs_model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/models/__pycache__/laBaseLUTs_model.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/models_x.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/models/__pycache__/models_x.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/modules.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/models/__pycache__/modules.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/networks.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/models/__pycache__/networks.cpython-36.pyc
--------------------------------------------------------------------------------
/models/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from collections import OrderedDict
4 | from abc import ABC, abstractmethod
5 |
6 |
7 | class BaseModel(ABC):
8 | """This class is an abstract base class (ABC) for models.
9 | To create a subclass, you need to implement the following five functions:
10 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
11 | -- : unpack data from dataset and apply preprocessing.
12 | -- : produce intermediate results.
13 | -- : (optionally) add model-specific options and set default options.
14 | """
15 |
16 | def __init__(self, opt):
17 | """Initialize the BaseModel class.
18 |
19 | Parameters:
20 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
21 |
22 | When creating your custom class, you need to implement your own initialization.
23 | In this fucntion, you should first call
24 | Then, you need to define four lists:
25 | -- self.model_names (str list): define networks used in our training.
26 | -- self.visual_names (str list): specify the images that you want to display and save.
27 | """
28 | self.opt = opt
29 | self.gpu_ids = opt.gpu_ids
30 | self.isTrain = opt.isTrain
31 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
32 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
33 | if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
34 | torch.backends.cudnn.benchmark = True
35 | self.model_names = []
36 | self.visual_names = []
37 | self.image_paths = []
38 | self.metric = 0 # used for learning rate policy 'plateau'
39 |
40 | @staticmethod
41 | def modify_commandline_options(parser, is_train):
42 | """Add new model-specific options, and rewrite default values for existing options.
43 |
44 | Parameters:
45 | parser -- original option parser
46 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
47 |
48 | Returns:
49 | the modified parser.
50 | """
51 | return parser
52 |
53 | @abstractmethod
54 | def set_input(self, input):
55 | """Unpack input data from the dataloader and perform necessary pre-processing steps.
56 |
57 | Parameters:
58 | input (dict): includes the data itself and its metadata information.
59 | """
60 | pass
61 |
62 | @abstractmethod
63 | def forward(self):
64 | """Run forward pass; called by both functions and ."""
65 | pass
66 |
67 | def setup(self, opt):
68 | """Load and print networks; create schedulers
69 |
70 | Parameters:
71 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
72 | """
73 |
74 | if not self.isTrain:
75 | load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch
76 | self.load_networks(load_suffix)
77 | self.print_networks(opt.verbose)
78 |
79 | def eval(self):
80 | """Make models eval mode during test time"""
81 | for name in self.model_names:
82 | if isinstance(name, str):
83 | net = getattr(self, 'net' + name)
84 | net.eval()
85 |
86 | def test(self):
87 | """Forward function used in test time.
88 |
89 | This function wraps function in no_grad() so we don't save intermediate steps for backprop
90 | It also calls to produce additional visualization results
91 | """
92 | with torch.no_grad():
93 | self.forward()
94 | self.compute_visuals()
95 |
96 | def compute_visuals(self):
97 | """Calculate additional output images for visdom and HTML visualization"""
98 | pass
99 |
100 | def get_image_paths(self):
101 | """ Return image paths that are used to load current data"""
102 | return self.image_paths
103 |
104 | def get_current_visuals(self):
105 | """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
106 | visual_ret = OrderedDict()
107 | for name in self.visual_names:
108 | if isinstance(name, str):
109 | visual_ret[name] = getattr(self, name)
110 | return visual_ret
111 |
112 |
113 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
114 | """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
115 | key = keys[i]
116 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
117 | if module.__class__.__name__.startswith('InstanceNorm') and \
118 | (key == 'running_mean' or key == 'running_var'):
119 | if getattr(module, key) is None:
120 | state_dict.pop('.'.join(keys))
121 | if module.__class__.__name__.startswith('InstanceNorm') and \
122 | (key == 'num_batches_tracked'):
123 | state_dict.pop('.'.join(keys))
124 | else:
125 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
126 |
127 | def load_networks(self, epoch):
128 | """Load all the networks from the disk.
129 |
130 | Parameters:
131 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
132 | """
133 |
134 | load_filename = '%s_net_LUTs.pth' % epoch
135 | load_path = os.path.join(self.save_dir, load_filename)
136 | print('loading the model from %s' % load_path)
137 | LUTs_state_dict = torch.load(load_path, map_location=str(self.device))
138 | LUTs = getattr(self, 'baseLUTs')
139 | for ii in range(len(LUTs)):
140 | LUTs[ii].load_state_dict(LUTs_state_dict[str(ii)])
141 |
142 | for name in self.model_names:
143 | if isinstance(name, str):
144 | load_filename = '%s_net_%s.pth' % (epoch, name)
145 | load_path = os.path.join(self.save_dir, load_filename)
146 | net = getattr(self, 'net' + name)
147 |
148 | if isinstance(net, torch.nn.DataParallel):
149 | net = net.module
150 | print('loading the model from %s' % load_path)
151 |
152 | state_dict = torch.load(load_path, map_location=str(self.device))
153 | if hasattr(state_dict, '_metadata'):
154 | del state_dict._metadata
155 |
156 | current_state_dict = net.state_dict()
157 |
158 | pretrained_dict = state_dict
159 | current_state_dict.update(pretrained_dict)
160 | net.load_state_dict(current_state_dict, strict=True)
161 |
162 |
163 |
164 | def print_networks(self, verbose):
165 | """Print the total number of parameters in the network and (if verbose) network architecture
166 |
167 | Parameters:
168 | verbose (bool) -- if verbose: print the network architecture
169 | """
170 | print('---------- Networks initialized -------------')
171 | for name in self.model_names:
172 | if isinstance(name, str):
173 | net = getattr(self, 'net' + name)
174 | num_params = 0
175 | for param in net.parameters():
176 | num_params += param.numel()
177 | if verbose:
178 | print(net)
179 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
180 | print('-----------------------------------------------')
181 |
182 | def set_requires_grad(self, nets, requires_grad=False):
183 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
184 | Parameters:
185 | nets (network list) -- a list of networks
186 | requires_grad (bool) -- whether the networks require gradients or not
187 | """
188 | if not isinstance(nets, list):
189 | nets = [nets]
190 | for net in nets:
191 | if net is not None:
192 | for param in net.parameters():
193 | param.requires_grad = requires_grad
194 |
--------------------------------------------------------------------------------
/models/laBaseLUTs_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .base_model import BaseModel
3 | from . import modules
4 | from torch import cuda
5 |
6 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
7 |
8 |
9 | class LABASELUTSModel(BaseModel):
10 | def __init__(self, opt):
11 | BaseModel.__init__(self, opt)
12 | self.visual_names = ['mask', 'real', 'transfer_img_c']
13 | self.model_names = ['Er']
14 |
15 | self.LUT_num = opt.LUT_num
16 |
17 | self.baseLUTs = []
18 | for _ in range(self.LUT_num):
19 | self.baseLUTs.append(modules.Get3DLUT_identity(dim=17).to(self.device))
20 | self.LUT_num = len(self.baseLUTs)
21 | print('total LUT numbers: %d' % self.LUT_num)
22 |
23 | self.netEr = modules.define_E(opt.input_nc, opt.nz, opt.nef, opt.nwf, opt.netEr, opt.norm,
24 | 'lrelu', self.gpu_ids, True, self.LUT_num)
25 |
26 |
27 | self.TV3 = modules.TV_3D().to(self.device)
28 | self.TV3.weight_r = self.TV3.weight_r.type(Tensor)
29 | self.TV3.weight_g = self.TV3.weight_g.type(Tensor)
30 | self.TV3.weight_b = self.TV3.weight_b.type(Tensor)
31 | self.trilinear_ = modules.TrilinearInterpolation()
32 |
33 |
34 | def set_input(self, tgt_input):
35 | self.mask = tgt_input['mask'].to(self.device)
36 | self.real = tgt_input['real'].to(self.device)
37 | self.inputs = torch.cat([self.real,self.mask], 1)
38 |
39 | self.real_raw = tgt_input['real_raw'].to(self.device)
40 | self.mask_raw = tgt_input['mask_raw'].to(self.device)
41 |
42 |
43 | def get_z_random(self, size, random_type='gauss'):
44 | if random_type == 'uni':
45 | z = torch.rand(size) * 2.0 - 1.0
46 | elif random_type == 'gauss':
47 | z = torch.randn(size)
48 | return z.detach().to(self.device)
49 |
50 |
51 | def generator_eval(self, pred, img, mask, LUTs):
52 |
53 | pred = pred.squeeze()
54 |
55 | for ii in range(self.LUT_num):
56 | self.baseLUTs[ii].eval()
57 |
58 | LUT = pred[0] * self.baseLUTs[0].LUT
59 | for idx in range(1, self.LUT_num):
60 | LUT += pred[idx] * self.baseLUTs[idx].LUT
61 |
62 |
63 | _, combine_A = self.trilinear_(LUT,img)
64 |
65 | combine_A = combine_A * mask + img * (1 - mask)
66 | combine_A = torch.clamp(combine_A, 0, 1)
67 |
68 | return combine_A
69 |
70 |
71 | def forward(self):
72 | self.z_size = [self.inputs.size(0), self.opt.nz, 1, 1]
73 | self.z_random = self.get_z_random(self.z_size)
74 |
75 | self.features_c, self.weightPred_c = self.netEr(self.inputs, self.z_random)
76 | if self.opt.keep_res:
77 | self.transfer_img_c = self.generator_eval(self.weightPred_c, self.real_raw, self.mask_raw, self.baseLUTs)
78 | else:
79 | self.transfer_img_c = self.generator_eval(self.weightPred_c, self.real, self.mask, self.baseLUTs)
80 |
81 |
82 |
--------------------------------------------------------------------------------
/models/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import functools
3 | import trilinear
4 | import numpy as np
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 |
10 | class Identity(nn.Module):
11 | def forward(self, x):
12 | return x
13 |
14 |
15 | def get_norm_layer(norm_type='instance'):
16 | """Return a normalization layer
17 |
18 | Parameters:
19 | norm_type (str) -- the name of the normalization layer: batch | instance | none
20 |
21 | For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
22 | For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
23 | """
24 | if norm_type == 'batch':
25 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
26 | elif norm_type == 'instance':
27 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
28 | elif norm_type == 'none':
29 | norm_layer = lambda x: Identity()
30 | else:
31 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
32 | return norm_layer
33 |
34 | def get_non_linearity(layer_type='relu'):
35 | if layer_type == 'relu':
36 | nl_layer = functools.partial(nn.ReLU, inplace=True)
37 | elif layer_type == 'lrelu':
38 | nl_layer = functools.partial(
39 | nn.LeakyReLU, negative_slope=0.2, inplace=True)
40 | elif layer_type == 'elu':
41 | nl_layer = functools.partial(nn.ELU, inplace=True)
42 | else:
43 | raise NotImplementedError(
44 | 'nonlinearity activitation [%s] is not found' % layer_type)
45 | return nl_layer
46 |
47 |
48 |
49 | def init_net(net, gpu_ids=[]):
50 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
51 | Parameters:
52 | net (network) -- the network to be initialized
53 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
54 |
55 | Return an initialized network.
56 | """
57 |
58 | if len(gpu_ids) > 0:
59 | assert(torch.cuda.is_available())
60 | net.to(gpu_ids[0])
61 | #net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
62 | return net
63 |
64 |
65 |
66 | def define_E(input_nc, output_nc, nef, nwf, netE, norm='batch', nl='lrelu', gpu_ids=[], linear=True, LUT_num=5):
67 | net = None
68 | norm_layer = get_norm_layer(norm_type=norm)
69 | nl = 'lrelu' # use leaky relu for E
70 | nl_layer = get_non_linearity(layer_type=nl)
71 |
72 | if netE == 'Syco': #Joy
73 | net = SycoNet(input_nc, output_nc, nef, nwf, n_blocks=5, norm_layer=norm_layer, nl_layer=nl_layer, linear=linear, LUT_num=LUT_num)
74 | else:
75 | raise NotImplementedError('Encoder model name [%s] is not recognized' % netE)
76 |
77 | return init_net(net, gpu_ids)
78 |
79 |
80 |
81 | def conv3x3(in_planes, out_planes):
82 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1,
83 | padding=1, bias=True)
84 |
85 |
86 | def upsampleConv(inplanes, outplanes, kw, padw):
87 | sequence = []
88 | sequence += [nn.Upsample(scale_factor=2, mode='nearest')]
89 | sequence += [nn.Conv2d(inplanes, outplanes, kernel_size=kw,
90 | stride=1, padding=padw, bias=True)]
91 | return nn.Sequential(*sequence)
92 |
93 |
94 | def meanpoolConv(inplanes, outplanes):
95 | sequence = []
96 | sequence += [nn.AvgPool2d(kernel_size=2, stride=2)]
97 | sequence += [nn.Conv2d(inplanes, outplanes,
98 | kernel_size=1, stride=1, padding=0, bias=True)]
99 | return nn.Sequential(*sequence)
100 |
101 |
102 | def convMeanpool(inplanes, outplanes):
103 | sequence = []
104 | sequence += [conv3x3(inplanes, outplanes)]
105 | sequence += [nn.AvgPool2d(kernel_size=2, stride=2)]
106 | return nn.Sequential(*sequence)
107 |
108 | class BasicBlock(nn.Module):
109 | def __init__(self, inplanes, outplanes, norm_layer=None, nl_layer=None):
110 | super(BasicBlock, self).__init__()
111 | layers = []
112 | if norm_layer is not None:
113 | layers += [norm_layer(inplanes)]
114 | layers += [nl_layer()]
115 | layers += [conv3x3(inplanes, inplanes)]
116 | if norm_layer is not None:
117 | layers += [norm_layer(inplanes)]
118 | layers += [nl_layer()]
119 | layers += [convMeanpool(inplanes, outplanes)]
120 | self.conv = nn.Sequential(*layers)
121 | self.shortcut = meanpoolConv(inplanes, outplanes)
122 |
123 | def forward(self, x):
124 | out = self.conv(x) + self.shortcut(x)
125 | return out
126 |
127 |
128 |
129 | class SycoNet(nn.Module):
130 | def __init__(self, input_nc=3, nz=1, nef=64, nwf=128, n_blocks=4, norm_layer=None, nl_layer=None, linear=False, LUT_num=5):
131 | super(SycoNet, self).__init__()
132 | self.isLinear = linear
133 | max_nef = 4
134 | self.block0 = nn.Conv2d(input_nc+nz, nef, kernel_size=4, stride=2, padding=1, bias=True)
135 |
136 | self.block1 = BasicBlock(nef * min(max_nef, 1)+nz, nef * min(max_nef, 2), norm_layer, nl_layer)
137 | self.block2 = BasicBlock(nef * min(max_nef, 2)+nz, nef * min(max_nef, 3), norm_layer, nl_layer)
138 | self.block3 = BasicBlock(nef * min(max_nef, 3)+nz, nef * min(max_nef, 4), norm_layer, nl_layer)
139 | self.block4 = BasicBlock(nef * min(max_nef, 4)+nz, nef * min(max_nef, 4), norm_layer, nl_layer)
140 |
141 | self.nl = nn.Sequential(nl_layer(), nn.AvgPool2d(8))
142 |
143 | self.weight_predictor = nn.Conv2d(nwf, LUT_num, 1, padding=0)
144 |
145 | def forward(self, img_input, random_z):
146 | z_img = random_z.expand(random_z.size(0), random_z.size(1), img_input.size(2), img_input.size(3))
147 | inputs = torch.cat([img_input, z_img], 1)
148 | x0 = self.block0(inputs)
149 | z0 = random_z.expand(random_z.size(0), random_z.size(1), x0.size(2), x0.size(3))
150 | x1 = torch.cat([x0, z0], 1)
151 | x1 = self.block1(x1)
152 | z1 = random_z.expand(random_z.size(0), random_z.size(1), x1.size(2), x1.size(3))
153 | x2 = torch.cat([x1, z1], 1)
154 | x2 = self.block2(x2)
155 | z2 = random_z.expand(random_z.size(0), random_z.size(1), x2.size(2), x2.size(3))
156 | x3 = torch.cat([x2, z2], 1)
157 | x3 = self.block3(x3)
158 | z3 = random_z.expand(random_z.size(0), random_z.size(1), x3.size(2), x3.size(3))
159 | x4 = torch.cat([x3, z3], 1)
160 | x4 = self.block4(x4)
161 | features = self.nl(x4)
162 | outputs = self.weight_predictor(features)
163 | if self.isLinear:
164 | outputs = F.softmax(outputs, dim=1)
165 | return features, outputs
166 |
167 | class Get3DLUT_identity(nn.Module):
168 | def __init__(self, dim=17):
169 | super(Get3DLUT_identity, self).__init__()
170 | buffer = np.zeros((3,dim,dim,dim), dtype=np.float32)
171 |
172 | for i in range(0,dim):
173 | for j in range(0,dim):
174 | for k in range(0,dim):
175 | n = i * dim*dim + j * dim + k
176 | #x = lines[n].split()
177 | buffer[0,i,j,k] = 1.0/(dim-1)*k #float(x[0])
178 | buffer[1,i,j,k] = 1.0/(dim-1)*j #float(x[1])
179 | buffer[2,i,j,k] = 1.0/(dim-1)*i #float(x[2])
180 | #print(i,j,k,":",buffer[0,i,j,k],buffer[1,i,j,k],buffer[2,i,j,k])
181 | self.LUT = nn.Parameter(torch.from_numpy(buffer).requires_grad_(True))
182 | self.TrilinearInterpolation = TrilinearInterpolation()
183 |
184 | def forward(self, x):
185 | _, output = self.TrilinearInterpolation(self.LUT, x)
186 | return output
187 |
188 | class TrilinearInterpolationFunction(torch.autograd.Function):
189 | @staticmethod
190 | def forward(ctx, lut, x):
191 | x = x.contiguous()
192 |
193 | output = x.new(x.size())
194 | dim = lut.size()[-1]
195 | shift = dim ** 3
196 | binsize = 1.000001 / (dim-1)
197 | W = x.size(2)
198 | H = x.size(3)
199 | batch = x.size(0)
200 |
201 | assert 1 == trilinear.forward(lut, x, output, dim, shift, binsize, W, H, batch)
202 |
203 | int_package = torch.IntTensor([dim, shift, W, H, batch])
204 | float_package = torch.FloatTensor([binsize])
205 | variables = [lut, x, int_package, float_package]
206 |
207 | ctx.save_for_backward(*variables)
208 |
209 | return lut, output
210 |
211 | @staticmethod
212 | def backward(ctx, lut_grad, x_grad):
213 |
214 | lut, x, int_package, float_package = ctx.saved_variables
215 | dim, shift, W, H, batch = int_package
216 | dim, shift, W, H, batch = int(dim), int(shift), int(W), int(H), int(batch)
217 | binsize = float(float_package[0])
218 |
219 | assert 1 == trilinear.backward(x, x_grad, lut_grad, dim, shift, binsize, W, H, batch)
220 | return lut_grad, x_grad
221 |
222 |
223 |
224 | class TrilinearInterpolation(torch.nn.Module):
225 | def __init__(self):
226 | super(TrilinearInterpolation, self).__init__()
227 |
228 | def forward(self, lut, x):
229 | return TrilinearInterpolationFunction.apply(lut, x)
230 |
231 |
232 | class TV_3D(nn.Module):
233 | def __init__(self, dim=33):
234 | super(TV_3D,self).__init__()
235 |
236 | self.weight_r = torch.ones(3,dim,dim,dim-1, dtype=torch.float)
237 | self.weight_r[:,:,:,(0,dim-2)] *= 2.0
238 | self.weight_g = torch.ones(3,dim,dim-1,dim, dtype=torch.float)
239 | self.weight_g[:,:,(0,dim-2),:] *= 2.0
240 | self.weight_b = torch.ones(3,dim-1,dim,dim, dtype=torch.float)
241 | self.weight_b[:,(0,dim-2),:,:] *= 2.0
242 | self.relu = torch.nn.ReLU()
243 |
244 | def forward(self, LUT):
245 |
246 | dif_r = LUT.LUT[:,:,:,:-1] - LUT.LUT[:,:,:,1:]
247 | dif_g = LUT.LUT[:,:,:-1,:] - LUT.LUT[:,:,1:,:]
248 | dif_b = LUT.LUT[:,:-1,:,:] - LUT.LUT[:,1:,:,:]
249 | tv = torch.mean(torch.mul((dif_r ** 2),self.weight_r)) + torch.mean(torch.mul((dif_g ** 2),self.weight_g)) + torch.mean(torch.mul((dif_b ** 2),self.weight_b))
250 |
251 | mn = torch.mean(self.relu(dif_r)) + torch.mean(self.relu(dif_g)) + torch.mean(self.relu(dif_b))
252 |
253 | return tv, mn
--------------------------------------------------------------------------------
/options/__init__.py:
--------------------------------------------------------------------------------
1 | """This package options includes option modules: training options, test options, and basic options (used in both training and test)."""
2 |
--------------------------------------------------------------------------------
/options/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/options/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/options/__pycache__/base_options.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/options/__pycache__/base_options.cpython-36.pyc
--------------------------------------------------------------------------------
/options/__pycache__/test_options.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/options/__pycache__/test_options.cpython-36.pyc
--------------------------------------------------------------------------------
/options/base_options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from util import util
4 | import torch
5 | import models
6 | import data
7 |
8 |
9 | class BaseOptions():
10 | """This class defines options used during both training and test time.
11 |
12 | It also implements several helper functions such as parsing, printing, and saving the options.
13 | It also gathers additional options defined in functions in both dataset class and model class.
14 | """
15 |
16 | def __init__(self):
17 | """Reset the class; indicates the class hasn't been initailized"""
18 | self.initialized = False
19 |
20 | def initialize(self, parser):
21 | """Define the common options that are used in both training and test."""
22 | # basic parameters
23 | parser.add_argument('--dataset_root',type=str, default='/data/caojunyan/datasets/IHD/', help='path to iHarmony4 dataset') #mia
24 | parser.add_argument('--dataset_name',type=str, default='', help='which sub-dataset to load [Hday2night | HVIDIT]') #mia
25 | parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
26 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
27 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
28 | # model parameters
29 | parser.add_argument('--model', type=str, default='la', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')
30 | parser.add_argument('--input_nc', type=int, default=4, help='# of input image channels: 4 for concated comp and mask') #mia
31 | parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')
32 | parser.add_argument('--nwf', type=int, default=256, help='# of weight predictor filters in the first conv layer')
33 | parser.add_argument('--nef', type=int, default=64, help='# of encoder filters in the first conv layer') # Joy
34 | parser.add_argument('--LUT_num', type=int, default=20, help='# the number of LUTs. maximum: 20') #Joy
35 | parser.add_argument('--nz', type=int, default=32, help='#latent code dim') #Joy
36 | parser.add_argument('--netEr', type=str, default='lut_spacial', help='Encoder architecture. [lut_special]') #Joy
37 | parser.add_argument('--norm', type=str, default='batch', help='instance normalization or batch normalization [instance | batch | none]')
38 | # dataset parameters
39 | parser.add_argument('--dataset_mode', type=str, default='ihd', help='load iHarmony4 dataset') #mia
40 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
41 | parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
42 | parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
43 | parser.add_argument('--load_size', type=int, default=286, help='scale images to this size')
44 | parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size')
45 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
46 | parser.add_argument('--preprocess', type=str, default='resize', help='scaling images at load time')
47 | parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
48 | parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')
49 | # additional parameters
50 | parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
51 | parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')
52 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
53 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
54 | self.initialized = True
55 | return parser
56 |
57 | def gather_options(self):
58 | """Initialize our parser with basic options(only once).
59 | Add additional model-specific and dataset-specific options.
60 | These options are defined in the function
61 | in model and dataset classes.
62 | """
63 | if not self.initialized: # check if it has been initialized
64 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
65 | parser = self.initialize(parser)
66 |
67 | # get the basic options
68 | opt, _ = parser.parse_known_args()
69 |
70 | # modify model-related parser options
71 | model_name = opt.model
72 | model_option_setter = models.get_option_setter(model_name)
73 | parser = model_option_setter(parser, self.isTrain)
74 | opt, _ = parser.parse_known_args() # parse again with new defaults
75 |
76 | # modify dataset-related parser options
77 | if opt.dataset_mode is not None:
78 | dataset_name = opt.dataset_mode
79 | dataset_option_setter = data.get_option_setter(dataset_name)
80 | parser = dataset_option_setter(parser, self.isTrain)
81 | else:
82 | dataset_name = 'ihd'
83 | dataset_option_setter = data.get_option_setter(dataset_name)
84 | parser = dataset_option_setter(parser, self.isTrain)
85 |
86 | # save and return the parser
87 | self.parser = parser
88 | return parser.parse_args()
89 |
90 | def print_options(self, opt):
91 | """Print and save options
92 |
93 | It will print both current options and default values(if different).
94 | It will save options into a text file / [checkpoints_dir] / opt.txt
95 | """
96 | message = ''
97 | message += '----------------- Options ---------------\n'
98 | for k, v in sorted(vars(opt).items()):
99 | comment = ''
100 | default = self.parser.get_default(k)
101 | if v != default:
102 | comment = '\t[default: %s]' % str(default)
103 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
104 | message += '----------------- End -------------------'
105 | print(message)
106 |
107 | # save to the disk
108 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
109 | util.mkdirs(expr_dir)
110 | file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
111 | with open(file_name, 'wt') as opt_file:
112 | opt_file.write(message)
113 | opt_file.write('\n')
114 |
115 | def parse(self):
116 | """Parse our options, create checkpoints directory suffix, and set up gpu device."""
117 | opt = self.gather_options()
118 | opt.isTrain = self.isTrain # train or test
119 |
120 | # process opt.suffix
121 | if opt.suffix:
122 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
123 | opt.name = opt.name + suffix
124 |
125 | self.print_options(opt)
126 |
127 | # set gpu ids
128 | str_ids = opt.gpu_ids.split(',')
129 | opt.gpu_ids = []
130 | for str_id in str_ids:
131 | id = int(str_id)
132 | if id >= 0:
133 | opt.gpu_ids.append(id)
134 | if len(opt.gpu_ids) > 0:
135 | torch.cuda.set_device(opt.gpu_ids[0])
136 |
137 | self.opt = opt
138 | return self.opt
139 |
--------------------------------------------------------------------------------
/options/test_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TestOptions(BaseOptions):
5 | """This class includes test options.
6 |
7 | It also includes shared options defined in BaseOptions.
8 | """
9 |
10 | def initialize(self, parser):
11 | parser = BaseOptions.initialize(self, parser) # define shared options
12 | parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
13 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
14 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
15 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
16 | # Dropout and Batchnorm has different behavioir during training and test.
17 | parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')
18 | parser.add_argument('--num_test', type=int, default=7404, help='how many test images to run, for iHarmony4, the number is 7404') #mia
19 | parser.add_argument('--augment_num', type=int, default=5, help='how many augmented reults to generate for each real image.')
20 | parser.add_argument('--keep_res', action='store_true', help='keep the input\'s resolution for augmented output.')
21 | parser.add_argument('--real', type=str, default='', help='the real image for learnable augmentation.')
22 | parser.add_argument('--mask', type=str, default='', help='the foreground mask of real image.')
23 | # rewrite devalue values
24 | parser.set_defaults(model='test')
25 | # To avoid cropping, the load_size should be the same as crop_size
26 | parser.set_defaults(load_size=parser.get_default('crop_size'))
27 | self.isTrain = False
28 | return parser
29 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | opencv_python==4.1.1.26
2 | albumentations==0.5.2 --no-binary albumentations
3 |
--------------------------------------------------------------------------------
/trilinear_cpp/build/lib.linux-x86_64-3.6/trilinear.cpython-36m-x86_64-linux-gnu.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/trilinear_cpp/build/lib.linux-x86_64-3.6/trilinear.cpython-36m-x86_64-linux-gnu.so
--------------------------------------------------------------------------------
/trilinear_cpp/build/temp.linux-x86_64-3.6/src/trilinear_cuda.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/trilinear_cpp/build/temp.linux-x86_64-3.6/src/trilinear_cuda.o
--------------------------------------------------------------------------------
/trilinear_cpp/build/temp.linux-x86_64-3.6/src/trilinear_kernel.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/trilinear_cpp/build/temp.linux-x86_64-3.6/src/trilinear_kernel.o
--------------------------------------------------------------------------------
/trilinear_cpp/dist/trilinear-0.0.0-py3.6-linux-x86_64.egg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/trilinear_cpp/dist/trilinear-0.0.0-py3.6-linux-x86_64.egg
--------------------------------------------------------------------------------
/trilinear_cpp/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | import torch
3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension
4 |
5 | if torch.cuda.is_available():
6 | print('Including CUDA code.')
7 | setup(
8 | name='trilinear',
9 | ext_modules=[
10 | CUDAExtension('trilinear', [
11 | 'src/trilinear_cuda.cpp',
12 | 'src/trilinear_kernel.cu',
13 | ])
14 | ],
15 | cmdclass={
16 | 'build_ext': BuildExtension
17 | })
18 | else:
19 | print('NO CUDA is found. Fall back to CPU.')
20 | setup(name='trilinear',
21 | ext_modules=[CppExtension('trilinear', ['src/trilinear.cpp'])],
22 | cmdclass={'build_ext': BuildExtension})
23 |
--------------------------------------------------------------------------------
/trilinear_cpp/setup.sh:
--------------------------------------------------------------------------------
1 | export CUDA_HOME=/usr/local/cuda-11.3 && python3 setup.py install
2 |
--------------------------------------------------------------------------------
/trilinear_cpp/src/trilinear.cpp:
--------------------------------------------------------------------------------
1 | #include "trilinear.h"
2 |
3 |
4 | void TriLinearForwardCpu(const float* lut, const float* image, float* output, const int dim, const int shift, const float binsize, const int width, const int height, const int channels);
5 |
6 | void TriLinearBackwardCpu(const float* image, const float* image_grad, float* lut_grad, const int dim, const int shift, const float binsize, const int width, const int height, const int channels);
7 |
8 | int trilinear_forward(torch::Tensor lut, torch::Tensor image, torch::Tensor output,
9 | int lut_dim, int shift, float binsize, int width, int height, int batch)
10 | {
11 | // Grab the input tensor
12 | float * lut_flat = lut.data();
13 | float * image_flat = image.data();
14 | float * output_flat = output.data();
15 |
16 | // whether color image
17 | auto image_size = image.sizes();
18 | int channels = image_size[1];
19 | if (channels != 3)
20 | {
21 | return 0;
22 | }
23 |
24 | TriLinearForwardCpu(lut_flat, image_flat, output_flat, lut_dim, shift, binsize, width, height, channels);
25 |
26 | return 1;
27 | }
28 |
29 | int trilinear_backward(torch::Tensor image, torch::Tensor image_grad, torch::Tensor lut_grad,
30 | int lut_dim, int shift, float binsize, int width, int height, int batch)
31 | {
32 | // Grab the input tensor
33 | float * image_grad_flat = image_grad.data();
34 | float * image_flat = image.data();
35 | float * lut_grad_flat = lut_grad.data();
36 |
37 | // whether color image
38 | auto image_size = image.sizes();
39 | int channels = image_size[1];
40 | if (channels != 3)
41 | {
42 | return 0;
43 | }
44 |
45 | TriLinearBackwardCpu(image_flat, image_grad_flat, lut_grad_flat, lut_dim, shift, binsize, width, height, channels);
46 |
47 | return 1;
48 | }
49 |
50 | void TriLinearForwardCpu(const float* lut, const float* image, float* output, const int dim, const int shift, const float binsize, const int width, const int height, const int channels)
51 | {
52 | const int output_size = height * width;;
53 |
54 | int index = 0;
55 | for (index = 0; index < output_size; ++index)
56 | {
57 | float r = image[index];
58 | float g = image[index + width * height];
59 | float b = image[index + width * height * 2];
60 |
61 | int r_id = floor(r / binsize);
62 | int g_id = floor(g / binsize);
63 | int b_id = floor(b / binsize);
64 |
65 | float r_d = fmod(r,binsize) / binsize;
66 | float g_d = fmod(g,binsize) / binsize;
67 | float b_d = fmod(b,binsize) / binsize;
68 |
69 | int id000 = r_id + g_id * dim + b_id * dim * dim;
70 | int id100 = r_id + 1 + g_id * dim + b_id * dim * dim;
71 | int id010 = r_id + (g_id + 1) * dim + b_id * dim * dim;
72 | int id110 = r_id + 1 + (g_id + 1) * dim + b_id * dim * dim;
73 | int id001 = r_id + g_id * dim + (b_id + 1) * dim * dim;
74 | int id101 = r_id + 1 + g_id * dim + (b_id + 1) * dim * dim;
75 | int id011 = r_id + (g_id + 1) * dim + (b_id + 1) * dim * dim;
76 | int id111 = r_id + 1 + (g_id + 1) * dim + (b_id + 1) * dim * dim;
77 |
78 | float w000 = (1-r_d)*(1-g_d)*(1-b_d);
79 | float w100 = r_d*(1-g_d)*(1-b_d);
80 | float w010 = (1-r_d)*g_d*(1-b_d);
81 | float w110 = r_d*g_d*(1-b_d);
82 | float w001 = (1-r_d)*(1-g_d)*b_d;
83 | float w101 = r_d*(1-g_d)*b_d;
84 | float w011 = (1-r_d)*g_d*b_d;
85 | float w111 = r_d*g_d*b_d;
86 |
87 | output[index] = w000 * lut[id000] + w100 * lut[id100] +
88 | w010 * lut[id010] + w110 * lut[id110] +
89 | w001 * lut[id001] + w101 * lut[id101] +
90 | w011 * lut[id011] + w111 * lut[id111];
91 |
92 | output[index + width * height] = w000 * lut[id000 + shift] + w100 * lut[id100 + shift] +
93 | w010 * lut[id010 + shift] + w110 * lut[id110 + shift] +
94 | w001 * lut[id001 + shift] + w101 * lut[id101 + shift] +
95 | w011 * lut[id011 + shift] + w111 * lut[id111 + shift];
96 |
97 | output[index + width * height * 2] = w000 * lut[id000 + shift * 2] + w100 * lut[id100 + shift * 2] +
98 | w010 * lut[id010 + shift * 2] + w110 * lut[id110 + shift * 2] +
99 | w001 * lut[id001 + shift * 2] + w101 * lut[id101 + shift * 2] +
100 | w011 * lut[id011 + shift * 2] + w111 * lut[id111 + shift * 2];
101 | }
102 | }
103 |
104 | void TriLinearBackwardCpu(const float* image, const float* image_grad, float* lut_grad, const int dim, const int shift, const float binsize, const int width, const int height, const int channels)
105 | {
106 | const int output_size = height * width;
107 |
108 | int index = 0;
109 | for (index = 0; index < output_size; ++index)
110 | {
111 | float r = image[index];
112 | float g = image[index + width * height];
113 | float b = image[index + width * height * 2];
114 |
115 | int r_id = floor(r / binsize);
116 | int g_id = floor(g / binsize);
117 | int b_id = floor(b / binsize);
118 |
119 | float r_d = fmod(r,binsize) / binsize;
120 | float g_d = fmod(g,binsize) / binsize;
121 | float b_d = fmod(b,binsize) / binsize;
122 |
123 | int id000 = r_id + g_id * dim + b_id * dim * dim;
124 | int id100 = r_id + 1 + g_id * dim + b_id * dim * dim;
125 | int id010 = r_id + (g_id + 1) * dim + b_id * dim * dim;
126 | int id110 = r_id + 1 + (g_id + 1) * dim + b_id * dim * dim;
127 | int id001 = r_id + g_id * dim + (b_id + 1) * dim * dim;
128 | int id101 = r_id + 1 + g_id * dim + (b_id + 1) * dim * dim;
129 | int id011 = r_id + (g_id + 1) * dim + (b_id + 1) * dim * dim;
130 | int id111 = r_id + 1 + (g_id + 1) * dim + (b_id + 1) * dim * dim;
131 |
132 | float w000 = (1-r_d)*(1-g_d)*(1-b_d);
133 | float w100 = r_d*(1-g_d)*(1-b_d);
134 | float w010 = (1-r_d)*g_d*(1-b_d);
135 | float w110 = r_d*g_d*(1-b_d);
136 | float w001 = (1-r_d)*(1-g_d)*b_d;
137 | float w101 = r_d*(1-g_d)*b_d;
138 | float w011 = (1-r_d)*g_d*b_d;
139 | float w111 = r_d*g_d*b_d;
140 |
141 | lut_grad[id000] += w000 * image_grad[index];
142 | lut_grad[id100] += w100 * image_grad[index];
143 | lut_grad[id010] += w010 * image_grad[index];
144 | lut_grad[id110] += w110 * image_grad[index];
145 | lut_grad[id001] += w001 * image_grad[index];
146 | lut_grad[id101] += w101 * image_grad[index];
147 | lut_grad[id011] += w011 * image_grad[index];
148 | lut_grad[id111] += w111 * image_grad[index];
149 |
150 | lut_grad[id000 + shift] += w000 * image_grad[index + width * height];
151 | lut_grad[id100 + shift] += w100 * image_grad[index + width * height];
152 | lut_grad[id010 + shift] += w010 * image_grad[index + width * height];
153 | lut_grad[id110 + shift] += w110 * image_grad[index + width * height];
154 | lut_grad[id001 + shift] += w001 * image_grad[index + width * height];
155 | lut_grad[id101 + shift] += w101 * image_grad[index + width * height];
156 | lut_grad[id011 + shift] += w011 * image_grad[index + width * height];
157 | lut_grad[id111 + shift] += w111 * image_grad[index + width * height];
158 |
159 | lut_grad[id000 + shift* 2] += w000 * image_grad[index + width * height * 2];
160 | lut_grad[id100 + shift* 2] += w100 * image_grad[index + width * height * 2];
161 | lut_grad[id010 + shift* 2] += w010 * image_grad[index + width * height * 2];
162 | lut_grad[id110 + shift* 2] += w110 * image_grad[index + width * height * 2];
163 | lut_grad[id001 + shift* 2] += w001 * image_grad[index + width * height * 2];
164 | lut_grad[id101 + shift* 2] += w101 * image_grad[index + width * height * 2];
165 | lut_grad[id011 + shift* 2] += w011 * image_grad[index + width * height * 2];
166 | lut_grad[id111 + shift* 2] += w111 * image_grad[index + width * height * 2];
167 | }
168 | }
169 |
170 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
171 | m.def("forward", &trilinear_forward, "Trilinear forward");
172 | m.def("backward", &trilinear_backward, "Trilinear backward");
173 | }
174 |
--------------------------------------------------------------------------------
/trilinear_cpp/src/trilinear.h:
--------------------------------------------------------------------------------
1 | #ifndef TRILINEAR_H
2 | #define TRILINEAR_H
3 |
4 | #include
5 |
6 | int trilinear_forward(torch::Tensor lut, torch::Tensor image, torch::Tensor output,
7 | int lut_dim, int shift, float binsize, int width, int height, int batch);
8 |
9 | int trilinear_backward(torch::Tensor image, torch::Tensor image_grad, torch::Tensor lut_grad,
10 | int lut_dim, int shift, float binsize, int width, int height, int batch);
11 |
12 | #endif
13 |
--------------------------------------------------------------------------------
/trilinear_cpp/src/trilinear_cuda.cpp:
--------------------------------------------------------------------------------
1 | #include "trilinear_kernel.h"
2 | #include
3 | #include
4 |
5 | int trilinear_forward_cuda(torch::Tensor lut, torch::Tensor image, torch::Tensor output,
6 | int lut_dim, int shift, float binsize, int width, int height, int batch)
7 | {
8 | // Grab the input tensor
9 | float * lut_flat = lut.data();
10 | float * image_flat = image.data();
11 | float * output_flat = output.data();
12 |
13 | TriLinearForwardLaucher(lut_flat, image_flat, output_flat, lut_dim, shift, binsize, width, height, batch, at::cuda::getCurrentCUDAStream());
14 |
15 | return 1;
16 | }
17 |
18 | int trilinear_backward_cuda(torch::Tensor image, torch::Tensor image_grad, torch::Tensor lut_grad,
19 | int lut_dim, int shift, float binsize, int width, int height, int batch)
20 | {
21 | // Grab the input tensor
22 | float * image_grad_flat = image_grad.data();
23 | float * image_flat = image.data();
24 | float * lut_grad_flat = lut_grad.data();
25 |
26 | TriLinearBackwardLaucher(image_flat, image_grad_flat, lut_grad_flat, lut_dim, shift, binsize, width, height, batch, at::cuda::getCurrentCUDAStream());
27 |
28 | return 1;
29 | }
30 |
31 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
32 | m.def("forward", &trilinear_forward_cuda, "Trilinear forward");
33 | m.def("backward", &trilinear_backward_cuda, "Trilinear backward");
34 | }
35 |
36 |
--------------------------------------------------------------------------------
/trilinear_cpp/src/trilinear_cuda.h:
--------------------------------------------------------------------------------
1 | #ifndef TRILINEAR_CUDA_H
2 | #define TRILINEAR_CUDA_H
3 |
4 | #import
5 |
6 | int trilinear_forward_cuda(torch::Tensor lut, torch::Tensor image, torch::Tensor output,
7 | int lut_dim, int shift, float binsize, int width, int height, int batch);
8 |
9 | int trilinear_backward_cuda(torch::Tensor image, torch::Tensor image_grad, torch::Tensor lut_grad,
10 | int lut_dim, int shift, float binsize, int width, int height, int batch);
11 |
12 | #endif
13 |
--------------------------------------------------------------------------------
/trilinear_cpp/src/trilinear_kernel.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include "trilinear_kernel.h"
4 |
5 | #define CUDA_1D_KERNEL_LOOP(i, n) \
6 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
7 | i += blockDim.x * gridDim.x)
8 |
9 |
10 | __global__ void TriLinearForward(const int nthreads, const float* lut, const float* image, float* output, const int dim, const int shift, const float binsize, const int width, const int height, const int batch) {
11 | CUDA_1D_KERNEL_LOOP(index, nthreads) {
12 |
13 | float r = image[index];
14 | float g = image[index + width * height * batch];
15 | float b = image[index + width * height * batch * 2];
16 |
17 | int r_id = floor(r / binsize);
18 | int g_id = floor(g / binsize);
19 | int b_id = floor(b / binsize);
20 |
21 | float r_d = fmod(r,binsize) / binsize;
22 | float g_d = fmod(g,binsize) / binsize;
23 | float b_d = fmod(b,binsize) / binsize;
24 |
25 | int id000 = r_id + g_id * dim + b_id * dim * dim;
26 | int id100 = r_id + 1 + g_id * dim + b_id * dim * dim;
27 | int id010 = r_id + (g_id + 1) * dim + b_id * dim * dim;
28 | int id110 = r_id + 1 + (g_id + 1) * dim + b_id * dim * dim;
29 | int id001 = r_id + g_id * dim + (b_id + 1) * dim * dim;
30 | int id101 = r_id + 1 + g_id * dim + (b_id + 1) * dim * dim;
31 | int id011 = r_id + (g_id + 1) * dim + (b_id + 1) * dim * dim;
32 | int id111 = r_id + 1 + (g_id + 1) * dim + (b_id + 1) * dim * dim;
33 |
34 | float w000 = (1-r_d)*(1-g_d)*(1-b_d);
35 | float w100 = r_d*(1-g_d)*(1-b_d);
36 | float w010 = (1-r_d)*g_d*(1-b_d);
37 | float w110 = r_d*g_d*(1-b_d);
38 | float w001 = (1-r_d)*(1-g_d)*b_d;
39 | float w101 = r_d*(1-g_d)*b_d;
40 | float w011 = (1-r_d)*g_d*b_d;
41 | float w111 = r_d*g_d*b_d;
42 |
43 | output[index] = w000 * lut[id000] + w100 * lut[id100] +
44 | w010 * lut[id010] + w110 * lut[id110] +
45 | w001 * lut[id001] + w101 * lut[id101] +
46 | w011 * lut[id011] + w111 * lut[id111];
47 |
48 | output[index + width * height * batch] = w000 * lut[id000 + shift] + w100 * lut[id100 + shift] +
49 | w010 * lut[id010 + shift] + w110 * lut[id110 + shift] +
50 | w001 * lut[id001 + shift] + w101 * lut[id101 + shift] +
51 | w011 * lut[id011 + shift] + w111 * lut[id111 + shift];
52 |
53 | output[index + width * height * batch * 2] = w000 * lut[id000 + shift * 2] + w100 * lut[id100 + shift * 2] +
54 | w010 * lut[id010 + shift * 2] + w110 * lut[id110 + shift * 2] +
55 | w001 * lut[id001 + shift * 2] + w101 * lut[id101 + shift * 2] +
56 | w011 * lut[id011 + shift * 2] + w111 * lut[id111 + shift * 2];
57 |
58 | }
59 | }
60 |
61 |
62 | int TriLinearForwardLaucher(const float* lut, const float* image, float* output, const int lut_dim, const int shift, const float binsize, const int width, const int height, const int batch, cudaStream_t stream) {
63 | const int kThreadsPerBlock = 1024;
64 | const int output_size = height * width * batch;
65 | cudaError_t err;
66 |
67 |
68 | TriLinearForward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(output_size, lut, image, output, lut_dim, shift, binsize, width, height, batch);
69 |
70 | err = cudaGetLastError();
71 | if(cudaSuccess != err) {
72 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
73 | exit( -1 );
74 | }
75 |
76 | return 1;
77 | }
78 |
79 |
80 | __global__ void TriLinearBackward(const int nthreads, const float* image, const float* image_grad, float* lut_grad, const int dim, const int shift, const float binsize, const int width, const int height, const int batch) {
81 | CUDA_1D_KERNEL_LOOP(index, nthreads) {
82 |
83 | float r = image[index];
84 | float g = image[index + width * height * batch];
85 | float b = image[index + width * height * batch * 2];
86 |
87 | int r_id = floor(r / binsize);
88 | int g_id = floor(g / binsize);
89 | int b_id = floor(b / binsize);
90 |
91 | float r_d = fmod(r,binsize) / binsize;
92 | float g_d = fmod(g,binsize) / binsize;
93 | float b_d = fmod(b,binsize) / binsize;
94 |
95 | int id000 = r_id + g_id * dim + b_id * dim * dim;
96 | int id100 = r_id + 1 + g_id * dim + b_id * dim * dim;
97 | int id010 = r_id + (g_id + 1) * dim + b_id * dim * dim;
98 | int id110 = r_id + 1 + (g_id + 1) * dim + b_id * dim * dim;
99 | int id001 = r_id + g_id * dim + (b_id + 1) * dim * dim;
100 | int id101 = r_id + 1 + g_id * dim + (b_id + 1) * dim * dim;
101 | int id011 = r_id + (g_id + 1) * dim + (b_id + 1) * dim * dim;
102 | int id111 = r_id + 1 + (g_id + 1) * dim + (b_id + 1) * dim * dim;
103 |
104 | float w000 = (1-r_d)*(1-g_d)*(1-b_d);
105 | float w100 = r_d*(1-g_d)*(1-b_d);
106 | float w010 = (1-r_d)*g_d*(1-b_d);
107 | float w110 = r_d*g_d*(1-b_d);
108 | float w001 = (1-r_d)*(1-g_d)*b_d;
109 | float w101 = r_d*(1-g_d)*b_d;
110 | float w011 = (1-r_d)*g_d*b_d;
111 | float w111 = r_d*g_d*b_d;
112 |
113 | atomicAdd(lut_grad + id000, image_grad[index] * w000);
114 | atomicAdd(lut_grad + id100, image_grad[index] * w100);
115 | atomicAdd(lut_grad + id010, image_grad[index] * w010);
116 | atomicAdd(lut_grad + id110, image_grad[index] * w110);
117 | atomicAdd(lut_grad + id001, image_grad[index] * w001);
118 | atomicAdd(lut_grad + id101, image_grad[index] * w101);
119 | atomicAdd(lut_grad + id011, image_grad[index] * w011);
120 | atomicAdd(lut_grad + id111, image_grad[index] * w111);
121 |
122 | atomicAdd(lut_grad + id000 + shift, image_grad[index + width * height * batch] * w000);
123 | atomicAdd(lut_grad + id100 + shift, image_grad[index + width * height * batch] * w100);
124 | atomicAdd(lut_grad + id010 + shift, image_grad[index + width * height * batch] * w010);
125 | atomicAdd(lut_grad + id110 + shift, image_grad[index + width * height * batch] * w110);
126 | atomicAdd(lut_grad + id001 + shift, image_grad[index + width * height * batch] * w001);
127 | atomicAdd(lut_grad + id101 + shift, image_grad[index + width * height * batch] * w101);
128 | atomicAdd(lut_grad + id011 + shift, image_grad[index + width * height * batch] * w011);
129 | atomicAdd(lut_grad + id111 + shift, image_grad[index + width * height * batch] * w111);
130 |
131 | atomicAdd(lut_grad + id000 + shift * 2, image_grad[index + width * height * batch * 2] * w000);
132 | atomicAdd(lut_grad + id100 + shift * 2, image_grad[index + width * height * batch * 2] * w100);
133 | atomicAdd(lut_grad + id010 + shift * 2, image_grad[index + width * height * batch * 2] * w010);
134 | atomicAdd(lut_grad + id110 + shift * 2, image_grad[index + width * height * batch * 2] * w110);
135 | atomicAdd(lut_grad + id001 + shift * 2, image_grad[index + width * height * batch * 2] * w001);
136 | atomicAdd(lut_grad + id101 + shift * 2, image_grad[index + width * height * batch * 2] * w101);
137 | atomicAdd(lut_grad + id011 + shift * 2, image_grad[index + width * height * batch * 2] * w011);
138 | atomicAdd(lut_grad + id111 + shift * 2, image_grad[index + width * height * batch * 2] * w111);
139 | }
140 | }
141 |
142 | int TriLinearBackwardLaucher(const float* image, const float* image_grad, float* lut_grad, const int lut_dim, const int shift, const float binsize, const int width, const int height, const int batch, cudaStream_t stream) {
143 | const int kThreadsPerBlock = 1024;
144 | const int output_size = height * width * batch;
145 | cudaError_t err;
146 |
147 | TriLinearBackward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(output_size, image, image_grad, lut_grad, lut_dim, shift, binsize, width, height, batch);
148 |
149 | err = cudaGetLastError();
150 | if(cudaSuccess != err) {
151 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
152 | exit( -1 );
153 | }
154 |
155 | return 1;
156 | }
157 |
--------------------------------------------------------------------------------
/trilinear_cpp/src/trilinear_kernel.h:
--------------------------------------------------------------------------------
1 | #ifndef _TRILINEAR_KERNEL
2 | #define _TRILINEAR_KERNEL
3 |
4 | #include
5 |
6 | __global__ void TriLinearForward(const int nthreads, const float* lut, const float* image, float* output, const int dim, const int shift, const float binsize, const int width, const int height, const int batch);
7 |
8 | int TriLinearForwardLaucher(const float* lut, const float* image, float* output, const int lut_dim, const int shift, const float binsize, const int width, const int height, const int batch, cudaStream_t stream);
9 |
10 | __global__ void TriLinearBackward(const int nthreads, const float* image, const float* image_grad, float* lut_grad, const int dim, const int shift, const float binsize, const int width, const int height, const int batch);
11 |
12 | int TriLinearBackwardLaucher(const float* image, const float* image_grad, float* lut_grad, const int lut_dim, const int shift, const float binsize, const int width, const int height, const int batch, cudaStream_t stream);
13 |
14 |
15 | #endif
16 |
17 |
--------------------------------------------------------------------------------
/trilinear_cpp/trilinear.egg-info/PKG-INFO:
--------------------------------------------------------------------------------
1 | Metadata-Version: 2.1
2 | Name: trilinear
3 | Version: 0.0.0
4 | Summary: UNKNOWN
5 | Home-page: UNKNOWN
6 | License: UNKNOWN
7 | Platform: UNKNOWN
8 |
9 | UNKNOWN
10 |
11 |
--------------------------------------------------------------------------------
/trilinear_cpp/trilinear.egg-info/SOURCES.txt:
--------------------------------------------------------------------------------
1 | setup.py
2 | src/trilinear_cuda.cpp
3 | src/trilinear_kernel.cu
4 | trilinear.egg-info/PKG-INFO
5 | trilinear.egg-info/SOURCES.txt
6 | trilinear.egg-info/dependency_links.txt
7 | trilinear.egg-info/top_level.txt
--------------------------------------------------------------------------------
/trilinear_cpp/trilinear.egg-info/dependency_links.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/trilinear_cpp/trilinear.egg-info/top_level.txt:
--------------------------------------------------------------------------------
1 | trilinear
2 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
1 | """This package includes a miscellaneous collection of useful helper functions."""
2 |
--------------------------------------------------------------------------------
/util/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/util/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/util/__pycache__/util.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/SycoNet-Adaptive-Image-Harmonization/4a7ab365a5a2cfa83de5525959982290d81bc31e/util/__pycache__/util.cpython-36.pyc
--------------------------------------------------------------------------------
/util/util.py:
--------------------------------------------------------------------------------
1 | """This module contains simple helper functions """
2 |
3 | import os
4 | import torch
5 | import numpy as np
6 | from PIL import Image
7 |
8 |
9 | def tensor2im(input_image, imtype=np.uint8):
10 | """"Converts a Tensor array into a numpy image array.
11 |
12 | Parameters:
13 | input_image (tensor) -- the input image tensor array
14 | imtype (type) -- the desired type of the converted numpy array
15 | """
16 |
17 | if not isinstance(input_image, np.ndarray):
18 | if isinstance(input_image, torch.Tensor): # get the data from a variable
19 | image_tensor = input_image.data
20 | else:
21 | return input_image
22 |
23 | image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
24 | if image_numpy.shape[0] == 1: # grayscale to RGB
25 | image_numpy = np.tile(image_numpy, (3, 1, 1))
26 |
27 | image_numpy = (np.transpose(image_numpy, (1, 2, 0))) * 255.0
28 |
29 | else: # if it is a numpy array, do nothing
30 | image_numpy = input_image
31 | image_numpy = np.clip(image_numpy, 0, 255)
32 | return image_numpy.astype(imtype)
33 |
34 |
35 | def save_image(image_numpy, image_path, aspect_ratio=1.0):
36 | """Save a numpy image to the disk
37 |
38 | Parameters:
39 | image_numpy (numpy array) -- input numpy array
40 | image_path (str) -- the path of the image
41 | """
42 |
43 | image_pil = Image.fromarray(image_numpy)
44 | h, w, _ = image_numpy.shape
45 |
46 | if aspect_ratio > 1.0:
47 | image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
48 | if aspect_ratio < 1.0:
49 | image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
50 | image_pil.save(image_path,quality=100) #added by Mia (quality)
51 |
52 |
53 | def mkdirs(paths):
54 | """create empty directories if they don't exist
55 |
56 | Parameters:
57 | paths (str list) -- a list of directory paths
58 | """
59 | if isinstance(paths, list) and not isinstance(paths, str):
60 | for path in paths:
61 | mkdir(path)
62 | else:
63 | mkdir(paths)
64 |
65 |
66 | def mkdir(path):
67 | """create a single empty directory if it didn't exist
68 |
69 | Parameters:
70 | path (str) -- a single directory path
71 | """
72 | if not os.path.exists(path):
73 | os.makedirs(path)
74 |
75 |
--------------------------------------------------------------------------------