├── scripts ├── datasets │ ├── __init__.py │ ├── COCO.py │ └── BIH.py ├── models │ ├── __init__.py │ ├── backbone_unet.py │ ├── vgg.py │ ├── rasc.py │ ├── discriminator.py │ ├── unet.py │ ├── blocks.py │ ├── vmu.py │ └── sa_resunet.py ├── utils │ ├── __init__.py │ ├── osutils.py │ ├── model_init.py │ ├── misc.py │ ├── evaluation.py │ ├── logger.py │ ├── transforms.py │ ├── imutils.py │ ├── losses.py │ └── parallel.py ├── machines │ ├── __init__.py │ ├── S2AM.py │ ├── BasicMachine.py │ └── VX.py └── __init__.py ├── requirements.txt ├── examples ├── test.sh └── evaluate.sh ├── test.py ├── README.md ├── main.py ├── watermark_synthesis.ipynb └── options.py /scripts/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .COCO import COCO 2 | from .BIH import BIH 3 | 4 | __all__ = ('COCO','BIH') -------------------------------------------------------------------------------- /scripts/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .vgg import * 2 | from .backbone_unet import * 3 | from .discriminator import * 4 | 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.19.1 2 | opencv-python==3.4.8.29 3 | Pillow 4 | scikit-image==0.14.5 5 | scikit-learn==0.23.1 6 | scipy==1.2.1 7 | sklearn==0.0 8 | tensorboardX 9 | torch>=1.0.0 10 | torchvision -------------------------------------------------------------------------------- /scripts/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .evaluation import * 4 | from .imutils import * 5 | from .logger import * 6 | from .misc import * 7 | from .osutils import * 8 | from .transforms import * 9 | -------------------------------------------------------------------------------- /scripts/machines/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .BasicMachine import BasicMachine 3 | from .VX import VX 4 | from .S2AM import S2AM 5 | 6 | def basic(**kwargs): 7 | return BasicMachine(**kwargs) 8 | 9 | def s2am(**kwargs): 10 | return S2AM(**kwargs) 11 | 12 | def vx(**kwargs): 13 | return VX(**kwargs) 14 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from . import datasets 4 | from . import models 5 | from . import utils 6 | 7 | # import os, sys 8 | # sys.path.append(os.path.join(os.path.dirname(__file__), "progress")) 9 | # from progress.bar import Bar as Bar 10 | 11 | # __version__ = '0.1.0' -------------------------------------------------------------------------------- /examples/test.sh: -------------------------------------------------------------------------------- 1 | 2 | set -ex 3 | 4 | CUDA_VISIBLE_DEVICES=0 python /data/home/yb87432/s2am/test.py \ 5 | -c test/10kgray_ssim\ 6 | --resume /data/home/yb87432/s2am/eval/10kgray/1e3_bs6_256_hybrid_ssim_vgg_vx__images_vvv4n/model_best.pth.tar\ 7 | --arch vvv4n\ 8 | --machine vx\ 9 | --input-size 256\ 10 | --test-batch 1\ 11 | --evaluate\ 12 | --base-dir $HOME/watermark/10kgray/\ 13 | --data _images -------------------------------------------------------------------------------- /scripts/utils/osutils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import os 4 | import errno 5 | 6 | def mkdir_p(dir_path): 7 | try: 8 | os.makedirs(dir_path) 9 | except OSError as e: 10 | if e.errno != errno.EEXIST: 11 | raise 12 | 13 | def isfile(fname): 14 | return os.path.isfile(fname) 15 | 16 | def isdir(dirname): 17 | return os.path.isdir(dirname) 18 | 19 | def join(path, *paths): 20 | return os.path.join(path, *paths) 21 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | import argparse 4 | import torch 5 | 6 | torch.backends.cudnn.benchmark = True 7 | 8 | from scripts.utils.misc import save_checkpoint, adjust_learning_rate 9 | 10 | import scripts.datasets as datasets 11 | import scripts.machines as machines 12 | from options import Options 13 | 14 | def main(args): 15 | 16 | val_loader = torch.utils.data.DataLoader(datasets.COCO('val',args),batch_size=args.test_batch, shuffle=False, 17 | num_workers=args.workers, pin_memory=True) 18 | 19 | data_loaders = (None,val_loader) 20 | 21 | Machine = machines.__dict__[args.machine](datasets=data_loaders, args=args) 22 | 23 | Machine.test() 24 | 25 | if __name__ == '__main__': 26 | parser=Options().init(argparse.ArgumentParser(description='WaterMark Removal')) 27 | main(parser.parse_args()) 28 | -------------------------------------------------------------------------------- /examples/evaluate.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | 3 | 4 | 5 | # example training scripts for AAAI-21 6 | # Split then Refine: Stacked Attention-guided ResUNets for Blind Single Image Visible Watermark Removal 7 | 8 | 9 | CUDA_VISIBLE_DEVICES=0 python /data/home/yb87432/s2am/main.py --epochs 100\ 10 | --schedule 100\ 11 | --lr 1e-3\ 12 | -c eval/10kgray/1e3_bs4_256_hybrid_ssim_vgg\ 13 | --arch vvv4n\ 14 | --sltype vggx\ 15 | --style-loss 0.025\ 16 | --ssim-loss 0.15\ 17 | --masked True\ 18 | --loss-type hybrid\ 19 | --limited-dataset 1\ 20 | --machine vx\ 21 | --input-size 256\ 22 | --train-batch 4\ 23 | --test-batch 1\ 24 | --base-dir $HOME/watermark/10kgray/\ 25 | --data _images 26 | 27 | 28 | 29 | 30 | 31 | # example training scripts for TIP-20 32 | # Improving the Harmony of the Composite Image by Spatial-Separated Attention Module 33 | # * in the original version, the res = False 34 | # suitable for the iHarmony4 dataset. 35 | 36 | python /data/home/yb87432/mypaper/s2am/main.py --epochs 200\ 37 | --schedule 150\ 38 | --lr 1e-3\ 39 | -c checkpoint/normal_rasc_HAdobe5k_res \ 40 | --arch rascv2\ 41 | --style-loss 0\ 42 | --ssim-loss 0\ 43 | --limited-dataset 0\ 44 | --res True\ 45 | --machine s2am\ 46 | --input-size 256\ 47 | --train-batch 16\ 48 | --test-batch 1\ 49 | --base-dir $HOME/Datasets/\ 50 | --data HAdobe5k -------------------------------------------------------------------------------- /scripts/models/backbone_unet.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | import torchvision 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import numpy as np 8 | import functools 9 | import math 10 | 11 | from scripts.utils.model_init import * 12 | from scripts.models.rasc import * 13 | from scripts.models.unet import UnetGenerator,MinimalUnetV2 14 | from scripts.models.vmu import UnetVM 15 | from scripts.models.sa_resunet import UnetVMS2AMv4 16 | 17 | 18 | # our method 19 | def vvv4n(**kwargs): 20 | return UnetVMS2AMv4(shared_depth=2, blocks=3, long_skip=True, use_vm_decoder=True,s2am='vms2am') 21 | 22 | 23 | # BVMR 24 | def vm3(**kwargs): 25 | return UnetVM(shared_depth=2, blocks=3, use_vm_decoder=True) 26 | 27 | 28 | # Blind version of S2AM 29 | def urasc(**kwargs): 30 | model = UnetGenerator(3,3,is_attention_layer=True,attention_model=URASC,basicblock=MinimalUnetV2) 31 | model.apply(weights_init_kaiming) 32 | return model 33 | 34 | 35 | # Improving the Harmony of the Composite Image by Spatial-Separated Attention Module 36 | # Xiaodong Cun and Chi-Man Pun 37 | # University of Macau 38 | # Trans. on Image Processing, vol. 29, pp. 4759-4771, 2020. 39 | def rascv2(**kwargs): 40 | model = UnetGenerator(4,3,is_attention_layer=True,attention_model=RASC,basicblock=MinimalUnetV2) 41 | model.apply(weights_init_kaiming) 42 | return model 43 | 44 | # just original unet 45 | def unet(**kwargs): 46 | model = UnetGenerator(3,3) 47 | model.apply(weights_init_kaiming) 48 | return model 49 | 50 | 51 | -------------------------------------------------------------------------------- /scripts/utils/model_init.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from torch.nn import init 4 | 5 | 6 | def weights_init_normal(m): 7 | classname = m.__class__.__name__ 8 | # print(classname) 9 | if classname.find('Conv') != -1: 10 | init.normal_(m.weight.data, 0.0, 0.02) 11 | elif classname.find('Linear') != -1: 12 | init.normal_(m.weight.data, 0.0, 0.02) 13 | elif classname.find('BatchNorm2d') != -1: 14 | init.normal_(m.weight.data, 1.0, 0.02) 15 | init.constant_(m.bias.data, 0.0) 16 | 17 | 18 | def weights_init_xavier(m): 19 | classname = m.__class__.__name__ 20 | # print(classname) 21 | if classname.find('Conv') != -1: 22 | init.xavier_normal(m.weight.data, gain=0.02) 23 | elif classname.find('Linear') != -1: 24 | init.xavier_normal(m.weight.data, gain=0.02) 25 | # elif classname.find('BatchNorm2d') != -1: 26 | # init.normal(m.weight.data, 1.0, 0.02) 27 | # init.constant(m.bias.data, 0.0) 28 | 29 | 30 | def weights_init_kaiming(m): 31 | classname = m.__class__.__name__ 32 | # print(classname) 33 | if classname.find('Conv') != -1 and m.weight.requires_grad == True: 34 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 35 | elif classname.find('Linear') != -1 and m.weight.requires_grad == True: 36 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 37 | elif classname.find('BatchNorm2d') != -1 and m.weight.requires_grad == True: 38 | init.normal_(m.weight.data, 1.0, 0.02) 39 | init.constant_(m.bias.data, 0.0) 40 | 41 | 42 | def weights_init_orthogonal(m): 43 | classname = m.__class__.__name__ 44 | if classname.find('Conv') != -1: 45 | init.orthogonal(m.weight.data, gain=1) 46 | elif classname.find('Linear') != -1: 47 | init.orthogonal(m.weight.data, gain=1) 48 | # elif classname.find('BatchNorm2d') != -1: 49 | # init.normal(m.weight.data, 1.0, 0.02) 50 | # init.constant(m.bias.data, 0.0) -------------------------------------------------------------------------------- /scripts/utils/misc.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import os 4 | import shutil 5 | import torch 6 | import math 7 | import numpy as np 8 | import scipy.io 9 | import matplotlib.pyplot as plt 10 | import torch.nn.functional as F 11 | 12 | def to_numpy(tensor): 13 | if torch.is_tensor(tensor): 14 | return tensor.cpu().numpy() 15 | elif type(tensor).__module__ != 'numpy': 16 | raise ValueError("Cannot convert {} to numpy array" 17 | .format(type(tensor))) 18 | return tensor 19 | 20 | def resize_to_match(fm,to): 21 | # just use interpolate 22 | # [1,3] = (h,w) 23 | return F.interpolate(fm,to.size()[-2:],mode='bilinear',align_corners=False) 24 | 25 | def to_torch(ndarray): 26 | if type(ndarray).__module__ == 'numpy': 27 | return torch.from_numpy(ndarray) 28 | elif not torch.is_tensor(ndarray): 29 | raise ValueError("Cannot convert {} to torch tensor" 30 | .format(type(ndarray))) 31 | return ndarray 32 | 33 | 34 | def save_checkpoint(machine,filename='checkpoint.pth.tar', snapshot=None): 35 | is_best = True if machine.best_acc < machine.metric else False 36 | 37 | if is_best: 38 | machine.best_acc = machine.metric 39 | 40 | state = { 41 | 'epoch': machine.current_epoch + 1, 42 | 'arch': machine.args.arch, 43 | 'state_dict': machine.model.state_dict(), 44 | 'best_acc': machine.best_acc, 45 | 'optimizer' : machine.optimizer.state_dict(), 46 | } 47 | 48 | filepath = os.path.join(machine.args.checkpoint, filename) 49 | torch.save(state, filepath) 50 | 51 | if snapshot and state['epoch'] % snapshot == 0: 52 | shutil.copyfile(filepath, os.path.join(machine.args.checkpoint, 'checkpoint_{}.pth.tar'.format(state.epoch))) 53 | 54 | if is_best: 55 | machine.best_acc = machine.metric 56 | print('Saving Best Metric with PSNR:%s'%machine.best_acc) 57 | shutil.copyfile(filepath, os.path.join(machine.args.checkpoint, 'model_best.pth.tar')) 58 | 59 | 60 | 61 | def save_pred(preds, checkpoint='checkpoint', filename='preds_valid.mat'): 62 | preds = to_numpy(preds) 63 | filepath = os.path.join(checkpoint, filename) 64 | scipy.io.savemat(filepath, mdict={'preds' : preds}) 65 | 66 | 67 | def adjust_learning_rate(datasets,optimizer, epoch, lr,args): 68 | """Sets the learning rate to the initial LR decayed by schedule""" 69 | if epoch in args.schedule: 70 | lr *= args.gamma 71 | for param_group in optimizer.param_groups: 72 | param_group['lr'] = lr 73 | 74 | # decay sigma 75 | for dset in datasets: 76 | if args.sigma_decay > 0: 77 | dset.dataset.sigma *= args.sigma_decay 78 | dset.dataset.sigma *= args.sigma_decay 79 | 80 | return lr 81 | 82 | 83 | 84 | 85 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repo contains the code and results of the AAAI 2021 paper: 2 | 3 | [Split then Refine: Stacked Attention-guided ResUNets for Blind Single Image Visible Watermark Removal](https://arxiv.org/abs/2012.07007)
4 | [Xiaodong Cun](http://vinthony.github.io), [Chi-Man Pun*](http://www.cis.umac.mo/~cmpun/)
5 | [University of Macau](http://um.edu.mo/) 6 | 7 | [Datasets](#Resources) | [Models](#Resources) | [Paper](https://arxiv.org/abs/2012.07007) | [🔥Online Demo!](https://colab.research.google.com/drive/1pYY7byBjM-7aFIWk8HcF9nK_s6pqGwww?usp=sharing)(Google CoLab) 8 | 9 |
10 | 11 | nn 12 | 13 | The overview of the proposed two-stage framework. Firstly, we propose a multi-task network, SplitNet, for watermark detection, removal, and recovery. Then, we propose the RefineNet to smooth the learned region with the predicted mask and the recovered background from the previous stage. As a consequence, our network can be trained in an end-to-end fashion without any manual intervention. Note that, for clarity, we do not show any skip-connections between all the encoders and decoders. 14 |
15 | 16 | > The whole project will be released in the January of 2021 (almost). 17 | 18 | 19 | ### Datasets 20 | 21 | We synthesized four different datasets for training and testing, you can download the dataset via [huggingface](https://huggingface.co/datasets/vinthony/watermark-removal-logo/tree/main). 22 | 23 | ![image](https://user-images.githubusercontent.com/4397546/104273158-74413900-54d9-11eb-95fa-c6bee94de0ea.png) 24 | 25 | 26 | ### Pre-trained Models 27 | 28 | * [27kpng_model_best.pth.tar (google drive)](https://drive.google.com/file/d/1KpSJ6385CHN6WlAINqB3CYrJdleQTJBc/view?usp=sharing) 29 | 30 | > Other Pre-trained Models are still reorganizing and uploading, it will be released soon. 31 | 32 | 33 | ### Demos 34 | 35 | An easy-to-use online demo can be founded in [google colab](https://colab.research.google.com/drive/1pYY7byBjM-7aFIWk8HcF9nK_s6pqGwww?usp=sharing). 36 | 37 | The local demo will be released soon. 38 | 39 | ### Pre-requirements 40 | 41 | ``` 42 | pip install -r requirements.txt 43 | ``` 44 | 45 | ### Train 46 | 47 | Besides training our methods, here, we also give an example of how to train the [s2am](https://github.com/vinthony/s2am) under our framework. More details can be found in the shell scripts. 48 | 49 | 50 | ``` 51 | bash examples/evaluation.sh 52 | ``` 53 | 54 | ### Test 55 | 56 | ``` 57 | bash examples/test.sh 58 | ``` 59 | 60 | ## **Acknowledgements** 61 | The author would like to thanks Nan Chen for her helpful discussion. 62 | 63 | Part of the code is based upon our previous work on image harmonization [s2am](https://github.com/vinthony/s2am) 64 | 65 | ## **Citation** 66 | 67 | If you find our work useful in your research, please consider citing: 68 | 69 | ``` 70 | @misc{cun2020split, 71 | title={Split then Refine: Stacked Attention-guided ResUNets for Blind Single Image Visible Watermark Removal}, 72 | author={Xiaodong Cun and Chi-Man Pun}, 73 | year={2020}, 74 | eprint={2012.07007}, 75 | archivePrefix={arXiv}, 76 | primaryClass={cs.CV} 77 | } 78 | ``` 79 | 80 | ## **Contact** 81 | Please contact me if there is any question (Xiaodong Cun yb87432@um.edu.mo) 82 | -------------------------------------------------------------------------------- /scripts/models/vgg.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import torch 4 | from torchvision import models 5 | 6 | 7 | class Vgg16(torch.nn.Module): 8 | def __init__(self, requires_grad=False): 9 | super(Vgg16, self).__init__() 10 | vgg_pretrained_features = models.vgg16(pretrained=True).features 11 | self.slice1 = torch.nn.Sequential() 12 | self.slice2 = torch.nn.Sequential() 13 | self.slice3 = torch.nn.Sequential() 14 | self.slice4 = torch.nn.Sequential() 15 | self.slice5 = torch.nn.Sequential() 16 | for x in range(4): 17 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 18 | for x in range(4, 9): 19 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 20 | for x in range(9, 16): 21 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 22 | for x in range(16, 23): 23 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 24 | for x in range(23,30): 25 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 26 | 27 | if not requires_grad: 28 | for param in self.parameters(): 29 | param.requires_grad = False 30 | 31 | def forward(self, X): 32 | h = self.slice1(X) 33 | h_relu1_2 = h 34 | h = self.slice2(h) 35 | h_relu2_2 = h 36 | h = self.slice3(h) 37 | h_relu3_3 = h 38 | h = self.slice4(h) 39 | h_relu4_3 = h 40 | h = self.slice5(h) 41 | h_relu5_3 = h 42 | # vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3','relu5_3']) 43 | # out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 44 | return (h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 45 | 46 | 47 | class Vgg19(torch.nn.Module): 48 | def __init__(self, requires_grad=False): 49 | super(Vgg19, self).__init__() 50 | # vgg_pretrained_features = models.vgg19(pretrained=True).features 51 | self.vgg_pretrained_features = models.vgg19(pretrained=True).features 52 | # self.slice1 = torch.nn.Sequential() 53 | # self.slice2 = torch.nn.Sequential() 54 | # self.slice3 = torch.nn.Sequential() 55 | # self.slice4 = torch.nn.Sequential() 56 | # self.slice5 = torch.nn.Sequential() 57 | # for x in range(2): 58 | # self.slice1.add_module(str(x), vgg_pretrained_features[x]) 59 | # for x in range(2, 7): 60 | # self.slice2.add_module(str(x), vgg_pretrained_features[x]) 61 | # for x in range(7, 12): 62 | # self.slice3.add_module(str(x), vgg_pretrained_features[x]) 63 | # for x in range(12, 21): 64 | # self.slice4.add_module(str(x), vgg_pretrained_features[x]) 65 | # for x in range(21, 30): 66 | # self.slice5.add_module(str(x), vgg_pretrained_features[x]) 67 | if not requires_grad: 68 | for param in self.parameters(): 69 | param.requires_grad = False 70 | 71 | def forward(self, X, indices=None): 72 | if indices is None: 73 | indices = [2, 7, 12, 21, 30] 74 | out = [] 75 | #indices = sorted(indices) 76 | for i in range(indices[-1]): 77 | X = self.vgg_pretrained_features[i](X) 78 | if (i+1) in indices: 79 | out.append(X) 80 | 81 | return out 82 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | import argparse 4 | import torch,time,os 5 | 6 | torch.backends.cudnn.benchmark = True 7 | 8 | from scripts.utils.misc import save_checkpoint, adjust_learning_rate 9 | 10 | import scripts.datasets as datasets 11 | import scripts.machines as machines 12 | from options import Options 13 | 14 | def main(args): 15 | 16 | if 'HFlickr' or 'HCOCO' or 'Hday2night' or 'HAdobe5k' in args.base_dir: 17 | dataset_func = datasets.BIH 18 | else: 19 | dataset_func = datasets.COCO 20 | 21 | train_loader = torch.utils.data.DataLoader(dataset_func('train',args),batch_size=args.train_batch, shuffle=True, 22 | num_workers=args.workers, pin_memory=True) 23 | 24 | val_loader = torch.utils.data.DataLoader(dataset_func('val',args),batch_size=args.test_batch, shuffle=False, 25 | num_workers=args.workers, pin_memory=True) 26 | 27 | lr = args.lr 28 | data_loaders = (train_loader,val_loader) 29 | 30 | Machine = machines.__dict__[args.machine](datasets=data_loaders, args=args) 31 | print('============================ Initization Finish && Training Start =============================================') 32 | 33 | for epoch in range(Machine.args.start_epoch, Machine.args.epochs): 34 | 35 | print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr)) 36 | lr = adjust_learning_rate(data_loaders, Machine.optimizer, epoch, lr, args) 37 | 38 | Machine.record('lr',lr, epoch) 39 | Machine.train(epoch) 40 | 41 | if args.freq < 0: 42 | Machine.validate(epoch) 43 | Machine.flush() 44 | Machine.save_checkpoint() 45 | 46 | if __name__ == '__main__': 47 | parser=Options().init(argparse.ArgumentParser(description='WaterMark Removal')) 48 | args = parser.parse_args() 49 | print('==================================== WaterMark Removal =============================================') 50 | print('==> {:50}: {:<}'.format("Start Time",time.ctime(time.time()))) 51 | print('==> {:50}: {:<}'.format("USE GPU",os.environ['CUDA_VISIBLE_DEVICES'])) 52 | print('==================================== Stable Parameters =============================================') 53 | for arg in vars(args): 54 | if type(getattr(args, arg)) == type([]): 55 | if ','.join([ str(i) for i in getattr(args, arg)]) == ','.join([ str(i) for i in parser.get_default(arg)]): 56 | print('==> {:50}: {:<}({:<})'.format(arg,','.join([ str(i) for i in getattr(args, arg)]),','.join([ str(i) for i in parser.get_default(arg)]))) 57 | else: 58 | if getattr(args, arg) == parser.get_default(arg): 59 | print('==> {:50}: {:<}({:<})'.format(arg,getattr(args, arg),parser.get_default(arg))) 60 | print('==================================== Changed Parameters =============================================') 61 | for arg in vars(args): 62 | if type(getattr(args, arg)) == type([]): 63 | if ','.join([ str(i) for i in getattr(args, arg)]) != ','.join([ str(i) for i in parser.get_default(arg)]): 64 | print('==> {:50}: {:<}({:<})'.format(arg,','.join([ str(i) for i in getattr(args, arg)]),','.join([ str(i) for i in parser.get_default(arg)]))) 65 | else: 66 | if getattr(args, arg) != parser.get_default(arg): 67 | print('==> {:50}: {:<}({:<})'.format(arg,getattr(args, arg),parser.get_default(arg))) 68 | print('==================================== Start Init Model ===============================================') 69 | main(args) 70 | print('==================================== FINISH WITHOUT ERROR =============================================') 71 | -------------------------------------------------------------------------------- /scripts/datasets/COCO.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | import os 4 | import csv 5 | import numpy as np 6 | import json 7 | import random 8 | import math 9 | import matplotlib.pyplot as plt 10 | from collections import namedtuple 11 | from os import listdir 12 | from os.path import isfile, join 13 | 14 | import torch 15 | import torch.utils.data as data 16 | 17 | from scripts.utils.osutils import * 18 | from scripts.utils.imutils import * 19 | from scripts.utils.transforms import * 20 | import torchvision.transforms as transforms 21 | from PIL import Image 22 | from PIL import ImageEnhance 23 | from PIL import ImageFilter 24 | from PIL import ImageFile 25 | ImageFile.LOAD_TRUNCATED_IMAGES = True 26 | 27 | class COCO(data.Dataset): 28 | def __init__(self,train,config=None, sample=[],gan_norm=False): 29 | 30 | self.train = [] 31 | self.anno = [] 32 | self.mask = [] 33 | self.wm = [] 34 | self.input_size = config.input_size 35 | self.normalized_input = config.normalized_input 36 | self.base_folder = config.base_dir 37 | self.dataset = train+config.data 38 | 39 | if config == None: 40 | self.data_augumentation = False 41 | else: 42 | self.data_augumentation = config.data_augumentation 43 | 44 | self.istrain = False if self.dataset.find('train') == -1 else True 45 | self.sample = sample 46 | self.gan_norm = gan_norm 47 | mypath = join(self.base_folder,self.dataset) 48 | file_names = sorted([f for f in listdir(join(mypath,'image')) if isfile(join(mypath,'image', f)) ]) 49 | 50 | if config.limited_dataset > 0: 51 | xtrain = sorted(list(set([ file_name.split('-')[0] for file_name in file_names ]))) 52 | tmp = [] 53 | for x in xtrain: 54 | # get the file_name by identifier 55 | tmp.append([y for y in file_names if x in y][0]) 56 | 57 | file_names = tmp 58 | else: 59 | file_names = file_names 60 | 61 | for file_name in file_names: 62 | self.train.append(os.path.join(mypath,'image',file_name)) 63 | self.mask.append(os.path.join(mypath,'mask',file_name)) 64 | self.wm.append(os.path.join(mypath,'wm',file_name)) 65 | self.anno.append(os.path.join(self.base_folder,'natural',file_name.split('-')[0]+'.jpg')) 66 | 67 | if len(self.sample) > 0 : 68 | self.train = [ self.train[i] for i in self.sample ] 69 | self.mask = [ self.mask[i] for i in self.sample ] 70 | self.anno = [ self.anno[i] for i in self.sample ] 71 | 72 | self.trans = transforms.Compose([ 73 | transforms.Resize((self.input_size,self.input_size)), 74 | transforms.ToTensor() 75 | ]) 76 | 77 | print('total Dataset of '+self.dataset+' is : ', len(self.train)) 78 | 79 | 80 | def __getitem__(self, index): 81 | img = Image.open(self.train[index]).convert('RGB') 82 | mask = Image.open(self.mask[index]).convert('L') 83 | anno = Image.open(self.anno[index]).convert('RGB') 84 | wm = Image.open(self.wm[index]).convert('RGB') 85 | 86 | return {"image": self.trans(img), 87 | "target": self.trans(anno), 88 | "mask": self.trans(mask), 89 | "wm": self.trans(wm), 90 | "name": self.train[index].split('/')[-1], 91 | "imgurl":self.train[index], 92 | "maskurl":self.mask[index], 93 | "targeturl":self.anno[index], 94 | "wmurl":self.wm[index] 95 | } 96 | 97 | def __len__(self): 98 | 99 | return len(self.train) 100 | -------------------------------------------------------------------------------- /scripts/datasets/BIH.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | import os 4 | import csv 5 | import numpy as np 6 | import json 7 | import random 8 | import math 9 | import matplotlib.pyplot as plt 10 | from collections import namedtuple 11 | from os import listdir 12 | from os.path import isfile, join 13 | 14 | import torch 15 | import torch.utils.data as data 16 | 17 | from scripts.utils.osutils import * 18 | from scripts.utils.imutils import * 19 | from scripts.utils.transforms import * 20 | import torchvision.transforms as transforms 21 | from PIL import Image 22 | from PIL import ImageEnhance 23 | from PIL import ImageFilter 24 | from PIL import ImageFile 25 | ImageFile.LOAD_TRUNCATED_IMAGES = True 26 | 27 | class BIH(data.Dataset): 28 | def __init__(self,train,config=None, sample=[],gan_norm=False): 29 | 30 | self.train = [] 31 | self.anno = [] 32 | self.mask = [] 33 | self.wm = [] 34 | self.input_size = config.input_size 35 | self.normalized_input = config.normalized_input 36 | self.base_folder = config.base_dir +'/' + config.data 37 | self.dataset = config.data 38 | 39 | if config == None: 40 | self.data_augumentation = False 41 | else: 42 | self.data_augumentation = config.data_augumentation 43 | 44 | self.istrain = False if train.find('train') == -1 else True 45 | self.sample = sample 46 | self.gan_norm = gan_norm 47 | mypath = join(self.base_folder,self.dataset+'_'+train+'.txt') 48 | 49 | with open(mypath) as f: 50 | # here we get the filenames 51 | file_names = [ im.strip() for im in f.readlines() ] 52 | 53 | if config.limited_dataset > 0: 54 | xtrain = sorted(list(set([ file_name.split('-')[0] for file_name in file_names ]))) 55 | tmp = [] 56 | for x in xtrain: 57 | tmp.append([y for y in file_names if x in y][0]) 58 | 59 | file_names = tmp 60 | else: 61 | file_names = file_names 62 | 63 | for file_name in file_names: 64 | self.train.append(os.path.join(self.base_folder,'images',file_name)) 65 | self.mask.append(os.path.join(self.base_folder,'masks','_'.join(file_name.split('_')[0:2])+'.png')) 66 | self.anno.append(os.path.join(self.base_folder,'reals',file_name.split('_')[0]+'.jpg')) 67 | 68 | if len(self.sample) > 0 : 69 | self.train = [ self.train[i] for i in self.sample ] 70 | self.mask = [ self.mask[i] for i in self.sample ] 71 | self.anno = [ self.anno[i] for i in self.sample ] 72 | 73 | self.trans = transforms.Compose([ 74 | transforms.Resize((self.input_size,self.input_size)), 75 | transforms.ToTensor() 76 | ]) 77 | 78 | print('total Dataset of '+self.dataset+' is : ', len(self.train)) 79 | 80 | 81 | def __getitem__(self, index): 82 | img = Image.open(self.train[index]).convert('RGB') 83 | mask = Image.open(self.mask[index]).convert('L') 84 | anno = Image.open(self.anno[index]).convert('RGB') 85 | 86 | # for shadow removal and blind image harmonization, here is no ground truth wm 87 | # wm = Image.open(self.wm[index]).convert('RGB') 88 | 89 | return {"image": self.trans(img), 90 | "target": self.trans(anno), 91 | "mask": self.trans(mask), 92 | "name": self.train[index].split('/')[-1], 93 | "imgurl":self.train[index], 94 | "maskurl":self.mask[index], 95 | "targeturl":self.anno[index], 96 | } 97 | 98 | def __len__(self): 99 | 100 | return len(self.train) 101 | -------------------------------------------------------------------------------- /scripts/utils/evaluation.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import math 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from random import randint 7 | 8 | from .misc import * 9 | from .transforms import transform, transform_preds 10 | 11 | __all__ = ['accuracy', 'AverageMeter'] 12 | 13 | def get_preds(scores): 14 | ''' get predictions from score maps in torch Tensor 15 | return type: torch.LongTensor 16 | ''' 17 | assert scores.dim() == 4, 'Score maps should be 4-dim' 18 | maxval, idx = torch.max(scores.view(scores.size(0), scores.size(1), -1), 2) 19 | 20 | maxval = maxval.view(scores.size(0), scores.size(1), 1) 21 | idx = idx.view(scores.size(0), scores.size(1), 1) + 1 22 | 23 | preds = idx.repeat(1, 1, 2).float() 24 | 25 | preds[:,:,0] = (preds[:,:,0] - 1) % scores.size(3) + 1 26 | preds[:,:,1] = torch.floor((preds[:,:,1] - 1) / scores.size(2)) + 1 27 | 28 | pred_mask = maxval.gt(0).repeat(1, 1, 2).float() 29 | preds *= pred_mask 30 | return preds 31 | 32 | def calc_dists(preds, target, normalize): 33 | preds = preds.float() 34 | target = target.float() 35 | dists = torch.zeros(preds.size(1), preds.size(0)) 36 | for n in range(preds.size(0)): 37 | for c in range(preds.size(1)): 38 | if target[n,c,0] > 1 and target[n, c, 1] > 1: 39 | dists[c, n] = torch.dist(preds[n,c,:], target[n,c,:])/normalize[n] 40 | else: 41 | dists[c, n] = -1 42 | return dists 43 | 44 | def dist_acc(dists, thr=0.5): 45 | ''' Return percentage below threshold while ignoring values with a -1 ''' 46 | if dists.ne(-1).sum() > 0: 47 | return dists.le(thr).eq(dists.ne(-1)).sum()*1.0 / dists.ne(-1).sum() 48 | else: 49 | return -1 50 | 51 | 52 | 53 | def accuracy(output, target, thr=0.5): 54 | ''' Calculate accuracy according to PCK, but uses ground truth heatmap rather than x,y locations 55 | First value to be returned is average accuracy across 'idxs', followed by individual accuracies 56 | ''' 57 | # output_mask = torch.gt(output,thr); 58 | # target_mask = torch.gt(target,thr); 59 | # equal_mask = torch.eq(output_mask,target_mask); 60 | # fp_equal_mask = torch.lt(output_mask,target_mask); 61 | # fn_equal_mask = torch.gt(output_mask,target_mask); 62 | 63 | 64 | # tp = torch.sum(equal_mask); 65 | # fn = torch.sum(fn_equal_mask); 66 | # fp = torch.sum(fp_equal_mask); 67 | 68 | # return 2*tp / (2*tp+fn+fp) 69 | 70 | 71 | if output.dim() > 2: 72 | v,i = torch.max(output,1); 73 | else: 74 | v,i = torch.max(output,1); 75 | return torch.sum(target.long() == i).float()/target.numel() 76 | 77 | def final_preds(output, center, scale, res): 78 | coords = get_preds(output) # float type 79 | 80 | # pose-processing 81 | for n in range(coords.size(0)): 82 | for p in range(coords.size(1)): 83 | hm = output[n][p] 84 | px = int(math.floor(coords[n][p][0])) 85 | py = int(math.floor(coords[n][p][1])) 86 | if px > 1 and px < res[0] and py > 1 and py < res[1]: 87 | diff = torch.Tensor([hm[py - 1][px] - hm[py - 1][px - 2], hm[py][px - 1]-hm[py - 2][px - 1]]) 88 | coords[n][p] += diff.sign() * .25 89 | coords += 0.5 90 | preds = coords.clone() 91 | 92 | # Transform back 93 | for i in range(coords.size(0)): 94 | preds[i] = transform_preds(coords[i], center[i], scale[i], res) 95 | 96 | if preds.dim() < 3: 97 | preds = preds.view(1, preds.size()) 98 | 99 | return preds 100 | 101 | 102 | class AverageMeter(object): 103 | """Computes and stores the average and current value""" 104 | def __init__(self): 105 | self.reset() 106 | 107 | def reset(self): 108 | self.val = 0 109 | self.avg = 0 110 | self.sum = 0 111 | self.count = 0 112 | 113 | def update(self, val, n=1): 114 | self.val = val 115 | self.sum += val * n 116 | self.count += n 117 | self.avg = self.sum / self.count 118 | -------------------------------------------------------------------------------- /scripts/utils/logger.py: -------------------------------------------------------------------------------- 1 | # A simple torch style logger 2 | # (C) Wei YANG 2017 3 | from __future__ import absolute_import 4 | 5 | import os 6 | import sys 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | 10 | __all__ = ['Logger', 'LoggerMonitor', 'savefig'] 11 | 12 | def savefig(fname, dpi=None): 13 | dpi = 150 if dpi == None else dpi 14 | plt.savefig(fname, dpi=dpi) 15 | 16 | def plot_overlap(logger, names=None): 17 | names = logger.names if names == None else names 18 | numbers = logger.numbers 19 | for _, name in enumerate(names): 20 | x = np.arange(len(numbers[name])) 21 | plt.plot(x, np.asarray(numbers[name])) 22 | return [logger.title + '(' + name + ')' for name in names] 23 | 24 | class Logger(object): 25 | '''Save training process to log file with simple plot function.''' 26 | def __init__(self, fpath, title=None, resume=False): 27 | self.file = None 28 | self.resume = resume 29 | self.title = '' if title == None else title 30 | if fpath is not None: 31 | if resume: 32 | self.file = open(fpath, 'r') 33 | name = self.file.readline() 34 | self.names = name.rstrip().split('\t') 35 | self.numbers = {} 36 | for _, name in enumerate(self.names): 37 | self.numbers[name] = [] 38 | 39 | for numbers in self.file: 40 | numbers = numbers.rstrip().split('\t') 41 | for i in range(0, len(numbers)): 42 | self.numbers[self.names[i]].append(numbers[i]) 43 | self.file.close() 44 | self.file = open(fpath, 'a') 45 | else: 46 | self.file = open(fpath, 'w') 47 | 48 | def set_names(self, names): 49 | if self.resume: 50 | pass 51 | # initialize numbers as empty list 52 | self.numbers = {} 53 | self.names = names 54 | for _, name in enumerate(self.names): 55 | self.file.write(name) 56 | self.file.write('\t') 57 | self.numbers[name] = [] 58 | self.file.write('\n') 59 | self.file.flush() 60 | 61 | 62 | def append(self, numbers): 63 | assert len(self.names) == len(numbers), 'Numbers do not match names' 64 | for index, num in enumerate(numbers): 65 | self.file.write("{0:.6f}".format(num)) 66 | self.file.write('\t') 67 | self.numbers[self.names[index]].append(num) 68 | self.file.write('\n') 69 | self.file.flush() 70 | 71 | def plot(self, names=None): 72 | names = self.names if names == None else names 73 | numbers = self.numbers 74 | for _, name in enumerate(names): 75 | x = np.arange(len(numbers[name])) 76 | plt.plot(x, np.asarray(numbers[name])) 77 | plt.legend([self.title + '(' + name + ')' for name in names]) 78 | plt.grid(True) 79 | 80 | def close(self): 81 | if self.file is not None: 82 | self.file.close() 83 | 84 | class LoggerMonitor(object): 85 | '''Load and visualize multiple logs.''' 86 | def __init__ (self, paths): 87 | '''paths is a distionary with {name:filepath} pair''' 88 | self.loggers = [] 89 | for title, path in paths.items(): 90 | logger = Logger(path, title=title, resume=True) 91 | self.loggers.append(logger) 92 | 93 | def plot(self, names=None): 94 | plt.figure() 95 | plt.subplot(121) 96 | legend_text = [] 97 | for logger in self.loggers: 98 | legend_text += plot_overlap(logger, names) 99 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 100 | plt.grid(True) 101 | 102 | if __name__ == '__main__': 103 | # # Example 104 | # logger = Logger('test.txt') 105 | # logger.set_names(['Train loss', 'Valid loss','Test loss']) 106 | 107 | # length = 100 108 | # t = np.arange(length) 109 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 110 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 111 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 112 | 113 | # for i in range(0, length): 114 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]]) 115 | # logger.plot() 116 | 117 | # Example: logger monitor 118 | paths = { 119 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', 120 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', 121 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', 122 | } 123 | 124 | field = ['Valid Acc.'] 125 | 126 | monitor = LoggerMonitor(paths) 127 | monitor.plot(names=field) 128 | savefig('test.eps') -------------------------------------------------------------------------------- /watermark_synthesis.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "SAVE ALL THE SETTING\n" 13 | ] 14 | } 15 | ], 16 | "source": [ 17 | "# watermark synthesis\n", 18 | "import os \n", 19 | "import random\n", 20 | "import shutil\n", 21 | "from PIL import Image\n", 22 | "import numpy as np\n", 23 | "\n", 24 | "def trans_paste(bg_img,fg_img,mask,box=(0,0)):\n", 25 | " fg_img_trans = Image.new(\"RGBA\",bg_img.size)\n", 26 | " fg_img_trans.paste(fg_img,box,mask=mask)\n", 27 | " new_img = Image.alpha_composite(bg_img,fg_img_trans)\n", 28 | " return new_img,fg_img_trans\n", 29 | "\n", 30 | "if os.path.isdir('dataset'):\n", 31 | " shutil.rmtree('dataset')\n", 32 | "\n", 33 | "os.mkdir('dataset')\n", 34 | "BASE_IMG_DIR = '/Users/oishii/Downloads/val2014/'\n", 35 | "WATERMARK_DIR = 'logos' #1080 \n", 36 | "images = sorted([os.path.join(BASE_IMG_DIR,x) for x in os.listdir(BASE_IMG_DIR) if '.jpg' in x])\n", 37 | "watermarks = sorted([os.path.join(WATERMARK_DIR,x).replace(' ','_') for x in os.listdir(WATERMARK_DIR) if '.png' in x])\n", 38 | "# rename all the watermark from replace ' ' to '_'\n", 39 | "\n", 40 | "random.shuffle(images)\n", 41 | "random.shuffle(watermarks)\n", 42 | "\n", 43 | "train_images = images[:int(len(images)*0.7)]\n", 44 | "val_images = images[int(len(images)*0.7):int(len(images)*0.8)]\n", 45 | "test_images = images[int(len(images)*0.8):]\n", 46 | "\n", 47 | "train_wms = watermarks[:int(len(watermarks)*0.7)]\n", 48 | "val_wms = watermarks[int(len(watermarks)*0.7):int(len(watermarks)*0.8)]\n", 49 | "test_wms = watermarks[int(len(watermarks)*0.8):]\n", 50 | "\n", 51 | "# save all the settings to file\n", 52 | "names = ['train_images','val_images','test_images','train_wms','val_wms','test_wms']\n", 53 | "lists = [train_images,val_images,test_images,train_wms,val_wms,test_wms]\n", 54 | "dataset = dict(zip(names, lists))\n", 55 | "\n", 56 | "for name,content in dataset.items():\n", 57 | " with open('dataset/%s.txt'%name,'w') as f:\n", 58 | " f.write(\"\\n\".join(content))\n", 59 | "\n", 60 | "print('SAVE ALL THE SETTING')\n", 61 | "\n", 62 | "for name, images in dataset.items():\n", 63 | " if 'images' not in name:\n", 64 | " continue\n", 65 | " # for each setting, synthesis the watermark\n", 66 | " # for each image, add X(X=6) watermark in differnet position, alpha,\n", 67 | " # save the synthesized image, watermark mask, reshaped mask,\n", 68 | " save_path = 'dataset/%s/'%name\n", 69 | " os.makedirs('%s/image'%(save_path))\n", 70 | " os.makedirs('%s/mask'%(save_path))\n", 71 | " os.makedirs('%s/wm'%(save_path))\n", 72 | " \n", 73 | " for img in images:\n", 74 | " im = Image.open(img).convert('RGBA')\n", 75 | " imw,imh = im.size\n", 76 | " \n", 77 | " for wmg in random.choices(dataset[name.replace('images','wms')],k=6):\n", 78 | " wm = Image.open(wmg.replace('_',' ')).convert(\"RGBA\") # RGBA\n", 79 | " # get the mask of wm\n", 80 | " # data agumentation of wm\n", 81 | " wm = wm.rotate(angle=random.randint(0,360),expand=True) # rotate\n", 82 | " \n", 83 | " # make sure the \n", 84 | " imrw = random.randrange(int(0.4*imw),int(0.8*imw))\n", 85 | " imrh = random.randrange(int(0.4*imh),int(0.8*imh))\n", 86 | " wmsize = imrh if imrw > imrh else imrw\n", 87 | " wm = wm.resize((wmsize,wmsize),Image.BILINEAR)\n", 88 | " w,h = wm.size # new size \n", 89 | " \n", 90 | " box_left = random.randint(0,imw-w)\n", 91 | " box_upper = random.randint(0,imh-h)\n", 92 | " wmm = wm.copy()\n", 93 | " wm.putalpha(random.randint(int(255*0.4),int(255*0.8))) # alpha\n", 94 | " \n", 95 | " ims,wmc = trans_paste(im,wm,wmm,(box_left,box_upper))\n", 96 | " \n", 97 | " wmnp = np.array(wmc) # h,w,3\n", 98 | " mask = np.sum(wmnp,axis=2)>0\n", 99 | " mm = Image.fromarray(np.uint8(mask*255),mode='L')\n", 100 | " \n", 101 | " identifier = os.path.basename(img).split('.')[0] +'-'+os.path.basename(wmg).split('.')[0] + '.png'\n", 102 | " # save \n", 103 | " wmc.save('%s/wm/%s'%(save_path,identifier))\n", 104 | " ims.save('%s/image/%s'%(save_path,identifier))\n", 105 | " mm.save('%s/mask/%s'%(save_path,identifier))\n", 106 | " \n", 107 | " " 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [] 130 | } 131 | ], 132 | "metadata": { 133 | "kernelspec": { 134 | "display_name": "Python 3", 135 | "language": "python", 136 | "name": "python3" 137 | }, 138 | "language_info": { 139 | "codemirror_mode": { 140 | "name": "ipython", 141 | "version": 3 142 | }, 143 | "file_extension": ".py", 144 | "mimetype": "text/x-python", 145 | "name": "python", 146 | "nbconvert_exporter": "python", 147 | "pygments_lexer": "ipython3", 148 | "version": "3.7.4" 149 | } 150 | }, 151 | "nbformat": 4, 152 | "nbformat_minor": 2 153 | } 154 | -------------------------------------------------------------------------------- /scripts/utils/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import os 4 | import numpy as np 5 | import scipy.misc 6 | import matplotlib.pyplot as plt 7 | import torch 8 | import torchvision 9 | 10 | from .misc import * 11 | from .imutils import * 12 | 13 | 14 | def color_normalize(x, mean, std): 15 | if x.size(0) == 1: 16 | x = x.repeat(3, x.size(1), x.size(2)) 17 | 18 | for t, m, s in zip(x, mean, std): 19 | t.sub_(m) 20 | return x 21 | 22 | 23 | def flip_back(flip_output, dataset='mpii'): 24 | """ 25 | flip output map 26 | """ 27 | if dataset == 'mpii': 28 | matchedParts = ( 29 | [0,5], [1,4], [2,3], 30 | [10,15], [11,14], [12,13] 31 | ) 32 | else: 33 | print('Not supported dataset: ' + dataset) 34 | 35 | # flip output horizontally 36 | flip_output = fliplr(flip_output.numpy()) 37 | 38 | # Change left-right parts 39 | for pair in matchedParts: 40 | tmp = np.copy(flip_output[:, pair[0], :, :]) 41 | flip_output[:, pair[0], :, :] = flip_output[:, pair[1], :, :] 42 | flip_output[:, pair[1], :, :] = tmp 43 | 44 | return torch.from_numpy(flip_output).float() 45 | 46 | 47 | def shufflelr(x, width, dataset='mpii'): 48 | """ 49 | flip coords 50 | """ 51 | if dataset == 'mpii': 52 | matchedParts = ( 53 | [0,5], [1,4], [2,3], 54 | [10,15], [11,14], [12,13] 55 | ) 56 | else: 57 | print('Not supported dataset: ' + dataset) 58 | 59 | # Flip horizontal 60 | x[:, 0] = width - x[:, 0] 61 | 62 | # Change left-right parts 63 | for pair in matchedParts: 64 | tmp = x[pair[0], :].clone() 65 | x[pair[0], :] = x[pair[1], :] 66 | x[pair[1], :] = tmp 67 | 68 | return x 69 | 70 | 71 | def fliplr(x): 72 | if x.ndim == 3: 73 | x = np.transpose(np.fliplr(np.transpose(x, (0, 2, 1))), (0, 2, 1)) 74 | elif x.ndim == 4: 75 | for i in range(x.shape[0]): 76 | x[i] = np.transpose(np.fliplr(np.transpose(x[i], (0, 2, 1))), (0, 2, 1)) 77 | return x.astype(float) 78 | 79 | 80 | def get_transform(center, scale, res, rot=0): 81 | """ 82 | General image processing functions 83 | """ 84 | # Generate transformation matrix 85 | h = 200 * scale 86 | t = np.zeros((3, 3)) 87 | t[0, 0] = float(res[1]) / h 88 | t[1, 1] = float(res[0]) / h 89 | t[0, 2] = res[1] * (-float(center[0]) / h + .5) 90 | t[1, 2] = res[0] * (-float(center[1]) / h + .5) 91 | t[2, 2] = 1 92 | if not rot == 0: 93 | rot = -rot # To match direction of rotation from cropping 94 | rot_mat = np.zeros((3,3)) 95 | rot_rad = rot * np.pi / 180 96 | sn,cs = np.sin(rot_rad), np.cos(rot_rad) 97 | rot_mat[0,:2] = [cs, -sn] 98 | rot_mat[1,:2] = [sn, cs] 99 | rot_mat[2,2] = 1 100 | # Need to rotate around center 101 | t_mat = np.eye(3) 102 | t_mat[0,2] = -res[1]/2 103 | t_mat[1,2] = -res[0]/2 104 | t_inv = t_mat.copy() 105 | t_inv[:2,2] *= -1 106 | t = np.dot(t_inv,np.dot(rot_mat,np.dot(t_mat,t))) 107 | return t 108 | 109 | 110 | def transform(pt, center, scale, res, invert=0, rot=0): 111 | # Transform pixel location to different reference 112 | t = get_transform(center, scale, res, rot=rot) 113 | if invert: 114 | t = np.linalg.inv(t) 115 | new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T 116 | new_pt = np.dot(t, new_pt) 117 | return new_pt[:2].astype(int) + 1 118 | 119 | 120 | def transform_preds(coords, center, scale, res): 121 | # size = coords.size() 122 | # coords = coords.view(-1, coords.size(-1)) 123 | # print(coords.size()) 124 | for p in range(coords.size(0)): 125 | coords[p, 0:2] = to_torch(transform(coords[p, 0:2], center, scale, res, 1, 0)) 126 | return coords 127 | 128 | 129 | def crop(img, center, scale, res, rot=0): 130 | img = im_to_numpy(img) 131 | 132 | # Upper left point 133 | ul = np.array(transform([0, 0], center, scale, res, invert=1)) 134 | # Bottom right point 135 | br = np.array(transform(res, center, scale, res, invert=1)) 136 | 137 | # Padding so that when rotated proper amount of context is included 138 | pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2) 139 | if not rot == 0: 140 | ul -= pad 141 | br += pad 142 | 143 | new_shape = [br[1] - ul[1], br[0] - ul[0]] 144 | if len(img.shape) > 2: 145 | new_shape += [img.shape[2]] 146 | new_img = np.zeros(new_shape) 147 | 148 | # Range to fill new array 149 | new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0] 150 | new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1] 151 | # Range to sample from original image 152 | old_x = max(0, ul[0]), min(len(img[0]), br[0]) 153 | old_y = max(0, ul[1]), min(len(img), br[1]) 154 | new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]] 155 | 156 | if not rot == 0: 157 | # Remove padding 158 | new_img = scipy.misc.imrotate(new_img, rot) 159 | new_img = new_img[pad:-pad, pad:-pad] 160 | 161 | new_img = im_to_torch(scipy.misc.imresize(new_img, res)) 162 | return new_img 163 | 164 | 165 | def get_right(img,gray=False): 166 | img = im_to_numpy(img) #H*W*C 167 | 168 | new_img = img[:,0:256,:] 169 | 170 | 171 | new_img = im_to_torch(new_img) 172 | if gray == True: 173 | new_img = new_img[1,:,:]; 174 | 175 | return new_img 176 | 177 | class NormalizeInverse(torchvision.transforms.Normalize): 178 | """ 179 | Undoes the normalization and returns the reconstructed images in the input domain. 180 | """ 181 | 182 | def __init__(self, mean, std): 183 | mean = torch.as_tensor(mean) 184 | std = torch.as_tensor(std) 185 | std_inv = 1 / (std + 1e-7) 186 | mean_inv = -mean * std_inv 187 | super().__init__(mean=mean_inv, std=std_inv) 188 | 189 | def __call__(self, tensor): 190 | return super().__call__(tensor.clone()) 191 | -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | 2 | import scripts.models as models 3 | 4 | model_names = sorted(name for name in models.__dict__ 5 | if name.islower() and not name.startswith("__") 6 | and callable(models.__dict__[name])) 7 | 8 | class Options(): 9 | """docstring for Options""" 10 | def __init__(self): 11 | pass 12 | 13 | def init(self, parser): 14 | # Model structure 15 | parser.add_argument('--arch', '-a', metavar='ARCH', default='dhn', 16 | choices=model_names, 17 | help='model architecture: ' + 18 | ' | '.join(model_names) + 19 | ' (default: resnet18)') 20 | parser.add_argument('--darch', metavar='ARCH', default='dhn', 21 | choices=model_names, 22 | help='model architecture: ' + 23 | ' | '.join(model_names) + 24 | ' (default: resnet18)') 25 | 26 | parser.add_argument('--machine', '-m', metavar='NACHINE', default='basic') 27 | # Training strategy 28 | parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', 29 | help='number of data loading workers (default: 4)') 30 | parser.add_argument('--epochs', default=30, type=int, metavar='N', 31 | help='number of total epochs to run') 32 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 33 | help='manual epoch number (useful on restarts)') 34 | parser.add_argument('--train-batch', default=64, type=int, metavar='N', 35 | help='train batchsize') 36 | parser.add_argument('--test-batch', default=6, type=int, metavar='N', 37 | help='test batchsize') 38 | parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float,metavar='LR', help='initial learning rate') 39 | parser.add_argument('--dlr', '--dlearning-rate', default=1e-3, type=float, help='initial learning rate') 40 | parser.add_argument('--beta1', default=0.9, type=float, help='initial learning rate') 41 | parser.add_argument('--beta2', default=0.999, type=float, help='initial learning rate') 42 | parser.add_argument('--momentum', default=0, type=float, metavar='M', 43 | help='momentum') 44 | parser.add_argument('--weight-decay', '--wd', default=0, type=float, 45 | metavar='W', help='weight decay (default: 0)') 46 | parser.add_argument('--schedule', type=int, nargs='+', default=[5, 10], 47 | help='Decrease learning rate at these epochs.') 48 | parser.add_argument('--gamma', type=float, default=0.1, 49 | help='LR is multiplied by gamma on schedule.') 50 | # Data processing 51 | parser.add_argument('-f', '--flip', dest='flip', action='store_true', 52 | help='flip the input during validation') 53 | parser.add_argument('--lambdaL1', type=float, default=1, help='the weight of L1.') 54 | parser.add_argument('--alpha', type=float, default=0.5, 55 | help='Groundtruth Gaussian sigma.') 56 | parser.add_argument('--sigma-decay', type=float, default=0, 57 | help='Sigma decay rate for each epoch.') 58 | # Miscs 59 | parser.add_argument('--base-dir', default='/PATH_TO_DATA_FOLDER/', type=str, metavar='PATH') 60 | parser.add_argument('--data', default='', type=str, metavar='PATH', 61 | help='path to save checkpoint (default: checkpoint)') 62 | parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', 63 | help='path to save checkpoint (default: checkpoint)') 64 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 65 | help='path to latest checkpoint (default: none)') 66 | parser.add_argument('--finetune', default='', type=str, metavar='PATH', 67 | help='path to latest checkpoint (default: none)') 68 | 69 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 70 | help='evaluate model on validation set') 71 | parser.add_argument('--style-loss', default=0, type=float, 72 | help='preception loss') 73 | parser.add_argument('--ssim-loss', default=0, type=float,help='msssim loss') 74 | parser.add_argument('--att-loss', default=1, type=float,help='msssim loss') 75 | parser.add_argument('--default-loss',default=False,type=bool) 76 | parser.add_argument('--sltype', default='vggx', type=str) 77 | parser.add_argument('-da', '--data-augumentation', default=False, type=bool, 78 | help='preception loss') 79 | parser.add_argument('-d', '--debug', dest='debug', action='store_true', 80 | help='show intermediate results') 81 | parser.add_argument('--input-size', default=256, type=int, metavar='N', 82 | help='train batchsize') 83 | parser.add_argument('--freq', default=-1, type=int, metavar='N', 84 | help='evaluation frequence') 85 | parser.add_argument('--normalized-input', default=False, type=bool, 86 | help='train batchsize') 87 | parser.add_argument('--res', default=False, type=bool,help='residual learning for s2am') 88 | parser.add_argument('--requires-grad', default=False, type=bool, 89 | help='train batchsize') 90 | parser.add_argument('--limited-dataset', default=0, type=int, metavar='N') 91 | parser.add_argument('--gpu',default=True,type=bool) 92 | parser.add_argument('--masked',default=False,type=bool) 93 | parser.add_argument('--gan-norm', default=False,type=bool, help='train batchsize') 94 | parser.add_argument('--hl', default=False,type=bool, help='homogenious leanring') 95 | parser.add_argument('--loss-type', default='l2',type=str, help='train batchsize') 96 | return parser -------------------------------------------------------------------------------- /scripts/models/rasc.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | import torchvision 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import numpy as np 8 | import math 9 | 10 | from scripts.utils.model_init import * 11 | from scripts.models.vgg import Vgg16 12 | from scripts.models.blocks import * 13 | 14 | 15 | class CAWapper(nn.Module): 16 | """docstring for SENet""" 17 | 18 | def __init__(self, channel, type_of_connection=BasicLearningBlock): 19 | super(CAWapper, self).__init__() 20 | self.attention = ContextualAttention(ksize=3, stride=1, rate=2, fuse_k=3, softmax_scale=10, fuse=True, use_cuda=True) 21 | 22 | def forward(self, feature, mask): 23 | _, _, w, _ = feature.size() 24 | _, _, mw, _ = mask.size() 25 | # binaryfiy 26 | # selected the feature from the background as the additional feature to masked splicing feature. 27 | mask = torch.round(F.avg_pool2d(mask, 2, stride=mw//w)) 28 | 29 | result = self.attention(feature,mask) 30 | 31 | return result 32 | 33 | 34 | class NLWapper(nn.Module): 35 | """docstring for SENet""" 36 | 37 | def __init__(self, channel, type_of_connection=BasicLearningBlock): 38 | super(NLWapper, self).__init__() 39 | self.attention = NONLocalBlock2D(channel) 40 | 41 | def forward(self, feature, mask): 42 | _, _, w, _ = feature.size() 43 | _, _, mw, _ = mask.size() 44 | # binaryfiy 45 | # selected the feature from the background as the additional feature to masked splicing feature. 46 | # mask = torch.round(F.avg_pool2d(mask, 2, stride=mw//w)) 47 | 48 | result = self.attention(feature) 49 | 50 | return result 51 | 52 | class SENet(nn.Module): 53 | """docstring for SENet""" 54 | def __init__(self,channel,type_of_connection=BasicLearningBlock): 55 | super(SENet, self).__init__() 56 | self.attention = SEBlock(channel,16) 57 | 58 | def forward(self,feature,mask): 59 | _,_,w,_ = feature.size() 60 | _,_,mw,_ = mask.size() 61 | # binaryfiy 62 | # selected the feature from the background as the additional feature to masked splicing feature. 63 | mask = torch.round(F.avg_pool2d(mask,2,stride=mw//w)) 64 | 65 | result = self.attention(feature) 66 | 67 | return result 68 | 69 | class CBAMConnect(nn.Module): 70 | def __init__(self,channel): 71 | super(CBAMConnect, self).__init__() 72 | self.attention = CBAM(channel) 73 | 74 | def forward(self,feature,mask): 75 | results = self.attention(feature) 76 | return results 77 | 78 | 79 | 80 | class RASC(nn.Module): 81 | def __init__(self,channel,type_of_connection=BasicLearningBlock): 82 | super(RASC, self).__init__() 83 | self.connection = type_of_connection(channel) 84 | self.background_attention = GlobalAttentionModule(channel,16) 85 | self.mixed_attention = GlobalAttentionModule(channel,16) 86 | self.spliced_attention = GlobalAttentionModule(channel,16) 87 | self.gaussianMask = GaussianSmoothing(1,5,1) 88 | 89 | def forward(self,feature,mask): 90 | _,_,w,_ = feature.size() 91 | _,_,mw,_ = mask.size() 92 | # binaryfiy 93 | # selected the feature from the background as the additional feature to masked splicing feature. 94 | if w != mw: 95 | mask = torch.round(F.avg_pool2d(mask,2,stride=mw//w)) 96 | reverse_mask = -1*(mask-1) 97 | # here we add gaussin filter to mask and reverse_mask for better harimoization of edges. 98 | 99 | mask = self.gaussianMask(F.pad(mask,(2,2,2,2),mode='reflect')) 100 | reverse_mask = self.gaussianMask(F.pad(reverse_mask,(2,2,2,2),mode='reflect')) 101 | 102 | 103 | background = self.background_attention(feature) * reverse_mask 104 | selected_feature = self.mixed_attention(feature) 105 | spliced_feature = self.spliced_attention(feature) 106 | spliced = ( self.connection(spliced_feature) + selected_feature ) * mask 107 | return background + spliced 108 | 109 | 110 | class UNO(nn.Module): 111 | def __init__(self,channel): 112 | super(UNO, self).__init__() 113 | 114 | def forward(self,feature,_m): 115 | return feature 116 | 117 | 118 | class URASC(nn.Module): 119 | def __init__(self,channel,type_of_connection=BasicLearningBlock): 120 | super(URASC, self).__init__() 121 | self.connection = type_of_connection(channel) 122 | self.background_attention = GlobalAttentionModule(channel,16) 123 | self.mixed_attention = GlobalAttentionModule(channel,16) 124 | self.spliced_attention = GlobalAttentionModule(channel,16) 125 | self.mask_attention = SpatialAttentionModule(channel,16) 126 | 127 | def forward(self,feature, m=None): 128 | _,_,w,_ = feature.size() 129 | 130 | mask, reverse_mask = self.mask_attention(feature) 131 | 132 | background = self.background_attention(feature) * reverse_mask 133 | selected_feature = self.mixed_attention(feature) 134 | spliced_feature = self.spliced_attention(feature) 135 | spliced = ( self.connection(spliced_feature) + selected_feature ) * mask 136 | return background + spliced 137 | 138 | 139 | class MaskedURASC(nn.Module): 140 | def __init__(self,channel,type_of_connection=BasicLearningBlock): 141 | super(MaskedURASC, self).__init__() 142 | self.connection = type_of_connection(channel) 143 | self.background_attention = GlobalAttentionModule(channel,16) 144 | self.mixed_attention = GlobalAttentionModule(channel,16) 145 | self.spliced_attention = GlobalAttentionModule(channel,16) 146 | self.mask_attention = SpatialAttentionModule(channel,16) 147 | 148 | def forward(self,feature): 149 | _,_,w,_ = feature.size() 150 | 151 | mask, reverse_mask = self.mask_attention(feature) 152 | 153 | background = self.background_attention(feature) * reverse_mask 154 | selected_feature = self.mixed_attention(feature) 155 | spliced_feature = self.spliced_attention(feature) 156 | spliced = ( self.connection(spliced_feature) + selected_feature ) * mask 157 | return background + spliced, mask 158 | 159 | -------------------------------------------------------------------------------- /scripts/models/discriminator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import functools 3 | import math 4 | import torch 5 | from torch.autograd import Variable 6 | import torch.nn.functional as F 7 | from torch import nn 8 | from torch import Tensor 9 | from torch.nn import Parameter 10 | from scripts.utils.model_init import * 11 | from torch.optim.optimizer import Optimizer, required 12 | 13 | 14 | __all__ = ['patchgan','sngan','maskedsngan'] 15 | 16 | 17 | class SNCoXvWithActivation(torch.nn.Module): 18 | """ 19 | SN convolution for spetral normalization conv 20 | """ 21 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, activation=torch.nn.LeakyReLU(0.2, inplace=True)): 22 | super(SNCoXvWithActivation, self).__init__() 23 | self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) 24 | self.conv2d = torch.nn.utils.spectral_norm(self.conv2d) 25 | self.activation = activation 26 | for m in self.modules(): 27 | if isinstance(m, nn.Conv2d): 28 | nn.init.kaiming_normal_(m.weight) 29 | def forward(self, input): 30 | x = self.conv2d(input) 31 | if self.activation is not None: 32 | return self.activation(x) 33 | else: 34 | return x 35 | 36 | def l2normalize(v, eps=1e-12): 37 | return v / (v.norm() + eps) 38 | 39 | 40 | class SpectralNorm(nn.Module): 41 | def __init__(self, module, name='weight', power_iterations=1): 42 | super(SpectralNorm, self).__init__() 43 | self.module = module 44 | self.name = name 45 | self.power_iterations = power_iterations 46 | if not self._made_params(): 47 | self._make_params() 48 | 49 | def _update_u_v(self): 50 | u = getattr(self.module, self.name + "_u") 51 | v = getattr(self.module, self.name + "_v") 52 | w = getattr(self.module, self.name + "_bar") 53 | 54 | height = w.data.shape[0] 55 | for _ in range(self.power_iterations): 56 | v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data)) 57 | u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data)) 58 | 59 | # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data)) 60 | sigma = u.dot(w.view(height, -1).mv(v)) 61 | setattr(self.module, self.name, w / sigma.expand_as(w)) 62 | 63 | def _made_params(self): 64 | try: 65 | u = getattr(self.module, self.name + "_u") 66 | v = getattr(self.module, self.name + "_v") 67 | w = getattr(self.module, self.name + "_bar") 68 | return True 69 | except AttributeError: 70 | return False 71 | 72 | 73 | def _make_params(self): 74 | w = getattr(self.module, self.name) 75 | 76 | height = w.data.shape[0] 77 | width = w.view(height, -1).data.shape[1] 78 | 79 | u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) 80 | v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) 81 | u.data = l2normalize(u.data) 82 | v.data = l2normalize(v.data) 83 | w_bar = Parameter(w.data) 84 | 85 | del self.module._parameters[self.name] 86 | 87 | self.module.register_parameter(self.name + "_u", u) 88 | self.module.register_parameter(self.name + "_v", v) 89 | self.module.register_parameter(self.name + "_bar", w_bar) 90 | 91 | 92 | def forward(self, *args): 93 | self._update_u_v() 94 | return self.module.forward(*args) 95 | 96 | 97 | def get_pad(in_, ksize, stride, atrous=1): 98 | out_ = np.ceil(float(in_)/stride) 99 | return int(((out_ - 1) * stride + atrous*(ksize-1) + 1 - in_)/2) 100 | 101 | class SNDiscriminator(nn.Module): 102 | def __init__(self,channel=6): 103 | super(SNDiscriminator, self).__init__() 104 | cnum = 32 105 | self.discriminator_net = nn.Sequential( 106 | SNCoXvWithActivation(channel, 2*cnum, 4, 2, padding=get_pad(256, 5, 2)), 107 | SNCoXvWithActivation(2*cnum, 4*cnum, 4, 2, padding=get_pad(128, 5, 2)), 108 | SNCoXvWithActivation(4*cnum, 8*cnum, 4, 2, padding=get_pad(64, 5, 2)), 109 | SNCoXvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(32, 5, 2)), 110 | SNCoXvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(16, 5, 2)), # 8*8*256 111 | # SNConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(8, 5, 2)), # 4*4*256 112 | # SNConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(4, 5, 2)), # 2*2*256 113 | ) 114 | # self.linear = nn.Linear(2*2*256,1) 115 | 116 | def forward(self, img_A, img_B): 117 | # Concatenate image and condition image by channels to produce input 118 | img_input = torch.cat((img_A, img_B), 1) 119 | x = self.discriminator_net(img_input) 120 | # x = x.view((x.size(0),-1)) 121 | # x = self.linear(x) 122 | return x 123 | 124 | class Discriminator(nn.Module): 125 | def __init__(self, in_channels=3): 126 | super(Discriminator, self).__init__() 127 | 128 | def discriminator_block(in_filters, out_filters, normalization=True): 129 | """Returns downsampling layers of each discriminator block""" 130 | layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)] 131 | if normalization: 132 | layers.append(nn.InstanceNorm2d(out_filters)) 133 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 134 | return layers 135 | 136 | self.model = nn.Sequential( 137 | *discriminator_block(in_channels*2, 64, normalization=False), 138 | *discriminator_block(64, 128), 139 | *discriminator_block(128, 256), 140 | *discriminator_block(256, 512), 141 | nn.ZeroPad2d((1, 0, 1, 0)), 142 | nn.Conv2d(512, 1, 4, padding=1, bias=False) 143 | ) 144 | 145 | def forward(self, img_A, img_B): 146 | # Concatenate image and condition image by channels to produce input 147 | img_input = torch.cat((img_A, img_B), 1) 148 | return self.model(img_input) 149 | 150 | 151 | def patchgan(): 152 | model = Discriminator() 153 | model.apply(weights_init_kaiming) 154 | return model 155 | 156 | def sngan(): 157 | model = SNDiscriminator() 158 | model.apply(weights_init_kaiming) 159 | return model 160 | 161 | def maskedsngan(): 162 | model = SNDiscriminator(channel=7) 163 | model.apply(weights_init_kaiming) 164 | return model -------------------------------------------------------------------------------- /scripts/models/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import functools 5 | from scripts.models.blocks import * 6 | from scripts.models.rasc import * 7 | 8 | 9 | class MinimalUnetV2(nn.Module): 10 | """docstring for MinimalUnet""" 11 | def __init__(self, down=None,up=None,submodule=None,attention=None,withoutskip=False,**kwags): 12 | super(MinimalUnetV2, self).__init__() 13 | 14 | self.down = nn.Sequential(*down) 15 | self.up = nn.Sequential(*up) 16 | self.sub = submodule 17 | self.attention = attention 18 | self.withoutskip = withoutskip 19 | self.is_attention = not self.attention == None 20 | self.is_sub = not submodule == None 21 | 22 | def forward(self,x,mask=None): 23 | if self.is_sub: 24 | x_up,_ = self.sub(self.down(x),mask) 25 | else: 26 | x_up = self.down(x) 27 | 28 | if self.withoutskip: #outer or inner. 29 | x_out = self.up(x_up) 30 | else: 31 | if self.is_attention: 32 | x_out = (self.attention(torch.cat([x,self.up(x_up)],1),mask),mask) 33 | else: 34 | x_out = (torch.cat([x,self.up(x_up)],1),mask) 35 | 36 | return x_out 37 | 38 | 39 | class MinimalUnet(nn.Module): 40 | """docstring for MinimalUnet""" 41 | def __init__(self, down=None,up=None,submodule=None,attention=None,withoutskip=False,**kwags): 42 | super(MinimalUnet, self).__init__() 43 | 44 | self.down = nn.Sequential(*down) 45 | self.up = nn.Sequential(*up) 46 | self.sub = submodule 47 | self.attention = attention 48 | self.withoutskip = withoutskip 49 | self.is_attention = not self.attention == None 50 | self.is_sub = not submodule == None 51 | 52 | def forward(self,x,mask=None): 53 | if self.is_sub: 54 | x_up,_ = self.sub(self.down(x),mask) 55 | else: 56 | x_up = self.down(x) 57 | 58 | if self.is_attention: 59 | x = self.attention(x,mask) 60 | 61 | if self.withoutskip: #outer or inner. 62 | x_out = self.up(x_up) 63 | else: 64 | x_out = (torch.cat([x,self.up(x_up)],1),mask) 65 | 66 | return x_out 67 | 68 | 69 | # Defines the submodule with skip connection. 70 | # X -------------------identity---------------------- X 71 | # |-- downsampling -- |submodule| -- upsampling --| 72 | class UnetSkipConnectionBlock(nn.Module): 73 | def __init__(self, outer_nc, inner_nc, input_nc=None, 74 | submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False,is_attention_layer=False, 75 | attention_model=RASC,basicblock=MinimalUnet,outermostattention=False): 76 | super(UnetSkipConnectionBlock, self).__init__() 77 | self.outermost = outermost 78 | if type(norm_layer) == functools.partial: 79 | use_bias = norm_layer.func == nn.InstanceNorm2d 80 | else: 81 | use_bias = norm_layer == nn.InstanceNorm2d 82 | if input_nc is None: 83 | input_nc = outer_nc 84 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, 85 | stride=2, padding=1, bias=use_bias) 86 | downrelu = nn.LeakyReLU(0.2, True) 87 | downnorm = norm_layer(inner_nc) 88 | uprelu = nn.ReLU(True) 89 | upnorm = norm_layer(outer_nc) 90 | 91 | 92 | if outermost: 93 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 94 | kernel_size=4, stride=2, 95 | padding=1) 96 | down = [downconv] 97 | up = [uprelu, upconv] 98 | model = basicblock(down,up,submodule,withoutskip=outermost) 99 | elif innermost: 100 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc, 101 | kernel_size=4, stride=2, 102 | padding=1, bias=use_bias) 103 | down = [downrelu, downconv] 104 | up = [uprelu, upconv, upnorm] 105 | model = basicblock(down,up) 106 | else: 107 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 108 | kernel_size=4, stride=2, 109 | padding=1, bias=use_bias) 110 | down = [downrelu, downconv, downnorm] 111 | up = [uprelu, upconv, upnorm] 112 | 113 | if is_attention_layer: 114 | if MinimalUnetV2.__qualname__ in basicblock.__qualname__ : 115 | attention_model = attention_model(input_nc*2) 116 | else: 117 | attention_model = attention_model(input_nc) 118 | else: 119 | attention_model = None 120 | 121 | if use_dropout: 122 | model = basicblock(down,up.append(nn.Dropout(0.5)),submodule,attention_model,outermostattention=outermostattention) 123 | else: 124 | model = basicblock(down,up,submodule,attention_model,outermostattention=outermostattention) 125 | 126 | self.model = model 127 | 128 | 129 | def forward(self, x,mask=None): 130 | # build the mask for attention use 131 | return self.model(x,mask) 132 | 133 | class UnetGenerator(nn.Module): 134 | def __init__(self, input_nc, output_nc, num_downs=8, ngf=64,norm_layer=nn.BatchNorm2d, use_dropout=False, 135 | is_attention_layer=False,attention_model=RASC,use_inner_attention=False,basicblock=MinimalUnet): 136 | super(UnetGenerator, self).__init__() 137 | 138 | # 8 for 256x256 139 | # 9 for 512x512 140 | # construct unet structure 141 | self.need_mask = not input_nc == output_nc 142 | 143 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True,basicblock=basicblock) # 1 144 | for i in range(num_downs - 5): #3 times 145 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout,is_attention_layer=use_inner_attention,attention_model=attention_model,basicblock=basicblock) # 8,4,2 146 | unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer,is_attention_layer=is_attention_layer,attention_model=attention_model,basicblock=basicblock) #16 147 | unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer,is_attention_layer=is_attention_layer,attention_model=attention_model,basicblock=basicblock) #32 148 | unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer,is_attention_layer=is_attention_layer,attention_model=attention_model,basicblock=basicblock, outermostattention=True) #64 149 | unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, basicblock=basicblock, norm_layer=norm_layer) # 128 150 | 151 | self.model = unet_block 152 | 153 | def forward(self, input): 154 | if self.need_mask: 155 | return self.model(input,input[:,3:4,:,:]) 156 | else: 157 | return self.model(input[:,0:3,:,:],input[:,3:4,:,:]) 158 | 159 | 160 | 161 | -------------------------------------------------------------------------------- /scripts/utils/imutils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | import scipy.misc 7 | 8 | from .misc import * 9 | 10 | def im_to_numpy(img): 11 | img = to_numpy(img) 12 | img = np.transpose(img, (1, 2, 0)) # H*W*C 13 | return img 14 | 15 | def im_to_torch(img): 16 | img = np.transpose(img, (2, 0, 1)) # C*H*W 17 | img = to_torch(img).float() 18 | if img.max() > 1: 19 | img /= 255 20 | return img 21 | 22 | def load_image(img_path): 23 | # H x W x C => C x H x W 24 | return im_to_torch(scipy.misc.imread(img_path, mode='RGB')) 25 | 26 | def imread_all(img_path): 27 | return scipy.misc.imread(img_path, mode='RGB') 28 | 29 | def load_image_gray(img_path): 30 | # H x W x C => C x H x W 31 | x = scipy.misc.imread(img_path, mode='L') 32 | x = x[:,:,np.newaxis] 33 | return im_to_torch(x) 34 | 35 | def resize(img, owidth, oheight): 36 | img = im_to_numpy(img) 37 | 38 | if img.shape[2] == 1: 39 | img = scipy.misc.imresize(img.squeeze(),(oheight,owidth)) 40 | img = img[:,:,np.newaxis] 41 | else: 42 | img = scipy.misc.imresize( 43 | img, 44 | (oheight, owidth) 45 | ) 46 | img = im_to_torch(img) 47 | # print('%f %f' % (img.min(), img.max())) 48 | return img 49 | 50 | # ============================================================================= 51 | # Helpful functions generating groundtruth labelmap 52 | # ============================================================================= 53 | 54 | def gaussian(shape=(7,7),sigma=1): 55 | """ 56 | 2D gaussian mask - should give the same result as MATLAB's 57 | fspecial('gaussian',[shape],[sigma]) 58 | """ 59 | m,n = [(ss-1.)/2. for ss in shape] 60 | y,x = np.ogrid[-m:m+1,-n:n+1] 61 | h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) ) 62 | h[ h < np.finfo(h.dtype).eps*h.max() ] = 0 63 | return to_torch(h).float() 64 | 65 | def draw_labelmap(img, pt, sigma, type='Gaussian'): 66 | # Draw a 2D gaussian 67 | # Adopted from https://github.com/anewell/pose-hg-train/blob/master/src/pypose/draw.py 68 | img = to_numpy(img) 69 | 70 | # Check that any part of the gaussian is in-bounds 71 | ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)] 72 | br = [int(pt[0] + 3 * sigma + 1), int(pt[1] + 3 * sigma + 1)] 73 | if (ul[0] >= img.shape[1] or ul[1] >= img.shape[0] or 74 | br[0] < 0 or br[1] < 0): 75 | # If not, just return the image as is 76 | return to_torch(img) 77 | 78 | # Generate gaussian 79 | size = 6 * sigma + 1 80 | x = np.arange(0, size, 1, float) 81 | y = x[:, np.newaxis] 82 | x0 = y0 = size // 2 83 | # The gaussian is not normalized, we want the center value to equal 1 84 | if type == 'Gaussian': 85 | g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2)) 86 | elif type == 'Cauchy': 87 | g = sigma / (((x - x0) ** 2 + (y - y0) ** 2 + sigma ** 2) ** 1.5) 88 | 89 | 90 | # Usable gaussian range 91 | g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0] 92 | g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1] 93 | # Image range 94 | img_x = max(0, ul[0]), min(br[0], img.shape[1]) 95 | img_y = max(0, ul[1]), min(br[1], img.shape[0]) 96 | 97 | img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]] 98 | return to_torch(img) 99 | 100 | # ============================================================================= 101 | # Helpful display functions 102 | # ============================================================================= 103 | 104 | def gauss(x, a, b, c, d=0): 105 | return a * np.exp(-(x - b)**2 / (2 * c**2)) + d 106 | 107 | def color_heatmap(x): 108 | x = to_numpy(x) 109 | color = np.zeros((x.shape[0],x.shape[1],3)) 110 | color[:,:,0] = gauss(x, .5, .6, .2) + gauss(x, 1, .8, .3) 111 | color[:,:,1] = gauss(x, 1, .5, .3) 112 | color[:,:,2] = gauss(x, 1, .2, .3) 113 | color[color > 1] = 1 114 | color = (color * 255).astype(np.uint8) 115 | return color 116 | 117 | def imshow(img): 118 | npimg = im_to_numpy(img*255).astype(np.uint8) 119 | plt.imshow(npimg) 120 | plt.axis('off') 121 | 122 | def show_joints(img, pts): 123 | imshow(img) 124 | 125 | for i in range(pts.size(0)): 126 | if pts[i, 2] > 0: 127 | plt.plot(pts[i, 0], pts[i, 1], 'yo') 128 | plt.axis('off') 129 | 130 | def show_sample(inputs, target): 131 | num_sample = inputs.size(0) 132 | num_joints = target.size(1) 133 | height = target.size(2) 134 | width = target.size(3) 135 | 136 | for n in range(num_sample): 137 | inp = resize(inputs[n], width, height) 138 | out = inp 139 | for p in range(num_joints): 140 | tgt = inp*0.5 + color_heatmap(target[n,p,:,:])*0.5 141 | out = torch.cat((out, tgt), 2) 142 | 143 | imshow(out) 144 | plt.show() 145 | 146 | def sample_with_heatmap(inp, out, num_rows=2, parts_to_show=None): 147 | inp = to_numpy(inp * 255) 148 | out = to_numpy(out) 149 | 150 | img = np.zeros((inp.shape[1], inp.shape[2], inp.shape[0])) 151 | for i in range(3): 152 | img[:, :, i] = inp[i, :, :] 153 | 154 | if parts_to_show is None: 155 | parts_to_show = np.arange(out.shape[0]) 156 | 157 | # Generate a single image to display input/output pair 158 | num_cols = int(np.ceil(float(len(parts_to_show)) / num_rows)) 159 | size = img.shape[0] // num_rows 160 | 161 | full_img = np.zeros((img.shape[0], size * (num_cols + num_rows), 3), np.uint8) 162 | full_img[:img.shape[0], :img.shape[1]] = img 163 | 164 | inp_small = scipy.misc.imresize(img, [size, size]) 165 | 166 | # Set up heatmap display for each part 167 | for i, part in enumerate(parts_to_show): 168 | part_idx = part 169 | out_resized = scipy.misc.imresize(out[part_idx], [size, size]) 170 | out_resized = out_resized.astype(float)/255 171 | out_img = inp_small.copy() * .3 172 | color_hm = color_heatmap(out_resized) 173 | out_img += color_hm * .7 174 | 175 | col_offset = (i % num_cols + num_rows) * size 176 | row_offset = (i // num_cols) * size 177 | full_img[row_offset:row_offset + size, col_offset:col_offset + size] = out_img 178 | 179 | return full_img 180 | 181 | def batch_with_heatmap(inputs, outputs, mean=torch.Tensor([0.5, 0.5, 0.5]), num_rows=2, parts_to_show=None): 182 | batch_img = [] 183 | for n in range(min(inputs.size(0), 4)): 184 | inp = inputs[n] + mean.view(3, 1, 1).expand_as(inputs[n]) 185 | batch_img.append( 186 | sample_with_heatmap(inp.clamp(0, 1), outputs[n], num_rows=num_rows, parts_to_show=parts_to_show) 187 | ) 188 | return np.concatenate(batch_img) 189 | 190 | 191 | def normalize_batch(batch): 192 | # normalize using imagenet mean and std 193 | mean = batch.new_tensor([0.485, 0.456, 0.406]).view(-1, 1, 1) 194 | std = batch.new_tensor([0.229, 0.224, 0.225]).view(-1, 1, 1) 195 | batch = batch/255.0 196 | return (batch - mean) / std 197 | 198 | def show_image_tensor(tensor): 199 | re = [] 200 | for i in range(tensor.size(0)): 201 | inp = tensor[i].data.cpu() #w,h,c 202 | inp = inp.numpy().transpose((1, 2, 0)) 203 | mean = np.array([0.485, 0.456, 0.406]) 204 | std = np.array([0.229, 0.224, 0.225]) 205 | inp = std * inp + mean 206 | inp = np.clip(inp, 0, 1).transpose((2,0,1)) 207 | re.append(torch.from_numpy(inp).unsqueeze(0)) 208 | return torch.cat(re,0) 209 | 210 | 211 | def get_jet(): 212 | colormap_int = np.zeros((256, 3), np.uint8) 213 | 214 | for i in range(0, 256, 1): 215 | colormap_int[i, 0] = np.int_(np.round(cm.jet(i)[0] * 255.0)) 216 | colormap_int[i, 1] = np.int_(np.round(cm.jet(i)[1] * 255.0)) 217 | colormap_int[i, 2] = np.int_(np.round(cm.jet(i)[2] * 255.0)) 218 | 219 | return colormap_int 220 | 221 | def clamp(num, min_value, max_value): 222 | return max(min(num, max_value), min_value) 223 | 224 | def gray2color(gray_array, color_map): 225 | 226 | rows, cols = gray_array.shape 227 | color_array = np.zeros((rows, cols, 3), np.uint8) 228 | 229 | for i in range(0, rows): 230 | for j in range(0, cols): 231 | # log(256,2) = 8 , log(1,2) = 0 * 8 232 | color_array[i, j] = color_map[clamp(int(abs(gray_array[i, j])*10),0,255)] 233 | 234 | return color_array 235 | 236 | class objectview(object): 237 | def __init__(self, *args, **kwargs): 238 | d = dict(*args, **kwargs) 239 | self.__dict__ = d -------------------------------------------------------------------------------- /scripts/models/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import functools 7 | import math 8 | import numbers 9 | 10 | from scripts.utils.model_init import * 11 | from scripts.models.vgg import Vgg16 12 | from torch import nn, cuda 13 | from torch.autograd import Variable 14 | 15 | class BasicLearningBlock(nn.Module): 16 | """docstring for BasicLearningBlock""" 17 | def __init__(self,channel): 18 | super(BasicLearningBlock, self).__init__() 19 | self.rconv1 = nn.Conv2d(channel,channel*2,3,padding=1,bias=False) 20 | self.rbn1 = nn.BatchNorm2d(channel*2) 21 | self.rconv2 = nn.Conv2d(channel*2,channel,3,padding=1,bias=False) 22 | self.rbn2 = nn.BatchNorm2d(channel) 23 | 24 | def forward(self,feature): 25 | return F.elu(self.rbn2(self.rconv2(F.elu(self.rbn1(self.rconv1(feature)))))) 26 | 27 | 28 | 29 | # From https://discuss.pytorch.org/t/is-there-anyway-to-do-gaussian-filtering-for-an-image-2d-3d-in-pytorch/12351/3 30 | class GaussianSmoothing(nn.Module): 31 | """ 32 | Apply gaussian smoothing on a 33 | 1d, 2d or 3d tensor. Filtering is performed seperately for each channel 34 | in the input using a depthwise convolution. 35 | Arguments: 36 | channels (int, sequence): Number of channels of the input tensors. Output will 37 | have this number of channels as well. 38 | kernel_size (int, sequence): Size of the gaussian kernel. 39 | sigma (float, sequence): Standard deviation of the gaussian kernel. 40 | dim (int, optional): The number of dimensions of the data. 41 | Default value is 2 (spatial). 42 | """ 43 | def __init__(self, channels, kernel_size, sigma, dim=2): 44 | super(GaussianSmoothing, self).__init__() 45 | if isinstance(kernel_size, numbers.Number): 46 | kernel_size = [kernel_size] * dim 47 | if isinstance(sigma, numbers.Number): 48 | sigma = [sigma] * dim 49 | 50 | # The gaussian kernel is the product of the 51 | # gaussian function of each dimension. 52 | kernel = 1 53 | meshgrids = torch.meshgrid( 54 | [ 55 | torch.arange(size, dtype=torch.float32) 56 | for size in kernel_size 57 | ] 58 | ) 59 | for size, std, mgrid in zip(kernel_size, sigma, meshgrids): 60 | mean = (size - 1) / 2 61 | kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \ 62 | torch.exp(-((mgrid - mean) / (2 * std)) ** 2) 63 | 64 | # Make sure sum of values in gaussian kernel equals 1. 65 | kernel = kernel / torch.sum(kernel) 66 | 67 | # Reshape to depthwise convolutional weight 68 | kernel = kernel.view(1, 1, *kernel.size()) 69 | kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) 70 | 71 | self.register_buffer('weight', kernel) 72 | self.groups = channels 73 | 74 | if dim == 1: 75 | self.conv = F.conv1d 76 | elif dim == 2: 77 | self.conv = F.conv2d 78 | elif dim == 3: 79 | self.conv = F.conv3d 80 | else: 81 | raise RuntimeError( 82 | 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim) 83 | ) 84 | 85 | def forward(self, input): 86 | """ 87 | Apply gaussian filter to input. 88 | Arguments: 89 | input (torch.Tensor): Input to apply gaussian filter on. 90 | Returns: 91 | filtered (torch.Tensor): Filtered output. 92 | """ 93 | return self.conv(input, weight=self.weight, groups=self.groups) 94 | 95 | class ChannelPool(nn.Module): 96 | def __init__(self,types): 97 | super(ChannelPool, self).__init__() 98 | if types == 'avg': 99 | self.poolingx = nn.AdaptiveAvgPool1d(1) 100 | elif types == 'max': 101 | self.poolingx = nn.AdaptiveMaxPool1d(1) 102 | else: 103 | raise 'inner error' 104 | 105 | def forward(self, input): 106 | n, c, w, h = input.size() 107 | input = input.view(n,c,w*h).permute(0,2,1) 108 | pooled = self.poolingx(input)# b,w*h,c -> b,w*h,1 109 | _, _, c = pooled.size() 110 | return pooled.view(n,c,w,h) 111 | 112 | 113 | 114 | class SEBlock(nn.Module): 115 | """docstring for SEBlock""" 116 | def __init__(self, channel,reducation=16): 117 | super(SEBlock, self).__init__() 118 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 119 | self.fc = nn.Sequential( 120 | nn.Linear(channel,channel//reducation), 121 | nn.ReLU(inplace=True), 122 | nn.Linear(channel//reducation,channel), 123 | nn.Sigmoid()) 124 | 125 | def forward(self,x): 126 | b,c,w,h = x.size() 127 | y1 = self.avg_pool(x).view(b,c) 128 | y = self.fc(y1).view(b,c,1,1) 129 | return x*y 130 | 131 | 132 | 133 | class GlobalAttentionModule(nn.Module): 134 | """docstring for GlobalAttentionModule""" 135 | def __init__(self, channel,reducation=16): 136 | super(GlobalAttentionModule, self).__init__() 137 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 138 | self.max_pool = nn.AdaptiveMaxPool2d(1) 139 | self.fc = nn.Sequential( 140 | nn.Linear(channel*2,channel//reducation), 141 | nn.ReLU(inplace=True), 142 | nn.Linear(channel//reducation,channel), 143 | nn.Sigmoid()) 144 | 145 | def forward(self,x): 146 | b,c,w,h = x.size() 147 | y1 = self.avg_pool(x).view(b,c) 148 | y2 = self.max_pool(x).view(b,c) 149 | y = self.fc(torch.cat([y1,y2],1)).view(b,c,1,1) 150 | return x*y 151 | 152 | class SpatialAttentionModule(nn.Module): 153 | """docstring for SpatialAttentionModule""" 154 | def __init__(self, channel,reducation=16): 155 | super(SpatialAttentionModule, self).__init__() 156 | self.avg_pool = ChannelPool('avg') 157 | self.max_pool = ChannelPool('max') 158 | self.fc = nn.Sequential( 159 | nn.Conv2d(2,reducation,7,stride=1,padding=3), 160 | nn.ReLU(inplace=True), 161 | nn.Conv2d(reducation,1,7,stride=1,padding=3), 162 | nn.Sigmoid()) 163 | 164 | def forward(self,x): 165 | b,c,w,h = x.size() 166 | y1 = self.avg_pool(x) 167 | y2 = self.max_pool(x) 168 | y = self.fc(torch.cat([y1,y2],1)) 169 | yr = 1-y 170 | return y,yr 171 | 172 | 173 | 174 | class GlobalAttentionModuleJustSigmoid(nn.Module): 175 | """docstring for GlobalAttentionModule""" 176 | def __init__(self, channel,reducation=16): 177 | super(GlobalAttentionModuleJustSigmoid, self).__init__() 178 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 179 | self.max_pool = nn.AdaptiveMaxPool2d(1) 180 | self.fc = nn.Sequential( 181 | nn.Linear(channel*2,channel//reducation), 182 | nn.ReLU(inplace=True), 183 | nn.Linear(channel//reducation,channel), 184 | nn.Sigmoid()) 185 | 186 | def forward(self,x): 187 | b,c,w,h = x.size() 188 | y1 = self.avg_pool(x).view(b,c) 189 | y2 = self.max_pool(x).view(b,c) 190 | y = self.fc(torch.cat([y1,y2],1)).view(b,c,1,1) 191 | return y 192 | 193 | 194 | 195 | class BasicBlock(nn.Module): 196 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False): 197 | super(BasicBlock, self).__init__() 198 | self.out_channels = out_planes 199 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) 200 | self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None 201 | self.relu = nn.ReLU() if relu else None 202 | 203 | def forward(self, x): 204 | x = self.conv(x) 205 | if self.bn is not None: 206 | x = self.bn(x) 207 | if self.relu is not None: 208 | x = self.relu(x) 209 | return x 210 | 211 | class Flatten(nn.Module): 212 | def forward(self, x): 213 | return x.view(x.size(0), -1) 214 | 215 | class ChannelGate(nn.Module): 216 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): 217 | super(ChannelGate, self).__init__() 218 | self.gate_channels = gate_channels 219 | self.mlp = nn.Sequential( 220 | Flatten(), 221 | nn.Linear(gate_channels, gate_channels // reduction_ratio), 222 | nn.ReLU(), 223 | nn.Linear(gate_channels // reduction_ratio, gate_channels) 224 | ) 225 | self.pool_types = pool_types 226 | def forward(self, x): 227 | channel_att_sum = None 228 | for pool_type in self.pool_types: 229 | if pool_type=='avg': 230 | avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 231 | channel_att_raw = self.mlp( avg_pool ) 232 | elif pool_type=='max': 233 | max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 234 | channel_att_raw = self.mlp( max_pool ) 235 | elif pool_type=='lp': 236 | lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 237 | channel_att_raw = self.mlp( lp_pool ) 238 | elif pool_type=='lse': 239 | # LSE pool only 240 | lse_pool = logsumexp_2d(x) 241 | channel_att_raw = self.mlp( lse_pool ) 242 | 243 | if channel_att_sum is None: 244 | channel_att_sum = channel_att_raw 245 | else: 246 | channel_att_sum = channel_att_sum + channel_att_raw 247 | 248 | scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x) 249 | return x * scale 250 | 251 | def logsumexp_2d(tensor): 252 | tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1) 253 | s, _ = torch.max(tensor_flatten, dim=2, keepdim=True) 254 | outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log() 255 | return outputs 256 | 257 | class ChannelPoolX(nn.Module): 258 | def forward(self, x): 259 | return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 ) 260 | 261 | class SpatialGate(nn.Module): 262 | def __init__(self): 263 | super(SpatialGate, self).__init__() 264 | kernel_size = 7 265 | self.compress = ChannelPoolX() 266 | self.spatial = BasicBlock(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False) 267 | def forward(self, x): 268 | x_compress = self.compress(x) 269 | x_out = self.spatial(x_compress) 270 | scale = F.sigmoid(x_out) # broadcasting 271 | return x * scale 272 | 273 | class CBAM(nn.Module): 274 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False): 275 | super(CBAM, self).__init__() 276 | self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) 277 | self.no_spatial=no_spatial 278 | if not no_spatial: 279 | self.SpatialGate = SpatialGate() 280 | def forward(self, x): 281 | x_out = self.ChannelGate(x) 282 | if not self.no_spatial: 283 | x_out = self.SpatialGate(x_out) 284 | return x_out 285 | 286 | 287 | -------------------------------------------------------------------------------- /scripts/utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from scripts.models.vgg import Vgg19 5 | from torchvision import models 6 | from scripts.utils.misc import resize_to_match 7 | # from pytorch_msssim import SSIM, MS_SSIM 8 | import pytorch_ssim 9 | 10 | class WeightedBCE(nn.Module): 11 | def __init__(self): 12 | super(WeightedBCE, self).__init__() 13 | 14 | def forward(self, pred, gt): 15 | eposion = 1e-10 16 | sigmoid_pred = torch.sigmoid(pred) 17 | count_pos = torch.sum(gt)*1.0+eposion 18 | count_neg = torch.sum(1.-gt)*1.0 19 | beta = count_neg/count_pos 20 | beta_back = count_pos / (count_pos + count_neg) 21 | 22 | bce1 = nn.BCEWithLogitsLoss(pos_weight=beta) 23 | loss = beta_back*bce1(pred, gt) 24 | 25 | return loss 26 | 27 | 28 | def l1_relative(reconstructed, real, mask): 29 | batch = real.size(0) 30 | area = torch.sum(mask.view(batch,-1),dim=1) 31 | reconstructed = reconstructed * mask 32 | real = real * mask 33 | 34 | loss_l1 = torch.abs(reconstructed - real).view(batch, -1) 35 | loss_l1 = torch.sum(loss_l1, dim=1) / area 36 | loss_l1 = torch.sum(loss_l1) / batch 37 | return loss_l1 38 | 39 | 40 | def is_dic(x): 41 | return type(x) == type([]) 42 | 43 | class Losses(nn.Module): 44 | def __init__(self, argx, device): 45 | super(Losses, self).__init__() 46 | self.args = argx 47 | 48 | if self.args.loss_type == 'l1bl2': 49 | self.outputLoss, self.attLoss, self.wrloss = nn.L1Loss(), nn.BCELoss(), nn.MSELoss() 50 | elif self.args.loss_type == 'l1wbl2': 51 | self.outputLoss, self.attLoss, self.wrloss = nn.L1Loss(), WeightedBCE(), nn.MSELoss() 52 | elif self.args.loss_type == 'l2wbl2': 53 | self.outputLoss, self.attLoss, self.wrloss = nn.MSELoss(), WeightedBCE(), nn.MSELoss() 54 | elif self.args.loss_type == 'l2xbl2': 55 | self.outputLoss, self.attLoss, self.wrloss = nn.MSELoss(), nn.BCEWithLogitsLoss(), nn.MSELoss() 56 | else: # l2bl2 57 | self.outputLoss, self.attLoss, self.wrloss = nn.MSELoss(), nn.BCELoss(), nn.MSELoss() 58 | 59 | if self.args.style_loss > 0: 60 | self.vggloss = VGGLoss(self.args.sltype).to(device) 61 | 62 | if self.args.ssim_loss > 0: 63 | self.ssimloss = pytorch_ssim.SSIM().to(device) 64 | 65 | self.outputLoss = self.outputLoss.to(device) 66 | self.attLoss = self.attLoss.to(device) 67 | self.wrloss = self.wrloss.to(device) 68 | 69 | 70 | def forward(self,imgx,target,attx,mask,wmx,wm): 71 | pixel_loss,att_loss,wm_loss,vgg_loss,ssim_loss = 0,0,0,0,0 72 | 73 | if is_dic(imgx): 74 | 75 | if self.args.masked: 76 | # calculate the overall loss and side output 77 | pixel_loss = self.outputLoss(imgx[0],target) + sum([self.outputLoss(im,resize_to_match(mask,im)*resize_to_match(target,im)) for im in imgx[1:]]) 78 | else: 79 | pixel_loss = sum([self.outputLoss(im,resize_to_match(target,im)) for im in imgx]) 80 | 81 | if self.args.style_loss > 0: 82 | vgg_loss = sum([self.vggloss(im,resize_to_match(target,im),resize_to_match(mask,im)) for im in imgx]) 83 | 84 | if self.args.ssim_loss > 0: 85 | ssim_loss = sum([ 1 - self.ssimloss(im,resize_to_match(target,im)) for im in imgx]) 86 | else: 87 | 88 | if self.args.masked: 89 | pixel_loss = self.outputLoss(imgx,mask*target) 90 | else: 91 | pixel_loss = self.outputLoss(imgx,target) 92 | 93 | if self.args.style_loss > 0: 94 | vgg_loss = self.vggloss(imgx,target,mask) 95 | 96 | if self.args.ssim_loss > 0: 97 | ssim_loss = 1 - self.ssimloss(imgx,target) 98 | 99 | if is_dic(attx): 100 | att_loss = sum([self.attLoss(at,resize_to_match(mask,at)) for at in attx]) 101 | else: 102 | att_loss = self.attLoss(attx, mask) 103 | 104 | if is_dic(wmx): 105 | wm_loss = sum([self.wrloss(w,resize_to_match(wm,w)) for w in wmx]) 106 | else: 107 | if self.args.masked: 108 | wm_loss = self.wrloss(wmx,mask*wm) 109 | else: 110 | wm_loss = self.wrloss(wmx, wm) 111 | 112 | return pixel_loss,att_loss,wm_loss,vgg_loss,ssim_loss 113 | 114 | 115 | 116 | def gram_matrix(feat): 117 | # https://github.com/pytorch/examples/blob/master/fast_neural_style/neural_style/utils.py 118 | (b, ch, h, w) = feat.size() 119 | feat = feat.view(b, ch, h * w) 120 | feat_t = feat.transpose(1, 2) 121 | gram = torch.bmm(feat, feat_t) / (ch * h * w) 122 | return gram 123 | 124 | class MeanShift(nn.Conv2d): 125 | def __init__(self, data_mean, data_std, data_range=1, norm=True): 126 | """norm (bool): normalize/denormalize the stats""" 127 | c = len(data_mean) 128 | super(MeanShift, self).__init__(c, c, kernel_size=1) 129 | std = torch.Tensor(data_std) 130 | self.weight.data = torch.eye(c).view(c, c, 1, 1) 131 | if norm: 132 | self.weight.data.div_(std.view(c, 1, 1, 1)) 133 | self.bias.data = -1 * data_range * torch.Tensor(data_mean) 134 | self.bias.data.div_(std) 135 | else: 136 | self.weight.data.mul_(std.view(c, 1, 1, 1)) 137 | self.bias.data = data_range * torch.Tensor(data_mean) 138 | self.requires_grad = False 139 | 140 | 141 | 142 | def VGGLoss(losstype): 143 | if losstype == 'vgg': 144 | return VGGLossA() 145 | elif losstype == 'vggx': 146 | return VGGLossX(mask=False) 147 | elif losstype == 'mvggx': 148 | return VGGLossX(mask=True) 149 | elif losstype == 'rvggx': 150 | return VGGLossX(mask=True,relative=True) 151 | else: 152 | raise Exception("error in %s"%losstype) 153 | 154 | 155 | 156 | class VGGLossA(nn.Module): 157 | def __init__(self, vgg=None, weights=None, indices=None, normalize=True): 158 | super(VGGLossA, self).__init__() 159 | if vgg is None: 160 | self.vgg = Vgg19().cuda() 161 | else: 162 | self.vgg = vgg 163 | self.criterion = nn.L1Loss() 164 | self.weights = weights or [1.0/2.6, 1.0/4.8, 1.0/3.7, 1.0/5.6, 10/1.5] 165 | self.indices = indices or [2, 7, 12, 21, 30] 166 | if normalize: 167 | self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda() 168 | else: 169 | self.normalize = None 170 | 171 | def forward(self, x, y): 172 | if self.normalize is not None: 173 | x = self.normalize(x) 174 | y = self.normalize(y) 175 | x_vgg, y_vgg = self.vgg(x, self.indices), self.vgg(y, self.indices) 176 | loss = 0 177 | for i in range(len(x_vgg)): 178 | loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) 179 | return loss 180 | 181 | 182 | class VGG16FeatureExtractor(nn.Module): 183 | def __init__(self): 184 | super().__init__() 185 | vgg16 = models.vgg16(pretrained=True) 186 | self.enc_1 = nn.Sequential(*vgg16.features[:5]) 187 | self.enc_2 = nn.Sequential(*vgg16.features[5:10]) 188 | self.enc_3 = nn.Sequential(*vgg16.features[10:17]) 189 | 190 | # fix the encoder 191 | for i in range(3): 192 | for param in getattr(self, 'enc_{:d}'.format(i + 1)).parameters(): 193 | param.requires_grad = False 194 | 195 | def forward(self, image): 196 | results = [image] 197 | for i in range(3): 198 | func = getattr(self, 'enc_{:d}'.format(i + 1)) 199 | results.append(func(results[-1])) 200 | return results[1:] 201 | 202 | class VGGLossX(nn.Module): 203 | def __init__(self, normalize=True, mask=False, relative=False): 204 | super(VGGLossX, self).__init__() 205 | 206 | self.vgg = VGG16FeatureExtractor().cuda() 207 | self.criterion = nn.L1Loss().cuda() if not relative else l1_relative 208 | self.use_mask= mask 209 | self.relative = relative 210 | 211 | if normalize: 212 | self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda() 213 | else: 214 | self.normalize = None 215 | 216 | def forward(self, x, y, Xmask=None): 217 | if not self.use_mask: 218 | mask = torch.ones_like(x)[:,0:1,:,:] 219 | else: 220 | mask = Xmask 221 | 222 | if self.normalize is not None: 223 | x = self.normalize(x) 224 | y = self.normalize(y) 225 | 226 | x_vgg = self.vgg(x) 227 | y_vgg = self.vgg(y) 228 | 229 | loss = 0 230 | for i in range(3): 231 | if self.relative: 232 | loss += self.criterion(x_vgg[i],y_vgg[i].detach(),resize_to_match(mask,x_vgg[i])) 233 | else: 234 | loss += self.criterion(resize_to_match(mask,x_vgg[i])*x_vgg[i],resize_to_match(mask,y_vgg[i])*y_vgg[i].detach()) 235 | 236 | return loss 237 | 238 | 239 | class GANLosses(object): 240 | """docstring for Loss""" 241 | def __init__(self, gantype): 242 | super(GANLosses, self).__init__() 243 | self.generator_loss = gen_gan(gantype) 244 | self.discriminator_loss = dis_gan(gantype) 245 | self.gantype = gantype 246 | 247 | def g_loss(self,dis_fake): 248 | if 'hinge' in self.gantype: 249 | return gen_hinge(dis_fake) 250 | else: 251 | return self.generator_loss(dis_fake) 252 | 253 | def d_loss(self,dis_fake,dis_real): 254 | if 'hinge' in self.gantype: 255 | return dis_hinge(dis_fake,dis_real) 256 | else: 257 | return self.discriminator_loss(dis_fake,dis_real) 258 | 259 | 260 | class gen_gan(nn.Module): 261 | def __init__(self,gantype): 262 | super(gen_gan,self).__init__() 263 | if gantype == 'lsgan': 264 | self.criterion = nn.MSELoss() 265 | elif gantype == 'naive': 266 | self.criterion = nn.BCEWithLogitsLoss() 267 | else: 268 | raise Exception("error gan type") 269 | 270 | def forward(self,dis_fake): 271 | return self.criterion(dis_fake, torch.ones_like(dis_fake)) 272 | 273 | class dis_gan(nn.Module): 274 | def __init__(self,gantype): 275 | super(dis_gan,self).__init__() 276 | if gantype == 'lsgan': 277 | self.criterion = nn.MSELoss() 278 | elif gantype == 'naive': 279 | self.criterion = nn.BCEWithLogitsLoss() 280 | else: 281 | raise Exception("error gan type") 282 | 283 | def forward(self,dis_fake,dis_real): 284 | loss_fake = self.criterion(dis_fake, torch.zeros_like(dis_fake)) 285 | loss_real = self.criterion(dis_real, torch.ones_like(dis_real)) 286 | return loss_fake, loss_real 287 | 288 | # def gen_gan(dis_fake): 289 | # # fake -> 1 290 | # return F.binary_cross_entropy_with_logits(dis_fake,torch.ones_like(dis_fake)) 291 | 292 | # def dis_gan(dis_fake,dis_real): 293 | # # fake -> 0 , real ->1 294 | # loss_fake = F.binary_cross_entropy_with_logits(dis_fake, torch.zeros_like(dis_real)) 295 | # loss_real = F.binary_cross_entropy_with_logits(dis_real, torch.ones_like(dis_fake)) 296 | # return loss_fake,loss_real 297 | 298 | # def gen_lsgan(dis_fake): 299 | # loss = F.mse_loss(dis_fake,torch.ones_like(dis_fake)) # 300 | # return loss 301 | 302 | # def dis_lsgan(dis_fake, dis_real): 303 | # loss_fake = F.mse_loss(dis_fake, torch.zeros_like(dis_real)) 304 | # loss_real = F.mse_loss(dis_real, torch.ones_like(dis_real)) 305 | # return loss_fake,loss_real 306 | 307 | def gen_hinge(dis_fake, dis_real=None): 308 | return -torch.mean(dis_fake) 309 | 310 | def dis_hinge(dis_fake, dis_real): 311 | loss_fake = torch.mean(torch.relu(1. + dis_fake)) 312 | loss_real = torch.mean(torch.relu(1. - dis_real)) 313 | return loss_fake,loss_real 314 | 315 | -------------------------------------------------------------------------------- /scripts/machines/S2AM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.backends.cudnn as cudnn 4 | from progress.bar import Bar 5 | import json 6 | import numpy as np 7 | from tensorboardX import SummaryWriter 8 | from scripts.utils.evaluation import accuracy, AverageMeter, final_preds 9 | from scripts.utils.osutils import mkdir_p, isfile, isdir, join 10 | from scripts.utils.parallel import DataParallelModel, DataParallelCriterion 11 | import pytorch_ssim as pytorch_ssim 12 | import torch.optim 13 | import sys,shutil,os 14 | import time 15 | import scripts.models as archs 16 | from math import log10 17 | from torch.autograd import Variable 18 | from scripts.utils.losses import VGGLoss 19 | from scripts.utils.imutils import im_to_numpy 20 | 21 | import skimage.io 22 | from skimage.measure import compare_psnr,compare_ssim 23 | 24 | 25 | class S2AM(object): 26 | def __init__(self, datasets =(None,None), models = None, args = None, **kwargs): 27 | super(S2AM, self).__init__() 28 | 29 | self.args = args 30 | 31 | # create model 32 | print("==> creating model ") 33 | self.model = archs.__dict__[self.args.arch]() 34 | print("==> creating model [Finish]") 35 | 36 | self.train_loader, self.val_loader = datasets 37 | self.loss = torch.nn.MSELoss() 38 | 39 | self.title = '_'+args.machine + '_' + args.data + '_' + args.arch 40 | self.args.checkpoint = args.checkpoint + self.title 41 | self.device = torch.device('cuda') 42 | # create checkpoint dir 43 | if not isdir(self.args.checkpoint): 44 | mkdir_p(self.args.checkpoint) 45 | 46 | self.optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), 47 | lr=args.lr, 48 | betas=(args.beta1,args.beta2), 49 | weight_decay=args.weight_decay) 50 | 51 | if not self.args.evaluate: 52 | self.writer = SummaryWriter(self.args.checkpoint+'/'+'ckpt') 53 | 54 | self.best_acc = 0 55 | self.is_best = False 56 | self.current_epoch = 0 57 | self.hl = 1 58 | self.metric = -100000 59 | self.count_gpu = len(range(torch.cuda.device_count())) 60 | 61 | if self.args.style_loss > 0: 62 | # init perception loss 63 | self.vggloss = VGGLoss(self.args.sltype).to(self.device) 64 | 65 | if self.count_gpu > 1 : # multiple 66 | # self.model = DataParallelModel(self.model, device_ids=range(torch.cuda.device_count())) 67 | # self.loss = DataParallelCriterion(self.loss, device_ids=range(torch.cuda.device_count())) 68 | self.model = nn.DataParallel(self.model, device_ids=range(torch.cuda.device_count())) 69 | 70 | self.model.to(self.device) 71 | self.loss.to(self.device) 72 | 73 | print('==> Total params: %.2fM' % (sum(p.numel() for p in self.model.parameters())/1000000.0)) 74 | print('==> Total devices: %d' % (torch.cuda.device_count())) 75 | print('==> Current Checkpoint: %s' % (self.args.checkpoint)) 76 | 77 | 78 | if self.args.resume != '': 79 | self.resume(self.args.resume) 80 | 81 | 82 | def train(self,epoch): 83 | batch_time = AverageMeter() 84 | data_time = AverageMeter() 85 | losses = AverageMeter() 86 | lossvgg = AverageMeter() 87 | 88 | # switch to train mode 89 | self.model.train() 90 | end = time.time() 91 | 92 | bar = Bar('Processing', max=len(self.train_loader)*self.hl) 93 | for _ in range(self.hl): 94 | for i, batches in enumerate(self.train_loader): 95 | # measure data loading time 96 | inputs = batches['image'].to(self.device) 97 | target = batches['target'].to(self.device) 98 | mask =batches['mask'].to(self.device) 99 | current_index = len(self.train_loader) * epoch + i 100 | 101 | feeded = torch.cat([inputs,mask],dim=1) 102 | feeded = feeded.to(self.device) 103 | 104 | output = self.model(feeded) 105 | 106 | if self.args.res: 107 | output = output + inputs 108 | 109 | L2_loss = self.loss(output,target) 110 | 111 | if self.args.style_loss > 0: 112 | vgg_loss = self.vggloss(output,target,mask) 113 | else: 114 | vgg_loss = 0 115 | 116 | total_loss = L2_loss + self.args.style_loss * vgg_loss 117 | 118 | # compute gradient and do SGD step 119 | self.optimizer.zero_grad() 120 | total_loss.backward() 121 | self.optimizer.step() 122 | 123 | # measure accuracy and record loss 124 | losses.update(L2_loss.item(), inputs.size(0)) 125 | 126 | if self.args.style_loss > 0 : 127 | lossvgg.update(vgg_loss.item(), inputs.size(0)) 128 | 129 | # measure elapsed time 130 | batch_time.update(time.time() - end) 131 | end = time.time() 132 | 133 | # plot progress 134 | suffix = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss L2: {loss_label:.4f} | Loss VGG: {loss_vgg:.4f}'.format( 135 | batch=i + 1, 136 | size=len(self.train_loader), 137 | data=data_time.val, 138 | bt=batch_time.val, 139 | total=bar.elapsed_td, 140 | eta=bar.eta_td, 141 | loss_label=losses.avg, 142 | loss_vgg=lossvgg.avg 143 | ) 144 | 145 | if current_index % 1000 == 0: 146 | print(suffix) 147 | 148 | if self.args.freq > 0 and current_index % self.args.freq == 0: 149 | self.validate(current_index) 150 | self.flush() 151 | self.save_checkpoint() 152 | 153 | self.record('train/loss_L2', losses.avg, current_index) 154 | 155 | 156 | def test(self, ): 157 | 158 | # switch to evaluate mode 159 | self.model.eval() 160 | 161 | ssimes = AverageMeter() 162 | psnres = AverageMeter() 163 | 164 | with torch.no_grad(): 165 | for i, batches in enumerate(self.val_loader): 166 | 167 | inputs = batches['image'].to(self.device) 168 | target = batches['target'].to(self.device) 169 | mask =batches['mask'].to(self.device) 170 | 171 | feeded = torch.cat([inputs,mask],dim=1) 172 | feeded = feeded.to(self.device) 173 | 174 | output = self.model(feeded) 175 | 176 | if self.args.res: 177 | output = output + inputs 178 | 179 | # recover the image to 255 180 | output = im_to_numpy(torch.clamp(output[0]*255,min=0.0,max=255.0)).astype(np.uint8) 181 | target = im_to_numpy(torch.clamp(target[0]*255,min=0.0,max=255.0)).astype(np.uint8) 182 | 183 | skimage.io.imsave('%s/%s'%(self.args.checkpoint,batches['name'][0]), output) 184 | 185 | psnr = compare_psnr(target,output) 186 | ssim = compare_ssim(target,output,multichannel=True) 187 | 188 | psnres.update(psnr, inputs.size(0)) 189 | ssimes.update(ssim, inputs.size(0)) 190 | 191 | print("%s:PSNR:%s,SSIM:%s"%(self.args.checkpoint,psnres.avg,ssimes.avg)) 192 | print("DONE.\n") 193 | 194 | 195 | def validate(self, epoch): 196 | batch_time = AverageMeter() 197 | data_time = AverageMeter() 198 | losses = AverageMeter() 199 | ssimes = AverageMeter() 200 | psnres = AverageMeter() 201 | # switch to evaluate mode 202 | self.model.eval() 203 | 204 | end = time.time() 205 | with torch.no_grad(): 206 | for i, batches in enumerate(self.val_loader): 207 | 208 | inputs = batches['image'].to(self.device) 209 | target = batches['target'].to(self.device) 210 | mask =batches['mask'].to(self.device) 211 | 212 | feeded = torch.cat([inputs,mask],dim=1) 213 | feeded = feeded.to(self.device) 214 | 215 | output = self.model(feeded) 216 | 217 | if self.args.res: 218 | output = output + inputs 219 | 220 | L2_loss = self.loss(output, target) 221 | 222 | psnr = 10 * log10(1 / L2_loss.item()) 223 | ssim = pytorch_ssim.ssim(output, target) 224 | 225 | losses.update(L2_loss.item(), inputs.size(0)) 226 | psnres.update(psnr, inputs.size(0)) 227 | ssimes.update(ssim.item(), inputs.size(0)) 228 | 229 | # measure elapsed time 230 | batch_time.update(time.time() - end) 231 | end = time.time() 232 | 233 | print("Epoches:%s,Losses:%.3f,PSNR:%.3f,SSIM:%.3f"%(epoch+1, losses.avg,psnres.avg,ssimes.avg)) 234 | self.record('val/loss_L2', losses.avg, epoch) 235 | self.record('val/PSNR', psnres.avg, epoch) 236 | self.record('val/SSIM', ssimes.avg, epoch) 237 | 238 | self.metric = psnres.avg 239 | 240 | def resume(self,resume_path): 241 | if isfile(resume_path): 242 | print("=> loading checkpoint '{}'".format(resume_path)) 243 | current_checkpoint = torch.load(resume_path) 244 | if isinstance(current_checkpoint['state_dict'], torch.nn.DataParallel): 245 | current_checkpoint['state_dict'] = current_checkpoint['state_dict'].module 246 | 247 | if isinstance(current_checkpoint['optimizer'], torch.nn.DataParallel): 248 | current_checkpoint['optimizer'] = current_checkpoint['optimizer'].module 249 | 250 | self.args.start_epoch = current_checkpoint['epoch'] 251 | self.metric = current_checkpoint['best_acc'] 252 | self.model.load_state_dict(current_checkpoint['state_dict']) 253 | # self.optimizer.load_state_dict(current_checkpoint['optimizer']) 254 | print("=> loaded checkpoint '{}' (epoch {})" 255 | .format(resume_path, current_checkpoint['epoch'])) 256 | else: 257 | raise Exception("=> no checkpoint found at '{}'".format(resume_path)) 258 | 259 | def save_checkpoint(self,filename='checkpoint.pth.tar', snapshot=None): 260 | is_best = True if self.best_acc < self.metric else False 261 | 262 | if is_best: 263 | self.best_acc = self.metric 264 | 265 | state = { 266 | 'epoch': self.current_epoch + 1, 267 | 'arch': self.args.arch, 268 | 'state_dict': self.model.state_dict(), 269 | 'best_acc': self.best_acc, 270 | 'optimizer' : self.optimizer.state_dict() if self.optimizer else None, 271 | } 272 | 273 | filepath = os.path.join(self.args.checkpoint, filename) 274 | torch.save(state, filepath) 275 | 276 | if snapshot and state['epoch'] % snapshot == 0: 277 | shutil.copyfile(filepath, os.path.join(self.args.checkpoint, 'checkpoint_{}.pth.tar'.format(state.epoch))) 278 | 279 | if is_best: 280 | self.best_acc = self.metric 281 | print('Saving Best Metric with PSNR:%s'%self.best_acc) 282 | shutil.copyfile(filepath, os.path.join(self.args.checkpoint, 'model_best.pth.tar')) 283 | 284 | def clean(self): 285 | self.writer.close() 286 | 287 | def record(self,k,v,epoch): 288 | self.writer.add_scalar(k, v, epoch) 289 | 290 | def flush(self): 291 | self.writer.flush() 292 | sys.stdout.flush() 293 | 294 | def norm(self,x): 295 | if self.args.gan_norm: 296 | return x*2.0 - 1.0 297 | else: 298 | return x 299 | 300 | def denorm(self,x): 301 | if self.args.gan_norm: 302 | return (x+1.0)/2.0 303 | else: 304 | return x 305 | 306 | -------------------------------------------------------------------------------- /scripts/utils/parallel.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang, Rutgers University, Email: zhang.hang@rutgers.edu 3 | ## Modified by Thomas Wolf, HuggingFace Inc., Email: thomas@huggingface.co 4 | ## Copyright (c) 2017-2018 5 | ## 6 | ## This source code is licensed under the MIT-style license found in the 7 | ## LICENSE file in the root directory of this source tree 8 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 9 | 10 | """Encoding Data Parallel""" 11 | import threading 12 | import functools 13 | import torch 14 | from torch.autograd import Variable, Function 15 | import torch.cuda.comm as comm 16 | from torch.nn.parallel import DistributedDataParallel 17 | from torch.nn.parallel.data_parallel import DataParallel 18 | from torch.nn.parallel.parallel_apply import get_a_var 19 | from torch.nn.parallel.scatter_gather import gather 20 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 21 | 22 | torch_ver = torch.__version__[:3] 23 | 24 | __all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion', 25 | 'patch_replication_callback'] 26 | 27 | def allreduce(*inputs): 28 | """Cross GPU all reduce autograd operation for calculate mean and 29 | variance in SyncBN. 30 | """ 31 | return AllReduce.apply(*inputs) 32 | 33 | class AllReduce(Function): 34 | @staticmethod 35 | def forward(ctx, num_inputs, *inputs): 36 | ctx.num_inputs = num_inputs 37 | ctx.target_gpus = [inputs[i].get_device() for i in range(0, len(inputs), num_inputs)] 38 | inputs = [inputs[i:i + num_inputs] 39 | for i in range(0, len(inputs), num_inputs)] 40 | # sort before reduce sum 41 | inputs = sorted(inputs, key=lambda i: i[0].get_device()) 42 | results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0]) 43 | outputs = comm.broadcast_coalesced(results, ctx.target_gpus) 44 | return tuple([t for tensors in outputs for t in tensors]) 45 | 46 | @staticmethod 47 | def backward(ctx, *inputs): 48 | inputs = [i.data for i in inputs] 49 | inputs = [inputs[i:i + ctx.num_inputs] 50 | for i in range(0, len(inputs), ctx.num_inputs)] 51 | results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0]) 52 | outputs = comm.broadcast_coalesced(results, ctx.target_gpus) 53 | return (None,) + tuple([Variable(t) for tensors in outputs for t in tensors]) 54 | 55 | 56 | class Reduce(Function): 57 | @staticmethod 58 | def forward(ctx, *inputs): 59 | ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))] 60 | inputs = sorted(inputs, key=lambda i: i.get_device()) 61 | return comm.reduce_add(inputs) 62 | 63 | @staticmethod 64 | def backward(ctx, gradOutput): 65 | return Broadcast.apply(ctx.target_gpus, gradOutput) 66 | 67 | class DistributedDataParallelModel(DistributedDataParallel): 68 | """Implements data parallelism at the module level for the DistributedDataParallel module. 69 | This container parallelizes the application of the given module by 70 | splitting the input across the specified devices by chunking in the 71 | batch dimension. 72 | In the forward pass, the module is replicated on each device, 73 | and each replica handles a portion of the input. During the backwards pass, 74 | gradients from each replica are summed into the original module. 75 | Note that the outputs are not gathered, please use compatible 76 | :class:`encoding.parallel.DataParallelCriterion`. 77 | The batch size should be larger than the number of GPUs used. It should 78 | also be an integer multiple of the number of GPUs so that each chunk is 79 | the same size (so that each GPU processes the same number of samples). 80 | Args: 81 | module: module to be parallelized 82 | device_ids: CUDA devices (default: all devices) 83 | Reference: 84 | Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, 85 | Amit Agrawal. “Context Encoding for Semantic Segmentation. 86 | *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018* 87 | Example:: 88 | >>> net = encoding.nn.DistributedDataParallelModel(model, device_ids=[0, 1, 2]) 89 | >>> y = net(x) 90 | """ 91 | def gather(self, outputs, output_device): 92 | return outputs 93 | 94 | class DataParallelModel(DataParallel): 95 | """Implements data parallelism at the module level. 96 | 97 | This container parallelizes the application of the given module by 98 | splitting the input across the specified devices by chunking in the 99 | batch dimension. 100 | In the forward pass, the module is replicated on each device, 101 | and each replica handles a portion of the input. During the backwards pass, 102 | gradients from each replica are summed into the original module. 103 | Note that the outputs are not gathered, please use compatible 104 | :class:`encoding.parallel.DataParallelCriterion`. 105 | 106 | The batch size should be larger than the number of GPUs used. It should 107 | also be an integer multiple of the number of GPUs so that each chunk is 108 | the same size (so that each GPU processes the same number of samples). 109 | 110 | Args: 111 | module: module to be parallelized 112 | device_ids: CUDA devices (default: all devices) 113 | 114 | Reference: 115 | Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, 116 | Amit Agrawal. “Context Encoding for Semantic Segmentation. 117 | *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018* 118 | 119 | Example:: 120 | 121 | >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2]) 122 | >>> y = net(x) 123 | """ 124 | def gather(self, outputs, output_device): 125 | return outputs 126 | 127 | def replicate(self, module, device_ids): 128 | modules = super(DataParallelModel, self).replicate(module, device_ids) 129 | execute_replication_callbacks(modules) 130 | return modules 131 | 132 | 133 | class DataParallelCriterion(DataParallel): 134 | """ 135 | Calculate loss in multiple-GPUs, which balance the memory usage. 136 | The targets are splitted across the specified devices by chunking in 137 | the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`. 138 | 139 | Reference: 140 | Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, 141 | Amit Agrawal. “Context Encoding for Semantic Segmentation. 142 | *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018* 143 | 144 | Example:: 145 | 146 | >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2]) 147 | >>> criterion = encoding.nn.DataParallelCriterion(criterion, device_ids=[0, 1, 2]) 148 | >>> y = net(x) 149 | >>> loss = criterion(y, target) 150 | """ 151 | def forward(self, inputs, *targets, **kwargs): 152 | # input should be already scatterd 153 | # scattering the targets instead 154 | if not self.device_ids: 155 | return self.module(inputs, *targets, **kwargs) 156 | targets, kwargs = self.scatter(targets, kwargs, self.device_ids) 157 | if len(self.device_ids) == 1: 158 | return self.module(inputs, *targets[0], **kwargs[0]) 159 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) 160 | outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs) 161 | #return Reduce.apply(*outputs) / len(outputs) 162 | #return self.gather(outputs, self.output_device).mean() 163 | return self.gather(outputs, self.output_device) 164 | 165 | 166 | def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None): 167 | assert len(modules) == len(inputs) 168 | assert len(targets) == len(inputs) 169 | if kwargs_tup: 170 | assert len(modules) == len(kwargs_tup) 171 | else: 172 | kwargs_tup = ({},) * len(modules) 173 | if devices is not None: 174 | assert len(modules) == len(devices) 175 | else: 176 | devices = [None] * len(modules) 177 | 178 | lock = threading.Lock() 179 | results = {} 180 | if torch_ver != "0.3": 181 | grad_enabled = torch.is_grad_enabled() 182 | 183 | def _worker(i, module, input, target, kwargs, device=None): 184 | if torch_ver != "0.3": 185 | torch.set_grad_enabled(grad_enabled) 186 | if device is None: 187 | device = get_a_var(input).get_device() 188 | try: 189 | with torch.cuda.device(device): 190 | # this also avoids accidental slicing of `input` if it is a Tensor 191 | if not isinstance(input, (list, tuple)): 192 | input = (input,) 193 | if not isinstance(target, (list, tuple)): 194 | target = (target,) 195 | output = module(*(input + target), **kwargs) 196 | with lock: 197 | results[i] = output 198 | except Exception as e: 199 | with lock: 200 | results[i] = e 201 | 202 | if len(modules) > 1: 203 | threads = [threading.Thread(target=_worker, 204 | args=(i, module, input, target, 205 | kwargs, device),) 206 | for i, (module, input, target, kwargs, device) in 207 | enumerate(zip(modules, inputs, targets, kwargs_tup, devices))] 208 | 209 | for thread in threads: 210 | thread.start() 211 | for thread in threads: 212 | thread.join() 213 | else: 214 | _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) 215 | 216 | outputs = [] 217 | for i in range(len(inputs)): 218 | output = results[i] 219 | if isinstance(output, Exception): 220 | raise output 221 | outputs.append(output) 222 | return outputs 223 | 224 | 225 | ########################################################################### 226 | # Adapted from Synchronized-BatchNorm-PyTorch. 227 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 228 | # 229 | class CallbackContext(object): 230 | pass 231 | 232 | 233 | def execute_replication_callbacks(modules): 234 | """ 235 | Execute an replication callback `__data_parallel_replicate__` on each module created 236 | by original replication. 237 | 238 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 239 | 240 | Note that, as all modules are isomorphism, we assign each sub-module with a context 241 | (shared among multiple copies of this module on different devices). 242 | Through this context, different copies can share some information. 243 | 244 | We guarantee that the callback on the master copy (the first copy) will be called ahead 245 | of calling the callback of any slave copies. 246 | """ 247 | master_copy = modules[0] 248 | nr_modules = len(list(master_copy.modules())) 249 | ctxs = [CallbackContext() for _ in range(nr_modules)] 250 | 251 | for i, module in enumerate(modules): 252 | for j, m in enumerate(module.modules()): 253 | if hasattr(m, '__data_parallel_replicate__'): 254 | m.__data_parallel_replicate__(ctxs[j], i) 255 | 256 | 257 | def patch_replication_callback(data_parallel): 258 | """ 259 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 260 | Useful when you have customized `DataParallel` implementation. 261 | 262 | Examples: 263 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 264 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 265 | > patch_replication_callback(sync_bn) 266 | # this is equivalent to 267 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 268 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 269 | """ 270 | 271 | assert isinstance(data_parallel, DataParallel) 272 | 273 | old_replicate = data_parallel.replicate 274 | 275 | @functools.wraps(old_replicate) 276 | def new_replicate(module, device_ids): 277 | modules = old_replicate(module, device_ids) 278 | execute_replication_callbacks(modules) 279 | return modules 280 | 281 | data_parallel.replicate = new_replicate -------------------------------------------------------------------------------- /scripts/machines/BasicMachine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.backends.cudnn as cudnn 4 | from progress.bar import Bar 5 | import json 6 | import numpy as np 7 | from tensorboardX import SummaryWriter 8 | from scripts.utils.evaluation import accuracy, AverageMeter, final_preds 9 | from scripts.utils.osutils import mkdir_p, isfile, isdir, join 10 | from scripts.utils.parallel import DataParallelModel, DataParallelCriterion 11 | import pytorch_ssim as pytorch_ssim 12 | import torch.optim 13 | import sys,shutil,os 14 | import time 15 | import scripts.models as archs 16 | from math import log10 17 | from torch.autograd import Variable 18 | from scripts.utils.losses import VGGLoss 19 | from scripts.utils.imutils import im_to_numpy 20 | 21 | import skimage.io 22 | from skimage.measure import compare_psnr,compare_ssim 23 | 24 | 25 | class BasicMachine(object): 26 | def __init__(self, datasets =(None,None), models = None, args = None, **kwargs): 27 | super(BasicMachine, self).__init__() 28 | 29 | self.args = args 30 | 31 | # create model 32 | print("==> creating model ") 33 | self.model = archs.__dict__[self.args.arch]() 34 | print("==> creating model [Finish]") 35 | 36 | self.train_loader, self.val_loader = datasets 37 | self.loss = torch.nn.MSELoss() 38 | 39 | self.title = '_'+args.machine + '_' + args.data + '_' + args.arch 40 | self.args.checkpoint = args.checkpoint + self.title 41 | self.device = torch.device('cuda') 42 | # create checkpoint dir 43 | if not isdir(self.args.checkpoint): 44 | mkdir_p(self.args.checkpoint) 45 | 46 | self.optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), 47 | lr=args.lr, 48 | betas=(args.beta1,args.beta2), 49 | weight_decay=args.weight_decay) 50 | 51 | if not self.args.evaluate: 52 | self.writer = SummaryWriter(self.args.checkpoint+'/'+'ckpt') 53 | 54 | self.best_acc = 0 55 | self.is_best = False 56 | self.current_epoch = 0 57 | self.metric = -100000 58 | self.hl = 6 if self.args.hl else 1 59 | self.count_gpu = len(range(torch.cuda.device_count())) 60 | 61 | if self.args.style_loss > 0: 62 | # init perception loss 63 | self.vggloss = VGGLoss(self.args.sltype).to(self.device) 64 | 65 | if self.count_gpu > 1 : # multiple 66 | # self.model = DataParallelModel(self.model, device_ids=range(torch.cuda.device_count())) 67 | # self.loss = DataParallelCriterion(self.loss, device_ids=range(torch.cuda.device_count())) 68 | self.model = nn.DataParallel(self.model, device_ids=range(torch.cuda.device_count())) 69 | 70 | self.model.to(self.device) 71 | self.loss.to(self.device) 72 | 73 | print('==> Total params: %.2fM' % (sum(p.numel() for p in self.model.parameters())/1000000.0)) 74 | print('==> Total devices: %d' % (torch.cuda.device_count())) 75 | print('==> Current Checkpoint: %s' % (self.args.checkpoint)) 76 | 77 | 78 | if self.args.resume != '': 79 | self.resume(self.args.resume) 80 | 81 | 82 | def train(self,epoch): 83 | batch_time = AverageMeter() 84 | data_time = AverageMeter() 85 | losses = AverageMeter() 86 | lossvgg = AverageMeter() 87 | 88 | # switch to train mode 89 | self.model.train() 90 | end = time.time() 91 | 92 | bar = Bar('Processing', max=len(self.train_loader)*self.hl) 93 | for _ in range(self.hl): 94 | for i, batches in enumerate(self.train_loader): 95 | # measure data loading time 96 | inputs = batches['image'] 97 | target = batches['target'].to(self.device) 98 | mask =batches['mask'].to(self.device) 99 | current_index = len(self.train_loader) * epoch + i 100 | 101 | if self.args.hl: 102 | feeded = torch.cat([inputs,mask],dim=1) 103 | else: 104 | feeded = inputs 105 | feeded = feeded.to(self.device) 106 | 107 | output = self.model(feeded) 108 | L2_loss = self.loss(output,target) 109 | 110 | if self.args.style_loss > 0: 111 | vgg_loss = self.vggloss(output,target,mask) 112 | else: 113 | vgg_loss = 0 114 | 115 | total_loss = L2_loss + self.args.style_loss * vgg_loss 116 | 117 | # compute gradient and do SGD step 118 | self.optimizer.zero_grad() 119 | total_loss.backward() 120 | self.optimizer.step() 121 | 122 | # measure accuracy and record loss 123 | losses.update(L2_loss.item(), inputs.size(0)) 124 | 125 | if self.args.style_loss > 0 : 126 | lossvgg.update(vgg_loss.item(), inputs.size(0)) 127 | 128 | # measure elapsed time 129 | batch_time.update(time.time() - end) 130 | end = time.time() 131 | 132 | # plot progress 133 | suffix = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss L2: {loss_label:.4f} | Loss VGG: {loss_vgg:.4f}'.format( 134 | batch=i + 1, 135 | size=len(self.train_loader), 136 | data=data_time.val, 137 | bt=batch_time.val, 138 | total=bar.elapsed_td, 139 | eta=bar.eta_td, 140 | loss_label=losses.avg, 141 | loss_vgg=lossvgg.avg 142 | ) 143 | 144 | if current_index % 1000 == 0: 145 | print(suffix) 146 | 147 | if self.args.freq > 0 and current_index % self.args.freq == 0: 148 | self.validate(current_index) 149 | self.flush() 150 | self.save_checkpoint() 151 | 152 | self.record('train/loss_L2', losses.avg, current_index) 153 | 154 | 155 | def test(self, ): 156 | 157 | # switch to evaluate mode 158 | self.model.eval() 159 | 160 | ssimes = AverageMeter() 161 | psnres = AverageMeter() 162 | 163 | with torch.no_grad(): 164 | for i, batches in enumerate(self.val_loader): 165 | 166 | inputs = batches['image'].to(self.device) 167 | target = batches['target'].to(self.device) 168 | mask =batches['mask'].to(self.device) 169 | 170 | outputs = self.model(inputs) 171 | 172 | # select the outputs by the giving arch 173 | if type(outputs) == type(inputs): 174 | output = outputs 175 | elif type(outputs[0]) == type([]): 176 | output = outputs[0][0] 177 | else: 178 | output = outputs[0] 179 | 180 | # recover the image to 255 181 | output = im_to_numpy(torch.clamp(output[0]*255,min=0.0,max=255.0)).astype(np.uint8) 182 | target = im_to_numpy(torch.clamp(target[0]*255,min=0.0,max=255.0)).astype(np.uint8) 183 | 184 | skimage.io.imsave('%s/%s'%(self.args.checkpoint,batches['name'][0]), output) 185 | 186 | psnr = compare_psnr(target,output) 187 | ssim = compare_ssim(target,output,multichannel=True) 188 | 189 | psnres.update(psnr, inputs.size(0)) 190 | ssimes.update(ssim, inputs.size(0)) 191 | 192 | print("%s:PSNR:%s,SSIM:%s"%(self.args.checkpoint,psnres.avg,ssimes.avg)) 193 | print("DONE.\n") 194 | 195 | 196 | def validate(self, epoch): 197 | batch_time = AverageMeter() 198 | data_time = AverageMeter() 199 | losses = AverageMeter() 200 | ssimes = AverageMeter() 201 | psnres = AverageMeter() 202 | # switch to evaluate mode 203 | self.model.eval() 204 | 205 | end = time.time() 206 | with torch.no_grad(): 207 | for i, batches in enumerate(self.val_loader): 208 | 209 | inputs = batches['image'].to(self.device) 210 | target = batches['target'].to(self.device) 211 | mask =batches['mask'].to(self.device) 212 | 213 | if self.args.hl: 214 | feeded = torch.cat([inputs,torch.zeros((1,4,self.args.input_size,self.args.input_size)).to(self.device)],dim=1) 215 | else: 216 | feeded = inputs 217 | 218 | output = self.model(feeded) 219 | 220 | L2_loss = self.loss(output, target) 221 | 222 | psnr = 10 * log10(1 / L2_loss.item()) 223 | ssim = pytorch_ssim.ssim(output, target) 224 | 225 | losses.update(L2_loss.item(), inputs.size(0)) 226 | psnres.update(psnr, inputs.size(0)) 227 | ssimes.update(ssim.item(), inputs.size(0)) 228 | 229 | # measure elapsed time 230 | batch_time.update(time.time() - end) 231 | end = time.time() 232 | 233 | print("Epoches:%s,Losses:%.3f,PSNR:%.3f,SSIM:%.3f"%(epoch+1, losses.avg,psnres.avg,ssimes.avg)) 234 | self.record('val/loss_L2', losses.avg, epoch) 235 | self.record('val/PSNR', psnres.avg, epoch) 236 | self.record('val/SSIM', ssimes.avg, epoch) 237 | 238 | self.metric = psnres.avg 239 | 240 | def resume(self,resume_path): 241 | if isfile(resume_path): 242 | print("=> loading checkpoint '{}'".format(resume_path)) 243 | current_checkpoint = torch.load(resume_path) 244 | if isinstance(current_checkpoint['state_dict'], torch.nn.DataParallel): 245 | current_checkpoint['state_dict'] = current_checkpoint['state_dict'].module 246 | 247 | if isinstance(current_checkpoint['optimizer'], torch.nn.DataParallel): 248 | current_checkpoint['optimizer'] = current_checkpoint['optimizer'].module 249 | 250 | self.args.start_epoch = current_checkpoint['epoch'] 251 | self.metric = current_checkpoint['best_acc'] 252 | self.model.load_state_dict(current_checkpoint['state_dict']) 253 | # self.optimizer.load_state_dict(current_checkpoint['optimizer']) 254 | print("=> loaded checkpoint '{}' (epoch {})" 255 | .format(resume_path, current_checkpoint['epoch'])) 256 | else: 257 | raise Exception("=> no checkpoint found at '{}'".format(resume_path)) 258 | 259 | def save_checkpoint(self,filename='checkpoint.pth.tar', snapshot=None): 260 | is_best = True if self.best_acc < self.metric else False 261 | 262 | if is_best: 263 | self.best_acc = self.metric 264 | 265 | state = { 266 | 'epoch': self.current_epoch + 1, 267 | 'arch': self.args.arch, 268 | 'state_dict': self.model.state_dict(), 269 | 'best_acc': self.best_acc, 270 | 'optimizer' : self.optimizer.state_dict() if self.optimizer else None, 271 | } 272 | 273 | filepath = os.path.join(self.args.checkpoint, filename) 274 | torch.save(state, filepath) 275 | 276 | if snapshot and state['epoch'] % snapshot == 0: 277 | shutil.copyfile(filepath, os.path.join(self.args.checkpoint, 'checkpoint_{}.pth.tar'.format(state.epoch))) 278 | 279 | if is_best: 280 | self.best_acc = self.metric 281 | print('Saving Best Metric with PSNR:%s'%self.best_acc) 282 | shutil.copyfile(filepath, os.path.join(self.args.checkpoint, 'model_best.pth.tar')) 283 | 284 | def clean(self): 285 | self.writer.close() 286 | 287 | def record(self,k,v,epoch): 288 | self.writer.add_scalar(k, v, epoch) 289 | 290 | def flush(self): 291 | self.writer.flush() 292 | sys.stdout.flush() 293 | 294 | def norm(self,x): 295 | if self.args.gan_norm: 296 | return x*2.0 - 1.0 297 | else: 298 | return x 299 | 300 | def denorm(self,x): 301 | if self.args.gan_norm: 302 | return (x+1.0)/2.0 303 | else: 304 | return x 305 | 306 | -------------------------------------------------------------------------------- /scripts/machines/VX.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from progress.bar import Bar 4 | from tqdm import tqdm 5 | import pytorch_ssim 6 | import json 7 | import sys,time,os 8 | import torchvision 9 | from math import log10 10 | import numpy as np 11 | from .BasicMachine import BasicMachine 12 | from scripts.utils.evaluation import accuracy, AverageMeter, final_preds 13 | from scripts.utils.misc import resize_to_match 14 | from torch.autograd import Variable 15 | import torch.nn.functional as F 16 | from scripts.utils.parallel import DataParallelModel, DataParallelCriterion 17 | from scripts.utils.losses import VGGLoss, l1_relative,is_dic 18 | from scripts.utils.imutils import im_to_numpy 19 | import skimage.io 20 | from skimage.measure import compare_psnr,compare_ssim 21 | 22 | 23 | class Losses(nn.Module): 24 | def __init__(self, argx, device, norm_func=None, denorm_func=None): 25 | super(Losses, self).__init__() 26 | self.args = argx 27 | 28 | if self.args.loss_type == 'l1bl2': 29 | self.outputLoss, self.attLoss, self.wrloss = nn.L1Loss(), nn.BCELoss(), nn.MSELoss() 30 | elif self.args.loss_type == 'l2xbl2': 31 | self.outputLoss, self.attLoss, self.wrloss = nn.MSELoss(), nn.BCEWithLogitsLoss(), nn.MSELoss() 32 | elif self.args.loss_type == 'relative' or self.args.loss_type == 'hybrid': 33 | self.outputLoss, self.attLoss, self.wrloss = l1_relative, nn.BCELoss(), l1_relative 34 | else: # l2bl2 35 | self.outputLoss, self.attLoss, self.wrloss = nn.MSELoss(), nn.BCELoss(), nn.MSELoss() 36 | 37 | self.default = nn.L1Loss() 38 | 39 | if self.args.style_loss > 0: 40 | self.vggloss = VGGLoss(self.args.sltype).to(device) 41 | 42 | if self.args.ssim_loss > 0: 43 | self.ssimloss = pytorch_ssim.SSIM().to(device) 44 | 45 | self.norm = norm_func 46 | self.denorm = denorm_func 47 | 48 | 49 | def forward(self,pred_ims,target,pred_ms,mask,pred_wms,wm): 50 | pixel_loss,att_loss,wm_loss,vgg_loss,ssim_loss = [0]*5 51 | pred_ims = pred_ims if is_dic(pred_ims) else [pred_ims] 52 | 53 | # try the loss in the masked region 54 | if self.args.masked and 'hybrid' in self.args.loss_type: # masked loss 55 | pixel_loss += sum([self.outputLoss(pred_im, target, mask) for pred_im in pred_ims]) 56 | pixel_loss += sum([self.default(pred_im*pred_ms,target*mask) for pred_im in pred_ims]) 57 | recov_imgs = [ self.denorm(pred_im*mask + (1-mask)*self.norm(target)) for pred_im in pred_ims ] 58 | wm_loss += self.wrloss(pred_wms, wm, mask) 59 | wm_loss += self.default(pred_wms*pred_ms, wm*mask) 60 | 61 | elif self.args.masked and 'relative' in self.args.loss_type: # masked loss 62 | pixel_loss += sum([self.outputLoss(pred_im, target, mask) for pred_im in pred_ims]) 63 | recov_imgs = [ self.denorm(pred_im*mask + (1-mask)*self.norm(target)) for pred_im in pred_ims ] 64 | wm_loss = self.wrloss(pred_wms, wm, mask) 65 | elif self.args.masked: 66 | pixel_loss += sum([self.outputLoss(pred_im*mask, target*mask) for pred_im in pred_ims]) 67 | recov_imgs = [ self.denorm(pred_im*pred_ms + (1-pred_ms)*self.norm(target)) for pred_im in pred_ims ] 68 | wm_loss = self.wrloss(pred_wms*mask, wm*mask) 69 | else: 70 | pixel_loss += sum([self.outputLoss(pred_im*pred_ms, target*mask) for pred_im in pred_ims]) 71 | recov_imgs = [ self.denorm(pred_im*pred_ms + (1-pred_ms)*self.norm(target)) for pred_im in pred_ims ] 72 | wm_loss = self.wrloss(pred_wms*pred_ms,wm*mask) 73 | 74 | pixel_loss += sum([self.default(im,target) for im in recov_imgs]) 75 | 76 | if self.args.style_loss > 0: 77 | vgg_loss = sum([self.vggloss(im,target,mask) for im in recov_imgs]) 78 | 79 | if self.args.ssim_loss > 0: 80 | ssim_loss = sum([ 1 - self.ssimloss(im,target) for im in recov_imgs]) 81 | 82 | att_loss = self.attLoss(pred_ms, mask) 83 | 84 | return pixel_loss,att_loss,wm_loss,vgg_loss,ssim_loss 85 | 86 | 87 | class VX(BasicMachine): 88 | def __init__(self,**kwargs): 89 | BasicMachine.__init__(self,**kwargs) 90 | self.loss = Losses(self.args, self.device, self.norm, self.denorm) 91 | self.model.set_optimizers() 92 | self.optimizer = None 93 | 94 | def train(self,epoch): 95 | 96 | self.current_epoch = epoch 97 | 98 | batch_time = AverageMeter() 99 | data_time = AverageMeter() 100 | losses = AverageMeter() 101 | lossMask = AverageMeter() 102 | lossWM = AverageMeter() 103 | lossMX = AverageMeter() 104 | lossvgg = AverageMeter() 105 | lossssim = AverageMeter() 106 | 107 | # switch to train mode 108 | self.model.train() 109 | 110 | end = time.time() 111 | bar = Bar('Processing {} '.format(self.args.arch), max=len(self.train_loader)) 112 | 113 | for i, batches in enumerate(self.train_loader): 114 | 115 | current_index = len(self.train_loader) * epoch + i 116 | 117 | inputs = batches['image'].to(self.device) 118 | target = batches['target'].to(self.device) 119 | mask = batches['mask'].to(self.device) 120 | wm = batches['wm'].to(self.device) 121 | 122 | outputs = self.model(self.norm(inputs)) 123 | 124 | self.model.zero_grad_all() 125 | 126 | l2_loss,att_loss,wm_loss,style_loss,ssim_loss = self.loss(outputs[0],self.norm(target),outputs[1],mask,outputs[2],self.norm(wm)) 127 | total_loss = 2*l2_loss + self.args.att_loss * att_loss + wm_loss + self.args.style_loss * style_loss + self.args.ssim_loss * ssim_loss 128 | 129 | # compute gradient and do SGD step 130 | total_loss.backward() 131 | self.model.step_all() 132 | 133 | # measure accuracy and record loss 134 | losses.update(l2_loss.item(), inputs.size(0)) 135 | lossMask.update(att_loss.item(), inputs.size(0)) 136 | lossWM.update(wm_loss.item(), inputs.size(0)) 137 | 138 | if self.args.style_loss > 0 : 139 | lossvgg.update(style_loss.item(), inputs.size(0)) 140 | 141 | if self.args.ssim_loss > 0 : 142 | lossssim.update(ssim_loss.item(), inputs.size(0)) 143 | 144 | 145 | # measure elapsed time 146 | batch_time.update(time.time() - end) 147 | end = time.time() 148 | 149 | # plot progress 150 | suffix = "({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss L2: {loss_label:.4f} | Loss Mask: {loss_mask:.4f} | loss WM: {loss_wm:.4f} | loss VGG: {loss_vgg:.4f} | loss SSIM: {loss_ssim:.4f}| loss MX: {loss_mx:.4f}".format( 151 | batch=i + 1, 152 | size=len(self.train_loader), 153 | data=data_time.val, 154 | bt=batch_time.val, 155 | total=bar.elapsed_td, 156 | eta=bar.eta_td, 157 | loss_label=losses.avg, 158 | loss_mask=lossMask.avg, 159 | loss_wm=lossWM.avg, 160 | loss_vgg=lossvgg.avg, 161 | loss_ssim=lossssim.avg, 162 | loss_mx=lossMX.avg 163 | ) 164 | if current_index % 1000 == 0: 165 | print(suffix) 166 | 167 | if self.args.freq > 0 and current_index % self.args.freq == 0: 168 | self.validate(current_index) 169 | self.flush() 170 | self.save_checkpoint() 171 | 172 | self.record('train/loss_L2', losses.avg, epoch) 173 | self.record('train/loss_Mask', lossMask.avg, epoch) 174 | self.record('train/loss_WM', lossWM.avg, epoch) 175 | self.record('train/loss_VGG', lossvgg.avg, epoch) 176 | self.record('train/loss_SSIM', lossssim.avg, epoch) 177 | self.record('train/loss_MX', lossMX.avg, epoch) 178 | 179 | 180 | 181 | 182 | def validate(self, epoch): 183 | 184 | self.current_epoch = epoch 185 | 186 | batch_time = AverageMeter() 187 | data_time = AverageMeter() 188 | losses = AverageMeter() 189 | lossMask = AverageMeter() 190 | psnres = AverageMeter() 191 | ssimes = AverageMeter() 192 | 193 | # switch to evaluate mode 194 | self.model.eval() 195 | 196 | end = time.time() 197 | bar = Bar('Processing {} '.format(self.args.arch), max=len(self.val_loader)) 198 | with torch.no_grad(): 199 | for i, batches in enumerate(self.val_loader): 200 | 201 | current_index = len(self.val_loader) * epoch + i 202 | 203 | inputs = batches['image'].to(self.device) 204 | target = batches['target'].to(self.device) 205 | 206 | outputs = self.model(self.norm(inputs)) 207 | imoutput,immask,imwatermark = outputs 208 | imoutput = imoutput[0] if is_dic(imoutput) else imoutput 209 | 210 | imfinal = self.denorm(imoutput*immask + self.norm(inputs)*(1-immask)) 211 | 212 | if i % 300 == 0: 213 | # save the sample images 214 | ims = torch.cat([inputs,target,imfinal,immask.repeat(1,3,1,1)],dim=3) 215 | torchvision.utils.save_image(ims,os.path.join(self.args.checkpoint,'%s_%s.jpg'%(i,epoch))) 216 | 217 | # here two choice: mseLoss or NLLLoss 218 | psnr = 10 * log10(1 / F.mse_loss(imfinal,target).item()) 219 | 220 | ssim = pytorch_ssim.ssim(imfinal,target) 221 | 222 | psnres.update(psnr, inputs.size(0)) 223 | ssimes.update(ssim, inputs.size(0)) 224 | 225 | # measure elapsed time 226 | batch_time.update(time.time() - end) 227 | end = time.time() 228 | 229 | # plot progress 230 | bar.suffix = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss_L2: {loss_label:.4f} | Loss_Mask: {loss_mask:.4f} | PSNR: {psnr:.4f} | SSIM: {ssim:.4f}'.format( 231 | batch=i + 1, 232 | size=len(self.val_loader), 233 | data=data_time.val, 234 | bt=batch_time.val, 235 | total=bar.elapsed_td, 236 | eta=bar.eta_td, 237 | loss_label=losses.avg, 238 | loss_mask=lossMask.avg, 239 | psnr=psnres.avg, 240 | ssim=ssimes.avg 241 | ) 242 | bar.next() 243 | bar.finish() 244 | 245 | print("Iter:%s,Losses:%s,PSNR:%.4f,SSIM:%.4f"%(epoch, losses.avg,psnres.avg,ssimes.avg)) 246 | self.record('val/loss_L2', losses.avg, epoch) 247 | self.record('val/lossMask', lossMask.avg, epoch) 248 | self.record('val/PSNR', psnres.avg, epoch) 249 | self.record('val/SSIM', ssimes.avg, epoch) 250 | self.metric = psnres.avg 251 | 252 | self.model.train() 253 | 254 | def test(self, ): 255 | 256 | # switch to evaluate mode 257 | self.model.eval() 258 | print("==> testing VM model ") 259 | ssimes = AverageMeter() 260 | psnres = AverageMeter() 261 | ssimesx = AverageMeter() 262 | psnresx = AverageMeter() 263 | 264 | with torch.no_grad(): 265 | for i, batches in enumerate(tqdm(self.val_loader)): 266 | 267 | inputs = batches['image'].to(self.device) 268 | target = batches['target'].to(self.device) 269 | mask =batches['mask'].to(self.device) 270 | 271 | # select the outputs by the giving arch 272 | outputs = self.model(self.norm(inputs)) 273 | imoutput,immask,imwatermark = outputs 274 | imoutput = imoutput[0] if is_dic(imoutput) else imoutput 275 | 276 | imfinal = self.denorm(imoutput*immask + self.norm(inputs)*(1-immask)) 277 | psnrx = 10 * log10(1 / F.mse_loss(imfinal,target).item()) 278 | ssimx = pytorch_ssim.ssim(imfinal,target) 279 | # recover the image to 255 280 | imfinal = im_to_numpy(torch.clamp(imfinal[0]*255,min=0.0,max=255.0)).astype(np.uint8) 281 | target = im_to_numpy(torch.clamp(target[0]*255,min=0.0,max=255.0)).astype(np.uint8) 282 | 283 | skimage.io.imsave('%s/%s'%(self.args.checkpoint,batches['name'][0]), imfinal) 284 | 285 | psnr = compare_psnr(target,imfinal) 286 | ssim = compare_ssim(target,imfinal,multichannel=True) 287 | 288 | psnres.update(psnr, inputs.size(0)) 289 | ssimes.update(ssim, inputs.size(0)) 290 | psnresx.update(psnrx, inputs.size(0)) 291 | ssimesx.update(ssimx, inputs.size(0)) 292 | 293 | print("%s:PSNR:%.5f(%.5f),SSIM:%.5f(%.5f)"%(self.args.checkpoint,psnres.avg,psnresx.avg,ssimes.avg,ssimesx.avg)) 294 | print("DONE.\n") -------------------------------------------------------------------------------- /scripts/models/vmu.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from scripts.models.blocks import SEBlock 6 | from scripts.models.rasc import * 7 | from scripts.models.unet import UnetGenerator,MinimalUnetV2 8 | 9 | def weight_init(m): 10 | if isinstance(m, nn.Conv2d): 11 | nn.init.xavier_normal_(m.weight) 12 | nn.init.constant_(m.bias, 0) 13 | 14 | def reset_params(model): 15 | for i, m in enumerate(model.modules()): 16 | weight_init(m) 17 | 18 | 19 | def conv3x3(in_channels, out_channels, stride=1, 20 | padding=1, bias=True, groups=1): 21 | return nn.Conv2d( 22 | in_channels, 23 | out_channels, 24 | kernel_size=3, 25 | stride=stride, 26 | padding=padding, 27 | bias=bias, 28 | groups=groups) 29 | 30 | 31 | def up_conv2x2(in_channels, out_channels, transpose=True): 32 | if transpose: 33 | return nn.ConvTranspose2d( 34 | in_channels, 35 | out_channels, 36 | kernel_size=2, 37 | stride=2) 38 | else: 39 | return nn.Sequential( 40 | nn.Upsample(mode='bilinear', scale_factor=2), 41 | conv1x1(in_channels, out_channels)) 42 | 43 | 44 | def conv1x1(in_channels, out_channels, groups=1): 45 | return nn.Conv2d( 46 | in_channels, 47 | out_channels, 48 | kernel_size=1, 49 | groups=groups, 50 | stride=1) 51 | 52 | 53 | 54 | 55 | class UpCoXvD(nn.Module): 56 | 57 | def __init__(self, in_channels, out_channels, blocks, residual=True, batch_norm=True, transpose=True,concat=True,use_att=False): 58 | super(UpCoXvD, self).__init__() 59 | self.concat = concat 60 | self.residual = residual 61 | self.batch_norm = batch_norm 62 | self.bn = None 63 | self.conv2 = [] 64 | self.use_att = use_att 65 | self.up_conv = up_conv2x2(in_channels, out_channels, transpose=transpose) 66 | 67 | if self.use_att: 68 | self.s2am = RASC(2 * out_channels) 69 | else: 70 | self.s2am = None 71 | 72 | if self.concat: 73 | self.conv1 = conv3x3(2 * out_channels, out_channels) 74 | else: 75 | self.conv1 = conv3x3(out_channels, out_channels) 76 | for _ in range(blocks): 77 | self.conv2.append(conv3x3(out_channels, out_channels)) 78 | if self.batch_norm: 79 | self.bn = [] 80 | for _ in range(blocks): 81 | self.bn.append(nn.BatchNorm2d(out_channels)) 82 | self.bn = nn.ModuleList(self.bn) 83 | self.conv2 = nn.ModuleList(self.conv2) 84 | 85 | def forward(self, from_up, from_down, mask=None): 86 | from_up = self.up_conv(from_up) 87 | if self.concat: 88 | x1 = torch.cat((from_up, from_down), 1) 89 | else: 90 | if from_down is not None: 91 | x1 = from_up + from_down 92 | else: 93 | x1 = from_up 94 | 95 | if self.use_att: 96 | x1 = self.s2am(x1,mask) 97 | 98 | x1 = F.relu(self.conv1(x1)) 99 | x2 = None 100 | for idx, conv in enumerate(self.conv2): 101 | x2 = conv(x1) 102 | if self.batch_norm: 103 | x2 = self.bn[idx](x2) 104 | if self.residual: 105 | x2 = x2 + x1 106 | x2 = F.relu(x2) 107 | x1 = x2 108 | return x2 109 | 110 | 111 | class DownCoXvD(nn.Module): 112 | 113 | def __init__(self, in_channels, out_channels, blocks, pooling=True, residual=True, batch_norm=True): 114 | super(DownCoXvD, self).__init__() 115 | self.pooling = pooling 116 | self.residual = residual 117 | self.batch_norm = batch_norm 118 | self.bn = None 119 | self.pool = None 120 | self.conv1 = conv3x3(in_channels, out_channels) 121 | self.conv2 = [] 122 | for _ in range(blocks): 123 | self.conv2.append(conv3x3(out_channels, out_channels)) 124 | if self.batch_norm: 125 | self.bn = [] 126 | for _ in range(blocks): 127 | self.bn.append(nn.BatchNorm2d(out_channels)) 128 | self.bn = nn.ModuleList(self.bn) 129 | if self.pooling: 130 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 131 | self.conv2 = nn.ModuleList(self.conv2) 132 | 133 | def __call__(self, x): 134 | return self.forward(x) 135 | 136 | def forward(self, x): 137 | x1 = F.relu(self.conv1(x)) 138 | x2 = None 139 | for idx, conv in enumerate(self.conv2): 140 | x2 = conv(x1) 141 | if self.batch_norm: 142 | x2 = self.bn[idx](x2) 143 | if self.residual: 144 | x2 = x2 + x1 145 | x2 = F.relu(x2) 146 | x1 = x2 147 | before_pool = x2 148 | if self.pooling: 149 | x2 = self.pool(x2) 150 | return x2, before_pool 151 | 152 | class UnetDecoderD(nn.Module): 153 | def __init__(self, in_channels=512, out_channels=3, depth=5, blocks=1, residual=True, batch_norm=True, 154 | transpose=True, concat=True, is_final=True): 155 | super(UnetDecoderD, self).__init__() 156 | self.conv_final = None 157 | self.up_convs = [] 158 | outs = in_channels 159 | for i in range(depth-1): 160 | ins = outs 161 | outs = ins // 2 162 | # 512,256 163 | # 256,128 164 | # 128,64 165 | # 64,32 166 | up_conv = UpCoXvD(ins, outs, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose, 167 | concat=concat) 168 | self.up_convs.append(up_conv) 169 | if is_final: 170 | self.conv_final = conv1x1(outs, out_channels) 171 | else: 172 | up_conv = UpCoXvD(outs, out_channels, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose, 173 | concat=concat) 174 | self.up_convs.append(up_conv) 175 | self.up_convs = nn.ModuleList(self.up_convs) 176 | reset_params(self) 177 | 178 | def __call__(self, x, encoder_outs=None): 179 | return self.forward(x, encoder_outs) 180 | 181 | def forward(self, x, encoder_outs=None): 182 | for i, up_conv in enumerate(self.up_convs): 183 | before_pool = None 184 | if encoder_outs is not None: 185 | before_pool = encoder_outs[-(i+2)] 186 | x = up_conv(x, before_pool) 187 | if self.conv_final is not None: 188 | x = self.conv_final(x) 189 | return x 190 | 191 | 192 | class UnetEncoderD(nn.Module): 193 | 194 | def __init__(self, in_channels=3, depth=5, blocks=1, start_filters=32, residual=True, batch_norm=True): 195 | super(UnetEncoderD, self).__init__() 196 | self.down_convs = [] 197 | outs = None 198 | if type(blocks) is tuple: 199 | blocks = blocks[0] 200 | for i in range(depth): 201 | ins = in_channels if i == 0 else outs 202 | outs = start_filters*(2**i) 203 | pooling = True if i < depth-1 else False 204 | down_conv = DownCoXvD(ins, outs, blocks, pooling=pooling, residual=residual, batch_norm=batch_norm) 205 | self.down_convs.append(down_conv) 206 | self.down_convs = nn.ModuleList(self.down_convs) 207 | reset_params(self) 208 | 209 | def __call__(self, x): 210 | return self.forward(x) 211 | 212 | def forward(self, x): 213 | encoder_outs = [] 214 | for d_conv in self.down_convs: 215 | x, before_pool = d_conv(x) 216 | encoder_outs.append(before_pool) 217 | return x, encoder_outs 218 | 219 | 220 | 221 | class UnetVM(nn.Module): 222 | 223 | def __init__(self, in_channels=3, depth=5, shared_depth=0, use_vm_decoder=False, blocks=1, 224 | out_channels_image=3, out_channels_mask=1, start_filters=32, residual=True, batch_norm=True, 225 | transpose=True, concat=True, transfer_data=True, long_skip=False): 226 | super(UnetVM, self).__init__() 227 | self.transfer_data = transfer_data 228 | self.shared = shared_depth 229 | self.optimizer_encoder, self.optimizer_image, self.optimizer_vm = None, None, None 230 | self.optimizer_mask, self.optimizer_shared = None, None 231 | if type(blocks) is not tuple: 232 | blocks = (blocks, blocks, blocks, blocks, blocks) 233 | if not transfer_data: 234 | concat = False 235 | self.encoder = UnetEncoderD(in_channels=in_channels, depth=depth, blocks=blocks[0], 236 | start_filters=start_filters, residual=residual, batch_norm=batch_norm) 237 | self.image_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1), 238 | out_channels=out_channels_image, depth=depth - shared_depth, 239 | blocks=blocks[1], residual=residual, batch_norm=batch_norm, 240 | transpose=transpose, concat=concat) 241 | self.mask_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - 1), 242 | out_channels=out_channels_mask, depth=depth, 243 | blocks=blocks[2], residual=residual, batch_norm=batch_norm, 244 | transpose=transpose, concat=concat) 245 | self.vm_decoder = None 246 | if use_vm_decoder: 247 | self.vm_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1), 248 | out_channels=out_channels_image, depth=depth - shared_depth, 249 | blocks=blocks[3], residual=residual, batch_norm=batch_norm, 250 | transpose=transpose, concat=concat) 251 | self.shared_decoder = None 252 | self.long_skip = long_skip 253 | self._forward = self.unshared_forward 254 | if self.shared != 0: 255 | self._forward = self.shared_forward 256 | self.shared_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - 1), 257 | out_channels=start_filters * 2 ** (depth - shared_depth - 1), 258 | depth=shared_depth, blocks=blocks[4], residual=residual, 259 | batch_norm=batch_norm, transpose=transpose, concat=concat, 260 | is_final=False) 261 | 262 | def set_optimizers(self): 263 | self.optimizer_encoder = torch.optim.Adam(self.encoder.parameters(), lr=0.001) 264 | self.optimizer_image = torch.optim.Adam(self.image_decoder.parameters(), lr=0.001) 265 | self.optimizer_mask = torch.optim.Adam(self.mask_decoder.parameters(), lr=0.001) 266 | if self.vm_decoder is not None: 267 | self.optimizer_vm = torch.optim.Adam(self.vm_decoder.parameters(), lr=0.001) 268 | if self.shared != 0: 269 | self.optimizer_shared = torch.optim.Adam(self.shared_decoder.parameters(), lr=0.001) 270 | 271 | def zero_grad_all(self): 272 | self.optimizer_encoder.zero_grad() 273 | self.optimizer_image.zero_grad() 274 | self.optimizer_mask.zero_grad() 275 | if self.vm_decoder is not None: 276 | self.optimizer_vm.zero_grad() 277 | if self.shared != 0: 278 | self.optimizer_shared.zero_grad() 279 | 280 | def step_all(self): 281 | self.optimizer_encoder.step() 282 | self.optimizer_image.step() 283 | self.optimizer_mask.step() 284 | if self.vm_decoder is not None: 285 | self.optimizer_vm.step() 286 | if self.shared != 0: 287 | self.optimizer_shared.step() 288 | 289 | def step_optimizer_image(self): 290 | self.optimizer_image.step() 291 | 292 | def __call__(self, synthesized): 293 | return self._forward(synthesized) 294 | 295 | def forward(self, synthesized): 296 | return self._forward(synthesized) 297 | 298 | def unshared_forward(self, synthesized): 299 | image_code, before_pool = self.encoder(synthesized) 300 | if not self.transfer_data: 301 | before_pool = None 302 | reconstructed_image = torch.tanh(self.image_decoder(image_code, before_pool)) 303 | reconstructed_mask = torch.sigmoid(self.mask_decoder(image_code, before_pool)) 304 | if self.vm_decoder is not None: 305 | reconstructed_vm = torch.tanh(self.vm_decoder(image_code, before_pool)) 306 | return reconstructed_image, reconstructed_mask, reconstructed_vm 307 | return reconstructed_image, reconstructed_mask 308 | 309 | def shared_forward(self, synthesized): 310 | image_code, before_pool = self.encoder(synthesized) 311 | if self.transfer_data: 312 | shared_before_pool = before_pool[- self.shared - 1:] 313 | unshared_before_pool = before_pool[: - self.shared] 314 | else: 315 | before_pool = None 316 | shared_before_pool = None 317 | unshared_before_pool = None 318 | x = self.shared_decoder(image_code, shared_before_pool) 319 | reconstructed_image = torch.tanh(self.image_decoder(x, unshared_before_pool)) 320 | if self.long_skip: 321 | reconstructed_image = reconstructed_image + synthesized 322 | 323 | reconstructed_mask = torch.sigmoid(self.mask_decoder(image_code, before_pool)) 324 | if self.vm_decoder is not None: 325 | reconstructed_vm = torch.tanh(self.vm_decoder(x, unshared_before_pool)) 326 | if self.long_skip: 327 | reconstructed_vm = reconstructed_vm + synthesized 328 | return reconstructed_image, reconstructed_mask, reconstructed_vm 329 | return reconstructed_image, reconstructed_mask 330 | -------------------------------------------------------------------------------- /scripts/models/sa_resunet.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from scripts.models.blocks import SEBlock 6 | from scripts.models.rasc import * 7 | from scripts.models.unet import UnetGenerator,MinimalUnetV2 8 | 9 | def weight_init(m): 10 | if isinstance(m, nn.Conv2d): 11 | nn.init.xavier_normal_(m.weight) 12 | nn.init.constant_(m.bias, 0) 13 | 14 | def reset_params(model): 15 | for i, m in enumerate(model.modules()): 16 | weight_init(m) 17 | 18 | 19 | def conv3x3(in_channels, out_channels, stride=1, 20 | padding=1, bias=True, groups=1): 21 | return nn.Conv2d( 22 | in_channels, 23 | out_channels, 24 | kernel_size=3, 25 | stride=stride, 26 | padding=padding, 27 | bias=bias, 28 | groups=groups) 29 | 30 | 31 | def up_conv2x2(in_channels, out_channels, transpose=True): 32 | if transpose: 33 | return nn.ConvTranspose2d( 34 | in_channels, 35 | out_channels, 36 | kernel_size=2, 37 | stride=2) 38 | else: 39 | return nn.Sequential( 40 | nn.Upsample(mode='bilinear', scale_factor=2), 41 | conv1x1(in_channels, out_channels)) 42 | 43 | 44 | def conv1x1(in_channels, out_channels, groups=1): 45 | return nn.Conv2d( 46 | in_channels, 47 | out_channels, 48 | kernel_size=1, 49 | groups=groups, 50 | stride=1) 51 | 52 | 53 | class UpCoXvD(nn.Module): 54 | 55 | def __init__(self, in_channels, out_channels, blocks, residual=True,norm=nn.BatchNorm2d, act=F.relu,batch_norm=True, transpose=True,concat=True,use_att=False): 56 | super(UpCoXvD, self).__init__() 57 | self.concat = concat 58 | self.residual = residual 59 | self.batch_norm = batch_norm 60 | self.bn = None 61 | self.conv2 = [] 62 | self.use_att = use_att 63 | self.up_conv = up_conv2x2(in_channels, out_channels, transpose=transpose) 64 | self.norm0 = norm(out_channels) 65 | 66 | if self.use_att: 67 | self.s2am = RASC(2 * out_channels) 68 | else: 69 | self.s2am = None 70 | 71 | if self.concat: 72 | self.conv1 = conv3x3(2 * out_channels, out_channels) 73 | self.norm1 = norm(out_channels , out_channels) 74 | else: 75 | self.conv1 = conv3x3(out_channels, out_channels) 76 | self.norm1 = norm(out_channels , out_channels) 77 | 78 | for _ in range(blocks): 79 | self.conv2.append(conv3x3(out_channels, out_channels)) 80 | if self.batch_norm: 81 | self.bn = [] 82 | for _ in range(blocks): 83 | self.bn.append(norm(out_channels)) 84 | self.bn = nn.ModuleList(self.bn) 85 | self.conv2 = nn.ModuleList(self.conv2) 86 | self.act = act 87 | 88 | def forward(self, from_up, from_down, mask=None,se=None): 89 | from_up = self.act(self.norm0(self.up_conv(from_up))) 90 | if self.concat: 91 | x1 = torch.cat((from_up, from_down), 1) 92 | else: 93 | if from_down is not None: 94 | x1 = from_up + from_down 95 | else: 96 | x1 = from_up 97 | 98 | if self.use_att: 99 | x1 = self.s2am(x1,mask) 100 | 101 | x1 = self.act(self.norm1(self.conv1(x1))) 102 | x2 = None 103 | for idx, conv in enumerate(self.conv2): 104 | x2 = conv(x1) 105 | if self.batch_norm: 106 | x2 = self.bn[idx](x2) 107 | 108 | if (se is not None) and (idx == len(self.conv2) - 1): # last 109 | x2 = se(x2) 110 | 111 | if self.residual: 112 | x2 = x2 + x1 113 | x2 = self.act(x2) 114 | x1 = x2 115 | return x2 116 | 117 | 118 | class DownCoXvD(nn.Module): 119 | 120 | def __init__(self, in_channels, out_channels, blocks, pooling=True, norm=nn.BatchNorm2d,act=F.relu,residual=True, batch_norm=True): 121 | super(DownCoXvD, self).__init__() 122 | self.pooling = pooling 123 | self.residual = residual 124 | self.batch_norm = batch_norm 125 | self.bn = None 126 | self.pool = None 127 | self.conv1 = conv3x3(in_channels, out_channels) 128 | self.norm1 = norm(out_channels) 129 | 130 | self.conv2 = [] 131 | for _ in range(blocks): 132 | self.conv2.append(conv3x3(out_channels, out_channels)) 133 | if self.batch_norm: 134 | self.bn = [] 135 | for _ in range(blocks): 136 | self.bn.append(norm(out_channels)) 137 | self.bn = nn.ModuleList(self.bn) 138 | if self.pooling: 139 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 140 | self.conv2 = nn.ModuleList(self.conv2) 141 | self.act = act 142 | 143 | def __call__(self, x): 144 | return self.forward(x) 145 | 146 | def forward(self, x): 147 | x1 = self.act(self.norm1(self.conv1(x))) 148 | x2 = None 149 | for idx, conv in enumerate(self.conv2): 150 | x2 = conv(x1) 151 | if self.batch_norm: 152 | x2 = self.bn[idx](x2) 153 | if self.residual: 154 | x2 = x2 + x1 155 | x2 = self.act(x2) 156 | x1 = x2 157 | before_pool = x2 158 | if self.pooling: 159 | x2 = self.pool(x2) 160 | return x2, before_pool 161 | 162 | class UnetDecoderD(nn.Module): 163 | def __init__(self, in_channels=512, out_channels=3, norm=nn.BatchNorm2d,act=F.relu, depth=5, blocks=1, residual=True, batch_norm=True, 164 | transpose=True, concat=True, is_final=True, use_att=False): 165 | super(UnetDecoderD, self).__init__() 166 | self.conv_final = None 167 | self.up_convs = [] 168 | self.atts = [] 169 | self.use_att = use_att 170 | 171 | outs = in_channels 172 | for i in range(depth-1): # depth = 1 173 | ins = outs 174 | outs = ins // 2 175 | # 512,256 176 | # 256,128 177 | # 128,64 178 | # 64,32 179 | up_conv = UpCoXvD(ins, outs, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose, 180 | concat=concat, norm=norm, act=act) 181 | if self.use_att: 182 | self.atts.append(SEBlock(outs)) 183 | 184 | self.up_convs.append(up_conv) 185 | 186 | if is_final: 187 | self.conv_final = conv1x1(outs, out_channels) 188 | else: 189 | up_conv = UpCoXvD(outs, out_channels, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose, 190 | concat=concat,norm=norm, act=act) 191 | if self.use_att: 192 | self.atts.append(SEBlock(out_channels)) 193 | 194 | self.up_convs.append(up_conv) 195 | self.up_convs = nn.ModuleList(self.up_convs) 196 | self.atts = nn.ModuleList(self.atts) 197 | 198 | reset_params(self) 199 | 200 | def __call__(self, x, encoder_outs=None): 201 | return self.forward(x, encoder_outs) 202 | 203 | def forward(self, x, encoder_outs=None): 204 | for i, up_conv in enumerate(self.up_convs): 205 | before_pool = None 206 | if encoder_outs is not None: 207 | before_pool = encoder_outs[-(i+2)] 208 | x = up_conv(x, before_pool) 209 | if self.use_att: 210 | x = self.atts[i](x) 211 | 212 | if self.conv_final is not None: 213 | x = self.conv_final(x) 214 | return x 215 | 216 | 217 | class UnetDecoderDatt(nn.Module): 218 | def __init__(self, in_channels=512, out_channels=3, depth=5, blocks=1, residual=True, batch_norm=True, 219 | transpose=True, concat=True, is_final=True, norm=nn.BatchNorm2d,act=F.relu): 220 | super(UnetDecoderDatt, self).__init__() 221 | self.conv_final = None 222 | self.up_convs = [] 223 | self.im_atts = [] 224 | self.vm_atts = [] 225 | self.mask_atts = [] 226 | 227 | outs = in_channels 228 | for i in range(depth-1): # depth = 5 [0,1,2,3] 229 | ins = outs 230 | outs = ins // 2 231 | # 512,256 232 | # 256,128 233 | # 128,64 234 | # 64,32 235 | up_conv = UpCoXvD(ins, outs, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose, 236 | concat=concat, norm=nn.BatchNorm2d,act=F.relu) 237 | self.up_convs.append(up_conv) 238 | self.im_atts.append(SEBlock(outs)) 239 | self.vm_atts.append(SEBlock(outs)) 240 | self.mask_atts.append(SEBlock(outs)) 241 | if is_final: 242 | self.conv_final = conv1x1(outs, out_channels) 243 | else: 244 | up_conv = UpCoXvD(outs, out_channels, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose, 245 | concat=concat, norm=nn.BatchNorm2d,act=F.relu) 246 | self.up_convs.append(up_conv) 247 | self.im_atts.append(SEBlock(out_channels)) 248 | self.vm_atts.append(SEBlock(out_channels)) 249 | self.mask_atts.append(SEBlock(out_channels)) 250 | 251 | self.up_convs = nn.ModuleList(self.up_convs) 252 | self.im_atts = nn.ModuleList(self.im_atts) 253 | self.vm_atts = nn.ModuleList(self.vm_atts) 254 | self.mask_atts = nn.ModuleList(self.mask_atts) 255 | 256 | reset_params(self) 257 | 258 | def forward(self, input, encoder_outs=None): 259 | # im branch 260 | x = input 261 | for i, up_conv in enumerate(self.up_convs): 262 | before_pool = None 263 | if encoder_outs is not None: 264 | before_pool = encoder_outs[-(i+2)] 265 | x = up_conv(x, before_pool,se=self.im_atts[i]) 266 | x_im = x 267 | 268 | x = input 269 | for i, up_conv in enumerate(self.up_convs): 270 | before_pool = None 271 | if encoder_outs is not None: 272 | before_pool = encoder_outs[-(i+2)] 273 | x = up_conv(x, before_pool, se = self.mask_atts[i]) 274 | x_mask = x 275 | 276 | x = input 277 | for i, up_conv in enumerate(self.up_convs): 278 | before_pool = None 279 | if encoder_outs is not None: 280 | before_pool = encoder_outs[-(i+2)] 281 | x = up_conv(x, before_pool, se=self.vm_atts[i]) 282 | x_vm = x 283 | 284 | return x_im,x_mask,x_vm 285 | 286 | class UnetEncoderD(nn.Module): 287 | 288 | def __init__(self, in_channels=3, depth=5, blocks=1, start_filters=32, residual=True, batch_norm=True, norm=nn.BatchNorm2d, act=F.relu): 289 | super(UnetEncoderD, self).__init__() 290 | self.down_convs = [] 291 | outs = None 292 | if type(blocks) is tuple: 293 | blocks = blocks[0] 294 | for i in range(depth): 295 | ins = in_channels if i == 0 else outs 296 | outs = start_filters*(2**i) 297 | pooling = True if i < depth-1 else False 298 | down_conv = DownCoXvD(ins, outs, blocks, pooling=pooling, residual=residual, batch_norm=batch_norm, norm=nn.BatchNorm2d, act=F.relu) 299 | self.down_convs.append(down_conv) 300 | self.down_convs = nn.ModuleList(self.down_convs) 301 | reset_params(self) 302 | 303 | def __call__(self, x): 304 | return self.forward(x) 305 | 306 | def forward(self, x): 307 | encoder_outs = [] 308 | for d_conv in self.down_convs: 309 | x, before_pool = d_conv(x) 310 | encoder_outs.append(before_pool) 311 | return x, encoder_outs 312 | 313 | class ResDown(nn.Module): 314 | def __init__(self, in_size, out_size, pooling=True, use_att=False): 315 | super(ResDown, self).__init__() 316 | self.model = DownCoXvD(in_size, out_size, 3, pooling=pooling) 317 | 318 | def forward(self, x): 319 | return self.model(x) 320 | 321 | class ResUp(nn.Module): 322 | def __init__(self, in_size, out_size, use_att=False): 323 | super(ResUp, self).__init__() 324 | self.model = UpCoXvD(in_size, out_size, 3, use_att=use_att) 325 | 326 | def forward(self, x, skip_input, mask=None): 327 | return self.model(x,skip_input,mask) 328 | 329 | class ResDownNew(nn.Module): 330 | def __init__(self, in_size, out_size, pooling=True, use_att=False): 331 | super(ResDownNew, self).__init__() 332 | self.model = DownCoXvD(in_size, out_size, 3, pooling=pooling, norm=nn.InstanceNorm2d, act=F.leaky_relu) 333 | 334 | def forward(self, x): 335 | return self.model(x) 336 | 337 | class ResUpNew(nn.Module): 338 | def __init__(self, in_size, out_size, use_att=False): 339 | super(ResUpNew, self).__init__() 340 | self.model = UpCoXvD(in_size, out_size, 3, use_att=use_att, norm=nn.InstanceNorm2d) 341 | 342 | def forward(self, x, skip_input, mask=None): 343 | return self.model(x,skip_input,mask) 344 | 345 | 346 | 347 | class VMSingle(nn.Module): 348 | def __init__(self, in_channels=3, out_channels=3, down=ResDown, up=ResUp, ngf=32, res=True,use_att=False): 349 | super(VMSingle, self).__init__() 350 | 351 | self.down1 = down(in_channels, ngf) 352 | self.down2 = down(ngf, ngf*2) 353 | self.down3 = down(ngf*2, ngf*4) 354 | self.down4 = down(ngf*4, ngf*8) 355 | self.down5 = down(ngf*8, ngf*16, pooling=False) 356 | 357 | self.up1 = up(ngf*16, ngf*8) 358 | self.up2 = up(ngf*8, ngf*4, use_att=use_att) 359 | self.up3 = up(ngf*4, ngf*2, use_att=use_att) 360 | self.up4 = up(ngf*2, ngf*1, use_att=use_att) 361 | 362 | self.im = nn.Conv2d(ngf, 3, 1) 363 | self.res = res 364 | 365 | 366 | def forward(self, input): 367 | img, mask = input[:,0:3,:,:],input[:,3:4,:,:] 368 | # U-Net generator with skip connections from encoder to decoder 369 | x,d1 = self.down1(input) # 128,256 370 | x,d2 = self.down2(x) # 64,128 371 | x,d3 = self.down3(x) # 32,64 372 | x,d4 = self.down4(x) # 16,32 373 | x,_ = self.down5(x) # 8,16 374 | 375 | x = self.up1(x, d4) # 16 376 | x = self.up2(x, d3, mask) # 32 377 | x = self.up3(x, d2, mask) # 64 378 | x = self.up4(x, d1, mask) # 128 379 | im = self.im(x) 380 | 381 | return im 382 | 383 | 384 | 385 | class VMSingleS2AM(nn.Module): 386 | def __init__(self, in_channels=3, out_channels=3, down=ResDown, up=ResUp, ngf=32): 387 | super(VMSingleS2AM, self).__init__() 388 | 389 | self.down1 = down(in_channels, ngf) 390 | self.down2 = down(ngf, ngf*2) 391 | self.down3 = down(ngf*2, ngf*4) 392 | self.down4 = down(ngf*4, ngf*8) 393 | self.down5 = down(ngf*8, ngf*16, pooling=False) 394 | 395 | self.up1 = up(ngf*16, ngf*8) 396 | self.up2 = up(ngf*8, ngf*4) 397 | self.s2am2 = RASC(ngf*4) 398 | 399 | self.up3 = up(ngf*4, ngf*2) 400 | self.s2am3 = RASC(ngf*2) 401 | 402 | self.up4 = up(ngf*2, ngf*1) 403 | self.s2am4 = RASC(ngf) 404 | 405 | self.im = nn.Conv2d(ngf, 3, 1) 406 | 407 | 408 | def forward(self, input): 409 | img, mask = input[:,0:3,:,:],input[:,3:4,:,:] 410 | # U-Net generator with skip connections from encoder to decoder 411 | x,d1 = self.down1(input) # 128,256 412 | x,d2 = self.down2(x) # 64,128 413 | x,d3 = self.down3(x) # 32,64 414 | x,d4 = self.down4(x) # 16,32 415 | x,_ = self.down5(x) # 8,16 416 | 417 | x = self.up1(x, d4) # 16 418 | x = self.up2(x, d3) # 32 419 | x = self.s2am2(x, mask) 420 | 421 | x = self.up3(x, d2) # 64 422 | x = self.s2am3(x, mask) 423 | 424 | x = self.up4(x, d1) # 128 425 | x = self.s2am4(x, mask) 426 | im = self.im(x) 427 | return im 428 | 429 | 430 | class UnetVMS2AMv4(nn.Module): 431 | 432 | def __init__(self, in_channels=3, depth=5, shared_depth=0, use_vm_decoder=False, blocks=1, 433 | out_channels_image=3, out_channels_mask=1, start_filters=32, residual=True, batch_norm=True, 434 | transpose=True, concat=True, transfer_data=True, long_skip=False, s2am='unet', use_coarser=True,no_stage2=False): 435 | super(UnetVMS2AMv4, self).__init__() 436 | self.transfer_data = transfer_data 437 | self.shared = shared_depth 438 | self.optimizer_encoder, self.optimizer_image, self.optimizer_vm = None, None, None 439 | self.optimizer_mask, self.optimizer_shared = None, None 440 | if type(blocks) is not tuple: 441 | blocks = (blocks, blocks, blocks, blocks, blocks) 442 | if not transfer_data: 443 | concat = False 444 | self.encoder = UnetEncoderD(in_channels=in_channels, depth=depth, blocks=blocks[0], 445 | start_filters=start_filters, residual=residual, batch_norm=batch_norm,norm=nn.InstanceNorm2d,act=F.leaky_relu) 446 | self.image_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1), 447 | out_channels=out_channels_image, depth=depth - shared_depth, 448 | blocks=blocks[1], residual=residual, batch_norm=batch_norm, 449 | transpose=transpose, concat=concat,norm=nn.InstanceNorm2d) 450 | self.mask_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1), 451 | out_channels=out_channels_mask, depth=depth - shared_depth, 452 | blocks=blocks[2], residual=residual, batch_norm=batch_norm, 453 | transpose=transpose, concat=concat,norm=nn.InstanceNorm2d) 454 | self.vm_decoder = None 455 | if use_vm_decoder: 456 | self.vm_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1), 457 | out_channels=out_channels_image, depth=depth - shared_depth, 458 | blocks=blocks[3], residual=residual, batch_norm=batch_norm, 459 | transpose=transpose, concat=concat,norm=nn.InstanceNorm2d) 460 | self.shared_decoder = None 461 | self.use_coarser = use_coarser 462 | self.long_skip = long_skip 463 | self.no_stage2 = no_stage2 464 | self._forward = self.unshared_forward 465 | if self.shared != 0: 466 | self._forward = self.shared_forward 467 | self.shared_decoder = UnetDecoderDatt(in_channels=start_filters * 2 ** (depth - 1), 468 | out_channels=start_filters * 2 ** (depth - shared_depth - 1), 469 | depth=shared_depth, blocks=blocks[4], residual=residual, 470 | batch_norm=batch_norm, transpose=transpose, concat=concat, 471 | is_final=False,norm=nn.InstanceNorm2d) 472 | 473 | if s2am == 'unet': 474 | self.s2am = UnetGenerator(4,3,is_attention_layer=True,attention_model=RASC,basicblock=MinimalUnetV2) 475 | elif s2am == 'vm': 476 | self.s2am = VMSingle(4) 477 | elif s2am == 'vms2am': 478 | self.s2am = VMSingleS2AM(4,down=ResDownNew,up=ResUpNew) 479 | 480 | def set_optimizers(self): 481 | self.optimizer_encoder = torch.optim.Adam(self.encoder.parameters(), lr=0.001) 482 | self.optimizer_image = torch.optim.Adam(self.image_decoder.parameters(), lr=0.001) 483 | self.optimizer_mask = torch.optim.Adam(self.mask_decoder.parameters(), lr=0.001) 484 | self.optimizer_s2am = torch.optim.Adam(self.s2am.parameters(), lr=0.001) 485 | 486 | if self.vm_decoder is not None: 487 | self.optimizer_vm = torch.optim.Adam(self.vm_decoder.parameters(), lr=0.001) 488 | if self.shared != 0: 489 | self.optimizer_shared = torch.optim.Adam(self.shared_decoder.parameters(), lr=0.001) 490 | 491 | def zero_grad_all(self): 492 | self.optimizer_encoder.zero_grad() 493 | self.optimizer_image.zero_grad() 494 | self.optimizer_mask.zero_grad() 495 | self.optimizer_s2am.zero_grad() 496 | if self.vm_decoder is not None: 497 | self.optimizer_vm.zero_grad() 498 | if self.shared != 0: 499 | self.optimizer_shared.zero_grad() 500 | 501 | def step_all(self): 502 | self.optimizer_encoder.step() 503 | self.optimizer_image.step() 504 | self.optimizer_mask.step() 505 | self.optimizer_s2am.step() 506 | if self.vm_decoder is not None: 507 | self.optimizer_vm.step() 508 | if self.shared != 0: 509 | self.optimizer_shared.step() 510 | 511 | def step_optimizer_image(self): 512 | self.optimizer_image.step() 513 | 514 | def __call__(self, synthesized): 515 | return self._forward(synthesized) 516 | 517 | def forward(self, synthesized): 518 | return self._forward(synthesized) 519 | 520 | def unshared_forward(self, synthesized): 521 | image_code, before_pool = self.encoder(synthesized) 522 | if not self.transfer_data: 523 | before_pool = None 524 | reconstructed_image = torch.tanh(self.image_decoder(image_code, before_pool)) 525 | reconstructed_mask = torch.sigmoid(self.mask_decoder(image_code, before_pool)) 526 | if self.vm_decoder is not None: 527 | reconstructed_vm = torch.tanh(self.vm_decoder(image_code, before_pool)) 528 | return reconstructed_image, reconstructed_mask, reconstructed_vm 529 | return reconstructed_image, reconstructed_mask 530 | 531 | def shared_forward(self, synthesized): 532 | image_code, before_pool = self.encoder(synthesized) 533 | if self.transfer_data: 534 | shared_before_pool = before_pool[- self.shared - 1:] 535 | unshared_before_pool = before_pool[: - self.shared] 536 | else: 537 | before_pool = None 538 | shared_before_pool = None 539 | unshared_before_pool = None 540 | im,mask,vm = self.shared_decoder(image_code, shared_before_pool) 541 | reconstructed_image = torch.tanh(self.image_decoder(im, unshared_before_pool)) 542 | if self.long_skip: 543 | reconstructed_image = reconstructed_image + synthesized 544 | 545 | reconstructed_mask = torch.sigmoid(self.mask_decoder(mask, unshared_before_pool)) 546 | if self.vm_decoder is not None: 547 | reconstructed_vm = torch.tanh(self.vm_decoder(vm, unshared_before_pool)) 548 | if self.long_skip: 549 | reconstructed_vm = reconstructed_vm + synthesized 550 | 551 | coarser = reconstructed_image * reconstructed_mask + (1-reconstructed_mask)* synthesized 552 | 553 | if self.use_coarser: 554 | refine = torch.tanh(self.s2am(torch.cat([coarser,reconstructed_mask],dim=1))) + coarser 555 | elif self.no_stage2: 556 | refine = torch.tanh(self.s2am(torch.cat([coarser,reconstructed_mask],dim=1))) 557 | else: 558 | refine = torch.tanh(self.s2am(torch.cat([coarser,reconstructed_mask],dim=1))) + synthesized 559 | 560 | # final = refine * reconstructed_mask + (1-reconstructed_mask)* synthesized 561 | if self.vm_decoder is not None: 562 | return [refine, reconstructed_image], reconstructed_mask, reconstructed_vm 563 | else: 564 | return [refine, reconstructed_image], reconstructed_mask 565 | 566 | 567 | --------------------------------------------------------------------------------