├── README.md ├── config.yml ├── examples └── DucoNet.png ├── iharm ├── data │ ├── base.py │ ├── compose.py │ ├── hdataset.py │ └── transforms.py ├── engine │ ├── event_loop.py │ ├── optimizer.py │ └── simple_trainer.py ├── inference │ ├── evaluation.py │ ├── metrics.py │ ├── predictor.py │ ├── transforms.py │ └── utils.py ├── mconfigs │ ├── __init__.py │ ├── backboned.py │ └── base.py ├── model │ ├── backboned │ │ ├── __init__.py │ │ ├── deeplab.py │ │ ├── hrnet.py │ │ └── ih_model.py │ ├── base │ │ ├── Control_encoder.py │ │ ├── DucoNet_model.py │ │ ├── __init__.py │ │ ├── dih_model.py │ │ ├── iseunet_v1.py │ │ └── ssam_model.py │ ├── initializer.py │ ├── losses.py │ ├── metrics.py │ ├── modeling │ │ ├── basic_blocks.py │ │ ├── conv_autoencoder.py │ │ ├── deeplab_v3.py │ │ ├── hrnet_ocr.py │ │ ├── ocr.py │ │ ├── resnet.py │ │ ├── resnetv1b.py │ │ ├── styleganv2.py │ │ ├── unet.py │ │ └── unet_v1.py │ ├── modifiers.py │ ├── ops.py │ └── syncbn │ │ ├── LICENSE │ │ ├── README.md │ │ ├── __init__.py │ │ └── modules │ │ ├── __init__.py │ │ ├── functional │ │ ├── __init__.py │ │ ├── _csrc.py │ │ ├── csrc │ │ │ ├── bn.h │ │ │ ├── cuda │ │ │ │ ├── bn_cuda.cu │ │ │ │ ├── common.h │ │ │ │ └── ext_lib.h │ │ │ └── ext_lib.cpp │ │ └── syncbn.py │ │ └── nn │ │ ├── __init__.py │ │ └── syncbn.py └── utils │ ├── exp.py │ ├── log.py │ └── misc.py ├── models ├── DucoNet_1024.py ├── DucoNet_256.py └── improved_ssam.py ├── requirements.txt ├── scripts ├── evaluate_model.py ├── evaluate_model_fg_ratios.py └── predict_for_dir.py ├── test.sh ├── train.py └── train.sh /README.md: -------------------------------------------------------------------------------- 1 | [ACM MM-23] Deep image harmonization in Dual Color Space [ACM MM-23] 2 | 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/deep-image-harmonization-in-dual-color-spaces/image-harmonization-on-iharmony4)](https://paperswithcode.com/sota/image-harmonization-on-iharmony4?p=deep-image-harmonization-in-dual-color-spaces) 4 | 5 | 6 | This is the official repository for the following paper: 7 | 8 | > **Deep Image Harmonization in Dual Color Spaces** [[arXiv]](https://arxiv.org/abs/2308.02813)
9 | > 10 | > Linfeng Tan, Jiangtong Li, Li Niu, Liqing Zhang
11 | > Accepted by **ACMMM2023**. 12 | 13 | >Image harmonization is an essential step in image composition that adjusts the appearance of composite foreground to address the inconsistency between foreground and background. Existing methods primarily operate in correlated $RGB$ color space, leading to entangled features and limited representation ability. In contrast, decorrelated color space ($Lab$) has decorrelated channels that provide disentangled color and illumination statistics. In this paper, we explore image harmonization in dual color spaces, which supplements entangled $RGB$ features with disentangled $L$, $a$, $b$ features to alleviate the workload in harmonization process. The network comprises a $RGB$ harmonization backbone, an $Lab$ encoding module, and an $Lab$ control module. The backbone is a U-Net network translating composite image to harmonized image. Three encoders in $Lab$ encoding module extract three control codes independently from $L$, $a$, $b$ channels, which are used to manipulate the decoder features in harmonization backbone via $Lab$ control module. 14 | 15 | ![](./examples/DucoNet.png) 16 | 17 | ## Getting Started 18 | 19 | ### Prerequisites 20 | Please refer to [iSSAM](https://github.com/saic-vul/image_harmonization) for guidance on setting up the environment. 21 | 22 | ### Installation 23 | + Clone this repo: 24 | ``` 25 | git clone https://github.com/bcmi/DucoNet-Image-Harmonization.git 26 | cd ./DucoNet-Image-Harmonization 27 | ``` 28 | + Download the [iHarmony4](https://github.com/bcmi/Image-Harmonization-Dataset-iHarmony4) dataset, and configure the paths to the datasets in [config.yml](./config.yml). 29 | 30 | - Install PyTorch and dependencies from http://pytorch.org. 31 | 32 | - Install python requirements: 33 | 34 | ```bash 35 | pip install -r requirements.txt 36 | ``` 37 | 38 | ### Training 39 | If you want to train DucoNet on dataset [iHarmony4](https://github.com/bcmi/Image-Harmonization-Dataset-iHarmony4), you can run this command: 40 | 41 | ``` 42 | ## for low-resolution 43 | 44 | python train.py models/DucoNet_256.py --workers=8 --gpus=0,1 --exp-name=DucoNet_256 --batch-size=64 45 | 46 | ## for high-resolution 47 | 48 | python train.py models/DucoNet_1024.py --workers=8 --gpus=2,3 --exp-name=DucoNet_1024 --batch-size=4 49 | ``` 50 | 51 | We have also provided some commands in the "train.sh" for your convenience. 52 | 53 | ### Testing 54 | You can run the following command to test the pretrained model, and you can download the pre-trained model we released from [Dropbox](https://www.dropbox.com/scl/fo/jnq5sgokct3n0l2ix8knl/AGwf0vaEHULz2qjg1iK81jA?rlkey=v3djzqf3v1upiddb1wc7t0481&st=vz6vi7uc&dl=0) or [Baidu Cloud](https://pan.baidu.com/s/1lnDOnmN1tLeoIcvjWvWFkQ?pwd=bcmi): 55 | ``` 56 | python scripts/evaluate_model.py DucoNet ./checkpoints/last_model/DucoNet256.pth \ 57 | --resize-strategy Fixed256 \ 58 | --gpu 0 59 | 60 | #python scripts/evaluate_model.py DucoNet ./checkpoints/last_model/DucoNet1024.pth \ 61 | #--resize-strategy Fixed1024 \ 62 | #--gpu 1 \ 63 | #--datasets HAdobe5k1 64 | ``` 65 | 66 | We have also provided some commands in the "test.sh" for your convenience. 67 | 68 | ## Results and Pretrained model 69 | 70 | We test our DucoNet on iHarmony4 dataset with image size 256×256 and on HAdobe5k dataset with image size 1024×1024. We report our results on evaluation metrics, including MSE, fMSE, and PSNR. 71 | We also released the pretrained model corresponding to our results, you can download it from the corresponding link. 72 | 73 | | Image Size | fMSE | MSE | PSNR | Google Drive | Baidu Cloud | 74 | | ------------------ | ------ | ----- | ----- | ---------------- | ---------------- | 75 | | 256 $\times$ 256 | 212.53 | 18.47 | 39.17 | [Dropbox](https://www.dropbox.com/scl/fi/8rf3or6gt35wizyxpgf0y/DucoNet256.pth?rlkey=jjun6hywcy2wte8k5idgh28jv&st=l3jwirjp&dl=0) | [Baidu Cloud](https://pan.baidu.com/s/1lnDOnmN1tLeoIcvjWvWFkQ?pwd=bcmi) | 76 | | 1024 $\times$ 1024 | 80.69 | 10.94 | 41.37 | [Dropbox](https://www.dropbox.com/scl/fi/tpowgk1ezf090ezb6o4ib/DucoNet1024.pth?rlkey=gyt042f5h5igc6ypk24stnwtd&st=gk10qjv2&dl=0) | [Baidu Cloud](https://pan.baidu.com/s/1lnDOnmN1tLeoIcvjWvWFkQ?pwd=bcmi) | 77 | 78 | ## Other Resources 79 | 80 | + [Awesome-Image-Harmonization](https://github.com/bcmi/Awesome-Image-Harmonization) 81 | + [Awesome-Image-Composition](https://github.com/bcmi/Awesome-Object-Insertion) 82 | 83 | ## Acknowledgement 84 | 85 | Our code is heavily borrowed from [iSSAM](https://github.com/saic-vul/image_harmonization) and PyTorch implementation of [styleGANv2](https://github.com/labmlai/annotated_deep_learning_paper_implementations) . 86 | 87 | ## Bibtex: 88 | If you are interested in our work, please consider citing the following: 89 | 90 | ``` 91 | @article{tan2023deep, 92 | title={Deep Image Harmonization in Dual Color Spaces}, 93 | author={Tan, Linfeng and Li, Jiangtong and Niu, Li and Zhang, Liqing}, 94 | journal={arXiv preprint arXiv:2308.02813}, 95 | year={2023} 96 | } 97 | -------------------------------------------------------------------------------- /config.yml: -------------------------------------------------------------------------------- 1 | MODELS_PATH: "./" 2 | EXPS_PATH: "./checkpoints/output" 3 | 4 | HFLICKR_PATH: "/home/dataset/IHD/HFlickr" 5 | HDAY2NIGHT_PATH: "/home/dataset/IHD/Hday2night" 6 | HCOCO_PATH: "/home/dataset/IHD/HCOCO" 7 | HADOBE5K1_PATH: "/data/IHD/HAdobe5k" 8 | HADOBE5K_PATH: "/home/dataset/IHD/HAdobe5k_resized1024" 9 | 10 | 11 | IMAGENET_PRETRAINED_MODELS: 12 | HRNETV2_W18_SMALL: "./pretrained_models/hrnet_w18_small_model_v2.pth" 13 | HRNETV2_W18: "./pretrained_models/hrnetv2_w18_imagenet_pretrained.pth" 14 | HRNETV2_W32: "./pretrained_models/hrnetv2_w32_imagenet_pretrained.pth" 15 | HRNETV2_W40: "./pretrained_models/hrnetv2_w40_imagenet_pretrained.pth" 16 | HRNETV2_W48: "./pretrained_models/hrnetv2_w48_imagenet_pretrained.pth" 17 | HistNet: "./checkpoints/last_model/hen_train/last_checkpoint.pth" 18 | # HistNet_old: "./checkpoints/last_model/hen_train/last_checkpoint.pth" 19 | -------------------------------------------------------------------------------- /examples/DucoNet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/DucoNet-Image-Harmonization/167cd720d91f0f5d9e64b54c3c8d4300d8915be2/examples/DucoNet.png -------------------------------------------------------------------------------- /iharm/data/base.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | import cv2 5 | 6 | from torchvision import transforms 7 | 8 | class BaseHDataset(torch.utils.data.dataset.Dataset): 9 | def __init__(self, 10 | augmentator=None, 11 | input_transform=None, 12 | keep_background_prob=0.0, 13 | with_image_info=False, 14 | epoch_len=-1, 15 | ): 16 | super(BaseHDataset, self).__init__() 17 | self.epoch_len = epoch_len 18 | self.input_transform = input_transform 19 | self.augmentator = augmentator 20 | self.keep_background_prob = keep_background_prob 21 | self.with_image_info = with_image_info 22 | 23 | if input_transform is None: 24 | input_transform = lambda x: x 25 | 26 | self.input_transform = input_transform 27 | self.dataset_samples = None 28 | 29 | def __getitem__(self, index): 30 | if self.epoch_len > 0: 31 | index = random.randrange(0, len(self.dataset_samples)) 32 | 33 | sample = self.get_sample(index) 34 | self.check_sample_types(sample) 35 | sample = self.augment_sample(sample) 36 | 37 | image = self.input_transform(sample['image']) 38 | target_image = self.input_transform(sample['target_image']) 39 | obj_mask = sample['object_mask'].astype(np.float32) 40 | 41 | transform = transforms.Compose([ 42 | transforms.ToTensor(), 43 | ]) 44 | comp_img_lab = sample['comp_img_lab'] 45 | real_img_lab = sample['real_img_lab'] 46 | 47 | comp_image_lab = transform(comp_img_lab) 48 | real_image_lab = transform(real_img_lab) 49 | 50 | output = { 51 | 'images': image, 52 | 'masks': obj_mask[np.newaxis, ...].astype(np.float32), 53 | 'target_images': target_image, 54 | 'comp_img_lab' : comp_image_lab, 55 | 'real_img_lab' : real_image_lab, 56 | } 57 | 58 | if self.with_image_info and 'image_id' in sample: 59 | output['image_info'] = sample['image_id'] 60 | 61 | return output 62 | 63 | def check_sample_types(self, sample): 64 | assert sample['image'].dtype == 'uint8' 65 | if 'target_image' in sample: 66 | assert sample['target_image'].dtype == 'uint8' 67 | 68 | def augment_sample(self, sample): 69 | if self.augmentator is None: 70 | return sample 71 | 72 | additional_targets = {target_name: sample[target_name] 73 | for target_name in self.augmentator.additional_targets.keys()} 74 | 75 | valid_augmentation = False 76 | while not valid_augmentation: 77 | aug_output = self.augmentator(image=sample['image'], **additional_targets) 78 | valid_augmentation = self.check_augmented_sample(sample, aug_output) 79 | 80 | for target_name, transformed_target in aug_output.items(): 81 | sample[target_name] = transformed_target 82 | 83 | return sample 84 | 85 | def check_augmented_sample(self, sample, aug_output): 86 | if self.keep_background_prob < 0.0 or random.random() < self.keep_background_prob: 87 | return True 88 | 89 | return aug_output['object_mask'].sum() > 10.0 90 | 91 | def get_sample(self, index): 92 | raise NotImplementedError 93 | 94 | def __len__(self): 95 | if self.epoch_len > 0: 96 | return self.epoch_len 97 | else: 98 | return len(self.dataset_samples) 99 | 100 | -------------------------------------------------------------------------------- /iharm/data/compose.py: -------------------------------------------------------------------------------- 1 | from .base import BaseHDataset 2 | 3 | 4 | class ComposeDataset(BaseHDataset): 5 | def __init__(self, datasets, **kwargs): 6 | super(ComposeDataset, self).__init__(**kwargs) 7 | 8 | self._datasets = datasets 9 | self.dataset_samples = [] 10 | for dataset_indx, dataset in enumerate(self._datasets): 11 | self.dataset_samples.extend([(dataset_indx, i) for i in range(len(dataset))]) 12 | 13 | def get_sample(self, index): 14 | dataset_indx, sample_indx = self.dataset_samples[index] 15 | return self._datasets[dataset_indx].get_sample(sample_indx) 16 | -------------------------------------------------------------------------------- /iharm/data/hdataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import os 3 | import cv2 4 | import numpy as np 5 | 6 | from .base import BaseHDataset 7 | 8 | 9 | class HDataset(BaseHDataset): 10 | def __init__(self, dataset_path, split, blur_target=False, **kwargs): 11 | super(HDataset, self).__init__(**kwargs) 12 | 13 | self.dataset_path = Path(dataset_path) 14 | self.blur_target = blur_target 15 | self._split = split 16 | self._real_images_path = self.dataset_path / 'real_images' 17 | self._composite_images_path = self.dataset_path / 'composite_images' 18 | self._masks_path = self.dataset_path / 'masks' 19 | 20 | images_lists_paths = [x for x in self.dataset_path.glob('*.txt') if x.stem.endswith(split)] 21 | 22 | assert len(images_lists_paths) == 1 23 | 24 | with open(images_lists_paths[0], 'r') as f: 25 | self.dataset_samples = [x.strip() for x in f.readlines()] 26 | 27 | def get_sample(self, index): 28 | composite_image_name = self.dataset_samples[index] 29 | real_image_name = composite_image_name.split('_')[0] + '.jpg' 30 | mask_name = '_'.join(composite_image_name.split('_')[:-1]) + '.png' 31 | 32 | 33 | composite_image_path = str(self._composite_images_path / composite_image_name) 34 | real_image_path = str(self._real_images_path / real_image_name) 35 | mask_path = str(self._masks_path / mask_name) 36 | 37 | composite_image = cv2.imread(composite_image_path) 38 | composite_image = cv2.cvtColor(composite_image, cv2.COLOR_BGR2RGB) 39 | 40 | composite_image_lab = cv2.cvtColor(composite_image, cv2.COLOR_RGB2Lab) 41 | 42 | real_image = cv2.imread(real_image_path) 43 | real_image = cv2.cvtColor(real_image, cv2.COLOR_BGR2RGB) 44 | 45 | real_image_lab = cv2.cvtColor(real_image, cv2.COLOR_RGB2Lab) 46 | 47 | object_mask_image = cv2.imread(mask_path) 48 | object_mask = object_mask_image[:, :, 0].astype(np.float32) / 255. 49 | if self.blur_target: 50 | object_mask = cv2.GaussianBlur(object_mask, (7, 7), 0) 51 | 52 | return { 53 | 'image': composite_image, 54 | 'object_mask': object_mask, 55 | 'target_image': real_image, 56 | 'image_id': index, 57 | 'comp_img_lab' :composite_image_lab, 58 | 'real_img_lab' :real_image_lab, 59 | 'comp_img_name':composite_image_name, 60 | } 61 | -------------------------------------------------------------------------------- /iharm/data/transforms.py: -------------------------------------------------------------------------------- 1 | from albumentations import Compose, LongestMaxSize, DualTransform 2 | import albumentations.augmentations.functional as F 3 | import cv2 4 | 5 | 6 | class HCompose(Compose): 7 | def __init__(self, transforms, *args, additional_targets=None, no_nearest_for_masks=True, **kwargs): 8 | if additional_targets is None: 9 | additional_targets = { 10 | 'target_image': 'image', 11 | 'object_mask': 'mask', 12 | 'comp_img_lab' :'image', 13 | 'real_img_lab' :'image', 14 | } 15 | self.additional_targets = additional_targets 16 | super().__init__(transforms, *args, additional_targets=additional_targets, **kwargs) 17 | if no_nearest_for_masks: 18 | for t in transforms: 19 | if isinstance(t, DualTransform): 20 | t._additional_targets['object_mask'] = 'image' 21 | 22 | 23 | class LongestMaxSizeIfLarger(LongestMaxSize): 24 | """ 25 | Rescale an image so that maximum side is less or equal to max_size, keeping the aspect ratio of the initial image. 26 | If image sides are smaller than the given max_size, no rescaling is applied. 27 | 28 | Args: 29 | max_size (int): maximum size of smallest side of the image after the transformation. 30 | interpolation (OpenCV flag): interpolation method. Default: cv2.INTER_LINEAR. 31 | p (float): probability of applying the transform. Default: 1. 32 | 33 | Targets: 34 | image, mask, bboxes, keypoints 35 | 36 | Image types: 37 | uint8, float32 38 | """ 39 | def apply(self, img, interpolation=cv2.INTER_LINEAR, **params): 40 | if max(img.shape[:2]) < self.max_size: 41 | return img 42 | return F.longest_max_size(img, max_size=self.max_size, interpolation=interpolation) 43 | 44 | def apply_to_keypoint(self, keypoint, **params): 45 | height = params["rows"] 46 | width = params["cols"] 47 | 48 | scale = self.max_size / max([height, width]) 49 | if scale > 1.0: 50 | return keypoint 51 | return F.keypoint_scale(keypoint, scale, scale) 52 | -------------------------------------------------------------------------------- /iharm/engine/event_loop.py: -------------------------------------------------------------------------------- 1 | import time 2 | import math 3 | from collections import deque 4 | 5 | 6 | class EventLoop(object): 7 | def __init__(self, console_logger=None, tb_writer=None, period_length=1000, start_tick=0, 8 | speed_counter='ma_10', speed_tb_period=-1): 9 | self.period_length = period_length 10 | self.total_ticks = start_tick 11 | self.events = [] 12 | self.metrics = {} 13 | self.last_time = None 14 | self.tb_writer = tb_writer 15 | self.console_logger = console_logger 16 | self._speed_metric = '_speed_' 17 | self._max_delta_time = 60 18 | self._prev_total_ticks = 0 19 | 20 | self.register_metric( 21 | self._speed_metric, 22 | console_name='Speed', console_period=-1, console_format='{:.2f} samples/s', 23 | tb_name='Misc/Speed', tb_period=speed_tb_period, 24 | counter=speed_counter 25 | ) 26 | 27 | def register_event(self, event, period, func_inputs=(), 28 | last_step=0, associated_metric=None, onetime=False): 29 | event = { 30 | 'event': event, 31 | 'last_step': last_step, 32 | 'period': period, 33 | 'func_inputs': func_inputs, 34 | 'onetime': onetime 35 | } 36 | if associated_metric is not None: 37 | event['metric'] = associated_metric 38 | 39 | self.events.append(event) 40 | 41 | def register_metric(self, metric_name, 42 | console_name=None, console_period=-1, console_format='{:.3f}', 43 | tb_name=None, tb_period=-1, tb_global_step='n_ticks', 44 | counter=None): 45 | 46 | self.metrics[metric_name] = { 47 | 'console_name': console_name if console_name is not None else metric_name, 48 | 'console_event': {'last_step': 0, 'period': console_period}, 49 | 'console_format': console_format, 50 | 'tb_name': tb_name if tb_name is not None else metric_name, 51 | 'tb_event': {'last_step': 0, 'period': tb_period}, 52 | 'counters': {counter: parse_counter(counter)}, 53 | 'default_counter': counter, 54 | 'tb_global_step': tb_global_step 55 | } 56 | 57 | def register_metric_event(self, event, metric_name, period, func_inputs=(), 58 | console_period=0, tb_period=0, 59 | **metric_kwargs): 60 | self.register_event(event, period, func_inputs=func_inputs, associated_metric=metric_name) 61 | self.register_metric(metric_name, console_period=console_period, tb_period=tb_period, 62 | **metric_kwargs) 63 | 64 | def add_metric_value(self, metric_name, value): 65 | metric = self.metrics[metric_name] 66 | for counter in metric['counters'].values(): 67 | counter.add(value) 68 | metric['console_event']['relaxed'] = False 69 | metric['tb_event']['relaxed'] = False 70 | 71 | def add_custom_metric_counter(self, metric_name, counter): 72 | metric = self.metrics[metric_name] 73 | if counter not in metric['counters']: 74 | metric['counters'][counter] = parse_counter(counter) 75 | 76 | def get_metric_value(self, metric_name, counter_name=None): 77 | metric = self.metrics[metric_name] 78 | if counter_name is None: 79 | return metric['counters'][metric['default_counter']].value 80 | else: 81 | return metric['counters'][counter_name].value 82 | 83 | def step(self, step_size): 84 | self._prev_total_ticks = self.total_ticks 85 | self.total_ticks += step_size 86 | 87 | self._update_time(step_size) 88 | self._check_events() 89 | self._check_metrics() 90 | 91 | def get_states(self): 92 | return { 93 | 'metrics': self.metrics, 94 | 'total_ticks': self.total_ticks, 95 | '_prev_total_ticks': self._prev_total_ticks, 96 | 'base_period': self.period_length 97 | } 98 | 99 | def set_states(self, states): 100 | for k, v in states.items(): 101 | if k == 'metrics': 102 | self.metrics.update(v) 103 | else: 104 | setattr(self, k, v) 105 | 106 | @property 107 | def n_periods(self): 108 | return self.total_ticks // self.period_length 109 | 110 | @property 111 | def f_periods(self): 112 | return self.total_ticks / self.period_length 113 | 114 | @property 115 | def n_ticks(self): 116 | return self.total_ticks 117 | 118 | def _check_events(self): 119 | triggered_events = [] 120 | for event in self.events: 121 | if self._check_event(event, ignore_relaxed=True): 122 | triggered_events.append(event) 123 | if event['onetime']: 124 | event['remove'] = True 125 | 126 | self.events = [event for event in self.events if not event.get('remove', False)] 127 | 128 | for event in triggered_events: 129 | event_metric = event.get('metric', None) 130 | inputs = (getattr(self, input_name) for input_name in event['func_inputs']) 131 | if event_metric is None: 132 | event['event'](*inputs) 133 | else: 134 | metric_value = event['event'](*inputs) 135 | self.add_metric_value(event_metric, metric_value) 136 | 137 | def _check_metrics(self): 138 | print_list = [] 139 | for metric_name, metric in self.metrics.items(): 140 | if self._check_event(metric['console_event']): 141 | print_list.append(metric_name) 142 | 143 | if self.tb_writer is not None and self._check_event(metric['tb_event']): 144 | global_step = getattr(self, metric['tb_global_step']) 145 | self.tb_writer.add_scalar(metric['tb_name'], self.get_metric_value(metric_name), 146 | global_step=global_step) 147 | 148 | if self.console_logger is not None and print_list: 149 | print_list.append(self._speed_metric) 150 | 151 | log_str = f'[{self.n_periods:06d}] ' 152 | for metric_name in print_list: 153 | metric = self.metrics[metric_name] 154 | metric_value = self.get_metric_value(metric_name) 155 | log_str += metric['console_name'] + ': ' + metric['console_format'].format(metric_value) + ' ' 156 | 157 | self.console_logger.info(log_str) 158 | 159 | def _check_event(self, event, ignore_relaxed=False): 160 | if not ignore_relaxed and event.get('relaxed', True): 161 | return False 162 | 163 | triggered = False 164 | event_period = event['period'] 165 | if isinstance(event_period, dict): 166 | pstates = event.get('period_states', None) 167 | if pstates is None: 168 | sorted_periods = list(sorted(event['period'].items())) 169 | pstates = {'sorted_periods': sorted_periods, 'p_indx': 0} 170 | event['period_states'] = pstates 171 | 172 | sorted_periods = pstates['sorted_periods'] 173 | p_indx = pstates['p_indx'] 174 | while p_indx + 1 < len(sorted_periods) and self.n_periods >= sorted_periods[p_indx + 1][0]: 175 | p_indx += 1 176 | pstates['p_indx'] = p_indx 177 | event['last_step'] = self.period_length * (sorted_periods[p_indx][0] - sorted_periods[p_indx][1]) 178 | event_period = sorted_periods[p_indx][1] 179 | 180 | if event_period == 0: 181 | triggered = True 182 | elif event_period > 0: 183 | event_period_ticks = self.period_length * event_period 184 | last_step = event['last_step'] 185 | k = (self.total_ticks - last_step) / event_period_ticks 186 | if k >= 1.0 or k < 0: 187 | event['last_step'] = last_step + math.floor(k) * event_period_ticks 188 | prev_k = (self._prev_total_ticks - last_step) / event_period_ticks 189 | triggered = k >= 1.0 and prev_k < 1.0 190 | 191 | if triggered: 192 | event['relaxed'] = True 193 | return triggered 194 | 195 | def _update_time(self, step_size): 196 | current_time = time.time() 197 | if self.last_time is None: 198 | delta_time = 0 199 | else: 200 | delta_time = current_time - self.last_time 201 | self.last_time = current_time 202 | if delta_time > self._max_delta_time: 203 | delta_time = 0 204 | 205 | if delta_time > 0: 206 | speed = step_size / delta_time 207 | self.add_metric_value(self._speed_metric, speed) 208 | 209 | 210 | def parse_counter(counter_desc): 211 | if counter_desc is not None: 212 | smoothing_name, smoothing_period = counter_desc.split('_') 213 | smoothing_period = int(smoothing_period) 214 | 215 | if smoothing_name == 'ma': 216 | counter = MovingAverage(smoothing_period) 217 | else: 218 | raise NotImplementedError 219 | else: 220 | counter = SimpleLastValue() 221 | 222 | return counter 223 | 224 | 225 | class MovingAverage(object): 226 | def __init__(self, window_size): 227 | self.window_size = window_size 228 | self.window = deque() 229 | self.sum = 0 230 | self.cnt = 0 231 | 232 | def add(self, value): 233 | if len(self.window) < self.window_size: 234 | self.sum += value 235 | self.window.append(value) 236 | else: 237 | first = self.window.popleft() 238 | self.sum -= first 239 | self.sum += value 240 | self.window.append(value) 241 | self.cnt += 1 242 | 243 | @property 244 | def value(self): 245 | if self.window: 246 | return self.sum / len(self.window) 247 | else: 248 | return 0 249 | 250 | def __len__(self): 251 | return self.cnt 252 | 253 | 254 | class SimpleLastValue(object): 255 | def __init__(self): 256 | self.last_value = 0 257 | self.cnt = 0 258 | 259 | def add(self, value): 260 | self.last_value = value 261 | self.cnt += 1 262 | 263 | @property 264 | def value(self): 265 | return self.last_value 266 | -------------------------------------------------------------------------------- /iharm/engine/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from iharm.utils.log import logger 4 | 5 | 6 | def get_optimizer(model, opt_name, opt_kwargs): 7 | params = [] 8 | base_lr = opt_kwargs['lr'] 9 | for name, param in model.named_parameters(): 10 | param_group = {'params': [param]} 11 | if not param.requires_grad: 12 | params.append(param_group) 13 | continue 14 | 15 | if not math.isclose(getattr(param, 'lr_mult', 1.0), 1.0): 16 | logger.info(f'Applied lr_mult={param.lr_mult} to "{name}" parameter.') 17 | param_group['lr'] = param_group.get('lr', base_lr) * param.lr_mult 18 | 19 | params.append(param_group) 20 | 21 | optimizer = { 22 | 'sgd': torch.optim.SGD, 23 | 'adam': torch.optim.Adam, 24 | 'adamw': torch.optim.AdamW 25 | }[opt_name.lower()](params, **opt_kwargs) 26 | 27 | return optimizer 28 | -------------------------------------------------------------------------------- /iharm/inference/evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | from time import time 3 | from tqdm import trange 4 | import torch 5 | import torch.nn as nn 6 | import random 7 | import cv2 8 | import numpy as np 9 | 10 | 11 | def evaluate_dataset(dataset, predictor, metrics_hub): 12 | 13 | for sample_i in trange(len(dataset), desc=f'Testing on {metrics_hub.name}'): 14 | sample = dataset.get_sample(sample_i) 15 | sample = dataset.augment_sample(sample) 16 | 17 | sample_mask = sample['object_mask'] 18 | image_lab = sample['comp_img_lab'] 19 | 20 | predict_start = time() 21 | 22 | pred = predictor.predict(sample['image'], sample_mask, image_lab=image_lab,return_numpy=False) 23 | 24 | torch.cuda.synchronize() 25 | metrics_hub.update_time(time() - predict_start) 26 | 27 | target_image = torch.as_tensor(sample['target_image'], dtype=torch.float32).to(predictor.device) 28 | sample_mask = torch.as_tensor(sample_mask, dtype=torch.float32).to(predictor.device) 29 | with torch.no_grad(): 30 | metrics_hub.compute_and_add(pred, target_image, sample_mask) 31 | 32 | -------------------------------------------------------------------------------- /iharm/inference/metrics.py: -------------------------------------------------------------------------------- 1 | from copy import copy 2 | import math 3 | 4 | 5 | class MetricsHub: 6 | def __init__(self, metrics, name='', name_width=20): 7 | self.metrics = metrics 8 | self.name = name 9 | self.name_width = name_width 10 | 11 | def compute_and_add(self, *args): 12 | for m in self.metrics: 13 | if not isinstance(m, TimeMetric): 14 | m.compute_and_add(*args) 15 | 16 | def update_time(self, time_value): 17 | for m in self.metrics: 18 | if isinstance(m, TimeMetric): 19 | m.update_time(time_value) 20 | 21 | def get_table_header(self): 22 | table_header = ' ' * self.name_width + '|' 23 | for m in self.metrics: 24 | table_header += f'{m.name:^{m.cwidth}}|' 25 | splitter = len(table_header) * '-' 26 | return f'{splitter}\n{table_header}\n{splitter}' 27 | 28 | def __add__(self, another_hub): 29 | merged_metrics = [] 30 | for a, b in zip(self.metrics, another_hub.metrics): 31 | merged_metrics.append(a + b) 32 | if not merged_metrics: 33 | merged_metrics = copy(another_hub.metrics) 34 | 35 | return MetricsHub(merged_metrics, name=self.name, name_width=self.name_width) 36 | 37 | def __repr__(self): 38 | table_row = f'{self.name:<{self.name_width}}|' 39 | for m in self.metrics: 40 | table_row += f'{str(m):^{m.cwidth}}|' 41 | return table_row 42 | 43 | 44 | class EvalMetric: 45 | def __init__(self): 46 | self._values_sum = 0.0 47 | self._count = 0 48 | self.cwidth = 10 49 | 50 | def compute_and_add(self, pred, target_image, mask): 51 | self._values_sum += self._compute_metric(pred, target_image, mask) 52 | self._count += 1 53 | 54 | def _compute_metric(self, pred, target_image, mask): 55 | raise NotImplementedError 56 | 57 | def __add__(self, another_eval_metric): 58 | comb_metric = copy(self) 59 | comb_metric._count += another_eval_metric._count 60 | comb_metric._values_sum += another_eval_metric._values_sum 61 | return comb_metric 62 | 63 | @property 64 | def value(self): 65 | return self._values_sum / self._count if self._count > 0 else None 66 | 67 | @property 68 | def name(self): 69 | return type(self).__name__ 70 | 71 | def __repr__(self): 72 | return f'{self.value:.2f}' 73 | 74 | def __len__(self): 75 | return self._count 76 | 77 | 78 | class MSE(EvalMetric): 79 | def _compute_metric(self, pred, target_image, mask): 80 | return ((pred - target_image) ** 2).mean().item() 81 | 82 | 83 | class fMSE(EvalMetric): 84 | def _compute_metric(self, pred, target_image, mask): 85 | diff = mask.unsqueeze(2) * ((pred - target_image) ** 2) 86 | return (diff.sum() / (diff.size(2) * mask.sum() + 1e-6)).item() 87 | 88 | 89 | class PSNR(MSE): 90 | def __init__(self, epsilon=1e-6): 91 | super().__init__() 92 | self._epsilon = epsilon 93 | 94 | def _compute_metric(self, pred, target_image, mask): 95 | mse = super()._compute_metric(pred, target_image, mask) 96 | squared_max = target_image.max().item() ** 2 97 | 98 | return 10 * math.log10(squared_max / (mse + self._epsilon)) 99 | 100 | 101 | class N(EvalMetric): 102 | def _compute_metric(self, pred, target_image, mask): 103 | return 0 104 | 105 | @property 106 | def value(self): 107 | return self._count 108 | 109 | def __repr__(self): 110 | return str(self.value) 111 | 112 | 113 | class TimeMetric(EvalMetric): 114 | def update_time(self, time_value): 115 | self._values_sum += time_value 116 | self._count += 1 117 | 118 | 119 | class AvgPredictTime(TimeMetric): 120 | def __init__(self): 121 | super().__init__() 122 | self.cwidth = 14 123 | 124 | @property 125 | def name(self): 126 | return 'AvgTime, ms' 127 | 128 | def __repr__(self): 129 | return f'{1000 * self.value:.1f}' 130 | -------------------------------------------------------------------------------- /iharm/inference/predictor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from iharm.inference.transforms import NormalizeTensor, PadToDivisor, ToTensor, AddFlippedTensor 4 | from torchvision import transforms 5 | import numpy as np 6 | import cv2 7 | 8 | class Predictor(object): 9 | def __init__(self, net, device, with_flip=False, 10 | mean=(.485, .456, .406), std=(.229, .224, .225)): 11 | self.device = device 12 | self.net = net.to(self.device) 13 | self.net.eval() 14 | 15 | if hasattr(net, 'depth'): 16 | size_divisor = 2 ** (net.depth + 1) 17 | else: 18 | size_divisor = 1 19 | 20 | mean = torch.tensor(mean, dtype=torch.float32) 21 | std = torch.tensor(std, dtype=torch.float32) 22 | self.transforms = [ 23 | PadToDivisor(divisor=size_divisor, border_mode=0), 24 | ToTensor(self.device), 25 | NormalizeTensor(mean, std, self.device), 26 | ] 27 | if with_flip: 28 | self.transforms.append(AddFlippedTensor()) 29 | 30 | def predict(self, image, mask, image_lab = None,return_numpy=True): 31 | with torch.no_grad(): 32 | for transform in self.transforms: 33 | image, mask = transform.transform(image, mask) 34 | 35 | transform = transforms.Compose([ 36 | transforms.ToTensor(), 37 | ]) 38 | comp_image_lab = transform(image_lab).unsqueeze(0).to(self.device) 39 | predicted_output = self.net(image, mask,image_lab = comp_image_lab) 40 | predicted_image = predicted_output['images'] 41 | 42 | 43 | for transform in reversed(self.transforms): 44 | predicted_image = transform.inv_transform(predicted_image) 45 | 46 | predicted_image = torch.clamp(predicted_image, 0, 255) 47 | 48 | if return_numpy: 49 | return predicted_image.cpu().numpy() 50 | else: 51 | return predicted_image -------------------------------------------------------------------------------- /iharm/inference/transforms.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | from collections import namedtuple 4 | 5 | 6 | class EvalTransform: 7 | def __init__(self): 8 | pass 9 | 10 | def transform(self, image, mask): 11 | raise NotImplementedError 12 | 13 | def inv_transform(self, image): 14 | raise NotImplementedError 15 | 16 | 17 | class PadToDivisor(EvalTransform): 18 | """ 19 | Pad side of the image so that its side is divisible by divisor. 20 | 21 | Args: 22 | divisor (int): desirable image size divisor 23 | border_mode (OpenCV flag): OpenCV border mode. 24 | fill_value (int, float, list of int, lisft of float): padding value if border_mode is cv2.BORDER_CONSTANT. 25 | """ 26 | PadParams = namedtuple('PadParams', ['top', 'bottom', 'left', 'right']) 27 | 28 | def __init__(self, divisor, border_mode=cv2.BORDER_CONSTANT, fill_value=0): 29 | super().__init__() 30 | self.border_mode = border_mode 31 | self.fill_value = fill_value 32 | self.divisor = divisor 33 | self._pads = None 34 | 35 | def transform(self, image, mask): 36 | self._pads = PadToDivisor.PadParams(*self._get_dim_padding(image.shape[0]), 37 | *self._get_dim_padding(image.shape[1])) 38 | 39 | image = cv2.copyMakeBorder(image, *self._pads, self.border_mode, value=self.fill_value) 40 | mask = cv2.copyMakeBorder(mask, *self._pads, self.border_mode, value=self.fill_value) 41 | 42 | return image, mask 43 | 44 | def inv_transform(self, image): 45 | assert self._pads is not None,\ 46 | 'Something went wrong, inv_transform(...) should be called after transform(...)' 47 | return self._remove_padding(image) 48 | 49 | def _get_dim_padding(self, dim_size): 50 | pad = (self.divisor - dim_size % self.divisor) % self.divisor 51 | pad_upper = pad // 2 52 | pad_lower = pad - pad_upper 53 | 54 | return pad_upper, pad_lower 55 | 56 | def _remove_padding(self, tensor): 57 | tensor_h, tensor_w = tensor.shape[:2] 58 | cropped = tensor[self._pads.top:tensor_h - self._pads.bottom, 59 | self._pads.left:tensor_w - self._pads.right, :] 60 | return cropped 61 | 62 | 63 | class NormalizeTensor(EvalTransform): 64 | def __init__(self, mean, std, device): 65 | super().__init__() 66 | self.mean = torch.as_tensor(mean).reshape(1, 3, 1, 1).to(device) 67 | self.std = torch.as_tensor(std).reshape(1, 3, 1, 1).to(device) 68 | 69 | def transform(self, image, mask): 70 | image.sub_(self.mean).div_(self.std) 71 | return image, mask 72 | 73 | def inv_transform(self, image): 74 | image.mul_(self.std).add_(self.mean) 75 | return image 76 | 77 | 78 | class ToTensor(EvalTransform): 79 | def __init__(self, device): 80 | super().__init__() 81 | self.device = device 82 | 83 | def transform(self, image, mask): 84 | image = torch.as_tensor(image, device=self.device, dtype=torch.float32) 85 | mask = torch.as_tensor(mask, device=self.device) 86 | image.unsqueeze_(0) 87 | mask.unsqueeze_(0).unsqueeze_(0) 88 | return image.permute(0, 3, 1, 2) / 255.0, mask 89 | 90 | def inv_transform(self, image): 91 | image.squeeze_(0) 92 | return 255 * image.permute(1, 2, 0) 93 | 94 | 95 | class AddFlippedTensor(EvalTransform): 96 | def transform(self, image, mask): 97 | flipped_image = torch.flip(image, dims=(3,)) 98 | flipped_mask = torch.flip(mask, dims=(3,)) 99 | image = torch.cat((image, flipped_image), dim=0) 100 | mask = torch.cat((mask, flipped_mask), dim=0) 101 | return image, mask 102 | 103 | def inv_transform(self, image): 104 | return 0.5 * (image[:1] + torch.flip(image[1:], dims=(3,))) 105 | -------------------------------------------------------------------------------- /iharm/inference/utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from iharm.utils.misc import load_weights 4 | from iharm.mconfigs import ALL_MCONFIGS 5 | 6 | 7 | def load_model(model_type, checkpoint_path, verbose=False): 8 | net = ALL_MCONFIGS[model_type]['model'](**ALL_MCONFIGS[model_type]['params']) 9 | load_weights(net, checkpoint_path, verbose=verbose) 10 | return net 11 | 12 | 13 | def find_checkpoint(weights_folder, checkpoint_name): 14 | weights_folder = Path(weights_folder) 15 | if ':' in checkpoint_name: 16 | model_name, checkpoint_name = checkpoint_name.split(':') 17 | models_candidates = [x for x in weights_folder.glob(f'{model_name}*') if x.is_dir()] 18 | assert len(models_candidates) == 1 19 | model_folder = models_candidates[0] 20 | else: 21 | model_folder = weights_folder 22 | 23 | if checkpoint_name.endswith('.pth'): 24 | if Path(checkpoint_name).exists(): 25 | checkpoint_path = checkpoint_name 26 | else: 27 | checkpoint_path = weights_folder / checkpoint_name 28 | else: 29 | model_checkpoints = list(model_folder.rglob(f'{checkpoint_name}*.pth')) 30 | assert len(model_checkpoints) == 1 31 | checkpoint_path = model_checkpoints[0] 32 | return str(checkpoint_path) 33 | 34 | -------------------------------------------------------------------------------- /iharm/mconfigs/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BMCONFIGS 2 | from .backboned import MCONFIGS 3 | 4 | 5 | ALL_MCONFIGS = dict(**BMCONFIGS, **MCONFIGS) 6 | -------------------------------------------------------------------------------- /iharm/mconfigs/backboned.py: -------------------------------------------------------------------------------- 1 | from .base import BMCONFIGS 2 | from iharm.model.backboned import DeepLabIHModel, HRNetIHModel 3 | 4 | 5 | MCONFIGS = { 6 | 'hrnet18s_idih256': { 7 | 'model': HRNetIHModel, 8 | 'params': {'base_config': BMCONFIGS['improved_dih256']} 9 | }, 10 | 'hrnet18s_v2p_idih256': { 11 | 'model': HRNetIHModel, 12 | 'params': {'base_config': BMCONFIGS['improved_dih256'], 'pyramid_channels': 256} 13 | }, 14 | 'hrnet18_idih256': { 15 | 'model': HRNetIHModel, 16 | 'params': {'base_config': BMCONFIGS['improved_dih256'], 'small': False} 17 | }, 18 | 'hrnet18_v2p_idih256': { 19 | 'model': HRNetIHModel, 20 | 'params': {'base_config': BMCONFIGS['improved_dih256'], 'small': False, 'pyramid_channels': 256} 21 | }, 22 | 'hrnet32_idih256': { 23 | 'model': HRNetIHModel, 24 | 'params': {'base_config': BMCONFIGS['improved_dih256'], 'width': 32, 'small': False} 25 | }, 26 | 'deeplab_r34_idih256': { 27 | 'model': DeepLabIHModel, 28 | 'params': {'base_config': BMCONFIGS['improved_dih256']} 29 | }, 30 | 'hrnet18_idih512': { 31 | 'model': HRNetIHModel, 32 | 'params': {'base_config': BMCONFIGS['improved_dih512'], 'small': False, 'downsize_hrnet_input': True} 33 | }, 34 | 'hrnet18_sedih512': { 35 | 'model': HRNetIHModel, 36 | 'params': {'base_config': BMCONFIGS['improved_sedih512'], 'small': False, 'downsize_hrnet_input': True} 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /iharm/mconfigs/base.py: -------------------------------------------------------------------------------- 1 | from iharm.model.base import DeepImageHarmonization, SSAMImageHarmonization, ISEUNetV1, DucoNet_model 2 | 3 | BMCONFIGS = { 4 | 'dih256': { 5 | 'model': DeepImageHarmonization, 6 | 'params': {'depth': 7} 7 | }, 8 | 'improved_dih256': { 9 | 'model': DeepImageHarmonization, 10 | 'params': {'depth': 7, 'batchnorm_from': 2, 'image_fusion': True} 11 | }, 12 | 'improved_sedih256': { 13 | 'model': DeepImageHarmonization, 14 | 'params': {'depth': 7, 'batchnorm_from': 2, 'image_fusion': True, 'attend_from': 5} 15 | }, 16 | 'ssam256': { 17 | 'model': SSAMImageHarmonization, 18 | 'params': {'depth': 4, 'batchnorm_from': 2, 'attend_from': 2} 19 | }, 20 | 'improved_ssam256': { 21 | 'model': SSAMImageHarmonization, 22 | 'params': {'depth': 4, 'ch': 32, 'image_fusion': True, 'attention_mid_k': 0.5, 23 | 'batchnorm_from': 2, 'attend_from': 2} 24 | }, 25 | 'iseunetv1_256': { 26 | 'model': ISEUNetV1, 27 | 'params': {'depth': 4, 'batchnorm_from': 2, 'attend_from': 1, 'ch': 64} 28 | }, 29 | 'dih512': { 30 | 'model': DeepImageHarmonization, 31 | 'params': {'depth': 8} 32 | }, 33 | 'improved_dih512': { 34 | 'model': DeepImageHarmonization, 35 | 'params': {'depth': 8, 'batchnorm_from': 2, 'image_fusion': True} 36 | }, 37 | 'improved_ssam512': { 38 | 'model': SSAMImageHarmonization, 39 | 'params': {'depth': 6, 'ch': 32, 'image_fusion': True, 'attention_mid_k': 0.5, 40 | 'batchnorm_from': 2, 'attend_from': 3} 41 | }, 42 | 'improved_sedih512': { 43 | 'model': DeepImageHarmonization, 44 | 'params': {'depth': 8, 'batchnorm_from': 2, 'image_fusion': True, 'attend_from': 6} 45 | }, 46 | 47 | 'DucoNet':{ 48 | 'model': DucoNet_model, 49 | 'params': {'depth': 4, 'ch': 32, 'image_fusion': True, 'attention_mid_k': 0.5, 50 | 'batchnorm_from': 2, 'attend_from': 2,'w_dim':256,'control_module_start':-1} 51 | }, 52 | 53 | } 54 | -------------------------------------------------------------------------------- /iharm/model/backboned/__init__.py: -------------------------------------------------------------------------------- 1 | from .deeplab import DeepLabIHModel 2 | from .hrnet import HRNetIHModel 3 | -------------------------------------------------------------------------------- /iharm/model/backboned/deeplab.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn 2 | 3 | from iharm.model.modeling.deeplab_v3 import DeepLabV3Plus 4 | from iharm.model.backboned.ih_model import IHModelWithBackbone 5 | from iharm.model.modifiers import LRMult 6 | from iharm.model.modeling.basic_blocks import MaxPoolDownSize 7 | 8 | 9 | class DeepLabIHModel(IHModelWithBackbone): 10 | def __init__( 11 | self, 12 | base_config, 13 | mask_fusion='sum', 14 | deeplab_backbone='resnet34', 15 | lr_mult=0.1, 16 | pyramid_channels=-1, deeplab_ch=256, 17 | mode='cat', 18 | **base_kwargs 19 | ): 20 | """ 21 | Creates image harmonization model supported by the features extracted from the pre-trained DeepLab backbone. 22 | 23 | Parameters 24 | ---------- 25 | base_config : dict 26 | Configuration dict for the base model, to which the backbone features are incorporated. 27 | base_config contains model class and init parameters, examples can be found in iharm.mconfigs.base_models 28 | mask_fusion : str 29 | How to fuse the binary mask with the backbone input: 30 | 'sum': apply convolution to the mask and sum it with the output of the first convolution in the backbone 31 | 'rgb': concatenate the mask to the input image and translate it back to 3 channels with convolution 32 | otherwise: do not fuse mask with the backbone input 33 | deeplab_backbone : str 34 | ResNet backbone name. 35 | lr_mult : float 36 | Multiply learning rate to lr_mult when updating the weights of the backbone. 37 | pyramid_channels : int 38 | The DeepLab output can be consequently downsized to produce a feature pyramid. 39 | The pyramid features are then fused with the encoder outputs in the base model on multiple layers. 40 | Each pyramid feature map contains equal number of channels equal to pyramid_channels. 41 | If pyramid_channels <= 0, the feature pyramid is not constructed. 42 | deeplab_ch : int 43 | Number of channels for output DeepLab layer and some in the middle. 44 | mode : str 45 | How to fuse the backbone features with the encoder outputs in the base model: 46 | 'sum': apply convolution to the backbone feature map obtaining number of channels 47 | same as in the encoder output and sum them 48 | 'cat': concatenate the backbone feature map with the encoder output 49 | 'catc': concatenate the backbone feature map with the encoder output and apply convolution obtaining 50 | number of channels same as in the encoder output 51 | otherwise: the backbone features are not incorporated into the base model 52 | base_kwargs : dict 53 | any kwargs associated with the base model 54 | """ 55 | params = base_config['params'] 56 | params.update(base_kwargs) 57 | depth = params['depth'] 58 | 59 | backbone = DeepLabBB(pyramid_channels, deeplab_ch, deeplab_backbone, lr_mult) 60 | 61 | downsize_input = depth > 7 62 | params.update(dict( 63 | backbone_from=3 if downsize_input else 2, 64 | backbone_channels=backbone.output_channels, 65 | backbone_mode=mode 66 | )) 67 | base_model = base_config['model'](**params) 68 | 69 | super(DeepLabIHModel, self).__init__(base_model, backbone, downsize_input, mask_fusion) 70 | 71 | 72 | class DeepLabBB(nn.Module): 73 | def __init__( 74 | self, 75 | pyramid_channels=256, 76 | deeplab_ch=256, 77 | backbone='resnet34', 78 | backbone_lr_mult=0.1, 79 | ): 80 | super(DeepLabBB, self).__init__() 81 | self.pyramid_on = pyramid_channels > 0 82 | if self.pyramid_on: 83 | self.output_channels = [pyramid_channels] * 4 84 | else: 85 | self.output_channels = [deeplab_ch] 86 | 87 | self.deeplab = DeepLabV3Plus(backbone=backbone, 88 | ch=deeplab_ch, 89 | project_dropout=0.2, 90 | norm_layer=nn.BatchNorm2d, 91 | backbone_norm_layer=nn.BatchNorm2d) 92 | self.deeplab.backbone.apply(LRMult(backbone_lr_mult)) 93 | 94 | if self.pyramid_on: 95 | self.downsize = MaxPoolDownSize(deeplab_ch, pyramid_channels, pyramid_channels, 4) 96 | 97 | def forward(self, image, mask, mask_features): 98 | outputs = list(self.deeplab(image, mask_features)) 99 | if self.pyramid_on: 100 | outputs = self.downsize(outputs[0]) 101 | return outputs 102 | 103 | def load_pretrained_weights(self): 104 | self.deeplab.load_pretrained_weights() 105 | -------------------------------------------------------------------------------- /iharm/model/backboned/hrnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from iharm.model.modeling.hrnet_ocr import HighResolutionNet 4 | from iharm.model.backboned.ih_model import IHModelWithBackbone 5 | from iharm.model.modifiers import LRMult 6 | from iharm.model.modeling.basic_blocks import MaxPoolDownSize 7 | 8 | 9 | class HRNetIHModel(IHModelWithBackbone): 10 | def __init__( 11 | self, 12 | base_config, 13 | downsize_hrnet_input=False, mask_fusion='sum', 14 | lr_mult=0.1, cat_hrnet_outputs=True, pyramid_channels=-1, 15 | ocr=64, width=18, small=True, 16 | mode='cat', 17 | **base_kwargs 18 | ): 19 | """ 20 | Creates image harmonization model supported by the features extracted from the pre-trained HRNet backbone. 21 | HRNet outputs feature maps on 4 different resolutions. 22 | 23 | Parameters 24 | ---------- 25 | base_config : dict 26 | Configuration dict for the base model, to which the backbone features are incorporated. 27 | base_config contains model class and init parameters, examples can be found in iharm.mconfigs.base_models 28 | downsize_backbone_input : bool 29 | If the input image should be half-sized for the backbone. 30 | mask_fusion : str 31 | How to fuse the binary mask with the backbone input: 32 | 'sum': apply convolution to the mask and sum it with the output of the first convolution in the backbone 33 | 'rgb': concatenate the mask to the input image and translate it back to 3 channels with convolution 34 | otherwise: do not fuse mask with the backbone input 35 | lr_mult : float 36 | Multiply learning rate to lr_mult when updating the weights of the backbone. 37 | cat_hrnet_outputs : bool 38 | If 4 HRNet outputs should be resized and concatenated to a single tensor. 39 | pyramid_channels : int 40 | When HRNet outputs are concatenated to a single one, it can be consequently downsized 41 | to produce a feature pyramid. 42 | The pyramid features are then fused with the encoder outputs in the base model on multiple layers. 43 | Each pyramid feature map contains equal number of channels equal to pyramid_channels. 44 | If pyramid_channels <= 0, the feature pyramid is not constructed. 45 | ocr : int 46 | When HRNet outputs are concatenated to a single one, the OCR module can be applied 47 | resulting in feature map with (2 * ocr) channels. If ocr <= 0 the OCR module is not applied. 48 | width : int 49 | Width of the HRNet blocks. 50 | small : bool 51 | If True, HRNet contains 2 blocks at each stage and 4 otherwise. 52 | mode : str 53 | How to fuse the backbone features with the encoder outputs in the base model: 54 | 'sum': apply convolution to the backbone feature map obtaining number of channels 55 | same as in the encoder output and sum them 56 | 'cat': concatenate the backbone feature map with the encoder output 57 | 'catc': concatenate the backbone feature map with the encoder output and apply convolution obtaining 58 | number of channels same as in the encoder output 59 | otherwise: the backbone features are not incorporated into the base model 60 | base_kwargs : dict 61 | any kwargs associated with the base model 62 | """ 63 | params = base_config['params'] 64 | params.update(base_kwargs) 65 | depth = params['depth'] 66 | 67 | backbone = HRNetBB( 68 | cat_outputs=cat_hrnet_outputs, 69 | pyramid_channels=pyramid_channels, 70 | pyramid_depth=min(depth - 2 if not downsize_hrnet_input else depth - 3, 4), 71 | width=width, ocr=ocr, small=small, 72 | lr_mult=lr_mult, 73 | ) 74 | 75 | params.update(dict( 76 | backbone_from=3 if downsize_hrnet_input else 2, 77 | backbone_channels=backbone.output_channels, 78 | backbone_mode=mode 79 | )) 80 | base_model = base_config['model'](**params) 81 | 82 | super(HRNetIHModel, self).__init__(base_model, backbone, downsize_hrnet_input, mask_fusion) 83 | 84 | 85 | class HRNetBB(nn.Module): 86 | def __init__( 87 | self, 88 | cat_outputs=True, 89 | pyramid_channels=256, pyramid_depth=4, 90 | width=18, ocr=64, small=True, 91 | lr_mult=0.1, 92 | ): 93 | super(HRNetBB, self).__init__() 94 | self.cat_outputs = cat_outputs 95 | self.ocr_on = ocr > 0 and cat_outputs 96 | self.pyramid_on = pyramid_channels > 0 and cat_outputs 97 | 98 | self.hrnet = HighResolutionNet(width, 2, ocr_width=ocr, small=small) 99 | self.hrnet.apply(LRMult(lr_mult)) 100 | if self.ocr_on: 101 | self.hrnet.ocr_distri_head.apply(LRMult(1.0)) 102 | self.hrnet.ocr_gather_head.apply(LRMult(1.0)) 103 | self.hrnet.conv3x3_ocr.apply(LRMult(1.0)) 104 | 105 | hrnet_cat_channels = [width * 2 ** i for i in range(4)] 106 | if self.pyramid_on: 107 | self.output_channels = [pyramid_channels] * 4 108 | elif self.ocr_on: 109 | self.output_channels = [ocr * 2] 110 | elif self.cat_outputs: 111 | self.output_channels = [sum(hrnet_cat_channels)] 112 | else: 113 | self.output_channels = hrnet_cat_channels 114 | 115 | if self.pyramid_on: 116 | downsize_in_channels = ocr * 2 if self.ocr_on else sum(hrnet_cat_channels) 117 | self.downsize = MaxPoolDownSize(downsize_in_channels, pyramid_channels, pyramid_channels, pyramid_depth) 118 | 119 | def forward(self, image, mask, mask_features): 120 | if not self.cat_outputs: 121 | return self.hrnet.compute_hrnet_feats(image, mask_features, return_list=True) 122 | 123 | outputs = list(self.hrnet(image, mask, mask_features)) 124 | if self.pyramid_on: 125 | outputs = self.downsize(outputs[0]) 126 | return outputs 127 | 128 | def load_pretrained_weights(self, pretrained_path): 129 | self.hrnet.load_pretrained_weights(pretrained_path) 130 | -------------------------------------------------------------------------------- /iharm/model/backboned/ih_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from iharm.model.ops import SimpleInputFusion, ScaleLayer 5 | 6 | 7 | class IHModelWithBackbone(nn.Module): 8 | def __init__( 9 | self, 10 | model, backbone, 11 | downsize_backbone_input=False, 12 | mask_fusion='sum', 13 | backbone_conv1_channels=64, 14 | ): 15 | """ 16 | Creates image harmonization model supported by the features extracted from the pre-trained backbone. 17 | 18 | Parameters 19 | ---------- 20 | model : nn.Module 21 | Image harmonization model takes image and mask as an input and handles features from the backbone network. 22 | backbone : nn.Module 23 | Backbone model accepts RGB image and returns a list of features. 24 | downsize_backbone_input : bool 25 | If the input image should be half-sized for the backbone. 26 | mask_fusion : str 27 | How to fuse the binary mask with the backbone input: 28 | 'sum': apply convolution to the mask and sum it with the output of the first convolution in the backbone 29 | 'rgb': concatenate the mask to the input image and translate it back to 3 channels with convolution 30 | otherwise: do not fuse mask with the backbone input 31 | backbone_conv1_channels : int 32 | If mask_fusion is 'sum', define the number of channels for the convolution applied to the mask. 33 | """ 34 | super(IHModelWithBackbone, self).__init__() 35 | self.downsize_backbone_input = downsize_backbone_input 36 | self.mask_fusion = mask_fusion 37 | 38 | self.backbone = backbone 39 | self.model = model 40 | 41 | if mask_fusion == 'rgb': 42 | self.fusion = SimpleInputFusion() 43 | elif mask_fusion == 'sum': 44 | self.mask_conv = nn.Sequential( 45 | nn.Conv2d(1, backbone_conv1_channels, kernel_size=3, stride=2, padding=1, bias=True), 46 | ScaleLayer(init_value=0.1, lr_mult=1) 47 | ) 48 | 49 | def forward(self, image, mask): 50 | """ 51 | Forward the backbone model and then the base model, supported by the backbone feature maps. 52 | Return model predictions. 53 | 54 | Parameters 55 | ---------- 56 | image : torch.Tensor 57 | Input RGB image. 58 | mask : torch.Tensor 59 | Binary mask of the foreground region. 60 | 61 | Returns 62 | ------- 63 | torch.Tensor 64 | Harmonized RGB image. 65 | """ 66 | backbone_image = image 67 | backbone_mask = torch.cat((mask, 1.0 - mask), dim=1) 68 | if self.downsize_backbone_input: 69 | backbone_image = nn.functional.interpolate( 70 | backbone_image, scale_factor=0.5, 71 | mode='bilinear', align_corners=True 72 | ) 73 | backbone_mask = nn.functional.interpolate( 74 | backbone_mask, backbone_image.size()[2:], 75 | mode='bilinear', align_corners=True 76 | ) 77 | backbone_image = ( 78 | self.fusion(backbone_image, backbone_mask[:, :1]) 79 | if self.mask_fusion == 'rgb' else 80 | backbone_image 81 | ) 82 | backbone_mask_features = self.mask_conv(backbone_mask[:, :1]) if self.mask_fusion == 'sum' else None 83 | backbone_features = self.backbone(backbone_image, backbone_mask, backbone_mask_features) 84 | 85 | output = self.model(image, mask, backbone_features) 86 | return output 87 | -------------------------------------------------------------------------------- /iharm/model/base/Control_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from functools import partial 3 | 4 | from torch import nn as nn 5 | 6 | from iharm.model.modeling.basic_blocks import ConvBlock, GaussianSmoothing 7 | from iharm.model.modeling.unet_v1 import UNetEncoder, UNetDecoder 8 | from iharm.model.ops import ChannelAttention 9 | from iharm.model.modeling.styleganv2 import ResidualBlock 10 | 11 | 12 | class Control_encoder(nn.Module): 13 | def __init__( 14 | self, 15 | depth, 16 | norm_layer=nn.BatchNorm2d, batchnorm_from=2, 17 | ch=64, max_channels=512, 18 | lab_encoder_mask_channel=1, 19 | L_or_Lab_dim=3,w_dim=256, 20 | backbone_from=-1, backbone_channels=None, backbone_mode='', 21 | ): 22 | super(Control_encoder, self).__init__() 23 | self.depth = depth 24 | self.w_dim = w_dim 25 | self.lab_encoder_mask_channel = lab_encoder_mask_channel 26 | 27 | self.lab_encoder = UNetEncoder( 28 | depth, ch, 29 | norm_layer, batchnorm_from, max_channels, 30 | backbone_from, backbone_channels, backbone_mode, 31 | _in_channels=L_or_Lab_dim+self.lab_encoder_mask_channel 32 | ) 33 | 34 | self.map2w = map_net(int_dim=256,out_dim=self.w_dim) 35 | 36 | 37 | def forward(self, image_lab, mask,backbone_features=None): 38 | x_lab = image_lab 39 | if self.lab_encoder_mask_channel==1: 40 | x_lab = torch.cat((x_lab, mask), dim=1) 41 | intermediates_lab = self.lab_encoder(x_lab,backbone_features) 42 | w = self.map2w(intermediates_lab[0],mask) 43 | 44 | return w 45 | 46 | class map_net(nn.Module): 47 | def __init__(self,int_dim = 256,out_dim = 256): 48 | super(map_net, self).__init__() 49 | 50 | self.int_dim = int_dim 51 | self.pooling = nn.AdaptiveAvgPool2d((1,1)) 52 | 53 | self.net = nn.Sequential( 54 | nn.Linear(int_dim,256), 55 | nn.ReLU(inplace=True), 56 | nn.Linear(256,out_dim), 57 | ) 58 | 59 | def forward(self,feature_map,mask): 60 | 61 | w_bg = self.pooling(feature_map).view(-1,self.int_dim) 62 | 63 | return self.net(w_bg) 64 | 65 | 66 | class SpatialSeparatedAttention(nn.Module): 67 | def __init__(self, in_channels, norm_layer, activation, mid_k=2.0): 68 | super(SpatialSeparatedAttention, self).__init__() 69 | self.background_gate = ChannelAttention(in_channels) 70 | self.foreground_gate = ChannelAttention(in_channels) 71 | self.mix_gate = ChannelAttention(in_channels) 72 | 73 | mid_channels = int(mid_k * in_channels) 74 | self.learning_block = nn.Sequential( 75 | ConvBlock( 76 | in_channels, mid_channels, 77 | kernel_size=3, stride=1, padding=1, 78 | norm_layer=norm_layer, activation=activation, 79 | bias=False, 80 | ), 81 | ConvBlock( 82 | mid_channels, in_channels, 83 | kernel_size=3, stride=1, padding=1, 84 | norm_layer=norm_layer, activation=activation, 85 | bias=False, 86 | ), 87 | ) 88 | self.mask_blurring = GaussianSmoothing(1, 7, 1, padding=3) 89 | 90 | def forward(self, x, mask): 91 | mask = self.mask_blurring(nn.functional.interpolate( 92 | mask, size=x.size()[-2:], 93 | mode='bilinear', align_corners=True 94 | )) 95 | background = self.background_gate(x) 96 | foreground = self.learning_block(self.foreground_gate(x)) 97 | mix = self.mix_gate(x) 98 | output = mask * (foreground + mix) + (1 - mask) * background 99 | return output 100 | -------------------------------------------------------------------------------- /iharm/model/base/DucoNet_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from functools import partial 3 | 4 | from torch import nn as nn 5 | 6 | from iharm.model.modeling.basic_blocks import ConvBlock, GaussianSmoothing 7 | from iharm.model.modeling.unet_v1 import UNetEncoder, UNetDecoder 8 | from iharm.model.ops import ChannelAttention 9 | from iharm.model.modeling.styleganv2 import ResidualBlock 10 | from .Control_encoder import Control_encoder 11 | 12 | 13 | class DucoNet_model(nn.Module): 14 | def __init__( 15 | self, 16 | depth, 17 | norm_layer=nn.BatchNorm2d, batchnorm_from=2, 18 | attend_from=3, attention_mid_k=2.0, 19 | image_fusion=False, 20 | ch=64, max_channels=512, 21 | backbone_from=-1, backbone_channels=None, backbone_mode='', 22 | control_module_start = -1, 23 | lab_encoder_mask_channel=1, 24 | w_dim = 256, 25 | 26 | ): 27 | super(DucoNet_model, self).__init__() 28 | self.depth = depth 29 | self.control_module_start = control_module_start 30 | self.w_dim = w_dim 31 | 32 | self.lab_encoder_mask_channel = lab_encoder_mask_channel 33 | 34 | self.encoder = UNetEncoder( 35 | depth, ch, 36 | norm_layer, batchnorm_from, max_channels, 37 | backbone_from, backbone_channels, backbone_mode 38 | ) 39 | self.decoder = UNetDecoder( 40 | depth, self.encoder.block_channels, 41 | norm_layer, 42 | attention_layer=partial(SpatialSeparatedAttention, mid_k=attention_mid_k), 43 | attend_from=attend_from, 44 | image_fusion=image_fusion, 45 | control_module_start=self.control_module_start, 46 | w_dim = self.w_dim, 47 | control_module_layer=Control_Module, 48 | ) 49 | 50 | self.l_encoder = Control_encoder( 51 | depth=depth, ch=ch, batchnorm_from=2, 52 | lab_encoder_mask_channel=1, 53 | L_or_Lab_dim=1,w_dim= self.w_dim 54 | ) 55 | 56 | self.a_encoder = Control_encoder( 57 | depth=depth, ch=ch, batchnorm_from=2, 58 | lab_encoder_mask_channel=1, 59 | L_or_Lab_dim=1, w_dim=self.w_dim 60 | ) 61 | 62 | self.b_encoder = Control_encoder( 63 | depth=depth, ch=ch, batchnorm_from=2, 64 | lab_encoder_mask_channel=1, 65 | L_or_Lab_dim=1, w_dim=self.w_dim 66 | ) 67 | 68 | 69 | def forward(self, image, mask, image_lab=None, backbone_features=None): 70 | 71 | x = torch.cat((image, mask), dim=1) 72 | intermediates = self.encoder(x, backbone_features) 73 | 74 | w_l = self.l_encoder(image_lab[:,0,:,:].unsqueeze(1),mask) 75 | w_a = self.a_encoder(image_lab[:,1,:,:].unsqueeze(1),mask) 76 | w_b = self.b_encoder(image_lab[:,2,:,:].unsqueeze(1),mask) 77 | 78 | ws = {'w_l':w_l,'w_a':w_a,'w_b':w_b} 79 | 80 | output = self.decoder(intermediates, image, mask, ws) 81 | 82 | return {'images': output} 83 | 84 | 85 | class Control_Module(nn.Module): 86 | def __init__(self, w_dim, feature_dim): 87 | super(Control_Module,self).__init__() 88 | 89 | self.w_dim = w_dim 90 | self.feature_dim = feature_dim 91 | 92 | self.l_styleblock = ResidualBlock(self.w_dim, self.feature_dim, self.feature_dim) 93 | self.a_styleblock = ResidualBlock(self.w_dim, self.feature_dim, self.feature_dim) 94 | self.b_styleblock = ResidualBlock(self.w_dim, self.feature_dim, self.feature_dim) 95 | 96 | self.G_weight = nn.Sequential( 97 | nn.Conv2d(self.feature_dim * 3, 32, kernel_size=1, stride=1, padding=0, bias=True), 98 | nn.ReLU(inplace=True), 99 | nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0, bias=True), 100 | ) 101 | 102 | 103 | def forward(self, x, sl, sa, sb, mask): 104 | 105 | f_l = self.l_styleblock(x, sl, noise=None) 106 | f_a = self.a_styleblock(x, sa, noise=None) 107 | f_b = self.b_styleblock(x, sb, noise=None) 108 | 109 | f_lab = torch.cat((f_l, f_a, f_b), dim=1) 110 | f_weight = self.G_weight(f_lab) 111 | f_weight = nn.functional.softmax(f_weight, dim=1) 112 | 113 | weight_l = f_weight[:, 0, :, :].unsqueeze(1) 114 | weight_a = f_weight[:, 1, :, :].unsqueeze(1) 115 | weight_b = f_weight[:, 2, :, :].unsqueeze(1) 116 | 117 | out = weight_l * f_l + weight_a * f_a + weight_b * f_b 118 | 119 | out = x * (1 - mask) + mask * out 120 | return out 121 | 122 | class SpatialSeparatedAttention(nn.Module): 123 | def __init__(self, in_channels, norm_layer, activation, mid_k=2.0): 124 | super(SpatialSeparatedAttention, self).__init__() 125 | self.background_gate = ChannelAttention(in_channels) 126 | self.foreground_gate = ChannelAttention(in_channels) 127 | self.mix_gate = ChannelAttention(in_channels) 128 | 129 | mid_channels = int(mid_k * in_channels) 130 | self.learning_block = nn.Sequential( 131 | ConvBlock( 132 | in_channels, mid_channels, 133 | kernel_size=3, stride=1, padding=1, 134 | norm_layer=norm_layer, activation=activation, 135 | bias=False, 136 | ), 137 | ConvBlock( 138 | mid_channels, in_channels, 139 | kernel_size=3, stride=1, padding=1, 140 | norm_layer=norm_layer, activation=activation, 141 | bias=False, 142 | ), 143 | ) 144 | self.mask_blurring = GaussianSmoothing(1, 7, 1, padding=3) 145 | 146 | def forward(self, x, mask): 147 | mask = self.mask_blurring(nn.functional.interpolate( 148 | mask, size=x.size()[-2:], 149 | mode='bilinear', align_corners=True 150 | )) 151 | background = self.background_gate(x) 152 | foreground = self.learning_block(self.foreground_gate(x)) 153 | mix = self.mix_gate(x) 154 | output = mask * (foreground + mix) + (1 - mask) * background 155 | return output 156 | -------------------------------------------------------------------------------- /iharm/model/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .dih_model import DeepImageHarmonization 2 | from .ssam_model import SSAMImageHarmonization 3 | from .iseunet_v1 import ISEUNetV1 4 | from .DucoNet_model import DucoNet_model 5 | -------------------------------------------------------------------------------- /iharm/model/base/dih_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from iharm.model.modeling.conv_autoencoder import ConvEncoder, DeconvDecoder 5 | 6 | 7 | class DeepImageHarmonization(nn.Module): 8 | def __init__( 9 | self, 10 | depth, 11 | norm_layer=nn.BatchNorm2d, batchnorm_from=0, 12 | attend_from=-1, 13 | image_fusion=False, 14 | ch=64, max_channels=512, 15 | backbone_from=-1, backbone_channels=None, backbone_mode='' 16 | ): 17 | super(DeepImageHarmonization, self).__init__() 18 | self.depth = depth 19 | self.encoder = ConvEncoder( 20 | depth, ch, 21 | norm_layer, batchnorm_from, max_channels, 22 | backbone_from, backbone_channels, backbone_mode 23 | ) 24 | self.decoder = DeconvDecoder(depth, self.encoder.blocks_channels, norm_layer, attend_from, image_fusion) 25 | 26 | def forward(self, image, mask, backbone_features=None): 27 | x = torch.cat((image, mask), dim=1) 28 | intermediates = self.encoder(x, backbone_features) 29 | output = self.decoder(intermediates, image, mask) 30 | return {'images': output} 31 | -------------------------------------------------------------------------------- /iharm/model/base/iseunet_v1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from iharm.model.modeling.unet import UNetEncoder, UNetDecoder 5 | from iharm.model.ops import MaskedChannelAttention 6 | 7 | 8 | class ISEUNetV1(nn.Module): 9 | def __init__( 10 | self, 11 | depth, 12 | norm_layer=nn.BatchNorm2d, batchnorm_from=2, 13 | attend_from=3, 14 | image_fusion=False, 15 | ch=64, max_channels=512, 16 | backbone_from=-1, backbone_channels=None, backbone_mode='' 17 | ): 18 | super(ISEUNetV1, self).__init__() 19 | self.depth = depth 20 | self.encoder = UNetEncoder( 21 | depth, ch, 22 | norm_layer, batchnorm_from, max_channels, 23 | backbone_from, backbone_channels, backbone_mode 24 | ) 25 | self.decoder = UNetDecoder( 26 | depth, self.encoder.block_channels, 27 | norm_layer, 28 | attention_layer=MaskedChannelAttention, 29 | attend_from=attend_from, 30 | image_fusion=image_fusion 31 | ) 32 | 33 | def forward(self, image, mask, backbone_features=None): 34 | x = torch.cat((image, mask), dim=1) 35 | intermediates = self.encoder(x, backbone_features) 36 | output = self.decoder(intermediates, image, mask) 37 | return {'images': output} 38 | -------------------------------------------------------------------------------- /iharm/model/base/ssam_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from functools import partial 3 | 4 | from torch import nn as nn 5 | 6 | from iharm.model.modeling.basic_blocks import ConvBlock, GaussianSmoothing 7 | from iharm.model.modeling.unet import UNetEncoder, UNetDecoder 8 | from iharm.model.ops import ChannelAttention 9 | 10 | 11 | class SSAMImageHarmonization(nn.Module): 12 | def __init__( 13 | self, 14 | depth, 15 | norm_layer=nn.BatchNorm2d, batchnorm_from=2, 16 | attend_from=3, attention_mid_k=2.0, 17 | image_fusion=False, 18 | ch=64, max_channels=512, 19 | backbone_from=-1, backbone_channels=None, backbone_mode='' 20 | ): 21 | super(SSAMImageHarmonization, self).__init__() 22 | self.depth = depth 23 | self.encoder = UNetEncoder( 24 | depth, ch, 25 | norm_layer, batchnorm_from, max_channels, 26 | backbone_from, backbone_channels, backbone_mode 27 | ) 28 | self.decoder = UNetDecoder( 29 | depth, self.encoder.block_channels, 30 | norm_layer, 31 | attention_layer=partial(SpatialSeparatedAttention, mid_k=attention_mid_k), 32 | attend_from=attend_from, 33 | image_fusion=image_fusion 34 | ) 35 | 36 | def forward(self, image, mask, backbone_features=None): 37 | 38 | x = torch.cat((image, mask), dim=1) 39 | intermediates = self.encoder(x, backbone_features) 40 | output = self.decoder(intermediates, image, mask) 41 | return {'images': output} 42 | 43 | 44 | class SpatialSeparatedAttention(nn.Module): 45 | def __init__(self, in_channels, norm_layer, activation, mid_k=2.0): 46 | super(SpatialSeparatedAttention, self).__init__() 47 | self.background_gate = ChannelAttention(in_channels) 48 | self.foreground_gate = ChannelAttention(in_channels) 49 | self.mix_gate = ChannelAttention(in_channels) 50 | 51 | mid_channels = int(mid_k * in_channels) 52 | self.learning_block = nn.Sequential( 53 | ConvBlock( 54 | in_channels, mid_channels, 55 | kernel_size=3, stride=1, padding=1, 56 | norm_layer=norm_layer, activation=activation, 57 | bias=False, 58 | ), 59 | ConvBlock( 60 | mid_channels, in_channels, 61 | kernel_size=3, stride=1, padding=1, 62 | norm_layer=norm_layer, activation=activation, 63 | bias=False, 64 | ), 65 | ) 66 | self.mask_blurring = GaussianSmoothing(1, 7, 1, padding=3) 67 | 68 | def forward(self, x, mask): 69 | mask = self.mask_blurring(nn.functional.interpolate( 70 | mask, size=x.size()[-2:], 71 | mode='bilinear', align_corners=True 72 | )) 73 | background = self.background_gate(x) 74 | foreground = self.learning_block(self.foreground_gate(x)) 75 | mix = self.mix_gate(x) 76 | output = mask * (foreground + mix) + (1 - mask) * background 77 | return output 78 | -------------------------------------------------------------------------------- /iharm/model/initializer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class Initializer(object): 7 | def __init__(self, local_init=True, gamma=None): 8 | self.local_init = local_init 9 | self.gamma = gamma 10 | 11 | def __call__(self, m): 12 | if getattr(m, '__initialized', False): 13 | return 14 | 15 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, 16 | nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d, 17 | nn.GroupNorm, nn.SyncBatchNorm)) or 'BatchNorm' in m.__class__.__name__: 18 | if m.weight is not None: 19 | self._init_gamma(m.weight.data) 20 | if m.bias is not None: 21 | self._init_beta(m.bias.data) 22 | else: 23 | if getattr(m, 'weight', None) is not None: 24 | self._init_weight(m.weight.data) 25 | if getattr(m, 'bias', None) is not None: 26 | self._init_bias(m.bias.data) 27 | 28 | if self.local_init: 29 | object.__setattr__(m, '__initialized', True) 30 | 31 | def _init_weight(self, data): 32 | nn.init.uniform_(data, -0.07, 0.07) 33 | 34 | def _init_bias(self, data): 35 | nn.init.constant_(data, 0) 36 | 37 | def _init_gamma(self, data): 38 | if self.gamma is None: 39 | nn.init.constant_(data, 1.0) 40 | else: 41 | nn.init.normal_(data, 1.0, self.gamma) 42 | 43 | def _init_beta(self, data): 44 | nn.init.constant_(data, 0) 45 | 46 | 47 | class Bilinear(Initializer): 48 | def __init__(self, scale, groups, in_channels, **kwargs): 49 | super().__init__(**kwargs) 50 | self.scale = scale 51 | self.groups = groups 52 | self.in_channels = in_channels 53 | 54 | def _init_weight(self, data): 55 | """Reset the weight and bias.""" 56 | bilinear_kernel = self.get_bilinear_kernel(self.scale) 57 | weight = torch.zeros_like(data) 58 | for i in range(self.in_channels): 59 | if self.groups == 1: 60 | j = i 61 | else: 62 | j = 0 63 | weight[i, j] = bilinear_kernel 64 | data[:] = weight 65 | 66 | @staticmethod 67 | def get_bilinear_kernel(scale): 68 | """Generate a bilinear upsampling kernel.""" 69 | kernel_size = 2 * scale - scale % 2 70 | scale = (kernel_size + 1) // 2 71 | center = scale - 0.5 * (1 + kernel_size % 2) 72 | 73 | og = np.ogrid[:kernel_size, :kernel_size] 74 | kernel = (1 - np.abs(og[0] - center) / scale) * (1 - np.abs(og[1] - center) / scale) 75 | 76 | return torch.tensor(kernel, dtype=torch.float32) 77 | 78 | 79 | class XavierGluon(Initializer): 80 | def __init__(self, rnd_type='uniform', factor_type='avg', magnitude=3, **kwargs): 81 | super().__init__(**kwargs) 82 | 83 | self.rnd_type = rnd_type 84 | self.factor_type = factor_type 85 | self.magnitude = float(magnitude) 86 | 87 | def _init_weight(self, arr): 88 | fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(arr) 89 | 90 | if self.factor_type == 'avg': 91 | factor = (fan_in + fan_out) / 2.0 92 | elif self.factor_type == 'in': 93 | factor = fan_in 94 | elif self.factor_type == 'out': 95 | factor = fan_out 96 | else: 97 | raise ValueError('Incorrect factor type') 98 | scale = np.sqrt(self.magnitude / factor) 99 | 100 | if self.rnd_type == 'uniform': 101 | nn.init.uniform_(arr, -scale, scale) 102 | elif self.rnd_type == 'gaussian': 103 | nn.init.normal_(arr, 0, scale) 104 | else: 105 | raise ValueError('Unknown random type') 106 | -------------------------------------------------------------------------------- /iharm/model/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from iharm.utils import misc 5 | 6 | 7 | class Loss(nn.Module): 8 | def __init__(self, pred_outputs, gt_outputs): 9 | super().__init__() 10 | self.pred_outputs = pred_outputs 11 | self.gt_outputs = gt_outputs 12 | 13 | 14 | class MSE(Loss): 15 | def __init__(self, pred_name='images', gt_image_name='target_images'): 16 | super(MSE, self).__init__(pred_outputs=(pred_name,), gt_outputs=(gt_image_name,)) 17 | 18 | def forward(self, pred, label): 19 | label = label.view(pred.size()) 20 | loss = torch.mean((pred - label) ** 2, dim=misc.get_dims_with_exclusion(label.dim(), 0)) 21 | return loss 22 | 23 | 24 | class MaskWeightedMSE(Loss): 25 | def __init__(self, min_area=1000.0, pred_name='images', 26 | gt_image_name='target_images', gt_mask_name='masks'): 27 | super(MaskWeightedMSE, self).__init__(pred_outputs=(pred_name, ), 28 | gt_outputs=(gt_image_name, gt_mask_name)) 29 | self.min_area = min_area 30 | 31 | def forward(self, pred, label, mask): 32 | label = label.view(pred.size()) 33 | reduce_dims = misc.get_dims_with_exclusion(label.dim(), 0) 34 | 35 | loss = (pred - label) ** 2 36 | delimeter = pred.size(1) * torch.clamp_min(torch.sum(mask, dim=reduce_dims), self.min_area) 37 | loss = torch.sum(loss, dim=reduce_dims) / delimeter 38 | 39 | return loss 40 | -------------------------------------------------------------------------------- /iharm/model/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | class TrainMetric(object): 6 | def __init__(self, pred_outputs, gt_outputs, epsilon=1e-6): 7 | self.pred_outputs = pred_outputs 8 | self.gt_outputs = gt_outputs 9 | self.epsilon = epsilon 10 | self._last_batch_metric = 0.0 11 | self._epoch_metric_sum = 0.0 12 | self._epoch_batch_count = 0 13 | 14 | def compute(self, *args, **kwargs): 15 | raise NotImplementedError 16 | 17 | def update(self, *args, **kwargs): 18 | self._last_batch_metric = self.compute(*args, **kwargs) 19 | self._epoch_metric_sum += self._last_batch_metric 20 | self._epoch_batch_count += 1 21 | 22 | def get_epoch_value(self): 23 | if self._epoch_batch_count > 0: 24 | return self._epoch_metric_sum / self._epoch_batch_count 25 | else: 26 | return 0.0 27 | 28 | def reset_epoch_stats(self): 29 | self._epoch_metric_sum = 0.0 30 | self._epoch_batch_count = 0 31 | 32 | def log_states(self, sw, tag_prefix, global_step): 33 | sw.add_scalar(tag=tag_prefix, value=self._last_batch_metric, global_step=global_step) 34 | 35 | @property 36 | def name(self): 37 | return type(self).__name__ 38 | 39 | 40 | class PSNRMetric(TrainMetric): 41 | def __init__(self, pred_output='instances', gt_output='instances'): 42 | super(PSNRMetric, self).__init__((pred_output, ), (gt_output, )) 43 | 44 | def compute(self, pred, gt): 45 | mse = F.mse_loss(pred, gt) 46 | squared_max = gt.max() ** 2 47 | psnr = 10 * torch.log10(squared_max / (mse + self.epsilon)) 48 | return psnr.item() 49 | 50 | 51 | class DenormalizedTrainMetric(TrainMetric): 52 | def __init__(self, pred_outputs, gt_outputs, mean=None, std=None): 53 | super(DenormalizedTrainMetric, self).__init__(pred_outputs, gt_outputs) 54 | self.mean = torch.zeros(1) if mean is None else mean 55 | self.std = torch.ones(1) if std is None else std 56 | self.device = None 57 | 58 | def init_device(self, input_device): 59 | if self.device is None: 60 | self.device = input_device 61 | self.mean = self.mean.to(self.device) 62 | self.std = self.std.to(self.device) 63 | 64 | def denormalize(self, tensor): 65 | self.init_device(tensor.device) 66 | return tensor * self.std + self.mean 67 | 68 | def update(self, *args, **kwargs): 69 | self._last_batch_metric = self.compute(*args, **kwargs) 70 | self._epoch_metric_sum += self._last_batch_metric 71 | self._epoch_batch_count += 1 72 | 73 | 74 | class DenormalizedPSNRMetric(DenormalizedTrainMetric): 75 | def __init__( 76 | self, 77 | pred_output='instances', gt_output='instances', 78 | mean=None, std=None, 79 | ): 80 | super(DenormalizedPSNRMetric, self).__init__((pred_output, ), (gt_output, ), mean, std) 81 | 82 | def compute(self, pred, gt): 83 | denormalized_pred = torch.clamp(self.denormalize(pred), 0, 1) 84 | denormalized_gt = self.denormalize(gt) 85 | return PSNRMetric.compute(self, denormalized_pred, denormalized_gt) 86 | 87 | 88 | class DenormalizedMSEMetric(DenormalizedTrainMetric): 89 | def __init__( 90 | self, 91 | pred_output='instances', gt_output='instances', 92 | mean=None, std=None, 93 | ): 94 | super(DenormalizedMSEMetric, self).__init__((pred_output, ), (gt_output, ), mean, std) 95 | 96 | def compute(self, pred, gt): 97 | denormalized_pred = self.denormalize(pred) * 255 98 | denormalized_gt = self.denormalize(gt) * 255 99 | return F.mse_loss(denormalized_pred, denormalized_gt).item() 100 | -------------------------------------------------------------------------------- /iharm/model/modeling/basic_blocks.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numbers 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn as nn 7 | 8 | 9 | class ConvHead(nn.Module): 10 | def __init__(self, out_channels, in_channels=32, num_layers=1, 11 | kernel_size=3, padding=1, 12 | norm_layer=nn.BatchNorm2d): 13 | super(ConvHead, self).__init__() 14 | convhead = [] 15 | 16 | for i in range(num_layers): 17 | convhead.extend([ 18 | nn.Conv2d(in_channels, in_channels, kernel_size, padding=padding), 19 | nn.ReLU(), 20 | norm_layer(in_channels) if norm_layer is not None else nn.Identity() 21 | ]) 22 | convhead.append(nn.Conv2d(in_channels, out_channels, 1, padding=0)) 23 | 24 | self.convhead = nn.Sequential(*convhead) 25 | 26 | def forward(self, *inputs): 27 | return self.convhead(inputs[0]) 28 | 29 | 30 | class SepConvHead(nn.Module): 31 | def __init__(self, num_outputs, in_channels, mid_channels, num_layers=1, 32 | kernel_size=3, padding=1, dropout_ratio=0.0, dropout_indx=0, 33 | norm_layer=nn.BatchNorm2d): 34 | super(SepConvHead, self).__init__() 35 | 36 | sepconvhead = [] 37 | 38 | for i in range(num_layers): 39 | sepconvhead.append( 40 | SeparableConv2d(in_channels=in_channels if i == 0 else mid_channels, 41 | out_channels=mid_channels, 42 | dw_kernel=kernel_size, dw_padding=padding, 43 | norm_layer=norm_layer, activation='relu') 44 | ) 45 | if dropout_ratio > 0 and dropout_indx == i: 46 | sepconvhead.append(nn.Dropout(dropout_ratio)) 47 | 48 | sepconvhead.append( 49 | nn.Conv2d(in_channels=mid_channels, out_channels=num_outputs, kernel_size=1, padding=0) 50 | ) 51 | 52 | self.layers = nn.Sequential(*sepconvhead) 53 | 54 | def forward(self, *inputs): 55 | x = inputs[0] 56 | 57 | return self.layers(x) 58 | 59 | 60 | def select_activation_function(activation): 61 | if isinstance(activation, str): 62 | if activation.lower() == 'relu': 63 | return nn.ReLU 64 | elif activation.lower() == 'softplus': 65 | return nn.Softplus 66 | else: 67 | raise ValueError(f"Unknown activation type {activation}") 68 | elif isinstance(activation, nn.Module): 69 | return activation 70 | else: 71 | raise ValueError(f"Unknown activation type {activation}") 72 | 73 | 74 | class SeparableConv2d(nn.Module): 75 | def __init__(self, in_channels, out_channels, dw_kernel, dw_padding, dw_stride=1, 76 | activation=None, use_bias=False, norm_layer=None): 77 | super(SeparableConv2d, self).__init__() 78 | _activation = select_activation_function(activation) 79 | self.body = nn.Sequential( 80 | nn.Conv2d(in_channels, in_channels, kernel_size=dw_kernel, stride=dw_stride, 81 | padding=dw_padding, bias=use_bias, groups=in_channels), 82 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=use_bias), 83 | norm_layer(out_channels) if norm_layer is not None else nn.Identity(), 84 | _activation() 85 | ) 86 | 87 | def forward(self, x): 88 | return self.body(x) 89 | 90 | 91 | class ConvBlock(nn.Module): 92 | def __init__( 93 | self, 94 | in_channels, out_channels, 95 | kernel_size=4, stride=2, padding=1, 96 | norm_layer=nn.BatchNorm2d, activation=nn.ELU, 97 | bias=True, 98 | ): 99 | super(ConvBlock, self).__init__() 100 | self.block = nn.Sequential( 101 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias), 102 | norm_layer(out_channels) if norm_layer is not None else nn.Identity(), 103 | activation(), 104 | ) 105 | 106 | def forward(self, x): 107 | return self.block(x) 108 | 109 | 110 | class GaussianSmoothing(nn.Module): 111 | """ 112 | https://discuss.pytorch.org/t/is-there-anyway-to-do-gaussian-filtering-for-an-image-2d-3d-in-pytorch/12351/10 113 | Apply gaussian smoothing on a tensor (1d, 2d, 3d). 114 | Filtering is performed seperately for each channel in the input using a depthwise convolution. 115 | Arguments: 116 | channels (int, sequence): Number of channels of the input tensors. 117 | Output will have this number of channels as well. 118 | kernel_size (int, sequence): Size of the gaussian kernel. 119 | sigma (float, sequence): Standard deviation of the gaussian kernel. 120 | dim (int, optional): The number of dimensions of the data. Default value is 2 (spatial). 121 | """ 122 | def __init__(self, channels, kernel_size, sigma, padding=0, dim=2): 123 | super(GaussianSmoothing, self).__init__() 124 | if isinstance(kernel_size, numbers.Number): 125 | kernel_size = [kernel_size] * dim 126 | if isinstance(sigma, numbers.Number): 127 | sigma = [sigma] * dim 128 | 129 | # The gaussian kernel is the product of the gaussian function of each dimension. 130 | kernel = 1. 131 | meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size]) 132 | for size, std, grid in zip(kernel_size, sigma, meshgrids): 133 | mean = (size - 1) / 2. 134 | kernel *= torch.exp(-((grid - mean) / std) ** 2 / 2) / (std * (2 * math.pi) ** 0.5) 135 | # Make sure sum of values in gaussian kernel equals 1. 136 | kernel = kernel / torch.sum(kernel) 137 | # Reshape to depthwise convolutional weight. 138 | kernel = kernel.view(1, 1, *kernel.size()) 139 | kernel = torch.repeat_interleave(kernel, channels, 0) 140 | 141 | self.register_buffer('weight', kernel) 142 | self.groups = channels 143 | self.padding = padding 144 | 145 | if dim == 1: 146 | self.conv = F.conv1d 147 | elif dim == 2: 148 | self.conv = F.conv2d 149 | elif dim == 3: 150 | self.conv = F.conv3d 151 | else: 152 | raise RuntimeError('Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)) 153 | 154 | def forward(self, input): 155 | """ 156 | Apply gaussian filter to input. 157 | Arguments: 158 | input (torch.Tensor): Input to apply gaussian filter on. 159 | Returns: 160 | filtered (torch.Tensor): Filtered output. 161 | """ 162 | return self.conv(input, weight=self.weight, padding=self.padding, groups=self.groups) 163 | 164 | 165 | class MaxPoolDownSize(nn.Module): 166 | def __init__(self, in_channels, mid_channels, out_channels, depth): 167 | super(MaxPoolDownSize, self).__init__() 168 | self.depth = depth 169 | self.reduce_conv = ConvBlock(in_channels, mid_channels, kernel_size=1, stride=1, padding=0) 170 | self.convs = nn.ModuleList([ 171 | ConvBlock(mid_channels, out_channels, kernel_size=3, stride=1, padding=1) 172 | for conv_i in range(depth) 173 | ]) 174 | self.pool2d = nn.MaxPool2d(kernel_size=2) 175 | 176 | def forward(self, x): 177 | outputs = [] 178 | 179 | output = self.reduce_conv(x) 180 | 181 | for conv_i, conv in enumerate(self.convs): 182 | output = output if conv_i == 0 else self.pool2d(output) 183 | outputs.append(conv(output)) 184 | 185 | return outputs 186 | -------------------------------------------------------------------------------- /iharm/model/modeling/conv_autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | 4 | from iharm.model.modeling.basic_blocks import ConvBlock 5 | from iharm.model.ops import MaskedChannelAttention, FeaturesConnector 6 | 7 | 8 | class ConvEncoder(nn.Module): 9 | def __init__( 10 | self, 11 | depth, ch, 12 | norm_layer, batchnorm_from, max_channels, 13 | backbone_from, backbone_channels=None, backbone_mode='' 14 | ): 15 | super(ConvEncoder, self).__init__() 16 | self.depth = depth 17 | self.backbone_from = backbone_from 18 | backbone_channels = [] if backbone_channels is None else backbone_channels[::-1] 19 | 20 | in_channels = 4 21 | out_channels = ch 22 | 23 | self.block0 = ConvBlock(in_channels, out_channels, norm_layer=norm_layer if batchnorm_from == 0 else None) 24 | self.block1 = ConvBlock(out_channels, out_channels, norm_layer=norm_layer if 0 <= batchnorm_from <= 1 else None) 25 | self.blocks_channels = [out_channels, out_channels] 26 | 27 | self.blocks_connected = nn.ModuleDict() 28 | self.connectors = nn.ModuleDict() 29 | for block_i in range(2, depth): 30 | if block_i % 2: 31 | in_channels = out_channels 32 | else: 33 | in_channels, out_channels = out_channels, min(2 * out_channels, max_channels) 34 | 35 | if 0 <= backbone_from <= block_i and len(backbone_channels): 36 | stage_channels = backbone_channels.pop() 37 | connector = FeaturesConnector(backbone_mode, in_channels, stage_channels, in_channels) 38 | self.connectors[f'connector{block_i}'] = connector 39 | in_channels = connector.output_channels 40 | 41 | self.blocks_connected[f'block{block_i}'] = ConvBlock( 42 | in_channels, out_channels, 43 | norm_layer=norm_layer if 0 <= batchnorm_from <= block_i else None, 44 | padding=int(block_i < depth - 1) 45 | ) 46 | self.blocks_channels += [out_channels] 47 | 48 | def forward(self, x, backbone_features): 49 | backbone_features = [] if backbone_features is None else backbone_features[::-1] 50 | 51 | outputs = [self.block0(x)] 52 | outputs += [self.block1(outputs[-1])] 53 | 54 | for block_i in range(2, self.depth): 55 | block = self.blocks_connected[f'block{block_i}'] 56 | output = outputs[-1] 57 | connector_name = f'connector{block_i}' 58 | if connector_name in self.connectors: 59 | stage_features = backbone_features.pop() 60 | connector = self.connectors[connector_name] 61 | output = connector(output, stage_features) 62 | outputs += [block(output)] 63 | 64 | return outputs[::-1] 65 | 66 | 67 | class DeconvDecoder(nn.Module): 68 | def __init__(self, depth, encoder_blocks_channels, norm_layer, attend_from=-1, image_fusion=False): 69 | super(DeconvDecoder, self).__init__() 70 | self.image_fusion = image_fusion 71 | self.deconv_blocks = nn.ModuleList() 72 | 73 | in_channels = encoder_blocks_channels.pop() 74 | out_channels = in_channels 75 | for d in range(depth): 76 | out_channels = encoder_blocks_channels.pop() if len(encoder_blocks_channels) else in_channels // 2 77 | self.deconv_blocks.append(SEDeconvBlock( 78 | in_channels, out_channels, 79 | norm_layer=norm_layer, 80 | padding=0 if d == 0 else 1, 81 | with_se=0 <= attend_from <= d 82 | )) 83 | in_channels = out_channels 84 | 85 | if self.image_fusion: 86 | self.conv_attention = nn.Conv2d(out_channels, 1, kernel_size=1) 87 | self.to_rgb = nn.Conv2d(out_channels, 3, kernel_size=1) 88 | 89 | def forward(self, encoder_outputs, image, mask=None): 90 | output = encoder_outputs[0] 91 | for block, skip_output in zip(self.deconv_blocks[:-1], encoder_outputs[1:]): 92 | output = block(output, mask) 93 | output = output + skip_output 94 | output = self.deconv_blocks[-1](output, mask) 95 | 96 | if self.image_fusion: 97 | attention_map = torch.sigmoid(3.0 * self.conv_attention(output)) 98 | output = attention_map * image + (1.0 - attention_map) * self.to_rgb(output) 99 | else: 100 | output = self.to_rgb(output) 101 | 102 | return output 103 | 104 | 105 | class SEDeconvBlock(nn.Module): 106 | def __init__( 107 | self, 108 | in_channels, out_channels, 109 | kernel_size=4, stride=2, padding=1, 110 | norm_layer=nn.BatchNorm2d, activation=nn.ELU, 111 | with_se=False 112 | ): 113 | super(SEDeconvBlock, self).__init__() 114 | self.with_se = with_se 115 | self.block = nn.Sequential( 116 | nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding), 117 | norm_layer(out_channels) if norm_layer is not None else nn.Identity(), 118 | activation(), 119 | ) 120 | if self.with_se: 121 | self.se = MaskedChannelAttention(out_channels) 122 | 123 | def forward(self, x, mask=None): 124 | out = self.block(x) 125 | if self.with_se: 126 | out = self.se(out, mask) 127 | return out 128 | -------------------------------------------------------------------------------- /iharm/model/modeling/deeplab_v3.py: -------------------------------------------------------------------------------- 1 | from contextlib import ExitStack 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | from iharm.model.modeling.basic_blocks import select_activation_function 8 | from .basic_blocks import SeparableConv2d 9 | from .resnet import ResNetBackbone 10 | 11 | 12 | class DeepLabV3Plus(nn.Module): 13 | def __init__(self, backbone='resnet50', norm_layer=nn.BatchNorm2d, 14 | backbone_norm_layer=None, 15 | ch=256, 16 | project_dropout=0.5, 17 | inference_mode=False, 18 | **kwargs): 19 | super(DeepLabV3Plus, self).__init__() 20 | if backbone_norm_layer is None: 21 | backbone_norm_layer = norm_layer 22 | 23 | self.backbone_name = backbone 24 | self.norm_layer = norm_layer 25 | self.backbone_norm_layer = backbone_norm_layer 26 | self.inference_mode = False 27 | self.ch = ch 28 | self.aspp_in_channels = 2048 29 | self.skip_project_in_channels = 256 # layer 1 out_channels 30 | 31 | self._kwargs = kwargs 32 | if backbone == 'resnet34': 33 | self.aspp_in_channels = 512 34 | self.skip_project_in_channels = 64 35 | 36 | self.backbone = ResNetBackbone(backbone=self.backbone_name, pretrained_base=False, 37 | norm_layer=self.backbone_norm_layer, **kwargs) 38 | 39 | self.head = _DeepLabHead(in_channels=ch + 32, mid_channels=ch, out_channels=ch, 40 | norm_layer=self.norm_layer) 41 | self.skip_project = _SkipProject(self.skip_project_in_channels, 32, norm_layer=self.norm_layer) 42 | self.aspp = _ASPP(in_channels=self.aspp_in_channels, 43 | atrous_rates=[12, 24, 36], 44 | out_channels=ch, 45 | project_dropout=project_dropout, 46 | norm_layer=self.norm_layer) 47 | 48 | if inference_mode: 49 | self.set_prediction_mode() 50 | 51 | def load_pretrained_weights(self): 52 | pretrained = ResNetBackbone(backbone=self.backbone_name, pretrained_base=True, 53 | norm_layer=self.backbone_norm_layer, **self._kwargs) 54 | backbone_state_dict = self.backbone.state_dict() 55 | pretrained_state_dict = pretrained.state_dict() 56 | 57 | backbone_state_dict.update(pretrained_state_dict) 58 | self.backbone.load_state_dict(backbone_state_dict) 59 | 60 | if self.inference_mode: 61 | for param in self.backbone.parameters(): 62 | param.requires_grad = False 63 | 64 | def set_prediction_mode(self): 65 | self.inference_mode = True 66 | self.eval() 67 | 68 | def forward(self, x, mask_features=None): 69 | with ExitStack() as stack: 70 | if self.inference_mode: 71 | stack.enter_context(torch.no_grad()) 72 | 73 | c1, _, c3, c4 = self.backbone(x, mask_features) 74 | c1 = self.skip_project(c1) 75 | 76 | x = self.aspp(c4) 77 | x = F.interpolate(x, c1.size()[2:], mode='bilinear', align_corners=True) 78 | x = torch.cat((x, c1), dim=1) 79 | x = self.head(x) 80 | 81 | return x, 82 | 83 | 84 | class _SkipProject(nn.Module): 85 | def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d): 86 | super(_SkipProject, self).__init__() 87 | _activation = select_activation_function("relu") 88 | 89 | self.skip_project = nn.Sequential( 90 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), 91 | norm_layer(out_channels), 92 | _activation() 93 | ) 94 | 95 | def forward(self, x): 96 | return self.skip_project(x) 97 | 98 | 99 | class _DeepLabHead(nn.Module): 100 | def __init__(self, out_channels, in_channels, mid_channels=256, norm_layer=nn.BatchNorm2d): 101 | super(_DeepLabHead, self).__init__() 102 | 103 | self.block = nn.Sequential( 104 | SeparableConv2d(in_channels=in_channels, out_channels=mid_channels, dw_kernel=3, 105 | dw_padding=1, activation='relu', norm_layer=norm_layer), 106 | SeparableConv2d(in_channels=mid_channels, out_channels=mid_channels, dw_kernel=3, 107 | dw_padding=1, activation='relu', norm_layer=norm_layer), 108 | nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=1) 109 | ) 110 | 111 | def forward(self, x): 112 | return self.block(x) 113 | 114 | 115 | class _ASPP(nn.Module): 116 | def __init__(self, in_channels, atrous_rates, out_channels=256, 117 | project_dropout=0.5, norm_layer=nn.BatchNorm2d): 118 | super(_ASPP, self).__init__() 119 | 120 | b0 = nn.Sequential( 121 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=False), 122 | norm_layer(out_channels), 123 | nn.ReLU() 124 | ) 125 | 126 | rate1, rate2, rate3 = tuple(atrous_rates) 127 | b1 = _ASPPConv(in_channels, out_channels, rate1, norm_layer) 128 | b2 = _ASPPConv(in_channels, out_channels, rate2, norm_layer) 129 | b3 = _ASPPConv(in_channels, out_channels, rate3, norm_layer) 130 | b4 = _AsppPooling(in_channels, out_channels, norm_layer=norm_layer) 131 | 132 | self.concurent = nn.ModuleList([b0, b1, b2, b3, b4]) 133 | 134 | project = [ 135 | nn.Conv2d(in_channels=5*out_channels, out_channels=out_channels, 136 | kernel_size=1, bias=False), 137 | norm_layer(out_channels), 138 | nn.ReLU() 139 | ] 140 | if project_dropout > 0: 141 | project.append(nn.Dropout(project_dropout)) 142 | self.project = nn.Sequential(*project) 143 | 144 | def forward(self, x): 145 | x = torch.cat([block(x) for block in self.concurent], dim=1) 146 | 147 | return self.project(x) 148 | 149 | 150 | class _AsppPooling(nn.Module): 151 | def __init__(self, in_channels, out_channels, norm_layer): 152 | super(_AsppPooling, self).__init__() 153 | 154 | self.gap = nn.Sequential( 155 | nn.AdaptiveAvgPool2d((1, 1)), 156 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 157 | kernel_size=1, bias=False), 158 | norm_layer(out_channels), 159 | nn.ReLU() 160 | ) 161 | 162 | def forward(self, x): 163 | pool = self.gap(x) 164 | return F.interpolate(pool, x.size()[2:], mode='bilinear', align_corners=True) 165 | 166 | 167 | def _ASPPConv(in_channels, out_channels, atrous_rate, norm_layer): 168 | block = nn.Sequential( 169 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 170 | kernel_size=3, padding=atrous_rate, 171 | dilation=atrous_rate, bias=False), 172 | norm_layer(out_channels), 173 | nn.ReLU() 174 | ) 175 | 176 | return block 177 | -------------------------------------------------------------------------------- /iharm/model/modeling/ocr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch._utils 4 | import torch.nn.functional as F 5 | 6 | 7 | class SpatialGather_Module(nn.Module): 8 | """ 9 | Aggregate the context features according to the initial 10 | predicted probability distribution. 11 | Employ the soft-weighted method to aggregate the context. 12 | """ 13 | 14 | def __init__(self, cls_num=0, scale=1): 15 | super(SpatialGather_Module, self).__init__() 16 | self.cls_num = cls_num 17 | self.scale = scale 18 | 19 | def forward(self, feats, probs): 20 | batch_size, c, h, w = probs.size(0), probs.size(1), probs.size(2), probs.size(3) 21 | probs = probs.view(batch_size, c, -1) 22 | feats = feats.view(batch_size, feats.size(1), -1) 23 | feats = feats.permute(0, 2, 1) # batch x hw x c 24 | probs = F.softmax(self.scale * probs, dim=2) # batch x k x hw 25 | ocr_context = torch.matmul(probs, feats) \ 26 | .permute(0, 2, 1).unsqueeze(3).contiguous() # batch x k x c 27 | return ocr_context 28 | 29 | 30 | class SpatialOCR_Module(nn.Module): 31 | """ 32 | Implementation of the OCR module: 33 | We aggregate the global object representation to update the representation for each pixel. 34 | """ 35 | 36 | def __init__(self, 37 | in_channels, 38 | key_channels, 39 | out_channels, 40 | scale=1, 41 | dropout=0.1, 42 | norm_layer=nn.BatchNorm2d, 43 | align_corners=True): 44 | super(SpatialOCR_Module, self).__init__() 45 | self.object_context_block = ObjectAttentionBlock2D(in_channels, key_channels, scale, 46 | norm_layer, align_corners) 47 | _in_channels = 2 * in_channels 48 | 49 | self.conv_bn_dropout = nn.Sequential( 50 | nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0, bias=False), 51 | nn.Sequential(norm_layer(out_channels), nn.ReLU(inplace=True)), 52 | nn.Dropout2d(dropout) 53 | ) 54 | 55 | def forward(self, feats, proxy_feats): 56 | context = self.object_context_block(feats, proxy_feats) 57 | 58 | output = self.conv_bn_dropout(torch.cat([context, feats], 1)) 59 | 60 | return output 61 | 62 | 63 | class ObjectAttentionBlock2D(nn.Module): 64 | ''' 65 | The basic implementation for object context block 66 | Input: 67 | N X C X H X W 68 | Parameters: 69 | in_channels : the dimension of the input feature map 70 | key_channels : the dimension after the key/query transform 71 | scale : choose the scale to downsample the input feature maps (save memory cost) 72 | bn_type : specify the bn type 73 | Return: 74 | N X C X H X W 75 | ''' 76 | 77 | def __init__(self, 78 | in_channels, 79 | key_channels, 80 | scale=1, 81 | norm_layer=nn.BatchNorm2d, 82 | align_corners=True): 83 | super(ObjectAttentionBlock2D, self).__init__() 84 | self.scale = scale 85 | self.in_channels = in_channels 86 | self.key_channels = key_channels 87 | self.align_corners = align_corners 88 | 89 | self.pool = nn.MaxPool2d(kernel_size=(scale, scale)) 90 | self.f_pixel = nn.Sequential( 91 | nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, 92 | kernel_size=1, stride=1, padding=0, bias=False), 93 | nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)), 94 | nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels, 95 | kernel_size=1, stride=1, padding=0, bias=False), 96 | nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)) 97 | ) 98 | self.f_object = nn.Sequential( 99 | nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, 100 | kernel_size=1, stride=1, padding=0, bias=False), 101 | nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)), 102 | nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels, 103 | kernel_size=1, stride=1, padding=0, bias=False), 104 | nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)) 105 | ) 106 | self.f_down = nn.Sequential( 107 | nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, 108 | kernel_size=1, stride=1, padding=0, bias=False), 109 | nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)) 110 | ) 111 | self.f_up = nn.Sequential( 112 | nn.Conv2d(in_channels=self.key_channels, out_channels=self.in_channels, 113 | kernel_size=1, stride=1, padding=0, bias=False), 114 | nn.Sequential(norm_layer(self.in_channels), nn.ReLU(inplace=True)) 115 | ) 116 | 117 | def forward(self, x, proxy): 118 | batch_size, h, w = x.size(0), x.size(2), x.size(3) 119 | if self.scale > 1: 120 | x = self.pool(x) 121 | 122 | query = self.f_pixel(x).view(batch_size, self.key_channels, -1) 123 | query = query.permute(0, 2, 1) 124 | key = self.f_object(proxy).view(batch_size, self.key_channels, -1) 125 | value = self.f_down(proxy).view(batch_size, self.key_channels, -1) 126 | value = value.permute(0, 2, 1) 127 | 128 | sim_map = torch.matmul(query, key) 129 | sim_map = (self.key_channels ** -.5) * sim_map 130 | sim_map = F.softmax(sim_map, dim=-1) 131 | 132 | # add bg context ... 133 | context = torch.matmul(sim_map, value) 134 | context = context.permute(0, 2, 1).contiguous() 135 | context = context.view(batch_size, self.key_channels, *x.size()[2:]) 136 | context = self.f_up(context) 137 | if self.scale > 1: 138 | context = F.interpolate(input=context, size=(h, w), 139 | mode='bilinear', align_corners=self.align_corners) 140 | 141 | return context 142 | -------------------------------------------------------------------------------- /iharm/model/modeling/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .resnetv1b import resnet34_v1b, resnet50_v1s, resnet101_v1s, resnet152_v1s 3 | 4 | 5 | class ResNetBackbone(torch.nn.Module): 6 | def __init__(self, backbone='resnet50', pretrained_base=True, dilated=True, **kwargs): 7 | super(ResNetBackbone, self).__init__() 8 | 9 | if backbone == 'resnet34': 10 | pretrained = resnet34_v1b(pretrained=pretrained_base, dilated=dilated, **kwargs) 11 | elif backbone == 'resnet50': 12 | pretrained = resnet50_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs) 13 | elif backbone == 'resnet101': 14 | pretrained = resnet101_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs) 15 | elif backbone == 'resnet152': 16 | pretrained = resnet152_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs) 17 | else: 18 | raise RuntimeError(f'unknown backbone: {backbone}') 19 | 20 | self.conv1 = pretrained.conv1 21 | self.bn1 = pretrained.bn1 22 | self.relu = pretrained.relu 23 | self.maxpool = pretrained.maxpool 24 | self.layer1 = pretrained.layer1 25 | self.layer2 = pretrained.layer2 26 | self.layer3 = pretrained.layer3 27 | self.layer4 = pretrained.layer4 28 | 29 | def forward(self, x, mask_features=None): 30 | x = self.conv1(x) 31 | x = self.bn1(x) 32 | x = self.relu(x) 33 | if mask_features is not None: 34 | x = x + mask_features 35 | x = self.maxpool(x) 36 | c1 = self.layer1(x) 37 | c2 = self.layer2(c1) 38 | c3 = self.layer3(c2) 39 | c4 = self.layer4(c3) 40 | 41 | return c1, c2, c3, c4 42 | -------------------------------------------------------------------------------- /iharm/model/modeling/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from functools import partial 4 | 5 | from iharm.model.modeling.basic_blocks import ConvBlock 6 | from iharm.model.ops import FeaturesConnector 7 | 8 | 9 | class UNetEncoder(nn.Module): 10 | def __init__( 11 | self, 12 | depth, ch, 13 | norm_layer, batchnorm_from, max_channels, 14 | backbone_from, backbone_channels=None, backbone_mode='' 15 | ): 16 | super(UNetEncoder, self).__init__() 17 | self.depth = depth 18 | self.backbone_from = backbone_from 19 | self.block_channels = [] 20 | backbone_channels = [] if backbone_channels is None else backbone_channels[::-1] 21 | relu = partial(nn.ReLU, inplace=True) 22 | 23 | in_channels = 4 24 | out_channels = ch 25 | 26 | self.block0 = UNetDownBlock( 27 | in_channels, out_channels, 28 | norm_layer=norm_layer if batchnorm_from == 0 else None, 29 | activation=relu, 30 | pool=True, padding=1, 31 | ) 32 | self.block_channels.append(out_channels) 33 | in_channels, out_channels = out_channels, min(2 * out_channels, max_channels) 34 | self.block1 = UNetDownBlock( 35 | in_channels, out_channels, 36 | norm_layer=norm_layer if 0 <= batchnorm_from <= 1 else None, 37 | activation=relu, 38 | pool=True, padding=1, 39 | ) 40 | self.block_channels.append(out_channels) 41 | 42 | self.blocks_connected = nn.ModuleDict() 43 | self.connectors = nn.ModuleDict() 44 | for block_i in range(2, depth): 45 | in_channels, out_channels = out_channels, min(2 * out_channels, max_channels) 46 | if 0 <= backbone_from <= block_i and len(backbone_channels): 47 | stage_channels = backbone_channels.pop() 48 | connector = FeaturesConnector(backbone_mode, in_channels, stage_channels, in_channels) 49 | self.connectors[f'connector{block_i}'] = connector 50 | in_channels = connector.output_channels 51 | self.blocks_connected[f'block{block_i}'] = UNetDownBlock( 52 | in_channels, out_channels, 53 | norm_layer=norm_layer if 0 <= batchnorm_from <= block_i else None, 54 | activation=relu, padding=1, 55 | pool=block_i < depth - 1, 56 | ) 57 | self.block_channels.append(out_channels) 58 | 59 | def forward(self, x, backbone_features): 60 | backbone_features = [] if backbone_features is None else backbone_features[::-1] 61 | outputs = [] 62 | 63 | block_input = x 64 | output, block_input = self.block0(block_input) 65 | outputs.append(output) 66 | output, block_input = self.block1(block_input) 67 | outputs.append(output) 68 | 69 | for block_i in range(2, self.depth): 70 | block = self.blocks_connected[f'block{block_i}'] 71 | connector_name = f'connector{block_i}' 72 | if connector_name in self.connectors: 73 | stage_features = backbone_features.pop() 74 | connector = self.connectors[connector_name] 75 | block_input = connector(block_input, stage_features) 76 | output, block_input = block(block_input) 77 | outputs.append(output) 78 | 79 | return outputs[::-1] 80 | 81 | 82 | class UNetDecoder(nn.Module): 83 | def __init__(self, depth, encoder_blocks_channels, norm_layer, 84 | attention_layer=None, attend_from=3, image_fusion=False): 85 | super(UNetDecoder, self).__init__() 86 | self.up_blocks = nn.ModuleList() 87 | self.image_fusion = image_fusion 88 | in_channels = encoder_blocks_channels.pop() 89 | out_channels = in_channels 90 | # Last encoder layer doesn't pool, so there're only (depth - 1) deconvs 91 | for d in range(depth - 1): 92 | out_channels = encoder_blocks_channels.pop() if len(encoder_blocks_channels) else in_channels // 2 93 | stage_attention_layer = attention_layer if 0 <= attend_from <= d else None 94 | self.up_blocks.append(UNetUpBlock( 95 | in_channels, out_channels, out_channels, 96 | norm_layer=norm_layer, activation=partial(nn.ReLU, inplace=True), 97 | padding=1, 98 | attention_layer=stage_attention_layer, 99 | )) 100 | in_channels = out_channels 101 | 102 | if self.image_fusion: 103 | self.conv_attention = nn.Conv2d(out_channels, 1, kernel_size=1) 104 | self.to_rgb = nn.Conv2d(out_channels, 3, kernel_size=1) 105 | 106 | def forward(self, encoder_outputs, input_image, mask): 107 | output = encoder_outputs[0] 108 | for block, skip_output in zip(self.up_blocks, encoder_outputs[1:]): 109 | output = block(output, skip_output, mask) 110 | 111 | if self.image_fusion: 112 | attention_map = torch.sigmoid(3.0 * self.conv_attention(output)) 113 | output = attention_map * input_image + (1.0 - attention_map) * self.to_rgb(output) 114 | else: 115 | output = self.to_rgb(output) 116 | 117 | return output 118 | 119 | 120 | class UNetDownBlock(nn.Module): 121 | def __init__(self, in_channels, out_channels, norm_layer, activation, pool, padding): 122 | super(UNetDownBlock, self).__init__() 123 | self.convs = UNetDoubleConv( 124 | in_channels, out_channels, 125 | norm_layer=norm_layer, activation=activation, padding=padding, 126 | ) 127 | self.pooling = nn.MaxPool2d(2, 2) if pool else nn.Identity() 128 | 129 | def forward(self, x): 130 | conv_x = self.convs(x) 131 | return conv_x, self.pooling(conv_x) 132 | 133 | 134 | class UNetUpBlock(nn.Module): 135 | def __init__( 136 | self, 137 | in_channels_decoder, in_channels_encoder, out_channels, 138 | norm_layer, activation, padding, 139 | attention_layer, 140 | ): 141 | super(UNetUpBlock, self).__init__() 142 | self.upconv = nn.Sequential( 143 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 144 | ConvBlock( 145 | in_channels_decoder, out_channels, 146 | kernel_size=3, stride=1, padding=1, 147 | norm_layer=None, activation=activation, 148 | ) 149 | ) 150 | self.convs = UNetDoubleConv( 151 | in_channels_encoder + out_channels, out_channels, 152 | norm_layer=norm_layer, activation=activation, padding=padding, 153 | ) 154 | if attention_layer is not None: 155 | self.attention = attention_layer(in_channels_encoder + out_channels, norm_layer, activation) 156 | else: 157 | self.attention = None 158 | 159 | def forward(self, x, encoder_out, mask=None): 160 | upsample_x = self.upconv(x) 161 | x_cat_encoder = torch.cat([encoder_out, upsample_x], dim=1) 162 | if self.attention is not None: 163 | x_cat_encoder = self.attention(x_cat_encoder, mask) 164 | return self.convs(x_cat_encoder) 165 | 166 | 167 | class UNetDoubleConv(nn.Module): 168 | def __init__(self, in_channels, out_channels, norm_layer, activation, padding): 169 | super(UNetDoubleConv, self).__init__() 170 | self.block = nn.Sequential( 171 | ConvBlock( 172 | in_channels, out_channels, 173 | kernel_size=3, stride=1, padding=padding, 174 | norm_layer=norm_layer, activation=activation, 175 | ), 176 | ConvBlock( 177 | out_channels, out_channels, 178 | kernel_size=3, stride=1, padding=padding, 179 | norm_layer=norm_layer, activation=activation, 180 | ), 181 | ) 182 | 183 | def forward(self, x): 184 | return self.block(x) 185 | -------------------------------------------------------------------------------- /iharm/model/modeling/unet_v1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from functools import partial 4 | 5 | from iharm.model.modeling.basic_blocks import ConvBlock,GaussianSmoothing 6 | from iharm.model.ops import FeaturesConnector 7 | 8 | 9 | class UNetEncoder(nn.Module): 10 | def __init__( 11 | self, 12 | depth, ch, 13 | norm_layer, batchnorm_from, max_channels, 14 | backbone_from, backbone_channels=None, backbone_mode='',_in_channels = 4 15 | ): 16 | super(UNetEncoder, self).__init__() 17 | self.depth = depth 18 | self.backbone_from = backbone_from 19 | self.block_channels = [] 20 | backbone_channels = [] if backbone_channels is None else backbone_channels[::-1] 21 | relu = partial(nn.ReLU, inplace=True) 22 | 23 | in_channels = _in_channels 24 | out_channels = ch 25 | 26 | # print(in_channels,out_channels) 27 | 28 | self.block0 = UNetDownBlock( 29 | in_channels, out_channels, 30 | norm_layer=norm_layer if batchnorm_from == 0 else None, 31 | activation=relu, 32 | pool=True, padding=1, 33 | ) 34 | self.block_channels.append(out_channels) 35 | in_channels, out_channels = out_channels, min(2 * out_channels, max_channels) 36 | self.block1 = UNetDownBlock( 37 | in_channels, out_channels, 38 | norm_layer=norm_layer if 0 <= batchnorm_from <= 1 else None, 39 | activation=relu, 40 | pool=True, padding=1, 41 | ) 42 | self.block_channels.append(out_channels) 43 | 44 | self.blocks_connected = nn.ModuleDict() 45 | self.connectors = nn.ModuleDict() 46 | for block_i in range(2, depth): 47 | in_channels, out_channels = out_channels, min(2 * out_channels, max_channels) 48 | if 0 <= backbone_from <= block_i and len(backbone_channels): 49 | stage_channels = backbone_channels.pop() 50 | connector = FeaturesConnector(backbone_mode, in_channels, stage_channels, in_channels) 51 | self.connectors[f'connector{block_i}'] = connector 52 | in_channels = connector.output_channels 53 | self.blocks_connected[f'block{block_i}'] = UNetDownBlock( 54 | in_channels, out_channels, 55 | norm_layer=norm_layer if 0 <= batchnorm_from <= block_i else None, 56 | activation=relu, padding=1, 57 | pool=block_i < depth - 1, 58 | ) 59 | self.block_channels.append(out_channels) 60 | 61 | def forward(self, x, backbone_features): 62 | backbone_features = [] if backbone_features is None else backbone_features[::-1] 63 | outputs = [] 64 | 65 | block_input = x 66 | output, block_input = self.block0(block_input) 67 | outputs.append(output) 68 | output, block_input = self.block1(block_input) 69 | outputs.append(output) 70 | 71 | for block_i in range(2, self.depth): 72 | block = self.blocks_connected[f'block{block_i}'] 73 | connector_name = f'connector{block_i}' 74 | if connector_name in self.connectors: 75 | stage_features = backbone_features.pop() 76 | connector = self.connectors[connector_name] 77 | block_input = connector(block_input, stage_features) 78 | output, block_input = block(block_input) 79 | outputs.append(output) 80 | 81 | return outputs[::-1] 82 | 83 | 84 | class UNetDecoder(nn.Module): 85 | def __init__(self, depth, encoder_blocks_channels, norm_layer, 86 | attention_layer=None, attend_from=3, image_fusion=False, control_module_start = -1, 87 | control_module_layer=None, 88 | w_dim = 256): 89 | super(UNetDecoder, self).__init__() 90 | self.up_blocks = nn.ModuleList() 91 | self.image_fusion = image_fusion 92 | self.w_dim = w_dim 93 | in_channels = encoder_blocks_channels.pop() 94 | out_channels = in_channels 95 | # Last encoder layer doesn't pool, so there're only (depth - 1) deconvs 96 | for d in range(depth - 1): 97 | out_channels = encoder_blocks_channels.pop() if len(encoder_blocks_channels) else in_channels // 2 98 | use_control_module = True if control_module_start == -1 else d >= control_module_start 99 | stage_attention_layer = attention_layer if 0 <= attend_from <= d else None 100 | self.up_blocks.append(UNetUpBlock( 101 | in_channels, out_channels, out_channels, 102 | norm_layer=norm_layer, activation=partial(nn.ReLU, inplace=True), 103 | padding=1, 104 | attention_layer=stage_attention_layer, 105 | use_control_module = use_control_module, 106 | w_dim = self.w_dim, 107 | control_module_layer = control_module_layer, 108 | )) 109 | in_channels = out_channels 110 | 111 | if self.image_fusion: 112 | self.conv_attention = nn.Conv2d(out_channels, 1, kernel_size=1) 113 | self.to_rgb = nn.Conv2d(out_channels, 3, kernel_size=1) 114 | 115 | def forward(self, encoder_outputs, input_image, mask, ws): 116 | output = encoder_outputs[0] 117 | count = 0 118 | for block, skip_output in zip(self.up_blocks, encoder_outputs[1:]): 119 | output = block(output, skip_output, mask,ws) 120 | count += 1 121 | 122 | if self.image_fusion: 123 | attention_map = torch.sigmoid(3.0 * self.conv_attention(output)) 124 | output = attention_map * input_image + (1.0 - attention_map) * self.to_rgb(output) 125 | else: 126 | output = self.to_rgb(output) 127 | 128 | return output 129 | 130 | 131 | class UNetDownBlock(nn.Module): 132 | def __init__(self, in_channels, out_channels, norm_layer, activation, pool, padding): 133 | super(UNetDownBlock, self).__init__() 134 | self.convs = UNetDoubleConv( 135 | in_channels, out_channels, 136 | norm_layer=norm_layer, activation=activation, padding=padding, 137 | ) 138 | self.pooling = nn.MaxPool2d(2, 2) if pool else nn.Identity() 139 | 140 | def forward(self, x): 141 | conv_x = self.convs(x) 142 | return conv_x, self.pooling(conv_x) 143 | 144 | 145 | class UNetUpBlock(nn.Module): 146 | def __init__( 147 | self, 148 | in_channels_decoder, in_channels_encoder, out_channels, 149 | norm_layer, activation, padding, 150 | attention_layer, 151 | use_control_module, 152 | w_dim, 153 | control_module_layer, 154 | ): 155 | super(UNetUpBlock, self).__init__() 156 | self.upconv = nn.Sequential( 157 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 158 | ConvBlock( 159 | in_channels_decoder, out_channels, 160 | kernel_size=3, stride=1, padding=1, 161 | norm_layer=None, activation=activation, 162 | ) 163 | ) 164 | self.convs = UNetDoubleConv( 165 | in_channels_encoder + out_channels, out_channels, 166 | norm_layer=norm_layer, activation=activation, padding=padding, 167 | ) 168 | if attention_layer is not None: 169 | self.attention = attention_layer(in_channels_encoder + out_channels, norm_layer, activation) 170 | else: 171 | self.attention = None 172 | 173 | self.w_dim = w_dim 174 | self.use_control_module = use_control_module 175 | self.out_channel = out_channels 176 | 177 | if self.use_control_module: 178 | self.control_module = control_module_layer( 179 | w_dim = self.w_dim, 180 | feature_dim = out_channels, 181 | ) 182 | 183 | 184 | def forward(self, x, encoder_out, mask, ws): 185 | upsample_x = self.upconv(x) 186 | 187 | _mask = nn.functional.interpolate( 188 | mask, size=encoder_out.size()[-2:], 189 | mode='bilinear', align_corners=True 190 | ) 191 | 192 | x_cat_encoder = torch.cat([encoder_out, upsample_x], dim=1) 193 | if self.attention is not None: 194 | x_cat_encoder = self.attention(x_cat_encoder, mask) 195 | dec_out = self.convs(x_cat_encoder) 196 | if self.use_control_module: 197 | wl, wa, wb = ws['w_l'], ws['w_a'], ws['w_b'] 198 | dec_out = self.control_module(dec_out, wl, wa, wb, _mask) 199 | return dec_out 200 | 201 | 202 | class UNetDoubleConv(nn.Module): 203 | def __init__(self, in_channels, out_channels, norm_layer, activation, padding): 204 | super(UNetDoubleConv, self).__init__() 205 | self.block = nn.Sequential( 206 | ConvBlock( 207 | in_channels, out_channels, 208 | kernel_size=3, stride=1, padding=padding, 209 | norm_layer=norm_layer, activation=activation, 210 | ), 211 | ConvBlock( 212 | out_channels, out_channels, 213 | kernel_size=3, stride=1, padding=padding, 214 | norm_layer=norm_layer, activation=activation, 215 | ), 216 | ) 217 | 218 | def forward(self, x): 219 | return self.block(x) 220 | -------------------------------------------------------------------------------- /iharm/model/modifiers.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class LRMult(object): 4 | def __init__(self, lr_mult=1.): 5 | self.lr_mult = lr_mult 6 | 7 | def __call__(self, m): 8 | if getattr(m, 'weight', None) is not None: 9 | m.weight.lr_mult = self.lr_mult 10 | if getattr(m, 'bias', None) is not None: 11 | m.bias.lr_mult = self.lr_mult 12 | -------------------------------------------------------------------------------- /iharm/model/ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | 4 | 5 | class SimpleInputFusion(nn.Module): 6 | def __init__(self, add_ch=1, rgb_ch=3, ch=8, norm_layer=nn.BatchNorm2d): 7 | super(SimpleInputFusion, self).__init__() 8 | 9 | self.fusion_conv = nn.Sequential( 10 | nn.Conv2d(in_channels=add_ch + rgb_ch, out_channels=ch, kernel_size=1), 11 | nn.LeakyReLU(negative_slope=0.2), 12 | norm_layer(ch), 13 | nn.Conv2d(in_channels=ch, out_channels=rgb_ch, kernel_size=1), 14 | ) 15 | 16 | def forward(self, image, additional_input): 17 | return self.fusion_conv(torch.cat((image, additional_input), dim=1)) 18 | 19 | 20 | class ChannelAttention(nn.Module): 21 | def __init__(self, in_channels): 22 | super(ChannelAttention, self).__init__() 23 | self.global_pools = nn.ModuleList([ 24 | nn.AdaptiveAvgPool2d(1), 25 | nn.AdaptiveMaxPool2d(1), 26 | ]) 27 | intermediate_channels_count = max(in_channels // 16, 8) 28 | self.attention_transform = nn.Sequential( 29 | nn.Linear(len(self.global_pools) * in_channels, intermediate_channels_count), 30 | nn.ReLU(), 31 | nn.Linear(intermediate_channels_count, in_channels), 32 | nn.Sigmoid(), 33 | ) 34 | 35 | def forward(self, x): 36 | pooled_x = [] 37 | for global_pool in self.global_pools: 38 | pooled_x.append(global_pool(x)) 39 | pooled_x = torch.cat(pooled_x, dim=1).flatten(start_dim=1) 40 | channel_attention_weights = self.attention_transform(pooled_x)[..., None, None] 41 | return channel_attention_weights * x 42 | 43 | 44 | class MaskedChannelAttention(nn.Module): 45 | def __init__(self, in_channels, *args, **kwargs): 46 | super(MaskedChannelAttention, self).__init__() 47 | self.global_max_pool = MaskedGlobalMaxPool2d() 48 | self.global_avg_pool = FastGlobalAvgPool2d() 49 | 50 | intermediate_channels_count = max(in_channels // 16, 8) 51 | self.attention_transform = nn.Sequential( 52 | nn.Linear(3 * in_channels, intermediate_channels_count), 53 | nn.ReLU(inplace=True), 54 | nn.Linear(intermediate_channels_count, in_channels), 55 | nn.Sigmoid(), 56 | ) 57 | 58 | def forward(self, x, mask): 59 | if mask.shape[2:] != x.shape[:2]: 60 | mask = nn.functional.interpolate( 61 | mask, size=x.size()[-2:], 62 | mode='bilinear', align_corners=True 63 | ) 64 | pooled_x = torch.cat([ 65 | self.global_max_pool(x, mask), 66 | self.global_avg_pool(x) 67 | ], dim=1) 68 | channel_attention_weights = self.attention_transform(pooled_x)[..., None, None] 69 | 70 | return channel_attention_weights * x 71 | 72 | 73 | class MaskedGlobalMaxPool2d(nn.Module): 74 | def __init__(self): 75 | super().__init__() 76 | self.global_max_pool = FastGlobalMaxPool2d() 77 | 78 | def forward(self, x, mask): 79 | return torch.cat(( 80 | self.global_max_pool(x * mask), 81 | self.global_max_pool(x * (1.0 - mask)) 82 | ), dim=1) 83 | 84 | 85 | class FastGlobalAvgPool2d(nn.Module): 86 | def __init__(self): 87 | super(FastGlobalAvgPool2d, self).__init__() 88 | 89 | def forward(self, x): 90 | in_size = x.size() 91 | return x.view((in_size[0], in_size[1], -1)).mean(dim=2) 92 | 93 | 94 | class FastGlobalMaxPool2d(nn.Module): 95 | def __init__(self): 96 | super(FastGlobalMaxPool2d, self).__init__() 97 | 98 | def forward(self, x): 99 | in_size = x.size() 100 | return x.view((in_size[0], in_size[1], -1)).max(dim=2)[0] 101 | 102 | 103 | class ScaleLayer(nn.Module): 104 | def __init__(self, init_value=1.0, lr_mult=1): 105 | super().__init__() 106 | self.lr_mult = lr_mult 107 | self.scale = nn.Parameter( 108 | torch.full((1,), init_value / lr_mult, dtype=torch.float32) 109 | ) 110 | 111 | def forward(self, x): 112 | scale = torch.abs(self.scale * self.lr_mult) 113 | return x * scale 114 | 115 | 116 | class FeaturesConnector(nn.Module): 117 | def __init__(self, mode, in_channels, feature_channels, out_channels): 118 | super(FeaturesConnector, self).__init__() 119 | self.mode = mode if feature_channels else '' 120 | 121 | if self.mode == 'catc': 122 | self.reduce_conv = nn.Conv2d(in_channels + feature_channels, out_channels, kernel_size=1) 123 | elif self.mode == 'sum': 124 | self.reduce_conv = nn.Conv2d(feature_channels, out_channels, kernel_size=1) 125 | 126 | self.output_channels = out_channels if self.mode != 'cat' else in_channels + feature_channels 127 | 128 | def forward(self, x, features): 129 | if self.mode == 'cat': 130 | return torch.cat((x, features), 1) 131 | if self.mode == 'catc': 132 | return self.reduce_conv(torch.cat((x, features), 1)) 133 | if self.mode == 'sum': 134 | return self.reduce_conv(features) + x 135 | return x 136 | 137 | def extra_repr(self): 138 | return self.mode 139 | -------------------------------------------------------------------------------- /iharm/model/syncbn/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Tamaki Kojima 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /iharm/model/syncbn/README.md: -------------------------------------------------------------------------------- 1 | # pytorch-syncbn 2 | 3 | Tamaki Kojima(tamakoji@gmail.com) 4 | 5 | ## Announcement 6 | 7 | **Pytorch 1.0 support** 8 | 9 | ## Overview 10 | This is alternative implementation of "Synchronized Multi-GPU Batch Normalization" which computes global stats across gpus instead of locally computed. SyncBN are getting important for those input image is large, and must use multi-gpu to increase the minibatch-size for the training. 11 | 12 | The code was inspired by [Pytorch-Encoding](https://github.com/zhanghang1989/PyTorch-Encoding) and [Inplace-ABN](https://github.com/mapillary/inplace_abn) 13 | 14 | ## Remarks 15 | - Unlike [Pytorch-Encoding](https://github.com/zhanghang1989/PyTorch-Encoding), you don't need custom `nn.DataParallel` 16 | - Unlike [Inplace-ABN](https://github.com/mapillary/inplace_abn), you can just replace your `nn.BatchNorm2d` to this module implementation, since it will not mark for inplace operation 17 | - You can plug into arbitrary module written in PyTorch to enable Synchronized BatchNorm 18 | - Backward computation is rewritten and tested against behavior of `nn.BatchNorm2d` 19 | 20 | ## Requirements 21 | For PyTorch, please refer to https://pytorch.org/ 22 | 23 | NOTE : The code is tested only with PyTorch v1.0.0, CUDA10/CuDNN7.4.2 on ubuntu18.04 24 | 25 | It utilize Pytorch JIT mechanism to compile seamlessly, using ninja. Please install ninja-build before use. 26 | 27 | ``` 28 | sudo apt-get install ninja-build 29 | ``` 30 | 31 | Also install all dependencies for python. For pip, run: 32 | 33 | 34 | ``` 35 | pip install -U -r requirements.txt 36 | ``` 37 | 38 | ## Build 39 | 40 | There is no need to build. just run and JIT will take care. 41 | JIT and cpp extensions are supported after PyTorch0.4, however it is highly recommended to use PyTorch > 1.0 due to huge design changes. 42 | 43 | ## Usage 44 | 45 | Please refer to [`test.py`](./test.py) for testing the difference between `nn.BatchNorm2d` and `modules.nn.BatchNorm2d` 46 | 47 | ``` 48 | import torch 49 | from modules import nn as NN 50 | num_gpu = torch.cuda.device_count() 51 | model = nn.Sequential( 52 | nn.Conv2d(3, 3, 1, 1, bias=False), 53 | NN.BatchNorm2d(3), 54 | nn.ReLU(inplace=True), 55 | nn.Conv2d(3, 3, 1, 1, bias=False), 56 | NN.BatchNorm2d(3), 57 | ).cuda() 58 | model = nn.DataParallel(model, device_ids=range(num_gpu)) 59 | x = torch.rand(num_gpu, 3, 2, 2).cuda() 60 | z = model(x) 61 | ``` 62 | 63 | ## Math 64 | 65 | ### Forward 66 | 1. compute in each gpu 67 | 2. gather all from workers to master and compute where 68 | 69 | 70 | 71 | and 72 | 73 | 74 | 75 | and then above global stats to be shared to all gpus, update running_mean and running_var by moving average using global stats. 76 | 77 | 3. forward batchnorm using global stats by 78 | 79 | 80 | 81 | and then 82 | 83 | 84 | 85 | where is weight parameter and is bias parameter. 86 | 87 | 4. save for backward 88 | 89 | ### Backward 90 | 91 | 1. Restore saved 92 | 93 | 2. Compute below sums on each gpu 94 | 95 | 96 | 97 | and 98 | 99 | 100 | 101 | where 102 | 103 | then gather them at master node to sum up global, and normalize with N where N is total number of elements for each channels. Global sums are then shared among all gpus. 104 | 105 | 3. compute gradients using global stats 106 | 107 | 108 | 109 | where 110 | 111 | 112 | 113 | and 114 | 115 | 116 | 117 | and finally, 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | Note that in the implementation, normalization with N is performed at step (2) and above equation and implementation is not exactly the same, but mathematically is same. 126 | 127 | You can go deeper on above explanation at [Kevin Zakka's Blog](https://kevinzakka.github.io/2016/09/14/batch_normalization/) -------------------------------------------------------------------------------- /iharm/model/syncbn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/DucoNet-Image-Harmonization/167cd720d91f0f5d9e64b54c3c8d4300d8915be2/iharm/model/syncbn/__init__.py -------------------------------------------------------------------------------- /iharm/model/syncbn/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/DucoNet-Image-Harmonization/167cd720d91f0f5d9e64b54c3c8d4300d8915be2/iharm/model/syncbn/modules/__init__.py -------------------------------------------------------------------------------- /iharm/model/syncbn/modules/functional/__init__.py: -------------------------------------------------------------------------------- 1 | from .syncbn import batchnorm2d_sync 2 | -------------------------------------------------------------------------------- /iharm/model/syncbn/modules/functional/_csrc.py: -------------------------------------------------------------------------------- 1 | """ 2 | /*****************************************************************************/ 3 | 4 | Extension module loader 5 | 6 | code referenced from : https://github.com/facebookresearch/maskrcnn-benchmark 7 | 8 | /*****************************************************************************/ 9 | """ 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | 14 | import glob 15 | import os.path 16 | 17 | import torch 18 | 19 | try: 20 | from torch.utils.cpp_extension import load 21 | from torch.utils.cpp_extension import CUDA_HOME 22 | except ImportError: 23 | raise ImportError( 24 | "The cpp layer extensions requires PyTorch 0.4 or higher") 25 | 26 | 27 | def _load_C_extensions(): 28 | this_dir = os.path.dirname(os.path.abspath(__file__)) 29 | this_dir = os.path.join(this_dir, "csrc") 30 | 31 | main_file = glob.glob(os.path.join(this_dir, "*.cpp")) 32 | sources_cpu = glob.glob(os.path.join(this_dir, "cpu", "*.cpp")) 33 | sources_cuda = glob.glob(os.path.join(this_dir, "cuda", "*.cu")) 34 | 35 | sources = main_file + sources_cpu 36 | 37 | extra_cflags = [] 38 | extra_cuda_cflags = [] 39 | if torch.cuda.is_available() and CUDA_HOME is not None: 40 | sources.extend(sources_cuda) 41 | extra_cflags = ["-O3", "-DWITH_CUDA"] 42 | extra_cuda_cflags = ["--expt-extended-lambda"] 43 | sources = [os.path.join(this_dir, s) for s in sources] 44 | extra_include_paths = [this_dir] 45 | return load( 46 | name="ext_lib", 47 | sources=sources, 48 | extra_cflags=extra_cflags, 49 | extra_include_paths=extra_include_paths, 50 | extra_cuda_cflags=extra_cuda_cflags, 51 | ) 52 | 53 | 54 | _backend = _load_C_extensions() 55 | -------------------------------------------------------------------------------- /iharm/model/syncbn/modules/functional/csrc/bn.h: -------------------------------------------------------------------------------- 1 | /***************************************************************************** 2 | 3 | SyncBN 4 | 5 | *****************************************************************************/ 6 | #pragma once 7 | 8 | #ifdef WITH_CUDA 9 | #include "cuda/ext_lib.h" 10 | #endif 11 | 12 | /// SyncBN 13 | 14 | std::vector syncbn_sum_sqsum(const at::Tensor& x) { 15 | if (x.is_cuda()) { 16 | #ifdef WITH_CUDA 17 | return syncbn_sum_sqsum_cuda(x); 18 | #else 19 | AT_ERROR("Not compiled with GPU support"); 20 | #endif 21 | } else { 22 | AT_ERROR("CPU implementation not supported"); 23 | } 24 | } 25 | 26 | at::Tensor syncbn_forward(const at::Tensor& x, const at::Tensor& weight, 27 | const at::Tensor& bias, const at::Tensor& mean, 28 | const at::Tensor& var, bool affine, float eps) { 29 | if (x.is_cuda()) { 30 | #ifdef WITH_CUDA 31 | return syncbn_forward_cuda(x, weight, bias, mean, var, affine, eps); 32 | #else 33 | AT_ERROR("Not compiled with GPU support"); 34 | #endif 35 | } else { 36 | AT_ERROR("CPU implementation not supported"); 37 | } 38 | } 39 | 40 | std::vector syncbn_backward_xhat(const at::Tensor& dz, 41 | const at::Tensor& x, 42 | const at::Tensor& mean, 43 | const at::Tensor& var, float eps) { 44 | if (dz.is_cuda()) { 45 | #ifdef WITH_CUDA 46 | return syncbn_backward_xhat_cuda(dz, x, mean, var, eps); 47 | #else 48 | AT_ERROR("Not compiled with GPU support"); 49 | #endif 50 | } else { 51 | AT_ERROR("CPU implementation not supported"); 52 | } 53 | } 54 | 55 | std::vector syncbn_backward( 56 | const at::Tensor& dz, const at::Tensor& x, const at::Tensor& weight, 57 | const at::Tensor& bias, const at::Tensor& mean, const at::Tensor& var, 58 | const at::Tensor& sum_dz, const at::Tensor& sum_dz_xhat, bool affine, 59 | float eps) { 60 | if (dz.is_cuda()) { 61 | #ifdef WITH_CUDA 62 | return syncbn_backward_cuda(dz, x, weight, bias, mean, var, sum_dz, 63 | sum_dz_xhat, affine, eps); 64 | #else 65 | AT_ERROR("Not compiled with GPU support"); 66 | #endif 67 | } else { 68 | AT_ERROR("CPU implementation not supported"); 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /iharm/model/syncbn/modules/functional/csrc/cuda/bn_cuda.cu: -------------------------------------------------------------------------------- 1 | /***************************************************************************** 2 | 3 | CUDA SyncBN code 4 | 5 | code referenced from : https://github.com/mapillary/inplace_abn 6 | 7 | *****************************************************************************/ 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include "cuda/common.h" 13 | 14 | // Utilities 15 | void get_dims(at::Tensor x, int64_t &num, int64_t &chn, int64_t &sp) { 16 | num = x.size(0); 17 | chn = x.size(1); 18 | sp = 1; 19 | for (int64_t i = 2; i < x.ndimension(); ++i) sp *= x.size(i); 20 | } 21 | 22 | /// SyncBN 23 | 24 | template 25 | struct SqSumOp { 26 | __device__ SqSumOp(const T *t, int c, int s) : tensor(t), chn(c), sp(s) {} 27 | __device__ __forceinline__ Pair operator()(int batch, int plane, int n) { 28 | T x = tensor[(batch * chn + plane) * sp + n]; 29 | return Pair(x, x * x); // x, x^2 30 | } 31 | const T *tensor; 32 | const int chn; 33 | const int sp; 34 | }; 35 | 36 | template 37 | __global__ void syncbn_sum_sqsum_kernel(const T *x, T *sum, T *sqsum, 38 | int num, int chn, int sp) { 39 | int plane = blockIdx.x; 40 | Pair res = 41 | reduce, SqSumOp>(SqSumOp(x, chn, sp), plane, num, chn, sp); 42 | __syncthreads(); 43 | if (threadIdx.x == 0) { 44 | sum[plane] = res.v1; 45 | sqsum[plane] = res.v2; 46 | } 47 | } 48 | 49 | std::vector syncbn_sum_sqsum_cuda(const at::Tensor &x) { 50 | CHECK_INPUT(x); 51 | 52 | // Extract dimensions 53 | int64_t num, chn, sp; 54 | get_dims(x, num, chn, sp); 55 | 56 | // Prepare output tensors 57 | auto sum = at::empty({chn}, x.options()); 58 | auto sqsum = at::empty({chn}, x.options()); 59 | 60 | // Run kernel 61 | dim3 blocks(chn); 62 | dim3 threads(getNumThreads(sp)); 63 | AT_DISPATCH_FLOATING_TYPES( 64 | x.type(), "syncbn_sum_sqsum_cuda", ([&] { 65 | syncbn_sum_sqsum_kernel<<>>( 66 | x.data(), sum.data(), 67 | sqsum.data(), num, chn, sp); 68 | })); 69 | return {sum, sqsum}; 70 | } 71 | 72 | template 73 | __global__ void syncbn_forward_kernel(T *z, const T *x, const T *weight, 74 | const T *bias, const T *mean, 75 | const T *var, bool affine, float eps, 76 | int num, int chn, int sp) { 77 | int plane = blockIdx.x; 78 | T _mean = mean[plane]; 79 | T _var = var[plane]; 80 | T _weight = affine ? weight[plane] : T(1); 81 | T _bias = affine ? bias[plane] : T(0); 82 | float _invstd = T(0); 83 | if (_var || eps) { 84 | _invstd = rsqrt(_var + eps); 85 | } 86 | for (int batch = 0; batch < num; ++batch) { 87 | for (int n = threadIdx.x; n < sp; n += blockDim.x) { 88 | T _x = x[(batch * chn + plane) * sp + n]; 89 | T _xhat = (_x - _mean) * _invstd; 90 | T _z = _xhat * _weight + _bias; 91 | z[(batch * chn + plane) * sp + n] = _z; 92 | } 93 | } 94 | } 95 | 96 | at::Tensor syncbn_forward_cuda(const at::Tensor &x, const at::Tensor &weight, 97 | const at::Tensor &bias, const at::Tensor &mean, 98 | const at::Tensor &var, bool affine, float eps) { 99 | CHECK_INPUT(x); 100 | CHECK_INPUT(weight); 101 | CHECK_INPUT(bias); 102 | CHECK_INPUT(mean); 103 | CHECK_INPUT(var); 104 | 105 | // Extract dimensions 106 | int64_t num, chn, sp; 107 | get_dims(x, num, chn, sp); 108 | 109 | auto z = at::zeros_like(x); 110 | 111 | // Run kernel 112 | dim3 blocks(chn); 113 | dim3 threads(getNumThreads(sp)); 114 | AT_DISPATCH_FLOATING_TYPES( 115 | x.type(), "syncbn_forward_cuda", ([&] { 116 | syncbn_forward_kernel<<>>( 117 | z.data(), x.data(), 118 | weight.data(), bias.data(), 119 | mean.data(), var.data(), 120 | affine, eps, num, chn, sp); 121 | })); 122 | return z; 123 | } 124 | 125 | template 126 | struct XHatOp { 127 | __device__ XHatOp(T _weight, T _bias, const T *_dz, const T *_x, int c, int s) 128 | : weight(_weight), bias(_bias), x(_x), dz(_dz), chn(c), sp(s) {} 129 | __device__ __forceinline__ Pair operator()(int batch, int plane, int n) { 130 | // xhat = (x - bias) * weight 131 | T _xhat = (x[(batch * chn + plane) * sp + n] - bias) * weight; 132 | // dxhat * x_hat 133 | T _dz = dz[(batch * chn + plane) * sp + n]; 134 | return Pair(_dz, _dz * _xhat); 135 | } 136 | const T weight; 137 | const T bias; 138 | const T *dz; 139 | const T *x; 140 | const int chn; 141 | const int sp; 142 | }; 143 | 144 | template 145 | __global__ void syncbn_backward_xhat_kernel(const T *dz, const T *x, 146 | const T *mean, const T *var, 147 | T *sum_dz, T *sum_dz_xhat, 148 | float eps, int num, int chn, 149 | int sp) { 150 | int plane = blockIdx.x; 151 | T _mean = mean[plane]; 152 | T _var = var[plane]; 153 | T _invstd = T(0); 154 | if (_var || eps) { 155 | _invstd = rsqrt(_var + eps); 156 | } 157 | Pair res = reduce, XHatOp>( 158 | XHatOp(_invstd, _mean, dz, x, chn, sp), plane, num, chn, sp); 159 | __syncthreads(); 160 | if (threadIdx.x == 0) { 161 | // \sum(\frac{dJ}{dy_i}) 162 | sum_dz[plane] = res.v1; 163 | // \sum(\frac{dJ}{dy_i}*\hat{x_i}) 164 | sum_dz_xhat[plane] = res.v2; 165 | } 166 | } 167 | 168 | std::vector syncbn_backward_xhat_cuda(const at::Tensor &dz, 169 | const at::Tensor &x, 170 | const at::Tensor &mean, 171 | const at::Tensor &var, 172 | float eps) { 173 | CHECK_INPUT(dz); 174 | CHECK_INPUT(x); 175 | CHECK_INPUT(mean); 176 | CHECK_INPUT(var); 177 | // Extract dimensions 178 | int64_t num, chn, sp; 179 | get_dims(x, num, chn, sp); 180 | // Prepare output tensors 181 | auto sum_dz = at::empty({chn}, x.options()); 182 | auto sum_dz_xhat = at::empty({chn}, x.options()); 183 | // Run kernel 184 | dim3 blocks(chn); 185 | dim3 threads(getNumThreads(sp)); 186 | AT_DISPATCH_FLOATING_TYPES( 187 | x.type(), "syncbn_backward_xhat_cuda", ([&] { 188 | syncbn_backward_xhat_kernel<<>>( 189 | dz.data(), x.data(), mean.data(), 190 | var.data(), sum_dz.data(), 191 | sum_dz_xhat.data(), eps, num, chn, sp); 192 | })); 193 | return {sum_dz, sum_dz_xhat}; 194 | } 195 | 196 | template 197 | __global__ void syncbn_backward_kernel(const T *dz, const T *x, const T *weight, 198 | const T *bias, const T *mean, 199 | const T *var, const T *sum_dz, 200 | const T *sum_dz_xhat, T *dx, T *dweight, 201 | T *dbias, bool affine, float eps, 202 | int num, int chn, int sp) { 203 | int plane = blockIdx.x; 204 | T _mean = mean[plane]; 205 | T _var = var[plane]; 206 | T _weight = affine ? weight[plane] : T(1); 207 | T _sum_dz = sum_dz[plane]; 208 | T _sum_dz_xhat = sum_dz_xhat[plane]; 209 | T _invstd = T(0); 210 | if (_var || eps) { 211 | _invstd = rsqrt(_var + eps); 212 | } 213 | /* 214 | \frac{dJ}{dx_i} = \frac{1}{N\sqrt{(\sigma^2+\epsilon)}} ( 215 | N\frac{dJ}{d\hat{x_i}} - 216 | \sum_{j=1}^{N}(\frac{dJ}{d\hat{x_j}}) - 217 | \hat{x_i}\sum_{j=1}^{N}(\frac{dJ}{d\hat{x_j}}\hat{x_j}) 218 | ) 219 | Note : N is omitted here since it will be accumulated and 220 | _sum_dz and _sum_dz_xhat expected to be already normalized 221 | before the call. 222 | */ 223 | if (dx) { 224 | T _mul = _weight * _invstd; 225 | for (int batch = 0; batch < num; ++batch) { 226 | for (int n = threadIdx.x; n < sp; n += blockDim.x) { 227 | T _dz = dz[(batch * chn + plane) * sp + n]; 228 | T _xhat = (x[(batch * chn + plane) * sp + n] - _mean) * _invstd; 229 | T _dx = (_dz - _sum_dz - _xhat * _sum_dz_xhat) * _mul; 230 | dx[(batch * chn + plane) * sp + n] = _dx; 231 | } 232 | } 233 | } 234 | __syncthreads(); 235 | if (threadIdx.x == 0) { 236 | if (affine) { 237 | T _norm = num * sp; 238 | dweight[plane] += _sum_dz_xhat * _norm; 239 | dbias[plane] += _sum_dz * _norm; 240 | } 241 | } 242 | } 243 | 244 | std::vector syncbn_backward_cuda( 245 | const at::Tensor &dz, const at::Tensor &x, const at::Tensor &weight, 246 | const at::Tensor &bias, const at::Tensor &mean, const at::Tensor &var, 247 | const at::Tensor &sum_dz, const at::Tensor &sum_dz_xhat, bool affine, 248 | float eps) { 249 | CHECK_INPUT(dz); 250 | CHECK_INPUT(x); 251 | CHECK_INPUT(weight); 252 | CHECK_INPUT(bias); 253 | CHECK_INPUT(mean); 254 | CHECK_INPUT(var); 255 | CHECK_INPUT(sum_dz); 256 | CHECK_INPUT(sum_dz_xhat); 257 | 258 | // Extract dimensions 259 | int64_t num, chn, sp; 260 | get_dims(x, num, chn, sp); 261 | 262 | // Prepare output tensors 263 | auto dx = at::zeros_like(dz); 264 | auto dweight = at::zeros_like(weight); 265 | auto dbias = at::zeros_like(bias); 266 | 267 | // Run kernel 268 | dim3 blocks(chn); 269 | dim3 threads(getNumThreads(sp)); 270 | AT_DISPATCH_FLOATING_TYPES( 271 | x.type(), "syncbn_backward_cuda", ([&] { 272 | syncbn_backward_kernel<<>>( 273 | dz.data(), x.data(), weight.data(), 274 | bias.data(), mean.data(), var.data(), 275 | sum_dz.data(), sum_dz_xhat.data(), 276 | dx.data(), dweight.data(), 277 | dbias.data(), affine, eps, num, chn, sp); 278 | })); 279 | return {dx, dweight, dbias}; 280 | } -------------------------------------------------------------------------------- /iharm/model/syncbn/modules/functional/csrc/cuda/common.h: -------------------------------------------------------------------------------- 1 | /***************************************************************************** 2 | 3 | CUDA utility funcs 4 | 5 | code referenced from : https://github.com/mapillary/inplace_abn 6 | 7 | *****************************************************************************/ 8 | #pragma once 9 | 10 | #include 11 | 12 | // Checks 13 | #ifndef AT_CHECK 14 | #define AT_CHECK AT_ASSERT 15 | #endif 16 | #define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 17 | #define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous") 18 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 19 | 20 | /* 21 | * General settings 22 | */ 23 | const int WARP_SIZE = 32; 24 | const int MAX_BLOCK_SIZE = 512; 25 | 26 | template 27 | struct Pair { 28 | T v1, v2; 29 | __device__ Pair() {} 30 | __device__ Pair(T _v1, T _v2) : v1(_v1), v2(_v2) {} 31 | __device__ Pair(T v) : v1(v), v2(v) {} 32 | __device__ Pair(int v) : v1(v), v2(v) {} 33 | __device__ Pair &operator+=(const Pair &a) { 34 | v1 += a.v1; 35 | v2 += a.v2; 36 | return *this; 37 | } 38 | }; 39 | 40 | /* 41 | * Utility functions 42 | */ 43 | template 44 | __device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, 45 | int width = warpSize, 46 | unsigned int mask = 0xffffffff) { 47 | #if CUDART_VERSION >= 9000 48 | return __shfl_xor_sync(mask, value, laneMask, width); 49 | #else 50 | return __shfl_xor(value, laneMask, width); 51 | #endif 52 | } 53 | 54 | __device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); } 55 | 56 | static int getNumThreads(int nElem) { 57 | int threadSizes[5] = {32, 64, 128, 256, MAX_BLOCK_SIZE}; 58 | for (int i = 0; i != 5; ++i) { 59 | if (nElem <= threadSizes[i]) { 60 | return threadSizes[i]; 61 | } 62 | } 63 | return MAX_BLOCK_SIZE; 64 | } 65 | 66 | template 67 | static __device__ __forceinline__ T warpSum(T val) { 68 | #if __CUDA_ARCH__ >= 300 69 | for (int i = 0; i < getMSB(WARP_SIZE); ++i) { 70 | val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE); 71 | } 72 | #else 73 | __shared__ T values[MAX_BLOCK_SIZE]; 74 | values[threadIdx.x] = val; 75 | __threadfence_block(); 76 | const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE; 77 | for (int i = 1; i < WARP_SIZE; i++) { 78 | val += values[base + ((i + threadIdx.x) % WARP_SIZE)]; 79 | } 80 | #endif 81 | return val; 82 | } 83 | 84 | template 85 | static __device__ __forceinline__ Pair warpSum(Pair value) { 86 | value.v1 = warpSum(value.v1); 87 | value.v2 = warpSum(value.v2); 88 | return value; 89 | } 90 | 91 | template 92 | __device__ T reduce(Op op, int plane, int N, int C, int S) { 93 | T sum = (T)0; 94 | for (int batch = 0; batch < N; ++batch) { 95 | for (int x = threadIdx.x; x < S; x += blockDim.x) { 96 | sum += op(batch, plane, x); 97 | } 98 | } 99 | 100 | // sum over NumThreads within a warp 101 | sum = warpSum(sum); 102 | 103 | // 'transpose', and reduce within warp again 104 | __shared__ T shared[32]; 105 | __syncthreads(); 106 | if (threadIdx.x % WARP_SIZE == 0) { 107 | shared[threadIdx.x / WARP_SIZE] = sum; 108 | } 109 | if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) { 110 | // zero out the other entries in shared 111 | shared[threadIdx.x] = (T)0; 112 | } 113 | __syncthreads(); 114 | if (threadIdx.x / WARP_SIZE == 0) { 115 | sum = warpSum(shared[threadIdx.x]); 116 | if (threadIdx.x == 0) { 117 | shared[0] = sum; 118 | } 119 | } 120 | __syncthreads(); 121 | 122 | // Everyone picks it up, should be broadcast into the whole gradInput 123 | return shared[0]; 124 | } -------------------------------------------------------------------------------- /iharm/model/syncbn/modules/functional/csrc/cuda/ext_lib.h: -------------------------------------------------------------------------------- 1 | /***************************************************************************** 2 | 3 | CUDA SyncBN code 4 | 5 | *****************************************************************************/ 6 | #pragma once 7 | #include 8 | #include 9 | 10 | /// Sync-BN 11 | std::vector syncbn_sum_sqsum_cuda(const at::Tensor& x); 12 | at::Tensor syncbn_forward_cuda(const at::Tensor& x, const at::Tensor& weight, 13 | const at::Tensor& bias, const at::Tensor& mean, 14 | const at::Tensor& var, bool affine, float eps); 15 | std::vector syncbn_backward_xhat_cuda(const at::Tensor& dz, 16 | const at::Tensor& x, 17 | const at::Tensor& mean, 18 | const at::Tensor& var, 19 | float eps); 20 | std::vector syncbn_backward_cuda( 21 | const at::Tensor& dz, const at::Tensor& x, const at::Tensor& weight, 22 | const at::Tensor& bias, const at::Tensor& mean, const at::Tensor& var, 23 | const at::Tensor& sum_dz, const at::Tensor& sum_dz_xhat, bool affine, 24 | float eps); 25 | -------------------------------------------------------------------------------- /iharm/model/syncbn/modules/functional/csrc/ext_lib.cpp: -------------------------------------------------------------------------------- 1 | #include "bn.h" 2 | 3 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 4 | m.def("syncbn_sum_sqsum", &syncbn_sum_sqsum, "Sum and Sum^2 computation"); 5 | m.def("syncbn_forward", &syncbn_forward, "SyncBN forward computation"); 6 | m.def("syncbn_backward_xhat", &syncbn_backward_xhat, 7 | "First part of SyncBN backward computation"); 8 | m.def("syncbn_backward", &syncbn_backward, 9 | "Second part of SyncBN backward computation"); 10 | } -------------------------------------------------------------------------------- /iharm/model/syncbn/modules/functional/syncbn.py: -------------------------------------------------------------------------------- 1 | """ 2 | /*****************************************************************************/ 3 | 4 | BatchNorm2dSync with multi-gpu 5 | 6 | code referenced from : https://github.com/mapillary/inplace_abn 7 | 8 | /*****************************************************************************/ 9 | """ 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | 14 | import torch.cuda.comm as comm 15 | from torch.autograd import Function 16 | from torch.autograd.function import once_differentiable 17 | from ._csrc import _backend 18 | 19 | 20 | def _count_samples(x): 21 | count = 1 22 | for i, s in enumerate(x.size()): 23 | if i != 1: 24 | count *= s 25 | return count 26 | 27 | 28 | class BatchNorm2dSyncFunc(Function): 29 | 30 | @staticmethod 31 | def forward(ctx, x, weight, bias, running_mean, running_var, 32 | extra, compute_stats=True, momentum=0.1, eps=1e-05): 33 | def _parse_extra(ctx, extra): 34 | ctx.is_master = extra["is_master"] 35 | if ctx.is_master: 36 | ctx.master_queue = extra["master_queue"] 37 | ctx.worker_queues = extra["worker_queues"] 38 | ctx.worker_ids = extra["worker_ids"] 39 | else: 40 | ctx.master_queue = extra["master_queue"] 41 | ctx.worker_queue = extra["worker_queue"] 42 | # Save context 43 | if extra is not None: 44 | _parse_extra(ctx, extra) 45 | ctx.compute_stats = compute_stats 46 | ctx.momentum = momentum 47 | ctx.eps = eps 48 | ctx.affine = weight is not None and bias is not None 49 | if ctx.compute_stats: 50 | N = _count_samples(x) * (ctx.master_queue.maxsize + 1) 51 | assert N > 1 52 | # 1. compute sum(x) and sum(x^2) 53 | xsum, xsqsum = _backend.syncbn_sum_sqsum(x.detach()) 54 | if ctx.is_master: 55 | xsums, xsqsums = [xsum], [xsqsum] 56 | # master : gatther all sum(x) and sum(x^2) from slaves 57 | for _ in range(ctx.master_queue.maxsize): 58 | xsum_w, xsqsum_w = ctx.master_queue.get() 59 | ctx.master_queue.task_done() 60 | xsums.append(xsum_w) 61 | xsqsums.append(xsqsum_w) 62 | xsum = comm.reduce_add(xsums) 63 | xsqsum = comm.reduce_add(xsqsums) 64 | mean = xsum / N 65 | sumvar = xsqsum - xsum * mean 66 | var = sumvar / N 67 | uvar = sumvar / (N - 1) 68 | # master : broadcast global mean, variance to all slaves 69 | tensors = comm.broadcast_coalesced( 70 | (mean, uvar, var), [mean.get_device()] + ctx.worker_ids) 71 | for ts, queue in zip(tensors[1:], ctx.worker_queues): 72 | queue.put(ts) 73 | else: 74 | # slave : send sum(x) and sum(x^2) to master 75 | ctx.master_queue.put((xsum, xsqsum)) 76 | # slave : get global mean and variance 77 | mean, uvar, var = ctx.worker_queue.get() 78 | ctx.worker_queue.task_done() 79 | 80 | # Update running stats 81 | running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean) 82 | running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * uvar) 83 | ctx.N = N 84 | ctx.save_for_backward(x, weight, bias, mean, var) 85 | else: 86 | mean, var = running_mean, running_var 87 | 88 | # do batch norm forward 89 | z = _backend.syncbn_forward(x, weight, bias, mean, var, 90 | ctx.affine, ctx.eps) 91 | return z 92 | 93 | @staticmethod 94 | @once_differentiable 95 | def backward(ctx, dz): 96 | x, weight, bias, mean, var = ctx.saved_tensors 97 | dz = dz.contiguous() 98 | 99 | # 1. compute \sum(\frac{dJ}{dy_i}) and \sum(\frac{dJ}{dy_i}*\hat{x_i}) 100 | sum_dz, sum_dz_xhat = _backend.syncbn_backward_xhat( 101 | dz, x, mean, var, ctx.eps) 102 | if ctx.is_master: 103 | sum_dzs, sum_dz_xhats = [sum_dz], [sum_dz_xhat] 104 | # master : gatther from slaves 105 | for _ in range(ctx.master_queue.maxsize): 106 | sum_dz_w, sum_dz_xhat_w = ctx.master_queue.get() 107 | ctx.master_queue.task_done() 108 | sum_dzs.append(sum_dz_w) 109 | sum_dz_xhats.append(sum_dz_xhat_w) 110 | # master : compute global stats 111 | sum_dz = comm.reduce_add(sum_dzs) 112 | sum_dz_xhat = comm.reduce_add(sum_dz_xhats) 113 | sum_dz /= ctx.N 114 | sum_dz_xhat /= ctx.N 115 | # master : broadcast global stats 116 | tensors = comm.broadcast_coalesced( 117 | (sum_dz, sum_dz_xhat), [mean.get_device()] + ctx.worker_ids) 118 | for ts, queue in zip(tensors[1:], ctx.worker_queues): 119 | queue.put(ts) 120 | else: 121 | # slave : send to master 122 | ctx.master_queue.put((sum_dz, sum_dz_xhat)) 123 | # slave : get global stats 124 | sum_dz, sum_dz_xhat = ctx.worker_queue.get() 125 | ctx.worker_queue.task_done() 126 | 127 | # do batch norm backward 128 | dx, dweight, dbias = _backend.syncbn_backward( 129 | dz, x, weight, bias, mean, var, sum_dz, sum_dz_xhat, 130 | ctx.affine, ctx.eps) 131 | 132 | return dx, dweight, dbias, \ 133 | None, None, None, None, None, None 134 | 135 | batchnorm2d_sync = BatchNorm2dSyncFunc.apply 136 | 137 | __all__ = ["batchnorm2d_sync"] 138 | -------------------------------------------------------------------------------- /iharm/model/syncbn/modules/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .syncbn import * 2 | -------------------------------------------------------------------------------- /iharm/model/syncbn/modules/nn/syncbn.py: -------------------------------------------------------------------------------- 1 | """ 2 | /*****************************************************************************/ 3 | 4 | BatchNorm2dSync with multi-gpu 5 | 6 | /*****************************************************************************/ 7 | """ 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | try: 13 | # python 3 14 | from queue import Queue 15 | except ImportError: 16 | # python 2 17 | from Queue import Queue 18 | 19 | import torch 20 | import torch.nn as nn 21 | from torch.nn import functional as F 22 | from torch.nn.parameter import Parameter 23 | from iharm.model.syncbn.modules.functional import batchnorm2d_sync 24 | 25 | 26 | class _BatchNorm(nn.Module): 27 | """ 28 | Customized BatchNorm from nn.BatchNorm 29 | >> added freeze attribute to enable bn freeze. 30 | """ 31 | 32 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 33 | track_running_stats=True): 34 | super(_BatchNorm, self).__init__() 35 | self.num_features = num_features 36 | self.eps = eps 37 | self.momentum = momentum 38 | self.affine = affine 39 | self.track_running_stats = track_running_stats 40 | self.freezed = False 41 | if self.affine: 42 | self.weight = Parameter(torch.Tensor(num_features)) 43 | self.bias = Parameter(torch.Tensor(num_features)) 44 | else: 45 | self.register_parameter('weight', None) 46 | self.register_parameter('bias', None) 47 | if self.track_running_stats: 48 | self.register_buffer('running_mean', torch.zeros(num_features)) 49 | self.register_buffer('running_var', torch.ones(num_features)) 50 | else: 51 | self.register_parameter('running_mean', None) 52 | self.register_parameter('running_var', None) 53 | self.reset_parameters() 54 | 55 | def reset_parameters(self): 56 | if self.track_running_stats: 57 | self.running_mean.zero_() 58 | self.running_var.fill_(1) 59 | if self.affine: 60 | self.weight.data.uniform_() 61 | self.bias.data.zero_() 62 | 63 | def _check_input_dim(self, input): 64 | return NotImplemented 65 | 66 | def forward(self, input): 67 | self._check_input_dim(input) 68 | 69 | compute_stats = not self.freezed and \ 70 | self.training and self.track_running_stats 71 | 72 | ret = F.batch_norm(input, self.running_mean, self.running_var, 73 | self.weight, self.bias, compute_stats, 74 | self.momentum, self.eps) 75 | return ret 76 | 77 | def extra_repr(self): 78 | return '{num_features}, eps={eps}, momentum={momentum}, '\ 79 | 'affine={affine}, ' \ 80 | 'track_running_stats={track_running_stats}'.format( 81 | **self.__dict__) 82 | 83 | 84 | class BatchNorm2dNoSync(_BatchNorm): 85 | """ 86 | Equivalent to nn.BatchNorm2d 87 | """ 88 | 89 | def _check_input_dim(self, input): 90 | if input.dim() != 4: 91 | raise ValueError('expected 4D input (got {}D input)' 92 | .format(input.dim())) 93 | 94 | 95 | class BatchNorm2dSync(BatchNorm2dNoSync): 96 | """ 97 | BatchNorm2d with automatic multi-GPU Sync 98 | """ 99 | 100 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 101 | track_running_stats=True): 102 | super(BatchNorm2dSync, self).__init__( 103 | num_features, eps=eps, momentum=momentum, affine=affine, 104 | track_running_stats=track_running_stats) 105 | self.sync_enabled = True 106 | self.devices = list(range(torch.cuda.device_count())) 107 | if len(self.devices) > 1: 108 | # Initialize queues 109 | self.worker_ids = self.devices[1:] 110 | self.master_queue = Queue(len(self.worker_ids)) 111 | self.worker_queues = [Queue(1) for _ in self.worker_ids] 112 | 113 | def forward(self, x): 114 | compute_stats = not self.freezed and \ 115 | self.training and self.track_running_stats 116 | if self.sync_enabled and compute_stats and len(self.devices) > 1: 117 | if x.get_device() == self.devices[0]: 118 | # Master mode 119 | extra = { 120 | "is_master": True, 121 | "master_queue": self.master_queue, 122 | "worker_queues": self.worker_queues, 123 | "worker_ids": self.worker_ids 124 | } 125 | else: 126 | # Worker mode 127 | extra = { 128 | "is_master": False, 129 | "master_queue": self.master_queue, 130 | "worker_queue": self.worker_queues[ 131 | self.worker_ids.index(x.get_device())] 132 | } 133 | return batchnorm2d_sync(x, self.weight, self.bias, 134 | self.running_mean, self.running_var, 135 | extra, compute_stats, self.momentum, 136 | self.eps) 137 | return super(BatchNorm2dSync, self).forward(x) 138 | 139 | def __repr__(self): 140 | """repr""" 141 | rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \ 142 | 'affine={affine}, ' \ 143 | 'track_running_stats={track_running_stats},' \ 144 | 'devices={devices})' 145 | return rep.format(name=self.__class__.__name__, **self.__dict__) 146 | 147 | #BatchNorm2d = BatchNorm2dNoSync 148 | BatchNorm2d = BatchNorm2dSync 149 | -------------------------------------------------------------------------------- /iharm/utils/exp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import shutil 4 | import pprint 5 | from pathlib import Path 6 | from datetime import datetime 7 | 8 | import yaml 9 | import torch 10 | from easydict import EasyDict as edict 11 | 12 | from .log import logger, add_new_file_output_to_logger 13 | 14 | 15 | def init_experiment(args): 16 | model_path = Path(args.model_path) 17 | ftree = get_model_family_tree(model_path) 18 | if ftree is None: 19 | print('Models can only be located in the "models" directory in the root of the repository') 20 | sys.exit(1) 21 | 22 | cfg = load_config(model_path) 23 | update_config(cfg, args) 24 | 25 | experiments_path = Path(cfg.EXPS_PATH) 26 | exp_parent_path = experiments_path / '/'.join(ftree) 27 | exp_parent_path.mkdir(parents=True, exist_ok=True) 28 | 29 | if cfg.resume_exp: 30 | exp_path = find_resume_exp(exp_parent_path, cfg.resume_exp) 31 | else: 32 | last_exp_indx = find_last_exp_indx(exp_parent_path) 33 | exp_name = f'{last_exp_indx:03d}' 34 | if cfg.exp_name: 35 | exp_name += '_' + cfg.exp_name 36 | exp_path = exp_parent_path / exp_name 37 | exp_path.mkdir(parents=True) 38 | 39 | cfg.EXP_PATH = exp_path 40 | cfg.CHECKPOINTS_PATH = exp_path / 'checkpoints' 41 | cfg.VIS_PATH = exp_path / 'vis' 42 | cfg.LOGS_PATH = exp_path / 'logs' 43 | 44 | cfg.LOGS_PATH.mkdir(exist_ok=True) 45 | cfg.CHECKPOINTS_PATH.mkdir(exist_ok=True) 46 | cfg.VIS_PATH.mkdir(exist_ok=True) 47 | 48 | dst_script_path = exp_path / (model_path.stem + datetime.strftime(datetime.today(), '_%Y-%m-%d-%H-%M-%S.py')) 49 | shutil.copy(model_path, dst_script_path) 50 | 51 | if cfg.gpus != '': 52 | gpu_ids = [int(id) for id in cfg.gpus.split(',')] 53 | else: 54 | gpu_ids = list(range(cfg.ngpus)) 55 | cfg.gpus = ','.join([str(id) for id in gpu_ids]) 56 | cfg.gpu_ids = gpu_ids 57 | cfg.ngpus = len(gpu_ids) 58 | cfg.multi_gpu = cfg.ngpus > 1 59 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.gpus 60 | 61 | if cfg.multi_gpu: 62 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.gpus 63 | ngpus = torch.cuda.device_count() 64 | assert ngpus == cfg.ngpus 65 | # cfg.device = torch.device(f'cuda:{cfg.gpu_ids[0]}') 66 | cfg.device = torch.device('cuda') if cfg.gpu_ids else torch.device('cpu') 67 | 68 | add_new_file_output_to_logger(cfg.LOGS_PATH, prefix='train_') 69 | 70 | logger.info(f'Number of GPUs: {len(cfg.gpu_ids)}') 71 | cfg.gpu_ids = [int(id) for id in range(cfg.ngpus)] 72 | logger.info('Run experiment with config:') 73 | logger.info(pprint.pformat(cfg, indent=4)) 74 | 75 | return cfg 76 | 77 | 78 | def get_model_family_tree(model_path, terminate_name='models'): 79 | model_name = model_path.stem 80 | family_tree = [model_name] 81 | for x in model_path.parents: 82 | if x.stem == terminate_name: 83 | break 84 | family_tree.append(x.stem) 85 | else: 86 | return None 87 | 88 | return family_tree[::-1] 89 | 90 | 91 | def find_last_exp_indx(exp_parent_path): 92 | indx = 0 93 | for x in exp_parent_path.iterdir(): 94 | if not x.is_dir(): 95 | continue 96 | 97 | exp_name = x.stem 98 | if exp_name[:3].isnumeric(): 99 | indx = max(indx, int(exp_name[:3]) + 1) 100 | 101 | return indx 102 | 103 | 104 | def find_resume_exp(exp_parent_path, exp_pattern): 105 | candidates = sorted(exp_parent_path.glob(f'{exp_pattern}*')) 106 | if len(candidates) == 0: 107 | print(f'No experiments could be found that satisfies the pattern = "*{exp_pattern}"') 108 | sys.exit(1) 109 | elif len(candidates) > 1: 110 | print('More than one experiment found:') 111 | for x in candidates: 112 | print(x) 113 | sys.exit(1) 114 | else: 115 | exp_path = candidates[0] 116 | print(f'Continue with experiment "{exp_path}"') 117 | 118 | return exp_path 119 | 120 | 121 | def update_config(cfg, args): 122 | for param_name, value in vars(args).items(): 123 | if param_name.lower() in cfg or param_name.upper() in cfg: 124 | continue 125 | cfg[param_name] = value 126 | 127 | 128 | def load_config(model_path): 129 | model_name = model_path.stem 130 | config_path = model_path.parent / (model_name + '.yml') 131 | 132 | if config_path.exists(): 133 | cfg = load_config_file(config_path) 134 | else: 135 | cfg = dict() 136 | 137 | cwd = Path.cwd() 138 | config_parent = config_path.parent.absolute() 139 | while len(config_parent.parents) > 0: 140 | config_path = config_parent / 'config.yml' 141 | 142 | if config_path.exists(): 143 | local_config = load_config_file(config_path, model_name=model_name) 144 | cfg.update({k: v for k, v in local_config.items() if k not in cfg}) 145 | 146 | if config_parent.absolute() == cwd: 147 | break 148 | config_parent = config_parent.parent 149 | 150 | return edict(cfg) 151 | 152 | 153 | def load_config_file(config_path, model_name=None, return_edict=False): 154 | with open(config_path, 'r') as f: 155 | cfg = yaml.safe_load(f) 156 | 157 | if 'SUBCONFIGS' in cfg: 158 | if model_name is not None and model_name in cfg['SUBCONFIGS']: 159 | cfg.update(cfg['SUBCONFIGS'][model_name]) 160 | del cfg['SUBCONFIGS'] 161 | 162 | return edict(cfg) if return_edict else cfg 163 | -------------------------------------------------------------------------------- /iharm/utils/log.py: -------------------------------------------------------------------------------- 1 | import io 2 | import time 3 | import logging 4 | from datetime import datetime 5 | 6 | import numpy as np 7 | from torch.utils.tensorboard import SummaryWriter 8 | 9 | LOGGER_NAME = 'root' 10 | LOGGER_DATEFMT = '%Y-%m-%d %H:%M:%S' 11 | 12 | handler = logging.StreamHandler() 13 | 14 | logger = logging.getLogger(LOGGER_NAME) 15 | logger.setLevel(logging.INFO) 16 | logger.addHandler(handler) 17 | 18 | 19 | def add_new_file_output_to_logger(logs_path, prefix, only_message=False): 20 | log_name = prefix + datetime.strftime(datetime.today(), '%Y-%m-%d_%H-%M-%S') + '.log' 21 | logs_path.mkdir(exist_ok=True, parents=True) 22 | stdout_log_path = logs_path / log_name 23 | 24 | fh = logging.FileHandler(str(stdout_log_path)) 25 | 26 | fmt = '%(message)s' if only_message else '(%(levelname)s) %(asctime)s: %(message)s' 27 | formatter = logging.Formatter(fmt=fmt, datefmt=LOGGER_DATEFMT) 28 | fh.setFormatter(formatter) 29 | logger.addHandler(fh) 30 | 31 | 32 | class TqdmToLogger(io.StringIO): 33 | logger = None 34 | level = None 35 | buf = '' 36 | 37 | def __init__(self, logger, level=None, mininterval=5): 38 | super(TqdmToLogger, self).__init__() 39 | self.logger = logger 40 | self.level = level or logging.INFO 41 | self.mininterval = mininterval 42 | self.last_time = 0 43 | 44 | def write(self, buf): 45 | self.buf = buf.strip('\r\n\t ') 46 | 47 | def flush(self): 48 | if len(self.buf) > 0 and time.time() - self.last_time > self.mininterval: 49 | self.logger.log(self.level, self.buf) 50 | self.last_time = time.time() 51 | 52 | 53 | class SummaryWriterAvg(SummaryWriter): 54 | def __init__(self, *args, dump_period=20, **kwargs): 55 | super().__init__(*args, **kwargs) 56 | self._dump_period = dump_period 57 | self._avg_scalars = dict() 58 | 59 | def add_scalar(self, tag, value, global_step=None, disable_avg=False): 60 | if disable_avg or isinstance(value, (tuple, list, dict)): 61 | super().add_scalar(tag, np.array(value), global_step=global_step) 62 | else: 63 | if tag not in self._avg_scalars: 64 | self._avg_scalars[tag] = ScalarAccumulator(self._dump_period) 65 | avg_scalar = self._avg_scalars[tag] 66 | avg_scalar.add(value) 67 | 68 | if avg_scalar.is_full(): 69 | super().add_scalar(tag, avg_scalar.value, 70 | global_step=global_step) 71 | avg_scalar.reset() 72 | 73 | 74 | class ScalarAccumulator(object): 75 | def __init__(self, period): 76 | self.sum = 0 77 | self.cnt = 0 78 | self.period = period 79 | 80 | def add(self, value): 81 | self.sum += value 82 | self.cnt += 1 83 | 84 | @property 85 | def value(self): 86 | if self.cnt > 0: 87 | return self.sum / self.cnt 88 | else: 89 | return 0 90 | 91 | def reset(self): 92 | self.cnt = 0 93 | self.sum = 0 94 | 95 | def is_full(self): 96 | return self.cnt >= self.period 97 | 98 | def __len__(self): 99 | return self.cnt 100 | -------------------------------------------------------------------------------- /iharm/utils/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .log import logger 4 | 5 | 6 | def get_dims_with_exclusion(dim, exclude=None): 7 | dims = list(range(dim)) 8 | if exclude is not None: 9 | dims.remove(exclude) 10 | 11 | return dims 12 | 13 | 14 | def save_checkpoint(net, checkpoints_path, epoch=None, prefix='', verbose=True, multi_gpu=False): 15 | if epoch is None: 16 | checkpoint_name = 'last_checkpoint.pth' 17 | else: 18 | checkpoint_name = f'{epoch:03d}.pth' 19 | 20 | if prefix: 21 | checkpoint_name = f'{prefix}_{checkpoint_name}' 22 | 23 | if not checkpoints_path.exists(): 24 | checkpoints_path.mkdir(parents=True) 25 | 26 | checkpoint_path = checkpoints_path / checkpoint_name 27 | if verbose: 28 | logger.info(f'Save checkpoint to {str(checkpoint_path)}') 29 | 30 | state_dict = net.module.state_dict() if multi_gpu else net.state_dict() 31 | torch.save(state_dict, str(checkpoint_path)) 32 | 33 | 34 | def load_weights(model, path_to_weights, verbose=False): 35 | if verbose: 36 | logger.info(f'Load checkpoint from path: {path_to_weights}') 37 | 38 | current_state_dict = model.state_dict() 39 | new_state_dict = torch.load(str(path_to_weights), map_location='cpu') 40 | current_state_dict.update(new_state_dict) 41 | model.load_state_dict(current_state_dict) 42 | 43 | -------------------------------------------------------------------------------- /models/DucoNet_1024.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | from torchvision import transforms 5 | from easydict import EasyDict as edict 6 | from albumentations import HorizontalFlip, Resize, RandomResizedCrop 7 | 8 | from iharm.data.compose import ComposeDataset 9 | from iharm.data.hdataset import HDataset 10 | from iharm.data.transforms import HCompose 11 | from iharm.engine.simple_trainer import SimpleHTrainer 12 | from iharm.model import initializer 13 | from iharm.model.base import DucoNet_model 14 | from iharm.model.losses import MaskWeightedMSE 15 | from iharm.model.metrics import DenormalizedMSEMetric, DenormalizedPSNRMetric 16 | from iharm.utils.log import logger 17 | 18 | 19 | def main(cfg): 20 | model, model_cfg = init_model(cfg) 21 | train(model, cfg, model_cfg, start_epoch=cfg.start_epoch) 22 | 23 | def _model_init(model,_skip_init_names,cnt,max_cnt=4): 24 | if cnt>max_cnt: 25 | return 26 | for name,module in model.named_children(): 27 | if name in _skip_init_names[cnt]: 28 | # print("skip:",name) 29 | _model_init(module,_skip_init_names,cnt+1,max_cnt=max_cnt) 30 | else: 31 | # print(name) 32 | module.apply(initializer.XavierGluon(rnd_type='gaussian', magnitude=1.0)) 33 | 34 | 35 | def init_model(cfg): 36 | model_cfg = edict() 37 | model_cfg.crop_size = (1024,1024) 38 | model_cfg.input_normalization = { 39 | 'mean': [.485, .456, .406], 40 | 'std': [.229, .224, .225] 41 | } 42 | model_cfg.depth = 4 43 | 44 | model_cfg.input_transform = transforms.Compose([ 45 | transforms.ToTensor(), 46 | transforms.Normalize(model_cfg.input_normalization['mean'], model_cfg.input_normalization['std']), 47 | ]) 48 | 49 | model = DucoNet_model( 50 | depth=4, ch=32, image_fusion=True, attention_mid_k=0.5, 51 | attend_from=2, batchnorm_from=2,w_dim=256,control_module_start=cfg.control_module_start 52 | ) 53 | 54 | model.to(cfg.device) 55 | 56 | _skip_init_names={ 57 | 0:['decoder'], 58 | 1:['up_blocks'], 59 | 2:['0','1','2'], 60 | 3:['control_module'], 61 | 4:['a_styleblock','b_styleblock','l_styleblock'] 62 | } 63 | 64 | cnt=0 65 | _model_init(model,_skip_init_names,cnt,max_cnt=4) 66 | 67 | return model, model_cfg 68 | 69 | 70 | def train(model, cfg, model_cfg, start_epoch=0): 71 | cfg.batch_size = 16 if cfg.batch_size < 1 else cfg.batch_size 72 | cfg.val_batch_size = cfg.batch_size 73 | 74 | cfg.input_normalization = model_cfg.input_normalization 75 | crop_size = model_cfg.crop_size 76 | 77 | loss_cfg = edict() 78 | loss_cfg.pixel_loss = MaskWeightedMSE(min_area=100) 79 | loss_cfg.pixel_loss_weight = 1.0 80 | 81 | num_epochs = 120 82 | 83 | train_augmentator = HCompose([ 84 | RandomResizedCrop(*crop_size, scale=(0.5, 1.0)), 85 | HorizontalFlip(), 86 | ]) 87 | 88 | val_augmentator = HCompose([ 89 | Resize(*crop_size) 90 | ]) 91 | 92 | 93 | trainset = ComposeDataset( 94 | [ 95 | # HDataset(cfg.HFLICKR_PATH, split='train'), 96 | # HDataset(cfg.HDAY2NIGHT_PATH, split='train'), 97 | # HDataset(cfg.HCOCO_PATH, split='train'), 98 | HDataset(cfg.HADOBE5K1_PATH, split='train'), 99 | ], 100 | augmentator=train_augmentator, 101 | input_transform=model_cfg.input_transform, 102 | keep_background_prob=0.05, 103 | ) 104 | 105 | valset = ComposeDataset( 106 | [ 107 | # HDataset(cfg.HFLICKR_PATH, split='test'), 108 | # HDataset(cfg.HDAY2NIGHT_PATH, split='test'), 109 | # HDataset(cfg.HCOCO_PATH, split='test'), 110 | HDataset(cfg.HADOBE5K1_PATH, split='test'), 111 | ], 112 | augmentator=val_augmentator, 113 | input_transform=model_cfg.input_transform, 114 | keep_background_prob=-1, 115 | ) 116 | 117 | optimizer_params = { 118 | 'lr': cfg.lr, 119 | 'betas': (0.9, 0.999), 'eps': 1e-8 120 | } 121 | lr_g = (cfg.batch_size / 64) ** 0.5 122 | optimizer_params['lr'] = optimizer_params['lr'] * lr_g 123 | 124 | lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR, 125 | milestones=[105, 115], gamma=0.1) 126 | trainer = SimpleHTrainer( 127 | model, cfg, model_cfg, loss_cfg, 128 | trainset, valset, 129 | optimizer='adam', 130 | optimizer_params=optimizer_params, 131 | lr_scheduler=lr_scheduler, 132 | metrics=[ 133 | DenormalizedPSNRMetric( 134 | 'images', 'target_images', 135 | mean=torch.tensor(cfg.input_normalization['mean'], dtype=torch.float32).view(1, 3, 1, 1), 136 | std=torch.tensor(cfg.input_normalization['std'], dtype=torch.float32).view(1, 3, 1, 1), 137 | ), 138 | DenormalizedMSEMetric( 139 | 'images', 'target_images', 140 | mean=torch.tensor(cfg.input_normalization['mean'], dtype=torch.float32).view(1, 3, 1, 1), 141 | std=torch.tensor(cfg.input_normalization['std'], dtype=torch.float32).view(1, 3, 1, 1), 142 | ) 143 | ], 144 | checkpoint_interval=10, 145 | image_dump_interval=1000 146 | ) 147 | 148 | logger.info(f'Starting Epoch: {start_epoch}') 149 | logger.info(f'Total Epochs: {num_epochs}') 150 | for epoch in range(start_epoch, num_epochs): 151 | trainer.training(epoch) 152 | trainer.validation(epoch) 153 | -------------------------------------------------------------------------------- /models/DucoNet_256.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | from torchvision import transforms 5 | from easydict import EasyDict as edict 6 | from albumentations import HorizontalFlip, Resize, RandomResizedCrop 7 | 8 | from iharm.data.compose import ComposeDataset 9 | from iharm.data.hdataset import HDataset 10 | from iharm.data.transforms import HCompose 11 | from iharm.engine.simple_trainer import SimpleHTrainer 12 | from iharm.model import initializer 13 | from iharm.model.base import DucoNet_model 14 | from iharm.model.losses import MaskWeightedMSE 15 | from iharm.model.metrics import DenormalizedMSEMetric, DenormalizedPSNRMetric 16 | from iharm.utils.log import logger 17 | 18 | 19 | def main(cfg): 20 | model, model_cfg = init_model(cfg) 21 | train(model, cfg, model_cfg, start_epoch=cfg.start_epoch) 22 | 23 | def _model_init(model,_skip_init_names,cnt,max_cnt=4): 24 | if cnt>max_cnt: 25 | return 26 | for name,module in model.named_children(): 27 | if name in _skip_init_names[cnt]: 28 | # print("skip:",name) 29 | _model_init(module,_skip_init_names,cnt+1,max_cnt=max_cnt) 30 | else: 31 | # print(name) 32 | module.apply(initializer.XavierGluon(rnd_type='gaussian', magnitude=1.0)) 33 | 34 | 35 | def init_model(cfg): 36 | model_cfg = edict() 37 | model_cfg.crop_size = (256, 256) 38 | model_cfg.input_normalization = { 39 | 'mean': [.485, .456, .406], 40 | 'std': [.229, .224, .225] 41 | } 42 | model_cfg.depth = 4 43 | 44 | model_cfg.input_transform = transforms.Compose([ 45 | transforms.ToTensor(), 46 | transforms.Normalize(model_cfg.input_normalization['mean'], model_cfg.input_normalization['std']), 47 | ]) 48 | 49 | model = DucoNet_model( 50 | depth=4, ch=32, image_fusion=True, attention_mid_k=0.5, 51 | attend_from=2, batchnorm_from=2,w_dim=256,control_module_start=cfg.control_module_start 52 | ) 53 | 54 | model.to(cfg.device) 55 | 56 | _skip_init_names={ 57 | 0:['decoder'], 58 | 1:['up_blocks'], 59 | 2:['0','1','2'], 60 | 3:['control_module'], 61 | 4:['a_styleblock','b_styleblock','l_styleblock'] 62 | } 63 | 64 | cnt=0 65 | _model_init(model,_skip_init_names,cnt,max_cnt=4) 66 | 67 | return model, model_cfg 68 | 69 | 70 | def train(model, cfg, model_cfg, start_epoch=0): 71 | cfg.batch_size = 16 if cfg.batch_size < 1 else cfg.batch_size 72 | cfg.val_batch_size = cfg.batch_size 73 | 74 | cfg.input_normalization = model_cfg.input_normalization 75 | crop_size = model_cfg.crop_size 76 | 77 | loss_cfg = edict() 78 | loss_cfg.pixel_loss = MaskWeightedMSE(min_area=100) 79 | loss_cfg.pixel_loss_weight = 1.0 80 | 81 | num_epochs = 120 82 | 83 | train_augmentator = HCompose([ 84 | RandomResizedCrop(*crop_size, scale=(0.5, 1.0)), 85 | HorizontalFlip(), 86 | ]) 87 | 88 | val_augmentator = HCompose([ 89 | Resize(*crop_size) 90 | ]) 91 | 92 | 93 | trainset = ComposeDataset( 94 | [ 95 | HDataset(cfg.HFLICKR_PATH, split='train'), 96 | HDataset(cfg.HDAY2NIGHT_PATH, split='train'), 97 | HDataset(cfg.HCOCO_PATH, split='train'), 98 | HDataset(cfg.HADOBE5K_PATH, split='train'), 99 | ], 100 | augmentator=train_augmentator, 101 | input_transform=model_cfg.input_transform, 102 | keep_background_prob=0.05, 103 | ) 104 | 105 | valset = ComposeDataset( 106 | [ 107 | HDataset(cfg.HFLICKR_PATH, split='test'), 108 | HDataset(cfg.HDAY2NIGHT_PATH, split='test'), 109 | HDataset(cfg.HCOCO_PATH, split='test'), 110 | ], 111 | augmentator=val_augmentator, 112 | input_transform=model_cfg.input_transform, 113 | keep_background_prob=-1, 114 | ) 115 | 116 | optimizer_params = { 117 | 'lr': cfg.lr, 118 | 'betas': (0.9, 0.999), 'eps': 1e-8 119 | } 120 | lr_g = (cfg.batch_size / 64) ** 0.5 121 | optimizer_params['lr'] = optimizer_params['lr'] * lr_g 122 | 123 | lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR, 124 | milestones=[105, 115], gamma=0.1) 125 | trainer = SimpleHTrainer( 126 | model, cfg, model_cfg, loss_cfg, 127 | trainset, valset, 128 | optimizer='adam', 129 | optimizer_params=optimizer_params, 130 | lr_scheduler=lr_scheduler, 131 | metrics=[ 132 | DenormalizedPSNRMetric( 133 | 'images', 'target_images', 134 | mean=torch.tensor(cfg.input_normalization['mean'], dtype=torch.float32).view(1, 3, 1, 1), 135 | std=torch.tensor(cfg.input_normalization['std'], dtype=torch.float32).view(1, 3, 1, 1), 136 | ), 137 | DenormalizedMSEMetric( 138 | 'images', 'target_images', 139 | mean=torch.tensor(cfg.input_normalization['mean'], dtype=torch.float32).view(1, 3, 1, 1), 140 | std=torch.tensor(cfg.input_normalization['std'], dtype=torch.float32).view(1, 3, 1, 1), 141 | ) 142 | ], 143 | checkpoint_interval=10, 144 | image_dump_interval=1000 145 | ) 146 | 147 | logger.info(f'Starting Epoch: {start_epoch}') 148 | logger.info(f'Total Epochs: {num_epochs}') 149 | for epoch in range(start_epoch, num_epochs): 150 | trainer.training(epoch) 151 | trainer.validation(epoch) 152 | -------------------------------------------------------------------------------- /models/improved_ssam.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | from torchvision import transforms 5 | from easydict import EasyDict as edict 6 | from albumentations import HorizontalFlip, Resize, RandomResizedCrop 7 | 8 | from iharm.data.compose import ComposeDataset 9 | from iharm.data.hdataset import HDataset 10 | from iharm.data.transforms import HCompose 11 | from iharm.engine.simple_trainer import SimpleHTrainer 12 | from iharm.model import initializer 13 | from iharm.model.base import SSAMImageHarmonization 14 | from iharm.model.losses import MaskWeightedMSE 15 | from iharm.model.metrics import DenormalizedMSEMetric, DenormalizedPSNRMetric 16 | from iharm.utils.log import logger 17 | 18 | 19 | def main(cfg): 20 | model, model_cfg = init_model(cfg) 21 | train(model, cfg, model_cfg, start_epoch=cfg.start_epoch) 22 | 23 | 24 | def init_model(cfg): 25 | model_cfg = edict() 26 | model_cfg.crop_size = (256, 256) 27 | model_cfg.input_normalization = { 28 | 'mean': [.485, .456, .406], 29 | 'std': [.229, .224, .225] 30 | } 31 | model_cfg.depth = 4 32 | 33 | model_cfg.input_transform = transforms.Compose([ 34 | transforms.ToTensor(), 35 | transforms.Normalize(model_cfg.input_normalization['mean'], model_cfg.input_normalization['std']), 36 | ]) 37 | 38 | model = SSAMImageHarmonization( 39 | depth=4, ch=32, image_fusion=True, attention_mid_k=0.5, 40 | attend_from=2, batchnorm_from=2 41 | ) 42 | 43 | model.to(cfg.device) 44 | model.apply(initializer.XavierGluon(rnd_type='gaussian', magnitude=1.0)) 45 | 46 | return model, model_cfg 47 | 48 | 49 | def train(model, cfg, model_cfg, start_epoch=0): 50 | cfg.batch_size = 16 if cfg.batch_size < 1 else cfg.batch_size 51 | cfg.val_batch_size = cfg.batch_size 52 | 53 | cfg.input_normalization = model_cfg.input_normalization 54 | crop_size = model_cfg.crop_size 55 | 56 | loss_cfg = edict() 57 | loss_cfg.pixel_loss = MaskWeightedMSE() 58 | loss_cfg.pixel_loss_weight = 1.0 59 | 60 | num_epochs = 120 61 | 62 | train_augmentator = HCompose([ 63 | RandomResizedCrop(*crop_size, scale=(0.5, 1.0)), 64 | HorizontalFlip(), 65 | ]) 66 | 67 | val_augmentator = HCompose([ 68 | Resize(*crop_size) 69 | ]) 70 | 71 | trainset = ComposeDataset( 72 | [ 73 | HDataset(cfg.HFLICKR_PATH, split='train'), 74 | HDataset(cfg.HDAY2NIGHT_PATH, split='train'), 75 | HDataset(cfg.HCOCO_PATH, split='train'), 76 | HDataset(cfg.HADOBE5K_PATH, split='train'), 77 | ], 78 | augmentator=train_augmentator, 79 | input_transform=model_cfg.input_transform, 80 | keep_background_prob=0.05, 81 | ) 82 | 83 | valset = ComposeDataset( 84 | [ 85 | HDataset(cfg.HFLICKR_PATH, split='test'), 86 | HDataset(cfg.HDAY2NIGHT_PATH, split='test'), 87 | HDataset(cfg.HCOCO_PATH, split='test'), 88 | ], 89 | augmentator=val_augmentator, 90 | input_transform=model_cfg.input_transform, 91 | keep_background_prob=-1, 92 | ) 93 | 94 | optimizer_params = { 95 | 'lr': 1e-3, 96 | 'betas': (0.9, 0.999), 'eps': 1e-8 97 | } 98 | 99 | lr_scheduler = partial(torch.optim.lr_scheduler.MultiStepLR, 100 | milestones=[105, 115], gamma=0.1) 101 | trainer = SimpleHTrainer( 102 | model, cfg, model_cfg, loss_cfg, 103 | trainset, valset, 104 | optimizer='adam', 105 | optimizer_params=optimizer_params, 106 | lr_scheduler=lr_scheduler, 107 | metrics=[ 108 | DenormalizedPSNRMetric( 109 | 'images', 'target_images', 110 | mean=torch.tensor(cfg.input_normalization['mean'], dtype=torch.float32).view(1, 3, 1, 1), 111 | std=torch.tensor(cfg.input_normalization['std'], dtype=torch.float32).view(1, 3, 1, 1), 112 | ), 113 | DenormalizedMSEMetric( 114 | 'images', 'target_images', 115 | mean=torch.tensor(cfg.input_normalization['mean'], dtype=torch.float32).view(1, 3, 1, 1), 116 | std=torch.tensor(cfg.input_normalization['std'], dtype=torch.float32).view(1, 3, 1, 1), 117 | ) 118 | ], 119 | checkpoint_interval=10, 120 | image_dump_interval=1000 121 | ) 122 | 123 | logger.info(f'Starting Epoch: {start_epoch}') 124 | logger.info(f'Total Epochs: {num_epochs}') 125 | for epoch in range(start_epoch, num_epochs): 126 | trainer.training(epoch) 127 | trainer.validation(epoch) 128 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scipy 2 | numpy 3 | Cython 4 | scikit-image 5 | opencv-python-headless 6 | Pillow 7 | matplotlib 8 | imgaug 9 | albumentations 10 | graphviz 11 | tqdm 12 | pyyaml 13 | easydict 14 | tensorboard 15 | future 16 | cffi 17 | ninja -------------------------------------------------------------------------------- /scripts/evaluate_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | # os.environ["CUDA_VISIBLE_DEVICES"] = "2,3" 3 | 4 | import argparse 5 | import sys 6 | 7 | sys.path.insert(0, '.') 8 | 9 | import torch 10 | from pathlib import Path 11 | from albumentations import Resize, NoOp 12 | from iharm.data.hdataset import HDataset 13 | from iharm.data.transforms import HCompose, LongestMaxSizeIfLarger 14 | from iharm.inference.predictor import Predictor 15 | from iharm.inference.evaluation import evaluate_dataset 16 | from iharm.inference.metrics import MetricsHub, MSE, fMSE, PSNR, N, AvgPredictTime 17 | from iharm.inference.utils import load_model, find_checkpoint 18 | from iharm.mconfigs import ALL_MCONFIGS 19 | from iharm.utils.exp import load_config_file 20 | from iharm.utils.log import logger, add_new_file_output_to_logger 21 | 22 | 23 | RESIZE_STRATEGIES = { 24 | 'None': NoOp(), 25 | 'LimitLongest1024': LongestMaxSizeIfLarger(1024), 26 | 'Fixed256': Resize(256, 256), 27 | 'Fixed512': Resize(512, 512), 28 | 'Fixed1024': Resize(1024, 1024) 29 | } 30 | 31 | 32 | def parse_args(): 33 | parser = argparse.ArgumentParser() 34 | 35 | parser.add_argument('model_type', choices=ALL_MCONFIGS.keys()) 36 | parser.add_argument('checkpoint', type=str, 37 | help='The path to the checkpoint. ' 38 | 'This can be a relative path (relative to cfg.MODELS_PATH) ' 39 | 'or an absolute path. The file extension can be omitted.') 40 | parser.add_argument('--datasets', type=str, default='HFlickr,HDay2Night,HCOCO,HAdobe5k1', 41 | help='Each dataset name must be one of the prefixes in config paths, ' 42 | 'which look like DATASET_PATH.') 43 | parser.add_argument('--resize-strategy', type=str, choices=RESIZE_STRATEGIES.keys(), default='Fixed256') 44 | parser.add_argument('--use-flip', action='store_true', default=False, 45 | help='Use horizontal flip test-time augmentation.') 46 | parser.add_argument('--gpu', type=str, default=0, help='ID of used GPU.') 47 | parser.add_argument('--config-path', type=str, default='./config.yml', 48 | help='The path to the config file.') 49 | 50 | parser.add_argument('--eval-prefix', type=str, default='') 51 | 52 | args = parser.parse_args() 53 | cfg = load_config_file(args.config_path, return_edict=True) 54 | return args, cfg 55 | 56 | 57 | def main(): 58 | args, cfg = parse_args() 59 | checkpoint_path = find_checkpoint(cfg.MODELS_PATH, args.checkpoint) 60 | add_new_file_output_to_logger( 61 | logs_path=Path(cfg.EXPS_PATH) / 'evaluation_logs', 62 | prefix=f'{Path(checkpoint_path).stem}_', 63 | only_message=True 64 | ) 65 | logger.info(vars(args)) 66 | 67 | device = torch.device(f'cuda:{args.gpu}') 68 | net = load_model(args.model_type, checkpoint_path, verbose=True) 69 | predictor = Predictor(net, device, with_flip=args.use_flip) 70 | 71 | datasets_names = args.datasets.split(',') 72 | datasets_metrics = [] 73 | for dataset_indx, dataset_name in enumerate(datasets_names): 74 | dataset = HDataset( 75 | cfg.get(f'{dataset_name.upper()}_PATH'), split='test', 76 | augmentator=HCompose([RESIZE_STRATEGIES[args.resize_strategy]]), 77 | keep_background_prob=-1 78 | ) 79 | 80 | dataset_metrics = MetricsHub([N(), MSE(), fMSE(), PSNR(), AvgPredictTime()], 81 | name=dataset_name) 82 | 83 | evaluate_dataset(dataset, predictor, dataset_metrics) 84 | datasets_metrics.append(dataset_metrics) 85 | if dataset_indx == 0: 86 | logger.info(dataset_metrics.get_table_header()) 87 | logger.info(dataset_metrics) 88 | 89 | if len(datasets_metrics) > 1: 90 | overall_metrics = sum(datasets_metrics, MetricsHub([], 'Overall')) 91 | logger.info('-' * len(str(overall_metrics))) 92 | logger.info(overall_metrics) 93 | 94 | 95 | if __name__ == '__main__': 96 | main() 97 | -------------------------------------------------------------------------------- /scripts/evaluate_model_fg_ratios.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | 4 | sys.path.insert(0, '.') 5 | 6 | import torch 7 | from pathlib import Path 8 | from tqdm import trange 9 | 10 | from albumentations import Resize, NoOp 11 | from iharm.data.hdataset import HDataset 12 | from iharm.data.transforms import HCompose, LongestMaxSizeIfLarger 13 | from iharm.inference.predictor import Predictor 14 | from iharm.inference.metrics import MetricsHub, MSE, fMSE, PSNR, N 15 | from iharm.inference.utils import load_model, find_checkpoint 16 | from iharm.mconfigs import ALL_MCONFIGS 17 | from iharm.utils.exp import load_config_file 18 | from iharm.utils.log import logger, add_new_file_output_to_logger 19 | 20 | 21 | RESIZE_STRATEGIES = { 22 | 'None': NoOp(), 23 | 'LimitLongest1024': LongestMaxSizeIfLarger(1024), 24 | 'Fixed256': Resize(256, 256), 25 | 'Fixed512': Resize(512, 512) 26 | } 27 | 28 | 29 | def parse_args(): 30 | parser = argparse.ArgumentParser() 31 | 32 | parser.add_argument('model_type', choices=ALL_MCONFIGS.keys()) 33 | parser.add_argument('checkpoint', type=str, 34 | help='The path to the checkpoint. ' 35 | 'This can be a relative path (relative to cfg.MODELS_PATH) ' 36 | 'or an absolute path. The file extension can be omitted.') 37 | parser.add_argument('--datasets', type=str, default='HFlickr,HDay2Night,HCOCO,HAdobe5k', 38 | help='Each dataset name must be one of the prefixes in config paths, ' 39 | 'which look like DATASET_PATH.') 40 | parser.add_argument('--resize-strategy', type=str, choices=RESIZE_STRATEGIES.keys(), default='Fixed256') 41 | parser.add_argument('--use-flip', action='store_true', default=False, 42 | help='Use horizontal flip test-time augmentation.') 43 | parser.add_argument('--gpu', type=str, default=0, help='ID of used GPU.') 44 | parser.add_argument('--config-path', type=str, default='./config.yml', 45 | help='The path to the config file.') 46 | parser.add_argument('--results-path', type=str, default='', 47 | help='The path to the evaluation results. ' 48 | 'Default path: cfg.EXPS_PATH/evaluation_results.') 49 | 50 | parser.add_argument('--eval-prefix', type=str, default='') 51 | 52 | args = parser.parse_args() 53 | cfg = load_config_file(args.config_path, return_edict=True) 54 | return args, cfg 55 | 56 | 57 | def main(): 58 | args, cfg = parse_args() 59 | checkpoint_path = find_checkpoint(cfg.MODELS_PATH, args.checkpoint) 60 | add_new_file_output_to_logger( 61 | logs_path=Path(cfg.EXPS_PATH) / 'evaluation_results', 62 | prefix=f'{Path(checkpoint_path).stem}_', 63 | only_message=True 64 | ) 65 | logger.info(vars(args)) 66 | 67 | device = torch.device(f'cuda:{args.gpu}') 68 | net = load_model(args.model_type, checkpoint_path, verbose=True) 69 | predictor = Predictor(net, device, with_flip=args.use_flip) 70 | 71 | fg_ratio_intervals = [(0.0, 0.05), (0.05, 0.15), (0.15, 1.0), (0.0, 1.00)] 72 | 73 | datasets_names = args.datasets.split(',') 74 | datasets_metrics = [[] for _ in fg_ratio_intervals] 75 | for dataset_indx, dataset_name in enumerate(datasets_names): 76 | dataset = HDataset( 77 | cfg.get(f'{dataset_name.upper()}_PATH'), split='test', 78 | augmentator=HCompose([RESIZE_STRATEGIES[args.resize_strategy]]), 79 | keep_background_prob=-1 80 | ) 81 | 82 | dataset_metrics = [] 83 | for fg_ratio_min, fg_ratio_max in fg_ratio_intervals: 84 | dataset_metrics.append(MetricsHub([N(), MSE(), fMSE(), PSNR()], 85 | name=f'{dataset_name} ({fg_ratio_min:.0%}-{fg_ratio_max:.0%})', 86 | name_width=28)) 87 | 88 | for sample_i in trange(len(dataset), desc=f'Testing on {dataset_name}'): 89 | sample = dataset.get_sample(sample_i) 90 | sample = dataset.augment_sample(sample) 91 | 92 | sample_mask = sample['object_mask'] 93 | sample_fg_ratio = (sample_mask > 0.5).sum() / (sample_mask.shape[0] * sample_mask.shape[1]) 94 | pred = predictor.predict(sample['image'], sample_mask, return_numpy=False) 95 | 96 | target_image = torch.as_tensor(sample['target_image'], dtype=torch.float32).to(predictor.device) 97 | sample_mask = torch.as_tensor(sample_mask, dtype=torch.float32).to(predictor.device) 98 | with torch.no_grad(): 99 | for metrics_hub, (fg_ratio_min, fg_ratio_max) in zip(dataset_metrics, fg_ratio_intervals): 100 | if fg_ratio_min <= sample_fg_ratio <= fg_ratio_max: 101 | metrics_hub.compute_and_add(pred, target_image, sample_mask) 102 | 103 | for indx, metrics_hub in enumerate(dataset_metrics): 104 | datasets_metrics[indx].append(metrics_hub) 105 | if dataset_indx == 0: 106 | logger.info(dataset_metrics[-1].get_table_header()) 107 | for metrics_hub in dataset_metrics: 108 | logger.info(metrics_hub) 109 | 110 | if len(datasets_metrics) > 1: 111 | overall_metrics = [sum(x, MetricsHub([], f'Overall ({fg_ratio_min:.0%}-{fg_ratio_max:.0%})', name_width=28)) 112 | for x, (fg_ratio_min, fg_ratio_max) in zip(datasets_metrics, fg_ratio_intervals)] 113 | logger.info('-' * len(str(overall_metrics[-1]))) 114 | for x in overall_metrics: 115 | logger.info(x) 116 | 117 | 118 | if __name__ == '__main__': 119 | main() 120 | -------------------------------------------------------------------------------- /scripts/predict_for_dir.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | from pathlib import Path 5 | import sys 6 | 7 | import cv2 8 | import numpy as np 9 | import torch 10 | from tqdm import tqdm 11 | 12 | sys.path.insert(0, '.') 13 | from iharm.inference.predictor import Predictor 14 | from iharm.inference.utils import load_model, find_checkpoint 15 | from iharm.mconfigs import ALL_MCONFIGS 16 | from iharm.utils.log import logger 17 | from iharm.utils.exp import load_config_file 18 | 19 | 20 | def main(): 21 | args, cfg = parse_args() 22 | 23 | device = torch.device(f'cuda:{args.gpu}') 24 | checkpoint_path = find_checkpoint(cfg.MODELS_PATH, args.checkpoint) 25 | net = load_model(args.model_type, checkpoint_path, verbose=True) 26 | predictor = Predictor(net, device) 27 | 28 | image_names = os.listdir(args.images) 29 | 30 | def _save_image(image_name, bgr_image): 31 | rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_RGB2BGR) 32 | cv2.imwrite( 33 | str(cfg.RESULTS_PATH / f'{image_name}'), 34 | rgb_image, 35 | [cv2.IMWRITE_JPEG_QUALITY, 85] 36 | ) 37 | 38 | logger.info(f'Save images to {cfg.RESULTS_PATH}') 39 | 40 | resize_shape = (args.resize, ) * 2 41 | for image_name in tqdm(image_names): 42 | image_path = osp.join(args.images, image_name) 43 | image = cv2.imread(image_path) 44 | composite_image_lab = cv2.cvtColor(image, cv2.COLOR_RGB2Lab) 45 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 46 | image_size = image.shape 47 | if resize_shape[0] > 0: 48 | image = cv2.resize(image, resize_shape, cv2.INTER_LINEAR) 49 | 50 | mask_path = osp.join(args.masks, '_'.join(image_name.split('_')[:-1]) + '.png') 51 | mask_image = cv2.imread(mask_path) 52 | if resize_shape[0] > 0: 53 | mask_image = cv2.resize(mask_image, resize_shape, cv2.INTER_LINEAR) 54 | mask = mask_image[:, :, 0] 55 | mask[mask <= 100] = 0 56 | mask[mask > 100] = 1 57 | mask = mask.astype(np.float32) 58 | 59 | pred = predictor.predict(image, mask, composite_image_lab) 60 | 61 | if args.original_size: 62 | pred = cv2.resize(pred, image_size[:-1][::-1]) 63 | _save_image(image_name, pred) 64 | 65 | 66 | def parse_args(): 67 | parser = argparse.ArgumentParser() 68 | parser.add_argument('model_type', choices=ALL_MCONFIGS.keys()) 69 | parser.add_argument('checkpoint', type=str, 70 | help='The path to the checkpoint. ' 71 | 'This can be a relative path (relative to cfg.MODELS_PATH) ' 72 | 'or an absolute path. The file extension can be omitted.') 73 | parser.add_argument( 74 | '--images', type=str, 75 | help='Path to directory with .jpg images to get predictions for.' 76 | ) 77 | parser.add_argument( 78 | '--masks', type=str, 79 | help='Path to directory with .png binary masks for images, named exactly like images without last _postfix.' 80 | ) 81 | parser.add_argument( 82 | '--resize', type=int, default=256, 83 | help='Resize image to a given size before feeding it into the network. If -1 the network input is not resized.' 84 | ) 85 | parser.add_argument( 86 | '--original-size', action='store_true', default=False, 87 | help='Resize predicted image back to the original size.' 88 | ) 89 | parser.add_argument('--gpu', type=str, default=0, help='ID of used GPU.') 90 | parser.add_argument('--config-path', type=str, default='./config.yml', help='The path to the config file.') 91 | parser.add_argument( 92 | '--results-path', type=str, default='', 93 | help='The path to the harmonized images. Default path: cfg.EXPS_PATH/predictions.' 94 | ) 95 | 96 | args = parser.parse_args() 97 | cfg = load_config_file(args.config_path, return_edict=True) 98 | cfg.EXPS_PATH = Path(cfg.EXPS_PATH) 99 | cfg.RESULTS_PATH = Path(args.results_path) if len(args.results_path) else cfg.EXPS_PATH / 'predictions' 100 | cfg.RESULTS_PATH.mkdir(parents=True, exist_ok=True) 101 | logger.info(cfg) 102 | return args, cfg 103 | 104 | 105 | if __name__ == '__main__': 106 | main() 107 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | python scripts/evaluate_model.py DucoNet ./checkpoints/last_model/DucoNet256.pth \ 2 | --resize-strategy Fixed256 \ 3 | --gpu 0 4 | #python scripts/evaluate_model.py DucoNet ./checkpoints/last_model/DucoNet1024.pth \ 5 | #--resize-strategy Fixed1024 \ 6 | #--gpu 1 \ 7 | #--datasets HAdobe5k1 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import importlib.util 3 | 4 | import torch 5 | from iharm.utils.exp import init_experiment 6 | 7 | 8 | def main(): 9 | args = parse_args() 10 | model_script = load_module(args.model_path) 11 | 12 | cfg = init_experiment(args) 13 | 14 | torch.backends.cudnn.benchmark = True 15 | torch.multiprocessing.set_sharing_strategy('file_system') 16 | model_script.main(cfg) 17 | 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser() 21 | 22 | parser.add_argument('model_path', type=str, 23 | help='Path to the model script.') 24 | 25 | parser.add_argument('--exp-name', type=str, default='', 26 | help='Here you can specify the name of the experiment. ' 27 | 'It will be added as a suffix to the experiment folder.') 28 | 29 | parser.add_argument('--workers', type=int, default=4, 30 | metavar='N', help='Dataloader threads.') 31 | 32 | parser.add_argument('--batch-size', type=int, default=-1, 33 | help='You can override model batch size by specify positive number.') 34 | 35 | parser.add_argument('--ngpus', type=int, default=1, 36 | help='Number of GPUs. ' 37 | 'If you only specify "--gpus" argument, the ngpus value will be calculated automatically. ' 38 | 'You should use either this argument or "--gpus".') 39 | 40 | parser.add_argument('--gpus', type=str, default='', required=False, 41 | help='Ids of used GPUs. You should use either this argument or "--ngpus".') 42 | 43 | parser.add_argument('--resume-exp', type=str, default=None, 44 | help='The prefix of the name of the experiment to be continued. ' 45 | 'If you use this field, you must specify the "--resume-prefix" argument.') 46 | 47 | parser.add_argument('--resume-prefix', type=str, default='latest', 48 | help='The prefix of the name of the checkpoint to be loaded.') 49 | 50 | parser.add_argument('--start-epoch', type=int, default=0, 51 | help='The number of the starting epoch from which training will continue. ' 52 | '(it is important for correct logging and learning rate)') 53 | 54 | parser.add_argument('--weights', type=str, default=None, 55 | help='Model weights will be loaded from the specified path if you use this argument.') 56 | 57 | parser.add_argument('--lr', type=float, default=1e-3, 58 | help='') 59 | 60 | parser.add_argument('--control_module_start', type=int, default=-1, 61 | help='') 62 | 63 | 64 | 65 | return parser.parse_args() 66 | 67 | 68 | def load_module(script_path): 69 | spec = importlib.util.spec_from_file_location("model_script", script_path) 70 | model_script = importlib.util.module_from_spec(spec) 71 | spec.loader.exec_module(model_script) 72 | 73 | return model_script 74 | 75 | 76 | if __name__ == '__main__': 77 | main() 78 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | ## for low-resolution (256 * 256) 2 | 3 | python train.py models/DucoNet_256.py --workers=8 --gpus=0,1 --exp-name=DucoNet_256 --batch-size=32 4 | 5 | ## for high-resolution (1024 * 1024) 6 | 7 | #python train.py models/DucoNet_1024.py --workers=8 --gpus=0,1 --exp-name=DucoNet_1024 --batch-size=4 --------------------------------------------------------------------------------