├── README.md
├── config.yml
├── examples
└── DucoNet.png
├── iharm
├── data
│ ├── base.py
│ ├── compose.py
│ ├── hdataset.py
│ └── transforms.py
├── engine
│ ├── event_loop.py
│ ├── optimizer.py
│ └── simple_trainer.py
├── inference
│ ├── evaluation.py
│ ├── metrics.py
│ ├── predictor.py
│ ├── transforms.py
│ └── utils.py
├── mconfigs
│ ├── __init__.py
│ ├── backboned.py
│ └── base.py
├── model
│ ├── backboned
│ │ ├── __init__.py
│ │ ├── deeplab.py
│ │ ├── hrnet.py
│ │ └── ih_model.py
│ ├── base
│ │ ├── Control_encoder.py
│ │ ├── DucoNet_model.py
│ │ ├── __init__.py
│ │ ├── dih_model.py
│ │ ├── iseunet_v1.py
│ │ └── ssam_model.py
│ ├── initializer.py
│ ├── losses.py
│ ├── metrics.py
│ ├── modeling
│ │ ├── basic_blocks.py
│ │ ├── conv_autoencoder.py
│ │ ├── deeplab_v3.py
│ │ ├── hrnet_ocr.py
│ │ ├── ocr.py
│ │ ├── resnet.py
│ │ ├── resnetv1b.py
│ │ ├── styleganv2.py
│ │ ├── unet.py
│ │ └── unet_v1.py
│ ├── modifiers.py
│ ├── ops.py
│ └── syncbn
│ │ ├── LICENSE
│ │ ├── README.md
│ │ ├── __init__.py
│ │ └── modules
│ │ ├── __init__.py
│ │ ├── functional
│ │ ├── __init__.py
│ │ ├── _csrc.py
│ │ ├── csrc
│ │ │ ├── bn.h
│ │ │ ├── cuda
│ │ │ │ ├── bn_cuda.cu
│ │ │ │ ├── common.h
│ │ │ │ └── ext_lib.h
│ │ │ └── ext_lib.cpp
│ │ └── syncbn.py
│ │ └── nn
│ │ ├── __init__.py
│ │ └── syncbn.py
└── utils
│ ├── exp.py
│ ├── log.py
│ └── misc.py
├── models
├── DucoNet_1024.py
├── DucoNet_256.py
└── improved_ssam.py
├── requirements.txt
├── scripts
├── evaluate_model.py
├── evaluate_model_fg_ratios.py
└── predict_for_dir.py
├── test.sh
├── train.py
└── train.sh
/README.md:
--------------------------------------------------------------------------------
1 | [ACM MM-23] Deep image harmonization in Dual Color Space [ACM MM-23]
2 |
3 | [](https://paperswithcode.com/sota/image-harmonization-on-iharmony4?p=deep-image-harmonization-in-dual-color-spaces)
4 |
5 |
6 | This is the official repository for the following paper:
7 |
8 | > **Deep Image Harmonization in Dual Color Spaces** [[arXiv]](https://arxiv.org/abs/2308.02813)
9 | >
10 | > Linfeng Tan, Jiangtong Li, Li Niu, Liqing Zhang
11 | > Accepted by **ACMMM2023**.
12 |
13 | >Image harmonization is an essential step in image composition that adjusts the appearance of composite foreground to address the inconsistency between foreground and background. Existing methods primarily operate in correlated $RGB$ color space, leading to entangled features and limited representation ability. In contrast, decorrelated color space ($Lab$) has decorrelated channels that provide disentangled color and illumination statistics. In this paper, we explore image harmonization in dual color spaces, which supplements entangled $RGB$ features with disentangled $L$, $a$, $b$ features to alleviate the workload in harmonization process. The network comprises a $RGB$ harmonization backbone, an $Lab$ encoding module, and an $Lab$ control module. The backbone is a U-Net network translating composite image to harmonized image. Three encoders in $Lab$ encoding module extract three control codes independently from $L$, $a$, $b$ channels, which are used to manipulate the decoder features in harmonization backbone via $Lab$ control module.
14 |
15 | 
16 |
17 | ## Getting Started
18 |
19 | ### Prerequisites
20 | Please refer to [iSSAM](https://github.com/saic-vul/image_harmonization) for guidance on setting up the environment.
21 |
22 | ### Installation
23 | + Clone this repo:
24 | ```
25 | git clone https://github.com/bcmi/DucoNet-Image-Harmonization.git
26 | cd ./DucoNet-Image-Harmonization
27 | ```
28 | + Download the [iHarmony4](https://github.com/bcmi/Image-Harmonization-Dataset-iHarmony4) dataset, and configure the paths to the datasets in [config.yml](./config.yml).
29 |
30 | - Install PyTorch and dependencies from http://pytorch.org.
31 |
32 | - Install python requirements:
33 |
34 | ```bash
35 | pip install -r requirements.txt
36 | ```
37 |
38 | ### Training
39 | If you want to train DucoNet on dataset [iHarmony4](https://github.com/bcmi/Image-Harmonization-Dataset-iHarmony4), you can run this command:
40 |
41 | ```
42 | ## for low-resolution
43 |
44 | python train.py models/DucoNet_256.py --workers=8 --gpus=0,1 --exp-name=DucoNet_256 --batch-size=64
45 |
46 | ## for high-resolution
47 |
48 | python train.py models/DucoNet_1024.py --workers=8 --gpus=2,3 --exp-name=DucoNet_1024 --batch-size=4
49 | ```
50 |
51 | We have also provided some commands in the "train.sh" for your convenience.
52 |
53 | ### Testing
54 | You can run the following command to test the pretrained model, and you can download the pre-trained model we released from [Dropbox](https://www.dropbox.com/scl/fo/jnq5sgokct3n0l2ix8knl/AGwf0vaEHULz2qjg1iK81jA?rlkey=v3djzqf3v1upiddb1wc7t0481&st=vz6vi7uc&dl=0) or [Baidu Cloud](https://pan.baidu.com/s/1lnDOnmN1tLeoIcvjWvWFkQ?pwd=bcmi):
55 | ```
56 | python scripts/evaluate_model.py DucoNet ./checkpoints/last_model/DucoNet256.pth \
57 | --resize-strategy Fixed256 \
58 | --gpu 0
59 |
60 | #python scripts/evaluate_model.py DucoNet ./checkpoints/last_model/DucoNet1024.pth \
61 | #--resize-strategy Fixed1024 \
62 | #--gpu 1 \
63 | #--datasets HAdobe5k1
64 | ```
65 |
66 | We have also provided some commands in the "test.sh" for your convenience.
67 |
68 | ## Results and Pretrained model
69 |
70 | We test our DucoNet on iHarmony4 dataset with image size 256×256 and on HAdobe5k dataset with image size 1024×1024. We report our results on evaluation metrics, including MSE, fMSE, and PSNR.
71 | We also released the pretrained model corresponding to our results, you can download it from the corresponding link.
72 |
73 | | Image Size | fMSE | MSE | PSNR | Google Drive | Baidu Cloud |
74 | | ------------------ | ------ | ----- | ----- | ---------------- | ---------------- |
75 | | 256 $\times$ 256 | 212.53 | 18.47 | 39.17 | [Dropbox](https://www.dropbox.com/scl/fi/8rf3or6gt35wizyxpgf0y/DucoNet256.pth?rlkey=jjun6hywcy2wte8k5idgh28jv&st=l3jwirjp&dl=0) | [Baidu Cloud](https://pan.baidu.com/s/1lnDOnmN1tLeoIcvjWvWFkQ?pwd=bcmi) |
76 | | 1024 $\times$ 1024 | 80.69 | 10.94 | 41.37 | [Dropbox](https://www.dropbox.com/scl/fi/tpowgk1ezf090ezb6o4ib/DucoNet1024.pth?rlkey=gyt042f5h5igc6ypk24stnwtd&st=gk10qjv2&dl=0) | [Baidu Cloud](https://pan.baidu.com/s/1lnDOnmN1tLeoIcvjWvWFkQ?pwd=bcmi) |
77 |
78 | ## Other Resources
79 |
80 | + [Awesome-Image-Harmonization](https://github.com/bcmi/Awesome-Image-Harmonization)
81 | + [Awesome-Image-Composition](https://github.com/bcmi/Awesome-Object-Insertion)
82 |
83 | ## Acknowledgement
84 |
85 | Our code is heavily borrowed from [iSSAM](https://github.com/saic-vul/image_harmonization) and PyTorch implementation of [styleGANv2](https://github.com/labmlai/annotated_deep_learning_paper_implementations) .
86 |
87 | ## Bibtex:
88 | If you are interested in our work, please consider citing the following:
89 |
90 | ```
91 | @article{tan2023deep,
92 | title={Deep Image Harmonization in Dual Color Spaces},
93 | author={Tan, Linfeng and Li, Jiangtong and Niu, Li and Zhang, Liqing},
94 | journal={arXiv preprint arXiv:2308.02813},
95 | year={2023}
96 | }
97 |
--------------------------------------------------------------------------------
/config.yml:
--------------------------------------------------------------------------------
1 | MODELS_PATH: "./"
2 | EXPS_PATH: "./checkpoints/output"
3 |
4 | HFLICKR_PATH: "/home/dataset/IHD/HFlickr"
5 | HDAY2NIGHT_PATH: "/home/dataset/IHD/Hday2night"
6 | HCOCO_PATH: "/home/dataset/IHD/HCOCO"
7 | HADOBE5K1_PATH: "/data/IHD/HAdobe5k"
8 | HADOBE5K_PATH: "/home/dataset/IHD/HAdobe5k_resized1024"
9 |
10 |
11 | IMAGENET_PRETRAINED_MODELS:
12 | HRNETV2_W18_SMALL: "./pretrained_models/hrnet_w18_small_model_v2.pth"
13 | HRNETV2_W18: "./pretrained_models/hrnetv2_w18_imagenet_pretrained.pth"
14 | HRNETV2_W32: "./pretrained_models/hrnetv2_w32_imagenet_pretrained.pth"
15 | HRNETV2_W40: "./pretrained_models/hrnetv2_w40_imagenet_pretrained.pth"
16 | HRNETV2_W48: "./pretrained_models/hrnetv2_w48_imagenet_pretrained.pth"
17 | HistNet: "./checkpoints/last_model/hen_train/last_checkpoint.pth"
18 | # HistNet_old: "./checkpoints/last_model/hen_train/last_checkpoint.pth"
19 |
--------------------------------------------------------------------------------
/examples/DucoNet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/DucoNet-Image-Harmonization/167cd720d91f0f5d9e64b54c3c8d4300d8915be2/examples/DucoNet.png
--------------------------------------------------------------------------------
/iharm/data/base.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import torch
4 | import cv2
5 |
6 | from torchvision import transforms
7 |
8 | class BaseHDataset(torch.utils.data.dataset.Dataset):
9 | def __init__(self,
10 | augmentator=None,
11 | input_transform=None,
12 | keep_background_prob=0.0,
13 | with_image_info=False,
14 | epoch_len=-1,
15 | ):
16 | super(BaseHDataset, self).__init__()
17 | self.epoch_len = epoch_len
18 | self.input_transform = input_transform
19 | self.augmentator = augmentator
20 | self.keep_background_prob = keep_background_prob
21 | self.with_image_info = with_image_info
22 |
23 | if input_transform is None:
24 | input_transform = lambda x: x
25 |
26 | self.input_transform = input_transform
27 | self.dataset_samples = None
28 |
29 | def __getitem__(self, index):
30 | if self.epoch_len > 0:
31 | index = random.randrange(0, len(self.dataset_samples))
32 |
33 | sample = self.get_sample(index)
34 | self.check_sample_types(sample)
35 | sample = self.augment_sample(sample)
36 |
37 | image = self.input_transform(sample['image'])
38 | target_image = self.input_transform(sample['target_image'])
39 | obj_mask = sample['object_mask'].astype(np.float32)
40 |
41 | transform = transforms.Compose([
42 | transforms.ToTensor(),
43 | ])
44 | comp_img_lab = sample['comp_img_lab']
45 | real_img_lab = sample['real_img_lab']
46 |
47 | comp_image_lab = transform(comp_img_lab)
48 | real_image_lab = transform(real_img_lab)
49 |
50 | output = {
51 | 'images': image,
52 | 'masks': obj_mask[np.newaxis, ...].astype(np.float32),
53 | 'target_images': target_image,
54 | 'comp_img_lab' : comp_image_lab,
55 | 'real_img_lab' : real_image_lab,
56 | }
57 |
58 | if self.with_image_info and 'image_id' in sample:
59 | output['image_info'] = sample['image_id']
60 |
61 | return output
62 |
63 | def check_sample_types(self, sample):
64 | assert sample['image'].dtype == 'uint8'
65 | if 'target_image' in sample:
66 | assert sample['target_image'].dtype == 'uint8'
67 |
68 | def augment_sample(self, sample):
69 | if self.augmentator is None:
70 | return sample
71 |
72 | additional_targets = {target_name: sample[target_name]
73 | for target_name in self.augmentator.additional_targets.keys()}
74 |
75 | valid_augmentation = False
76 | while not valid_augmentation:
77 | aug_output = self.augmentator(image=sample['image'], **additional_targets)
78 | valid_augmentation = self.check_augmented_sample(sample, aug_output)
79 |
80 | for target_name, transformed_target in aug_output.items():
81 | sample[target_name] = transformed_target
82 |
83 | return sample
84 |
85 | def check_augmented_sample(self, sample, aug_output):
86 | if self.keep_background_prob < 0.0 or random.random() < self.keep_background_prob:
87 | return True
88 |
89 | return aug_output['object_mask'].sum() > 10.0
90 |
91 | def get_sample(self, index):
92 | raise NotImplementedError
93 |
94 | def __len__(self):
95 | if self.epoch_len > 0:
96 | return self.epoch_len
97 | else:
98 | return len(self.dataset_samples)
99 |
100 |
--------------------------------------------------------------------------------
/iharm/data/compose.py:
--------------------------------------------------------------------------------
1 | from .base import BaseHDataset
2 |
3 |
4 | class ComposeDataset(BaseHDataset):
5 | def __init__(self, datasets, **kwargs):
6 | super(ComposeDataset, self).__init__(**kwargs)
7 |
8 | self._datasets = datasets
9 | self.dataset_samples = []
10 | for dataset_indx, dataset in enumerate(self._datasets):
11 | self.dataset_samples.extend([(dataset_indx, i) for i in range(len(dataset))])
12 |
13 | def get_sample(self, index):
14 | dataset_indx, sample_indx = self.dataset_samples[index]
15 | return self._datasets[dataset_indx].get_sample(sample_indx)
16 |
--------------------------------------------------------------------------------
/iharm/data/hdataset.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import os
3 | import cv2
4 | import numpy as np
5 |
6 | from .base import BaseHDataset
7 |
8 |
9 | class HDataset(BaseHDataset):
10 | def __init__(self, dataset_path, split, blur_target=False, **kwargs):
11 | super(HDataset, self).__init__(**kwargs)
12 |
13 | self.dataset_path = Path(dataset_path)
14 | self.blur_target = blur_target
15 | self._split = split
16 | self._real_images_path = self.dataset_path / 'real_images'
17 | self._composite_images_path = self.dataset_path / 'composite_images'
18 | self._masks_path = self.dataset_path / 'masks'
19 |
20 | images_lists_paths = [x for x in self.dataset_path.glob('*.txt') if x.stem.endswith(split)]
21 |
22 | assert len(images_lists_paths) == 1
23 |
24 | with open(images_lists_paths[0], 'r') as f:
25 | self.dataset_samples = [x.strip() for x in f.readlines()]
26 |
27 | def get_sample(self, index):
28 | composite_image_name = self.dataset_samples[index]
29 | real_image_name = composite_image_name.split('_')[0] + '.jpg'
30 | mask_name = '_'.join(composite_image_name.split('_')[:-1]) + '.png'
31 |
32 |
33 | composite_image_path = str(self._composite_images_path / composite_image_name)
34 | real_image_path = str(self._real_images_path / real_image_name)
35 | mask_path = str(self._masks_path / mask_name)
36 |
37 | composite_image = cv2.imread(composite_image_path)
38 | composite_image = cv2.cvtColor(composite_image, cv2.COLOR_BGR2RGB)
39 |
40 | composite_image_lab = cv2.cvtColor(composite_image, cv2.COLOR_RGB2Lab)
41 |
42 | real_image = cv2.imread(real_image_path)
43 | real_image = cv2.cvtColor(real_image, cv2.COLOR_BGR2RGB)
44 |
45 | real_image_lab = cv2.cvtColor(real_image, cv2.COLOR_RGB2Lab)
46 |
47 | object_mask_image = cv2.imread(mask_path)
48 | object_mask = object_mask_image[:, :, 0].astype(np.float32) / 255.
49 | if self.blur_target:
50 | object_mask = cv2.GaussianBlur(object_mask, (7, 7), 0)
51 |
52 | return {
53 | 'image': composite_image,
54 | 'object_mask': object_mask,
55 | 'target_image': real_image,
56 | 'image_id': index,
57 | 'comp_img_lab' :composite_image_lab,
58 | 'real_img_lab' :real_image_lab,
59 | 'comp_img_name':composite_image_name,
60 | }
61 |
--------------------------------------------------------------------------------
/iharm/data/transforms.py:
--------------------------------------------------------------------------------
1 | from albumentations import Compose, LongestMaxSize, DualTransform
2 | import albumentations.augmentations.functional as F
3 | import cv2
4 |
5 |
6 | class HCompose(Compose):
7 | def __init__(self, transforms, *args, additional_targets=None, no_nearest_for_masks=True, **kwargs):
8 | if additional_targets is None:
9 | additional_targets = {
10 | 'target_image': 'image',
11 | 'object_mask': 'mask',
12 | 'comp_img_lab' :'image',
13 | 'real_img_lab' :'image',
14 | }
15 | self.additional_targets = additional_targets
16 | super().__init__(transforms, *args, additional_targets=additional_targets, **kwargs)
17 | if no_nearest_for_masks:
18 | for t in transforms:
19 | if isinstance(t, DualTransform):
20 | t._additional_targets['object_mask'] = 'image'
21 |
22 |
23 | class LongestMaxSizeIfLarger(LongestMaxSize):
24 | """
25 | Rescale an image so that maximum side is less or equal to max_size, keeping the aspect ratio of the initial image.
26 | If image sides are smaller than the given max_size, no rescaling is applied.
27 |
28 | Args:
29 | max_size (int): maximum size of smallest side of the image after the transformation.
30 | interpolation (OpenCV flag): interpolation method. Default: cv2.INTER_LINEAR.
31 | p (float): probability of applying the transform. Default: 1.
32 |
33 | Targets:
34 | image, mask, bboxes, keypoints
35 |
36 | Image types:
37 | uint8, float32
38 | """
39 | def apply(self, img, interpolation=cv2.INTER_LINEAR, **params):
40 | if max(img.shape[:2]) < self.max_size:
41 | return img
42 | return F.longest_max_size(img, max_size=self.max_size, interpolation=interpolation)
43 |
44 | def apply_to_keypoint(self, keypoint, **params):
45 | height = params["rows"]
46 | width = params["cols"]
47 |
48 | scale = self.max_size / max([height, width])
49 | if scale > 1.0:
50 | return keypoint
51 | return F.keypoint_scale(keypoint, scale, scale)
52 |
--------------------------------------------------------------------------------
/iharm/engine/event_loop.py:
--------------------------------------------------------------------------------
1 | import time
2 | import math
3 | from collections import deque
4 |
5 |
6 | class EventLoop(object):
7 | def __init__(self, console_logger=None, tb_writer=None, period_length=1000, start_tick=0,
8 | speed_counter='ma_10', speed_tb_period=-1):
9 | self.period_length = period_length
10 | self.total_ticks = start_tick
11 | self.events = []
12 | self.metrics = {}
13 | self.last_time = None
14 | self.tb_writer = tb_writer
15 | self.console_logger = console_logger
16 | self._speed_metric = '_speed_'
17 | self._max_delta_time = 60
18 | self._prev_total_ticks = 0
19 |
20 | self.register_metric(
21 | self._speed_metric,
22 | console_name='Speed', console_period=-1, console_format='{:.2f} samples/s',
23 | tb_name='Misc/Speed', tb_period=speed_tb_period,
24 | counter=speed_counter
25 | )
26 |
27 | def register_event(self, event, period, func_inputs=(),
28 | last_step=0, associated_metric=None, onetime=False):
29 | event = {
30 | 'event': event,
31 | 'last_step': last_step,
32 | 'period': period,
33 | 'func_inputs': func_inputs,
34 | 'onetime': onetime
35 | }
36 | if associated_metric is not None:
37 | event['metric'] = associated_metric
38 |
39 | self.events.append(event)
40 |
41 | def register_metric(self, metric_name,
42 | console_name=None, console_period=-1, console_format='{:.3f}',
43 | tb_name=None, tb_period=-1, tb_global_step='n_ticks',
44 | counter=None):
45 |
46 | self.metrics[metric_name] = {
47 | 'console_name': console_name if console_name is not None else metric_name,
48 | 'console_event': {'last_step': 0, 'period': console_period},
49 | 'console_format': console_format,
50 | 'tb_name': tb_name if tb_name is not None else metric_name,
51 | 'tb_event': {'last_step': 0, 'period': tb_period},
52 | 'counters': {counter: parse_counter(counter)},
53 | 'default_counter': counter,
54 | 'tb_global_step': tb_global_step
55 | }
56 |
57 | def register_metric_event(self, event, metric_name, period, func_inputs=(),
58 | console_period=0, tb_period=0,
59 | **metric_kwargs):
60 | self.register_event(event, period, func_inputs=func_inputs, associated_metric=metric_name)
61 | self.register_metric(metric_name, console_period=console_period, tb_period=tb_period,
62 | **metric_kwargs)
63 |
64 | def add_metric_value(self, metric_name, value):
65 | metric = self.metrics[metric_name]
66 | for counter in metric['counters'].values():
67 | counter.add(value)
68 | metric['console_event']['relaxed'] = False
69 | metric['tb_event']['relaxed'] = False
70 |
71 | def add_custom_metric_counter(self, metric_name, counter):
72 | metric = self.metrics[metric_name]
73 | if counter not in metric['counters']:
74 | metric['counters'][counter] = parse_counter(counter)
75 |
76 | def get_metric_value(self, metric_name, counter_name=None):
77 | metric = self.metrics[metric_name]
78 | if counter_name is None:
79 | return metric['counters'][metric['default_counter']].value
80 | else:
81 | return metric['counters'][counter_name].value
82 |
83 | def step(self, step_size):
84 | self._prev_total_ticks = self.total_ticks
85 | self.total_ticks += step_size
86 |
87 | self._update_time(step_size)
88 | self._check_events()
89 | self._check_metrics()
90 |
91 | def get_states(self):
92 | return {
93 | 'metrics': self.metrics,
94 | 'total_ticks': self.total_ticks,
95 | '_prev_total_ticks': self._prev_total_ticks,
96 | 'base_period': self.period_length
97 | }
98 |
99 | def set_states(self, states):
100 | for k, v in states.items():
101 | if k == 'metrics':
102 | self.metrics.update(v)
103 | else:
104 | setattr(self, k, v)
105 |
106 | @property
107 | def n_periods(self):
108 | return self.total_ticks // self.period_length
109 |
110 | @property
111 | def f_periods(self):
112 | return self.total_ticks / self.period_length
113 |
114 | @property
115 | def n_ticks(self):
116 | return self.total_ticks
117 |
118 | def _check_events(self):
119 | triggered_events = []
120 | for event in self.events:
121 | if self._check_event(event, ignore_relaxed=True):
122 | triggered_events.append(event)
123 | if event['onetime']:
124 | event['remove'] = True
125 |
126 | self.events = [event for event in self.events if not event.get('remove', False)]
127 |
128 | for event in triggered_events:
129 | event_metric = event.get('metric', None)
130 | inputs = (getattr(self, input_name) for input_name in event['func_inputs'])
131 | if event_metric is None:
132 | event['event'](*inputs)
133 | else:
134 | metric_value = event['event'](*inputs)
135 | self.add_metric_value(event_metric, metric_value)
136 |
137 | def _check_metrics(self):
138 | print_list = []
139 | for metric_name, metric in self.metrics.items():
140 | if self._check_event(metric['console_event']):
141 | print_list.append(metric_name)
142 |
143 | if self.tb_writer is not None and self._check_event(metric['tb_event']):
144 | global_step = getattr(self, metric['tb_global_step'])
145 | self.tb_writer.add_scalar(metric['tb_name'], self.get_metric_value(metric_name),
146 | global_step=global_step)
147 |
148 | if self.console_logger is not None and print_list:
149 | print_list.append(self._speed_metric)
150 |
151 | log_str = f'[{self.n_periods:06d}] '
152 | for metric_name in print_list:
153 | metric = self.metrics[metric_name]
154 | metric_value = self.get_metric_value(metric_name)
155 | log_str += metric['console_name'] + ': ' + metric['console_format'].format(metric_value) + ' '
156 |
157 | self.console_logger.info(log_str)
158 |
159 | def _check_event(self, event, ignore_relaxed=False):
160 | if not ignore_relaxed and event.get('relaxed', True):
161 | return False
162 |
163 | triggered = False
164 | event_period = event['period']
165 | if isinstance(event_period, dict):
166 | pstates = event.get('period_states', None)
167 | if pstates is None:
168 | sorted_periods = list(sorted(event['period'].items()))
169 | pstates = {'sorted_periods': sorted_periods, 'p_indx': 0}
170 | event['period_states'] = pstates
171 |
172 | sorted_periods = pstates['sorted_periods']
173 | p_indx = pstates['p_indx']
174 | while p_indx + 1 < len(sorted_periods) and self.n_periods >= sorted_periods[p_indx + 1][0]:
175 | p_indx += 1
176 | pstates['p_indx'] = p_indx
177 | event['last_step'] = self.period_length * (sorted_periods[p_indx][0] - sorted_periods[p_indx][1])
178 | event_period = sorted_periods[p_indx][1]
179 |
180 | if event_period == 0:
181 | triggered = True
182 | elif event_period > 0:
183 | event_period_ticks = self.period_length * event_period
184 | last_step = event['last_step']
185 | k = (self.total_ticks - last_step) / event_period_ticks
186 | if k >= 1.0 or k < 0:
187 | event['last_step'] = last_step + math.floor(k) * event_period_ticks
188 | prev_k = (self._prev_total_ticks - last_step) / event_period_ticks
189 | triggered = k >= 1.0 and prev_k < 1.0
190 |
191 | if triggered:
192 | event['relaxed'] = True
193 | return triggered
194 |
195 | def _update_time(self, step_size):
196 | current_time = time.time()
197 | if self.last_time is None:
198 | delta_time = 0
199 | else:
200 | delta_time = current_time - self.last_time
201 | self.last_time = current_time
202 | if delta_time > self._max_delta_time:
203 | delta_time = 0
204 |
205 | if delta_time > 0:
206 | speed = step_size / delta_time
207 | self.add_metric_value(self._speed_metric, speed)
208 |
209 |
210 | def parse_counter(counter_desc):
211 | if counter_desc is not None:
212 | smoothing_name, smoothing_period = counter_desc.split('_')
213 | smoothing_period = int(smoothing_period)
214 |
215 | if smoothing_name == 'ma':
216 | counter = MovingAverage(smoothing_period)
217 | else:
218 | raise NotImplementedError
219 | else:
220 | counter = SimpleLastValue()
221 |
222 | return counter
223 |
224 |
225 | class MovingAverage(object):
226 | def __init__(self, window_size):
227 | self.window_size = window_size
228 | self.window = deque()
229 | self.sum = 0
230 | self.cnt = 0
231 |
232 | def add(self, value):
233 | if len(self.window) < self.window_size:
234 | self.sum += value
235 | self.window.append(value)
236 | else:
237 | first = self.window.popleft()
238 | self.sum -= first
239 | self.sum += value
240 | self.window.append(value)
241 | self.cnt += 1
242 |
243 | @property
244 | def value(self):
245 | if self.window:
246 | return self.sum / len(self.window)
247 | else:
248 | return 0
249 |
250 | def __len__(self):
251 | return self.cnt
252 |
253 |
254 | class SimpleLastValue(object):
255 | def __init__(self):
256 | self.last_value = 0
257 | self.cnt = 0
258 |
259 | def add(self, value):
260 | self.last_value = value
261 | self.cnt += 1
262 |
263 | @property
264 | def value(self):
265 | return self.last_value
266 |
--------------------------------------------------------------------------------
/iharm/engine/optimizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | from iharm.utils.log import logger
4 |
5 |
6 | def get_optimizer(model, opt_name, opt_kwargs):
7 | params = []
8 | base_lr = opt_kwargs['lr']
9 | for name, param in model.named_parameters():
10 | param_group = {'params': [param]}
11 | if not param.requires_grad:
12 | params.append(param_group)
13 | continue
14 |
15 | if not math.isclose(getattr(param, 'lr_mult', 1.0), 1.0):
16 | logger.info(f'Applied lr_mult={param.lr_mult} to "{name}" parameter.')
17 | param_group['lr'] = param_group.get('lr', base_lr) * param.lr_mult
18 |
19 | params.append(param_group)
20 |
21 | optimizer = {
22 | 'sgd': torch.optim.SGD,
23 | 'adam': torch.optim.Adam,
24 | 'adamw': torch.optim.AdamW
25 | }[opt_name.lower()](params, **opt_kwargs)
26 |
27 | return optimizer
28 |
--------------------------------------------------------------------------------
/iharm/inference/evaluation.py:
--------------------------------------------------------------------------------
1 | import os
2 | from time import time
3 | from tqdm import trange
4 | import torch
5 | import torch.nn as nn
6 | import random
7 | import cv2
8 | import numpy as np
9 |
10 |
11 | def evaluate_dataset(dataset, predictor, metrics_hub):
12 |
13 | for sample_i in trange(len(dataset), desc=f'Testing on {metrics_hub.name}'):
14 | sample = dataset.get_sample(sample_i)
15 | sample = dataset.augment_sample(sample)
16 |
17 | sample_mask = sample['object_mask']
18 | image_lab = sample['comp_img_lab']
19 |
20 | predict_start = time()
21 |
22 | pred = predictor.predict(sample['image'], sample_mask, image_lab=image_lab,return_numpy=False)
23 |
24 | torch.cuda.synchronize()
25 | metrics_hub.update_time(time() - predict_start)
26 |
27 | target_image = torch.as_tensor(sample['target_image'], dtype=torch.float32).to(predictor.device)
28 | sample_mask = torch.as_tensor(sample_mask, dtype=torch.float32).to(predictor.device)
29 | with torch.no_grad():
30 | metrics_hub.compute_and_add(pred, target_image, sample_mask)
31 |
32 |
--------------------------------------------------------------------------------
/iharm/inference/metrics.py:
--------------------------------------------------------------------------------
1 | from copy import copy
2 | import math
3 |
4 |
5 | class MetricsHub:
6 | def __init__(self, metrics, name='', name_width=20):
7 | self.metrics = metrics
8 | self.name = name
9 | self.name_width = name_width
10 |
11 | def compute_and_add(self, *args):
12 | for m in self.metrics:
13 | if not isinstance(m, TimeMetric):
14 | m.compute_and_add(*args)
15 |
16 | def update_time(self, time_value):
17 | for m in self.metrics:
18 | if isinstance(m, TimeMetric):
19 | m.update_time(time_value)
20 |
21 | def get_table_header(self):
22 | table_header = ' ' * self.name_width + '|'
23 | for m in self.metrics:
24 | table_header += f'{m.name:^{m.cwidth}}|'
25 | splitter = len(table_header) * '-'
26 | return f'{splitter}\n{table_header}\n{splitter}'
27 |
28 | def __add__(self, another_hub):
29 | merged_metrics = []
30 | for a, b in zip(self.metrics, another_hub.metrics):
31 | merged_metrics.append(a + b)
32 | if not merged_metrics:
33 | merged_metrics = copy(another_hub.metrics)
34 |
35 | return MetricsHub(merged_metrics, name=self.name, name_width=self.name_width)
36 |
37 | def __repr__(self):
38 | table_row = f'{self.name:<{self.name_width}}|'
39 | for m in self.metrics:
40 | table_row += f'{str(m):^{m.cwidth}}|'
41 | return table_row
42 |
43 |
44 | class EvalMetric:
45 | def __init__(self):
46 | self._values_sum = 0.0
47 | self._count = 0
48 | self.cwidth = 10
49 |
50 | def compute_and_add(self, pred, target_image, mask):
51 | self._values_sum += self._compute_metric(pred, target_image, mask)
52 | self._count += 1
53 |
54 | def _compute_metric(self, pred, target_image, mask):
55 | raise NotImplementedError
56 |
57 | def __add__(self, another_eval_metric):
58 | comb_metric = copy(self)
59 | comb_metric._count += another_eval_metric._count
60 | comb_metric._values_sum += another_eval_metric._values_sum
61 | return comb_metric
62 |
63 | @property
64 | def value(self):
65 | return self._values_sum / self._count if self._count > 0 else None
66 |
67 | @property
68 | def name(self):
69 | return type(self).__name__
70 |
71 | def __repr__(self):
72 | return f'{self.value:.2f}'
73 |
74 | def __len__(self):
75 | return self._count
76 |
77 |
78 | class MSE(EvalMetric):
79 | def _compute_metric(self, pred, target_image, mask):
80 | return ((pred - target_image) ** 2).mean().item()
81 |
82 |
83 | class fMSE(EvalMetric):
84 | def _compute_metric(self, pred, target_image, mask):
85 | diff = mask.unsqueeze(2) * ((pred - target_image) ** 2)
86 | return (diff.sum() / (diff.size(2) * mask.sum() + 1e-6)).item()
87 |
88 |
89 | class PSNR(MSE):
90 | def __init__(self, epsilon=1e-6):
91 | super().__init__()
92 | self._epsilon = epsilon
93 |
94 | def _compute_metric(self, pred, target_image, mask):
95 | mse = super()._compute_metric(pred, target_image, mask)
96 | squared_max = target_image.max().item() ** 2
97 |
98 | return 10 * math.log10(squared_max / (mse + self._epsilon))
99 |
100 |
101 | class N(EvalMetric):
102 | def _compute_metric(self, pred, target_image, mask):
103 | return 0
104 |
105 | @property
106 | def value(self):
107 | return self._count
108 |
109 | def __repr__(self):
110 | return str(self.value)
111 |
112 |
113 | class TimeMetric(EvalMetric):
114 | def update_time(self, time_value):
115 | self._values_sum += time_value
116 | self._count += 1
117 |
118 |
119 | class AvgPredictTime(TimeMetric):
120 | def __init__(self):
121 | super().__init__()
122 | self.cwidth = 14
123 |
124 | @property
125 | def name(self):
126 | return 'AvgTime, ms'
127 |
128 | def __repr__(self):
129 | return f'{1000 * self.value:.1f}'
130 |
--------------------------------------------------------------------------------
/iharm/inference/predictor.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from iharm.inference.transforms import NormalizeTensor, PadToDivisor, ToTensor, AddFlippedTensor
4 | from torchvision import transforms
5 | import numpy as np
6 | import cv2
7 |
8 | class Predictor(object):
9 | def __init__(self, net, device, with_flip=False,
10 | mean=(.485, .456, .406), std=(.229, .224, .225)):
11 | self.device = device
12 | self.net = net.to(self.device)
13 | self.net.eval()
14 |
15 | if hasattr(net, 'depth'):
16 | size_divisor = 2 ** (net.depth + 1)
17 | else:
18 | size_divisor = 1
19 |
20 | mean = torch.tensor(mean, dtype=torch.float32)
21 | std = torch.tensor(std, dtype=torch.float32)
22 | self.transforms = [
23 | PadToDivisor(divisor=size_divisor, border_mode=0),
24 | ToTensor(self.device),
25 | NormalizeTensor(mean, std, self.device),
26 | ]
27 | if with_flip:
28 | self.transforms.append(AddFlippedTensor())
29 |
30 | def predict(self, image, mask, image_lab = None,return_numpy=True):
31 | with torch.no_grad():
32 | for transform in self.transforms:
33 | image, mask = transform.transform(image, mask)
34 |
35 | transform = transforms.Compose([
36 | transforms.ToTensor(),
37 | ])
38 | comp_image_lab = transform(image_lab).unsqueeze(0).to(self.device)
39 | predicted_output = self.net(image, mask,image_lab = comp_image_lab)
40 | predicted_image = predicted_output['images']
41 |
42 |
43 | for transform in reversed(self.transforms):
44 | predicted_image = transform.inv_transform(predicted_image)
45 |
46 | predicted_image = torch.clamp(predicted_image, 0, 255)
47 |
48 | if return_numpy:
49 | return predicted_image.cpu().numpy()
50 | else:
51 | return predicted_image
--------------------------------------------------------------------------------
/iharm/inference/transforms.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import torch
3 | from collections import namedtuple
4 |
5 |
6 | class EvalTransform:
7 | def __init__(self):
8 | pass
9 |
10 | def transform(self, image, mask):
11 | raise NotImplementedError
12 |
13 | def inv_transform(self, image):
14 | raise NotImplementedError
15 |
16 |
17 | class PadToDivisor(EvalTransform):
18 | """
19 | Pad side of the image so that its side is divisible by divisor.
20 |
21 | Args:
22 | divisor (int): desirable image size divisor
23 | border_mode (OpenCV flag): OpenCV border mode.
24 | fill_value (int, float, list of int, lisft of float): padding value if border_mode is cv2.BORDER_CONSTANT.
25 | """
26 | PadParams = namedtuple('PadParams', ['top', 'bottom', 'left', 'right'])
27 |
28 | def __init__(self, divisor, border_mode=cv2.BORDER_CONSTANT, fill_value=0):
29 | super().__init__()
30 | self.border_mode = border_mode
31 | self.fill_value = fill_value
32 | self.divisor = divisor
33 | self._pads = None
34 |
35 | def transform(self, image, mask):
36 | self._pads = PadToDivisor.PadParams(*self._get_dim_padding(image.shape[0]),
37 | *self._get_dim_padding(image.shape[1]))
38 |
39 | image = cv2.copyMakeBorder(image, *self._pads, self.border_mode, value=self.fill_value)
40 | mask = cv2.copyMakeBorder(mask, *self._pads, self.border_mode, value=self.fill_value)
41 |
42 | return image, mask
43 |
44 | def inv_transform(self, image):
45 | assert self._pads is not None,\
46 | 'Something went wrong, inv_transform(...) should be called after transform(...)'
47 | return self._remove_padding(image)
48 |
49 | def _get_dim_padding(self, dim_size):
50 | pad = (self.divisor - dim_size % self.divisor) % self.divisor
51 | pad_upper = pad // 2
52 | pad_lower = pad - pad_upper
53 |
54 | return pad_upper, pad_lower
55 |
56 | def _remove_padding(self, tensor):
57 | tensor_h, tensor_w = tensor.shape[:2]
58 | cropped = tensor[self._pads.top:tensor_h - self._pads.bottom,
59 | self._pads.left:tensor_w - self._pads.right, :]
60 | return cropped
61 |
62 |
63 | class NormalizeTensor(EvalTransform):
64 | def __init__(self, mean, std, device):
65 | super().__init__()
66 | self.mean = torch.as_tensor(mean).reshape(1, 3, 1, 1).to(device)
67 | self.std = torch.as_tensor(std).reshape(1, 3, 1, 1).to(device)
68 |
69 | def transform(self, image, mask):
70 | image.sub_(self.mean).div_(self.std)
71 | return image, mask
72 |
73 | def inv_transform(self, image):
74 | image.mul_(self.std).add_(self.mean)
75 | return image
76 |
77 |
78 | class ToTensor(EvalTransform):
79 | def __init__(self, device):
80 | super().__init__()
81 | self.device = device
82 |
83 | def transform(self, image, mask):
84 | image = torch.as_tensor(image, device=self.device, dtype=torch.float32)
85 | mask = torch.as_tensor(mask, device=self.device)
86 | image.unsqueeze_(0)
87 | mask.unsqueeze_(0).unsqueeze_(0)
88 | return image.permute(0, 3, 1, 2) / 255.0, mask
89 |
90 | def inv_transform(self, image):
91 | image.squeeze_(0)
92 | return 255 * image.permute(1, 2, 0)
93 |
94 |
95 | class AddFlippedTensor(EvalTransform):
96 | def transform(self, image, mask):
97 | flipped_image = torch.flip(image, dims=(3,))
98 | flipped_mask = torch.flip(mask, dims=(3,))
99 | image = torch.cat((image, flipped_image), dim=0)
100 | mask = torch.cat((mask, flipped_mask), dim=0)
101 | return image, mask
102 |
103 | def inv_transform(self, image):
104 | return 0.5 * (image[:1] + torch.flip(image[1:], dims=(3,)))
105 |
--------------------------------------------------------------------------------
/iharm/inference/utils.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from iharm.utils.misc import load_weights
4 | from iharm.mconfigs import ALL_MCONFIGS
5 |
6 |
7 | def load_model(model_type, checkpoint_path, verbose=False):
8 | net = ALL_MCONFIGS[model_type]['model'](**ALL_MCONFIGS[model_type]['params'])
9 | load_weights(net, checkpoint_path, verbose=verbose)
10 | return net
11 |
12 |
13 | def find_checkpoint(weights_folder, checkpoint_name):
14 | weights_folder = Path(weights_folder)
15 | if ':' in checkpoint_name:
16 | model_name, checkpoint_name = checkpoint_name.split(':')
17 | models_candidates = [x for x in weights_folder.glob(f'{model_name}*') if x.is_dir()]
18 | assert len(models_candidates) == 1
19 | model_folder = models_candidates[0]
20 | else:
21 | model_folder = weights_folder
22 |
23 | if checkpoint_name.endswith('.pth'):
24 | if Path(checkpoint_name).exists():
25 | checkpoint_path = checkpoint_name
26 | else:
27 | checkpoint_path = weights_folder / checkpoint_name
28 | else:
29 | model_checkpoints = list(model_folder.rglob(f'{checkpoint_name}*.pth'))
30 | assert len(model_checkpoints) == 1
31 | checkpoint_path = model_checkpoints[0]
32 | return str(checkpoint_path)
33 |
34 |
--------------------------------------------------------------------------------
/iharm/mconfigs/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import BMCONFIGS
2 | from .backboned import MCONFIGS
3 |
4 |
5 | ALL_MCONFIGS = dict(**BMCONFIGS, **MCONFIGS)
6 |
--------------------------------------------------------------------------------
/iharm/mconfigs/backboned.py:
--------------------------------------------------------------------------------
1 | from .base import BMCONFIGS
2 | from iharm.model.backboned import DeepLabIHModel, HRNetIHModel
3 |
4 |
5 | MCONFIGS = {
6 | 'hrnet18s_idih256': {
7 | 'model': HRNetIHModel,
8 | 'params': {'base_config': BMCONFIGS['improved_dih256']}
9 | },
10 | 'hrnet18s_v2p_idih256': {
11 | 'model': HRNetIHModel,
12 | 'params': {'base_config': BMCONFIGS['improved_dih256'], 'pyramid_channels': 256}
13 | },
14 | 'hrnet18_idih256': {
15 | 'model': HRNetIHModel,
16 | 'params': {'base_config': BMCONFIGS['improved_dih256'], 'small': False}
17 | },
18 | 'hrnet18_v2p_idih256': {
19 | 'model': HRNetIHModel,
20 | 'params': {'base_config': BMCONFIGS['improved_dih256'], 'small': False, 'pyramid_channels': 256}
21 | },
22 | 'hrnet32_idih256': {
23 | 'model': HRNetIHModel,
24 | 'params': {'base_config': BMCONFIGS['improved_dih256'], 'width': 32, 'small': False}
25 | },
26 | 'deeplab_r34_idih256': {
27 | 'model': DeepLabIHModel,
28 | 'params': {'base_config': BMCONFIGS['improved_dih256']}
29 | },
30 | 'hrnet18_idih512': {
31 | 'model': HRNetIHModel,
32 | 'params': {'base_config': BMCONFIGS['improved_dih512'], 'small': False, 'downsize_hrnet_input': True}
33 | },
34 | 'hrnet18_sedih512': {
35 | 'model': HRNetIHModel,
36 | 'params': {'base_config': BMCONFIGS['improved_sedih512'], 'small': False, 'downsize_hrnet_input': True}
37 | }
38 | }
39 |
--------------------------------------------------------------------------------
/iharm/mconfigs/base.py:
--------------------------------------------------------------------------------
1 | from iharm.model.base import DeepImageHarmonization, SSAMImageHarmonization, ISEUNetV1, DucoNet_model
2 |
3 | BMCONFIGS = {
4 | 'dih256': {
5 | 'model': DeepImageHarmonization,
6 | 'params': {'depth': 7}
7 | },
8 | 'improved_dih256': {
9 | 'model': DeepImageHarmonization,
10 | 'params': {'depth': 7, 'batchnorm_from': 2, 'image_fusion': True}
11 | },
12 | 'improved_sedih256': {
13 | 'model': DeepImageHarmonization,
14 | 'params': {'depth': 7, 'batchnorm_from': 2, 'image_fusion': True, 'attend_from': 5}
15 | },
16 | 'ssam256': {
17 | 'model': SSAMImageHarmonization,
18 | 'params': {'depth': 4, 'batchnorm_from': 2, 'attend_from': 2}
19 | },
20 | 'improved_ssam256': {
21 | 'model': SSAMImageHarmonization,
22 | 'params': {'depth': 4, 'ch': 32, 'image_fusion': True, 'attention_mid_k': 0.5,
23 | 'batchnorm_from': 2, 'attend_from': 2}
24 | },
25 | 'iseunetv1_256': {
26 | 'model': ISEUNetV1,
27 | 'params': {'depth': 4, 'batchnorm_from': 2, 'attend_from': 1, 'ch': 64}
28 | },
29 | 'dih512': {
30 | 'model': DeepImageHarmonization,
31 | 'params': {'depth': 8}
32 | },
33 | 'improved_dih512': {
34 | 'model': DeepImageHarmonization,
35 | 'params': {'depth': 8, 'batchnorm_from': 2, 'image_fusion': True}
36 | },
37 | 'improved_ssam512': {
38 | 'model': SSAMImageHarmonization,
39 | 'params': {'depth': 6, 'ch': 32, 'image_fusion': True, 'attention_mid_k': 0.5,
40 | 'batchnorm_from': 2, 'attend_from': 3}
41 | },
42 | 'improved_sedih512': {
43 | 'model': DeepImageHarmonization,
44 | 'params': {'depth': 8, 'batchnorm_from': 2, 'image_fusion': True, 'attend_from': 6}
45 | },
46 |
47 | 'DucoNet':{
48 | 'model': DucoNet_model,
49 | 'params': {'depth': 4, 'ch': 32, 'image_fusion': True, 'attention_mid_k': 0.5,
50 | 'batchnorm_from': 2, 'attend_from': 2,'w_dim':256,'control_module_start':-1}
51 | },
52 |
53 | }
54 |
--------------------------------------------------------------------------------
/iharm/model/backboned/__init__.py:
--------------------------------------------------------------------------------
1 | from .deeplab import DeepLabIHModel
2 | from .hrnet import HRNetIHModel
3 |
--------------------------------------------------------------------------------
/iharm/model/backboned/deeplab.py:
--------------------------------------------------------------------------------
1 | from torch import nn as nn
2 |
3 | from iharm.model.modeling.deeplab_v3 import DeepLabV3Plus
4 | from iharm.model.backboned.ih_model import IHModelWithBackbone
5 | from iharm.model.modifiers import LRMult
6 | from iharm.model.modeling.basic_blocks import MaxPoolDownSize
7 |
8 |
9 | class DeepLabIHModel(IHModelWithBackbone):
10 | def __init__(
11 | self,
12 | base_config,
13 | mask_fusion='sum',
14 | deeplab_backbone='resnet34',
15 | lr_mult=0.1,
16 | pyramid_channels=-1, deeplab_ch=256,
17 | mode='cat',
18 | **base_kwargs
19 | ):
20 | """
21 | Creates image harmonization model supported by the features extracted from the pre-trained DeepLab backbone.
22 |
23 | Parameters
24 | ----------
25 | base_config : dict
26 | Configuration dict for the base model, to which the backbone features are incorporated.
27 | base_config contains model class and init parameters, examples can be found in iharm.mconfigs.base_models
28 | mask_fusion : str
29 | How to fuse the binary mask with the backbone input:
30 | 'sum': apply convolution to the mask and sum it with the output of the first convolution in the backbone
31 | 'rgb': concatenate the mask to the input image and translate it back to 3 channels with convolution
32 | otherwise: do not fuse mask with the backbone input
33 | deeplab_backbone : str
34 | ResNet backbone name.
35 | lr_mult : float
36 | Multiply learning rate to lr_mult when updating the weights of the backbone.
37 | pyramid_channels : int
38 | The DeepLab output can be consequently downsized to produce a feature pyramid.
39 | The pyramid features are then fused with the encoder outputs in the base model on multiple layers.
40 | Each pyramid feature map contains equal number of channels equal to pyramid_channels.
41 | If pyramid_channels <= 0, the feature pyramid is not constructed.
42 | deeplab_ch : int
43 | Number of channels for output DeepLab layer and some in the middle.
44 | mode : str
45 | How to fuse the backbone features with the encoder outputs in the base model:
46 | 'sum': apply convolution to the backbone feature map obtaining number of channels
47 | same as in the encoder output and sum them
48 | 'cat': concatenate the backbone feature map with the encoder output
49 | 'catc': concatenate the backbone feature map with the encoder output and apply convolution obtaining
50 | number of channels same as in the encoder output
51 | otherwise: the backbone features are not incorporated into the base model
52 | base_kwargs : dict
53 | any kwargs associated with the base model
54 | """
55 | params = base_config['params']
56 | params.update(base_kwargs)
57 | depth = params['depth']
58 |
59 | backbone = DeepLabBB(pyramid_channels, deeplab_ch, deeplab_backbone, lr_mult)
60 |
61 | downsize_input = depth > 7
62 | params.update(dict(
63 | backbone_from=3 if downsize_input else 2,
64 | backbone_channels=backbone.output_channels,
65 | backbone_mode=mode
66 | ))
67 | base_model = base_config['model'](**params)
68 |
69 | super(DeepLabIHModel, self).__init__(base_model, backbone, downsize_input, mask_fusion)
70 |
71 |
72 | class DeepLabBB(nn.Module):
73 | def __init__(
74 | self,
75 | pyramid_channels=256,
76 | deeplab_ch=256,
77 | backbone='resnet34',
78 | backbone_lr_mult=0.1,
79 | ):
80 | super(DeepLabBB, self).__init__()
81 | self.pyramid_on = pyramid_channels > 0
82 | if self.pyramid_on:
83 | self.output_channels = [pyramid_channels] * 4
84 | else:
85 | self.output_channels = [deeplab_ch]
86 |
87 | self.deeplab = DeepLabV3Plus(backbone=backbone,
88 | ch=deeplab_ch,
89 | project_dropout=0.2,
90 | norm_layer=nn.BatchNorm2d,
91 | backbone_norm_layer=nn.BatchNorm2d)
92 | self.deeplab.backbone.apply(LRMult(backbone_lr_mult))
93 |
94 | if self.pyramid_on:
95 | self.downsize = MaxPoolDownSize(deeplab_ch, pyramid_channels, pyramid_channels, 4)
96 |
97 | def forward(self, image, mask, mask_features):
98 | outputs = list(self.deeplab(image, mask_features))
99 | if self.pyramid_on:
100 | outputs = self.downsize(outputs[0])
101 | return outputs
102 |
103 | def load_pretrained_weights(self):
104 | self.deeplab.load_pretrained_weights()
105 |
--------------------------------------------------------------------------------
/iharm/model/backboned/hrnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from iharm.model.modeling.hrnet_ocr import HighResolutionNet
4 | from iharm.model.backboned.ih_model import IHModelWithBackbone
5 | from iharm.model.modifiers import LRMult
6 | from iharm.model.modeling.basic_blocks import MaxPoolDownSize
7 |
8 |
9 | class HRNetIHModel(IHModelWithBackbone):
10 | def __init__(
11 | self,
12 | base_config,
13 | downsize_hrnet_input=False, mask_fusion='sum',
14 | lr_mult=0.1, cat_hrnet_outputs=True, pyramid_channels=-1,
15 | ocr=64, width=18, small=True,
16 | mode='cat',
17 | **base_kwargs
18 | ):
19 | """
20 | Creates image harmonization model supported by the features extracted from the pre-trained HRNet backbone.
21 | HRNet outputs feature maps on 4 different resolutions.
22 |
23 | Parameters
24 | ----------
25 | base_config : dict
26 | Configuration dict for the base model, to which the backbone features are incorporated.
27 | base_config contains model class and init parameters, examples can be found in iharm.mconfigs.base_models
28 | downsize_backbone_input : bool
29 | If the input image should be half-sized for the backbone.
30 | mask_fusion : str
31 | How to fuse the binary mask with the backbone input:
32 | 'sum': apply convolution to the mask and sum it with the output of the first convolution in the backbone
33 | 'rgb': concatenate the mask to the input image and translate it back to 3 channels with convolution
34 | otherwise: do not fuse mask with the backbone input
35 | lr_mult : float
36 | Multiply learning rate to lr_mult when updating the weights of the backbone.
37 | cat_hrnet_outputs : bool
38 | If 4 HRNet outputs should be resized and concatenated to a single tensor.
39 | pyramid_channels : int
40 | When HRNet outputs are concatenated to a single one, it can be consequently downsized
41 | to produce a feature pyramid.
42 | The pyramid features are then fused with the encoder outputs in the base model on multiple layers.
43 | Each pyramid feature map contains equal number of channels equal to pyramid_channels.
44 | If pyramid_channels <= 0, the feature pyramid is not constructed.
45 | ocr : int
46 | When HRNet outputs are concatenated to a single one, the OCR module can be applied
47 | resulting in feature map with (2 * ocr) channels. If ocr <= 0 the OCR module is not applied.
48 | width : int
49 | Width of the HRNet blocks.
50 | small : bool
51 | If True, HRNet contains 2 blocks at each stage and 4 otherwise.
52 | mode : str
53 | How to fuse the backbone features with the encoder outputs in the base model:
54 | 'sum': apply convolution to the backbone feature map obtaining number of channels
55 | same as in the encoder output and sum them
56 | 'cat': concatenate the backbone feature map with the encoder output
57 | 'catc': concatenate the backbone feature map with the encoder output and apply convolution obtaining
58 | number of channels same as in the encoder output
59 | otherwise: the backbone features are not incorporated into the base model
60 | base_kwargs : dict
61 | any kwargs associated with the base model
62 | """
63 | params = base_config['params']
64 | params.update(base_kwargs)
65 | depth = params['depth']
66 |
67 | backbone = HRNetBB(
68 | cat_outputs=cat_hrnet_outputs,
69 | pyramid_channels=pyramid_channels,
70 | pyramid_depth=min(depth - 2 if not downsize_hrnet_input else depth - 3, 4),
71 | width=width, ocr=ocr, small=small,
72 | lr_mult=lr_mult,
73 | )
74 |
75 | params.update(dict(
76 | backbone_from=3 if downsize_hrnet_input else 2,
77 | backbone_channels=backbone.output_channels,
78 | backbone_mode=mode
79 | ))
80 | base_model = base_config['model'](**params)
81 |
82 | super(HRNetIHModel, self).__init__(base_model, backbone, downsize_hrnet_input, mask_fusion)
83 |
84 |
85 | class HRNetBB(nn.Module):
86 | def __init__(
87 | self,
88 | cat_outputs=True,
89 | pyramid_channels=256, pyramid_depth=4,
90 | width=18, ocr=64, small=True,
91 | lr_mult=0.1,
92 | ):
93 | super(HRNetBB, self).__init__()
94 | self.cat_outputs = cat_outputs
95 | self.ocr_on = ocr > 0 and cat_outputs
96 | self.pyramid_on = pyramid_channels > 0 and cat_outputs
97 |
98 | self.hrnet = HighResolutionNet(width, 2, ocr_width=ocr, small=small)
99 | self.hrnet.apply(LRMult(lr_mult))
100 | if self.ocr_on:
101 | self.hrnet.ocr_distri_head.apply(LRMult(1.0))
102 | self.hrnet.ocr_gather_head.apply(LRMult(1.0))
103 | self.hrnet.conv3x3_ocr.apply(LRMult(1.0))
104 |
105 | hrnet_cat_channels = [width * 2 ** i for i in range(4)]
106 | if self.pyramid_on:
107 | self.output_channels = [pyramid_channels] * 4
108 | elif self.ocr_on:
109 | self.output_channels = [ocr * 2]
110 | elif self.cat_outputs:
111 | self.output_channels = [sum(hrnet_cat_channels)]
112 | else:
113 | self.output_channels = hrnet_cat_channels
114 |
115 | if self.pyramid_on:
116 | downsize_in_channels = ocr * 2 if self.ocr_on else sum(hrnet_cat_channels)
117 | self.downsize = MaxPoolDownSize(downsize_in_channels, pyramid_channels, pyramid_channels, pyramid_depth)
118 |
119 | def forward(self, image, mask, mask_features):
120 | if not self.cat_outputs:
121 | return self.hrnet.compute_hrnet_feats(image, mask_features, return_list=True)
122 |
123 | outputs = list(self.hrnet(image, mask, mask_features))
124 | if self.pyramid_on:
125 | outputs = self.downsize(outputs[0])
126 | return outputs
127 |
128 | def load_pretrained_weights(self, pretrained_path):
129 | self.hrnet.load_pretrained_weights(pretrained_path)
130 |
--------------------------------------------------------------------------------
/iharm/model/backboned/ih_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from iharm.model.ops import SimpleInputFusion, ScaleLayer
5 |
6 |
7 | class IHModelWithBackbone(nn.Module):
8 | def __init__(
9 | self,
10 | model, backbone,
11 | downsize_backbone_input=False,
12 | mask_fusion='sum',
13 | backbone_conv1_channels=64,
14 | ):
15 | """
16 | Creates image harmonization model supported by the features extracted from the pre-trained backbone.
17 |
18 | Parameters
19 | ----------
20 | model : nn.Module
21 | Image harmonization model takes image and mask as an input and handles features from the backbone network.
22 | backbone : nn.Module
23 | Backbone model accepts RGB image and returns a list of features.
24 | downsize_backbone_input : bool
25 | If the input image should be half-sized for the backbone.
26 | mask_fusion : str
27 | How to fuse the binary mask with the backbone input:
28 | 'sum': apply convolution to the mask and sum it with the output of the first convolution in the backbone
29 | 'rgb': concatenate the mask to the input image and translate it back to 3 channels with convolution
30 | otherwise: do not fuse mask with the backbone input
31 | backbone_conv1_channels : int
32 | If mask_fusion is 'sum', define the number of channels for the convolution applied to the mask.
33 | """
34 | super(IHModelWithBackbone, self).__init__()
35 | self.downsize_backbone_input = downsize_backbone_input
36 | self.mask_fusion = mask_fusion
37 |
38 | self.backbone = backbone
39 | self.model = model
40 |
41 | if mask_fusion == 'rgb':
42 | self.fusion = SimpleInputFusion()
43 | elif mask_fusion == 'sum':
44 | self.mask_conv = nn.Sequential(
45 | nn.Conv2d(1, backbone_conv1_channels, kernel_size=3, stride=2, padding=1, bias=True),
46 | ScaleLayer(init_value=0.1, lr_mult=1)
47 | )
48 |
49 | def forward(self, image, mask):
50 | """
51 | Forward the backbone model and then the base model, supported by the backbone feature maps.
52 | Return model predictions.
53 |
54 | Parameters
55 | ----------
56 | image : torch.Tensor
57 | Input RGB image.
58 | mask : torch.Tensor
59 | Binary mask of the foreground region.
60 |
61 | Returns
62 | -------
63 | torch.Tensor
64 | Harmonized RGB image.
65 | """
66 | backbone_image = image
67 | backbone_mask = torch.cat((mask, 1.0 - mask), dim=1)
68 | if self.downsize_backbone_input:
69 | backbone_image = nn.functional.interpolate(
70 | backbone_image, scale_factor=0.5,
71 | mode='bilinear', align_corners=True
72 | )
73 | backbone_mask = nn.functional.interpolate(
74 | backbone_mask, backbone_image.size()[2:],
75 | mode='bilinear', align_corners=True
76 | )
77 | backbone_image = (
78 | self.fusion(backbone_image, backbone_mask[:, :1])
79 | if self.mask_fusion == 'rgb' else
80 | backbone_image
81 | )
82 | backbone_mask_features = self.mask_conv(backbone_mask[:, :1]) if self.mask_fusion == 'sum' else None
83 | backbone_features = self.backbone(backbone_image, backbone_mask, backbone_mask_features)
84 |
85 | output = self.model(image, mask, backbone_features)
86 | return output
87 |
--------------------------------------------------------------------------------
/iharm/model/base/Control_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from functools import partial
3 |
4 | from torch import nn as nn
5 |
6 | from iharm.model.modeling.basic_blocks import ConvBlock, GaussianSmoothing
7 | from iharm.model.modeling.unet_v1 import UNetEncoder, UNetDecoder
8 | from iharm.model.ops import ChannelAttention
9 | from iharm.model.modeling.styleganv2 import ResidualBlock
10 |
11 |
12 | class Control_encoder(nn.Module):
13 | def __init__(
14 | self,
15 | depth,
16 | norm_layer=nn.BatchNorm2d, batchnorm_from=2,
17 | ch=64, max_channels=512,
18 | lab_encoder_mask_channel=1,
19 | L_or_Lab_dim=3,w_dim=256,
20 | backbone_from=-1, backbone_channels=None, backbone_mode='',
21 | ):
22 | super(Control_encoder, self).__init__()
23 | self.depth = depth
24 | self.w_dim = w_dim
25 | self.lab_encoder_mask_channel = lab_encoder_mask_channel
26 |
27 | self.lab_encoder = UNetEncoder(
28 | depth, ch,
29 | norm_layer, batchnorm_from, max_channels,
30 | backbone_from, backbone_channels, backbone_mode,
31 | _in_channels=L_or_Lab_dim+self.lab_encoder_mask_channel
32 | )
33 |
34 | self.map2w = map_net(int_dim=256,out_dim=self.w_dim)
35 |
36 |
37 | def forward(self, image_lab, mask,backbone_features=None):
38 | x_lab = image_lab
39 | if self.lab_encoder_mask_channel==1:
40 | x_lab = torch.cat((x_lab, mask), dim=1)
41 | intermediates_lab = self.lab_encoder(x_lab,backbone_features)
42 | w = self.map2w(intermediates_lab[0],mask)
43 |
44 | return w
45 |
46 | class map_net(nn.Module):
47 | def __init__(self,int_dim = 256,out_dim = 256):
48 | super(map_net, self).__init__()
49 |
50 | self.int_dim = int_dim
51 | self.pooling = nn.AdaptiveAvgPool2d((1,1))
52 |
53 | self.net = nn.Sequential(
54 | nn.Linear(int_dim,256),
55 | nn.ReLU(inplace=True),
56 | nn.Linear(256,out_dim),
57 | )
58 |
59 | def forward(self,feature_map,mask):
60 |
61 | w_bg = self.pooling(feature_map).view(-1,self.int_dim)
62 |
63 | return self.net(w_bg)
64 |
65 |
66 | class SpatialSeparatedAttention(nn.Module):
67 | def __init__(self, in_channels, norm_layer, activation, mid_k=2.0):
68 | super(SpatialSeparatedAttention, self).__init__()
69 | self.background_gate = ChannelAttention(in_channels)
70 | self.foreground_gate = ChannelAttention(in_channels)
71 | self.mix_gate = ChannelAttention(in_channels)
72 |
73 | mid_channels = int(mid_k * in_channels)
74 | self.learning_block = nn.Sequential(
75 | ConvBlock(
76 | in_channels, mid_channels,
77 | kernel_size=3, stride=1, padding=1,
78 | norm_layer=norm_layer, activation=activation,
79 | bias=False,
80 | ),
81 | ConvBlock(
82 | mid_channels, in_channels,
83 | kernel_size=3, stride=1, padding=1,
84 | norm_layer=norm_layer, activation=activation,
85 | bias=False,
86 | ),
87 | )
88 | self.mask_blurring = GaussianSmoothing(1, 7, 1, padding=3)
89 |
90 | def forward(self, x, mask):
91 | mask = self.mask_blurring(nn.functional.interpolate(
92 | mask, size=x.size()[-2:],
93 | mode='bilinear', align_corners=True
94 | ))
95 | background = self.background_gate(x)
96 | foreground = self.learning_block(self.foreground_gate(x))
97 | mix = self.mix_gate(x)
98 | output = mask * (foreground + mix) + (1 - mask) * background
99 | return output
100 |
--------------------------------------------------------------------------------
/iharm/model/base/DucoNet_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from functools import partial
3 |
4 | from torch import nn as nn
5 |
6 | from iharm.model.modeling.basic_blocks import ConvBlock, GaussianSmoothing
7 | from iharm.model.modeling.unet_v1 import UNetEncoder, UNetDecoder
8 | from iharm.model.ops import ChannelAttention
9 | from iharm.model.modeling.styleganv2 import ResidualBlock
10 | from .Control_encoder import Control_encoder
11 |
12 |
13 | class DucoNet_model(nn.Module):
14 | def __init__(
15 | self,
16 | depth,
17 | norm_layer=nn.BatchNorm2d, batchnorm_from=2,
18 | attend_from=3, attention_mid_k=2.0,
19 | image_fusion=False,
20 | ch=64, max_channels=512,
21 | backbone_from=-1, backbone_channels=None, backbone_mode='',
22 | control_module_start = -1,
23 | lab_encoder_mask_channel=1,
24 | w_dim = 256,
25 |
26 | ):
27 | super(DucoNet_model, self).__init__()
28 | self.depth = depth
29 | self.control_module_start = control_module_start
30 | self.w_dim = w_dim
31 |
32 | self.lab_encoder_mask_channel = lab_encoder_mask_channel
33 |
34 | self.encoder = UNetEncoder(
35 | depth, ch,
36 | norm_layer, batchnorm_from, max_channels,
37 | backbone_from, backbone_channels, backbone_mode
38 | )
39 | self.decoder = UNetDecoder(
40 | depth, self.encoder.block_channels,
41 | norm_layer,
42 | attention_layer=partial(SpatialSeparatedAttention, mid_k=attention_mid_k),
43 | attend_from=attend_from,
44 | image_fusion=image_fusion,
45 | control_module_start=self.control_module_start,
46 | w_dim = self.w_dim,
47 | control_module_layer=Control_Module,
48 | )
49 |
50 | self.l_encoder = Control_encoder(
51 | depth=depth, ch=ch, batchnorm_from=2,
52 | lab_encoder_mask_channel=1,
53 | L_or_Lab_dim=1,w_dim= self.w_dim
54 | )
55 |
56 | self.a_encoder = Control_encoder(
57 | depth=depth, ch=ch, batchnorm_from=2,
58 | lab_encoder_mask_channel=1,
59 | L_or_Lab_dim=1, w_dim=self.w_dim
60 | )
61 |
62 | self.b_encoder = Control_encoder(
63 | depth=depth, ch=ch, batchnorm_from=2,
64 | lab_encoder_mask_channel=1,
65 | L_or_Lab_dim=1, w_dim=self.w_dim
66 | )
67 |
68 |
69 | def forward(self, image, mask, image_lab=None, backbone_features=None):
70 |
71 | x = torch.cat((image, mask), dim=1)
72 | intermediates = self.encoder(x, backbone_features)
73 |
74 | w_l = self.l_encoder(image_lab[:,0,:,:].unsqueeze(1),mask)
75 | w_a = self.a_encoder(image_lab[:,1,:,:].unsqueeze(1),mask)
76 | w_b = self.b_encoder(image_lab[:,2,:,:].unsqueeze(1),mask)
77 |
78 | ws = {'w_l':w_l,'w_a':w_a,'w_b':w_b}
79 |
80 | output = self.decoder(intermediates, image, mask, ws)
81 |
82 | return {'images': output}
83 |
84 |
85 | class Control_Module(nn.Module):
86 | def __init__(self, w_dim, feature_dim):
87 | super(Control_Module,self).__init__()
88 |
89 | self.w_dim = w_dim
90 | self.feature_dim = feature_dim
91 |
92 | self.l_styleblock = ResidualBlock(self.w_dim, self.feature_dim, self.feature_dim)
93 | self.a_styleblock = ResidualBlock(self.w_dim, self.feature_dim, self.feature_dim)
94 | self.b_styleblock = ResidualBlock(self.w_dim, self.feature_dim, self.feature_dim)
95 |
96 | self.G_weight = nn.Sequential(
97 | nn.Conv2d(self.feature_dim * 3, 32, kernel_size=1, stride=1, padding=0, bias=True),
98 | nn.ReLU(inplace=True),
99 | nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0, bias=True),
100 | )
101 |
102 |
103 | def forward(self, x, sl, sa, sb, mask):
104 |
105 | f_l = self.l_styleblock(x, sl, noise=None)
106 | f_a = self.a_styleblock(x, sa, noise=None)
107 | f_b = self.b_styleblock(x, sb, noise=None)
108 |
109 | f_lab = torch.cat((f_l, f_a, f_b), dim=1)
110 | f_weight = self.G_weight(f_lab)
111 | f_weight = nn.functional.softmax(f_weight, dim=1)
112 |
113 | weight_l = f_weight[:, 0, :, :].unsqueeze(1)
114 | weight_a = f_weight[:, 1, :, :].unsqueeze(1)
115 | weight_b = f_weight[:, 2, :, :].unsqueeze(1)
116 |
117 | out = weight_l * f_l + weight_a * f_a + weight_b * f_b
118 |
119 | out = x * (1 - mask) + mask * out
120 | return out
121 |
122 | class SpatialSeparatedAttention(nn.Module):
123 | def __init__(self, in_channels, norm_layer, activation, mid_k=2.0):
124 | super(SpatialSeparatedAttention, self).__init__()
125 | self.background_gate = ChannelAttention(in_channels)
126 | self.foreground_gate = ChannelAttention(in_channels)
127 | self.mix_gate = ChannelAttention(in_channels)
128 |
129 | mid_channels = int(mid_k * in_channels)
130 | self.learning_block = nn.Sequential(
131 | ConvBlock(
132 | in_channels, mid_channels,
133 | kernel_size=3, stride=1, padding=1,
134 | norm_layer=norm_layer, activation=activation,
135 | bias=False,
136 | ),
137 | ConvBlock(
138 | mid_channels, in_channels,
139 | kernel_size=3, stride=1, padding=1,
140 | norm_layer=norm_layer, activation=activation,
141 | bias=False,
142 | ),
143 | )
144 | self.mask_blurring = GaussianSmoothing(1, 7, 1, padding=3)
145 |
146 | def forward(self, x, mask):
147 | mask = self.mask_blurring(nn.functional.interpolate(
148 | mask, size=x.size()[-2:],
149 | mode='bilinear', align_corners=True
150 | ))
151 | background = self.background_gate(x)
152 | foreground = self.learning_block(self.foreground_gate(x))
153 | mix = self.mix_gate(x)
154 | output = mask * (foreground + mix) + (1 - mask) * background
155 | return output
156 |
--------------------------------------------------------------------------------
/iharm/model/base/__init__.py:
--------------------------------------------------------------------------------
1 | from .dih_model import DeepImageHarmonization
2 | from .ssam_model import SSAMImageHarmonization
3 | from .iseunet_v1 import ISEUNetV1
4 | from .DucoNet_model import DucoNet_model
5 |
--------------------------------------------------------------------------------
/iharm/model/base/dih_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from iharm.model.modeling.conv_autoencoder import ConvEncoder, DeconvDecoder
5 |
6 |
7 | class DeepImageHarmonization(nn.Module):
8 | def __init__(
9 | self,
10 | depth,
11 | norm_layer=nn.BatchNorm2d, batchnorm_from=0,
12 | attend_from=-1,
13 | image_fusion=False,
14 | ch=64, max_channels=512,
15 | backbone_from=-1, backbone_channels=None, backbone_mode=''
16 | ):
17 | super(DeepImageHarmonization, self).__init__()
18 | self.depth = depth
19 | self.encoder = ConvEncoder(
20 | depth, ch,
21 | norm_layer, batchnorm_from, max_channels,
22 | backbone_from, backbone_channels, backbone_mode
23 | )
24 | self.decoder = DeconvDecoder(depth, self.encoder.blocks_channels, norm_layer, attend_from, image_fusion)
25 |
26 | def forward(self, image, mask, backbone_features=None):
27 | x = torch.cat((image, mask), dim=1)
28 | intermediates = self.encoder(x, backbone_features)
29 | output = self.decoder(intermediates, image, mask)
30 | return {'images': output}
31 |
--------------------------------------------------------------------------------
/iharm/model/base/iseunet_v1.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from iharm.model.modeling.unet import UNetEncoder, UNetDecoder
5 | from iharm.model.ops import MaskedChannelAttention
6 |
7 |
8 | class ISEUNetV1(nn.Module):
9 | def __init__(
10 | self,
11 | depth,
12 | norm_layer=nn.BatchNorm2d, batchnorm_from=2,
13 | attend_from=3,
14 | image_fusion=False,
15 | ch=64, max_channels=512,
16 | backbone_from=-1, backbone_channels=None, backbone_mode=''
17 | ):
18 | super(ISEUNetV1, self).__init__()
19 | self.depth = depth
20 | self.encoder = UNetEncoder(
21 | depth, ch,
22 | norm_layer, batchnorm_from, max_channels,
23 | backbone_from, backbone_channels, backbone_mode
24 | )
25 | self.decoder = UNetDecoder(
26 | depth, self.encoder.block_channels,
27 | norm_layer,
28 | attention_layer=MaskedChannelAttention,
29 | attend_from=attend_from,
30 | image_fusion=image_fusion
31 | )
32 |
33 | def forward(self, image, mask, backbone_features=None):
34 | x = torch.cat((image, mask), dim=1)
35 | intermediates = self.encoder(x, backbone_features)
36 | output = self.decoder(intermediates, image, mask)
37 | return {'images': output}
38 |
--------------------------------------------------------------------------------
/iharm/model/base/ssam_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from functools import partial
3 |
4 | from torch import nn as nn
5 |
6 | from iharm.model.modeling.basic_blocks import ConvBlock, GaussianSmoothing
7 | from iharm.model.modeling.unet import UNetEncoder, UNetDecoder
8 | from iharm.model.ops import ChannelAttention
9 |
10 |
11 | class SSAMImageHarmonization(nn.Module):
12 | def __init__(
13 | self,
14 | depth,
15 | norm_layer=nn.BatchNorm2d, batchnorm_from=2,
16 | attend_from=3, attention_mid_k=2.0,
17 | image_fusion=False,
18 | ch=64, max_channels=512,
19 | backbone_from=-1, backbone_channels=None, backbone_mode=''
20 | ):
21 | super(SSAMImageHarmonization, self).__init__()
22 | self.depth = depth
23 | self.encoder = UNetEncoder(
24 | depth, ch,
25 | norm_layer, batchnorm_from, max_channels,
26 | backbone_from, backbone_channels, backbone_mode
27 | )
28 | self.decoder = UNetDecoder(
29 | depth, self.encoder.block_channels,
30 | norm_layer,
31 | attention_layer=partial(SpatialSeparatedAttention, mid_k=attention_mid_k),
32 | attend_from=attend_from,
33 | image_fusion=image_fusion
34 | )
35 |
36 | def forward(self, image, mask, backbone_features=None):
37 |
38 | x = torch.cat((image, mask), dim=1)
39 | intermediates = self.encoder(x, backbone_features)
40 | output = self.decoder(intermediates, image, mask)
41 | return {'images': output}
42 |
43 |
44 | class SpatialSeparatedAttention(nn.Module):
45 | def __init__(self, in_channels, norm_layer, activation, mid_k=2.0):
46 | super(SpatialSeparatedAttention, self).__init__()
47 | self.background_gate = ChannelAttention(in_channels)
48 | self.foreground_gate = ChannelAttention(in_channels)
49 | self.mix_gate = ChannelAttention(in_channels)
50 |
51 | mid_channels = int(mid_k * in_channels)
52 | self.learning_block = nn.Sequential(
53 | ConvBlock(
54 | in_channels, mid_channels,
55 | kernel_size=3, stride=1, padding=1,
56 | norm_layer=norm_layer, activation=activation,
57 | bias=False,
58 | ),
59 | ConvBlock(
60 | mid_channels, in_channels,
61 | kernel_size=3, stride=1, padding=1,
62 | norm_layer=norm_layer, activation=activation,
63 | bias=False,
64 | ),
65 | )
66 | self.mask_blurring = GaussianSmoothing(1, 7, 1, padding=3)
67 |
68 | def forward(self, x, mask):
69 | mask = self.mask_blurring(nn.functional.interpolate(
70 | mask, size=x.size()[-2:],
71 | mode='bilinear', align_corners=True
72 | ))
73 | background = self.background_gate(x)
74 | foreground = self.learning_block(self.foreground_gate(x))
75 | mix = self.mix_gate(x)
76 | output = mask * (foreground + mix) + (1 - mask) * background
77 | return output
78 |
--------------------------------------------------------------------------------
/iharm/model/initializer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 |
6 | class Initializer(object):
7 | def __init__(self, local_init=True, gamma=None):
8 | self.local_init = local_init
9 | self.gamma = gamma
10 |
11 | def __call__(self, m):
12 | if getattr(m, '__initialized', False):
13 | return
14 |
15 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
16 | nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d,
17 | nn.GroupNorm, nn.SyncBatchNorm)) or 'BatchNorm' in m.__class__.__name__:
18 | if m.weight is not None:
19 | self._init_gamma(m.weight.data)
20 | if m.bias is not None:
21 | self._init_beta(m.bias.data)
22 | else:
23 | if getattr(m, 'weight', None) is not None:
24 | self._init_weight(m.weight.data)
25 | if getattr(m, 'bias', None) is not None:
26 | self._init_bias(m.bias.data)
27 |
28 | if self.local_init:
29 | object.__setattr__(m, '__initialized', True)
30 |
31 | def _init_weight(self, data):
32 | nn.init.uniform_(data, -0.07, 0.07)
33 |
34 | def _init_bias(self, data):
35 | nn.init.constant_(data, 0)
36 |
37 | def _init_gamma(self, data):
38 | if self.gamma is None:
39 | nn.init.constant_(data, 1.0)
40 | else:
41 | nn.init.normal_(data, 1.0, self.gamma)
42 |
43 | def _init_beta(self, data):
44 | nn.init.constant_(data, 0)
45 |
46 |
47 | class Bilinear(Initializer):
48 | def __init__(self, scale, groups, in_channels, **kwargs):
49 | super().__init__(**kwargs)
50 | self.scale = scale
51 | self.groups = groups
52 | self.in_channels = in_channels
53 |
54 | def _init_weight(self, data):
55 | """Reset the weight and bias."""
56 | bilinear_kernel = self.get_bilinear_kernel(self.scale)
57 | weight = torch.zeros_like(data)
58 | for i in range(self.in_channels):
59 | if self.groups == 1:
60 | j = i
61 | else:
62 | j = 0
63 | weight[i, j] = bilinear_kernel
64 | data[:] = weight
65 |
66 | @staticmethod
67 | def get_bilinear_kernel(scale):
68 | """Generate a bilinear upsampling kernel."""
69 | kernel_size = 2 * scale - scale % 2
70 | scale = (kernel_size + 1) // 2
71 | center = scale - 0.5 * (1 + kernel_size % 2)
72 |
73 | og = np.ogrid[:kernel_size, :kernel_size]
74 | kernel = (1 - np.abs(og[0] - center) / scale) * (1 - np.abs(og[1] - center) / scale)
75 |
76 | return torch.tensor(kernel, dtype=torch.float32)
77 |
78 |
79 | class XavierGluon(Initializer):
80 | def __init__(self, rnd_type='uniform', factor_type='avg', magnitude=3, **kwargs):
81 | super().__init__(**kwargs)
82 |
83 | self.rnd_type = rnd_type
84 | self.factor_type = factor_type
85 | self.magnitude = float(magnitude)
86 |
87 | def _init_weight(self, arr):
88 | fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(arr)
89 |
90 | if self.factor_type == 'avg':
91 | factor = (fan_in + fan_out) / 2.0
92 | elif self.factor_type == 'in':
93 | factor = fan_in
94 | elif self.factor_type == 'out':
95 | factor = fan_out
96 | else:
97 | raise ValueError('Incorrect factor type')
98 | scale = np.sqrt(self.magnitude / factor)
99 |
100 | if self.rnd_type == 'uniform':
101 | nn.init.uniform_(arr, -scale, scale)
102 | elif self.rnd_type == 'gaussian':
103 | nn.init.normal_(arr, 0, scale)
104 | else:
105 | raise ValueError('Unknown random type')
106 |
--------------------------------------------------------------------------------
/iharm/model/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from iharm.utils import misc
5 |
6 |
7 | class Loss(nn.Module):
8 | def __init__(self, pred_outputs, gt_outputs):
9 | super().__init__()
10 | self.pred_outputs = pred_outputs
11 | self.gt_outputs = gt_outputs
12 |
13 |
14 | class MSE(Loss):
15 | def __init__(self, pred_name='images', gt_image_name='target_images'):
16 | super(MSE, self).__init__(pred_outputs=(pred_name,), gt_outputs=(gt_image_name,))
17 |
18 | def forward(self, pred, label):
19 | label = label.view(pred.size())
20 | loss = torch.mean((pred - label) ** 2, dim=misc.get_dims_with_exclusion(label.dim(), 0))
21 | return loss
22 |
23 |
24 | class MaskWeightedMSE(Loss):
25 | def __init__(self, min_area=1000.0, pred_name='images',
26 | gt_image_name='target_images', gt_mask_name='masks'):
27 | super(MaskWeightedMSE, self).__init__(pred_outputs=(pred_name, ),
28 | gt_outputs=(gt_image_name, gt_mask_name))
29 | self.min_area = min_area
30 |
31 | def forward(self, pred, label, mask):
32 | label = label.view(pred.size())
33 | reduce_dims = misc.get_dims_with_exclusion(label.dim(), 0)
34 |
35 | loss = (pred - label) ** 2
36 | delimeter = pred.size(1) * torch.clamp_min(torch.sum(mask, dim=reduce_dims), self.min_area)
37 | loss = torch.sum(loss, dim=reduce_dims) / delimeter
38 |
39 | return loss
40 |
--------------------------------------------------------------------------------
/iharm/model/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | class TrainMetric(object):
6 | def __init__(self, pred_outputs, gt_outputs, epsilon=1e-6):
7 | self.pred_outputs = pred_outputs
8 | self.gt_outputs = gt_outputs
9 | self.epsilon = epsilon
10 | self._last_batch_metric = 0.0
11 | self._epoch_metric_sum = 0.0
12 | self._epoch_batch_count = 0
13 |
14 | def compute(self, *args, **kwargs):
15 | raise NotImplementedError
16 |
17 | def update(self, *args, **kwargs):
18 | self._last_batch_metric = self.compute(*args, **kwargs)
19 | self._epoch_metric_sum += self._last_batch_metric
20 | self._epoch_batch_count += 1
21 |
22 | def get_epoch_value(self):
23 | if self._epoch_batch_count > 0:
24 | return self._epoch_metric_sum / self._epoch_batch_count
25 | else:
26 | return 0.0
27 |
28 | def reset_epoch_stats(self):
29 | self._epoch_metric_sum = 0.0
30 | self._epoch_batch_count = 0
31 |
32 | def log_states(self, sw, tag_prefix, global_step):
33 | sw.add_scalar(tag=tag_prefix, value=self._last_batch_metric, global_step=global_step)
34 |
35 | @property
36 | def name(self):
37 | return type(self).__name__
38 |
39 |
40 | class PSNRMetric(TrainMetric):
41 | def __init__(self, pred_output='instances', gt_output='instances'):
42 | super(PSNRMetric, self).__init__((pred_output, ), (gt_output, ))
43 |
44 | def compute(self, pred, gt):
45 | mse = F.mse_loss(pred, gt)
46 | squared_max = gt.max() ** 2
47 | psnr = 10 * torch.log10(squared_max / (mse + self.epsilon))
48 | return psnr.item()
49 |
50 |
51 | class DenormalizedTrainMetric(TrainMetric):
52 | def __init__(self, pred_outputs, gt_outputs, mean=None, std=None):
53 | super(DenormalizedTrainMetric, self).__init__(pred_outputs, gt_outputs)
54 | self.mean = torch.zeros(1) if mean is None else mean
55 | self.std = torch.ones(1) if std is None else std
56 | self.device = None
57 |
58 | def init_device(self, input_device):
59 | if self.device is None:
60 | self.device = input_device
61 | self.mean = self.mean.to(self.device)
62 | self.std = self.std.to(self.device)
63 |
64 | def denormalize(self, tensor):
65 | self.init_device(tensor.device)
66 | return tensor * self.std + self.mean
67 |
68 | def update(self, *args, **kwargs):
69 | self._last_batch_metric = self.compute(*args, **kwargs)
70 | self._epoch_metric_sum += self._last_batch_metric
71 | self._epoch_batch_count += 1
72 |
73 |
74 | class DenormalizedPSNRMetric(DenormalizedTrainMetric):
75 | def __init__(
76 | self,
77 | pred_output='instances', gt_output='instances',
78 | mean=None, std=None,
79 | ):
80 | super(DenormalizedPSNRMetric, self).__init__((pred_output, ), (gt_output, ), mean, std)
81 |
82 | def compute(self, pred, gt):
83 | denormalized_pred = torch.clamp(self.denormalize(pred), 0, 1)
84 | denormalized_gt = self.denormalize(gt)
85 | return PSNRMetric.compute(self, denormalized_pred, denormalized_gt)
86 |
87 |
88 | class DenormalizedMSEMetric(DenormalizedTrainMetric):
89 | def __init__(
90 | self,
91 | pred_output='instances', gt_output='instances',
92 | mean=None, std=None,
93 | ):
94 | super(DenormalizedMSEMetric, self).__init__((pred_output, ), (gt_output, ), mean, std)
95 |
96 | def compute(self, pred, gt):
97 | denormalized_pred = self.denormalize(pred) * 255
98 | denormalized_gt = self.denormalize(gt) * 255
99 | return F.mse_loss(denormalized_pred, denormalized_gt).item()
100 |
--------------------------------------------------------------------------------
/iharm/model/modeling/basic_blocks.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numbers
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from torch import nn as nn
7 |
8 |
9 | class ConvHead(nn.Module):
10 | def __init__(self, out_channels, in_channels=32, num_layers=1,
11 | kernel_size=3, padding=1,
12 | norm_layer=nn.BatchNorm2d):
13 | super(ConvHead, self).__init__()
14 | convhead = []
15 |
16 | for i in range(num_layers):
17 | convhead.extend([
18 | nn.Conv2d(in_channels, in_channels, kernel_size, padding=padding),
19 | nn.ReLU(),
20 | norm_layer(in_channels) if norm_layer is not None else nn.Identity()
21 | ])
22 | convhead.append(nn.Conv2d(in_channels, out_channels, 1, padding=0))
23 |
24 | self.convhead = nn.Sequential(*convhead)
25 |
26 | def forward(self, *inputs):
27 | return self.convhead(inputs[0])
28 |
29 |
30 | class SepConvHead(nn.Module):
31 | def __init__(self, num_outputs, in_channels, mid_channels, num_layers=1,
32 | kernel_size=3, padding=1, dropout_ratio=0.0, dropout_indx=0,
33 | norm_layer=nn.BatchNorm2d):
34 | super(SepConvHead, self).__init__()
35 |
36 | sepconvhead = []
37 |
38 | for i in range(num_layers):
39 | sepconvhead.append(
40 | SeparableConv2d(in_channels=in_channels if i == 0 else mid_channels,
41 | out_channels=mid_channels,
42 | dw_kernel=kernel_size, dw_padding=padding,
43 | norm_layer=norm_layer, activation='relu')
44 | )
45 | if dropout_ratio > 0 and dropout_indx == i:
46 | sepconvhead.append(nn.Dropout(dropout_ratio))
47 |
48 | sepconvhead.append(
49 | nn.Conv2d(in_channels=mid_channels, out_channels=num_outputs, kernel_size=1, padding=0)
50 | )
51 |
52 | self.layers = nn.Sequential(*sepconvhead)
53 |
54 | def forward(self, *inputs):
55 | x = inputs[0]
56 |
57 | return self.layers(x)
58 |
59 |
60 | def select_activation_function(activation):
61 | if isinstance(activation, str):
62 | if activation.lower() == 'relu':
63 | return nn.ReLU
64 | elif activation.lower() == 'softplus':
65 | return nn.Softplus
66 | else:
67 | raise ValueError(f"Unknown activation type {activation}")
68 | elif isinstance(activation, nn.Module):
69 | return activation
70 | else:
71 | raise ValueError(f"Unknown activation type {activation}")
72 |
73 |
74 | class SeparableConv2d(nn.Module):
75 | def __init__(self, in_channels, out_channels, dw_kernel, dw_padding, dw_stride=1,
76 | activation=None, use_bias=False, norm_layer=None):
77 | super(SeparableConv2d, self).__init__()
78 | _activation = select_activation_function(activation)
79 | self.body = nn.Sequential(
80 | nn.Conv2d(in_channels, in_channels, kernel_size=dw_kernel, stride=dw_stride,
81 | padding=dw_padding, bias=use_bias, groups=in_channels),
82 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=use_bias),
83 | norm_layer(out_channels) if norm_layer is not None else nn.Identity(),
84 | _activation()
85 | )
86 |
87 | def forward(self, x):
88 | return self.body(x)
89 |
90 |
91 | class ConvBlock(nn.Module):
92 | def __init__(
93 | self,
94 | in_channels, out_channels,
95 | kernel_size=4, stride=2, padding=1,
96 | norm_layer=nn.BatchNorm2d, activation=nn.ELU,
97 | bias=True,
98 | ):
99 | super(ConvBlock, self).__init__()
100 | self.block = nn.Sequential(
101 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
102 | norm_layer(out_channels) if norm_layer is not None else nn.Identity(),
103 | activation(),
104 | )
105 |
106 | def forward(self, x):
107 | return self.block(x)
108 |
109 |
110 | class GaussianSmoothing(nn.Module):
111 | """
112 | https://discuss.pytorch.org/t/is-there-anyway-to-do-gaussian-filtering-for-an-image-2d-3d-in-pytorch/12351/10
113 | Apply gaussian smoothing on a tensor (1d, 2d, 3d).
114 | Filtering is performed seperately for each channel in the input using a depthwise convolution.
115 | Arguments:
116 | channels (int, sequence): Number of channels of the input tensors.
117 | Output will have this number of channels as well.
118 | kernel_size (int, sequence): Size of the gaussian kernel.
119 | sigma (float, sequence): Standard deviation of the gaussian kernel.
120 | dim (int, optional): The number of dimensions of the data. Default value is 2 (spatial).
121 | """
122 | def __init__(self, channels, kernel_size, sigma, padding=0, dim=2):
123 | super(GaussianSmoothing, self).__init__()
124 | if isinstance(kernel_size, numbers.Number):
125 | kernel_size = [kernel_size] * dim
126 | if isinstance(sigma, numbers.Number):
127 | sigma = [sigma] * dim
128 |
129 | # The gaussian kernel is the product of the gaussian function of each dimension.
130 | kernel = 1.
131 | meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size])
132 | for size, std, grid in zip(kernel_size, sigma, meshgrids):
133 | mean = (size - 1) / 2.
134 | kernel *= torch.exp(-((grid - mean) / std) ** 2 / 2) / (std * (2 * math.pi) ** 0.5)
135 | # Make sure sum of values in gaussian kernel equals 1.
136 | kernel = kernel / torch.sum(kernel)
137 | # Reshape to depthwise convolutional weight.
138 | kernel = kernel.view(1, 1, *kernel.size())
139 | kernel = torch.repeat_interleave(kernel, channels, 0)
140 |
141 | self.register_buffer('weight', kernel)
142 | self.groups = channels
143 | self.padding = padding
144 |
145 | if dim == 1:
146 | self.conv = F.conv1d
147 | elif dim == 2:
148 | self.conv = F.conv2d
149 | elif dim == 3:
150 | self.conv = F.conv3d
151 | else:
152 | raise RuntimeError('Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim))
153 |
154 | def forward(self, input):
155 | """
156 | Apply gaussian filter to input.
157 | Arguments:
158 | input (torch.Tensor): Input to apply gaussian filter on.
159 | Returns:
160 | filtered (torch.Tensor): Filtered output.
161 | """
162 | return self.conv(input, weight=self.weight, padding=self.padding, groups=self.groups)
163 |
164 |
165 | class MaxPoolDownSize(nn.Module):
166 | def __init__(self, in_channels, mid_channels, out_channels, depth):
167 | super(MaxPoolDownSize, self).__init__()
168 | self.depth = depth
169 | self.reduce_conv = ConvBlock(in_channels, mid_channels, kernel_size=1, stride=1, padding=0)
170 | self.convs = nn.ModuleList([
171 | ConvBlock(mid_channels, out_channels, kernel_size=3, stride=1, padding=1)
172 | for conv_i in range(depth)
173 | ])
174 | self.pool2d = nn.MaxPool2d(kernel_size=2)
175 |
176 | def forward(self, x):
177 | outputs = []
178 |
179 | output = self.reduce_conv(x)
180 |
181 | for conv_i, conv in enumerate(self.convs):
182 | output = output if conv_i == 0 else self.pool2d(output)
183 | outputs.append(conv(output))
184 |
185 | return outputs
186 |
--------------------------------------------------------------------------------
/iharm/model/modeling/conv_autoencoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn as nn
3 |
4 | from iharm.model.modeling.basic_blocks import ConvBlock
5 | from iharm.model.ops import MaskedChannelAttention, FeaturesConnector
6 |
7 |
8 | class ConvEncoder(nn.Module):
9 | def __init__(
10 | self,
11 | depth, ch,
12 | norm_layer, batchnorm_from, max_channels,
13 | backbone_from, backbone_channels=None, backbone_mode=''
14 | ):
15 | super(ConvEncoder, self).__init__()
16 | self.depth = depth
17 | self.backbone_from = backbone_from
18 | backbone_channels = [] if backbone_channels is None else backbone_channels[::-1]
19 |
20 | in_channels = 4
21 | out_channels = ch
22 |
23 | self.block0 = ConvBlock(in_channels, out_channels, norm_layer=norm_layer if batchnorm_from == 0 else None)
24 | self.block1 = ConvBlock(out_channels, out_channels, norm_layer=norm_layer if 0 <= batchnorm_from <= 1 else None)
25 | self.blocks_channels = [out_channels, out_channels]
26 |
27 | self.blocks_connected = nn.ModuleDict()
28 | self.connectors = nn.ModuleDict()
29 | for block_i in range(2, depth):
30 | if block_i % 2:
31 | in_channels = out_channels
32 | else:
33 | in_channels, out_channels = out_channels, min(2 * out_channels, max_channels)
34 |
35 | if 0 <= backbone_from <= block_i and len(backbone_channels):
36 | stage_channels = backbone_channels.pop()
37 | connector = FeaturesConnector(backbone_mode, in_channels, stage_channels, in_channels)
38 | self.connectors[f'connector{block_i}'] = connector
39 | in_channels = connector.output_channels
40 |
41 | self.blocks_connected[f'block{block_i}'] = ConvBlock(
42 | in_channels, out_channels,
43 | norm_layer=norm_layer if 0 <= batchnorm_from <= block_i else None,
44 | padding=int(block_i < depth - 1)
45 | )
46 | self.blocks_channels += [out_channels]
47 |
48 | def forward(self, x, backbone_features):
49 | backbone_features = [] if backbone_features is None else backbone_features[::-1]
50 |
51 | outputs = [self.block0(x)]
52 | outputs += [self.block1(outputs[-1])]
53 |
54 | for block_i in range(2, self.depth):
55 | block = self.blocks_connected[f'block{block_i}']
56 | output = outputs[-1]
57 | connector_name = f'connector{block_i}'
58 | if connector_name in self.connectors:
59 | stage_features = backbone_features.pop()
60 | connector = self.connectors[connector_name]
61 | output = connector(output, stage_features)
62 | outputs += [block(output)]
63 |
64 | return outputs[::-1]
65 |
66 |
67 | class DeconvDecoder(nn.Module):
68 | def __init__(self, depth, encoder_blocks_channels, norm_layer, attend_from=-1, image_fusion=False):
69 | super(DeconvDecoder, self).__init__()
70 | self.image_fusion = image_fusion
71 | self.deconv_blocks = nn.ModuleList()
72 |
73 | in_channels = encoder_blocks_channels.pop()
74 | out_channels = in_channels
75 | for d in range(depth):
76 | out_channels = encoder_blocks_channels.pop() if len(encoder_blocks_channels) else in_channels // 2
77 | self.deconv_blocks.append(SEDeconvBlock(
78 | in_channels, out_channels,
79 | norm_layer=norm_layer,
80 | padding=0 if d == 0 else 1,
81 | with_se=0 <= attend_from <= d
82 | ))
83 | in_channels = out_channels
84 |
85 | if self.image_fusion:
86 | self.conv_attention = nn.Conv2d(out_channels, 1, kernel_size=1)
87 | self.to_rgb = nn.Conv2d(out_channels, 3, kernel_size=1)
88 |
89 | def forward(self, encoder_outputs, image, mask=None):
90 | output = encoder_outputs[0]
91 | for block, skip_output in zip(self.deconv_blocks[:-1], encoder_outputs[1:]):
92 | output = block(output, mask)
93 | output = output + skip_output
94 | output = self.deconv_blocks[-1](output, mask)
95 |
96 | if self.image_fusion:
97 | attention_map = torch.sigmoid(3.0 * self.conv_attention(output))
98 | output = attention_map * image + (1.0 - attention_map) * self.to_rgb(output)
99 | else:
100 | output = self.to_rgb(output)
101 |
102 | return output
103 |
104 |
105 | class SEDeconvBlock(nn.Module):
106 | def __init__(
107 | self,
108 | in_channels, out_channels,
109 | kernel_size=4, stride=2, padding=1,
110 | norm_layer=nn.BatchNorm2d, activation=nn.ELU,
111 | with_se=False
112 | ):
113 | super(SEDeconvBlock, self).__init__()
114 | self.with_se = with_se
115 | self.block = nn.Sequential(
116 | nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding),
117 | norm_layer(out_channels) if norm_layer is not None else nn.Identity(),
118 | activation(),
119 | )
120 | if self.with_se:
121 | self.se = MaskedChannelAttention(out_channels)
122 |
123 | def forward(self, x, mask=None):
124 | out = self.block(x)
125 | if self.with_se:
126 | out = self.se(out, mask)
127 | return out
128 |
--------------------------------------------------------------------------------
/iharm/model/modeling/deeplab_v3.py:
--------------------------------------------------------------------------------
1 | from contextlib import ExitStack
2 |
3 | import torch
4 | from torch import nn
5 | import torch.nn.functional as F
6 |
7 | from iharm.model.modeling.basic_blocks import select_activation_function
8 | from .basic_blocks import SeparableConv2d
9 | from .resnet import ResNetBackbone
10 |
11 |
12 | class DeepLabV3Plus(nn.Module):
13 | def __init__(self, backbone='resnet50', norm_layer=nn.BatchNorm2d,
14 | backbone_norm_layer=None,
15 | ch=256,
16 | project_dropout=0.5,
17 | inference_mode=False,
18 | **kwargs):
19 | super(DeepLabV3Plus, self).__init__()
20 | if backbone_norm_layer is None:
21 | backbone_norm_layer = norm_layer
22 |
23 | self.backbone_name = backbone
24 | self.norm_layer = norm_layer
25 | self.backbone_norm_layer = backbone_norm_layer
26 | self.inference_mode = False
27 | self.ch = ch
28 | self.aspp_in_channels = 2048
29 | self.skip_project_in_channels = 256 # layer 1 out_channels
30 |
31 | self._kwargs = kwargs
32 | if backbone == 'resnet34':
33 | self.aspp_in_channels = 512
34 | self.skip_project_in_channels = 64
35 |
36 | self.backbone = ResNetBackbone(backbone=self.backbone_name, pretrained_base=False,
37 | norm_layer=self.backbone_norm_layer, **kwargs)
38 |
39 | self.head = _DeepLabHead(in_channels=ch + 32, mid_channels=ch, out_channels=ch,
40 | norm_layer=self.norm_layer)
41 | self.skip_project = _SkipProject(self.skip_project_in_channels, 32, norm_layer=self.norm_layer)
42 | self.aspp = _ASPP(in_channels=self.aspp_in_channels,
43 | atrous_rates=[12, 24, 36],
44 | out_channels=ch,
45 | project_dropout=project_dropout,
46 | norm_layer=self.norm_layer)
47 |
48 | if inference_mode:
49 | self.set_prediction_mode()
50 |
51 | def load_pretrained_weights(self):
52 | pretrained = ResNetBackbone(backbone=self.backbone_name, pretrained_base=True,
53 | norm_layer=self.backbone_norm_layer, **self._kwargs)
54 | backbone_state_dict = self.backbone.state_dict()
55 | pretrained_state_dict = pretrained.state_dict()
56 |
57 | backbone_state_dict.update(pretrained_state_dict)
58 | self.backbone.load_state_dict(backbone_state_dict)
59 |
60 | if self.inference_mode:
61 | for param in self.backbone.parameters():
62 | param.requires_grad = False
63 |
64 | def set_prediction_mode(self):
65 | self.inference_mode = True
66 | self.eval()
67 |
68 | def forward(self, x, mask_features=None):
69 | with ExitStack() as stack:
70 | if self.inference_mode:
71 | stack.enter_context(torch.no_grad())
72 |
73 | c1, _, c3, c4 = self.backbone(x, mask_features)
74 | c1 = self.skip_project(c1)
75 |
76 | x = self.aspp(c4)
77 | x = F.interpolate(x, c1.size()[2:], mode='bilinear', align_corners=True)
78 | x = torch.cat((x, c1), dim=1)
79 | x = self.head(x)
80 |
81 | return x,
82 |
83 |
84 | class _SkipProject(nn.Module):
85 | def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d):
86 | super(_SkipProject, self).__init__()
87 | _activation = select_activation_function("relu")
88 |
89 | self.skip_project = nn.Sequential(
90 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
91 | norm_layer(out_channels),
92 | _activation()
93 | )
94 |
95 | def forward(self, x):
96 | return self.skip_project(x)
97 |
98 |
99 | class _DeepLabHead(nn.Module):
100 | def __init__(self, out_channels, in_channels, mid_channels=256, norm_layer=nn.BatchNorm2d):
101 | super(_DeepLabHead, self).__init__()
102 |
103 | self.block = nn.Sequential(
104 | SeparableConv2d(in_channels=in_channels, out_channels=mid_channels, dw_kernel=3,
105 | dw_padding=1, activation='relu', norm_layer=norm_layer),
106 | SeparableConv2d(in_channels=mid_channels, out_channels=mid_channels, dw_kernel=3,
107 | dw_padding=1, activation='relu', norm_layer=norm_layer),
108 | nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=1)
109 | )
110 |
111 | def forward(self, x):
112 | return self.block(x)
113 |
114 |
115 | class _ASPP(nn.Module):
116 | def __init__(self, in_channels, atrous_rates, out_channels=256,
117 | project_dropout=0.5, norm_layer=nn.BatchNorm2d):
118 | super(_ASPP, self).__init__()
119 |
120 | b0 = nn.Sequential(
121 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=False),
122 | norm_layer(out_channels),
123 | nn.ReLU()
124 | )
125 |
126 | rate1, rate2, rate3 = tuple(atrous_rates)
127 | b1 = _ASPPConv(in_channels, out_channels, rate1, norm_layer)
128 | b2 = _ASPPConv(in_channels, out_channels, rate2, norm_layer)
129 | b3 = _ASPPConv(in_channels, out_channels, rate3, norm_layer)
130 | b4 = _AsppPooling(in_channels, out_channels, norm_layer=norm_layer)
131 |
132 | self.concurent = nn.ModuleList([b0, b1, b2, b3, b4])
133 |
134 | project = [
135 | nn.Conv2d(in_channels=5*out_channels, out_channels=out_channels,
136 | kernel_size=1, bias=False),
137 | norm_layer(out_channels),
138 | nn.ReLU()
139 | ]
140 | if project_dropout > 0:
141 | project.append(nn.Dropout(project_dropout))
142 | self.project = nn.Sequential(*project)
143 |
144 | def forward(self, x):
145 | x = torch.cat([block(x) for block in self.concurent], dim=1)
146 |
147 | return self.project(x)
148 |
149 |
150 | class _AsppPooling(nn.Module):
151 | def __init__(self, in_channels, out_channels, norm_layer):
152 | super(_AsppPooling, self).__init__()
153 |
154 | self.gap = nn.Sequential(
155 | nn.AdaptiveAvgPool2d((1, 1)),
156 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
157 | kernel_size=1, bias=False),
158 | norm_layer(out_channels),
159 | nn.ReLU()
160 | )
161 |
162 | def forward(self, x):
163 | pool = self.gap(x)
164 | return F.interpolate(pool, x.size()[2:], mode='bilinear', align_corners=True)
165 |
166 |
167 | def _ASPPConv(in_channels, out_channels, atrous_rate, norm_layer):
168 | block = nn.Sequential(
169 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
170 | kernel_size=3, padding=atrous_rate,
171 | dilation=atrous_rate, bias=False),
172 | norm_layer(out_channels),
173 | nn.ReLU()
174 | )
175 |
176 | return block
177 |
--------------------------------------------------------------------------------
/iharm/model/modeling/ocr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch._utils
4 | import torch.nn.functional as F
5 |
6 |
7 | class SpatialGather_Module(nn.Module):
8 | """
9 | Aggregate the context features according to the initial
10 | predicted probability distribution.
11 | Employ the soft-weighted method to aggregate the context.
12 | """
13 |
14 | def __init__(self, cls_num=0, scale=1):
15 | super(SpatialGather_Module, self).__init__()
16 | self.cls_num = cls_num
17 | self.scale = scale
18 |
19 | def forward(self, feats, probs):
20 | batch_size, c, h, w = probs.size(0), probs.size(1), probs.size(2), probs.size(3)
21 | probs = probs.view(batch_size, c, -1)
22 | feats = feats.view(batch_size, feats.size(1), -1)
23 | feats = feats.permute(0, 2, 1) # batch x hw x c
24 | probs = F.softmax(self.scale * probs, dim=2) # batch x k x hw
25 | ocr_context = torch.matmul(probs, feats) \
26 | .permute(0, 2, 1).unsqueeze(3).contiguous() # batch x k x c
27 | return ocr_context
28 |
29 |
30 | class SpatialOCR_Module(nn.Module):
31 | """
32 | Implementation of the OCR module:
33 | We aggregate the global object representation to update the representation for each pixel.
34 | """
35 |
36 | def __init__(self,
37 | in_channels,
38 | key_channels,
39 | out_channels,
40 | scale=1,
41 | dropout=0.1,
42 | norm_layer=nn.BatchNorm2d,
43 | align_corners=True):
44 | super(SpatialOCR_Module, self).__init__()
45 | self.object_context_block = ObjectAttentionBlock2D(in_channels, key_channels, scale,
46 | norm_layer, align_corners)
47 | _in_channels = 2 * in_channels
48 |
49 | self.conv_bn_dropout = nn.Sequential(
50 | nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0, bias=False),
51 | nn.Sequential(norm_layer(out_channels), nn.ReLU(inplace=True)),
52 | nn.Dropout2d(dropout)
53 | )
54 |
55 | def forward(self, feats, proxy_feats):
56 | context = self.object_context_block(feats, proxy_feats)
57 |
58 | output = self.conv_bn_dropout(torch.cat([context, feats], 1))
59 |
60 | return output
61 |
62 |
63 | class ObjectAttentionBlock2D(nn.Module):
64 | '''
65 | The basic implementation for object context block
66 | Input:
67 | N X C X H X W
68 | Parameters:
69 | in_channels : the dimension of the input feature map
70 | key_channels : the dimension after the key/query transform
71 | scale : choose the scale to downsample the input feature maps (save memory cost)
72 | bn_type : specify the bn type
73 | Return:
74 | N X C X H X W
75 | '''
76 |
77 | def __init__(self,
78 | in_channels,
79 | key_channels,
80 | scale=1,
81 | norm_layer=nn.BatchNorm2d,
82 | align_corners=True):
83 | super(ObjectAttentionBlock2D, self).__init__()
84 | self.scale = scale
85 | self.in_channels = in_channels
86 | self.key_channels = key_channels
87 | self.align_corners = align_corners
88 |
89 | self.pool = nn.MaxPool2d(kernel_size=(scale, scale))
90 | self.f_pixel = nn.Sequential(
91 | nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
92 | kernel_size=1, stride=1, padding=0, bias=False),
93 | nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
94 | nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
95 | kernel_size=1, stride=1, padding=0, bias=False),
96 | nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
97 | )
98 | self.f_object = nn.Sequential(
99 | nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
100 | kernel_size=1, stride=1, padding=0, bias=False),
101 | nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
102 | nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
103 | kernel_size=1, stride=1, padding=0, bias=False),
104 | nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
105 | )
106 | self.f_down = nn.Sequential(
107 | nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
108 | kernel_size=1, stride=1, padding=0, bias=False),
109 | nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
110 | )
111 | self.f_up = nn.Sequential(
112 | nn.Conv2d(in_channels=self.key_channels, out_channels=self.in_channels,
113 | kernel_size=1, stride=1, padding=0, bias=False),
114 | nn.Sequential(norm_layer(self.in_channels), nn.ReLU(inplace=True))
115 | )
116 |
117 | def forward(self, x, proxy):
118 | batch_size, h, w = x.size(0), x.size(2), x.size(3)
119 | if self.scale > 1:
120 | x = self.pool(x)
121 |
122 | query = self.f_pixel(x).view(batch_size, self.key_channels, -1)
123 | query = query.permute(0, 2, 1)
124 | key = self.f_object(proxy).view(batch_size, self.key_channels, -1)
125 | value = self.f_down(proxy).view(batch_size, self.key_channels, -1)
126 | value = value.permute(0, 2, 1)
127 |
128 | sim_map = torch.matmul(query, key)
129 | sim_map = (self.key_channels ** -.5) * sim_map
130 | sim_map = F.softmax(sim_map, dim=-1)
131 |
132 | # add bg context ...
133 | context = torch.matmul(sim_map, value)
134 | context = context.permute(0, 2, 1).contiguous()
135 | context = context.view(batch_size, self.key_channels, *x.size()[2:])
136 | context = self.f_up(context)
137 | if self.scale > 1:
138 | context = F.interpolate(input=context, size=(h, w),
139 | mode='bilinear', align_corners=self.align_corners)
140 |
141 | return context
142 |
--------------------------------------------------------------------------------
/iharm/model/modeling/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .resnetv1b import resnet34_v1b, resnet50_v1s, resnet101_v1s, resnet152_v1s
3 |
4 |
5 | class ResNetBackbone(torch.nn.Module):
6 | def __init__(self, backbone='resnet50', pretrained_base=True, dilated=True, **kwargs):
7 | super(ResNetBackbone, self).__init__()
8 |
9 | if backbone == 'resnet34':
10 | pretrained = resnet34_v1b(pretrained=pretrained_base, dilated=dilated, **kwargs)
11 | elif backbone == 'resnet50':
12 | pretrained = resnet50_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs)
13 | elif backbone == 'resnet101':
14 | pretrained = resnet101_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs)
15 | elif backbone == 'resnet152':
16 | pretrained = resnet152_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs)
17 | else:
18 | raise RuntimeError(f'unknown backbone: {backbone}')
19 |
20 | self.conv1 = pretrained.conv1
21 | self.bn1 = pretrained.bn1
22 | self.relu = pretrained.relu
23 | self.maxpool = pretrained.maxpool
24 | self.layer1 = pretrained.layer1
25 | self.layer2 = pretrained.layer2
26 | self.layer3 = pretrained.layer3
27 | self.layer4 = pretrained.layer4
28 |
29 | def forward(self, x, mask_features=None):
30 | x = self.conv1(x)
31 | x = self.bn1(x)
32 | x = self.relu(x)
33 | if mask_features is not None:
34 | x = x + mask_features
35 | x = self.maxpool(x)
36 | c1 = self.layer1(x)
37 | c2 = self.layer2(c1)
38 | c3 = self.layer3(c2)
39 | c4 = self.layer4(c3)
40 |
41 | return c1, c2, c3, c4
42 |
--------------------------------------------------------------------------------
/iharm/model/modeling/unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn as nn
3 | from functools import partial
4 |
5 | from iharm.model.modeling.basic_blocks import ConvBlock
6 | from iharm.model.ops import FeaturesConnector
7 |
8 |
9 | class UNetEncoder(nn.Module):
10 | def __init__(
11 | self,
12 | depth, ch,
13 | norm_layer, batchnorm_from, max_channels,
14 | backbone_from, backbone_channels=None, backbone_mode=''
15 | ):
16 | super(UNetEncoder, self).__init__()
17 | self.depth = depth
18 | self.backbone_from = backbone_from
19 | self.block_channels = []
20 | backbone_channels = [] if backbone_channels is None else backbone_channels[::-1]
21 | relu = partial(nn.ReLU, inplace=True)
22 |
23 | in_channels = 4
24 | out_channels = ch
25 |
26 | self.block0 = UNetDownBlock(
27 | in_channels, out_channels,
28 | norm_layer=norm_layer if batchnorm_from == 0 else None,
29 | activation=relu,
30 | pool=True, padding=1,
31 | )
32 | self.block_channels.append(out_channels)
33 | in_channels, out_channels = out_channels, min(2 * out_channels, max_channels)
34 | self.block1 = UNetDownBlock(
35 | in_channels, out_channels,
36 | norm_layer=norm_layer if 0 <= batchnorm_from <= 1 else None,
37 | activation=relu,
38 | pool=True, padding=1,
39 | )
40 | self.block_channels.append(out_channels)
41 |
42 | self.blocks_connected = nn.ModuleDict()
43 | self.connectors = nn.ModuleDict()
44 | for block_i in range(2, depth):
45 | in_channels, out_channels = out_channels, min(2 * out_channels, max_channels)
46 | if 0 <= backbone_from <= block_i and len(backbone_channels):
47 | stage_channels = backbone_channels.pop()
48 | connector = FeaturesConnector(backbone_mode, in_channels, stage_channels, in_channels)
49 | self.connectors[f'connector{block_i}'] = connector
50 | in_channels = connector.output_channels
51 | self.blocks_connected[f'block{block_i}'] = UNetDownBlock(
52 | in_channels, out_channels,
53 | norm_layer=norm_layer if 0 <= batchnorm_from <= block_i else None,
54 | activation=relu, padding=1,
55 | pool=block_i < depth - 1,
56 | )
57 | self.block_channels.append(out_channels)
58 |
59 | def forward(self, x, backbone_features):
60 | backbone_features = [] if backbone_features is None else backbone_features[::-1]
61 | outputs = []
62 |
63 | block_input = x
64 | output, block_input = self.block0(block_input)
65 | outputs.append(output)
66 | output, block_input = self.block1(block_input)
67 | outputs.append(output)
68 |
69 | for block_i in range(2, self.depth):
70 | block = self.blocks_connected[f'block{block_i}']
71 | connector_name = f'connector{block_i}'
72 | if connector_name in self.connectors:
73 | stage_features = backbone_features.pop()
74 | connector = self.connectors[connector_name]
75 | block_input = connector(block_input, stage_features)
76 | output, block_input = block(block_input)
77 | outputs.append(output)
78 |
79 | return outputs[::-1]
80 |
81 |
82 | class UNetDecoder(nn.Module):
83 | def __init__(self, depth, encoder_blocks_channels, norm_layer,
84 | attention_layer=None, attend_from=3, image_fusion=False):
85 | super(UNetDecoder, self).__init__()
86 | self.up_blocks = nn.ModuleList()
87 | self.image_fusion = image_fusion
88 | in_channels = encoder_blocks_channels.pop()
89 | out_channels = in_channels
90 | # Last encoder layer doesn't pool, so there're only (depth - 1) deconvs
91 | for d in range(depth - 1):
92 | out_channels = encoder_blocks_channels.pop() if len(encoder_blocks_channels) else in_channels // 2
93 | stage_attention_layer = attention_layer if 0 <= attend_from <= d else None
94 | self.up_blocks.append(UNetUpBlock(
95 | in_channels, out_channels, out_channels,
96 | norm_layer=norm_layer, activation=partial(nn.ReLU, inplace=True),
97 | padding=1,
98 | attention_layer=stage_attention_layer,
99 | ))
100 | in_channels = out_channels
101 |
102 | if self.image_fusion:
103 | self.conv_attention = nn.Conv2d(out_channels, 1, kernel_size=1)
104 | self.to_rgb = nn.Conv2d(out_channels, 3, kernel_size=1)
105 |
106 | def forward(self, encoder_outputs, input_image, mask):
107 | output = encoder_outputs[0]
108 | for block, skip_output in zip(self.up_blocks, encoder_outputs[1:]):
109 | output = block(output, skip_output, mask)
110 |
111 | if self.image_fusion:
112 | attention_map = torch.sigmoid(3.0 * self.conv_attention(output))
113 | output = attention_map * input_image + (1.0 - attention_map) * self.to_rgb(output)
114 | else:
115 | output = self.to_rgb(output)
116 |
117 | return output
118 |
119 |
120 | class UNetDownBlock(nn.Module):
121 | def __init__(self, in_channels, out_channels, norm_layer, activation, pool, padding):
122 | super(UNetDownBlock, self).__init__()
123 | self.convs = UNetDoubleConv(
124 | in_channels, out_channels,
125 | norm_layer=norm_layer, activation=activation, padding=padding,
126 | )
127 | self.pooling = nn.MaxPool2d(2, 2) if pool else nn.Identity()
128 |
129 | def forward(self, x):
130 | conv_x = self.convs(x)
131 | return conv_x, self.pooling(conv_x)
132 |
133 |
134 | class UNetUpBlock(nn.Module):
135 | def __init__(
136 | self,
137 | in_channels_decoder, in_channels_encoder, out_channels,
138 | norm_layer, activation, padding,
139 | attention_layer,
140 | ):
141 | super(UNetUpBlock, self).__init__()
142 | self.upconv = nn.Sequential(
143 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
144 | ConvBlock(
145 | in_channels_decoder, out_channels,
146 | kernel_size=3, stride=1, padding=1,
147 | norm_layer=None, activation=activation,
148 | )
149 | )
150 | self.convs = UNetDoubleConv(
151 | in_channels_encoder + out_channels, out_channels,
152 | norm_layer=norm_layer, activation=activation, padding=padding,
153 | )
154 | if attention_layer is not None:
155 | self.attention = attention_layer(in_channels_encoder + out_channels, norm_layer, activation)
156 | else:
157 | self.attention = None
158 |
159 | def forward(self, x, encoder_out, mask=None):
160 | upsample_x = self.upconv(x)
161 | x_cat_encoder = torch.cat([encoder_out, upsample_x], dim=1)
162 | if self.attention is not None:
163 | x_cat_encoder = self.attention(x_cat_encoder, mask)
164 | return self.convs(x_cat_encoder)
165 |
166 |
167 | class UNetDoubleConv(nn.Module):
168 | def __init__(self, in_channels, out_channels, norm_layer, activation, padding):
169 | super(UNetDoubleConv, self).__init__()
170 | self.block = nn.Sequential(
171 | ConvBlock(
172 | in_channels, out_channels,
173 | kernel_size=3, stride=1, padding=padding,
174 | norm_layer=norm_layer, activation=activation,
175 | ),
176 | ConvBlock(
177 | out_channels, out_channels,
178 | kernel_size=3, stride=1, padding=padding,
179 | norm_layer=norm_layer, activation=activation,
180 | ),
181 | )
182 |
183 | def forward(self, x):
184 | return self.block(x)
185 |
--------------------------------------------------------------------------------
/iharm/model/modeling/unet_v1.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn as nn
3 | from functools import partial
4 |
5 | from iharm.model.modeling.basic_blocks import ConvBlock,GaussianSmoothing
6 | from iharm.model.ops import FeaturesConnector
7 |
8 |
9 | class UNetEncoder(nn.Module):
10 | def __init__(
11 | self,
12 | depth, ch,
13 | norm_layer, batchnorm_from, max_channels,
14 | backbone_from, backbone_channels=None, backbone_mode='',_in_channels = 4
15 | ):
16 | super(UNetEncoder, self).__init__()
17 | self.depth = depth
18 | self.backbone_from = backbone_from
19 | self.block_channels = []
20 | backbone_channels = [] if backbone_channels is None else backbone_channels[::-1]
21 | relu = partial(nn.ReLU, inplace=True)
22 |
23 | in_channels = _in_channels
24 | out_channels = ch
25 |
26 | # print(in_channels,out_channels)
27 |
28 | self.block0 = UNetDownBlock(
29 | in_channels, out_channels,
30 | norm_layer=norm_layer if batchnorm_from == 0 else None,
31 | activation=relu,
32 | pool=True, padding=1,
33 | )
34 | self.block_channels.append(out_channels)
35 | in_channels, out_channels = out_channels, min(2 * out_channels, max_channels)
36 | self.block1 = UNetDownBlock(
37 | in_channels, out_channels,
38 | norm_layer=norm_layer if 0 <= batchnorm_from <= 1 else None,
39 | activation=relu,
40 | pool=True, padding=1,
41 | )
42 | self.block_channels.append(out_channels)
43 |
44 | self.blocks_connected = nn.ModuleDict()
45 | self.connectors = nn.ModuleDict()
46 | for block_i in range(2, depth):
47 | in_channels, out_channels = out_channels, min(2 * out_channels, max_channels)
48 | if 0 <= backbone_from <= block_i and len(backbone_channels):
49 | stage_channels = backbone_channels.pop()
50 | connector = FeaturesConnector(backbone_mode, in_channels, stage_channels, in_channels)
51 | self.connectors[f'connector{block_i}'] = connector
52 | in_channels = connector.output_channels
53 | self.blocks_connected[f'block{block_i}'] = UNetDownBlock(
54 | in_channels, out_channels,
55 | norm_layer=norm_layer if 0 <= batchnorm_from <= block_i else None,
56 | activation=relu, padding=1,
57 | pool=block_i < depth - 1,
58 | )
59 | self.block_channels.append(out_channels)
60 |
61 | def forward(self, x, backbone_features):
62 | backbone_features = [] if backbone_features is None else backbone_features[::-1]
63 | outputs = []
64 |
65 | block_input = x
66 | output, block_input = self.block0(block_input)
67 | outputs.append(output)
68 | output, block_input = self.block1(block_input)
69 | outputs.append(output)
70 |
71 | for block_i in range(2, self.depth):
72 | block = self.blocks_connected[f'block{block_i}']
73 | connector_name = f'connector{block_i}'
74 | if connector_name in self.connectors:
75 | stage_features = backbone_features.pop()
76 | connector = self.connectors[connector_name]
77 | block_input = connector(block_input, stage_features)
78 | output, block_input = block(block_input)
79 | outputs.append(output)
80 |
81 | return outputs[::-1]
82 |
83 |
84 | class UNetDecoder(nn.Module):
85 | def __init__(self, depth, encoder_blocks_channels, norm_layer,
86 | attention_layer=None, attend_from=3, image_fusion=False, control_module_start = -1,
87 | control_module_layer=None,
88 | w_dim = 256):
89 | super(UNetDecoder, self).__init__()
90 | self.up_blocks = nn.ModuleList()
91 | self.image_fusion = image_fusion
92 | self.w_dim = w_dim
93 | in_channels = encoder_blocks_channels.pop()
94 | out_channels = in_channels
95 | # Last encoder layer doesn't pool, so there're only (depth - 1) deconvs
96 | for d in range(depth - 1):
97 | out_channels = encoder_blocks_channels.pop() if len(encoder_blocks_channels) else in_channels // 2
98 | use_control_module = True if control_module_start == -1 else d >= control_module_start
99 | stage_attention_layer = attention_layer if 0 <= attend_from <= d else None
100 | self.up_blocks.append(UNetUpBlock(
101 | in_channels, out_channels, out_channels,
102 | norm_layer=norm_layer, activation=partial(nn.ReLU, inplace=True),
103 | padding=1,
104 | attention_layer=stage_attention_layer,
105 | use_control_module = use_control_module,
106 | w_dim = self.w_dim,
107 | control_module_layer = control_module_layer,
108 | ))
109 | in_channels = out_channels
110 |
111 | if self.image_fusion:
112 | self.conv_attention = nn.Conv2d(out_channels, 1, kernel_size=1)
113 | self.to_rgb = nn.Conv2d(out_channels, 3, kernel_size=1)
114 |
115 | def forward(self, encoder_outputs, input_image, mask, ws):
116 | output = encoder_outputs[0]
117 | count = 0
118 | for block, skip_output in zip(self.up_blocks, encoder_outputs[1:]):
119 | output = block(output, skip_output, mask,ws)
120 | count += 1
121 |
122 | if self.image_fusion:
123 | attention_map = torch.sigmoid(3.0 * self.conv_attention(output))
124 | output = attention_map * input_image + (1.0 - attention_map) * self.to_rgb(output)
125 | else:
126 | output = self.to_rgb(output)
127 |
128 | return output
129 |
130 |
131 | class UNetDownBlock(nn.Module):
132 | def __init__(self, in_channels, out_channels, norm_layer, activation, pool, padding):
133 | super(UNetDownBlock, self).__init__()
134 | self.convs = UNetDoubleConv(
135 | in_channels, out_channels,
136 | norm_layer=norm_layer, activation=activation, padding=padding,
137 | )
138 | self.pooling = nn.MaxPool2d(2, 2) if pool else nn.Identity()
139 |
140 | def forward(self, x):
141 | conv_x = self.convs(x)
142 | return conv_x, self.pooling(conv_x)
143 |
144 |
145 | class UNetUpBlock(nn.Module):
146 | def __init__(
147 | self,
148 | in_channels_decoder, in_channels_encoder, out_channels,
149 | norm_layer, activation, padding,
150 | attention_layer,
151 | use_control_module,
152 | w_dim,
153 | control_module_layer,
154 | ):
155 | super(UNetUpBlock, self).__init__()
156 | self.upconv = nn.Sequential(
157 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
158 | ConvBlock(
159 | in_channels_decoder, out_channels,
160 | kernel_size=3, stride=1, padding=1,
161 | norm_layer=None, activation=activation,
162 | )
163 | )
164 | self.convs = UNetDoubleConv(
165 | in_channels_encoder + out_channels, out_channels,
166 | norm_layer=norm_layer, activation=activation, padding=padding,
167 | )
168 | if attention_layer is not None:
169 | self.attention = attention_layer(in_channels_encoder + out_channels, norm_layer, activation)
170 | else:
171 | self.attention = None
172 |
173 | self.w_dim = w_dim
174 | self.use_control_module = use_control_module
175 | self.out_channel = out_channels
176 |
177 | if self.use_control_module:
178 | self.control_module = control_module_layer(
179 | w_dim = self.w_dim,
180 | feature_dim = out_channels,
181 | )
182 |
183 |
184 | def forward(self, x, encoder_out, mask, ws):
185 | upsample_x = self.upconv(x)
186 |
187 | _mask = nn.functional.interpolate(
188 | mask, size=encoder_out.size()[-2:],
189 | mode='bilinear', align_corners=True
190 | )
191 |
192 | x_cat_encoder = torch.cat([encoder_out, upsample_x], dim=1)
193 | if self.attention is not None:
194 | x_cat_encoder = self.attention(x_cat_encoder, mask)
195 | dec_out = self.convs(x_cat_encoder)
196 | if self.use_control_module:
197 | wl, wa, wb = ws['w_l'], ws['w_a'], ws['w_b']
198 | dec_out = self.control_module(dec_out, wl, wa, wb, _mask)
199 | return dec_out
200 |
201 |
202 | class UNetDoubleConv(nn.Module):
203 | def __init__(self, in_channels, out_channels, norm_layer, activation, padding):
204 | super(UNetDoubleConv, self).__init__()
205 | self.block = nn.Sequential(
206 | ConvBlock(
207 | in_channels, out_channels,
208 | kernel_size=3, stride=1, padding=padding,
209 | norm_layer=norm_layer, activation=activation,
210 | ),
211 | ConvBlock(
212 | out_channels, out_channels,
213 | kernel_size=3, stride=1, padding=padding,
214 | norm_layer=norm_layer, activation=activation,
215 | ),
216 | )
217 |
218 | def forward(self, x):
219 | return self.block(x)
220 |
--------------------------------------------------------------------------------
/iharm/model/modifiers.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | class LRMult(object):
4 | def __init__(self, lr_mult=1.):
5 | self.lr_mult = lr_mult
6 |
7 | def __call__(self, m):
8 | if getattr(m, 'weight', None) is not None:
9 | m.weight.lr_mult = self.lr_mult
10 | if getattr(m, 'bias', None) is not None:
11 | m.bias.lr_mult = self.lr_mult
12 |
--------------------------------------------------------------------------------
/iharm/model/ops.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn as nn
3 |
4 |
5 | class SimpleInputFusion(nn.Module):
6 | def __init__(self, add_ch=1, rgb_ch=3, ch=8, norm_layer=nn.BatchNorm2d):
7 | super(SimpleInputFusion, self).__init__()
8 |
9 | self.fusion_conv = nn.Sequential(
10 | nn.Conv2d(in_channels=add_ch + rgb_ch, out_channels=ch, kernel_size=1),
11 | nn.LeakyReLU(negative_slope=0.2),
12 | norm_layer(ch),
13 | nn.Conv2d(in_channels=ch, out_channels=rgb_ch, kernel_size=1),
14 | )
15 |
16 | def forward(self, image, additional_input):
17 | return self.fusion_conv(torch.cat((image, additional_input), dim=1))
18 |
19 |
20 | class ChannelAttention(nn.Module):
21 | def __init__(self, in_channels):
22 | super(ChannelAttention, self).__init__()
23 | self.global_pools = nn.ModuleList([
24 | nn.AdaptiveAvgPool2d(1),
25 | nn.AdaptiveMaxPool2d(1),
26 | ])
27 | intermediate_channels_count = max(in_channels // 16, 8)
28 | self.attention_transform = nn.Sequential(
29 | nn.Linear(len(self.global_pools) * in_channels, intermediate_channels_count),
30 | nn.ReLU(),
31 | nn.Linear(intermediate_channels_count, in_channels),
32 | nn.Sigmoid(),
33 | )
34 |
35 | def forward(self, x):
36 | pooled_x = []
37 | for global_pool in self.global_pools:
38 | pooled_x.append(global_pool(x))
39 | pooled_x = torch.cat(pooled_x, dim=1).flatten(start_dim=1)
40 | channel_attention_weights = self.attention_transform(pooled_x)[..., None, None]
41 | return channel_attention_weights * x
42 |
43 |
44 | class MaskedChannelAttention(nn.Module):
45 | def __init__(self, in_channels, *args, **kwargs):
46 | super(MaskedChannelAttention, self).__init__()
47 | self.global_max_pool = MaskedGlobalMaxPool2d()
48 | self.global_avg_pool = FastGlobalAvgPool2d()
49 |
50 | intermediate_channels_count = max(in_channels // 16, 8)
51 | self.attention_transform = nn.Sequential(
52 | nn.Linear(3 * in_channels, intermediate_channels_count),
53 | nn.ReLU(inplace=True),
54 | nn.Linear(intermediate_channels_count, in_channels),
55 | nn.Sigmoid(),
56 | )
57 |
58 | def forward(self, x, mask):
59 | if mask.shape[2:] != x.shape[:2]:
60 | mask = nn.functional.interpolate(
61 | mask, size=x.size()[-2:],
62 | mode='bilinear', align_corners=True
63 | )
64 | pooled_x = torch.cat([
65 | self.global_max_pool(x, mask),
66 | self.global_avg_pool(x)
67 | ], dim=1)
68 | channel_attention_weights = self.attention_transform(pooled_x)[..., None, None]
69 |
70 | return channel_attention_weights * x
71 |
72 |
73 | class MaskedGlobalMaxPool2d(nn.Module):
74 | def __init__(self):
75 | super().__init__()
76 | self.global_max_pool = FastGlobalMaxPool2d()
77 |
78 | def forward(self, x, mask):
79 | return torch.cat((
80 | self.global_max_pool(x * mask),
81 | self.global_max_pool(x * (1.0 - mask))
82 | ), dim=1)
83 |
84 |
85 | class FastGlobalAvgPool2d(nn.Module):
86 | def __init__(self):
87 | super(FastGlobalAvgPool2d, self).__init__()
88 |
89 | def forward(self, x):
90 | in_size = x.size()
91 | return x.view((in_size[0], in_size[1], -1)).mean(dim=2)
92 |
93 |
94 | class FastGlobalMaxPool2d(nn.Module):
95 | def __init__(self):
96 | super(FastGlobalMaxPool2d, self).__init__()
97 |
98 | def forward(self, x):
99 | in_size = x.size()
100 | return x.view((in_size[0], in_size[1], -1)).max(dim=2)[0]
101 |
102 |
103 | class ScaleLayer(nn.Module):
104 | def __init__(self, init_value=1.0, lr_mult=1):
105 | super().__init__()
106 | self.lr_mult = lr_mult
107 | self.scale = nn.Parameter(
108 | torch.full((1,), init_value / lr_mult, dtype=torch.float32)
109 | )
110 |
111 | def forward(self, x):
112 | scale = torch.abs(self.scale * self.lr_mult)
113 | return x * scale
114 |
115 |
116 | class FeaturesConnector(nn.Module):
117 | def __init__(self, mode, in_channels, feature_channels, out_channels):
118 | super(FeaturesConnector, self).__init__()
119 | self.mode = mode if feature_channels else ''
120 |
121 | if self.mode == 'catc':
122 | self.reduce_conv = nn.Conv2d(in_channels + feature_channels, out_channels, kernel_size=1)
123 | elif self.mode == 'sum':
124 | self.reduce_conv = nn.Conv2d(feature_channels, out_channels, kernel_size=1)
125 |
126 | self.output_channels = out_channels if self.mode != 'cat' else in_channels + feature_channels
127 |
128 | def forward(self, x, features):
129 | if self.mode == 'cat':
130 | return torch.cat((x, features), 1)
131 | if self.mode == 'catc':
132 | return self.reduce_conv(torch.cat((x, features), 1))
133 | if self.mode == 'sum':
134 | return self.reduce_conv(features) + x
135 | return x
136 |
137 | def extra_repr(self):
138 | return self.mode
139 |
--------------------------------------------------------------------------------
/iharm/model/syncbn/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Tamaki Kojima
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/iharm/model/syncbn/README.md:
--------------------------------------------------------------------------------
1 | # pytorch-syncbn
2 |
3 | Tamaki Kojima(tamakoji@gmail.com)
4 |
5 | ## Announcement
6 |
7 | **Pytorch 1.0 support**
8 |
9 | ## Overview
10 | This is alternative implementation of "Synchronized Multi-GPU Batch Normalization" which computes global stats across gpus instead of locally computed. SyncBN are getting important for those input image is large, and must use multi-gpu to increase the minibatch-size for the training.
11 |
12 | The code was inspired by [Pytorch-Encoding](https://github.com/zhanghang1989/PyTorch-Encoding) and [Inplace-ABN](https://github.com/mapillary/inplace_abn)
13 |
14 | ## Remarks
15 | - Unlike [Pytorch-Encoding](https://github.com/zhanghang1989/PyTorch-Encoding), you don't need custom `nn.DataParallel`
16 | - Unlike [Inplace-ABN](https://github.com/mapillary/inplace_abn), you can just replace your `nn.BatchNorm2d` to this module implementation, since it will not mark for inplace operation
17 | - You can plug into arbitrary module written in PyTorch to enable Synchronized BatchNorm
18 | - Backward computation is rewritten and tested against behavior of `nn.BatchNorm2d`
19 |
20 | ## Requirements
21 | For PyTorch, please refer to https://pytorch.org/
22 |
23 | NOTE : The code is tested only with PyTorch v1.0.0, CUDA10/CuDNN7.4.2 on ubuntu18.04
24 |
25 | It utilize Pytorch JIT mechanism to compile seamlessly, using ninja. Please install ninja-build before use.
26 |
27 | ```
28 | sudo apt-get install ninja-build
29 | ```
30 |
31 | Also install all dependencies for python. For pip, run:
32 |
33 |
34 | ```
35 | pip install -U -r requirements.txt
36 | ```
37 |
38 | ## Build
39 |
40 | There is no need to build. just run and JIT will take care.
41 | JIT and cpp extensions are supported after PyTorch0.4, however it is highly recommended to use PyTorch > 1.0 due to huge design changes.
42 |
43 | ## Usage
44 |
45 | Please refer to [`test.py`](./test.py) for testing the difference between `nn.BatchNorm2d` and `modules.nn.BatchNorm2d`
46 |
47 | ```
48 | import torch
49 | from modules import nn as NN
50 | num_gpu = torch.cuda.device_count()
51 | model = nn.Sequential(
52 | nn.Conv2d(3, 3, 1, 1, bias=False),
53 | NN.BatchNorm2d(3),
54 | nn.ReLU(inplace=True),
55 | nn.Conv2d(3, 3, 1, 1, bias=False),
56 | NN.BatchNorm2d(3),
57 | ).cuda()
58 | model = nn.DataParallel(model, device_ids=range(num_gpu))
59 | x = torch.rand(num_gpu, 3, 2, 2).cuda()
60 | z = model(x)
61 | ```
62 |
63 | ## Math
64 |
65 | ### Forward
66 | 1. compute
in each gpu
67 | 2. gather all
from workers to master and compute
where
68 |
69 |
70 |
71 | and
72 |
73 |
74 |
75 | and then above global stats to be shared to all gpus, update running_mean and running_var by moving average using global stats.
76 |
77 | 3. forward batchnorm using global stats by
78 |
79 |
80 |
81 | and then
82 |
83 |
84 |
85 | where
is weight parameter and
is bias parameter.
86 |
87 | 4. save
for backward
88 |
89 | ### Backward
90 |
91 | 1. Restore saved
92 |
93 | 2. Compute below sums on each gpu
94 |
95 |
96 |
97 | and
98 |
99 |
100 |
101 | where
102 |
103 | then gather them at master node to sum up global, and normalize with N where N is total number of elements for each channels. Global sums are then shared among all gpus.
104 |
105 | 3. compute gradients using global stats
106 |
107 |
108 |
109 | where
110 |
111 |
112 |
113 | and
114 |
115 |
116 |
117 | and finally,
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 | Note that in the implementation, normalization with N is performed at step (2) and above equation and implementation is not exactly the same, but mathematically is same.
126 |
127 | You can go deeper on above explanation at [Kevin Zakka's Blog](https://kevinzakka.github.io/2016/09/14/batch_normalization/)
--------------------------------------------------------------------------------
/iharm/model/syncbn/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/DucoNet-Image-Harmonization/167cd720d91f0f5d9e64b54c3c8d4300d8915be2/iharm/model/syncbn/__init__.py
--------------------------------------------------------------------------------
/iharm/model/syncbn/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/DucoNet-Image-Harmonization/167cd720d91f0f5d9e64b54c3c8d4300d8915be2/iharm/model/syncbn/modules/__init__.py
--------------------------------------------------------------------------------
/iharm/model/syncbn/modules/functional/__init__.py:
--------------------------------------------------------------------------------
1 | from .syncbn import batchnorm2d_sync
2 |
--------------------------------------------------------------------------------
/iharm/model/syncbn/modules/functional/_csrc.py:
--------------------------------------------------------------------------------
1 | """
2 | /*****************************************************************************/
3 |
4 | Extension module loader
5 |
6 | code referenced from : https://github.com/facebookresearch/maskrcnn-benchmark
7 |
8 | /*****************************************************************************/
9 | """
10 | from __future__ import absolute_import
11 | from __future__ import division
12 | from __future__ import print_function
13 |
14 | import glob
15 | import os.path
16 |
17 | import torch
18 |
19 | try:
20 | from torch.utils.cpp_extension import load
21 | from torch.utils.cpp_extension import CUDA_HOME
22 | except ImportError:
23 | raise ImportError(
24 | "The cpp layer extensions requires PyTorch 0.4 or higher")
25 |
26 |
27 | def _load_C_extensions():
28 | this_dir = os.path.dirname(os.path.abspath(__file__))
29 | this_dir = os.path.join(this_dir, "csrc")
30 |
31 | main_file = glob.glob(os.path.join(this_dir, "*.cpp"))
32 | sources_cpu = glob.glob(os.path.join(this_dir, "cpu", "*.cpp"))
33 | sources_cuda = glob.glob(os.path.join(this_dir, "cuda", "*.cu"))
34 |
35 | sources = main_file + sources_cpu
36 |
37 | extra_cflags = []
38 | extra_cuda_cflags = []
39 | if torch.cuda.is_available() and CUDA_HOME is not None:
40 | sources.extend(sources_cuda)
41 | extra_cflags = ["-O3", "-DWITH_CUDA"]
42 | extra_cuda_cflags = ["--expt-extended-lambda"]
43 | sources = [os.path.join(this_dir, s) for s in sources]
44 | extra_include_paths = [this_dir]
45 | return load(
46 | name="ext_lib",
47 | sources=sources,
48 | extra_cflags=extra_cflags,
49 | extra_include_paths=extra_include_paths,
50 | extra_cuda_cflags=extra_cuda_cflags,
51 | )
52 |
53 |
54 | _backend = _load_C_extensions()
55 |
--------------------------------------------------------------------------------
/iharm/model/syncbn/modules/functional/csrc/bn.h:
--------------------------------------------------------------------------------
1 | /*****************************************************************************
2 |
3 | SyncBN
4 |
5 | *****************************************************************************/
6 | #pragma once
7 |
8 | #ifdef WITH_CUDA
9 | #include "cuda/ext_lib.h"
10 | #endif
11 |
12 | /// SyncBN
13 |
14 | std::vector syncbn_sum_sqsum(const at::Tensor& x) {
15 | if (x.is_cuda()) {
16 | #ifdef WITH_CUDA
17 | return syncbn_sum_sqsum_cuda(x);
18 | #else
19 | AT_ERROR("Not compiled with GPU support");
20 | #endif
21 | } else {
22 | AT_ERROR("CPU implementation not supported");
23 | }
24 | }
25 |
26 | at::Tensor syncbn_forward(const at::Tensor& x, const at::Tensor& weight,
27 | const at::Tensor& bias, const at::Tensor& mean,
28 | const at::Tensor& var, bool affine, float eps) {
29 | if (x.is_cuda()) {
30 | #ifdef WITH_CUDA
31 | return syncbn_forward_cuda(x, weight, bias, mean, var, affine, eps);
32 | #else
33 | AT_ERROR("Not compiled with GPU support");
34 | #endif
35 | } else {
36 | AT_ERROR("CPU implementation not supported");
37 | }
38 | }
39 |
40 | std::vector syncbn_backward_xhat(const at::Tensor& dz,
41 | const at::Tensor& x,
42 | const at::Tensor& mean,
43 | const at::Tensor& var, float eps) {
44 | if (dz.is_cuda()) {
45 | #ifdef WITH_CUDA
46 | return syncbn_backward_xhat_cuda(dz, x, mean, var, eps);
47 | #else
48 | AT_ERROR("Not compiled with GPU support");
49 | #endif
50 | } else {
51 | AT_ERROR("CPU implementation not supported");
52 | }
53 | }
54 |
55 | std::vector syncbn_backward(
56 | const at::Tensor& dz, const at::Tensor& x, const at::Tensor& weight,
57 | const at::Tensor& bias, const at::Tensor& mean, const at::Tensor& var,
58 | const at::Tensor& sum_dz, const at::Tensor& sum_dz_xhat, bool affine,
59 | float eps) {
60 | if (dz.is_cuda()) {
61 | #ifdef WITH_CUDA
62 | return syncbn_backward_cuda(dz, x, weight, bias, mean, var, sum_dz,
63 | sum_dz_xhat, affine, eps);
64 | #else
65 | AT_ERROR("Not compiled with GPU support");
66 | #endif
67 | } else {
68 | AT_ERROR("CPU implementation not supported");
69 | }
70 | }
71 |
--------------------------------------------------------------------------------
/iharm/model/syncbn/modules/functional/csrc/cuda/bn_cuda.cu:
--------------------------------------------------------------------------------
1 | /*****************************************************************************
2 |
3 | CUDA SyncBN code
4 |
5 | code referenced from : https://github.com/mapillary/inplace_abn
6 |
7 | *****************************************************************************/
8 | #include
9 | #include
10 | #include
11 | #include
12 | #include "cuda/common.h"
13 |
14 | // Utilities
15 | void get_dims(at::Tensor x, int64_t &num, int64_t &chn, int64_t &sp) {
16 | num = x.size(0);
17 | chn = x.size(1);
18 | sp = 1;
19 | for (int64_t i = 2; i < x.ndimension(); ++i) sp *= x.size(i);
20 | }
21 |
22 | /// SyncBN
23 |
24 | template
25 | struct SqSumOp {
26 | __device__ SqSumOp(const T *t, int c, int s) : tensor(t), chn(c), sp(s) {}
27 | __device__ __forceinline__ Pair operator()(int batch, int plane, int n) {
28 | T x = tensor[(batch * chn + plane) * sp + n];
29 | return Pair(x, x * x); // x, x^2
30 | }
31 | const T *tensor;
32 | const int chn;
33 | const int sp;
34 | };
35 |
36 | template
37 | __global__ void syncbn_sum_sqsum_kernel(const T *x, T *sum, T *sqsum,
38 | int num, int chn, int sp) {
39 | int plane = blockIdx.x;
40 | Pair res =
41 | reduce, SqSumOp>(SqSumOp(x, chn, sp), plane, num, chn, sp);
42 | __syncthreads();
43 | if (threadIdx.x == 0) {
44 | sum[plane] = res.v1;
45 | sqsum[plane] = res.v2;
46 | }
47 | }
48 |
49 | std::vector syncbn_sum_sqsum_cuda(const at::Tensor &x) {
50 | CHECK_INPUT(x);
51 |
52 | // Extract dimensions
53 | int64_t num, chn, sp;
54 | get_dims(x, num, chn, sp);
55 |
56 | // Prepare output tensors
57 | auto sum = at::empty({chn}, x.options());
58 | auto sqsum = at::empty({chn}, x.options());
59 |
60 | // Run kernel
61 | dim3 blocks(chn);
62 | dim3 threads(getNumThreads(sp));
63 | AT_DISPATCH_FLOATING_TYPES(
64 | x.type(), "syncbn_sum_sqsum_cuda", ([&] {
65 | syncbn_sum_sqsum_kernel<<>>(
66 | x.data(), sum.data(),
67 | sqsum.data(), num, chn, sp);
68 | }));
69 | return {sum, sqsum};
70 | }
71 |
72 | template
73 | __global__ void syncbn_forward_kernel(T *z, const T *x, const T *weight,
74 | const T *bias, const T *mean,
75 | const T *var, bool affine, float eps,
76 | int num, int chn, int sp) {
77 | int plane = blockIdx.x;
78 | T _mean = mean[plane];
79 | T _var = var[plane];
80 | T _weight = affine ? weight[plane] : T(1);
81 | T _bias = affine ? bias[plane] : T(0);
82 | float _invstd = T(0);
83 | if (_var || eps) {
84 | _invstd = rsqrt(_var + eps);
85 | }
86 | for (int batch = 0; batch < num; ++batch) {
87 | for (int n = threadIdx.x; n < sp; n += blockDim.x) {
88 | T _x = x[(batch * chn + plane) * sp + n];
89 | T _xhat = (_x - _mean) * _invstd;
90 | T _z = _xhat * _weight + _bias;
91 | z[(batch * chn + plane) * sp + n] = _z;
92 | }
93 | }
94 | }
95 |
96 | at::Tensor syncbn_forward_cuda(const at::Tensor &x, const at::Tensor &weight,
97 | const at::Tensor &bias, const at::Tensor &mean,
98 | const at::Tensor &var, bool affine, float eps) {
99 | CHECK_INPUT(x);
100 | CHECK_INPUT(weight);
101 | CHECK_INPUT(bias);
102 | CHECK_INPUT(mean);
103 | CHECK_INPUT(var);
104 |
105 | // Extract dimensions
106 | int64_t num, chn, sp;
107 | get_dims(x, num, chn, sp);
108 |
109 | auto z = at::zeros_like(x);
110 |
111 | // Run kernel
112 | dim3 blocks(chn);
113 | dim3 threads(getNumThreads(sp));
114 | AT_DISPATCH_FLOATING_TYPES(
115 | x.type(), "syncbn_forward_cuda", ([&] {
116 | syncbn_forward_kernel<<>>(
117 | z.data(), x.data(),
118 | weight.data(), bias.data(),
119 | mean.data(), var.data(),
120 | affine, eps, num, chn, sp);
121 | }));
122 | return z;
123 | }
124 |
125 | template
126 | struct XHatOp {
127 | __device__ XHatOp(T _weight, T _bias, const T *_dz, const T *_x, int c, int s)
128 | : weight(_weight), bias(_bias), x(_x), dz(_dz), chn(c), sp(s) {}
129 | __device__ __forceinline__ Pair operator()(int batch, int plane, int n) {
130 | // xhat = (x - bias) * weight
131 | T _xhat = (x[(batch * chn + plane) * sp + n] - bias) * weight;
132 | // dxhat * x_hat
133 | T _dz = dz[(batch * chn + plane) * sp + n];
134 | return Pair(_dz, _dz * _xhat);
135 | }
136 | const T weight;
137 | const T bias;
138 | const T *dz;
139 | const T *x;
140 | const int chn;
141 | const int sp;
142 | };
143 |
144 | template
145 | __global__ void syncbn_backward_xhat_kernel(const T *dz, const T *x,
146 | const T *mean, const T *var,
147 | T *sum_dz, T *sum_dz_xhat,
148 | float eps, int num, int chn,
149 | int sp) {
150 | int plane = blockIdx.x;
151 | T _mean = mean[plane];
152 | T _var = var[plane];
153 | T _invstd = T(0);
154 | if (_var || eps) {
155 | _invstd = rsqrt(_var + eps);
156 | }
157 | Pair res = reduce, XHatOp>(
158 | XHatOp(_invstd, _mean, dz, x, chn, sp), plane, num, chn, sp);
159 | __syncthreads();
160 | if (threadIdx.x == 0) {
161 | // \sum(\frac{dJ}{dy_i})
162 | sum_dz[plane] = res.v1;
163 | // \sum(\frac{dJ}{dy_i}*\hat{x_i})
164 | sum_dz_xhat[plane] = res.v2;
165 | }
166 | }
167 |
168 | std::vector syncbn_backward_xhat_cuda(const at::Tensor &dz,
169 | const at::Tensor &x,
170 | const at::Tensor &mean,
171 | const at::Tensor &var,
172 | float eps) {
173 | CHECK_INPUT(dz);
174 | CHECK_INPUT(x);
175 | CHECK_INPUT(mean);
176 | CHECK_INPUT(var);
177 | // Extract dimensions
178 | int64_t num, chn, sp;
179 | get_dims(x, num, chn, sp);
180 | // Prepare output tensors
181 | auto sum_dz = at::empty({chn}, x.options());
182 | auto sum_dz_xhat = at::empty({chn}, x.options());
183 | // Run kernel
184 | dim3 blocks(chn);
185 | dim3 threads(getNumThreads(sp));
186 | AT_DISPATCH_FLOATING_TYPES(
187 | x.type(), "syncbn_backward_xhat_cuda", ([&] {
188 | syncbn_backward_xhat_kernel<<>>(
189 | dz.data(), x.data(), mean.data(),
190 | var.data(), sum_dz.data(),
191 | sum_dz_xhat.data(), eps, num, chn, sp);
192 | }));
193 | return {sum_dz, sum_dz_xhat};
194 | }
195 |
196 | template
197 | __global__ void syncbn_backward_kernel(const T *dz, const T *x, const T *weight,
198 | const T *bias, const T *mean,
199 | const T *var, const T *sum_dz,
200 | const T *sum_dz_xhat, T *dx, T *dweight,
201 | T *dbias, bool affine, float eps,
202 | int num, int chn, int sp) {
203 | int plane = blockIdx.x;
204 | T _mean = mean[plane];
205 | T _var = var[plane];
206 | T _weight = affine ? weight[plane] : T(1);
207 | T _sum_dz = sum_dz[plane];
208 | T _sum_dz_xhat = sum_dz_xhat[plane];
209 | T _invstd = T(0);
210 | if (_var || eps) {
211 | _invstd = rsqrt(_var + eps);
212 | }
213 | /*
214 | \frac{dJ}{dx_i} = \frac{1}{N\sqrt{(\sigma^2+\epsilon)}} (
215 | N\frac{dJ}{d\hat{x_i}} -
216 | \sum_{j=1}^{N}(\frac{dJ}{d\hat{x_j}}) -
217 | \hat{x_i}\sum_{j=1}^{N}(\frac{dJ}{d\hat{x_j}}\hat{x_j})
218 | )
219 | Note : N is omitted here since it will be accumulated and
220 | _sum_dz and _sum_dz_xhat expected to be already normalized
221 | before the call.
222 | */
223 | if (dx) {
224 | T _mul = _weight * _invstd;
225 | for (int batch = 0; batch < num; ++batch) {
226 | for (int n = threadIdx.x; n < sp; n += blockDim.x) {
227 | T _dz = dz[(batch * chn + plane) * sp + n];
228 | T _xhat = (x[(batch * chn + plane) * sp + n] - _mean) * _invstd;
229 | T _dx = (_dz - _sum_dz - _xhat * _sum_dz_xhat) * _mul;
230 | dx[(batch * chn + plane) * sp + n] = _dx;
231 | }
232 | }
233 | }
234 | __syncthreads();
235 | if (threadIdx.x == 0) {
236 | if (affine) {
237 | T _norm = num * sp;
238 | dweight[plane] += _sum_dz_xhat * _norm;
239 | dbias[plane] += _sum_dz * _norm;
240 | }
241 | }
242 | }
243 |
244 | std::vector syncbn_backward_cuda(
245 | const at::Tensor &dz, const at::Tensor &x, const at::Tensor &weight,
246 | const at::Tensor &bias, const at::Tensor &mean, const at::Tensor &var,
247 | const at::Tensor &sum_dz, const at::Tensor &sum_dz_xhat, bool affine,
248 | float eps) {
249 | CHECK_INPUT(dz);
250 | CHECK_INPUT(x);
251 | CHECK_INPUT(weight);
252 | CHECK_INPUT(bias);
253 | CHECK_INPUT(mean);
254 | CHECK_INPUT(var);
255 | CHECK_INPUT(sum_dz);
256 | CHECK_INPUT(sum_dz_xhat);
257 |
258 | // Extract dimensions
259 | int64_t num, chn, sp;
260 | get_dims(x, num, chn, sp);
261 |
262 | // Prepare output tensors
263 | auto dx = at::zeros_like(dz);
264 | auto dweight = at::zeros_like(weight);
265 | auto dbias = at::zeros_like(bias);
266 |
267 | // Run kernel
268 | dim3 blocks(chn);
269 | dim3 threads(getNumThreads(sp));
270 | AT_DISPATCH_FLOATING_TYPES(
271 | x.type(), "syncbn_backward_cuda", ([&] {
272 | syncbn_backward_kernel<<>>(
273 | dz.data(), x.data(), weight.data(),
274 | bias.data(), mean.data(), var.data(),
275 | sum_dz.data(), sum_dz_xhat.data(),
276 | dx.data(), dweight.data(),
277 | dbias.data(), affine, eps, num, chn, sp);
278 | }));
279 | return {dx, dweight, dbias};
280 | }
--------------------------------------------------------------------------------
/iharm/model/syncbn/modules/functional/csrc/cuda/common.h:
--------------------------------------------------------------------------------
1 | /*****************************************************************************
2 |
3 | CUDA utility funcs
4 |
5 | code referenced from : https://github.com/mapillary/inplace_abn
6 |
7 | *****************************************************************************/
8 | #pragma once
9 |
10 | #include
11 |
12 | // Checks
13 | #ifndef AT_CHECK
14 | #define AT_CHECK AT_ASSERT
15 | #endif
16 | #define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
17 | #define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous")
18 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
19 |
20 | /*
21 | * General settings
22 | */
23 | const int WARP_SIZE = 32;
24 | const int MAX_BLOCK_SIZE = 512;
25 |
26 | template
27 | struct Pair {
28 | T v1, v2;
29 | __device__ Pair() {}
30 | __device__ Pair(T _v1, T _v2) : v1(_v1), v2(_v2) {}
31 | __device__ Pair(T v) : v1(v), v2(v) {}
32 | __device__ Pair(int v) : v1(v), v2(v) {}
33 | __device__ Pair &operator+=(const Pair &a) {
34 | v1 += a.v1;
35 | v2 += a.v2;
36 | return *this;
37 | }
38 | };
39 |
40 | /*
41 | * Utility functions
42 | */
43 | template
44 | __device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask,
45 | int width = warpSize,
46 | unsigned int mask = 0xffffffff) {
47 | #if CUDART_VERSION >= 9000
48 | return __shfl_xor_sync(mask, value, laneMask, width);
49 | #else
50 | return __shfl_xor(value, laneMask, width);
51 | #endif
52 | }
53 |
54 | __device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); }
55 |
56 | static int getNumThreads(int nElem) {
57 | int threadSizes[5] = {32, 64, 128, 256, MAX_BLOCK_SIZE};
58 | for (int i = 0; i != 5; ++i) {
59 | if (nElem <= threadSizes[i]) {
60 | return threadSizes[i];
61 | }
62 | }
63 | return MAX_BLOCK_SIZE;
64 | }
65 |
66 | template
67 | static __device__ __forceinline__ T warpSum(T val) {
68 | #if __CUDA_ARCH__ >= 300
69 | for (int i = 0; i < getMSB(WARP_SIZE); ++i) {
70 | val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE);
71 | }
72 | #else
73 | __shared__ T values[MAX_BLOCK_SIZE];
74 | values[threadIdx.x] = val;
75 | __threadfence_block();
76 | const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE;
77 | for (int i = 1; i < WARP_SIZE; i++) {
78 | val += values[base + ((i + threadIdx.x) % WARP_SIZE)];
79 | }
80 | #endif
81 | return val;
82 | }
83 |
84 | template
85 | static __device__ __forceinline__ Pair warpSum(Pair value) {
86 | value.v1 = warpSum(value.v1);
87 | value.v2 = warpSum(value.v2);
88 | return value;
89 | }
90 |
91 | template
92 | __device__ T reduce(Op op, int plane, int N, int C, int S) {
93 | T sum = (T)0;
94 | for (int batch = 0; batch < N; ++batch) {
95 | for (int x = threadIdx.x; x < S; x += blockDim.x) {
96 | sum += op(batch, plane, x);
97 | }
98 | }
99 |
100 | // sum over NumThreads within a warp
101 | sum = warpSum(sum);
102 |
103 | // 'transpose', and reduce within warp again
104 | __shared__ T shared[32];
105 | __syncthreads();
106 | if (threadIdx.x % WARP_SIZE == 0) {
107 | shared[threadIdx.x / WARP_SIZE] = sum;
108 | }
109 | if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) {
110 | // zero out the other entries in shared
111 | shared[threadIdx.x] = (T)0;
112 | }
113 | __syncthreads();
114 | if (threadIdx.x / WARP_SIZE == 0) {
115 | sum = warpSum(shared[threadIdx.x]);
116 | if (threadIdx.x == 0) {
117 | shared[0] = sum;
118 | }
119 | }
120 | __syncthreads();
121 |
122 | // Everyone picks it up, should be broadcast into the whole gradInput
123 | return shared[0];
124 | }
--------------------------------------------------------------------------------
/iharm/model/syncbn/modules/functional/csrc/cuda/ext_lib.h:
--------------------------------------------------------------------------------
1 | /*****************************************************************************
2 |
3 | CUDA SyncBN code
4 |
5 | *****************************************************************************/
6 | #pragma once
7 | #include
8 | #include
9 |
10 | /// Sync-BN
11 | std::vector syncbn_sum_sqsum_cuda(const at::Tensor& x);
12 | at::Tensor syncbn_forward_cuda(const at::Tensor& x, const at::Tensor& weight,
13 | const at::Tensor& bias, const at::Tensor& mean,
14 | const at::Tensor& var, bool affine, float eps);
15 | std::vector syncbn_backward_xhat_cuda(const at::Tensor& dz,
16 | const at::Tensor& x,
17 | const at::Tensor& mean,
18 | const at::Tensor& var,
19 | float eps);
20 | std::vector syncbn_backward_cuda(
21 | const at::Tensor& dz, const at::Tensor& x, const at::Tensor& weight,
22 | const at::Tensor& bias, const at::Tensor& mean, const at::Tensor& var,
23 | const at::Tensor& sum_dz, const at::Tensor& sum_dz_xhat, bool affine,
24 | float eps);
25 |
--------------------------------------------------------------------------------
/iharm/model/syncbn/modules/functional/csrc/ext_lib.cpp:
--------------------------------------------------------------------------------
1 | #include "bn.h"
2 |
3 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
4 | m.def("syncbn_sum_sqsum", &syncbn_sum_sqsum, "Sum and Sum^2 computation");
5 | m.def("syncbn_forward", &syncbn_forward, "SyncBN forward computation");
6 | m.def("syncbn_backward_xhat", &syncbn_backward_xhat,
7 | "First part of SyncBN backward computation");
8 | m.def("syncbn_backward", &syncbn_backward,
9 | "Second part of SyncBN backward computation");
10 | }
--------------------------------------------------------------------------------
/iharm/model/syncbn/modules/functional/syncbn.py:
--------------------------------------------------------------------------------
1 | """
2 | /*****************************************************************************/
3 |
4 | BatchNorm2dSync with multi-gpu
5 |
6 | code referenced from : https://github.com/mapillary/inplace_abn
7 |
8 | /*****************************************************************************/
9 | """
10 | from __future__ import absolute_import
11 | from __future__ import division
12 | from __future__ import print_function
13 |
14 | import torch.cuda.comm as comm
15 | from torch.autograd import Function
16 | from torch.autograd.function import once_differentiable
17 | from ._csrc import _backend
18 |
19 |
20 | def _count_samples(x):
21 | count = 1
22 | for i, s in enumerate(x.size()):
23 | if i != 1:
24 | count *= s
25 | return count
26 |
27 |
28 | class BatchNorm2dSyncFunc(Function):
29 |
30 | @staticmethod
31 | def forward(ctx, x, weight, bias, running_mean, running_var,
32 | extra, compute_stats=True, momentum=0.1, eps=1e-05):
33 | def _parse_extra(ctx, extra):
34 | ctx.is_master = extra["is_master"]
35 | if ctx.is_master:
36 | ctx.master_queue = extra["master_queue"]
37 | ctx.worker_queues = extra["worker_queues"]
38 | ctx.worker_ids = extra["worker_ids"]
39 | else:
40 | ctx.master_queue = extra["master_queue"]
41 | ctx.worker_queue = extra["worker_queue"]
42 | # Save context
43 | if extra is not None:
44 | _parse_extra(ctx, extra)
45 | ctx.compute_stats = compute_stats
46 | ctx.momentum = momentum
47 | ctx.eps = eps
48 | ctx.affine = weight is not None and bias is not None
49 | if ctx.compute_stats:
50 | N = _count_samples(x) * (ctx.master_queue.maxsize + 1)
51 | assert N > 1
52 | # 1. compute sum(x) and sum(x^2)
53 | xsum, xsqsum = _backend.syncbn_sum_sqsum(x.detach())
54 | if ctx.is_master:
55 | xsums, xsqsums = [xsum], [xsqsum]
56 | # master : gatther all sum(x) and sum(x^2) from slaves
57 | for _ in range(ctx.master_queue.maxsize):
58 | xsum_w, xsqsum_w = ctx.master_queue.get()
59 | ctx.master_queue.task_done()
60 | xsums.append(xsum_w)
61 | xsqsums.append(xsqsum_w)
62 | xsum = comm.reduce_add(xsums)
63 | xsqsum = comm.reduce_add(xsqsums)
64 | mean = xsum / N
65 | sumvar = xsqsum - xsum * mean
66 | var = sumvar / N
67 | uvar = sumvar / (N - 1)
68 | # master : broadcast global mean, variance to all slaves
69 | tensors = comm.broadcast_coalesced(
70 | (mean, uvar, var), [mean.get_device()] + ctx.worker_ids)
71 | for ts, queue in zip(tensors[1:], ctx.worker_queues):
72 | queue.put(ts)
73 | else:
74 | # slave : send sum(x) and sum(x^2) to master
75 | ctx.master_queue.put((xsum, xsqsum))
76 | # slave : get global mean and variance
77 | mean, uvar, var = ctx.worker_queue.get()
78 | ctx.worker_queue.task_done()
79 |
80 | # Update running stats
81 | running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
82 | running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * uvar)
83 | ctx.N = N
84 | ctx.save_for_backward(x, weight, bias, mean, var)
85 | else:
86 | mean, var = running_mean, running_var
87 |
88 | # do batch norm forward
89 | z = _backend.syncbn_forward(x, weight, bias, mean, var,
90 | ctx.affine, ctx.eps)
91 | return z
92 |
93 | @staticmethod
94 | @once_differentiable
95 | def backward(ctx, dz):
96 | x, weight, bias, mean, var = ctx.saved_tensors
97 | dz = dz.contiguous()
98 |
99 | # 1. compute \sum(\frac{dJ}{dy_i}) and \sum(\frac{dJ}{dy_i}*\hat{x_i})
100 | sum_dz, sum_dz_xhat = _backend.syncbn_backward_xhat(
101 | dz, x, mean, var, ctx.eps)
102 | if ctx.is_master:
103 | sum_dzs, sum_dz_xhats = [sum_dz], [sum_dz_xhat]
104 | # master : gatther from slaves
105 | for _ in range(ctx.master_queue.maxsize):
106 | sum_dz_w, sum_dz_xhat_w = ctx.master_queue.get()
107 | ctx.master_queue.task_done()
108 | sum_dzs.append(sum_dz_w)
109 | sum_dz_xhats.append(sum_dz_xhat_w)
110 | # master : compute global stats
111 | sum_dz = comm.reduce_add(sum_dzs)
112 | sum_dz_xhat = comm.reduce_add(sum_dz_xhats)
113 | sum_dz /= ctx.N
114 | sum_dz_xhat /= ctx.N
115 | # master : broadcast global stats
116 | tensors = comm.broadcast_coalesced(
117 | (sum_dz, sum_dz_xhat), [mean.get_device()] + ctx.worker_ids)
118 | for ts, queue in zip(tensors[1:], ctx.worker_queues):
119 | queue.put(ts)
120 | else:
121 | # slave : send to master
122 | ctx.master_queue.put((sum_dz, sum_dz_xhat))
123 | # slave : get global stats
124 | sum_dz, sum_dz_xhat = ctx.worker_queue.get()
125 | ctx.worker_queue.task_done()
126 |
127 | # do batch norm backward
128 | dx, dweight, dbias = _backend.syncbn_backward(
129 | dz, x, weight, bias, mean, var, sum_dz, sum_dz_xhat,
130 | ctx.affine, ctx.eps)
131 |
132 | return dx, dweight, dbias, \
133 | None, None, None, None, None, None
134 |
135 | batchnorm2d_sync = BatchNorm2dSyncFunc.apply
136 |
137 | __all__ = ["batchnorm2d_sync"]
138 |
--------------------------------------------------------------------------------
/iharm/model/syncbn/modules/nn/__init__.py:
--------------------------------------------------------------------------------
1 | from .syncbn import *
2 |
--------------------------------------------------------------------------------
/iharm/model/syncbn/modules/nn/syncbn.py:
--------------------------------------------------------------------------------
1 | """
2 | /*****************************************************************************/
3 |
4 | BatchNorm2dSync with multi-gpu
5 |
6 | /*****************************************************************************/
7 | """
8 | from __future__ import absolute_import
9 | from __future__ import division
10 | from __future__ import print_function
11 |
12 | try:
13 | # python 3
14 | from queue import Queue
15 | except ImportError:
16 | # python 2
17 | from Queue import Queue
18 |
19 | import torch
20 | import torch.nn as nn
21 | from torch.nn import functional as F
22 | from torch.nn.parameter import Parameter
23 | from iharm.model.syncbn.modules.functional import batchnorm2d_sync
24 |
25 |
26 | class _BatchNorm(nn.Module):
27 | """
28 | Customized BatchNorm from nn.BatchNorm
29 | >> added freeze attribute to enable bn freeze.
30 | """
31 |
32 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
33 | track_running_stats=True):
34 | super(_BatchNorm, self).__init__()
35 | self.num_features = num_features
36 | self.eps = eps
37 | self.momentum = momentum
38 | self.affine = affine
39 | self.track_running_stats = track_running_stats
40 | self.freezed = False
41 | if self.affine:
42 | self.weight = Parameter(torch.Tensor(num_features))
43 | self.bias = Parameter(torch.Tensor(num_features))
44 | else:
45 | self.register_parameter('weight', None)
46 | self.register_parameter('bias', None)
47 | if self.track_running_stats:
48 | self.register_buffer('running_mean', torch.zeros(num_features))
49 | self.register_buffer('running_var', torch.ones(num_features))
50 | else:
51 | self.register_parameter('running_mean', None)
52 | self.register_parameter('running_var', None)
53 | self.reset_parameters()
54 |
55 | def reset_parameters(self):
56 | if self.track_running_stats:
57 | self.running_mean.zero_()
58 | self.running_var.fill_(1)
59 | if self.affine:
60 | self.weight.data.uniform_()
61 | self.bias.data.zero_()
62 |
63 | def _check_input_dim(self, input):
64 | return NotImplemented
65 |
66 | def forward(self, input):
67 | self._check_input_dim(input)
68 |
69 | compute_stats = not self.freezed and \
70 | self.training and self.track_running_stats
71 |
72 | ret = F.batch_norm(input, self.running_mean, self.running_var,
73 | self.weight, self.bias, compute_stats,
74 | self.momentum, self.eps)
75 | return ret
76 |
77 | def extra_repr(self):
78 | return '{num_features}, eps={eps}, momentum={momentum}, '\
79 | 'affine={affine}, ' \
80 | 'track_running_stats={track_running_stats}'.format(
81 | **self.__dict__)
82 |
83 |
84 | class BatchNorm2dNoSync(_BatchNorm):
85 | """
86 | Equivalent to nn.BatchNorm2d
87 | """
88 |
89 | def _check_input_dim(self, input):
90 | if input.dim() != 4:
91 | raise ValueError('expected 4D input (got {}D input)'
92 | .format(input.dim()))
93 |
94 |
95 | class BatchNorm2dSync(BatchNorm2dNoSync):
96 | """
97 | BatchNorm2d with automatic multi-GPU Sync
98 | """
99 |
100 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
101 | track_running_stats=True):
102 | super(BatchNorm2dSync, self).__init__(
103 | num_features, eps=eps, momentum=momentum, affine=affine,
104 | track_running_stats=track_running_stats)
105 | self.sync_enabled = True
106 | self.devices = list(range(torch.cuda.device_count()))
107 | if len(self.devices) > 1:
108 | # Initialize queues
109 | self.worker_ids = self.devices[1:]
110 | self.master_queue = Queue(len(self.worker_ids))
111 | self.worker_queues = [Queue(1) for _ in self.worker_ids]
112 |
113 | def forward(self, x):
114 | compute_stats = not self.freezed and \
115 | self.training and self.track_running_stats
116 | if self.sync_enabled and compute_stats and len(self.devices) > 1:
117 | if x.get_device() == self.devices[0]:
118 | # Master mode
119 | extra = {
120 | "is_master": True,
121 | "master_queue": self.master_queue,
122 | "worker_queues": self.worker_queues,
123 | "worker_ids": self.worker_ids
124 | }
125 | else:
126 | # Worker mode
127 | extra = {
128 | "is_master": False,
129 | "master_queue": self.master_queue,
130 | "worker_queue": self.worker_queues[
131 | self.worker_ids.index(x.get_device())]
132 | }
133 | return batchnorm2d_sync(x, self.weight, self.bias,
134 | self.running_mean, self.running_var,
135 | extra, compute_stats, self.momentum,
136 | self.eps)
137 | return super(BatchNorm2dSync, self).forward(x)
138 |
139 | def __repr__(self):
140 | """repr"""
141 | rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \
142 | 'affine={affine}, ' \
143 | 'track_running_stats={track_running_stats},' \
144 | 'devices={devices})'
145 | return rep.format(name=self.__class__.__name__, **self.__dict__)
146 |
147 | #BatchNorm2d = BatchNorm2dNoSync
148 | BatchNorm2d = BatchNorm2dSync
149 |
--------------------------------------------------------------------------------
/iharm/utils/exp.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import shutil
4 | import pprint
5 | from pathlib import Path
6 | from datetime import datetime
7 |
8 | import yaml
9 | import torch
10 | from easydict import EasyDict as edict
11 |
12 | from .log import logger, add_new_file_output_to_logger
13 |
14 |
15 | def init_experiment(args):
16 | model_path = Path(args.model_path)
17 | ftree = get_model_family_tree(model_path)
18 | if ftree is None:
19 | print('Models can only be located in the "models" directory in the root of the repository')
20 | sys.exit(1)
21 |
22 | cfg = load_config(model_path)
23 | update_config(cfg, args)
24 |
25 | experiments_path = Path(cfg.EXPS_PATH)
26 | exp_parent_path = experiments_path / '/'.join(ftree)
27 | exp_parent_path.mkdir(parents=True, exist_ok=True)
28 |
29 | if cfg.resume_exp:
30 | exp_path = find_resume_exp(exp_parent_path, cfg.resume_exp)
31 | else:
32 | last_exp_indx = find_last_exp_indx(exp_parent_path)
33 | exp_name = f'{last_exp_indx:03d}'
34 | if cfg.exp_name:
35 | exp_name += '_' + cfg.exp_name
36 | exp_path = exp_parent_path / exp_name
37 | exp_path.mkdir(parents=True)
38 |
39 | cfg.EXP_PATH = exp_path
40 | cfg.CHECKPOINTS_PATH = exp_path / 'checkpoints'
41 | cfg.VIS_PATH = exp_path / 'vis'
42 | cfg.LOGS_PATH = exp_path / 'logs'
43 |
44 | cfg.LOGS_PATH.mkdir(exist_ok=True)
45 | cfg.CHECKPOINTS_PATH.mkdir(exist_ok=True)
46 | cfg.VIS_PATH.mkdir(exist_ok=True)
47 |
48 | dst_script_path = exp_path / (model_path.stem + datetime.strftime(datetime.today(), '_%Y-%m-%d-%H-%M-%S.py'))
49 | shutil.copy(model_path, dst_script_path)
50 |
51 | if cfg.gpus != '':
52 | gpu_ids = [int(id) for id in cfg.gpus.split(',')]
53 | else:
54 | gpu_ids = list(range(cfg.ngpus))
55 | cfg.gpus = ','.join([str(id) for id in gpu_ids])
56 | cfg.gpu_ids = gpu_ids
57 | cfg.ngpus = len(gpu_ids)
58 | cfg.multi_gpu = cfg.ngpus > 1
59 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.gpus
60 |
61 | if cfg.multi_gpu:
62 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.gpus
63 | ngpus = torch.cuda.device_count()
64 | assert ngpus == cfg.ngpus
65 | # cfg.device = torch.device(f'cuda:{cfg.gpu_ids[0]}')
66 | cfg.device = torch.device('cuda') if cfg.gpu_ids else torch.device('cpu')
67 |
68 | add_new_file_output_to_logger(cfg.LOGS_PATH, prefix='train_')
69 |
70 | logger.info(f'Number of GPUs: {len(cfg.gpu_ids)}')
71 | cfg.gpu_ids = [int(id) for id in range(cfg.ngpus)]
72 | logger.info('Run experiment with config:')
73 | logger.info(pprint.pformat(cfg, indent=4))
74 |
75 | return cfg
76 |
77 |
78 | def get_model_family_tree(model_path, terminate_name='models'):
79 | model_name = model_path.stem
80 | family_tree = [model_name]
81 | for x in model_path.parents:
82 | if x.stem == terminate_name:
83 | break
84 | family_tree.append(x.stem)
85 | else:
86 | return None
87 |
88 | return family_tree[::-1]
89 |
90 |
91 | def find_last_exp_indx(exp_parent_path):
92 | indx = 0
93 | for x in exp_parent_path.iterdir():
94 | if not x.is_dir():
95 | continue
96 |
97 | exp_name = x.stem
98 | if exp_name[:3].isnumeric():
99 | indx = max(indx, int(exp_name[:3]) + 1)
100 |
101 | return indx
102 |
103 |
104 | def find_resume_exp(exp_parent_path, exp_pattern):
105 | candidates = sorted(exp_parent_path.glob(f'{exp_pattern}*'))
106 | if len(candidates) == 0:
107 | print(f'No experiments could be found that satisfies the pattern = "*{exp_pattern}"')
108 | sys.exit(1)
109 | elif len(candidates) > 1:
110 | print('More than one experiment found:')
111 | for x in candidates:
112 | print(x)
113 | sys.exit(1)
114 | else:
115 | exp_path = candidates[0]
116 | print(f'Continue with experiment "{exp_path}"')
117 |
118 | return exp_path
119 |
120 |
121 | def update_config(cfg, args):
122 | for param_name, value in vars(args).items():
123 | if param_name.lower() in cfg or param_name.upper() in cfg:
124 | continue
125 | cfg[param_name] = value
126 |
127 |
128 | def load_config(model_path):
129 | model_name = model_path.stem
130 | config_path = model_path.parent / (model_name + '.yml')
131 |
132 | if config_path.exists():
133 | cfg = load_config_file(config_path)
134 | else:
135 | cfg = dict()
136 |
137 | cwd = Path.cwd()
138 | config_parent = config_path.parent.absolute()
139 | while len(config_parent.parents) > 0:
140 | config_path = config_parent / 'config.yml'
141 |
142 | if config_path.exists():
143 | local_config = load_config_file(config_path, model_name=model_name)
144 | cfg.update({k: v for k, v in local_config.items() if k not in cfg})
145 |
146 | if config_parent.absolute() == cwd:
147 | break
148 | config_parent = config_parent.parent
149 |
150 | return edict(cfg)
151 |
152 |
153 | def load_config_file(config_path, model_name=None, return_edict=False):
154 | with open(config_path, 'r') as f:
155 | cfg = yaml.safe_load(f)
156 |
157 | if 'SUBCONFIGS' in cfg:
158 | if model_name is not None and model_name in cfg['SUBCONFIGS']:
159 | cfg.update(cfg['SUBCONFIGS'][model_name])
160 | del cfg['SUBCONFIGS']
161 |
162 | return edict(cfg) if return_edict else cfg
163 |
--------------------------------------------------------------------------------
/iharm/utils/log.py:
--------------------------------------------------------------------------------
1 | import io
2 | import time
3 | import logging
4 | from datetime import datetime
5 |
6 | import numpy as np
7 | from torch.utils.tensorboard import SummaryWriter
8 |
9 | LOGGER_NAME = 'root'
10 | LOGGER_DATEFMT = '%Y-%m-%d %H:%M:%S'
11 |
12 | handler = logging.StreamHandler()
13 |
14 | logger = logging.getLogger(LOGGER_NAME)
15 | logger.setLevel(logging.INFO)
16 | logger.addHandler(handler)
17 |
18 |
19 | def add_new_file_output_to_logger(logs_path, prefix, only_message=False):
20 | log_name = prefix + datetime.strftime(datetime.today(), '%Y-%m-%d_%H-%M-%S') + '.log'
21 | logs_path.mkdir(exist_ok=True, parents=True)
22 | stdout_log_path = logs_path / log_name
23 |
24 | fh = logging.FileHandler(str(stdout_log_path))
25 |
26 | fmt = '%(message)s' if only_message else '(%(levelname)s) %(asctime)s: %(message)s'
27 | formatter = logging.Formatter(fmt=fmt, datefmt=LOGGER_DATEFMT)
28 | fh.setFormatter(formatter)
29 | logger.addHandler(fh)
30 |
31 |
32 | class TqdmToLogger(io.StringIO):
33 | logger = None
34 | level = None
35 | buf = ''
36 |
37 | def __init__(self, logger, level=None, mininterval=5):
38 | super(TqdmToLogger, self).__init__()
39 | self.logger = logger
40 | self.level = level or logging.INFO
41 | self.mininterval = mininterval
42 | self.last_time = 0
43 |
44 | def write(self, buf):
45 | self.buf = buf.strip('\r\n\t ')
46 |
47 | def flush(self):
48 | if len(self.buf) > 0 and time.time() - self.last_time > self.mininterval:
49 | self.logger.log(self.level, self.buf)
50 | self.last_time = time.time()
51 |
52 |
53 | class SummaryWriterAvg(SummaryWriter):
54 | def __init__(self, *args, dump_period=20, **kwargs):
55 | super().__init__(*args, **kwargs)
56 | self._dump_period = dump_period
57 | self._avg_scalars = dict()
58 |
59 | def add_scalar(self, tag, value, global_step=None, disable_avg=False):
60 | if disable_avg or isinstance(value, (tuple, list, dict)):
61 | super().add_scalar(tag, np.array(value), global_step=global_step)
62 | else:
63 | if tag not in self._avg_scalars:
64 | self._avg_scalars[tag] = ScalarAccumulator(self._dump_period)
65 | avg_scalar = self._avg_scalars[tag]
66 | avg_scalar.add(value)
67 |
68 | if avg_scalar.is_full():
69 | super().add_scalar(tag, avg_scalar.value,
70 | global_step=global_step)
71 | avg_scalar.reset()
72 |
73 |
74 | class ScalarAccumulator(object):
75 | def __init__(self, period):
76 | self.sum = 0
77 | self.cnt = 0
78 | self.period = period
79 |
80 | def add(self, value):
81 | self.sum += value
82 | self.cnt += 1
83 |
84 | @property
85 | def value(self):
86 | if self.cnt > 0:
87 | return self.sum / self.cnt
88 | else:
89 | return 0
90 |
91 | def reset(self):
92 | self.cnt = 0
93 | self.sum = 0
94 |
95 | def is_full(self):
96 | return self.cnt >= self.period
97 |
98 | def __len__(self):
99 | return self.cnt
100 |
--------------------------------------------------------------------------------
/iharm/utils/misc.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .log import logger
4 |
5 |
6 | def get_dims_with_exclusion(dim, exclude=None):
7 | dims = list(range(dim))
8 | if exclude is not None:
9 | dims.remove(exclude)
10 |
11 | return dims
12 |
13 |
14 | def save_checkpoint(net, checkpoints_path, epoch=None, prefix='', verbose=True, multi_gpu=False):
15 | if epoch is None:
16 | checkpoint_name = 'last_checkpoint.pth'
17 | else:
18 | checkpoint_name = f'{epoch:03d}.pth'
19 |
20 | if prefix:
21 | checkpoint_name = f'{prefix}_{checkpoint_name}'
22 |
23 | if not checkpoints_path.exists():
24 | checkpoints_path.mkdir(parents=True)
25 |
26 | checkpoint_path = checkpoints_path / checkpoint_name
27 | if verbose:
28 | logger.info(f'Save checkpoint to {str(checkpoint_path)}')
29 |
30 | state_dict = net.module.state_dict() if multi_gpu else net.state_dict()
31 | torch.save(state_dict, str(checkpoint_path))
32 |
33 |
34 | def load_weights(model, path_to_weights, verbose=False):
35 | if verbose:
36 | logger.info(f'Load checkpoint from path: {path_to_weights}')
37 |
38 | current_state_dict = model.state_dict()
39 | new_state_dict = torch.load(str(path_to_weights), map_location='cpu')
40 | current_state_dict.update(new_state_dict)
41 | model.load_state_dict(current_state_dict)
42 |
43 |
--------------------------------------------------------------------------------
/models/DucoNet_1024.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import torch
4 | from torchvision import transforms
5 | from easydict import EasyDict as edict
6 | from albumentations import HorizontalFlip, Resize, RandomResizedCrop
7 |
8 | from iharm.data.compose import ComposeDataset
9 | from iharm.data.hdataset import HDataset
10 | from iharm.data.transforms import HCompose
11 | from iharm.engine.simple_trainer import SimpleHTrainer
12 | from iharm.model import initializer
13 | from iharm.model.base import DucoNet_model
14 | from iharm.model.losses import MaskWeightedMSE
15 | from iharm.model.metrics import DenormalizedMSEMetric, DenormalizedPSNRMetric
16 | from iharm.utils.log import logger
17 |
18 |
19 | def main(cfg):
20 | model, model_cfg = init_model(cfg)
21 | train(model, cfg, model_cfg, start_epoch=cfg.start_epoch)
22 |
23 | def _model_init(model,_skip_init_names,cnt,max_cnt=4):
24 | if cnt>max_cnt:
25 | return
26 | for name,module in model.named_children():
27 | if name in _skip_init_names[cnt]:
28 | # print("skip:",name)
29 | _model_init(module,_skip_init_names,cnt+1,max_cnt=max_cnt)
30 | else:
31 | # print(name)
32 | module.apply(initializer.XavierGluon(rnd_type='gaussian', magnitude=1.0))
33 |
34 |
35 | def init_model(cfg):
36 | model_cfg = edict()
37 | model_cfg.crop_size = (1024,1024)
38 | model_cfg.input_normalization = {
39 | 'mean': [.485, .456, .406],
40 | 'std': [.229, .224, .225]
41 | }
42 | model_cfg.depth = 4
43 |
44 | model_cfg.input_transform = transforms.Compose([
45 | transforms.ToTensor(),
46 | transforms.Normalize(model_cfg.input_normalization['mean'], model_cfg.input_normalization['std']),
47 | ])
48 |
49 | model = DucoNet_model(
50 | depth=4, ch=32, image_fusion=True, attention_mid_k=0.5,
51 | attend_from=2, batchnorm_from=2,w_dim=256,control_module_start=cfg.control_module_start
52 | )
53 |
54 | model.to(cfg.device)
55 |
56 | _skip_init_names={
57 | 0:['decoder'],
58 | 1:['up_blocks'],
59 | 2:['0','1','2'],
60 | 3:['control_module'],
61 | 4:['a_styleblock','b_styleblock','l_styleblock']
62 | }
63 |
64 | cnt=0
65 | _model_init(model,_skip_init_names,cnt,max_cnt=4)
66 |
67 | return model, model_cfg
68 |
69 |
70 | def train(model, cfg, model_cfg, start_epoch=0):
71 | cfg.batch_size = 16 if cfg.batch_size < 1 else cfg.batch_size
72 | cfg.val_batch_size = cfg.batch_size
73 |
74 | cfg.input_normalization = model_cfg.input_normalization
75 | crop_size = model_cfg.crop_size
76 |
77 | loss_cfg = edict()
78 | loss_cfg.pixel_loss = MaskWeightedMSE(min_area=100)
79 | loss_cfg.pixel_loss_weight = 1.0
80 |
81 | num_epochs = 120
82 |
83 | train_augmentator = HCompose([
84 | RandomResizedCrop(*crop_size, scale=(0.5, 1.0)),
85 | HorizontalFlip(),
86 | ])
87 |
88 | val_augmentator = HCompose([
89 | Resize(*crop_size)
90 | ])
91 |
92 |
93 | trainset = ComposeDataset(
94 | [
95 | # HDataset(cfg.HFLICKR_PATH, split='train'),
96 | # HDataset(cfg.HDAY2NIGHT_PATH, split='train'),
97 | # HDataset(cfg.HCOCO_PATH, split='train'),
98 | HDataset(cfg.HADOBE5K1_PATH, split='train'),
99 | ],
100 | augmentator=train_augmentator,
101 | input_transform=model_cfg.input_transform,
102 | keep_background_prob=0.05,
103 | )
104 |
105 | valset = ComposeDataset(
106 | [
107 | # HDataset(cfg.HFLICKR_PATH, split='test'),
108 | # HDataset(cfg.HDAY2NIGHT_PATH, split='test'),
109 | # HDataset(cfg.HCOCO_PATH, split='test'),
110 | HDataset(cfg.HADOBE5K1_PATH, split='test'),
111 | ],
112 | augmentator=val_augmentator,
113 | input_transform=model_cfg.input_transform,
114 | keep_background_prob=-1,
115 | )
116 |
117 | optimizer_params = {
118 | 'lr': cfg.lr,
119 | 'betas': (0.9, 0.999), 'eps': 1e-8
120 | }
121 | lr_g = (cfg.batch_size / 64) ** 0.5
122 | optimizer_params['lr'] = optimizer_params['lr'] * lr_g
123 |
124 | lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR,
125 | milestones=[105, 115], gamma=0.1)
126 | trainer = SimpleHTrainer(
127 | model, cfg, model_cfg, loss_cfg,
128 | trainset, valset,
129 | optimizer='adam',
130 | optimizer_params=optimizer_params,
131 | lr_scheduler=lr_scheduler,
132 | metrics=[
133 | DenormalizedPSNRMetric(
134 | 'images', 'target_images',
135 | mean=torch.tensor(cfg.input_normalization['mean'], dtype=torch.float32).view(1, 3, 1, 1),
136 | std=torch.tensor(cfg.input_normalization['std'], dtype=torch.float32).view(1, 3, 1, 1),
137 | ),
138 | DenormalizedMSEMetric(
139 | 'images', 'target_images',
140 | mean=torch.tensor(cfg.input_normalization['mean'], dtype=torch.float32).view(1, 3, 1, 1),
141 | std=torch.tensor(cfg.input_normalization['std'], dtype=torch.float32).view(1, 3, 1, 1),
142 | )
143 | ],
144 | checkpoint_interval=10,
145 | image_dump_interval=1000
146 | )
147 |
148 | logger.info(f'Starting Epoch: {start_epoch}')
149 | logger.info(f'Total Epochs: {num_epochs}')
150 | for epoch in range(start_epoch, num_epochs):
151 | trainer.training(epoch)
152 | trainer.validation(epoch)
153 |
--------------------------------------------------------------------------------
/models/DucoNet_256.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import torch
4 | from torchvision import transforms
5 | from easydict import EasyDict as edict
6 | from albumentations import HorizontalFlip, Resize, RandomResizedCrop
7 |
8 | from iharm.data.compose import ComposeDataset
9 | from iharm.data.hdataset import HDataset
10 | from iharm.data.transforms import HCompose
11 | from iharm.engine.simple_trainer import SimpleHTrainer
12 | from iharm.model import initializer
13 | from iharm.model.base import DucoNet_model
14 | from iharm.model.losses import MaskWeightedMSE
15 | from iharm.model.metrics import DenormalizedMSEMetric, DenormalizedPSNRMetric
16 | from iharm.utils.log import logger
17 |
18 |
19 | def main(cfg):
20 | model, model_cfg = init_model(cfg)
21 | train(model, cfg, model_cfg, start_epoch=cfg.start_epoch)
22 |
23 | def _model_init(model,_skip_init_names,cnt,max_cnt=4):
24 | if cnt>max_cnt:
25 | return
26 | for name,module in model.named_children():
27 | if name in _skip_init_names[cnt]:
28 | # print("skip:",name)
29 | _model_init(module,_skip_init_names,cnt+1,max_cnt=max_cnt)
30 | else:
31 | # print(name)
32 | module.apply(initializer.XavierGluon(rnd_type='gaussian', magnitude=1.0))
33 |
34 |
35 | def init_model(cfg):
36 | model_cfg = edict()
37 | model_cfg.crop_size = (256, 256)
38 | model_cfg.input_normalization = {
39 | 'mean': [.485, .456, .406],
40 | 'std': [.229, .224, .225]
41 | }
42 | model_cfg.depth = 4
43 |
44 | model_cfg.input_transform = transforms.Compose([
45 | transforms.ToTensor(),
46 | transforms.Normalize(model_cfg.input_normalization['mean'], model_cfg.input_normalization['std']),
47 | ])
48 |
49 | model = DucoNet_model(
50 | depth=4, ch=32, image_fusion=True, attention_mid_k=0.5,
51 | attend_from=2, batchnorm_from=2,w_dim=256,control_module_start=cfg.control_module_start
52 | )
53 |
54 | model.to(cfg.device)
55 |
56 | _skip_init_names={
57 | 0:['decoder'],
58 | 1:['up_blocks'],
59 | 2:['0','1','2'],
60 | 3:['control_module'],
61 | 4:['a_styleblock','b_styleblock','l_styleblock']
62 | }
63 |
64 | cnt=0
65 | _model_init(model,_skip_init_names,cnt,max_cnt=4)
66 |
67 | return model, model_cfg
68 |
69 |
70 | def train(model, cfg, model_cfg, start_epoch=0):
71 | cfg.batch_size = 16 if cfg.batch_size < 1 else cfg.batch_size
72 | cfg.val_batch_size = cfg.batch_size
73 |
74 | cfg.input_normalization = model_cfg.input_normalization
75 | crop_size = model_cfg.crop_size
76 |
77 | loss_cfg = edict()
78 | loss_cfg.pixel_loss = MaskWeightedMSE(min_area=100)
79 | loss_cfg.pixel_loss_weight = 1.0
80 |
81 | num_epochs = 120
82 |
83 | train_augmentator = HCompose([
84 | RandomResizedCrop(*crop_size, scale=(0.5, 1.0)),
85 | HorizontalFlip(),
86 | ])
87 |
88 | val_augmentator = HCompose([
89 | Resize(*crop_size)
90 | ])
91 |
92 |
93 | trainset = ComposeDataset(
94 | [
95 | HDataset(cfg.HFLICKR_PATH, split='train'),
96 | HDataset(cfg.HDAY2NIGHT_PATH, split='train'),
97 | HDataset(cfg.HCOCO_PATH, split='train'),
98 | HDataset(cfg.HADOBE5K_PATH, split='train'),
99 | ],
100 | augmentator=train_augmentator,
101 | input_transform=model_cfg.input_transform,
102 | keep_background_prob=0.05,
103 | )
104 |
105 | valset = ComposeDataset(
106 | [
107 | HDataset(cfg.HFLICKR_PATH, split='test'),
108 | HDataset(cfg.HDAY2NIGHT_PATH, split='test'),
109 | HDataset(cfg.HCOCO_PATH, split='test'),
110 | ],
111 | augmentator=val_augmentator,
112 | input_transform=model_cfg.input_transform,
113 | keep_background_prob=-1,
114 | )
115 |
116 | optimizer_params = {
117 | 'lr': cfg.lr,
118 | 'betas': (0.9, 0.999), 'eps': 1e-8
119 | }
120 | lr_g = (cfg.batch_size / 64) ** 0.5
121 | optimizer_params['lr'] = optimizer_params['lr'] * lr_g
122 |
123 | lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR,
124 | milestones=[105, 115], gamma=0.1)
125 | trainer = SimpleHTrainer(
126 | model, cfg, model_cfg, loss_cfg,
127 | trainset, valset,
128 | optimizer='adam',
129 | optimizer_params=optimizer_params,
130 | lr_scheduler=lr_scheduler,
131 | metrics=[
132 | DenormalizedPSNRMetric(
133 | 'images', 'target_images',
134 | mean=torch.tensor(cfg.input_normalization['mean'], dtype=torch.float32).view(1, 3, 1, 1),
135 | std=torch.tensor(cfg.input_normalization['std'], dtype=torch.float32).view(1, 3, 1, 1),
136 | ),
137 | DenormalizedMSEMetric(
138 | 'images', 'target_images',
139 | mean=torch.tensor(cfg.input_normalization['mean'], dtype=torch.float32).view(1, 3, 1, 1),
140 | std=torch.tensor(cfg.input_normalization['std'], dtype=torch.float32).view(1, 3, 1, 1),
141 | )
142 | ],
143 | checkpoint_interval=10,
144 | image_dump_interval=1000
145 | )
146 |
147 | logger.info(f'Starting Epoch: {start_epoch}')
148 | logger.info(f'Total Epochs: {num_epochs}')
149 | for epoch in range(start_epoch, num_epochs):
150 | trainer.training(epoch)
151 | trainer.validation(epoch)
152 |
--------------------------------------------------------------------------------
/models/improved_ssam.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import torch
4 | from torchvision import transforms
5 | from easydict import EasyDict as edict
6 | from albumentations import HorizontalFlip, Resize, RandomResizedCrop
7 |
8 | from iharm.data.compose import ComposeDataset
9 | from iharm.data.hdataset import HDataset
10 | from iharm.data.transforms import HCompose
11 | from iharm.engine.simple_trainer import SimpleHTrainer
12 | from iharm.model import initializer
13 | from iharm.model.base import SSAMImageHarmonization
14 | from iharm.model.losses import MaskWeightedMSE
15 | from iharm.model.metrics import DenormalizedMSEMetric, DenormalizedPSNRMetric
16 | from iharm.utils.log import logger
17 |
18 |
19 | def main(cfg):
20 | model, model_cfg = init_model(cfg)
21 | train(model, cfg, model_cfg, start_epoch=cfg.start_epoch)
22 |
23 |
24 | def init_model(cfg):
25 | model_cfg = edict()
26 | model_cfg.crop_size = (256, 256)
27 | model_cfg.input_normalization = {
28 | 'mean': [.485, .456, .406],
29 | 'std': [.229, .224, .225]
30 | }
31 | model_cfg.depth = 4
32 |
33 | model_cfg.input_transform = transforms.Compose([
34 | transforms.ToTensor(),
35 | transforms.Normalize(model_cfg.input_normalization['mean'], model_cfg.input_normalization['std']),
36 | ])
37 |
38 | model = SSAMImageHarmonization(
39 | depth=4, ch=32, image_fusion=True, attention_mid_k=0.5,
40 | attend_from=2, batchnorm_from=2
41 | )
42 |
43 | model.to(cfg.device)
44 | model.apply(initializer.XavierGluon(rnd_type='gaussian', magnitude=1.0))
45 |
46 | return model, model_cfg
47 |
48 |
49 | def train(model, cfg, model_cfg, start_epoch=0):
50 | cfg.batch_size = 16 if cfg.batch_size < 1 else cfg.batch_size
51 | cfg.val_batch_size = cfg.batch_size
52 |
53 | cfg.input_normalization = model_cfg.input_normalization
54 | crop_size = model_cfg.crop_size
55 |
56 | loss_cfg = edict()
57 | loss_cfg.pixel_loss = MaskWeightedMSE()
58 | loss_cfg.pixel_loss_weight = 1.0
59 |
60 | num_epochs = 120
61 |
62 | train_augmentator = HCompose([
63 | RandomResizedCrop(*crop_size, scale=(0.5, 1.0)),
64 | HorizontalFlip(),
65 | ])
66 |
67 | val_augmentator = HCompose([
68 | Resize(*crop_size)
69 | ])
70 |
71 | trainset = ComposeDataset(
72 | [
73 | HDataset(cfg.HFLICKR_PATH, split='train'),
74 | HDataset(cfg.HDAY2NIGHT_PATH, split='train'),
75 | HDataset(cfg.HCOCO_PATH, split='train'),
76 | HDataset(cfg.HADOBE5K_PATH, split='train'),
77 | ],
78 | augmentator=train_augmentator,
79 | input_transform=model_cfg.input_transform,
80 | keep_background_prob=0.05,
81 | )
82 |
83 | valset = ComposeDataset(
84 | [
85 | HDataset(cfg.HFLICKR_PATH, split='test'),
86 | HDataset(cfg.HDAY2NIGHT_PATH, split='test'),
87 | HDataset(cfg.HCOCO_PATH, split='test'),
88 | ],
89 | augmentator=val_augmentator,
90 | input_transform=model_cfg.input_transform,
91 | keep_background_prob=-1,
92 | )
93 |
94 | optimizer_params = {
95 | 'lr': 1e-3,
96 | 'betas': (0.9, 0.999), 'eps': 1e-8
97 | }
98 |
99 | lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR,
100 | milestones=[105, 115], gamma=0.1)
101 | trainer = SimpleHTrainer(
102 | model, cfg, model_cfg, loss_cfg,
103 | trainset, valset,
104 | optimizer='adam',
105 | optimizer_params=optimizer_params,
106 | lr_scheduler=lr_scheduler,
107 | metrics=[
108 | DenormalizedPSNRMetric(
109 | 'images', 'target_images',
110 | mean=torch.tensor(cfg.input_normalization['mean'], dtype=torch.float32).view(1, 3, 1, 1),
111 | std=torch.tensor(cfg.input_normalization['std'], dtype=torch.float32).view(1, 3, 1, 1),
112 | ),
113 | DenormalizedMSEMetric(
114 | 'images', 'target_images',
115 | mean=torch.tensor(cfg.input_normalization['mean'], dtype=torch.float32).view(1, 3, 1, 1),
116 | std=torch.tensor(cfg.input_normalization['std'], dtype=torch.float32).view(1, 3, 1, 1),
117 | )
118 | ],
119 | checkpoint_interval=10,
120 | image_dump_interval=1000
121 | )
122 |
123 | logger.info(f'Starting Epoch: {start_epoch}')
124 | logger.info(f'Total Epochs: {num_epochs}')
125 | for epoch in range(start_epoch, num_epochs):
126 | trainer.training(epoch)
127 | trainer.validation(epoch)
128 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | scipy
2 | numpy
3 | Cython
4 | scikit-image
5 | opencv-python-headless
6 | Pillow
7 | matplotlib
8 | imgaug
9 | albumentations
10 | graphviz
11 | tqdm
12 | pyyaml
13 | easydict
14 | tensorboard
15 | future
16 | cffi
17 | ninja
--------------------------------------------------------------------------------
/scripts/evaluate_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | # os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"
3 |
4 | import argparse
5 | import sys
6 |
7 | sys.path.insert(0, '.')
8 |
9 | import torch
10 | from pathlib import Path
11 | from albumentations import Resize, NoOp
12 | from iharm.data.hdataset import HDataset
13 | from iharm.data.transforms import HCompose, LongestMaxSizeIfLarger
14 | from iharm.inference.predictor import Predictor
15 | from iharm.inference.evaluation import evaluate_dataset
16 | from iharm.inference.metrics import MetricsHub, MSE, fMSE, PSNR, N, AvgPredictTime
17 | from iharm.inference.utils import load_model, find_checkpoint
18 | from iharm.mconfigs import ALL_MCONFIGS
19 | from iharm.utils.exp import load_config_file
20 | from iharm.utils.log import logger, add_new_file_output_to_logger
21 |
22 |
23 | RESIZE_STRATEGIES = {
24 | 'None': NoOp(),
25 | 'LimitLongest1024': LongestMaxSizeIfLarger(1024),
26 | 'Fixed256': Resize(256, 256),
27 | 'Fixed512': Resize(512, 512),
28 | 'Fixed1024': Resize(1024, 1024)
29 | }
30 |
31 |
32 | def parse_args():
33 | parser = argparse.ArgumentParser()
34 |
35 | parser.add_argument('model_type', choices=ALL_MCONFIGS.keys())
36 | parser.add_argument('checkpoint', type=str,
37 | help='The path to the checkpoint. '
38 | 'This can be a relative path (relative to cfg.MODELS_PATH) '
39 | 'or an absolute path. The file extension can be omitted.')
40 | parser.add_argument('--datasets', type=str, default='HFlickr,HDay2Night,HCOCO,HAdobe5k1',
41 | help='Each dataset name must be one of the prefixes in config paths, '
42 | 'which look like DATASET_PATH.')
43 | parser.add_argument('--resize-strategy', type=str, choices=RESIZE_STRATEGIES.keys(), default='Fixed256')
44 | parser.add_argument('--use-flip', action='store_true', default=False,
45 | help='Use horizontal flip test-time augmentation.')
46 | parser.add_argument('--gpu', type=str, default=0, help='ID of used GPU.')
47 | parser.add_argument('--config-path', type=str, default='./config.yml',
48 | help='The path to the config file.')
49 |
50 | parser.add_argument('--eval-prefix', type=str, default='')
51 |
52 | args = parser.parse_args()
53 | cfg = load_config_file(args.config_path, return_edict=True)
54 | return args, cfg
55 |
56 |
57 | def main():
58 | args, cfg = parse_args()
59 | checkpoint_path = find_checkpoint(cfg.MODELS_PATH, args.checkpoint)
60 | add_new_file_output_to_logger(
61 | logs_path=Path(cfg.EXPS_PATH) / 'evaluation_logs',
62 | prefix=f'{Path(checkpoint_path).stem}_',
63 | only_message=True
64 | )
65 | logger.info(vars(args))
66 |
67 | device = torch.device(f'cuda:{args.gpu}')
68 | net = load_model(args.model_type, checkpoint_path, verbose=True)
69 | predictor = Predictor(net, device, with_flip=args.use_flip)
70 |
71 | datasets_names = args.datasets.split(',')
72 | datasets_metrics = []
73 | for dataset_indx, dataset_name in enumerate(datasets_names):
74 | dataset = HDataset(
75 | cfg.get(f'{dataset_name.upper()}_PATH'), split='test',
76 | augmentator=HCompose([RESIZE_STRATEGIES[args.resize_strategy]]),
77 | keep_background_prob=-1
78 | )
79 |
80 | dataset_metrics = MetricsHub([N(), MSE(), fMSE(), PSNR(), AvgPredictTime()],
81 | name=dataset_name)
82 |
83 | evaluate_dataset(dataset, predictor, dataset_metrics)
84 | datasets_metrics.append(dataset_metrics)
85 | if dataset_indx == 0:
86 | logger.info(dataset_metrics.get_table_header())
87 | logger.info(dataset_metrics)
88 |
89 | if len(datasets_metrics) > 1:
90 | overall_metrics = sum(datasets_metrics, MetricsHub([], 'Overall'))
91 | logger.info('-' * len(str(overall_metrics)))
92 | logger.info(overall_metrics)
93 |
94 |
95 | if __name__ == '__main__':
96 | main()
97 |
--------------------------------------------------------------------------------
/scripts/evaluate_model_fg_ratios.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import sys
3 |
4 | sys.path.insert(0, '.')
5 |
6 | import torch
7 | from pathlib import Path
8 | from tqdm import trange
9 |
10 | from albumentations import Resize, NoOp
11 | from iharm.data.hdataset import HDataset
12 | from iharm.data.transforms import HCompose, LongestMaxSizeIfLarger
13 | from iharm.inference.predictor import Predictor
14 | from iharm.inference.metrics import MetricsHub, MSE, fMSE, PSNR, N
15 | from iharm.inference.utils import load_model, find_checkpoint
16 | from iharm.mconfigs import ALL_MCONFIGS
17 | from iharm.utils.exp import load_config_file
18 | from iharm.utils.log import logger, add_new_file_output_to_logger
19 |
20 |
21 | RESIZE_STRATEGIES = {
22 | 'None': NoOp(),
23 | 'LimitLongest1024': LongestMaxSizeIfLarger(1024),
24 | 'Fixed256': Resize(256, 256),
25 | 'Fixed512': Resize(512, 512)
26 | }
27 |
28 |
29 | def parse_args():
30 | parser = argparse.ArgumentParser()
31 |
32 | parser.add_argument('model_type', choices=ALL_MCONFIGS.keys())
33 | parser.add_argument('checkpoint', type=str,
34 | help='The path to the checkpoint. '
35 | 'This can be a relative path (relative to cfg.MODELS_PATH) '
36 | 'or an absolute path. The file extension can be omitted.')
37 | parser.add_argument('--datasets', type=str, default='HFlickr,HDay2Night,HCOCO,HAdobe5k',
38 | help='Each dataset name must be one of the prefixes in config paths, '
39 | 'which look like DATASET_PATH.')
40 | parser.add_argument('--resize-strategy', type=str, choices=RESIZE_STRATEGIES.keys(), default='Fixed256')
41 | parser.add_argument('--use-flip', action='store_true', default=False,
42 | help='Use horizontal flip test-time augmentation.')
43 | parser.add_argument('--gpu', type=str, default=0, help='ID of used GPU.')
44 | parser.add_argument('--config-path', type=str, default='./config.yml',
45 | help='The path to the config file.')
46 | parser.add_argument('--results-path', type=str, default='',
47 | help='The path to the evaluation results. '
48 | 'Default path: cfg.EXPS_PATH/evaluation_results.')
49 |
50 | parser.add_argument('--eval-prefix', type=str, default='')
51 |
52 | args = parser.parse_args()
53 | cfg = load_config_file(args.config_path, return_edict=True)
54 | return args, cfg
55 |
56 |
57 | def main():
58 | args, cfg = parse_args()
59 | checkpoint_path = find_checkpoint(cfg.MODELS_PATH, args.checkpoint)
60 | add_new_file_output_to_logger(
61 | logs_path=Path(cfg.EXPS_PATH) / 'evaluation_results',
62 | prefix=f'{Path(checkpoint_path).stem}_',
63 | only_message=True
64 | )
65 | logger.info(vars(args))
66 |
67 | device = torch.device(f'cuda:{args.gpu}')
68 | net = load_model(args.model_type, checkpoint_path, verbose=True)
69 | predictor = Predictor(net, device, with_flip=args.use_flip)
70 |
71 | fg_ratio_intervals = [(0.0, 0.05), (0.05, 0.15), (0.15, 1.0), (0.0, 1.00)]
72 |
73 | datasets_names = args.datasets.split(',')
74 | datasets_metrics = [[] for _ in fg_ratio_intervals]
75 | for dataset_indx, dataset_name in enumerate(datasets_names):
76 | dataset = HDataset(
77 | cfg.get(f'{dataset_name.upper()}_PATH'), split='test',
78 | augmentator=HCompose([RESIZE_STRATEGIES[args.resize_strategy]]),
79 | keep_background_prob=-1
80 | )
81 |
82 | dataset_metrics = []
83 | for fg_ratio_min, fg_ratio_max in fg_ratio_intervals:
84 | dataset_metrics.append(MetricsHub([N(), MSE(), fMSE(), PSNR()],
85 | name=f'{dataset_name} ({fg_ratio_min:.0%}-{fg_ratio_max:.0%})',
86 | name_width=28))
87 |
88 | for sample_i in trange(len(dataset), desc=f'Testing on {dataset_name}'):
89 | sample = dataset.get_sample(sample_i)
90 | sample = dataset.augment_sample(sample)
91 |
92 | sample_mask = sample['object_mask']
93 | sample_fg_ratio = (sample_mask > 0.5).sum() / (sample_mask.shape[0] * sample_mask.shape[1])
94 | pred = predictor.predict(sample['image'], sample_mask, return_numpy=False)
95 |
96 | target_image = torch.as_tensor(sample['target_image'], dtype=torch.float32).to(predictor.device)
97 | sample_mask = torch.as_tensor(sample_mask, dtype=torch.float32).to(predictor.device)
98 | with torch.no_grad():
99 | for metrics_hub, (fg_ratio_min, fg_ratio_max) in zip(dataset_metrics, fg_ratio_intervals):
100 | if fg_ratio_min <= sample_fg_ratio <= fg_ratio_max:
101 | metrics_hub.compute_and_add(pred, target_image, sample_mask)
102 |
103 | for indx, metrics_hub in enumerate(dataset_metrics):
104 | datasets_metrics[indx].append(metrics_hub)
105 | if dataset_indx == 0:
106 | logger.info(dataset_metrics[-1].get_table_header())
107 | for metrics_hub in dataset_metrics:
108 | logger.info(metrics_hub)
109 |
110 | if len(datasets_metrics) > 1:
111 | overall_metrics = [sum(x, MetricsHub([], f'Overall ({fg_ratio_min:.0%}-{fg_ratio_max:.0%})', name_width=28))
112 | for x, (fg_ratio_min, fg_ratio_max) in zip(datasets_metrics, fg_ratio_intervals)]
113 | logger.info('-' * len(str(overall_metrics[-1])))
114 | for x in overall_metrics:
115 | logger.info(x)
116 |
117 |
118 | if __name__ == '__main__':
119 | main()
120 |
--------------------------------------------------------------------------------
/scripts/predict_for_dir.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import os.path as osp
4 | from pathlib import Path
5 | import sys
6 |
7 | import cv2
8 | import numpy as np
9 | import torch
10 | from tqdm import tqdm
11 |
12 | sys.path.insert(0, '.')
13 | from iharm.inference.predictor import Predictor
14 | from iharm.inference.utils import load_model, find_checkpoint
15 | from iharm.mconfigs import ALL_MCONFIGS
16 | from iharm.utils.log import logger
17 | from iharm.utils.exp import load_config_file
18 |
19 |
20 | def main():
21 | args, cfg = parse_args()
22 |
23 | device = torch.device(f'cuda:{args.gpu}')
24 | checkpoint_path = find_checkpoint(cfg.MODELS_PATH, args.checkpoint)
25 | net = load_model(args.model_type, checkpoint_path, verbose=True)
26 | predictor = Predictor(net, device)
27 |
28 | image_names = os.listdir(args.images)
29 |
30 | def _save_image(image_name, bgr_image):
31 | rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_RGB2BGR)
32 | cv2.imwrite(
33 | str(cfg.RESULTS_PATH / f'{image_name}'),
34 | rgb_image,
35 | [cv2.IMWRITE_JPEG_QUALITY, 85]
36 | )
37 |
38 | logger.info(f'Save images to {cfg.RESULTS_PATH}')
39 |
40 | resize_shape = (args.resize, ) * 2
41 | for image_name in tqdm(image_names):
42 | image_path = osp.join(args.images, image_name)
43 | image = cv2.imread(image_path)
44 | composite_image_lab = cv2.cvtColor(image, cv2.COLOR_RGB2Lab)
45 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
46 | image_size = image.shape
47 | if resize_shape[0] > 0:
48 | image = cv2.resize(image, resize_shape, cv2.INTER_LINEAR)
49 |
50 | mask_path = osp.join(args.masks, '_'.join(image_name.split('_')[:-1]) + '.png')
51 | mask_image = cv2.imread(mask_path)
52 | if resize_shape[0] > 0:
53 | mask_image = cv2.resize(mask_image, resize_shape, cv2.INTER_LINEAR)
54 | mask = mask_image[:, :, 0]
55 | mask[mask <= 100] = 0
56 | mask[mask > 100] = 1
57 | mask = mask.astype(np.float32)
58 |
59 | pred = predictor.predict(image, mask, composite_image_lab)
60 |
61 | if args.original_size:
62 | pred = cv2.resize(pred, image_size[:-1][::-1])
63 | _save_image(image_name, pred)
64 |
65 |
66 | def parse_args():
67 | parser = argparse.ArgumentParser()
68 | parser.add_argument('model_type', choices=ALL_MCONFIGS.keys())
69 | parser.add_argument('checkpoint', type=str,
70 | help='The path to the checkpoint. '
71 | 'This can be a relative path (relative to cfg.MODELS_PATH) '
72 | 'or an absolute path. The file extension can be omitted.')
73 | parser.add_argument(
74 | '--images', type=str,
75 | help='Path to directory with .jpg images to get predictions for.'
76 | )
77 | parser.add_argument(
78 | '--masks', type=str,
79 | help='Path to directory with .png binary masks for images, named exactly like images without last _postfix.'
80 | )
81 | parser.add_argument(
82 | '--resize', type=int, default=256,
83 | help='Resize image to a given size before feeding it into the network. If -1 the network input is not resized.'
84 | )
85 | parser.add_argument(
86 | '--original-size', action='store_true', default=False,
87 | help='Resize predicted image back to the original size.'
88 | )
89 | parser.add_argument('--gpu', type=str, default=0, help='ID of used GPU.')
90 | parser.add_argument('--config-path', type=str, default='./config.yml', help='The path to the config file.')
91 | parser.add_argument(
92 | '--results-path', type=str, default='',
93 | help='The path to the harmonized images. Default path: cfg.EXPS_PATH/predictions.'
94 | )
95 |
96 | args = parser.parse_args()
97 | cfg = load_config_file(args.config_path, return_edict=True)
98 | cfg.EXPS_PATH = Path(cfg.EXPS_PATH)
99 | cfg.RESULTS_PATH = Path(args.results_path) if len(args.results_path) else cfg.EXPS_PATH / 'predictions'
100 | cfg.RESULTS_PATH.mkdir(parents=True, exist_ok=True)
101 | logger.info(cfg)
102 | return args, cfg
103 |
104 |
105 | if __name__ == '__main__':
106 | main()
107 |
--------------------------------------------------------------------------------
/test.sh:
--------------------------------------------------------------------------------
1 | python scripts/evaluate_model.py DucoNet ./checkpoints/last_model/DucoNet256.pth \
2 | --resize-strategy Fixed256 \
3 | --gpu 0
4 | #python scripts/evaluate_model.py DucoNet ./checkpoints/last_model/DucoNet1024.pth \
5 | #--resize-strategy Fixed1024 \
6 | #--gpu 1 \
7 | #--datasets HAdobe5k1
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import importlib.util
3 |
4 | import torch
5 | from iharm.utils.exp import init_experiment
6 |
7 |
8 | def main():
9 | args = parse_args()
10 | model_script = load_module(args.model_path)
11 |
12 | cfg = init_experiment(args)
13 |
14 | torch.backends.cudnn.benchmark = True
15 | torch.multiprocessing.set_sharing_strategy('file_system')
16 | model_script.main(cfg)
17 |
18 |
19 | def parse_args():
20 | parser = argparse.ArgumentParser()
21 |
22 | parser.add_argument('model_path', type=str,
23 | help='Path to the model script.')
24 |
25 | parser.add_argument('--exp-name', type=str, default='',
26 | help='Here you can specify the name of the experiment. '
27 | 'It will be added as a suffix to the experiment folder.')
28 |
29 | parser.add_argument('--workers', type=int, default=4,
30 | metavar='N', help='Dataloader threads.')
31 |
32 | parser.add_argument('--batch-size', type=int, default=-1,
33 | help='You can override model batch size by specify positive number.')
34 |
35 | parser.add_argument('--ngpus', type=int, default=1,
36 | help='Number of GPUs. '
37 | 'If you only specify "--gpus" argument, the ngpus value will be calculated automatically. '
38 | 'You should use either this argument or "--gpus".')
39 |
40 | parser.add_argument('--gpus', type=str, default='', required=False,
41 | help='Ids of used GPUs. You should use either this argument or "--ngpus".')
42 |
43 | parser.add_argument('--resume-exp', type=str, default=None,
44 | help='The prefix of the name of the experiment to be continued. '
45 | 'If you use this field, you must specify the "--resume-prefix" argument.')
46 |
47 | parser.add_argument('--resume-prefix', type=str, default='latest',
48 | help='The prefix of the name of the checkpoint to be loaded.')
49 |
50 | parser.add_argument('--start-epoch', type=int, default=0,
51 | help='The number of the starting epoch from which training will continue. '
52 | '(it is important for correct logging and learning rate)')
53 |
54 | parser.add_argument('--weights', type=str, default=None,
55 | help='Model weights will be loaded from the specified path if you use this argument.')
56 |
57 | parser.add_argument('--lr', type=float, default=1e-3,
58 | help='')
59 |
60 | parser.add_argument('--control_module_start', type=int, default=-1,
61 | help='')
62 |
63 |
64 |
65 | return parser.parse_args()
66 |
67 |
68 | def load_module(script_path):
69 | spec = importlib.util.spec_from_file_location("model_script", script_path)
70 | model_script = importlib.util.module_from_spec(spec)
71 | spec.loader.exec_module(model_script)
72 |
73 | return model_script
74 |
75 |
76 | if __name__ == '__main__':
77 | main()
78 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | ## for low-resolution (256 * 256)
2 |
3 | python train.py models/DucoNet_256.py --workers=8 --gpus=0,1 --exp-name=DucoNet_256 --batch-size=32
4 |
5 | ## for high-resolution (1024 * 1024)
6 |
7 | #python train.py models/DucoNet_1024.py --workers=8 --gpus=0,1 --exp-name=DucoNet_1024 --batch-size=4
--------------------------------------------------------------------------------