├── .gitignore ├── Network.png ├── README.md ├── TonemapReinhard_npy.py ├── base ├── __init__.py ├── base_data_loader.py ├── base_model.py └── base_trainer.py ├── config ├── edge_module.json ├── edge_module_gray.json ├── mask_module.json └── mask_module_gray.json ├── data_loader ├── __init__.py ├── data_loader_LearnEdgeNet.py ├── data_loader_LearnMaskNet.py └── dataset │ ├── __init__.py │ ├── dataset_LearnEdgeNet.py │ └── dataset_LearnMaskNet.py ├── execute ├── infer_LearnEdgeNet.py ├── infer_LearnMaskNet.py └── train.py ├── model ├── __init__.py ├── layer_utils │ ├── __init__.py │ ├── funcs.py │ ├── non_local_block.py │ ├── region_non_local_block.py │ ├── resnet.py │ ├── se_block.py │ └── unet.py ├── loss_LearnEdgeNet.py ├── loss_LearnMaskNet.py ├── metric_LearnEdgeNet.py ├── metric_LearnMaskNet.py ├── metric_utils │ ├── __init__.py │ ├── per_pixel.py │ ├── psnr.py │ └── ssim.py ├── model_LearnEdgeNet.py └── model_LearnMaskNet.py ├── poster.pdf ├── scripts ├── __init__.py ├── make_dataset.py ├── make_dataset_original_resolution.py └── make_edge_map.py ├── trainer ├── __init__.py ├── trainer_LearnEdgeNet.py └── trainer_LearnMaskNet.py └── utils ├── __init__.py ├── logger.py ├── util.py └── visualization.py /.gitignore: -------------------------------------------------------------------------------- 1 | others/ 2 | -------------------------------------------------------------------------------- /Network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fourson/UnModNet/70b72c595382556b96b66333ea3f82079c48dfa6/Network.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UnModNet: Learning to Unwrap a Modulo Image for High Dynamic Range Imaging 2 | 3 | By [Chu Zhou](https://fourson.github.io/), Hang Zhao, Jin Han, Chang Xu, Chao Xu, Tiejun Huang, [Boxin Shi](http://ci.idm.pku.edu.cn/) 4 | ![Network](Network.png) 5 | 6 | [PDF](https://proceedings.neurips.cc/paper/2020/file/1102a326d5f7c9e04fc3c89d0ede88c9-Paper.pdf) | [SUPP](https://proceedings.neurips.cc/paper/2020/file/1102a326d5f7c9e04fc3c89d0ede88c9-Supplemental.pdf) 7 | 8 | ## Abstract 9 | A conventional camera often suffers from over- or under-exposure when recording a real-world scene with a very high dynamic range (HDR). In contrast, a modulo camera with a Markov random field (MRF) based unwrapping algorithm can theoretically accomplish unbounded dynamic range but shows degenerate performances when there are modulus-intensity ambiguity, strong local contrast, and color misalignment. In this paper, we reformulate the modulo image unwrapping problem into a series of binary labeling problems and propose a modulo edge-aware model, named as UnModNet, to iteratively estimate the binary rollover masks of the modulo image for unwrapping. Experimental results show that our approach can generate 12-bit HDR images from 8-bit modulo images reliably, and runs much faster than the previous MRF-based algorithm thanks to the GPU acceleration. 10 | ## Prerequisites 11 | 12 | * Linux Distributions (tested on Ubuntu 18.04). 13 | * NVIDIA GPU and CUDA cuDNN 14 | * Python >= 3.7 15 | * Pytorch >= 1.1.0 16 | * cv2 17 | * numpy 18 | * tqdm 19 | * tensorboardX (for training visualization) 20 | 21 | ## Inference 22 | 23 | * To unwrap RGB modulo images (in `.npy` format and in `(H, W, 3)` shape): 24 | ``` 25 | python execute/infer_LearnMaskNet.py -r checkpoint/checkpoint-mask.pth --data_dir --result_dir --resume_edge_module checkpoint/checkpoint-edge.pth default 26 | ``` 27 | 28 | * To unwrap grayscale modulo images (in `.npy` format and in `(H, W, 1)` shape): 29 | ``` 30 | python execute/infer_LearnMaskNet.py -r checkpoint/checkpoint-mask-gray.pth --data_dir --result_dir --resume_edge_module checkpoint/checkpoint-edge-gray.pth default 31 | ``` 32 | 33 | * Use `TonemapReinhard_npy.py` to visualize the results. Note that the default tonemap method we use is `cv2.createTonemapReinhard(intensity=-1.0, light_adapt=0.8, color_adapt=0.0)`. 34 | 35 | ## Pre-trained models and test examples 36 | 37 | https://drive.google.com/drive/folders/10Y8MOr2o2TZzTI5RZUQZQ-0RBezbzhIV?usp=sharing 38 | 39 | ## Training your own model 40 | 41 | 1. Make dataset from original data (HDR images in `.npy` format): 42 | * make dataset: 43 | ``` 44 | python scripts/make_dataset.py --data_dir --train_dir --test_dir --training_sample 45 | ``` 46 | * make edge map: 47 | ``` 48 | python scripts/make_edge_map.py --data_dir 49 | ``` 50 | 51 | 2. Configure the training parameters: 52 | * write your own `config.json` or use ours: `config/edge_module.json` and `config/mask_module.json` for two stages respectively 53 | * edit the learning rate schedule function (LambdaLR) at `get_lr_lambda` in `utils/util.py` 54 | 55 | 3. Run: 56 | ``` 57 | python execute/train.py -c 58 | ``` 59 | 60 | ## Citation 61 | 62 | If you find this work helpful to your research, please cite: 63 | ``` 64 | @inproceedings{NEURIPS2020_1102a326, 65 | author = {Zhou, Chu and Zhao, Hang and Han, Jin and Xu, Chang and Xu, Chao and Huang, Tiejun and Shi, Boxin}, 66 | booktitle = {Advances in Neural Information Processing Systems}, 67 | editor = {H. Larochelle and M. Ranzato and R. Hadsell and M. F. Balcan and H. Lin}, 68 | pages = {1559--1570}, 69 | publisher = {Curran Associates, Inc.}, 70 | title = {UnModNet: Learning to Unwrap a Modulo Image for High Dynamic Range Imaging}, 71 | url = {https://proceedings.neurips.cc/paper/2020/file/1102a326d5f7c9e04fc3c89d0ede88c9-Paper.pdf}, 72 | volume = {33}, 73 | year = {2020} 74 | } 75 | ``` 76 | -------------------------------------------------------------------------------- /TonemapReinhard_npy.py: -------------------------------------------------------------------------------- 1 | # tonemap all .npy images in current dir to .jpg using OpenCV Reinhard method 2 | import os 3 | import cv2 4 | import numpy as np 5 | 6 | for f in os.listdir(os.getcwd()): 7 | if f.endswith('.npy'): 8 | hdr = np.load(f) 9 | hdr = hdr.astype('float32') 10 | hdr = (hdr - np.min(hdr)) / (np.max(hdr) - np.min(hdr)) 11 | grayscale = True 12 | if hdr.ndim == 3: 13 | if hdr.shape[2] == 3: 14 | # RGB image (H, W, 3) 15 | hdr = cv2.cvtColor(hdr, cv2.COLOR_RGB2BGR) 16 | grayscale = False 17 | elif hdr.shape[2] == 1: 18 | # grayscale image (H, W, 1) 19 | hdr = hdr[:, :, 0] 20 | if grayscale: 21 | hdr = np.stack([hdr, hdr, hdr], axis=2) 22 | 23 | tmo = cv2.createTonemapReinhard(intensity=-1.0, light_adapt=0.8, color_adapt=0.0) 24 | tonemapped = tmo.process(hdr) 25 | f_ = f.split('.')[0] + '.jpg' 26 | cv2.imwrite(f_, tonemapped * 255) 27 | 28 | -------------------------------------------------------------------------------- /base/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fourson/UnModNet/70b72c595382556b96b66333ea3f82079c48dfa6/base/__init__.py -------------------------------------------------------------------------------- /base/base_data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import DataLoader 3 | from torch.utils.data.sampler import SubsetRandomSampler 4 | 5 | 6 | class BaseDataLoader(DataLoader): 7 | """ 8 | Base class for all data loaders 9 | """ 10 | def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers): 11 | self.validation_split = validation_split 12 | self.shuffle = shuffle 13 | 14 | self.batch_idx = 0 15 | self.n_samples = len(dataset) 16 | 17 | self.sampler, self.valid_sampler = self._split_sampler(self.validation_split) 18 | 19 | self.init_kwargs = { 20 | 'dataset': dataset, 21 | 'batch_size': batch_size, 22 | 'shuffle': self.shuffle, 23 | 'num_workers': num_workers 24 | } 25 | super(BaseDataLoader, self).__init__(sampler=self.sampler, **self.init_kwargs) 26 | 27 | def _split_sampler(self, split): 28 | if split == 0.0: 29 | return None, None 30 | 31 | idx_full = np.arange(self.n_samples) 32 | 33 | np.random.seed(0) 34 | np.random.shuffle(idx_full) 35 | 36 | len_valid = int(self.n_samples * split) 37 | 38 | valid_idx = idx_full[0:len_valid] 39 | train_idx = np.delete(idx_full, np.arange(0, len_valid)) 40 | 41 | train_sampler = SubsetRandomSampler(train_idx) 42 | valid_sampler = SubsetRandomSampler(valid_idx) 43 | 44 | # turn off shuffle option which is mutually exclusive with sampler 45 | self.shuffle = False 46 | self.n_samples = len(train_idx) 47 | 48 | return train_sampler, valid_sampler 49 | 50 | def split_validation(self): 51 | if self.valid_sampler is None: 52 | return None 53 | else: 54 | return DataLoader(sampler=self.valid_sampler, **self.init_kwargs) 55 | -------------------------------------------------------------------------------- /base/base_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch.nn as nn 3 | 4 | 5 | class BaseModel(nn.Module): 6 | """ 7 | Base class for all models 8 | """ 9 | 10 | def __init__(self): 11 | super(BaseModel, self).__init__() 12 | self.logger = logging.getLogger(self.__class__.__name__) 13 | 14 | def forward(self, *input): 15 | """ 16 | Forward pass logic 17 | :return: Model output 18 | """ 19 | raise NotImplementedError 20 | 21 | def summary(self): 22 | """ 23 | Model summary 24 | """ 25 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 26 | params = sum([p.numel() for p in model_parameters]) 27 | self.logger.info('Trainable parameters: {}'.format(params)) 28 | self.logger.info(self) 29 | 30 | def __str__(self): 31 | """ 32 | Model prints with number of trainable parameters 33 | """ 34 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 35 | params = sum([p.numel() for p in model_parameters]) 36 | return super(BaseModel, self).__str__() + '\nTrainable parameters: {}'.format(params) 37 | -------------------------------------------------------------------------------- /base/base_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import json 4 | import logging 5 | import datetime 6 | 7 | import torch 8 | 9 | from utils import util 10 | from utils.visualization import WriterTensorboardX 11 | from model.layer_utils.funcs import init_weights 12 | 13 | 14 | class BaseTrainer: 15 | """ 16 | Base class for all trainers 17 | """ 18 | 19 | def __init__(self, config, model, loss, metrics, optimizer, lr_scheduler, resume, train_logger): 20 | self.config = config 21 | self.logger = logging.getLogger(self.__class__.__name__) 22 | 23 | # setup GPU device if available, move model into configured device and init the weights 24 | self.device, device_ids = self._prepare_device(config['n_gpu']) 25 | self.model = model.to(self.device) 26 | self.model.apply(init_weights) 27 | self.data_parallel = (len(device_ids) > 1) 28 | if self.data_parallel: 29 | self.model = torch.nn.DataParallel(model, device_ids=device_ids) 30 | 31 | self.loss = loss 32 | self.metrics = metrics 33 | self.optimizer = optimizer 34 | self.lr_scheduler = lr_scheduler 35 | self.train_logger = train_logger 36 | 37 | trainer_args = config['trainer']['args'] 38 | self.epochs = trainer_args['epochs'] 39 | self.save_period = trainer_args['save_period'] 40 | self.verbosity = trainer_args['verbosity'] 41 | self.monitor = trainer_args.get('monitor', 'off') 42 | 43 | # configuration to monitor model performance and save best 44 | if self.monitor == 'off': 45 | self.mnt_mode = 'off' 46 | self.mnt_best = 0 47 | else: 48 | self.mnt_mode, self.mnt_metric = self.monitor.split() 49 | assert self.mnt_mode in ['min', 'max'] 50 | 51 | self.mnt_best = math.inf if self.mnt_mode == 'min' else -math.inf 52 | self.early_stop = trainer_args.get('early_stop', math.inf) 53 | 54 | self.start_epoch = 1 55 | 56 | # setup directory for checkpoint saving 57 | start_time = datetime.datetime.now().strftime('%m%d_%H%M%S') 58 | self.checkpoint_dir = os.path.join(trainer_args['save_dir'], config['module'], config['name'], start_time) 59 | 60 | # setup visualization writer instance 61 | writer_dir = os.path.join(trainer_args['log_dir'], config['module'], config['name'], start_time) 62 | self.writer = WriterTensorboardX(writer_dir, self.logger, trainer_args['tensorboardX']) 63 | 64 | # Save configuration file into checkpoint directory 65 | util.ensure_dir(self.checkpoint_dir) 66 | config_save_path = os.path.join(self.checkpoint_dir, 'config.json') 67 | with open(config_save_path, 'w') as handle: 68 | json.dump(config, handle, indent=4) 69 | 70 | if resume: 71 | self._resume_checkpoint(resume) 72 | 73 | def _prepare_device(self, n_gpu_use): 74 | """ 75 | setup GPU device if available, move model into configured device 76 | """ 77 | n_gpu = torch.cuda.device_count() 78 | if n_gpu_use > 0 and n_gpu == 0: 79 | self.logger.warning("Warning: There's no GPU available on this machine, training will be performed on CPU.") 80 | n_gpu_use = 0 81 | if n_gpu_use > n_gpu: 82 | self.logger.warning( 83 | "Warning: The number of GPU's configured to use is {}, but only {} are available " 84 | "on this machine.".format(n_gpu_use, n_gpu)) 85 | n_gpu_use = n_gpu 86 | device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu') 87 | device_ids = list(range(n_gpu_use)) 88 | return device, device_ids 89 | 90 | def train(self): 91 | """ 92 | Full training logic 93 | """ 94 | not_improved_count = 0 95 | for epoch in range(self.start_epoch, self.epochs + 1): 96 | result = self._train_epoch(epoch) 97 | 98 | # save logged informations into log dict 99 | log = {'epoch': epoch} 100 | for key, value in result.items(): 101 | if key == 'metrics': 102 | log.update({mtr.__name__: value[i] for i, mtr in enumerate(self.metrics)}) 103 | elif key == 'val_metrics': 104 | log.update({'val_' + mtr.__name__: value[i] for i, mtr in enumerate(self.metrics)}) 105 | else: 106 | log[key] = value 107 | 108 | # print logged informations to the screen 109 | if self.train_logger is not None: 110 | self.train_logger.add_entry(log) 111 | if self.verbosity >= 1: 112 | for key, value in log.items(): 113 | self.logger.info(' {:15s}: {}'.format(str(key), value)) 114 | 115 | # evaluate model performance according to configured metric, save best checkpoint as model_best 116 | is_best = False 117 | if self.mnt_mode != 'off': 118 | try: 119 | # check whether model performance improved or not, according to specified metric(mnt_metric) 120 | improved = (self.mnt_mode == 'min' and log[self.mnt_metric] < self.mnt_best) or \ 121 | (self.mnt_mode == 'max' and log[self.mnt_metric] > self.mnt_best) 122 | except KeyError: 123 | self.logger.warning("Warning: Metric '{}' is not found. " 124 | "Model performance monitoring is disabled.".format(self.mnt_metric)) 125 | self.mnt_mode = 'off' 126 | improved = False 127 | not_improved_count = 0 128 | 129 | if improved: 130 | self.mnt_best = log[self.mnt_metric] 131 | not_improved_count = 0 132 | is_best = True 133 | else: 134 | not_improved_count += 1 135 | 136 | if not_improved_count > self.early_stop: 137 | self.logger.info("Validation performance didn't improve for {} epochs. " 138 | "Training stops.".format(self.early_stop)) 139 | break 140 | 141 | if epoch % self.save_period == 0: 142 | self._save_checkpoint(epoch, save_best=is_best) 143 | 144 | def _train_epoch(self, epoch): 145 | """ 146 | Training logic for an epoch 147 | 148 | :param epoch: Current epoch number 149 | """ 150 | raise NotImplementedError 151 | 152 | def _save_checkpoint(self, epoch, save_best=False): 153 | """ 154 | Saving checkpoints 155 | 156 | :param epoch: current epoch number 157 | :param save_best: if True, rename the saved checkpoint to 'model_best.pth' 158 | """ 159 | state = { 160 | 'epoch': epoch, 161 | 'logger': self.train_logger, 162 | 'optimizer': self.optimizer.state_dict(), 163 | 'lr_scheduler': self.lr_scheduler.state_dict(), 164 | 'monitor_best': self.mnt_best, 165 | 'config': self.config 166 | } 167 | if self.data_parallel: 168 | state['model'] = self.model.module.state_dict() 169 | else: 170 | state['model'] = self.model.state_dict() 171 | filename = os.path.join(self.checkpoint_dir, 'checkpoint-epoch{}.pth'.format(epoch)) 172 | torch.save(state, filename) 173 | self.logger.info("Saving checkpoint: {} ...".format(filename)) 174 | if save_best: 175 | best_path = os.path.join(self.checkpoint_dir, 'model_best.pth') 176 | torch.save(state, best_path) 177 | self.logger.info("Saving current best: {} ...".format('model_best.pth')) 178 | 179 | def _resume_checkpoint(self, resume_path): 180 | """ 181 | Resume from saved checkpoints 182 | 183 | :param resume_path: Checkpoint path to be resumed 184 | """ 185 | self.logger.info("Loading checkpoint: {} ...".format(resume_path)) 186 | checkpoint = torch.load(resume_path) 187 | self.start_epoch = checkpoint['epoch'] + 1 188 | self.mnt_best = checkpoint['monitor_best'] 189 | 190 | # load params from checkpoint 191 | if checkpoint['config']['module'] != self.config['module']: 192 | self.logger.warning("Warning: Module configuration given in config file is different from that of " 193 | "checkpoint. This may yield an exception while state_dict is being loaded.") 194 | if checkpoint['config']['model']['type'] != self.config['model']['type']: 195 | self.logger.warning("Warning: Architecture configuration given in config file is different from that of " 196 | "checkpoint. This may yield an exception while state_dict is being loaded.") 197 | if self.data_parallel: 198 | self.model.module.load_state_dict(checkpoint['model']) 199 | else: 200 | self.model.load_state_dict(checkpoint['model']) 201 | 202 | # load optimizer state from checkpoint only when optimizer type is not changed. 203 | if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']: 204 | self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. " 205 | "Optimizer parameters not being resumed.") 206 | else: 207 | self.optimizer.load_state_dict(checkpoint['optimizer']) 208 | 209 | # load learning scheduler state from checkpoint only when learning scheduler type is not changed. 210 | if checkpoint['config']['lr_scheduler']['type'] != self.config['lr_scheduler']['type']: 211 | self.logger.warning( 212 | "Warning: Learning scheduler type given in config file is different from that of checkpoint. " 213 | "Learning scheduler parameters not being resumed.") 214 | else: 215 | self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 216 | 217 | self.train_logger = checkpoint['logger'] 218 | self.logger.info("Checkpoint '{}' (epoch {}) loaded".format(resume_path, self.start_epoch)) 219 | -------------------------------------------------------------------------------- /config/edge_module.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "non_local_resnet_dropout", 3 | "n_gpu": 2, 4 | "module": "LearnEdgeNet", 5 | "data_loader": { 6 | "type": "TrainDataLoader", 7 | "args": { 8 | "data_dir": "data/train", 9 | "batch_size": 4, 10 | "shuffle": true, 11 | "validation_split": 0.0, 12 | "num_workers": 4 13 | } 14 | }, 15 | "model": { 16 | "type": "DefaultModel", 17 | "args": { 18 | "input_nc": 3, 19 | "use_dropout": true, 20 | "mode": "residual" 21 | } 22 | }, 23 | "loss": { 24 | "type": "bce_with_logits", 25 | "args": { 26 | } 27 | }, 28 | "metrics": [ 29 | "accuracy", 30 | "precision", 31 | "recall", 32 | "f1_score" 33 | ], 34 | "optimizer": { 35 | "type": "Adam", 36 | "args": { 37 | "lr": 0.0001, 38 | "betas": [ 39 | 0.5, 40 | 0.999 41 | ], 42 | "weight_decay": 0, 43 | "amsgrad": true 44 | } 45 | }, 46 | "lr_scheduler": { 47 | "type": "LambdaLR", 48 | "args": { 49 | "lr_lambda_tag": "original" 50 | } 51 | }, 52 | "trainer": { 53 | "type": "DefaultTrainer", 54 | "args": { 55 | "epochs": 400, 56 | "save_dir": "saved", 57 | "save_period": 20, 58 | "verbosity": 2, 59 | "monitor": "off", 60 | "tensorboardX": true, 61 | "log_dir": "saved/runs" 62 | } 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /config/edge_module_gray.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "non_local_resnet_dropout_gray", 3 | "n_gpu": 4, 4 | "module": "LearnEdgeNet", 5 | "data_loader": { 6 | "type": "TrainDataLoader", 7 | "args": { 8 | "data_dir": "data/grayscale/train", 9 | "batch_size": 8, 10 | "shuffle": true, 11 | "validation_split": 0.0, 12 | "num_workers": 8 13 | } 14 | }, 15 | "model": { 16 | "type": "DefaultModel", 17 | "args": { 18 | "input_nc": 1, 19 | "use_dropout": true, 20 | "mode": "residual" 21 | } 22 | }, 23 | "loss": { 24 | "type": "bce_with_logits", 25 | "args": { 26 | } 27 | }, 28 | "metrics": [ 29 | "accuracy", 30 | "precision", 31 | "recall", 32 | "f1_score" 33 | ], 34 | "optimizer": { 35 | "type": "Adam", 36 | "args": { 37 | "lr": 0.0002, 38 | "betas": [ 39 | 0.5, 40 | 0.999 41 | ], 42 | "weight_decay": 0, 43 | "amsgrad": true 44 | } 45 | }, 46 | "lr_scheduler": { 47 | "type": "LambdaLR", 48 | "args": { 49 | "lr_lambda_tag": "grayscale" 50 | } 51 | }, 52 | "trainer": { 53 | "type": "DefaultTrainer", 54 | "args": { 55 | "epochs": 300, 56 | "save_dir": "saved", 57 | "save_period": 10, 58 | "verbosity": 2, 59 | "monitor": "off", 60 | "tensorboardX": true, 61 | "log_dir": "saved/runs" 62 | } 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /config/mask_module.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "se_attention_unet_dropout", 3 | "n_gpu": 2, 4 | "module": "LearnMaskNet", 5 | "data_loader": { 6 | "type": "TrainDataLoader", 7 | "args": { 8 | "data_dir": "data/train", 9 | "batch_size": 4, 10 | "shuffle": true, 11 | "validation_split": 0.0, 12 | "num_workers": 4 13 | } 14 | }, 15 | "model": { 16 | "type": "DefaultModel", 17 | "args": { 18 | "input_nc": 3, 19 | "use_dropout": true 20 | } 21 | }, 22 | "loss": { 23 | "type": "bce_with_logits", 24 | "args": { 25 | } 26 | }, 27 | "metrics": [ 28 | "accuracy", 29 | "precision", 30 | "recall", 31 | "f1_score" 32 | ], 33 | "optimizer": { 34 | "type": "Adam", 35 | "args": { 36 | "lr": 0.0001, 37 | "betas": [ 38 | 0.5, 39 | 0.999 40 | ], 41 | "weight_decay": 0, 42 | "amsgrad": true 43 | } 44 | }, 45 | "lr_scheduler": { 46 | "type": "LambdaLR", 47 | "args": { 48 | "lr_lambda_tag": "original" 49 | } 50 | }, 51 | "trainer": { 52 | "type": "DefaultTrainer", 53 | "args": { 54 | "epochs": 400, 55 | "save_dir": "saved", 56 | "save_period": 20, 57 | "verbosity": 2, 58 | "monitor": "off", 59 | "tensorboardX": true, 60 | "log_dir": "saved/runs" 61 | } 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /config/mask_module_gray.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "se_attention_unet_dropout_gray", 3 | "n_gpu": 4, 4 | "module": "LearnMaskNet", 5 | "data_loader": { 6 | "type": "TrainDataLoader", 7 | "args": { 8 | "data_dir": "data/grayscale/train", 9 | "batch_size": 8, 10 | "shuffle": true, 11 | "validation_split": 0.0, 12 | "num_workers": 8 13 | } 14 | }, 15 | "model": { 16 | "type": "DefaultModel", 17 | "args": { 18 | "input_nc": 1, 19 | "use_dropout": true 20 | } 21 | }, 22 | "loss": { 23 | "type": "bce_with_logits", 24 | "args": { 25 | } 26 | }, 27 | "metrics": [ 28 | "accuracy", 29 | "precision", 30 | "recall", 31 | "f1_score" 32 | ], 33 | "optimizer": { 34 | "type": "Adam", 35 | "args": { 36 | "lr": 0.0002, 37 | "betas": [ 38 | 0.5, 39 | 0.999 40 | ], 41 | "weight_decay": 0, 42 | "amsgrad": true 43 | } 44 | }, 45 | "lr_scheduler": { 46 | "type": "LambdaLR", 47 | "args": { 48 | "lr_lambda_tag": "grayscale" 49 | } 50 | }, 51 | "trainer": { 52 | "type": "DefaultTrainer", 53 | "args": { 54 | "epochs": 300, 55 | "save_dir": "saved", 56 | "save_period": 10, 57 | "verbosity": 2, 58 | "monitor": "off", 59 | "tensorboardX": true, 60 | "log_dir": "saved/runs" 61 | } 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /data_loader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fourson/UnModNet/70b72c595382556b96b66333ea3f82079c48dfa6/data_loader/__init__.py -------------------------------------------------------------------------------- /data_loader/data_loader_LearnEdgeNet.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | 3 | from .dataset import dataset_LearnEdgeNet 4 | from base.base_data_loader import BaseDataLoader 5 | 6 | 7 | class TrainDataLoader(BaseDataLoader): 8 | def __init__(self, data_dir, batch_size, shuffle, validation_split, num_workers): 9 | transform = None 10 | self.dataset = dataset_LearnEdgeNet.TrainDataset(data_dir, transform=transform) 11 | 12 | super(TrainDataLoader, self).__init__(self.dataset, batch_size, shuffle, validation_split, num_workers) 13 | 14 | 15 | class InferDataLoader(DataLoader): 16 | def __init__(self, data_dir): 17 | transform = None 18 | self.dataset = dataset_LearnEdgeNet.InferDataset(data_dir, transform=transform) 19 | 20 | super(InferDataLoader, self).__init__(self.dataset) 21 | 22 | 23 | -------------------------------------------------------------------------------- /data_loader/data_loader_LearnMaskNet.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | 3 | from .dataset import dataset_LearnMaskNet 4 | from base.base_data_loader import BaseDataLoader 5 | 6 | 7 | class TrainDataLoader(BaseDataLoader): 8 | def __init__(self, data_dir, batch_size, shuffle, validation_split, num_workers): 9 | transform = None 10 | self.dataset = dataset_LearnMaskNet.TrainDataset(data_dir, transform=transform) 11 | 12 | super(TrainDataLoader, self).__init__(self.dataset, batch_size, shuffle, validation_split, num_workers) 13 | 14 | 15 | class InferDataLoader(DataLoader): 16 | def __init__(self, data_dir): 17 | transform = None 18 | self.dataset = dataset_LearnMaskNet.InferDataset(data_dir, transform=transform) 19 | 20 | super(InferDataLoader, self).__init__(self.dataset) 21 | 22 | -------------------------------------------------------------------------------- /data_loader/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fourson/UnModNet/70b72c595382556b96b66333ea3f82079c48dfa6/data_loader/dataset/__init__.py -------------------------------------------------------------------------------- /data_loader/dataset/dataset_LearnEdgeNet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import fnmatch 3 | 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import Dataset 7 | 8 | 9 | class TrainDataset(Dataset): 10 | """ 11 | 256*256 RGB images 12 | 13 | as input: 14 | modulo: [0, 1] float, as float32 15 | modulo_edge: [0, 1] float, as float32 16 | 17 | as target: 18 | fold_number_edge: binary, as float32 19 | """ 20 | 21 | def __init__(self, data_dir='data', transform=None): 22 | self.modulo_dir = os.path.join(data_dir, 'modulo') 23 | self.modulo_edge_dir = os.path.join(data_dir, 'modulo_edge') 24 | self.fold_number_edge_dir = os.path.join(data_dir, 'fold_number_edge') 25 | 26 | self.names = fnmatch.filter(os.listdir(self.modulo_dir), '*.npy') 27 | 28 | self.transform = transform 29 | 30 | def __len__(self): 31 | return len(self.names) 32 | 33 | def __getitem__(self, index): 34 | # (H, W, C) 35 | modulo = np.load(os.path.join(self.modulo_dir, self.names[index])) # positive int, as float32 36 | modulo_edge = np.load(os.path.join(self.modulo_edge_dir, self.names[index])) # [0, 1] float, as float32 37 | fold_number_edge = np.load(os.path.join(self.fold_number_edge_dir, self.names[index])) # binary, as float32 38 | 39 | name = self.names[index].split('.')[0] 40 | assert modulo.ndim == 3 # for RGB image 41 | 42 | # (C, H, W) 43 | modulo = torch.tensor(np.transpose(modulo / np.max(modulo), (2, 0, 1)), dtype=torch.float32) 44 | modulo_edge = torch.tensor(np.transpose(modulo_edge, (2, 0, 1)), dtype=torch.float32) 45 | fold_number_edge = torch.tensor(np.transpose(fold_number_edge, (2, 0, 1)), dtype=torch.float32) 46 | 47 | if self.transform: 48 | modulo = self.transform(modulo) 49 | modulo_edge = self.transform(modulo_edge) 50 | fold_number_edge = self.transform(fold_number_edge) 51 | 52 | return {'modulo': modulo, 'modulo_edge': modulo_edge, 'fold_number_edge': fold_number_edge, 'name': name} 53 | 54 | 55 | class InferDataset(Dataset): 56 | """ 57 | 256*256 RGB images 58 | 59 | modulo: positive int, as float32 60 | """ 61 | 62 | def __init__(self, data_dir, transform=None): 63 | self.data_dir = data_dir 64 | 65 | self.names = fnmatch.filter(os.listdir(self.data_dir), '*.npy') 66 | 67 | self.transform = transform 68 | 69 | def __len__(self): 70 | return len(self.names) 71 | 72 | def __getitem__(self, index): 73 | # (H, W, C) 74 | modulo = np.load(os.path.join(self.data_dir, self.names[index])) # positive int, as float32 75 | 76 | name = self.names[index].split('.')[0] 77 | assert modulo.ndim == 3 # for RGB image 78 | 79 | # (C, H, W) 80 | modulo = torch.tensor(np.transpose(modulo, (2, 0, 1)), dtype=torch.float32) 81 | 82 | if self.transform: 83 | modulo = self.transform(modulo) 84 | 85 | return {'modulo': modulo, 'name': name} 86 | 87 | -------------------------------------------------------------------------------- /data_loader/dataset/dataset_LearnMaskNet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import fnmatch 3 | 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import Dataset 7 | 8 | 9 | class TrainDataset(Dataset): 10 | """ 11 | 256*256 RGB images 12 | 13 | as input: 14 | modulo: [0, 1] float, as float32 15 | fold_number_edge: binary, as float32 16 | 17 | as target: 18 | mask: binary, as float32 19 | """ 20 | 21 | def __init__(self, data_dir='data', transform=None): 22 | self.modulo_dir = os.path.join(data_dir, 'modulo') 23 | self.fold_number_edge_dir = os.path.join(data_dir, 'fold_number_edge') 24 | self.mask_dir = os.path.join(data_dir, 'mask') 25 | 26 | self.names = fnmatch.filter(os.listdir(self.modulo_dir), '*.npy') 27 | 28 | self.transform = transform 29 | 30 | def __len__(self): 31 | return len(self.names) 32 | 33 | def __getitem__(self, index): 34 | # (H, W, C) 35 | modulo = np.load(os.path.join(self.modulo_dir, self.names[index])) # positive int, as float32 36 | fold_number_edge = np.load(os.path.join(self.fold_number_edge_dir, self.names[index])) # binary, as float32 37 | mask = np.load(os.path.join(self.mask_dir, self.names[index])) # binary, as float32 38 | 39 | name = self.names[index].split('.')[0] 40 | assert modulo.ndim == 3 # for RGB image 41 | 42 | # (C, H, W) 43 | modulo = torch.tensor(np.transpose(modulo / np.max(modulo), (2, 0, 1)), dtype=torch.float32) 44 | fold_number_edge = torch.tensor(np.transpose(fold_number_edge, (2, 0, 1)), dtype=torch.float32) 45 | mask = torch.tensor(np.transpose(mask, (2, 0, 1)), dtype=torch.float32) 46 | 47 | if self.transform: 48 | modulo = self.transform(modulo) 49 | fold_number_edge = self.transform(fold_number_edge) 50 | mask = self.transform(mask) 51 | 52 | return {'modulo': modulo, 'fold_number_edge': fold_number_edge, 'mask': mask, 'name': name} 53 | 54 | 55 | class InferDataset(Dataset): 56 | """ 57 | 256*256 RGB images 58 | 59 | modulo: positive int, as float32 60 | """ 61 | 62 | def __init__(self, data_dir, transform=None): 63 | self.data_dir = data_dir 64 | 65 | self.names = fnmatch.filter(os.listdir(self.data_dir), '*.npy') 66 | 67 | self.transform = transform 68 | 69 | def __len__(self): 70 | return len(self.names) 71 | 72 | def __getitem__(self, index): 73 | # (H, W, C) 74 | modulo = np.load(os.path.join(self.data_dir, self.names[index])) # positive int, as float32 75 | 76 | name = self.names[index].split('.')[0] 77 | assert modulo.ndim == 3 # for RGB image 78 | 79 | # (C, H, W) 80 | modulo = torch.tensor(np.transpose(modulo, (2, 0, 1)), dtype=torch.float32) 81 | 82 | if self.transform: 83 | modulo = self.transform(modulo) 84 | 85 | return {'modulo': modulo, 'name': name} 86 | 87 | -------------------------------------------------------------------------------- /execute/infer_LearnEdgeNet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import importlib 4 | import sys 5 | 6 | import torch 7 | from tqdm import tqdm 8 | import numpy as np 9 | 10 | 11 | def infer_default(): 12 | with torch.no_grad(): 13 | for batch_idx, sample in enumerate(tqdm(data_loader, ascii=True)): 14 | name = sample['name'][0] 15 | 16 | # get data and send them to GPU 17 | modulo = sample['modulo'].to(device) # positive int, as float32 18 | modulo_edge = torch.abs(util.torch_laplacian(modulo)) # positive int, as float32 19 | 20 | output = model(modulo / torch.max(modulo), modulo_edge / torch.max(modulo_edge)) 21 | fold_number_edge_pred = torch.round(torch.sigmoid(output)) 22 | fold_number_edge_pred_numpy = fold_number_edge_pred.squeeze(0).permute(1, 2, 0).cpu().numpy() 23 | np.save(os.path.join(result_dir, name + '.npy'), fold_number_edge_pred_numpy) 24 | 25 | 26 | if __name__ == '__main__': 27 | MODULE = 'LearnEdgeNet' 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('-r', '--resume', required=True, type=str, help='path to latest checkpoint') 30 | parser.add_argument('-d', '--device', default=None, type=str, help='indices of GPUs to enable (default: all)') 31 | parser.add_argument('--data_dir', required=True, type=str, help='dir of input data') 32 | parser.add_argument('--result_dir', required=True, type=str, help='dir to save result') 33 | parser.add_argument('--data_loader_type', default='InferDataLoader', type=str, help='which data loader to use') 34 | subparsers = parser.add_subparsers(help='which func to run', dest='func') 35 | 36 | # add subparsers and their args for each func 37 | subparser = subparsers.add_parser("default") 38 | 39 | args = parser.parse_args() 40 | 41 | if args.device: 42 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device 43 | 44 | sys.path.append(os.path.dirname(os.path.dirname(__file__))) # add project root to PATH 45 | from utils import util 46 | 47 | # load checkpoint 48 | checkpoint = torch.load(args.resume) 49 | config = checkpoint['config'] 50 | assert config['module'] == MODULE 51 | 52 | # setup data_loader instances 53 | # we choose batch_size=1(default value) 54 | module_data = importlib.import_module('.data_loader_' + MODULE, package='data_loader') 55 | data_loader_class = getattr(module_data, args.data_loader_type) 56 | data_loader = data_loader_class(data_dir=args.data_dir) 57 | 58 | # build model architecture 59 | module_arch = importlib.import_module('.model_' + MODULE, package='model') 60 | model_class = getattr(module_arch, config['model']['type']) 61 | model = model_class(**config['model']['args']) 62 | 63 | # prepare model 64 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 65 | model = model.to(device) 66 | model.load_state_dict(checkpoint['model']) 67 | 68 | # set the model to validation mode 69 | model.eval() 70 | 71 | # ensure result_dir 72 | result_dir = args.result_dir 73 | util.ensure_dir(result_dir) 74 | 75 | # run the selected func 76 | if args.func == 'default': 77 | infer_default() 78 | else: 79 | # run the default 80 | infer_default() 81 | -------------------------------------------------------------------------------- /execute/infer_LearnMaskNet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import importlib 4 | import sys 5 | 6 | import torch 7 | from tqdm import tqdm 8 | import numpy as np 9 | 10 | 11 | def infer_default(iter_max): 12 | # load edge prediction model checkpoint 13 | EDGE_MODULE = 'LearnEdgeNet' 14 | edge_prediction_checkpoint = torch.load(resume_edge_module) 15 | edge_prediction_config = edge_prediction_checkpoint['config'] 16 | assert edge_prediction_config['module'] == EDGE_MODULE 17 | edge_prediction_module_arch = importlib.import_module('.model_' + EDGE_MODULE, package='model') 18 | edge_prediction_model_class = getattr(edge_prediction_module_arch, edge_prediction_config['model']['type']) 19 | edge_prediction_model = edge_prediction_model_class(**edge_prediction_config['model']['args']) 20 | edge_prediction_model = edge_prediction_model.to(device) 21 | edge_prediction_model.load_state_dict(edge_prediction_checkpoint['model']) 22 | edge_prediction_model.eval() 23 | 24 | # make dirs 25 | unwrapped_dir = os.path.join(result_dir, 'unwrapped') 26 | util.ensure_dir(unwrapped_dir) 27 | 28 | with torch.no_grad(): 29 | for batch_idx, sample in enumerate(tqdm(data_loader, ascii=True)): 30 | name = sample['name'][0] 31 | img_dir = os.path.join(result_dir, 'steps', name) 32 | util.ensure_dir(img_dir) 33 | 34 | # get data and send them to GPU 35 | modulo = sample['modulo'].to(device) # positive int, as float32 36 | mask_pred = torch.ones_like(modulo).to(device) 37 | 38 | i = 0 39 | while torch.sum(mask_pred) > 0 and i <= iter_max: 40 | modulo_numpy = modulo.squeeze(0).permute(1, 2, 0).cpu().numpy() 41 | np.save(os.path.join(img_dir, str(i) + '.npy'), modulo_numpy) 42 | 43 | modulo_edge = torch.abs(util.torch_laplacian(modulo)) # positive int, as float32 44 | edge_out = edge_prediction_model(modulo / torch.max(modulo), modulo_edge / torch.max(modulo_edge)) 45 | fold_number_edge = torch.round(torch.sigmoid(edge_out)) # binary, as float32 46 | output = model(modulo / torch.max(modulo), fold_number_edge) 47 | 48 | if confine: 49 | mask_pred *= torch.round(torch.sigmoid(output)) 50 | else: 51 | mask_pred = torch.round(torch.sigmoid(output)) 52 | 53 | modulo += 256 * mask_pred 54 | i += 1 55 | 56 | unwrapped_numpy = modulo.squeeze(0).permute(1, 2, 0).cpu().numpy() # positive int, as float32 57 | np.save(os.path.join(unwrapped_dir, name + '.npy'), unwrapped_numpy) 58 | 59 | 60 | if __name__ == '__main__': 61 | MODULE = 'LearnMaskNet' 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument('-r', '--resume', required=True, type=str, help='path to latest checkpoint') 64 | parser.add_argument('-d', '--device', default=None, type=str, help='indices of GPUs to enable (default: all)') 65 | parser.add_argument('--data_dir', required=True, type=str, help='dir of input data') 66 | parser.add_argument('--result_dir', required=True, type=str, help='dir to save result') 67 | parser.add_argument('--data_loader_type', default='InferDataLoader', type=str, help='which data loader to use') 68 | parser.add_argument('--confine', default=0, type=int, help='confine mode') 69 | parser.add_argument('--resume_edge_module', required=True, type=str, 70 | help='path to latest checkpoint of edge prediction model') 71 | subparsers = parser.add_subparsers(help='which func to run', dest='func') 72 | 73 | # add subparsers and their args for each func 74 | subparser = subparsers.add_parser("default") 75 | subparser.add_argument('--iter_max', default=15, type=int, help='iteration limit') 76 | 77 | args = parser.parse_args() 78 | 79 | if args.device: 80 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device 81 | 82 | sys.path.append(os.path.dirname(os.path.dirname(__file__))) # add project root to PATH 83 | from utils import util 84 | 85 | # load checkpoint 86 | checkpoint = torch.load(args.resume) 87 | config = checkpoint['config'] 88 | assert config['module'] == MODULE 89 | 90 | # setup data_loader instances 91 | # we choose batch_size=1(default value) 92 | module_data = importlib.import_module('.data_loader_' + MODULE, package='data_loader') 93 | data_loader_class = getattr(module_data, args.data_loader_type) 94 | data_loader = data_loader_class(data_dir=args.data_dir) 95 | 96 | # build model architecture 97 | module_arch = importlib.import_module('.model_' + MODULE, package='model') 98 | model_class = getattr(module_arch, config['model']['type']) 99 | model = model_class(**config['model']['args']) 100 | 101 | # prepare model 102 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 103 | model = model.to(device) 104 | model.load_state_dict(checkpoint['model']) 105 | 106 | # set the model to validation mode 107 | model.eval() 108 | 109 | # ensure result_dir 110 | result_dir = args.result_dir 111 | util.ensure_dir(result_dir) 112 | 113 | # use the previous mask as the confinement of the current mask 114 | confine = bool(args.confine) 115 | # path to latest checkpoint of edge prediction model 116 | resume_edge_module = args.resume_edge_module 117 | 118 | # run the selected func 119 | if args.func == 'default': 120 | infer_default(args.iter_max) 121 | else: 122 | # run the default 123 | infer_default(args.iter_max) 124 | -------------------------------------------------------------------------------- /execute/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import argparse 5 | import importlib 6 | from functools import partial 7 | 8 | import torch 9 | 10 | 11 | def train(config, resume): 12 | # prepare logger 13 | train_logger = Logger() 14 | 15 | # setup data_loader instances 16 | data_loader_class = getattr(module_data, config['data_loader']['type']) 17 | data_loader = data_loader_class(**config['data_loader']['args']) 18 | valid_data_loader = data_loader.split_validation() 19 | 20 | # build model architecture 21 | model_class = getattr(module_arch, config['model']['type']) 22 | model = model_class(**config['model']['args']) 23 | 24 | # show model structure 25 | print(model) 26 | 27 | # get function handles of loss and metrics 28 | loss = partial(getattr(module_loss, config['loss']['type']), **config['loss']['args']) 29 | metrics = [getattr(module_metric, met) for met in config['metrics']] 30 | 31 | # build optimizer for model parameters 32 | trainable_params = filter(lambda p: p.requires_grad, model.parameters()) 33 | optimizer_class = getattr(torch.optim, config['optimizer']['type']) 34 | optimizer = optimizer_class(trainable_params, **config['optimizer']['args']) 35 | 36 | # build learning rate scheduler for optimizer 37 | lr_scheduler_class = getattr(torch.optim.lr_scheduler, config['lr_scheduler']['type']) 38 | if config['lr_scheduler']['type'] == 'LambdaLR': 39 | lr_lambda = util.get_lr_lambda(config['lr_scheduler']['args']['lr_lambda_tag']) 40 | lr_scheduler = lr_scheduler_class(optimizer, lr_lambda) 41 | else: 42 | lr_scheduler = lr_scheduler_class(optimizer, **config['lr_scheduler']['args']) 43 | 44 | # build trainer and train the network 45 | trainer_class = getattr(module_trainer, config['trainer']['type']) 46 | trainer = trainer_class(config, model, loss, metrics, optimizer, lr_scheduler, resume, data_loader, 47 | valid_data_loader, train_logger) 48 | trainer.train() 49 | 50 | 51 | if __name__ == '__main__': 52 | parser = argparse.ArgumentParser(description='UnwrapNet') 53 | parser.add_argument('-c', '--config', default=None, type=str, help='config file path (default: None)') 54 | parser.add_argument('-r', '--resume', default=None, type=str, help='path to latest checkpoint (default: None)') 55 | parser.add_argument('-d', '--device', default=None, type=str, help='indices of GPUs to enable (default: all)') 56 | args = parser.parse_args() 57 | 58 | sys.path.append(os.path.dirname(os.path.dirname(__file__))) # add project root to PATH 59 | 60 | if args.config: 61 | # load config file 62 | with open(args.config) as handle: 63 | config = json.load(handle) 64 | elif args.resume: 65 | # load config from checkpoint if new config file is not given. 66 | # Use '--config' and '--resume' together to fine-tune trained model with changed configurations. 67 | config = torch.load(args.resume)['config'] 68 | else: 69 | raise AssertionError("Configuration file need to be specified. Add '-c config.json', for example.") 70 | 71 | if args.device: 72 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device 73 | 74 | from utils.logger import Logger 75 | from utils import util 76 | 77 | module = config['module'] 78 | if module == 'full': 79 | postfix = '' 80 | else: 81 | postfix = '_' + module 82 | module_data = importlib.import_module('.data_loader' + postfix, package='data_loader') 83 | module_arch = importlib.import_module('.model' + postfix, package='model') 84 | module_loss = importlib.import_module('.loss' + postfix, package='model') 85 | module_metric = importlib.import_module('.metric' + postfix, package='model') 86 | module_trainer = importlib.import_module('.trainer' + postfix, package='trainer') 87 | 88 | train(config, args.resume) 89 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fourson/UnModNet/70b72c595382556b96b66333ea3f82079c48dfa6/model/__init__.py -------------------------------------------------------------------------------- /model/layer_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fourson/UnModNet/70b72c595382556b96b66333ea3f82079c48dfa6/model/layer_utils/__init__.py -------------------------------------------------------------------------------- /model/layer_utils/funcs.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch.nn as nn 4 | from torchvision import models 5 | 6 | VGG19_FEATURES = models.vgg19(pretrained=True).features 7 | CONV3_3_IN_VGG_19 = VGG19_FEATURES[0:15].cuda() 8 | VGG19_0to8 = VGG19_FEATURES[0:9].cuda() 9 | VGG19_9to13 = VGG19_FEATURES[9:14].cuda() 10 | VGG19_14to22 = VGG19_FEATURES[14:23].cuda() 11 | VGG19_23to31 = VGG19_FEATURES[23:32].cuda() 12 | 13 | 14 | def get_norm_layer(norm_type='instance'): 15 | if norm_type == 'batch': 16 | norm_layer = nn.BatchNorm2d 17 | elif norm_type == 'instance': 18 | norm_layer = functools.partial(nn.InstanceNorm2d) 19 | else: 20 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 21 | return norm_layer 22 | 23 | 24 | def init_weights(m): 25 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 26 | nn.init.normal_(m.weight, 0, 0.02) 27 | if m.bias is not None: 28 | nn.init.zeros_(m.bias) 29 | if isinstance(m, nn.BatchNorm2d): 30 | nn.init.normal_(m.weight, 1, 0.02) 31 | nn.init.zeros_(m.bias) -------------------------------------------------------------------------------- /model/layer_utils/non_local_block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | # An implemention of non-local blocks 6 | 7 | class Self_Attn_FM(nn.Module): 8 | """ Self attention Layer for Feature Map dimension""" 9 | 10 | def __init__(self, in_dim, latent_dim=8, subsample=True): 11 | super(Self_Attn_FM, self).__init__() 12 | self.channel_latent = in_dim // latent_dim 13 | self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=self.channel_latent, kernel_size=1, stride=1) 14 | self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=self.channel_latent, kernel_size=1, stride=1) 15 | self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=self.channel_latent, kernel_size=1, stride=1) 16 | self.out_conv = nn.Conv2d(in_channels=self.channel_latent, out_channels=in_dim, kernel_size=1, stride=1) 17 | self.gamma = nn.Parameter(torch.zeros(1)) 18 | self.softmax = nn.Softmax(dim=-1) 19 | 20 | if subsample: 21 | self.key_conv = nn.Sequential( 22 | self.key_conv, 23 | nn.MaxPool2d(2) 24 | ) 25 | self.value_conv = nn.Sequential( 26 | self.value_conv, 27 | nn.MaxPool2d(2) 28 | ) 29 | 30 | def forward(self, x): 31 | """ 32 | inputs : 33 | x : input feature maps(B x C x H x W) 34 | returns : 35 | out : self attention value + input feature 36 | """ 37 | batchsize, C, height, width = x.size() 38 | c = self.channel_latent 39 | # proj_query: reshape to B x N x c, N = H x W 40 | proj_query = self.query_conv(x).view(batchsize, c, -1).permute(0, 2, 1) 41 | # proj_key: reshape to B x c x N_, N_ = H_ x W_ 42 | proj_key = self.key_conv(x).view(batchsize, c, -1) 43 | # energy: B x N x N_, N = H x W, N_ = H_ x W_ 44 | energy = torch.bmm(proj_query, proj_key) 45 | # attention: B x N_ x N, N = H x W, N_ = H_ x W_ 46 | attention = self.softmax(energy).permute(0, 2, 1) 47 | # proj_value: B x c x N_, N_ = H_ x W_ 48 | proj_value = self.value_conv(x).view(batchsize, c, -1) 49 | # attention_out: B x c x N, N = H x W 50 | attention_out = torch.bmm(proj_value, attention) 51 | # out: B x C x H x W 52 | out = self.out_conv(attention_out.view(batchsize, c, height, width)) 53 | 54 | out = self.gamma * out + x 55 | return out 56 | 57 | 58 | class Self_Attn_C(nn.Module): 59 | """ Self attention Layer for Channel dimension""" 60 | 61 | def __init__(self, in_dim, latent_dim=8): 62 | super(Self_Attn_C, self).__init__() 63 | self.channel_latent = in_dim // latent_dim 64 | self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=self.channel_latent, kernel_size=1, stride=1) 65 | self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=self.channel_latent, kernel_size=1, stride=1) 66 | self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=self.channel_latent, kernel_size=1, stride=1) 67 | self.out_conv = nn.Conv2d(in_channels=self.channel_latent, out_channels=in_dim, kernel_size=1, stride=1) 68 | self.gamma = nn.Parameter(torch.zeros(1)) 69 | self.softmax = nn.Softmax(dim=-1) 70 | 71 | def forward(self, x): 72 | """ 73 | inputs : 74 | x : input feature maps(B x C x H x W) 75 | returns : 76 | out : self attention value + input feature 77 | """ 78 | batchsize, C, height, width = x.size() 79 | # proj_query: reshape to B x N x c, N = H x W 80 | proj_query = self.query_conv(x).view(batchsize, -1, height * width).permute(0, 2, 1) 81 | # proj_key: reshape to B x c x N, N = H x W 82 | proj_key = self.key_conv(x).view(batchsize, -1, height * width) 83 | # energy: B x c x c 84 | energy = torch.bmm(proj_key, proj_query) 85 | # attention: B x c x c 86 | attention = self.softmax(energy) 87 | # proj_value: B x c x N 88 | proj_value = self.value_conv(x).view(batchsize, -1, height * width) 89 | # attention_out: B x c x N 90 | attention_out = torch.bmm(attention.permute(0, 2, 1), proj_value) 91 | # out: B x C x H x W 92 | out = self.out_conv(attention_out.view(batchsize, self.channel_latent, height, width)) 93 | 94 | out = self.gamma * out + x 95 | return out 96 | -------------------------------------------------------------------------------- /model/layer_utils/region_non_local_block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .non_local_block import Self_Attn_FM 4 | 5 | 6 | class RegionNonLocalBlock(nn.Module): 7 | """ 8 | region non-local block 9 | in_channel -> in_channel 10 | """ 11 | 12 | def __init__(self, in_channel, latent_dim=2, subsample=True, grid=(8, 8)): 13 | super(RegionNonLocalBlock, self).__init__() 14 | 15 | self.non_local_block = Self_Attn_FM(in_channel, latent_dim=latent_dim, subsample=subsample) 16 | self.grid = grid 17 | 18 | def forward(self, x): 19 | input_row_list = x.chunk(self.grid[0], dim=2) 20 | output_row_list = [] 21 | for i, row in enumerate(input_row_list): 22 | input_grid_list_of_a_row = row.chunk(self.grid[1], dim=3) 23 | output_grid_list_of_a_row = [] 24 | for j, grid in enumerate(input_grid_list_of_a_row): 25 | grid = self.non_local_block(grid) 26 | output_grid_list_of_a_row.append(grid) 27 | output_row = torch.cat(output_grid_list_of_a_row, dim=3) 28 | output_row_list.append(output_row) 29 | output = torch.cat(output_row_list, dim=2) 30 | return output 31 | 32 | 33 | class RegionNonLocalEnhancedDenseBlock(nn.Module): 34 | """ 35 | region non-local enhanced dense block 36 | in_channel -> in_channel 37 | """ 38 | 39 | def __init__(self, in_channel=64, inter_channel=32, n_blocks=3, latent_dim=2, subsample=True, grid=(8, 8)): 40 | super(RegionNonLocalEnhancedDenseBlock, self).__init__() 41 | 42 | self.region_non_local = RegionNonLocalBlock(in_channel, latent_dim, subsample, grid) 43 | self.conv_blocks = nn.ModuleList() 44 | 45 | dim = in_channel 46 | for i in range(n_blocks): 47 | self.conv_blocks.append( 48 | nn.Sequential( 49 | nn.Conv2d(in_channels=dim, out_channels=inter_channel, kernel_size=3, stride=1, padding=1), 50 | nn.ReLU(), 51 | ) 52 | ) 53 | dim += inter_channel 54 | 55 | self.fusion = nn.Conv2d(in_channels=dim, out_channels=in_channel, kernel_size=1, stride=1) 56 | 57 | def forward(self, x): 58 | feature_list = [self.region_non_local(x)] 59 | for conv_block in self.conv_blocks: 60 | feature_list.append(conv_block(torch.cat(feature_list, dim=1))) 61 | out = self.fusion(torch.cat(feature_list, dim=1)) + x 62 | return out 63 | 64 | -------------------------------------------------------------------------------- /model/layer_utils/resnet.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .funcs import get_norm_layer 7 | from .region_non_local_block import RegionNonLocalBlock 8 | 9 | 10 | class ResnetBlock(nn.Module): 11 | """ 12 | Resnet block using bottleneck structure 13 | dim -> dim 14 | """ 15 | 16 | def __init__(self, dim, norm_layer, use_dropout, use_bias): 17 | super(ResnetBlock, self).__init__() 18 | 19 | sequence = [ 20 | nn.Conv2d(dim, dim, kernel_size=1, stride=1, bias=use_bias), 21 | norm_layer(dim), 22 | nn.ReLU(inplace=True), 23 | nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=use_bias), 24 | norm_layer(dim), 25 | nn.ReLU(inplace=True), 26 | nn.Conv2d(dim, dim, kernel_size=1, stride=1, bias=use_bias), 27 | norm_layer(dim), 28 | nn.ReLU(inplace=True) 29 | ] 30 | if use_dropout: 31 | sequence += [nn.Dropout(0.5)] 32 | 33 | self.model = nn.Sequential(*sequence) 34 | 35 | def forward(self, x): 36 | out = x + self.model(x) 37 | return out 38 | 39 | 40 | class ResnetBackbone(nn.Module): 41 | """ 42 | Resnet backbone 43 | input_nc -> output_nc 44 | """ 45 | 46 | def __init__(self, input_nc, output_nc=64, n_downsampling=2, n_blocks=3, norm_type='instance', use_dropout=False): 47 | super(ResnetBackbone, self).__init__() 48 | 49 | norm_layer = get_norm_layer(norm_type) 50 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 51 | use_bias = norm_layer.func != nn.BatchNorm2d 52 | else: 53 | use_bias = norm_layer != nn.BatchNorm2d 54 | 55 | sequence = [ 56 | nn.Conv2d(input_nc, output_nc, kernel_size=7, stride=1, padding=3, bias=use_bias), 57 | norm_layer(output_nc), 58 | nn.ReLU(True) 59 | ] 60 | 61 | dim = output_nc 62 | for i in range(n_downsampling): # downsample the feature map 63 | sequence += [ 64 | nn.Conv2d(dim, 2 * dim, kernel_size=3, stride=2, padding=1, bias=use_bias), 65 | norm_layer(2 * dim), 66 | nn.ReLU(True) 67 | ] 68 | dim *= 2 69 | 70 | for i in range(n_blocks): # ResBlock 71 | sequence += [ 72 | ResnetBlock(dim, norm_layer, use_dropout, use_bias) 73 | ] 74 | 75 | for i in range(n_downsampling): # upsample the feature map 76 | sequence += [ 77 | nn.ConvTranspose2d(dim, dim // 2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias), 78 | norm_layer(dim // 2), 79 | nn.ReLU(True) 80 | ] 81 | dim //= 2 82 | 83 | self.model = nn.Sequential(*sequence) 84 | 85 | def forward(self, x): 86 | out = self.model(x) 87 | return out 88 | 89 | 90 | class NonLocalResnetDownsamplingBlock(nn.Module): 91 | """ 92 | non-local Resnet downsampling block 93 | in_channel -> out_channel 94 | """ 95 | 96 | def __init__(self, in_channel, out_channel, norm_layer, use_dropout, use_bias, latent_dim): 97 | super(NonLocalResnetDownsamplingBlock, self).__init__() 98 | 99 | self.projection = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1) 100 | self.non_local = RegionNonLocalBlock(out_channel, latent_dim) 101 | self.bottleneck = nn.Sequential( 102 | nn.Conv2d(out_channel, out_channel, kernel_size=1, stride=1, bias=use_bias), 103 | norm_layer(out_channel), 104 | nn.ReLU(inplace=True), 105 | nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=use_bias), 106 | norm_layer(out_channel), 107 | nn.ReLU(inplace=True), 108 | nn.Conv2d(out_channel, out_channel, kernel_size=1, stride=1, bias=use_bias), 109 | ) 110 | out_sequence = [ 111 | norm_layer(out_channel), 112 | nn.ReLU(inplace=True) 113 | ] 114 | 115 | if use_dropout: 116 | out_sequence += [nn.Dropout(0.5)] 117 | out_sequence += [nn.MaxPool2d(2)] 118 | 119 | self.out_block = nn.Sequential(*out_sequence) 120 | 121 | def forward(self, x): 122 | x_ = self.projection(x) 123 | x_ = self.non_local(x_) 124 | out = self.out_block(x_ + self.bottleneck(x_)) 125 | return out 126 | 127 | 128 | class NonLocalResnetUpsamplingBlock(nn.Module): 129 | """ 130 | non-local Resnet upsampling block 131 | x1:in_channel1 x2:in_channel2 --> out_channel 132 | """ 133 | 134 | def __init__(self, in_channel1, in_channel2, out_channel, norm_layer, use_dropout, use_bias, latent_dim): 135 | super(NonLocalResnetUpsamplingBlock, self).__init__() 136 | # in_channel1: 待上采样的输入通道数 137 | # in_channel2: skip link来的通道数 138 | self.upsample = nn.ConvTranspose2d(in_channel1, in_channel1 // 2, kernel_size=4, stride=2, padding=1, 139 | bias=use_bias) 140 | self.projection = nn.Conv2d(in_channel1 // 2 + in_channel2, out_channel, kernel_size=1, stride=1) 141 | self.non_local = RegionNonLocalBlock(out_channel, latent_dim) 142 | self.bottleneck = nn.Sequential( 143 | nn.Conv2d(out_channel, out_channel, kernel_size=1, stride=1, bias=use_bias), 144 | norm_layer(out_channel), 145 | nn.ReLU(inplace=True), 146 | nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=use_bias), 147 | norm_layer(out_channel), 148 | nn.ReLU(inplace=True), 149 | nn.Conv2d(out_channel, out_channel, kernel_size=1, stride=1, bias=use_bias), 150 | ) 151 | out_sequence = [ 152 | norm_layer(out_channel), 153 | nn.ReLU(inplace=True) 154 | ] 155 | 156 | if use_dropout: 157 | out_sequence += [nn.Dropout(0.5)] 158 | 159 | self.out_block = nn.Sequential(*out_sequence) 160 | 161 | def forward(self, x1, x2): 162 | # x1: 待上采样的输入 163 | # x2: skip link来的输入 164 | x_ = self.projection(torch.cat([x2, self.upsample(x1)], dim=1)) 165 | x_ = self.non_local(x_) 166 | out = self.out_block(x_ + self.bottleneck(x_)) 167 | return out 168 | 169 | 170 | class NonLocalResnetBackbone(nn.Module): 171 | """ 172 | non-local Resnet backbone 173 | input_nc -> output_nc 174 | """ 175 | 176 | def __init__(self, input_nc, output_nc=64, n_downsampling=2, n_blocks=6, norm_type='instance', use_dropout=False, 177 | latent_dim=8): 178 | super(NonLocalResnetBackbone, self).__init__() 179 | 180 | norm_layer = get_norm_layer(norm_type) 181 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 182 | use_bias = norm_layer.func != nn.BatchNorm2d 183 | else: 184 | use_bias = norm_layer != nn.BatchNorm2d 185 | 186 | self.n_downsampling = n_downsampling 187 | self.n_blocks = n_blocks 188 | 189 | self.projection = nn.Sequential( 190 | nn.Conv2d(input_nc, output_nc, kernel_size=7, stride=1, padding=3, bias=use_bias), 191 | norm_layer(output_nc), 192 | nn.ReLU(True) 193 | ) 194 | self.in_conv = nn.Sequential( 195 | nn.Conv2d(output_nc, output_nc, kernel_size=3, stride=1, padding=1, bias=use_bias), 196 | norm_layer(output_nc), 197 | nn.ReLU(True) 198 | ) 199 | self.out_conv = nn.Sequential( 200 | nn.Conv2d(2 * output_nc, output_nc, kernel_size=3, stride=1, padding=1, bias=use_bias), 201 | norm_layer(output_nc), 202 | nn.ReLU(True) 203 | ) 204 | self.downsampling_blocks = nn.ModuleList() 205 | self.upsampling_blocks = nn.ModuleList() 206 | 207 | dim = output_nc 208 | for i in range(n_downsampling): 209 | self.downsampling_blocks.append( 210 | NonLocalResnetDownsamplingBlock(dim, 2 * dim, norm_layer, use_dropout, use_bias, latent_dim) 211 | ) 212 | dim *= 2 213 | 214 | res_blocks_seq = n_blocks * [ResnetBlock(dim, norm_layer, use_dropout, use_bias)] 215 | self.res_blocks = nn.Sequential(*res_blocks_seq) 216 | 217 | for i in range(n_downsampling): 218 | self.upsampling_blocks.append( 219 | NonLocalResnetUpsamplingBlock(dim, dim // 2, dim // 2, norm_layer, use_dropout, use_bias, latent_dim) 220 | ) 221 | dim //= 2 222 | 223 | def forward(self, x): 224 | x_ = self.projection(x) 225 | out = self.in_conv(x_) 226 | 227 | skip_links = list() 228 | for i in range(self.n_downsampling): 229 | skip_links.append(out) 230 | out = self.downsampling_blocks[i](out) 231 | 232 | out = self.res_blocks(out) 233 | 234 | for i in range(self.n_downsampling): 235 | out = self.upsampling_blocks[i](out, skip_links[-i - 1]) 236 | 237 | out = self.out_conv(torch.cat([x_, out], dim=1)) 238 | return out 239 | -------------------------------------------------------------------------------- /model/layer_utils/se_block.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class GlobalAvgPool(nn.Module): 5 | """(N,C,H,W) -> (N,C)""" 6 | 7 | def __init__(self): 8 | super(GlobalAvgPool, self).__init__() 9 | 10 | def forward(self, x): 11 | N, C, H, W = x.shape 12 | return x.view(N, C, -1).mean(-1) 13 | 14 | 15 | class SEBlock(nn.Module): 16 | """(N,C,H,W) -> (N,C,H,W)""" 17 | def __init__(self, in_channel, r): 18 | super(SEBlock, self).__init__() 19 | self.se = nn.Sequential( 20 | GlobalAvgPool(), 21 | nn.Linear(in_channel, in_channel // r), 22 | nn.ReLU(inplace=True), 23 | nn.Linear(in_channel // r, in_channel), 24 | nn.Sigmoid() 25 | ) 26 | 27 | def forward(self, x): 28 | se_weight = self.se(x).unsqueeze(-1).unsqueeze(-1) # (N, C, 1, 1) 29 | return x * se_weight # (N, C, H, W) 30 | -------------------------------------------------------------------------------- /model/layer_utils/unet.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .funcs import get_norm_layer 7 | 8 | 9 | class UnetDoubleConvBlock(nn.Module): 10 | """ 11 | Unet double Conv block 12 | in_channel -> out_channel 13 | """ 14 | 15 | def __init__(self, in_channel, out_channel, norm_layer, use_dropout, use_bias, mode='default'): 16 | super(UnetDoubleConvBlock, self).__init__() 17 | 18 | self.mode = mode 19 | 20 | if self.mode == 'default': 21 | self.model = nn.Sequential( 22 | nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=use_bias), 23 | norm_layer(out_channel), 24 | nn.ReLU(inplace=True), 25 | nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=use_bias), 26 | norm_layer(out_channel), 27 | nn.ReLU(inplace=True) 28 | ) 29 | out_sequence = [] 30 | elif self.mode == 'bottleneck': 31 | self.model = nn.Sequential( 32 | nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, bias=use_bias), 33 | norm_layer(out_channel), 34 | nn.ReLU(inplace=True), 35 | nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=use_bias), 36 | norm_layer(out_channel), 37 | nn.ReLU(inplace=True), 38 | nn.Conv2d(out_channel, out_channel, kernel_size=1, stride=1, bias=use_bias), 39 | norm_layer(out_channel), 40 | nn.ReLU(inplace=True) 41 | ) 42 | out_sequence = [] 43 | elif self.mode == 'res-bottleneck': 44 | self.projection = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1) 45 | self.bottleneck = nn.Sequential( 46 | nn.Conv2d(out_channel, out_channel, kernel_size=1, stride=1, bias=use_bias), 47 | norm_layer(out_channel), 48 | nn.ReLU(inplace=True), 49 | nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=use_bias), 50 | norm_layer(out_channel), 51 | nn.ReLU(inplace=True), 52 | nn.Conv2d(out_channel, out_channel, kernel_size=1, stride=1, bias=use_bias), 53 | ) 54 | out_sequence = [ 55 | norm_layer(out_channel), 56 | nn.ReLU(inplace=True) 57 | ] 58 | else: 59 | raise NotImplementedError('mode [%s] is not found' % self.mode) 60 | 61 | if use_dropout: 62 | out_sequence += [nn.Dropout(0.5)] 63 | 64 | self.out_block = nn.Sequential(*out_sequence) 65 | 66 | def forward(self, x): 67 | if self.mode == 'res-bottleneck': 68 | x_ = self.projection(x) 69 | out = self.out_block(x_ + self.bottleneck(x_)) 70 | else: 71 | out = self.out_block(self.model(x)) 72 | return out 73 | 74 | 75 | class UnetDownsamplingBlock(nn.Module): 76 | """ 77 | Unet downsampling block 78 | in_channel -> out_channel 79 | """ 80 | 81 | def __init__(self, in_channel, out_channel, norm_layer, use_dropout, use_bias, use_conv, mode='default'): 82 | super(UnetDownsamplingBlock, self).__init__() 83 | 84 | downsampling_layers = list() 85 | if use_conv: 86 | downsampling_layers += [ 87 | nn.Conv2d(in_channel, in_channel, kernel_size=4, stride=2, padding=1, bias=use_bias), 88 | norm_layer(out_channel), 89 | nn.ReLU(inplace=True) 90 | ] 91 | else: 92 | downsampling_layers += [nn.MaxPool2d(2)] 93 | 94 | self.model = nn.Sequential( 95 | nn.Sequential(*downsampling_layers), 96 | UnetDoubleConvBlock(in_channel, out_channel, norm_layer, use_dropout, use_bias, mode=mode) 97 | ) 98 | 99 | def forward(self, x): 100 | out = self.model(x) 101 | return out 102 | 103 | 104 | class UnetUpsamplingBlock(nn.Module): 105 | """ 106 | Unet upsampling block 107 | x1:in_channel1 x2:in_channel2 --> out_channel 108 | """ 109 | 110 | def __init__(self, in_channel1, in_channel2, out_channel, norm_layer, use_dropout, use_bias, mode='default'): 111 | super(UnetUpsamplingBlock, self).__init__() 112 | # in_channel1: 待上采样的输入通道数 113 | # in_channel2: skip link来的通道数 114 | self.upsample = nn.ConvTranspose2d(in_channel1, in_channel1 // 2, kernel_size=4, stride=2, padding=1, 115 | bias=use_bias) 116 | self.double_conv = UnetDoubleConvBlock(in_channel1 // 2 + in_channel2, out_channel, norm_layer, use_dropout, 117 | use_bias, mode=mode) 118 | 119 | def forward(self, x1, x2): 120 | # x1: 待上采样的输入 121 | # x2: skip link来的输入 122 | out = torch.cat([x2, self.upsample(x1)], dim=1) 123 | out = self.double_conv(out) 124 | return out 125 | 126 | 127 | class UnetBackbone(nn.Module): 128 | """ 129 | Unet backbone 130 | input_nc -> output_nc 131 | """ 132 | 133 | def __init__(self, input_nc, output_nc=64, n_downsampling=4, use_conv_to_downsample=True, norm_type='instance', 134 | use_dropout=False, mode='default'): 135 | super(UnetBackbone, self).__init__() 136 | 137 | self.n_downsampling = n_downsampling 138 | 139 | norm_layer = get_norm_layer(norm_type) 140 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 141 | use_bias = norm_layer.func != nn.BatchNorm2d 142 | else: 143 | use_bias = norm_layer != nn.BatchNorm2d 144 | 145 | self.double_conv_block = UnetDoubleConvBlock(input_nc, output_nc, norm_layer, use_dropout, use_bias, mode=mode) 146 | self.downsampling_blocks = nn.ModuleList() 147 | self.upsampling_blocks = nn.ModuleList() 148 | 149 | dim = output_nc 150 | for i in range(n_downsampling): 151 | self.downsampling_blocks.append( 152 | UnetDownsamplingBlock(dim, 2 * dim, norm_layer, use_dropout, use_bias, use_conv_to_downsample, 153 | mode=mode) 154 | ) 155 | dim *= 2 156 | 157 | for i in range(n_downsampling): 158 | self.upsampling_blocks.append( 159 | UnetUpsamplingBlock(dim, dim // 2, dim // 2, norm_layer, use_dropout, use_bias, mode=mode) 160 | ) 161 | dim //= 2 162 | 163 | def forward(self, x): 164 | double_conv_block_out = self.double_conv_block(x) 165 | 166 | downsampling_blocks_out = list() 167 | downsampling_blocks_out.append( 168 | self.downsampling_blocks[0](double_conv_block_out) 169 | ) 170 | for i in range(1, self.n_downsampling): 171 | downsampling_blocks_out.append( 172 | self.downsampling_blocks[i](downsampling_blocks_out[-1]) 173 | ) 174 | 175 | upsampling_blocks_out = list() 176 | upsampling_blocks_out.append( 177 | self.upsampling_blocks[0](downsampling_blocks_out[-1], downsampling_blocks_out[-2]) 178 | ) 179 | for i in range(1, self.n_downsampling - 1): 180 | upsampling_blocks_out.append( 181 | self.upsampling_blocks[i](upsampling_blocks_out[-1], downsampling_blocks_out[-2 - i]) 182 | ) 183 | upsampling_blocks_out.append( 184 | self.upsampling_blocks[-1](upsampling_blocks_out[-1], double_conv_block_out) 185 | ) 186 | 187 | out = upsampling_blocks_out[-1] 188 | return out 189 | 190 | 191 | class AttentionBlock(nn.Module): 192 | """ 193 | attention block 194 | x:in_channel_x g:in_channel_g --> in_channel_x 195 | """ 196 | 197 | def __init__(self, in_channel_x, in_channel_g, channel_t, norm_layer, use_bias): 198 | # in_channel_x: 输入通道数(skip link来的) 199 | # in_channel_g: gating signal的通道数(上采样后的) 200 | super(AttentionBlock, self).__init__() 201 | self.x_block = nn.Sequential( 202 | nn.Conv2d(in_channel_x, channel_t, kernel_size=1, stride=1, padding=0, bias=use_bias), 203 | norm_layer(channel_t) 204 | ) 205 | 206 | self.g_block = nn.Sequential( 207 | nn.Conv2d(in_channel_g, channel_t, kernel_size=1, stride=1, padding=0, bias=use_bias), 208 | norm_layer(channel_t) 209 | ) 210 | 211 | self.t_block = nn.Sequential( 212 | nn.Conv2d(channel_t, 1, kernel_size=1, stride=1, padding=0, bias=use_bias), 213 | norm_layer(1), 214 | nn.Sigmoid() 215 | ) 216 | 217 | self.relu = nn.ReLU(inplace=True) 218 | 219 | def forward(self, x, g): 220 | # x: (N, in_channel_x, H, W) 输入(skip link来的) 221 | # g: (N, in_channel_g, H, W) gating signal的输入(上采样后的) 222 | # x g两者的H W是一致的 223 | x_out = self.x_block(x) # (N, channel_t, H, W) 224 | g_out = self.g_block(g) # (N, channel_t, H, W) 225 | t_in = self.relu(x_out + g_out) # (N, 1, H, W) 226 | attention_map = self.t_block(t_in) # (N, 1, H, W) 227 | return x * attention_map # (N, in_channel_x, H, W) 228 | 229 | 230 | class AttentionUnetUpsamplingBlock(nn.Module): 231 | """ 232 | attention Unet upsampling block 233 | x1:in_channel1 x2:in_channel2 --> out_channel 234 | """ 235 | 236 | def __init__(self, in_channel1, in_channel2, out_channel, norm_layer, use_dropout, use_bias, mode='default'): 237 | super(AttentionUnetUpsamplingBlock, self).__init__() 238 | # in_channel1: 待上采样的输入通道数 239 | # in_channel2: skip link来的通道数 240 | self.upsample = nn.Sequential( 241 | nn.ConvTranspose2d(in_channel1, in_channel1 // 2, kernel_size=4, stride=2, padding=1, bias=use_bias), 242 | norm_layer(out_channel), 243 | nn.ReLU(inplace=True) 244 | ) 245 | self.attention = AttentionBlock(in_channel1 // 2, in_channel2, in_channel1 // 2, norm_layer, use_bias) 246 | self.double_conv = UnetDoubleConvBlock(in_channel1 // 2 + in_channel2, out_channel, norm_layer, use_dropout, 247 | use_bias, mode=mode) 248 | 249 | def forward(self, x1, x2): 250 | # x1: 待上采样的输入 251 | # x2: skip link来的输入 252 | upsampled_x1 = self.upsample(x1) # 作为attention block的gating signal 253 | attentioned_x2 = self.attention(x2, upsampled_x1) 254 | out = torch.cat([attentioned_x2, upsampled_x1], dim=1) 255 | out = self.double_conv(out) 256 | return out 257 | 258 | 259 | class AttentionUnetBackbone(nn.Module): 260 | """ 261 | attention Unet backbone 262 | input_nc -> output_nc 263 | """ 264 | 265 | def __init__(self, input_nc, output_nc=64, n_downsampling=4, use_conv_to_downsample=False, norm_type='instance', 266 | use_dropout=False, mode='default'): 267 | super(AttentionUnetBackbone, self).__init__() 268 | 269 | self.n_downsampling = n_downsampling 270 | 271 | norm_layer = get_norm_layer(norm_type) 272 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 273 | use_bias = norm_layer.func != nn.BatchNorm2d 274 | else: 275 | use_bias = norm_layer != nn.BatchNorm2d 276 | 277 | self.double_conv_block = UnetDoubleConvBlock(input_nc, output_nc, norm_layer, use_dropout, use_bias, mode=mode) 278 | self.downsampling_blocks = nn.ModuleList() 279 | self.upsampling_blocks = nn.ModuleList() 280 | 281 | dim = output_nc 282 | for i in range(n_downsampling): 283 | self.downsampling_blocks.append( 284 | UnetDownsamplingBlock(dim, 2 * dim, norm_layer, use_dropout, use_bias, use_conv_to_downsample, 285 | mode=mode) 286 | ) 287 | dim *= 2 288 | 289 | for i in range(n_downsampling): 290 | self.upsampling_blocks.append( 291 | AttentionUnetUpsamplingBlock(dim, dim // 2, dim // 2, norm_layer, use_dropout, use_bias, mode=mode) 292 | ) 293 | dim //= 2 294 | 295 | def forward(self, x): 296 | double_conv_block_out = self.double_conv_block(x) 297 | 298 | downsampling_blocks_out = list() 299 | downsampling_blocks_out.append( 300 | self.downsampling_blocks[0](double_conv_block_out) 301 | ) 302 | for i in range(1, self.n_downsampling): 303 | downsampling_blocks_out.append( 304 | self.downsampling_blocks[i](downsampling_blocks_out[-1]) 305 | ) 306 | 307 | upsampling_blocks_out = list() 308 | upsampling_blocks_out.append( 309 | self.upsampling_blocks[0](downsampling_blocks_out[-1], downsampling_blocks_out[-2]) 310 | ) 311 | for i in range(1, self.n_downsampling - 1): 312 | upsampling_blocks_out.append( 313 | self.upsampling_blocks[i](upsampling_blocks_out[-1], downsampling_blocks_out[-2 - i]) 314 | ) 315 | upsampling_blocks_out.append( 316 | self.upsampling_blocks[-1](upsampling_blocks_out[-1], double_conv_block_out) 317 | ) 318 | 319 | out = upsampling_blocks_out[-1] 320 | return out 321 | -------------------------------------------------------------------------------- /model/loss_LearnEdgeNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | # should be torch.ones([1, 256, 256]).cuda() while training with the grayscale images 5 | pos_weight_dummy = torch.ones([3, 256, 256]).cuda() 6 | 7 | def bce_with_logits(output, target, **kwargs): 8 | pos_weight = kwargs.get('pos_weight', 1) * pos_weight_dummy 9 | loss = F.binary_cross_entropy_with_logits(output, target, pos_weight=pos_weight) 10 | return loss 11 | -------------------------------------------------------------------------------- /model/loss_LearnMaskNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | # should be torch.ones([1, 256, 256]).cuda() while training with the grayscale images 5 | pos_weight_dummy = torch.ones([3, 256, 256]).cuda() 6 | 7 | 8 | def bce_with_logits(output, target, **kwargs): 9 | pos_weight = kwargs.get('pos_weight', 1) * pos_weight_dummy 10 | loss = F.binary_cross_entropy_with_logits(output, target, pos_weight=pos_weight) 11 | return loss 12 | -------------------------------------------------------------------------------- /model/metric_LearnEdgeNet.py: -------------------------------------------------------------------------------- 1 | from .metric_utils.per_pixel import accuracy, precision, recall, f1_score 2 | -------------------------------------------------------------------------------- /model/metric_LearnMaskNet.py: -------------------------------------------------------------------------------- 1 | from .metric_utils.per_pixel import accuracy, precision, recall, f1_score 2 | -------------------------------------------------------------------------------- /model/metric_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fourson/UnModNet/70b72c595382556b96b66333ea3f82079c48dfa6/model/metric_utils/__init__.py -------------------------------------------------------------------------------- /model/metric_utils/per_pixel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.jit 3 | 4 | 5 | def accuracy(pred, target): 6 | # accuracy = (TP + TN) / (TP + TN + FP + FN) 7 | # the accuracy in all pixels 8 | true = torch.sum(pred == target).double() # TP + TN 9 | total = torch.numel(pred) # TP + TN + FP + FN 10 | return true / total 11 | 12 | 13 | def precision(pred, target): 14 | # precision = TP / (TP + FP) 15 | # the accuracy in pixels which are predicted as positive 16 | true_positive = torch.sum(pred * target).double() # TP 17 | pred_positive = torch.sum(pred) # TP + FP 18 | return true_positive / (pred_positive + 1e-8) 19 | 20 | 21 | def recall(pred, target): 22 | # recall = TP / (TP + FN) 23 | # the accuracy in pixels which are positive 24 | true_positive = torch.sum(pred * target).double() # TP 25 | target_positive = torch.sum(target) # TP + FN 26 | return true_positive / (target_positive + 1e-8) 27 | 28 | 29 | def f1_score(pred, target): 30 | # f1_score = 2 * precision * recall / (precision + recall) 31 | p = precision(pred, target) 32 | r = recall(pred, target) 33 | if p == 0 and r == 0: 34 | return 0 35 | else: 36 | return 2 * p * r / (p + r) 37 | 38 | 39 | def MIoU(pred, target): 40 | # MIoU = TP / (TP + FN + FP) 41 | # Mean Intersection over Union of the **positive pixels** (TP + FN) and the **predicted postive pixels** (TP + FP) 42 | true_positive = torch.sum(pred * target).double() # TP 43 | target_positive = torch.sum(target) # TP + FN 44 | false_positive = torch.sum(pred * (1 - target)) # FP 45 | return (true_positive + 1e-8) / (target_positive + false_positive + 1e-8) 46 | -------------------------------------------------------------------------------- /model/metric_utils/psnr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.jit 3 | 4 | 5 | @torch.jit.script 6 | def psnr(X, Y, data_range: float): 7 | """ 8 | Peak Signal to Noise Ratio 9 | """ 10 | 11 | mse = torch.mean((X - Y) ** 2) 12 | return 10 * torch.log10(data_range ** 2 / mse) 13 | 14 | 15 | class PSNR(torch.jit.ScriptModule): 16 | __constants__ = ['data_range', 'avg'] 17 | 18 | def __init__(self, data_range=1., avg=True): 19 | super().__init__() 20 | self.data_range = data_range 21 | self.avg = avg 22 | 23 | @torch.jit.script_method 24 | def forward(self, X, Y): 25 | r = psnr(X, Y, self.data_range) 26 | if self.avg: 27 | return r.mean() 28 | else: 29 | return r 30 | -------------------------------------------------------------------------------- /model/metric_utils/ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.jit 3 | import torch.nn.functional as F 4 | 5 | 6 | @torch.jit.script 7 | def create_window(window_size: int, sigma: float, channel: int): 8 | ''' 9 | Create 1-D gauss kernel 10 | :param window_size: the size of gauss kernel 11 | :param sigma: sigma of normal distribution 12 | :param channel: input channel 13 | :return: 1D kernel 14 | ''' 15 | coords = torch.arange(window_size, dtype=torch.float) 16 | coords -= window_size // 2 17 | 18 | g = torch.exp(-(coords ** 2) / (2 * sigma ** 2)) 19 | g /= g.sum() 20 | 21 | g = g.reshape(1, 1, 1, -1).repeat(channel, 1, 1, 1) 22 | return g 23 | 24 | 25 | @torch.jit.script 26 | def _gaussian_filter(x, window_1d, use_padding: bool): 27 | ''' 28 | Blur input with 1-D kernel 29 | :param x: batch of tensors to be blured 30 | :param window_1d: 1-D gauss kernel 31 | :param use_padding: padding image before conv 32 | :return: blured tensors 33 | ''' 34 | C = x.shape[1] 35 | padding = 0 36 | if use_padding: 37 | window_size = window_1d.shape[3] 38 | padding = window_size // 2 39 | out = F.conv2d(x, window_1d, stride=1, padding=(0, padding), groups=C) 40 | out = F.conv2d(out, window_1d.transpose(2, 3), stride=1, padding=(padding, 0), groups=C) 41 | return out 42 | 43 | 44 | @torch.jit.script 45 | def ssim(X, Y, window, data_range: float, use_padding: bool = False): 46 | ''' 47 | Calculate ssim index for X and Y 48 | :param X: images 49 | :param Y: images 50 | :param window: 1-D gauss kernel 51 | :param data_range: value range of input images. (usually 1.0 or 255) 52 | :param use_padding: padding image before conv 53 | :return: (N,1) 54 | ''' 55 | 56 | K1 = 0.01 57 | K2 = 0.03 58 | compensation = 1.0 59 | 60 | C1 = (K1 * data_range) ** 2 61 | C2 = (K2 * data_range) ** 2 62 | 63 | mu1 = _gaussian_filter(X, window, use_padding) 64 | mu2 = _gaussian_filter(Y, window, use_padding) 65 | sigma1_sq = _gaussian_filter(X * X, window, use_padding) 66 | sigma2_sq = _gaussian_filter(Y * Y, window, use_padding) 67 | sigma12 = _gaussian_filter(X * Y, window, use_padding) 68 | 69 | mu1_sq = mu1.pow(2) 70 | mu2_sq = mu2.pow(2) 71 | mu1_mu2 = mu1 * mu2 72 | 73 | sigma1_sq = compensation * (sigma1_sq - mu1_sq) 74 | sigma2_sq = compensation * (sigma2_sq - mu2_sq) 75 | sigma12 = compensation * (sigma12 - mu1_mu2) 76 | 77 | cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2) 78 | ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map 79 | 80 | ssim_val = ssim_map.mean(dim=(1, 2, 3)) # reduce along CHW 81 | cs = cs_map.mean(dim=(1, 2, 3)) 82 | 83 | return ssim_val, cs 84 | 85 | 86 | @torch.jit.script 87 | def ms_ssim(X, Y, window, data_range: float, weights, use_padding: bool = False): 88 | ''' 89 | interface of ms-ssim 90 | :param X: a batch of images, (N,C,H,W) 91 | :param Y: a batch of images, (N,C,H,W) 92 | :param window: 1-D gauss kernel 93 | :param data_range: value range of input images. (usually 1.0 or 255) 94 | :param weights: weights for different levels 95 | :param use_padding: padding image before conv 96 | :return: (N,1) 97 | ''' 98 | levels = weights.shape[0] 99 | cs_vals = [] 100 | ssim_vals = [] 101 | for _ in range(levels): 102 | ssim_val, cs = ssim(X, Y, window=window, data_range=data_range, use_padding=use_padding) 103 | cs_vals.append(cs) 104 | ssim_vals.append(ssim_val) 105 | padding = (X.shape[2] % 2, X.shape[3] % 2) 106 | X = F.avg_pool2d(X, kernel_size=2, stride=2, padding=padding) 107 | Y = F.avg_pool2d(Y, kernel_size=2, stride=2, padding=padding) 108 | 109 | cs_vals = torch.stack(cs_vals, dim=0) 110 | ms_ssim_val = torch.prod((cs_vals[:-1] ** weights[:-1].unsqueeze(1)) * (ssim_vals[-1] ** weights[-1]), dim=0) 111 | return ms_ssim_val 112 | 113 | 114 | class SSIM(torch.jit.ScriptModule): 115 | __constants__ = ['data_range', 'use_padding', 'avg'] 116 | 117 | def __init__(self, window_size=11, window_sigma=1.5, data_range=1., channel=3, use_padding=False, avg=True): 118 | ''' 119 | Structural Similarity Index 120 | :param window_size: the size of gauss kernel 121 | :param window_sigma: sigma of normal distribution 122 | :param data_range: value range of input images. (usually 1.0 or 255) 123 | :param channel: input channels (default: 3) 124 | :param use_padding: padding image before conv 125 | :param avg: average between the batch 126 | ''' 127 | super().__init__() 128 | assert window_size % 2 == 1, 'Window size must be odd.' 129 | window = create_window(window_size, window_sigma, channel) 130 | self.register_buffer('window', window) 131 | self.data_range = data_range 132 | self.use_padding = use_padding 133 | self.avg = avg 134 | 135 | @torch.jit.script_method 136 | def forward(self, X, Y): 137 | r = ssim(X, Y, window=self.window, data_range=self.data_range, use_padding=self.use_padding)[0] 138 | if self.avg: 139 | return r.mean() 140 | else: 141 | return r 142 | 143 | 144 | class MS_SSIM(torch.jit.ScriptModule): 145 | __constants__ = ['data_range', 'use_padding', 'avg'] 146 | 147 | def __init__(self, window_size=11, window_sigma=1.5, data_range=1., channel=3, use_padding=False, weights=None, 148 | levels=None, avg=True): 149 | ''' 150 | Multi-Scale Structural Similarity Index 151 | :param window_size: the size of gauss kernel 152 | :param window_sigma: sigma of normal distribution 153 | :param data_range: value range of input images. (usually 1.0 or 255) 154 | :param channel: input channels 155 | :param use_padding: padding image before conv 156 | :param weights: weights for different levels. (default [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]) 157 | :param levels: number of downsampling 158 | :param avg: average between the batch 159 | ''' 160 | super().__init__() 161 | assert window_size % 2 == 1, 'Window size must be odd.' 162 | self.data_range = data_range 163 | self.use_padding = use_padding 164 | self.avg = avg 165 | 166 | window = create_window(window_size, window_sigma, channel) 167 | self.register_buffer('window', window) 168 | 169 | if weights is None: 170 | weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333] 171 | weights = torch.tensor(weights, dtype=torch.float) 172 | 173 | if levels is not None: 174 | weights = weights[:levels] 175 | weights = weights / weights.sum() 176 | 177 | self.register_buffer('weights', weights) 178 | 179 | @torch.jit.script_method 180 | def forward(self, X, Y): 181 | r = ms_ssim(X, Y, window=self.window, data_range=self.data_range, weights=self.weights, 182 | use_padding=self.use_padding) 183 | if self.avg: 184 | return r.mean() 185 | else: 186 | return r 187 | -------------------------------------------------------------------------------- /model/model_LearnEdgeNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .layer_utils.resnet import NonLocalResnetBackbone 5 | from base.base_model import BaseModel 6 | 7 | 8 | class DefaultModel(BaseModel): 9 | """ 10 | Define the network to learn the fold_number_edge 11 | x1:input_nc x2:input_nc --> input_nc 12 | output: scores (we shold use sigmoid() and round() to get the fold_number_edge in trainer) 13 | """ 14 | 15 | def __init__(self, input_nc, init_dim=64, n_downsampling=2, n_blocks=6, norm_type='instance', use_dropout=False, 16 | mode='residual'): 17 | super(DefaultModel, self).__init__() 18 | 19 | self.mode = mode 20 | 21 | if self.mode == 'residual': 22 | self.backbone = nn.Sequential( 23 | NonLocalResnetBackbone(input_nc * 2, output_nc=init_dim, n_downsampling=n_downsampling, 24 | n_blocks=n_blocks, 25 | norm_type=norm_type, use_dropout=use_dropout), 26 | nn.Conv2d(init_dim, input_nc, kernel_size=7, stride=1, padding=3), 27 | nn.Tanh() 28 | ) 29 | elif self.mode == 'end2end': 30 | self.backbone = nn.Sequential( 31 | NonLocalResnetBackbone(input_nc, output_nc=init_dim, n_downsampling=n_downsampling, 32 | n_blocks=n_blocks, 33 | norm_type=norm_type, use_dropout=use_dropout), 34 | nn.Conv2d(init_dim, input_nc, kernel_size=7, stride=1, padding=3), 35 | nn.Tanh() 36 | ) 37 | elif self.mode == 'end2end_without_tanh': 38 | self.backbone = nn.Sequential( 39 | NonLocalResnetBackbone(input_nc, output_nc=init_dim, n_downsampling=n_downsampling, 40 | n_blocks=n_blocks, 41 | norm_type=norm_type, use_dropout=use_dropout), 42 | nn.Conv2d(init_dim, input_nc, kernel_size=7, stride=1, padding=3), 43 | ) 44 | else: 45 | raise NotImplementedError('mode [%s] is not found' % self.mode) 46 | 47 | def forward(self, x1, x2): 48 | # x1: (N, input_nc, H, W) input modulo img([0, 1] float, as float32) 49 | # x2: (N, input_nc, H, W) laplacian of input modulo img([0, 1] float, as float32) 50 | if self.mode == 'residual': 51 | out = self.backbone(torch.cat([x1, x2], dim=1)) + x2 52 | else: 53 | out = self.backbone(x1) 54 | return out 55 | -------------------------------------------------------------------------------- /model/model_LearnMaskNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .layer_utils.unet import AttentionUnetBackbone 5 | from .layer_utils.se_block import SEBlock 6 | from .layer_utils.region_non_local_block import RegionNonLocalEnhancedDenseBlock 7 | from base.base_model import BaseModel 8 | 9 | 10 | class DefaultModel(BaseModel): 11 | """ 12 | Define the network to learn the binary mask 13 | x1:input_nc x2:input_nc --> input_nc 14 | output: scores (we shold use sigmoid() and round() to get the binary mask in trainer) 15 | """ 16 | 17 | def __init__(self, input_nc, init_dim=64, n_downsampling=4, use_conv_to_downsample=True, norm_type='instance', 18 | use_dropout=False, mode='res-bottleneck'): 19 | super(DefaultModel, self).__init__() 20 | 21 | self.modulo_feature_extraction = nn.Sequential( 22 | nn.Conv2d(input_nc, init_dim // 2, kernel_size=7, stride=1, padding=3, bias=True), 23 | nn.InstanceNorm2d(init_dim // 2), 24 | nn.ReLU(True) 25 | ) 26 | self.edge_feature_extraction = nn.Sequential( 27 | nn.Conv2d(input_nc, init_dim // 2, kernel_size=7, stride=1, padding=3, bias=True), 28 | nn.InstanceNorm2d(init_dim // 2), 29 | nn.ReLU(True), 30 | RegionNonLocalEnhancedDenseBlock(in_channel=init_dim // 2, inter_channel=init_dim // 4, n_blocks=3, 31 | latent_dim=2, subsample=True, grid=(8, 8)) 32 | ) 33 | self.fusion = nn.Sequential( 34 | nn.Conv2d(init_dim, init_dim, kernel_size=1, stride=1), 35 | SEBlock(init_dim, 8) 36 | ) 37 | self.backbone = AttentionUnetBackbone(init_dim, output_nc=init_dim, n_downsampling=n_downsampling, 38 | use_conv_to_downsample=use_conv_to_downsample, norm_type=norm_type, 39 | use_dropout=use_dropout, mode=mode) 40 | self.out_block = nn.Sequential( 41 | nn.Conv2d(init_dim, input_nc, kernel_size=1, stride=1) 42 | ) 43 | 44 | def forward(self, x1, x2): 45 | # x1: (N, input_nc, H, W) input modulo img([0, 1] float, as float32) 46 | # x2: (N, input_nc, H, W) fold number edge of input modulo img(binary, as float32) 47 | modulo_feature = self.modulo_feature_extraction(x1) 48 | edge_feature = self.edge_feature_extraction(x2) 49 | fusion_out = self.fusion(torch.cat([modulo_feature, edge_feature], dim=1)) 50 | backbone_out = self.backbone(fusion_out) 51 | out = self.out_block(backbone_out) 52 | return out 53 | -------------------------------------------------------------------------------- /poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fourson/UnModNet/70b72c595382556b96b66333ea3f82079c48dfa6/poster.pdf -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fourson/UnModNet/70b72c595382556b96b66333ea3f82079c48dfa6/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/make_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import argparse 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | 8 | class RandomExposer: 9 | def __init__(self, hdr, overexposure_rate_ubound=0.35, overexposure_rate_lbound=0.05, iter_max=15): 10 | self.hdr = (hdr - np.min(hdr)) / (np.max(hdr) - np.min(hdr)) # (H, W, C) image(scaled to [0, 1]) 11 | 12 | # first we assume that the pixel values(scaled to [0, 1]) obey uniform distribution 13 | # so that pixel tolerance v is equal to not_overexposure rate 14 | # pixel tolerance definition: the pixel value which is larger than pixel tolerance means that it is overexposed 15 | 16 | # we hope that overexposure rate is between [overexposure_rate_lbound, overexposure_rate_ubound] 17 | self.overexposure_rate_ubound = overexposure_rate_ubound 18 | self.overexposure_rate_lbound = overexposure_rate_lbound 19 | 20 | # so not_overexposure rate is between [1-overexposure_rate_ubound, 1-overexposure_rate_lbound], so is v 21 | v0_ubound = 1 - overexposure_rate_lbound 22 | v0_lbound = 1 - overexposure_rate_ubound 23 | 24 | # we set the initial value of v between [v0_lbound, v0_ubound] randomly 25 | self.v = np.random.random() * (v0_ubound - v0_lbound) + v0_lbound 26 | 27 | # set the maximum number of iterations 28 | self.iter_max = iter_max 29 | 30 | self.success = True 31 | 32 | # for log 33 | self.iter_cnt = 0 34 | self.final_overexposure_rate = 0 35 | self.exposure = 0 36 | self.log = "iter_cnt exceeds the limit({})\n".format(self.iter_max) 37 | 38 | @property 39 | def overexposure_rate(self): 40 | # calculate the overexposure_rate 41 | # "the pixel which is not overexposure" means that it is not overexposure in all channels 42 | not_overexposure_mask = np.prod((self.hdr < self.v), axis=2) 43 | return 1 - np.sum(not_overexposure_mask) / not_overexposure_mask.size 44 | 45 | def _find_v(self): 46 | # use binary-search to find a reasonable v 47 | overexposure_rate = self.overexposure_rate 48 | if overexposure_rate > self.overexposure_rate_ubound: 49 | v_lb, v_ub = self.v, 1 50 | elif overexposure_rate < self.overexposure_rate_lbound: 51 | v_lb, v_ub = 0, self.v 52 | else: 53 | v_lb, v_ub = 0, 0 54 | 55 | iter_cnt = 0 56 | while v_lb < v_ub: 57 | if iter_cnt > self.iter_max: 58 | # exceeds the iteration limit 59 | self.success = False 60 | break 61 | iter_cnt += 1 62 | self.v = (v_lb + v_ub) / 2 63 | overexposure_rate = self.overexposure_rate 64 | if overexposure_rate > self.overexposure_rate_ubound: 65 | v_lb = self.v 66 | elif overexposure_rate < self.overexposure_rate_lbound: 67 | v_ub = self.v 68 | else: 69 | break 70 | 71 | self.iter_cnt = iter_cnt 72 | self.final_overexposure_rate = overexposure_rate 73 | 74 | def expose(self, modulo_pixel_max, ref_pixel_max): 75 | # simulates the sensor exposure and ADC(we assume that the CRF is linear) 76 | self._find_v() 77 | if self.success: 78 | # as the equation: v = (modulo_pixel_max + 1) / (ref_pixel_max * exposure) 79 | self.exposure = (modulo_pixel_max + 1) / (ref_pixel_max * self.v) 80 | self.hdr = np.clip(self.hdr * self.exposure, 0, 1) # sensor exposure 81 | self.hdr = np.floor(self.hdr * ref_pixel_max) # ADC 82 | self.log = "iter_cnt: {}\nfinal_overexposure_rate: {}\nexposure: {}\n".format(self.iter_cnt, 83 | self.final_overexposure_rate, 84 | self.exposure) 85 | 86 | 87 | class DataItem: 88 | """ 89 | (H, W, C) 90 | modulo: positive int, as float32 91 | fold_number: positive int, as float32 92 | mask: binary, as float32 93 | ref: positive int, as float32 94 | ldr: [0, modulo_pixel_max] int, as float32 95 | """ 96 | 97 | def __init__(self, origin, modulo_pixel_max): 98 | self.origin = origin 99 | self.modulo_pixel_max = modulo_pixel_max 100 | 101 | self.modulo = self.origin % (self.modulo_pixel_max + 1) # modulo image 102 | self.fold_number = self.origin // (self.modulo_pixel_max + 1) # fold number map 103 | self.mask = np.float32((self.fold_number > 0)) # binary mask for the pixels which are overexposed 104 | self.ref = self.modulo + (self.modulo_pixel_max + 1) * self.mask # reference image(modulo "plus 1 fold") 105 | self.ldr = (self.modulo_pixel_max + 1) * self.mask + self.origin * (1 - self.mask) # ldr image 106 | 107 | self.fold_number_max = int(np.max(self.fold_number)) 108 | 109 | def add_one_fold(self): 110 | self.modulo = self.ref 111 | self.fold_number = self.fold_number - self.mask 112 | self.mask = np.float32((self.fold_number > 0)) 113 | self.ref = self.modulo + (self.modulo_pixel_max + 1) * self.mask 114 | 115 | 116 | class DatasetMaker: 117 | def __init__(self, data_dir, train_dir, test_dir, training_sample=400, n_cut=5, crop_size=256, modulo_bits=8, 118 | ref_bits=12, multi_fold_for_training=False, **random_exposer_args): 119 | self.data_dir = data_dir # contains the original 512*512 hdr images 120 | self.train_dir = train_dir 121 | self.test_dir = test_dir 122 | self.n_cut = n_cut 123 | self.crop_size = crop_size 124 | self.modulo_pixel_max = 2 ** modulo_bits - 1 125 | self.ref_pixel_max = 2 ** ref_bits - 1 126 | self.multi_fold_for_training = multi_fold_for_training 127 | 128 | self.random_exposer_args = random_exposer_args 129 | 130 | self.data_list = os.listdir(data_dir) 131 | random.shuffle(self.data_list) 132 | self.train_data_list = self.data_list[:training_sample] 133 | self.test_data_list = self.data_list[training_sample:] 134 | 135 | self.train_subdir = {"origin": "", "modulo": "", "fold_number": "", "mask": "", "ref": "", "ldr": ""} 136 | self.test_subdir = {"origin": "", "modulo": "", "fold_number": "", "ldr": ""} 137 | 138 | self.save_cnt = 0 139 | 140 | def _ensure_dir(self, mode): 141 | assert mode in ("train", "test") 142 | mode_dir = getattr(self, mode + "_dir") 143 | mode_subdir = getattr(self, mode + "_subdir") 144 | for key in mode_subdir: 145 | mode_subdir[key] = os.path.join(mode_dir, key) 146 | if not os.path.exists(mode_subdir[key]): 147 | os.makedirs(mode_subdir[key]) 148 | 149 | def _cut(self, img, name): 150 | imgs = dict() 151 | for i in range(self.n_cut): 152 | h_offset = random.randint(0, 512 - self.crop_size) 153 | w_offset = random.randint(0, 512 - self.crop_size) 154 | imgs[name + "_" + str(i)] = img[h_offset:h_offset + self.crop_size, w_offset:w_offset + self.crop_size, :] 155 | return imgs 156 | 157 | def _save(self, data_item, name, mode): 158 | assert mode in ("train", "test") 159 | mode_subdir = getattr(self, mode + "_subdir") 160 | for key in mode_subdir: 161 | np.save(os.path.join(mode_subdir[key], name + ".npy"), getattr(data_item, key)) 162 | self.save_cnt += 1 163 | 164 | def make(self, mode): 165 | assert mode in ("train", "test") 166 | print("mode: {}\n".format(mode)) 167 | mode_data_list = getattr(self, mode + "_data_list") 168 | if not mode_data_list: 169 | return 170 | log = "" 171 | self._ensure_dir(mode) 172 | for hdr_file in tqdm(mode_data_list, ascii=True): 173 | hdr_name = hdr_file.split(".")[0] 174 | hdr_img = np.load(os.path.join(self.data_dir, hdr_file)) 175 | for name, img in self._cut(hdr_img, hdr_name).items(): 176 | log += "--------------------\nname: {}\n".format(name) 177 | random_exposer = RandomExposer(img, **self.random_exposer_args) 178 | random_exposer.expose(self.modulo_pixel_max, self.ref_pixel_max) 179 | log += random_exposer.log 180 | if random_exposer.success: 181 | data_item = DataItem(random_exposer.hdr, self.modulo_pixel_max) 182 | log += "fold_number_max: {}\n".format(data_item.fold_number_max) 183 | self._save(data_item, name, mode) 184 | if mode == "train" and self.multi_fold_for_training: 185 | for i in range(1, data_item.fold_number_max): 186 | data_item.add_one_fold() 187 | name_new = name + "_plus" + str(i) 188 | self._save(data_item, name_new, mode) 189 | log += "\n----------Summary----------\n" 190 | log += "This is the dataset for *{}*\n".format((mode + "ing").upper()) 191 | log += "{} images in total".format(self.save_cnt) 192 | self.save_cnt = 0 193 | with open(mode + "_dataset.log", "w") as f: 194 | f.write(log) 195 | 196 | 197 | if __name__ == "__main__": 198 | parser = argparse.ArgumentParser(description="Make dataset: from RGB(H, W, C) or grayscale(H, W, 1)") 199 | parser.add_argument("--data_dir", required=True, type=str, help="dir of original 512*512 hdr images") 200 | parser.add_argument("--train_dir", default="train", type=str, help="dir to save training dataset") 201 | parser.add_argument("--test_dir", default="test", type=str, help="dir to save test dataset") 202 | parser.add_argument("--training_sample", default=400, type=int, 203 | help="number of training sample") # the rest will be used as test sample 204 | parser.add_argument("--n_cut", default=5, type=int, help="number of cut per image") 205 | parser.add_argument("--modulo_bits", default=8, type=int, help="modulo image bits") 206 | parser.add_argument("--ref_bits", default=12, type=int, help="reference image(ground truth) bits") 207 | parser.add_argument("--multi_fold_for_training", default=1, type=int, 208 | help="make multi-fold data for training (set this to 0 when directly learn label in ablation)") 209 | parser.add_argument("--overexposure_rate_ubound", default=0.35, type=float, help="overexposure rate upper bound") 210 | parser.add_argument("--overexposure_rate_lbound", default=0.05, type=float, help="overexposure rate lower bound") 211 | parser.add_argument("--iter_max", default=15, type=int, help="the iteration limit of random exposer") 212 | args = parser.parse_args() 213 | 214 | dataset_maker = DatasetMaker(**vars(args)) 215 | dataset_maker.make("train") 216 | # for test, we don't need to crop, just keep original 512*512 resolution 217 | dataset_maker.crop_size = 512 218 | dataset_maker.n_cut = 1 219 | dataset_maker.make("test") 220 | -------------------------------------------------------------------------------- /scripts/make_dataset_original_resolution.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import argparse 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | 8 | class RandomExposer: 9 | def __init__(self, hdr, overexposure_rate_ubound=0.35, overexposure_rate_lbound=0.05, iter_max=15): 10 | self.hdr = (hdr - np.min(hdr)) / (np.max(hdr) - np.min(hdr)) # (H, W, C) image(scaled to [0, 1]) 11 | 12 | # first we assume that the pixel values(scaled to [0, 1]) obey uniform distribution 13 | # so that pixel tolerance v is equal to not_overexposure rate 14 | # pixel tolerance definition: the pixel value which is larger than pixel tolerance means that it is overexposed 15 | 16 | # we hope that overexposure rate is between [overexposure_rate_lbound, overexposure_rate_ubound] 17 | self.overexposure_rate_ubound = overexposure_rate_ubound 18 | self.overexposure_rate_lbound = overexposure_rate_lbound 19 | 20 | # so not_overexposure rate is between [1-overexposure_rate_ubound, 1-overexposure_rate_lbound], so is v 21 | v0_ubound = 1 - overexposure_rate_lbound 22 | v0_lbound = 1 - overexposure_rate_ubound 23 | 24 | # we set the initial value of v between [v0_lbound, v0_ubound] randomly 25 | self.v = np.random.random() * (v0_ubound - v0_lbound) + v0_lbound 26 | 27 | # set the maximum number of iterations 28 | self.iter_max = iter_max 29 | 30 | self.success = True 31 | 32 | # for log 33 | self.iter_cnt = 0 34 | self.final_overexposure_rate = 0 35 | self.exposure = 0 36 | self.log = "iter_cnt exceeds the limit({})\n".format(self.iter_max) 37 | 38 | @property 39 | def overexposure_rate(self): 40 | # calculate the overexposure_rate 41 | # "the pixel which is not overexposure" means that it is not overexposure in all channels 42 | not_overexposure_mask = np.prod((self.hdr < self.v), axis=2) 43 | return 1 - np.sum(not_overexposure_mask) / not_overexposure_mask.size 44 | 45 | def _find_v(self): 46 | # use binary-search to find a reasonable v 47 | overexposure_rate = self.overexposure_rate 48 | if overexposure_rate > self.overexposure_rate_ubound: 49 | v_lb, v_ub = self.v, 1 50 | elif overexposure_rate < self.overexposure_rate_lbound: 51 | v_lb, v_ub = 0, self.v 52 | else: 53 | v_lb, v_ub = 0, 0 54 | 55 | iter_cnt = 0 56 | while v_lb < v_ub: 57 | if iter_cnt > self.iter_max: 58 | # exceeds the iteration limit 59 | self.success = False 60 | break 61 | iter_cnt += 1 62 | self.v = (v_lb + v_ub) / 2 63 | overexposure_rate = self.overexposure_rate 64 | if overexposure_rate > self.overexposure_rate_ubound: 65 | v_lb = self.v 66 | elif overexposure_rate < self.overexposure_rate_lbound: 67 | v_ub = self.v 68 | else: 69 | break 70 | 71 | self.iter_cnt = iter_cnt 72 | self.final_overexposure_rate = overexposure_rate 73 | 74 | def expose(self, modulo_pixel_max, ref_pixel_max): 75 | # simulates the sensor exposure and ADC(we assume that the CRF is linear) 76 | self._find_v() 77 | if self.success: 78 | # as the equation: v = (modulo_pixel_max + 1) / (ref_pixel_max * exposure) 79 | self.exposure = (modulo_pixel_max + 1) / (ref_pixel_max * self.v) 80 | self.hdr = np.clip(self.hdr * self.exposure, 0, 1) # sensor exposure 81 | self.hdr = np.floor(self.hdr * ref_pixel_max) # ADC 82 | self.log = "iter_cnt: {}\nfinal_overexposure_rate: {}\nexposure: {}\n".format(self.iter_cnt, 83 | self.final_overexposure_rate, 84 | self.exposure) 85 | 86 | 87 | class DataItem: 88 | """ 89 | (H, W, C) 90 | modulo: positive int, as float32 91 | fold_number: positive int, as float32 92 | mask: binary, as float32 93 | ref: positive int, as float32 94 | ldr: [0, modulo_pixel_max] int, as float32 95 | """ 96 | 97 | def __init__(self, origin, modulo_pixel_max): 98 | self.origin = origin 99 | self.modulo_pixel_max = modulo_pixel_max 100 | 101 | self.modulo = self.origin % (self.modulo_pixel_max + 1) # modulo image 102 | self.fold_number = self.origin // (self.modulo_pixel_max + 1) # fold number map 103 | self.mask = np.float32((self.fold_number > 0)) # binary mask for the pixels which are overexposed 104 | self.ref = self.modulo + (self.modulo_pixel_max + 1) * self.mask # reference image(modulo "plus 1 fold") 105 | self.ldr = (self.modulo_pixel_max + 1) * self.mask + self.origin * (1 - self.mask) # ldr image 106 | 107 | self.fold_number_max = int(np.max(self.fold_number)) 108 | 109 | def add_one_fold(self): 110 | self.modulo = self.ref 111 | self.fold_number = self.fold_number - self.mask 112 | self.mask = np.float32((self.fold_number > 0)) 113 | self.ref = self.modulo + (self.modulo_pixel_max + 1) * self.mask 114 | 115 | 116 | class DatasetMaker: 117 | def __init__(self, data_dir, train_dir, test_dir, training_sample=400, modulo_bits=8, ref_bits=12, 118 | multi_fold_for_training=False, **random_exposer_args): 119 | self.data_dir = data_dir # contains the original hdr images 120 | self.train_dir = train_dir 121 | self.test_dir = test_dir 122 | self.modulo_pixel_max = 2 ** modulo_bits - 1 123 | self.ref_pixel_max = 2 ** ref_bits - 1 124 | self.multi_fold_for_training = multi_fold_for_training 125 | 126 | self.random_exposer_args = random_exposer_args 127 | 128 | self.data_list = os.listdir(data_dir) 129 | random.shuffle(self.data_list) 130 | self.train_data_list = self.data_list[:training_sample] 131 | self.test_data_list = self.data_list[training_sample:] 132 | 133 | self.train_subdir = {"origin": "", "modulo": "", "fold_number": "", "mask": "", "ref": "", "ldr": ""} 134 | self.test_subdir = {"origin": "", "modulo": "", "fold_number": "", "ldr": ""} 135 | 136 | self.save_cnt = 0 137 | 138 | def _ensure_dir(self, mode): 139 | assert mode in ("train", "test") 140 | mode_dir = getattr(self, mode + "_dir") 141 | mode_subdir = getattr(self, mode + "_subdir") 142 | for key in mode_subdir: 143 | mode_subdir[key] = os.path.join(mode_dir, key) 144 | if not os.path.exists(mode_subdir[key]): 145 | os.makedirs(mode_subdir[key]) 146 | 147 | def _save(self, data_item, name, mode): 148 | assert mode in ("train", "test") 149 | mode_subdir = getattr(self, mode + "_subdir") 150 | for key in mode_subdir: 151 | np.save(os.path.join(mode_subdir[key], name + ".npy"), getattr(data_item, key)) 152 | self.save_cnt += 1 153 | 154 | def make(self, mode): 155 | assert mode in ("train", "test") 156 | print("mode: {}\n".format(mode)) 157 | mode_data_list = getattr(self, mode + "_data_list") 158 | if not mode_data_list: 159 | return 160 | log = "" 161 | self._ensure_dir(mode) 162 | for hdr_file in tqdm(mode_data_list, ascii=True): 163 | hdr_name = hdr_file.split(".")[0] 164 | hdr_img = np.load(os.path.join(self.data_dir, hdr_file)) 165 | log += "--------------------\nname: {}\n".format(hdr_name) 166 | random_exposer = RandomExposer(hdr_img, **self.random_exposer_args) 167 | random_exposer.expose(self.modulo_pixel_max, self.ref_pixel_max) 168 | log += random_exposer.log 169 | if random_exposer.success: 170 | data_item = DataItem(random_exposer.hdr, self.modulo_pixel_max) 171 | log += "fold_number_max: {}\n".format(data_item.fold_number_max) 172 | self._save(data_item, hdr_name, mode) 173 | if mode == "train" and self.multi_fold_for_training: 174 | for i in range(1, data_item.fold_number_max): 175 | data_item.add_one_fold() 176 | name_new = hdr_name + "_plus" + str(i) 177 | self._save(data_item, name_new, mode) 178 | 179 | log += "\n----------Summary----------\n" 180 | log += "This is the dataset for *{}*\n".format((mode + "ing").upper()) 181 | log += "{} images in total".format(self.save_cnt) 182 | self.save_cnt = 0 183 | with open(mode + "_dataset.log", "w") as f: 184 | f.write(log) 185 | 186 | 187 | if __name__ == "__main__": 188 | parser = argparse.ArgumentParser(description="Make dataset") 189 | parser.add_argument("--data_dir", required=True, type=str, help="dir of original 512*512 hdr images") 190 | parser.add_argument("--train_dir", default="train", type=str, help="dir to save training dataset") 191 | parser.add_argument("--test_dir", default="test", type=str, help="dir to save test dataset") 192 | parser.add_argument("--training_sample", default=400, type=int, help="number of training sample") 193 | parser.add_argument("--modulo_bits", default=8, type=int, help="modulo image bits") 194 | parser.add_argument("--ref_bits", default=12, type=int, help="reference image(ground truth) bits") 195 | parser.add_argument("--multi_fold_for_training", default=1, type=int, 196 | help="make multi-fold data for training (set this to 0 when directly learn label in ablation)") 197 | parser.add_argument("--overexposure_rate_ubound", default=0.35, type=float, help="overexposure rate upper bound") 198 | parser.add_argument("--overexposure_rate_lbound", default=0.05, type=float, help="overexposure rate lower bound") 199 | parser.add_argument("--iter_max", default=15, type=int, help="the iteration limit of random exposer") 200 | args = parser.parse_args() 201 | 202 | dataset_maker = DatasetMaker(**vars(args)) 203 | dataset_maker.make("train") 204 | # for test 205 | dataset_maker.make("test") 206 | -------------------------------------------------------------------------------- /scripts/make_edge_map.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import cv2 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | 9 | def main(data_dir): 10 | modulo_dir = os.path.join(data_dir, 'modulo') 11 | fold_number_dir = os.path.join(data_dir, 'fold_number') 12 | modulo_edge_dir = os.path.join(data_dir, 'modulo_edge_dir') 13 | fold_number_edge_dir = os.path.join(data_dir, 'fold_number_edge') 14 | 15 | if not os.path.exists(modulo_edge_dir): 16 | os.mkdir(modulo_edge_dir) 17 | if not os.path.exists(fold_number_edge_dir): 18 | os.mkdir(fold_number_edge_dir) 19 | 20 | for name in tqdm(os.listdir(modulo_dir), ascii=True): 21 | # input 22 | modulo = np.load(os.path.join(modulo_dir, name)) # positive int, as float32 23 | fold_number = np.load(os.path.join(fold_number_dir, name)) # positive int, as float32 24 | 25 | laplacian_modulo = np.abs(cv2.Laplacian(modulo, -1)) 26 | laplacian_fold_number = np.abs(cv2.Laplacian(fold_number, -1)) 27 | 28 | if modulo.shape[2] == 1: 29 | # for grayscale image (H, W, 1), laplacian will output (H, W) 30 | # we need to expand the lost dim 31 | laplacian_modulo = laplacian_modulo[:, :, np.newaxis] 32 | laplacian_fold_number = laplacian_fold_number[:, :, np.newaxis] 33 | 34 | # to save 35 | modulo_edge = laplacian_modulo / np.max(laplacian_modulo) # [0, 1] float, as float32 36 | fold_number_edge = np.float32(laplacian_fold_number > 0) # binary, as float32 37 | np.save(os.path.join(modulo_edge_dir, name), modulo_edge) 38 | np.save(os.path.join(fold_number_edge_dir, name), fold_number_edge) 39 | 40 | 41 | if __name__ == "__main__": 42 | parser = argparse.ArgumentParser(description="Make edge map: from RGB(H, W, C) or grayscale(H, W, 1)") 43 | parser.add_argument("--data_dir", default="data", type=str, help="dir of modulo image") 44 | args = parser.parse_args() 45 | main(**vars(args)) 46 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fourson/UnModNet/70b72c595382556b96b66333ea3f82079c48dfa6/trainer/__init__.py -------------------------------------------------------------------------------- /trainer/trainer_LearnEdgeNet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torchvision.utils import make_grid 4 | 5 | from base.base_trainer import BaseTrainer 6 | from utils import util 7 | 8 | 9 | class DefaultTrainer(BaseTrainer): 10 | """ 11 | Trainer class 12 | 13 | Note: 14 | Inherited from BaseTrainer. 15 | """ 16 | 17 | def __init__(self, config, model, loss, metrics, optimizer, lr_scheduler, resume, data_loader, 18 | valid_data_loader=None, train_logger=None): 19 | super(DefaultTrainer, self).__init__(config, model, loss, metrics, optimizer, lr_scheduler, resume, 20 | train_logger) 21 | 22 | self.data_loader = data_loader 23 | self.valid_data_loader = valid_data_loader 24 | self.do_validation = self.valid_data_loader is not None 25 | self.log_step = int(np.sqrt(data_loader.batch_size)) 26 | 27 | def _eval_metrics(self, mask_pred, target): 28 | acc_metrics = np.zeros(len(self.metrics)) 29 | for i, metric in enumerate(self.metrics): 30 | acc_metrics[i] += metric(mask_pred, target) 31 | self.writer.add_scalar('{}'.format(metric.__name__), acc_metrics[i]) 32 | return acc_metrics 33 | 34 | def _train_epoch(self, epoch): 35 | """ 36 | Training logic for an epoch 37 | 38 | :param epoch: Current training epoch. 39 | :return: A log that contains all information you want to save. 40 | 41 | Note: 42 | If you have additional information to record, for example: 43 | > additional_log = {"x": x, "y": y} 44 | merge it with log before return. i.e. 45 | > log = {**log, **additional_log} 46 | > return log 47 | 48 | The metrics in log must have the key 'metrics'. 49 | """ 50 | # set the model to train mode 51 | self.model.train() 52 | 53 | total_loss = 0 54 | total_metrics = np.zeros(len(self.metrics)) 55 | 56 | # start training 57 | for batch_idx, sample in enumerate(self.data_loader): 58 | self.writer.set_step((epoch - 1) * len(self.data_loader) + batch_idx) 59 | 60 | # get data and send them to GPU 61 | # (N, C, H, W) GPU tensor 62 | modulo = sample['modulo'].to(self.device) # [0, 1] float, as float32 63 | modulo_edge = sample['modulo_edge'].to(self.device) # [0, 1] float, as float32 64 | fold_number_edge = sample['fold_number_edge'].to(self.device) # binary, as float32 65 | 66 | # get network output 67 | # (N, C, H, W) GPU tensor 68 | output = self.model(modulo, modulo_edge) 69 | 70 | # visualization 71 | with torch.no_grad(): 72 | # (N, C, H, W) GPU tensor 73 | fold_number_edge_pred = torch.round(torch.sigmoid(output)) 74 | 75 | if batch_idx % 100 == 0: 76 | modulo_tonemapped = util.tonemap(modulo.cpu()) 77 | 78 | # save images to tensorboardX 79 | self.writer.add_image('1_modulo', make_grid(modulo_tonemapped)) 80 | self.writer.add_image('2_modulo_edge', make_grid(modulo_edge)) 81 | self.writer.add_image('3_fold_number_edge_pred', make_grid(fold_number_edge_pred)) 82 | self.writer.add_image('4_fold_number_edge', make_grid(fold_number_edge)) 83 | 84 | # train model 85 | self.optimizer.zero_grad() 86 | model_loss = self.loss(output, fold_number_edge) 87 | model_loss.backward() 88 | self.optimizer.step() 89 | 90 | # calculate total loss/metrics and add scalar to tensorboard 91 | self.writer.add_scalar('loss', model_loss.item()) 92 | total_loss += model_loss.item() 93 | total_metrics += self._eval_metrics(fold_number_edge_pred, fold_number_edge) 94 | 95 | # show current training step info 96 | if self.verbosity >= 2 and batch_idx % self.log_step == 0: 97 | self.logger.info( 98 | 'Train Epoch: {} [{}/{} ({:.0f}%)] loss: {:.6f}'.format( 99 | epoch, 100 | batch_idx * self.data_loader.batch_size, 101 | self.data_loader.n_samples, 102 | 100.0 * batch_idx / len(self.data_loader), 103 | model_loss.item(), # it's a tensor, so we call .item() method 104 | ) 105 | ) 106 | 107 | # turn the learning rate 108 | self.lr_scheduler.step() 109 | 110 | # get batch average loss/metrics as log and do validation 111 | log = { 112 | 'loss': total_loss / len(self.data_loader), 113 | 'metrics': (total_metrics / len(self.data_loader)).tolist() 114 | } 115 | if self.do_validation: 116 | val_log = self._valid_epoch(epoch) 117 | log = {**log, **val_log} 118 | 119 | return log 120 | 121 | def _valid_epoch(self, epoch): 122 | """ 123 | Validate after training an epoch 124 | 125 | :return: A log that contains information about validation 126 | 127 | Note: 128 | The validation metrics in log must have the key 'val_metrics'. 129 | """ 130 | # set the model to validation mode 131 | self.model.eval() 132 | 133 | total_val_loss = 0 134 | total_val_metrics = np.zeros(len(self.metrics)) 135 | 136 | # start validating 137 | with torch.no_grad(): 138 | for batch_idx, sample in enumerate(self.valid_data_loader): 139 | self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') 140 | 141 | # get data and send them to GPU 142 | # (N, C, H, W) GPU tensor 143 | modulo = sample['modulo'].to(self.device) # [0, 1] float, as float32 144 | modulo_edge = sample['modulo_edge'].to(self.device) # [0, 1] float, as float32 145 | fold_number_edge = sample['fold_number_edge'].to(self.device) # binary, as float32 146 | 147 | # infer and calculate the loss 148 | # (N, C, H, W) GPU tensor 149 | output = self.model(modulo, modulo_edge) 150 | fold_number_edge_pred = torch.round(torch.sigmoid(output)) 151 | loss = self.loss(output, fold_number_edge) 152 | 153 | # calculate total loss/metrics and add scalar to tensorboardX 154 | self.writer.add_scalar('loss', loss.item()) 155 | total_val_loss += loss.item() 156 | total_val_metrics += self._eval_metrics(fold_number_edge_pred, fold_number_edge) 157 | 158 | # add histogram of model parameters to the tensorboard 159 | for name, p in self.model.named_parameters(): 160 | self.writer.add_histogram(name, p, bins='auto') 161 | 162 | return { 163 | 'val_loss': total_val_loss / len(self.valid_data_loader), 164 | 'val_metrics': (total_val_metrics / len(self.valid_data_loader)).tolist() 165 | } 166 | -------------------------------------------------------------------------------- /trainer/trainer_LearnMaskNet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torchvision.utils import make_grid 4 | 5 | from base.base_trainer import BaseTrainer 6 | from utils import util 7 | 8 | 9 | class DefaultTrainer(BaseTrainer): 10 | """ 11 | Trainer class 12 | 13 | Note: 14 | Inherited from BaseTrainer. 15 | """ 16 | 17 | def __init__(self, config, model, loss, metrics, optimizer, lr_scheduler, resume, data_loader, 18 | valid_data_loader=None, train_logger=None): 19 | super(DefaultTrainer, self).__init__(config, model, loss, metrics, optimizer, lr_scheduler, resume, 20 | train_logger) 21 | 22 | self.data_loader = data_loader 23 | self.valid_data_loader = valid_data_loader 24 | self.do_validation = self.valid_data_loader is not None 25 | self.log_step = int(np.sqrt(data_loader.batch_size)) 26 | 27 | def _eval_metrics(self, mask_pred, target): 28 | acc_metrics = np.zeros(len(self.metrics)) 29 | for i, metric in enumerate(self.metrics): 30 | acc_metrics[i] += metric(mask_pred, target) 31 | self.writer.add_scalar('{}'.format(metric.__name__), acc_metrics[i]) 32 | return acc_metrics 33 | 34 | def _train_epoch(self, epoch): 35 | """ 36 | Training logic for an epoch 37 | 38 | :param epoch: Current training epoch. 39 | :return: A log that contains all information you want to save. 40 | 41 | Note: 42 | If you have additional information to record, for example: 43 | > additional_log = {"x": x, "y": y} 44 | merge it with log before return. i.e. 45 | > log = {**log, **additional_log} 46 | > return log 47 | 48 | The metrics in log must have the key 'metrics'. 49 | """ 50 | # set the model to train mode 51 | self.model.train() 52 | 53 | total_loss = 0 54 | total_metrics = np.zeros(len(self.metrics)) 55 | 56 | # start training 57 | for batch_idx, sample in enumerate(self.data_loader): 58 | self.writer.set_step((epoch - 1) * len(self.data_loader) + batch_idx) 59 | 60 | # get data and send them to GPU 61 | # (N, C, H, W) GPU tensor 62 | modulo = sample['modulo'].to(self.device) # [0, 1] float, as float32 63 | fold_number_edge = sample['fold_number_edge'].to(self.device) # binary, as float32 64 | mask = sample['mask'].to(self.device) # binary, as float32 65 | 66 | # get network output 67 | # (N, C, H, W) GPU tensor 68 | output = self.model(modulo, fold_number_edge) 69 | 70 | # visualization 71 | with torch.no_grad(): 72 | # (N, C, H, W) GPU tensor 73 | mask_pred = torch.round(torch.sigmoid(output)) 74 | 75 | if batch_idx % 100 == 0: 76 | modulo_tonemapped = util.tonemap(modulo.cpu()) 77 | 78 | # save images to tensorboardX 79 | self.writer.add_image('1_modulo', make_grid(modulo_tonemapped)) 80 | self.writer.add_image('2_fold_number_edge', make_grid(fold_number_edge)) 81 | self.writer.add_image('3_mask_pred', make_grid(mask_pred)) 82 | self.writer.add_image('4_mask', make_grid(mask)) 83 | 84 | # train model 85 | self.optimizer.zero_grad() 86 | model_loss = self.loss(output, mask) 87 | model_loss.backward() 88 | self.optimizer.step() 89 | 90 | # calculate total loss/metrics and add scalar to tensorboard 91 | self.writer.add_scalar('loss', model_loss.item()) 92 | total_loss += model_loss.item() 93 | total_metrics += self._eval_metrics(mask_pred, mask) 94 | 95 | # show current training step info 96 | if self.verbosity >= 2 and batch_idx % self.log_step == 0: 97 | self.logger.info( 98 | 'Train Epoch: {} [{}/{} ({:.0f}%)] loss: {:.6f}'.format( 99 | epoch, 100 | batch_idx * self.data_loader.batch_size, 101 | self.data_loader.n_samples, 102 | 100.0 * batch_idx / len(self.data_loader), 103 | model_loss.item(), # it's a tensor, so we call .item() method 104 | ) 105 | ) 106 | 107 | # turn the learning rate 108 | self.lr_scheduler.step() 109 | 110 | # get batch average loss/metrics as log and do validation 111 | log = { 112 | 'loss': total_loss / len(self.data_loader), 113 | 'metrics': (total_metrics / len(self.data_loader)).tolist() 114 | } 115 | if self.do_validation: 116 | val_log = self._valid_epoch(epoch) 117 | log = {**log, **val_log} 118 | 119 | return log 120 | 121 | def _valid_epoch(self, epoch): 122 | """ 123 | Validate after training an epoch 124 | 125 | :return: A log that contains information about validation 126 | 127 | Note: 128 | The validation metrics in log must have the key 'val_metrics'. 129 | """ 130 | # set the model to validation mode 131 | self.model.eval() 132 | 133 | total_val_loss = 0 134 | total_val_metrics = np.zeros(len(self.metrics)) 135 | 136 | # start validating 137 | with torch.no_grad(): 138 | for batch_idx, sample in enumerate(self.valid_data_loader): 139 | self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') 140 | 141 | # get data and send them to GPU 142 | # (N, C, H, W) GPU tensor 143 | modulo = sample['modulo'].to(self.device) # [0, 1] float, as float32 144 | fold_number_edge = sample['fold_number_edge'].to(self.device) # binary, as float32 145 | mask = sample['mask'].to(self.device) # binary, as float32 146 | 147 | # infer and calculate the loss 148 | # (N, C, H, W) GPU tensor 149 | output = self.model(modulo, fold_number_edge) 150 | mask_pred = torch.round(torch.sigmoid(output)) 151 | loss = self.loss(output, mask) 152 | 153 | # calculate total loss/metrics and add scalar to tensorboardX 154 | self.writer.add_scalar('loss', loss.item()) 155 | total_val_loss += loss.item() 156 | total_val_metrics += self._eval_metrics(mask_pred, mask) 157 | 158 | # add histogram of model parameters to the tensorboard 159 | for name, p in self.model.named_parameters(): 160 | self.writer.add_histogram(name, p, bins='auto') 161 | 162 | return { 163 | 'val_loss': total_val_loss / len(self.valid_data_loader), 164 | 'val_metrics': (total_val_metrics / len(self.valid_data_loader)).tolist() 165 | } 166 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fourson/UnModNet/70b72c595382556b96b66333ea3f82079c48dfa6/utils/__init__.py -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | 4 | logging.basicConfig(level=logging.INFO, format='') 5 | 6 | 7 | class Logger: 8 | """ 9 | Training process logger 10 | 11 | Note: 12 | Used by BaseTrainer to save training history. 13 | """ 14 | def __init__(self): 15 | self.entries = {} 16 | 17 | def add_entry(self, entry): 18 | self.entries[len(self.entries) + 1] = entry 19 | 20 | def __str__(self): 21 | return json.dumps(self.entries, sort_keys=True, indent=4) 22 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import cv2 6 | import numpy as np 7 | 8 | TonemapReinhard = cv2.createTonemapReinhard(intensity=-1.0, light_adapt=0.8, color_adapt=0.0) 9 | Laplacian = torch.tensor([[0, 1, 0], [1, -4, 1], [0, 1, 0]], dtype=torch.float32).cuda() 10 | 11 | 12 | def ensure_dir(path): 13 | if not os.path.exists(path): 14 | os.makedirs(path) 15 | 16 | 17 | def get_lr_lambda(lr_lambda_tag): 18 | if lr_lambda_tag == 'original': 19 | # 400 epoch 20 | # 1~200 ep: 1 21 | # 201~400 ep: linear decays to 0.5 22 | return lambda epoch: (600 - epoch) / 400 if epoch > 200 else 1 23 | elif lr_lambda_tag == 'grayscale': 24 | # 300 epoch 25 | # 1~200 ep: 1 26 | # 201~300 ep: linear decays to 0 27 | return lambda epoch: (300 - epoch) / 100 if epoch > 100 else 1 28 | elif lr_lambda_tag == 'temp': 29 | return lambda epoch: 1 30 | else: 31 | raise NotImplementedError('lr_lambda_tag [%s] is not found' % lr_lambda_tag) 32 | 33 | 34 | def tonemap(hdr_tensor): 35 | # tonemap hdr image tensor(N, C, H, W) for visualization 36 | tonemapped_tensor = torch.zeros(hdr_tensor.shape, dtype=torch.float32) 37 | for i in range(hdr_tensor.shape[0]): 38 | hdr = hdr_tensor[i].numpy().transpose((1, 2, 0)) # (H, W, C) 39 | is_rgb = (hdr.shape[2] == 3) 40 | if is_rgb: 41 | # if RGB (H, W, 3) , we should convert to an (H, W, 3) numpy array in order of BGR before tonemapping 42 | hdr = cv2.cvtColor(hdr, cv2.COLOR_RGB2BGR) 43 | else: 44 | # if grayscale (H ,W, 1), we should copy the image 3 times to an (H, W, 3) numpy array before tonemapping 45 | hdr = cv2.merge([hdr, hdr, hdr]) 46 | hdr = (hdr - np.min(hdr)) / (np.max(hdr) - np.min(hdr)) 47 | tonemapped = TonemapReinhard.process(hdr) 48 | if is_rgb: 49 | # back to (C, H, W) tensor in order of RGB 50 | tonemapped_tensor[i] = torch.from_numpy(cv2.cvtColor(tonemapped, cv2.COLOR_BGR2RGB).transpose((2, 0, 1))) 51 | else: 52 | tonemapped_tensor[i] = torch.from_numpy(tonemapped[:, :, 0:1].transpose((2, 0, 1))) 53 | return tonemapped_tensor 54 | 55 | 56 | def torch_laplacian(img_tensor): 57 | # (N, C, H, W) image tensor -> (N, C, H, W) edge tensor, the same as cv2.Laplacian 58 | pad = [1, 1, 1, 1] 59 | laplacian_kernel = Laplacian.view(1, 1, 3, 3) 60 | edge_tensor = torch.zeros(img_tensor.shape, dtype=torch.float32).cuda() 61 | for i in range(img_tensor.shape[1]): 62 | padded = F.pad(img_tensor[:, i:i + 1, :, :], pad, mode='reflect') 63 | edge_tensor[:, i:i + 1, :, :] = F.conv2d(padded, laplacian_kernel) 64 | return edge_tensor 65 | 66 | 67 | def torch_convertScaleAbs(img_tensor, alpha=1.0, beta=0.0): 68 | # (N, C, H, W) tensor -> (N, C, H, W) tensor, the same as cv2.convertScaleAbs(but return as float) 69 | scaled = img_tensor * alpha + beta 70 | abs = torch.abs(scaled) 71 | return torch.clamp(torch.round(abs), max=255.) 72 | -------------------------------------------------------------------------------- /utils/visualization.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | 4 | class WriterTensorboardX: 5 | def __init__(self, writer_dir, logger, enable): 6 | self.writer = None 7 | if enable: 8 | log_path = writer_dir 9 | try: 10 | self.writer = importlib.import_module('tensorboardX').SummaryWriter(log_path) 11 | except ImportError: 12 | message = "Warning: TensorboardX visualization is configured to use, but currently not installed on " \ 13 | "this machine. Please install the package by 'pip install tensorboardx' command or turn " \ 14 | "off the option in the 'config.json' file." 15 | logger.warning(message) 16 | self.step = 0 17 | self.mode = '' 18 | 19 | self.tb_writer_ftns = [ 20 | 'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio', 21 | 'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding' 22 | ] 23 | self.tag_mode_exceptions = ['add_histogram', 'add_embedding'] 24 | 25 | def set_step(self, step, mode='train'): 26 | self.mode = mode 27 | self.step = step 28 | 29 | def __getattr__(self, name): 30 | """ 31 | If visualization is configured to use: 32 | return add_data() methods of tensorboard with additional information (step, tag) added. 33 | Otherwise: 34 | return a blank function handle that does nothing 35 | """ 36 | if name in self.tb_writer_ftns: 37 | add_data = getattr(self.writer, name, None) 38 | 39 | def wrapper(tag, data, *args, **kwargs): 40 | if add_data is not None: 41 | # add mode(train/valid) tag 42 | if name not in self.tag_mode_exceptions: 43 | tag = '{}/{}'.format(self.mode, tag) 44 | add_data(tag, data, self.step, *args, **kwargs) 45 | return wrapper 46 | else: 47 | # default action for returning methods defined in this class, set_step() for instance. 48 | try: 49 | attr = object.__getattr__(name) 50 | except AttributeError: 51 | raise AttributeError("type object 'WriterTensorboardX' has no attribute '{}'".format(name)) 52 | return attr 53 | --------------------------------------------------------------------------------