├── scripts
├── datasets
│ ├── __init__.py
│ ├── COCO.py
│ └── BIH.py
├── models
│ ├── __init__.py
│ ├── backbone_unet.py
│ ├── vgg.py
│ ├── rasc.py
│ ├── discriminator.py
│ ├── unet.py
│ ├── blocks.py
│ ├── vmu.py
│ └── sa_resunet.py
├── utils
│ ├── __init__.py
│ ├── osutils.py
│ ├── model_init.py
│ ├── misc.py
│ ├── evaluation.py
│ ├── logger.py
│ ├── transforms.py
│ ├── imutils.py
│ ├── losses.py
│ └── parallel.py
├── machines
│ ├── __init__.py
│ ├── S2AM.py
│ ├── BasicMachine.py
│ └── VX.py
└── __init__.py
├── requirements.txt
├── examples
├── test.sh
└── evaluate.sh
├── test.py
├── README.md
├── main.py
├── watermark_synthesis.ipynb
└── options.py
/scripts/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .COCO import COCO
2 | from .BIH import BIH
3 |
4 | __all__ = ('COCO','BIH')
--------------------------------------------------------------------------------
/scripts/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .vgg import *
2 | from .backbone_unet import *
3 | from .discriminator import *
4 |
5 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.19.1
2 | opencv-python==3.4.8.29
3 | Pillow
4 | scikit-image==0.14.5
5 | scikit-learn==0.23.1
6 | scipy==1.2.1
7 | sklearn==0.0
8 | tensorboardX
9 | torch>=1.0.0
10 | torchvision
--------------------------------------------------------------------------------
/scripts/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from .evaluation import *
4 | from .imutils import *
5 | from .logger import *
6 | from .misc import *
7 | from .osutils import *
8 | from .transforms import *
9 |
--------------------------------------------------------------------------------
/scripts/machines/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from .BasicMachine import BasicMachine
3 | from .VX import VX
4 | from .S2AM import S2AM
5 |
6 | def basic(**kwargs):
7 | return BasicMachine(**kwargs)
8 |
9 | def s2am(**kwargs):
10 | return S2AM(**kwargs)
11 |
12 | def vx(**kwargs):
13 | return VX(**kwargs)
14 |
--------------------------------------------------------------------------------
/scripts/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from . import datasets
4 | from . import models
5 | from . import utils
6 |
7 | # import os, sys
8 | # sys.path.append(os.path.join(os.path.dirname(__file__), "progress"))
9 | # from progress.bar import Bar as Bar
10 |
11 | # __version__ = '0.1.0'
--------------------------------------------------------------------------------
/examples/test.sh:
--------------------------------------------------------------------------------
1 |
2 | set -ex
3 |
4 | CUDA_VISIBLE_DEVICES=0 python /data/home/yb87432/s2am/test.py \
5 | -c test/10kgray_ssim\
6 | --resume /data/home/yb87432/s2am/eval/10kgray/1e3_bs6_256_hybrid_ssim_vgg_vx__images_vvv4n/model_best.pth.tar\
7 | --arch vvv4n\
8 | --machine vx\
9 | --input-size 256\
10 | --test-batch 1\
11 | --evaluate\
12 | --base-dir $HOME/watermark/10kgray/\
13 | --data _images
--------------------------------------------------------------------------------
/scripts/utils/osutils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import os
4 | import errno
5 |
6 | def mkdir_p(dir_path):
7 | try:
8 | os.makedirs(dir_path)
9 | except OSError as e:
10 | if e.errno != errno.EEXIST:
11 | raise
12 |
13 | def isfile(fname):
14 | return os.path.isfile(fname)
15 |
16 | def isdir(dirname):
17 | return os.path.isdir(dirname)
18 |
19 | def join(path, *paths):
20 | return os.path.join(path, *paths)
21 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 |
3 | import argparse
4 | import torch
5 |
6 | torch.backends.cudnn.benchmark = True
7 |
8 | from scripts.utils.misc import save_checkpoint, adjust_learning_rate
9 |
10 | import scripts.datasets as datasets
11 | import scripts.machines as machines
12 | from options import Options
13 |
14 | def main(args):
15 |
16 | val_loader = torch.utils.data.DataLoader(datasets.COCO('val',args),batch_size=args.test_batch, shuffle=False,
17 | num_workers=args.workers, pin_memory=True)
18 |
19 | data_loaders = (None,val_loader)
20 |
21 | Machine = machines.__dict__[args.machine](datasets=data_loaders, args=args)
22 |
23 | Machine.test()
24 |
25 | if __name__ == '__main__':
26 | parser=Options().init(argparse.ArgumentParser(description='WaterMark Removal'))
27 | main(parser.parse_args())
28 |
--------------------------------------------------------------------------------
/examples/evaluate.sh:
--------------------------------------------------------------------------------
1 | set -ex
2 |
3 |
4 |
5 | # example training scripts for AAAI-21
6 | # Split then Refine: Stacked Attention-guided ResUNets for Blind Single Image Visible Watermark Removal
7 |
8 |
9 | CUDA_VISIBLE_DEVICES=0 python /data/home/yb87432/s2am/main.py --epochs 100\
10 | --schedule 100\
11 | --lr 1e-3\
12 | -c eval/10kgray/1e3_bs4_256_hybrid_ssim_vgg\
13 | --arch vvv4n\
14 | --sltype vggx\
15 | --style-loss 0.025\
16 | --ssim-loss 0.15\
17 | --masked True\
18 | --loss-type hybrid\
19 | --limited-dataset 1\
20 | --machine vx\
21 | --input-size 256\
22 | --train-batch 4\
23 | --test-batch 1\
24 | --base-dir $HOME/watermark/10kgray/\
25 | --data _images
26 |
27 |
28 |
29 |
30 |
31 | # example training scripts for TIP-20
32 | # Improving the Harmony of the Composite Image by Spatial-Separated Attention Module
33 | # * in the original version, the res = False
34 | # suitable for the iHarmony4 dataset.
35 |
36 | python /data/home/yb87432/mypaper/s2am/main.py --epochs 200\
37 | --schedule 150\
38 | --lr 1e-3\
39 | -c checkpoint/normal_rasc_HAdobe5k_res \
40 | --arch rascv2\
41 | --style-loss 0\
42 | --ssim-loss 0\
43 | --limited-dataset 0\
44 | --res True\
45 | --machine s2am\
46 | --input-size 256\
47 | --train-batch 16\
48 | --test-batch 1\
49 | --base-dir $HOME/Datasets/\
50 | --data HAdobe5k
--------------------------------------------------------------------------------
/scripts/models/backbone_unet.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import torch
4 | import torchvision
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import numpy as np
8 | import functools
9 | import math
10 |
11 | from scripts.utils.model_init import *
12 | from scripts.models.rasc import *
13 | from scripts.models.unet import UnetGenerator,MinimalUnetV2
14 | from scripts.models.vmu import UnetVM
15 | from scripts.models.sa_resunet import UnetVMS2AMv4
16 |
17 |
18 | # our method
19 | def vvv4n(**kwargs):
20 | return UnetVMS2AMv4(shared_depth=2, blocks=3, long_skip=True, use_vm_decoder=True,s2am='vms2am')
21 |
22 |
23 | # BVMR
24 | def vm3(**kwargs):
25 | return UnetVM(shared_depth=2, blocks=3, use_vm_decoder=True)
26 |
27 |
28 | # Blind version of S2AM
29 | def urasc(**kwargs):
30 | model = UnetGenerator(3,3,is_attention_layer=True,attention_model=URASC,basicblock=MinimalUnetV2)
31 | model.apply(weights_init_kaiming)
32 | return model
33 |
34 |
35 | # Improving the Harmony of the Composite Image by Spatial-Separated Attention Module
36 | # Xiaodong Cun and Chi-Man Pun
37 | # University of Macau
38 | # Trans. on Image Processing, vol. 29, pp. 4759-4771, 2020.
39 | def rascv2(**kwargs):
40 | model = UnetGenerator(4,3,is_attention_layer=True,attention_model=RASC,basicblock=MinimalUnetV2)
41 | model.apply(weights_init_kaiming)
42 | return model
43 |
44 | # just original unet
45 | def unet(**kwargs):
46 | model = UnetGenerator(3,3)
47 | model.apply(weights_init_kaiming)
48 | return model
49 |
50 |
51 |
--------------------------------------------------------------------------------
/scripts/utils/model_init.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | from torch.nn import init
4 |
5 |
6 | def weights_init_normal(m):
7 | classname = m.__class__.__name__
8 | # print(classname)
9 | if classname.find('Conv') != -1:
10 | init.normal_(m.weight.data, 0.0, 0.02)
11 | elif classname.find('Linear') != -1:
12 | init.normal_(m.weight.data, 0.0, 0.02)
13 | elif classname.find('BatchNorm2d') != -1:
14 | init.normal_(m.weight.data, 1.0, 0.02)
15 | init.constant_(m.bias.data, 0.0)
16 |
17 |
18 | def weights_init_xavier(m):
19 | classname = m.__class__.__name__
20 | # print(classname)
21 | if classname.find('Conv') != -1:
22 | init.xavier_normal(m.weight.data, gain=0.02)
23 | elif classname.find('Linear') != -1:
24 | init.xavier_normal(m.weight.data, gain=0.02)
25 | # elif classname.find('BatchNorm2d') != -1:
26 | # init.normal(m.weight.data, 1.0, 0.02)
27 | # init.constant(m.bias.data, 0.0)
28 |
29 |
30 | def weights_init_kaiming(m):
31 | classname = m.__class__.__name__
32 | # print(classname)
33 | if classname.find('Conv') != -1 and m.weight.requires_grad == True:
34 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
35 | elif classname.find('Linear') != -1 and m.weight.requires_grad == True:
36 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
37 | elif classname.find('BatchNorm2d') != -1 and m.weight.requires_grad == True:
38 | init.normal_(m.weight.data, 1.0, 0.02)
39 | init.constant_(m.bias.data, 0.0)
40 |
41 |
42 | def weights_init_orthogonal(m):
43 | classname = m.__class__.__name__
44 | if classname.find('Conv') != -1:
45 | init.orthogonal(m.weight.data, gain=1)
46 | elif classname.find('Linear') != -1:
47 | init.orthogonal(m.weight.data, gain=1)
48 | # elif classname.find('BatchNorm2d') != -1:
49 | # init.normal(m.weight.data, 1.0, 0.02)
50 | # init.constant(m.bias.data, 0.0)
--------------------------------------------------------------------------------
/scripts/utils/misc.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import os
4 | import shutil
5 | import torch
6 | import math
7 | import numpy as np
8 | import scipy.io
9 | import matplotlib.pyplot as plt
10 | import torch.nn.functional as F
11 |
12 | def to_numpy(tensor):
13 | if torch.is_tensor(tensor):
14 | return tensor.cpu().numpy()
15 | elif type(tensor).__module__ != 'numpy':
16 | raise ValueError("Cannot convert {} to numpy array"
17 | .format(type(tensor)))
18 | return tensor
19 |
20 | def resize_to_match(fm,to):
21 | # just use interpolate
22 | # [1,3] = (h,w)
23 | return F.interpolate(fm,to.size()[-2:],mode='bilinear',align_corners=False)
24 |
25 | def to_torch(ndarray):
26 | if type(ndarray).__module__ == 'numpy':
27 | return torch.from_numpy(ndarray)
28 | elif not torch.is_tensor(ndarray):
29 | raise ValueError("Cannot convert {} to torch tensor"
30 | .format(type(ndarray)))
31 | return ndarray
32 |
33 |
34 | def save_checkpoint(machine,filename='checkpoint.pth.tar', snapshot=None):
35 | is_best = True if machine.best_acc < machine.metric else False
36 |
37 | if is_best:
38 | machine.best_acc = machine.metric
39 |
40 | state = {
41 | 'epoch': machine.current_epoch + 1,
42 | 'arch': machine.args.arch,
43 | 'state_dict': machine.model.state_dict(),
44 | 'best_acc': machine.best_acc,
45 | 'optimizer' : machine.optimizer.state_dict(),
46 | }
47 |
48 | filepath = os.path.join(machine.args.checkpoint, filename)
49 | torch.save(state, filepath)
50 |
51 | if snapshot and state['epoch'] % snapshot == 0:
52 | shutil.copyfile(filepath, os.path.join(machine.args.checkpoint, 'checkpoint_{}.pth.tar'.format(state.epoch)))
53 |
54 | if is_best:
55 | machine.best_acc = machine.metric
56 | print('Saving Best Metric with PSNR:%s'%machine.best_acc)
57 | shutil.copyfile(filepath, os.path.join(machine.args.checkpoint, 'model_best.pth.tar'))
58 |
59 |
60 |
61 | def save_pred(preds, checkpoint='checkpoint', filename='preds_valid.mat'):
62 | preds = to_numpy(preds)
63 | filepath = os.path.join(checkpoint, filename)
64 | scipy.io.savemat(filepath, mdict={'preds' : preds})
65 |
66 |
67 | def adjust_learning_rate(datasets,optimizer, epoch, lr,args):
68 | """Sets the learning rate to the initial LR decayed by schedule"""
69 | if epoch in args.schedule:
70 | lr *= args.gamma
71 | for param_group in optimizer.param_groups:
72 | param_group['lr'] = lr
73 |
74 | # decay sigma
75 | for dset in datasets:
76 | if args.sigma_decay > 0:
77 | dset.dataset.sigma *= args.sigma_decay
78 | dset.dataset.sigma *= args.sigma_decay
79 |
80 | return lr
81 |
82 |
83 |
84 |
85 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | This repo contains the code and results of the AAAI 2021 paper:
2 |
3 | [Split then Refine: Stacked Attention-guided ResUNets for Blind Single Image Visible Watermark Removal](https://arxiv.org/abs/2012.07007)
4 | [Xiaodong Cun](http://vinthony.github.io), [Chi-Man Pun*](http://www.cis.umac.mo/~cmpun/)
5 | [University of Macau](http://um.edu.mo/)
6 |
7 | [Datasets](#Resources) | [Models](#Resources) | [Paper](https://arxiv.org/abs/2012.07007) | [🔥Online Demo!](https://colab.research.google.com/drive/1pYY7byBjM-7aFIWk8HcF9nK_s6pqGwww?usp=sharing)(Google CoLab)
8 |
9 |
10 |
11 |
12 |
13 | The overview of the proposed two-stage framework. Firstly, we propose a multi-task network, SplitNet, for watermark detection, removal, and recovery. Then, we propose the RefineNet to smooth the learned region with the predicted mask and the recovered background from the previous stage. As a consequence, our network can be trained in an end-to-end fashion without any manual intervention. Note that, for clarity, we do not show any skip-connections between all the encoders and decoders.
14 |
15 |
16 | > The whole project will be released in the January of 2021 (almost).
17 |
18 |
19 | ### Datasets
20 |
21 | We synthesized four different datasets for training and testing, you can download the dataset via [huggingface](https://huggingface.co/datasets/vinthony/watermark-removal-logo/tree/main).
22 |
23 | 
24 |
25 |
26 | ### Pre-trained Models
27 |
28 | * [27kpng_model_best.pth.tar (google drive)](https://drive.google.com/file/d/1KpSJ6385CHN6WlAINqB3CYrJdleQTJBc/view?usp=sharing)
29 |
30 | > Other Pre-trained Models are still reorganizing and uploading, it will be released soon.
31 |
32 |
33 | ### Demos
34 |
35 | An easy-to-use online demo can be founded in [google colab](https://colab.research.google.com/drive/1pYY7byBjM-7aFIWk8HcF9nK_s6pqGwww?usp=sharing).
36 |
37 | The local demo will be released soon.
38 |
39 | ### Pre-requirements
40 |
41 | ```
42 | pip install -r requirements.txt
43 | ```
44 |
45 | ### Train
46 |
47 | Besides training our methods, here, we also give an example of how to train the [s2am](https://github.com/vinthony/s2am) under our framework. More details can be found in the shell scripts.
48 |
49 |
50 | ```
51 | bash examples/evaluation.sh
52 | ```
53 |
54 | ### Test
55 |
56 | ```
57 | bash examples/test.sh
58 | ```
59 |
60 | ## **Acknowledgements**
61 | The author would like to thanks Nan Chen for her helpful discussion.
62 |
63 | Part of the code is based upon our previous work on image harmonization [s2am](https://github.com/vinthony/s2am)
64 |
65 | ## **Citation**
66 |
67 | If you find our work useful in your research, please consider citing:
68 |
69 | ```
70 | @misc{cun2020split,
71 | title={Split then Refine: Stacked Attention-guided ResUNets for Blind Single Image Visible Watermark Removal},
72 | author={Xiaodong Cun and Chi-Man Pun},
73 | year={2020},
74 | eprint={2012.07007},
75 | archivePrefix={arXiv},
76 | primaryClass={cs.CV}
77 | }
78 | ```
79 |
80 | ## **Contact**
81 | Please contact me if there is any question (Xiaodong Cun yb87432@um.edu.mo)
82 |
--------------------------------------------------------------------------------
/scripts/models/vgg.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 |
3 | import torch
4 | from torchvision import models
5 |
6 |
7 | class Vgg16(torch.nn.Module):
8 | def __init__(self, requires_grad=False):
9 | super(Vgg16, self).__init__()
10 | vgg_pretrained_features = models.vgg16(pretrained=True).features
11 | self.slice1 = torch.nn.Sequential()
12 | self.slice2 = torch.nn.Sequential()
13 | self.slice3 = torch.nn.Sequential()
14 | self.slice4 = torch.nn.Sequential()
15 | self.slice5 = torch.nn.Sequential()
16 | for x in range(4):
17 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
18 | for x in range(4, 9):
19 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
20 | for x in range(9, 16):
21 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
22 | for x in range(16, 23):
23 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
24 | for x in range(23,30):
25 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
26 |
27 | if not requires_grad:
28 | for param in self.parameters():
29 | param.requires_grad = False
30 |
31 | def forward(self, X):
32 | h = self.slice1(X)
33 | h_relu1_2 = h
34 | h = self.slice2(h)
35 | h_relu2_2 = h
36 | h = self.slice3(h)
37 | h_relu3_3 = h
38 | h = self.slice4(h)
39 | h_relu4_3 = h
40 | h = self.slice5(h)
41 | h_relu5_3 = h
42 | # vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3','relu5_3'])
43 | # out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
44 | return (h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
45 |
46 |
47 | class Vgg19(torch.nn.Module):
48 | def __init__(self, requires_grad=False):
49 | super(Vgg19, self).__init__()
50 | # vgg_pretrained_features = models.vgg19(pretrained=True).features
51 | self.vgg_pretrained_features = models.vgg19(pretrained=True).features
52 | # self.slice1 = torch.nn.Sequential()
53 | # self.slice2 = torch.nn.Sequential()
54 | # self.slice3 = torch.nn.Sequential()
55 | # self.slice4 = torch.nn.Sequential()
56 | # self.slice5 = torch.nn.Sequential()
57 | # for x in range(2):
58 | # self.slice1.add_module(str(x), vgg_pretrained_features[x])
59 | # for x in range(2, 7):
60 | # self.slice2.add_module(str(x), vgg_pretrained_features[x])
61 | # for x in range(7, 12):
62 | # self.slice3.add_module(str(x), vgg_pretrained_features[x])
63 | # for x in range(12, 21):
64 | # self.slice4.add_module(str(x), vgg_pretrained_features[x])
65 | # for x in range(21, 30):
66 | # self.slice5.add_module(str(x), vgg_pretrained_features[x])
67 | if not requires_grad:
68 | for param in self.parameters():
69 | param.requires_grad = False
70 |
71 | def forward(self, X, indices=None):
72 | if indices is None:
73 | indices = [2, 7, 12, 21, 30]
74 | out = []
75 | #indices = sorted(indices)
76 | for i in range(indices[-1]):
77 | X = self.vgg_pretrained_features[i](X)
78 | if (i+1) in indices:
79 | out.append(X)
80 |
81 | return out
82 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 |
3 | import argparse
4 | import torch,time,os
5 |
6 | torch.backends.cudnn.benchmark = True
7 |
8 | from scripts.utils.misc import save_checkpoint, adjust_learning_rate
9 |
10 | import scripts.datasets as datasets
11 | import scripts.machines as machines
12 | from options import Options
13 |
14 | def main(args):
15 |
16 | if 'HFlickr' or 'HCOCO' or 'Hday2night' or 'HAdobe5k' in args.base_dir:
17 | dataset_func = datasets.BIH
18 | else:
19 | dataset_func = datasets.COCO
20 |
21 | train_loader = torch.utils.data.DataLoader(dataset_func('train',args),batch_size=args.train_batch, shuffle=True,
22 | num_workers=args.workers, pin_memory=True)
23 |
24 | val_loader = torch.utils.data.DataLoader(dataset_func('val',args),batch_size=args.test_batch, shuffle=False,
25 | num_workers=args.workers, pin_memory=True)
26 |
27 | lr = args.lr
28 | data_loaders = (train_loader,val_loader)
29 |
30 | Machine = machines.__dict__[args.machine](datasets=data_loaders, args=args)
31 | print('============================ Initization Finish && Training Start =============================================')
32 |
33 | for epoch in range(Machine.args.start_epoch, Machine.args.epochs):
34 |
35 | print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr))
36 | lr = adjust_learning_rate(data_loaders, Machine.optimizer, epoch, lr, args)
37 |
38 | Machine.record('lr',lr, epoch)
39 | Machine.train(epoch)
40 |
41 | if args.freq < 0:
42 | Machine.validate(epoch)
43 | Machine.flush()
44 | Machine.save_checkpoint()
45 |
46 | if __name__ == '__main__':
47 | parser=Options().init(argparse.ArgumentParser(description='WaterMark Removal'))
48 | args = parser.parse_args()
49 | print('==================================== WaterMark Removal =============================================')
50 | print('==> {:50}: {:<}'.format("Start Time",time.ctime(time.time())))
51 | print('==> {:50}: {:<}'.format("USE GPU",os.environ['CUDA_VISIBLE_DEVICES']))
52 | print('==================================== Stable Parameters =============================================')
53 | for arg in vars(args):
54 | if type(getattr(args, arg)) == type([]):
55 | if ','.join([ str(i) for i in getattr(args, arg)]) == ','.join([ str(i) for i in parser.get_default(arg)]):
56 | print('==> {:50}: {:<}({:<})'.format(arg,','.join([ str(i) for i in getattr(args, arg)]),','.join([ str(i) for i in parser.get_default(arg)])))
57 | else:
58 | if getattr(args, arg) == parser.get_default(arg):
59 | print('==> {:50}: {:<}({:<})'.format(arg,getattr(args, arg),parser.get_default(arg)))
60 | print('==================================== Changed Parameters =============================================')
61 | for arg in vars(args):
62 | if type(getattr(args, arg)) == type([]):
63 | if ','.join([ str(i) for i in getattr(args, arg)]) != ','.join([ str(i) for i in parser.get_default(arg)]):
64 | print('==> {:50}: {:<}({:<})'.format(arg,','.join([ str(i) for i in getattr(args, arg)]),','.join([ str(i) for i in parser.get_default(arg)])))
65 | else:
66 | if getattr(args, arg) != parser.get_default(arg):
67 | print('==> {:50}: {:<}({:<})'.format(arg,getattr(args, arg),parser.get_default(arg)))
68 | print('==================================== Start Init Model ===============================================')
69 | main(args)
70 | print('==================================== FINISH WITHOUT ERROR =============================================')
71 |
--------------------------------------------------------------------------------
/scripts/datasets/COCO.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 |
3 | import os
4 | import csv
5 | import numpy as np
6 | import json
7 | import random
8 | import math
9 | import matplotlib.pyplot as plt
10 | from collections import namedtuple
11 | from os import listdir
12 | from os.path import isfile, join
13 |
14 | import torch
15 | import torch.utils.data as data
16 |
17 | from scripts.utils.osutils import *
18 | from scripts.utils.imutils import *
19 | from scripts.utils.transforms import *
20 | import torchvision.transforms as transforms
21 | from PIL import Image
22 | from PIL import ImageEnhance
23 | from PIL import ImageFilter
24 | from PIL import ImageFile
25 | ImageFile.LOAD_TRUNCATED_IMAGES = True
26 |
27 | class COCO(data.Dataset):
28 | def __init__(self,train,config=None, sample=[],gan_norm=False):
29 |
30 | self.train = []
31 | self.anno = []
32 | self.mask = []
33 | self.wm = []
34 | self.input_size = config.input_size
35 | self.normalized_input = config.normalized_input
36 | self.base_folder = config.base_dir
37 | self.dataset = train+config.data
38 |
39 | if config == None:
40 | self.data_augumentation = False
41 | else:
42 | self.data_augumentation = config.data_augumentation
43 |
44 | self.istrain = False if self.dataset.find('train') == -1 else True
45 | self.sample = sample
46 | self.gan_norm = gan_norm
47 | mypath = join(self.base_folder,self.dataset)
48 | file_names = sorted([f for f in listdir(join(mypath,'image')) if isfile(join(mypath,'image', f)) ])
49 |
50 | if config.limited_dataset > 0:
51 | xtrain = sorted(list(set([ file_name.split('-')[0] for file_name in file_names ])))
52 | tmp = []
53 | for x in xtrain:
54 | # get the file_name by identifier
55 | tmp.append([y for y in file_names if x in y][0])
56 |
57 | file_names = tmp
58 | else:
59 | file_names = file_names
60 |
61 | for file_name in file_names:
62 | self.train.append(os.path.join(mypath,'image',file_name))
63 | self.mask.append(os.path.join(mypath,'mask',file_name))
64 | self.wm.append(os.path.join(mypath,'wm',file_name))
65 | self.anno.append(os.path.join(self.base_folder,'natural',file_name.split('-')[0]+'.jpg'))
66 |
67 | if len(self.sample) > 0 :
68 | self.train = [ self.train[i] for i in self.sample ]
69 | self.mask = [ self.mask[i] for i in self.sample ]
70 | self.anno = [ self.anno[i] for i in self.sample ]
71 |
72 | self.trans = transforms.Compose([
73 | transforms.Resize((self.input_size,self.input_size)),
74 | transforms.ToTensor()
75 | ])
76 |
77 | print('total Dataset of '+self.dataset+' is : ', len(self.train))
78 |
79 |
80 | def __getitem__(self, index):
81 | img = Image.open(self.train[index]).convert('RGB')
82 | mask = Image.open(self.mask[index]).convert('L')
83 | anno = Image.open(self.anno[index]).convert('RGB')
84 | wm = Image.open(self.wm[index]).convert('RGB')
85 |
86 | return {"image": self.trans(img),
87 | "target": self.trans(anno),
88 | "mask": self.trans(mask),
89 | "wm": self.trans(wm),
90 | "name": self.train[index].split('/')[-1],
91 | "imgurl":self.train[index],
92 | "maskurl":self.mask[index],
93 | "targeturl":self.anno[index],
94 | "wmurl":self.wm[index]
95 | }
96 |
97 | def __len__(self):
98 |
99 | return len(self.train)
100 |
--------------------------------------------------------------------------------
/scripts/datasets/BIH.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 |
3 | import os
4 | import csv
5 | import numpy as np
6 | import json
7 | import random
8 | import math
9 | import matplotlib.pyplot as plt
10 | from collections import namedtuple
11 | from os import listdir
12 | from os.path import isfile, join
13 |
14 | import torch
15 | import torch.utils.data as data
16 |
17 | from scripts.utils.osutils import *
18 | from scripts.utils.imutils import *
19 | from scripts.utils.transforms import *
20 | import torchvision.transforms as transforms
21 | from PIL import Image
22 | from PIL import ImageEnhance
23 | from PIL import ImageFilter
24 | from PIL import ImageFile
25 | ImageFile.LOAD_TRUNCATED_IMAGES = True
26 |
27 | class BIH(data.Dataset):
28 | def __init__(self,train,config=None, sample=[],gan_norm=False):
29 |
30 | self.train = []
31 | self.anno = []
32 | self.mask = []
33 | self.wm = []
34 | self.input_size = config.input_size
35 | self.normalized_input = config.normalized_input
36 | self.base_folder = config.base_dir +'/' + config.data
37 | self.dataset = config.data
38 |
39 | if config == None:
40 | self.data_augumentation = False
41 | else:
42 | self.data_augumentation = config.data_augumentation
43 |
44 | self.istrain = False if train.find('train') == -1 else True
45 | self.sample = sample
46 | self.gan_norm = gan_norm
47 | mypath = join(self.base_folder,self.dataset+'_'+train+'.txt')
48 |
49 | with open(mypath) as f:
50 | # here we get the filenames
51 | file_names = [ im.strip() for im in f.readlines() ]
52 |
53 | if config.limited_dataset > 0:
54 | xtrain = sorted(list(set([ file_name.split('-')[0] for file_name in file_names ])))
55 | tmp = []
56 | for x in xtrain:
57 | tmp.append([y for y in file_names if x in y][0])
58 |
59 | file_names = tmp
60 | else:
61 | file_names = file_names
62 |
63 | for file_name in file_names:
64 | self.train.append(os.path.join(self.base_folder,'images',file_name))
65 | self.mask.append(os.path.join(self.base_folder,'masks','_'.join(file_name.split('_')[0:2])+'.png'))
66 | self.anno.append(os.path.join(self.base_folder,'reals',file_name.split('_')[0]+'.jpg'))
67 |
68 | if len(self.sample) > 0 :
69 | self.train = [ self.train[i] for i in self.sample ]
70 | self.mask = [ self.mask[i] for i in self.sample ]
71 | self.anno = [ self.anno[i] for i in self.sample ]
72 |
73 | self.trans = transforms.Compose([
74 | transforms.Resize((self.input_size,self.input_size)),
75 | transforms.ToTensor()
76 | ])
77 |
78 | print('total Dataset of '+self.dataset+' is : ', len(self.train))
79 |
80 |
81 | def __getitem__(self, index):
82 | img = Image.open(self.train[index]).convert('RGB')
83 | mask = Image.open(self.mask[index]).convert('L')
84 | anno = Image.open(self.anno[index]).convert('RGB')
85 |
86 | # for shadow removal and blind image harmonization, here is no ground truth wm
87 | # wm = Image.open(self.wm[index]).convert('RGB')
88 |
89 | return {"image": self.trans(img),
90 | "target": self.trans(anno),
91 | "mask": self.trans(mask),
92 | "name": self.train[index].split('/')[-1],
93 | "imgurl":self.train[index],
94 | "maskurl":self.mask[index],
95 | "targeturl":self.anno[index],
96 | }
97 |
98 | def __len__(self):
99 |
100 | return len(self.train)
101 |
--------------------------------------------------------------------------------
/scripts/utils/evaluation.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import math
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 | from random import randint
7 |
8 | from .misc import *
9 | from .transforms import transform, transform_preds
10 |
11 | __all__ = ['accuracy', 'AverageMeter']
12 |
13 | def get_preds(scores):
14 | ''' get predictions from score maps in torch Tensor
15 | return type: torch.LongTensor
16 | '''
17 | assert scores.dim() == 4, 'Score maps should be 4-dim'
18 | maxval, idx = torch.max(scores.view(scores.size(0), scores.size(1), -1), 2)
19 |
20 | maxval = maxval.view(scores.size(0), scores.size(1), 1)
21 | idx = idx.view(scores.size(0), scores.size(1), 1) + 1
22 |
23 | preds = idx.repeat(1, 1, 2).float()
24 |
25 | preds[:,:,0] = (preds[:,:,0] - 1) % scores.size(3) + 1
26 | preds[:,:,1] = torch.floor((preds[:,:,1] - 1) / scores.size(2)) + 1
27 |
28 | pred_mask = maxval.gt(0).repeat(1, 1, 2).float()
29 | preds *= pred_mask
30 | return preds
31 |
32 | def calc_dists(preds, target, normalize):
33 | preds = preds.float()
34 | target = target.float()
35 | dists = torch.zeros(preds.size(1), preds.size(0))
36 | for n in range(preds.size(0)):
37 | for c in range(preds.size(1)):
38 | if target[n,c,0] > 1 and target[n, c, 1] > 1:
39 | dists[c, n] = torch.dist(preds[n,c,:], target[n,c,:])/normalize[n]
40 | else:
41 | dists[c, n] = -1
42 | return dists
43 |
44 | def dist_acc(dists, thr=0.5):
45 | ''' Return percentage below threshold while ignoring values with a -1 '''
46 | if dists.ne(-1).sum() > 0:
47 | return dists.le(thr).eq(dists.ne(-1)).sum()*1.0 / dists.ne(-1).sum()
48 | else:
49 | return -1
50 |
51 |
52 |
53 | def accuracy(output, target, thr=0.5):
54 | ''' Calculate accuracy according to PCK, but uses ground truth heatmap rather than x,y locations
55 | First value to be returned is average accuracy across 'idxs', followed by individual accuracies
56 | '''
57 | # output_mask = torch.gt(output,thr);
58 | # target_mask = torch.gt(target,thr);
59 | # equal_mask = torch.eq(output_mask,target_mask);
60 | # fp_equal_mask = torch.lt(output_mask,target_mask);
61 | # fn_equal_mask = torch.gt(output_mask,target_mask);
62 |
63 |
64 | # tp = torch.sum(equal_mask);
65 | # fn = torch.sum(fn_equal_mask);
66 | # fp = torch.sum(fp_equal_mask);
67 |
68 | # return 2*tp / (2*tp+fn+fp)
69 |
70 |
71 | if output.dim() > 2:
72 | v,i = torch.max(output,1);
73 | else:
74 | v,i = torch.max(output,1);
75 | return torch.sum(target.long() == i).float()/target.numel()
76 |
77 | def final_preds(output, center, scale, res):
78 | coords = get_preds(output) # float type
79 |
80 | # pose-processing
81 | for n in range(coords.size(0)):
82 | for p in range(coords.size(1)):
83 | hm = output[n][p]
84 | px = int(math.floor(coords[n][p][0]))
85 | py = int(math.floor(coords[n][p][1]))
86 | if px > 1 and px < res[0] and py > 1 and py < res[1]:
87 | diff = torch.Tensor([hm[py - 1][px] - hm[py - 1][px - 2], hm[py][px - 1]-hm[py - 2][px - 1]])
88 | coords[n][p] += diff.sign() * .25
89 | coords += 0.5
90 | preds = coords.clone()
91 |
92 | # Transform back
93 | for i in range(coords.size(0)):
94 | preds[i] = transform_preds(coords[i], center[i], scale[i], res)
95 |
96 | if preds.dim() < 3:
97 | preds = preds.view(1, preds.size())
98 |
99 | return preds
100 |
101 |
102 | class AverageMeter(object):
103 | """Computes and stores the average and current value"""
104 | def __init__(self):
105 | self.reset()
106 |
107 | def reset(self):
108 | self.val = 0
109 | self.avg = 0
110 | self.sum = 0
111 | self.count = 0
112 |
113 | def update(self, val, n=1):
114 | self.val = val
115 | self.sum += val * n
116 | self.count += n
117 | self.avg = self.sum / self.count
118 |
--------------------------------------------------------------------------------
/scripts/utils/logger.py:
--------------------------------------------------------------------------------
1 | # A simple torch style logger
2 | # (C) Wei YANG 2017
3 | from __future__ import absolute_import
4 |
5 | import os
6 | import sys
7 | import numpy as np
8 | import matplotlib.pyplot as plt
9 |
10 | __all__ = ['Logger', 'LoggerMonitor', 'savefig']
11 |
12 | def savefig(fname, dpi=None):
13 | dpi = 150 if dpi == None else dpi
14 | plt.savefig(fname, dpi=dpi)
15 |
16 | def plot_overlap(logger, names=None):
17 | names = logger.names if names == None else names
18 | numbers = logger.numbers
19 | for _, name in enumerate(names):
20 | x = np.arange(len(numbers[name]))
21 | plt.plot(x, np.asarray(numbers[name]))
22 | return [logger.title + '(' + name + ')' for name in names]
23 |
24 | class Logger(object):
25 | '''Save training process to log file with simple plot function.'''
26 | def __init__(self, fpath, title=None, resume=False):
27 | self.file = None
28 | self.resume = resume
29 | self.title = '' if title == None else title
30 | if fpath is not None:
31 | if resume:
32 | self.file = open(fpath, 'r')
33 | name = self.file.readline()
34 | self.names = name.rstrip().split('\t')
35 | self.numbers = {}
36 | for _, name in enumerate(self.names):
37 | self.numbers[name] = []
38 |
39 | for numbers in self.file:
40 | numbers = numbers.rstrip().split('\t')
41 | for i in range(0, len(numbers)):
42 | self.numbers[self.names[i]].append(numbers[i])
43 | self.file.close()
44 | self.file = open(fpath, 'a')
45 | else:
46 | self.file = open(fpath, 'w')
47 |
48 | def set_names(self, names):
49 | if self.resume:
50 | pass
51 | # initialize numbers as empty list
52 | self.numbers = {}
53 | self.names = names
54 | for _, name in enumerate(self.names):
55 | self.file.write(name)
56 | self.file.write('\t')
57 | self.numbers[name] = []
58 | self.file.write('\n')
59 | self.file.flush()
60 |
61 |
62 | def append(self, numbers):
63 | assert len(self.names) == len(numbers), 'Numbers do not match names'
64 | for index, num in enumerate(numbers):
65 | self.file.write("{0:.6f}".format(num))
66 | self.file.write('\t')
67 | self.numbers[self.names[index]].append(num)
68 | self.file.write('\n')
69 | self.file.flush()
70 |
71 | def plot(self, names=None):
72 | names = self.names if names == None else names
73 | numbers = self.numbers
74 | for _, name in enumerate(names):
75 | x = np.arange(len(numbers[name]))
76 | plt.plot(x, np.asarray(numbers[name]))
77 | plt.legend([self.title + '(' + name + ')' for name in names])
78 | plt.grid(True)
79 |
80 | def close(self):
81 | if self.file is not None:
82 | self.file.close()
83 |
84 | class LoggerMonitor(object):
85 | '''Load and visualize multiple logs.'''
86 | def __init__ (self, paths):
87 | '''paths is a distionary with {name:filepath} pair'''
88 | self.loggers = []
89 | for title, path in paths.items():
90 | logger = Logger(path, title=title, resume=True)
91 | self.loggers.append(logger)
92 |
93 | def plot(self, names=None):
94 | plt.figure()
95 | plt.subplot(121)
96 | legend_text = []
97 | for logger in self.loggers:
98 | legend_text += plot_overlap(logger, names)
99 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
100 | plt.grid(True)
101 |
102 | if __name__ == '__main__':
103 | # # Example
104 | # logger = Logger('test.txt')
105 | # logger.set_names(['Train loss', 'Valid loss','Test loss'])
106 |
107 | # length = 100
108 | # t = np.arange(length)
109 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
110 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
111 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
112 |
113 | # for i in range(0, length):
114 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]])
115 | # logger.plot()
116 |
117 | # Example: logger monitor
118 | paths = {
119 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt',
120 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt',
121 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt',
122 | }
123 |
124 | field = ['Valid Acc.']
125 |
126 | monitor = LoggerMonitor(paths)
127 | monitor.plot(names=field)
128 | savefig('test.eps')
--------------------------------------------------------------------------------
/watermark_synthesis.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "name": "stdout",
10 | "output_type": "stream",
11 | "text": [
12 | "SAVE ALL THE SETTING\n"
13 | ]
14 | }
15 | ],
16 | "source": [
17 | "# watermark synthesis\n",
18 | "import os \n",
19 | "import random\n",
20 | "import shutil\n",
21 | "from PIL import Image\n",
22 | "import numpy as np\n",
23 | "\n",
24 | "def trans_paste(bg_img,fg_img,mask,box=(0,0)):\n",
25 | " fg_img_trans = Image.new(\"RGBA\",bg_img.size)\n",
26 | " fg_img_trans.paste(fg_img,box,mask=mask)\n",
27 | " new_img = Image.alpha_composite(bg_img,fg_img_trans)\n",
28 | " return new_img,fg_img_trans\n",
29 | "\n",
30 | "if os.path.isdir('dataset'):\n",
31 | " shutil.rmtree('dataset')\n",
32 | "\n",
33 | "os.mkdir('dataset')\n",
34 | "BASE_IMG_DIR = '/Users/oishii/Downloads/val2014/'\n",
35 | "WATERMARK_DIR = 'logos' #1080 \n",
36 | "images = sorted([os.path.join(BASE_IMG_DIR,x) for x in os.listdir(BASE_IMG_DIR) if '.jpg' in x])\n",
37 | "watermarks = sorted([os.path.join(WATERMARK_DIR,x).replace(' ','_') for x in os.listdir(WATERMARK_DIR) if '.png' in x])\n",
38 | "# rename all the watermark from replace ' ' to '_'\n",
39 | "\n",
40 | "random.shuffle(images)\n",
41 | "random.shuffle(watermarks)\n",
42 | "\n",
43 | "train_images = images[:int(len(images)*0.7)]\n",
44 | "val_images = images[int(len(images)*0.7):int(len(images)*0.8)]\n",
45 | "test_images = images[int(len(images)*0.8):]\n",
46 | "\n",
47 | "train_wms = watermarks[:int(len(watermarks)*0.7)]\n",
48 | "val_wms = watermarks[int(len(watermarks)*0.7):int(len(watermarks)*0.8)]\n",
49 | "test_wms = watermarks[int(len(watermarks)*0.8):]\n",
50 | "\n",
51 | "# save all the settings to file\n",
52 | "names = ['train_images','val_images','test_images','train_wms','val_wms','test_wms']\n",
53 | "lists = [train_images,val_images,test_images,train_wms,val_wms,test_wms]\n",
54 | "dataset = dict(zip(names, lists))\n",
55 | "\n",
56 | "for name,content in dataset.items():\n",
57 | " with open('dataset/%s.txt'%name,'w') as f:\n",
58 | " f.write(\"\\n\".join(content))\n",
59 | "\n",
60 | "print('SAVE ALL THE SETTING')\n",
61 | "\n",
62 | "for name, images in dataset.items():\n",
63 | " if 'images' not in name:\n",
64 | " continue\n",
65 | " # for each setting, synthesis the watermark\n",
66 | " # for each image, add X(X=6) watermark in differnet position, alpha,\n",
67 | " # save the synthesized image, watermark mask, reshaped mask,\n",
68 | " save_path = 'dataset/%s/'%name\n",
69 | " os.makedirs('%s/image'%(save_path))\n",
70 | " os.makedirs('%s/mask'%(save_path))\n",
71 | " os.makedirs('%s/wm'%(save_path))\n",
72 | " \n",
73 | " for img in images:\n",
74 | " im = Image.open(img).convert('RGBA')\n",
75 | " imw,imh = im.size\n",
76 | " \n",
77 | " for wmg in random.choices(dataset[name.replace('images','wms')],k=6):\n",
78 | " wm = Image.open(wmg.replace('_',' ')).convert(\"RGBA\") # RGBA\n",
79 | " # get the mask of wm\n",
80 | " # data agumentation of wm\n",
81 | " wm = wm.rotate(angle=random.randint(0,360),expand=True) # rotate\n",
82 | " \n",
83 | " # make sure the \n",
84 | " imrw = random.randrange(int(0.4*imw),int(0.8*imw))\n",
85 | " imrh = random.randrange(int(0.4*imh),int(0.8*imh))\n",
86 | " wmsize = imrh if imrw > imrh else imrw\n",
87 | " wm = wm.resize((wmsize,wmsize),Image.BILINEAR)\n",
88 | " w,h = wm.size # new size \n",
89 | " \n",
90 | " box_left = random.randint(0,imw-w)\n",
91 | " box_upper = random.randint(0,imh-h)\n",
92 | " wmm = wm.copy()\n",
93 | " wm.putalpha(random.randint(int(255*0.4),int(255*0.8))) # alpha\n",
94 | " \n",
95 | " ims,wmc = trans_paste(im,wm,wmm,(box_left,box_upper))\n",
96 | " \n",
97 | " wmnp = np.array(wmc) # h,w,3\n",
98 | " mask = np.sum(wmnp,axis=2)>0\n",
99 | " mm = Image.fromarray(np.uint8(mask*255),mode='L')\n",
100 | " \n",
101 | " identifier = os.path.basename(img).split('.')[0] +'-'+os.path.basename(wmg).split('.')[0] + '.png'\n",
102 | " # save \n",
103 | " wmc.save('%s/wm/%s'%(save_path,identifier))\n",
104 | " ims.save('%s/image/%s'%(save_path,identifier))\n",
105 | " mm.save('%s/mask/%s'%(save_path,identifier))\n",
106 | " \n",
107 | " "
108 | ]
109 | },
110 | {
111 | "cell_type": "code",
112 | "execution_count": null,
113 | "metadata": {},
114 | "outputs": [],
115 | "source": []
116 | },
117 | {
118 | "cell_type": "code",
119 | "execution_count": null,
120 | "metadata": {},
121 | "outputs": [],
122 | "source": []
123 | },
124 | {
125 | "cell_type": "code",
126 | "execution_count": null,
127 | "metadata": {},
128 | "outputs": [],
129 | "source": []
130 | }
131 | ],
132 | "metadata": {
133 | "kernelspec": {
134 | "display_name": "Python 3",
135 | "language": "python",
136 | "name": "python3"
137 | },
138 | "language_info": {
139 | "codemirror_mode": {
140 | "name": "ipython",
141 | "version": 3
142 | },
143 | "file_extension": ".py",
144 | "mimetype": "text/x-python",
145 | "name": "python",
146 | "nbconvert_exporter": "python",
147 | "pygments_lexer": "ipython3",
148 | "version": "3.7.4"
149 | }
150 | },
151 | "nbformat": 4,
152 | "nbformat_minor": 2
153 | }
154 |
--------------------------------------------------------------------------------
/scripts/utils/transforms.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import os
4 | import numpy as np
5 | import scipy.misc
6 | import matplotlib.pyplot as plt
7 | import torch
8 | import torchvision
9 |
10 | from .misc import *
11 | from .imutils import *
12 |
13 |
14 | def color_normalize(x, mean, std):
15 | if x.size(0) == 1:
16 | x = x.repeat(3, x.size(1), x.size(2))
17 |
18 | for t, m, s in zip(x, mean, std):
19 | t.sub_(m)
20 | return x
21 |
22 |
23 | def flip_back(flip_output, dataset='mpii'):
24 | """
25 | flip output map
26 | """
27 | if dataset == 'mpii':
28 | matchedParts = (
29 | [0,5], [1,4], [2,3],
30 | [10,15], [11,14], [12,13]
31 | )
32 | else:
33 | print('Not supported dataset: ' + dataset)
34 |
35 | # flip output horizontally
36 | flip_output = fliplr(flip_output.numpy())
37 |
38 | # Change left-right parts
39 | for pair in matchedParts:
40 | tmp = np.copy(flip_output[:, pair[0], :, :])
41 | flip_output[:, pair[0], :, :] = flip_output[:, pair[1], :, :]
42 | flip_output[:, pair[1], :, :] = tmp
43 |
44 | return torch.from_numpy(flip_output).float()
45 |
46 |
47 | def shufflelr(x, width, dataset='mpii'):
48 | """
49 | flip coords
50 | """
51 | if dataset == 'mpii':
52 | matchedParts = (
53 | [0,5], [1,4], [2,3],
54 | [10,15], [11,14], [12,13]
55 | )
56 | else:
57 | print('Not supported dataset: ' + dataset)
58 |
59 | # Flip horizontal
60 | x[:, 0] = width - x[:, 0]
61 |
62 | # Change left-right parts
63 | for pair in matchedParts:
64 | tmp = x[pair[0], :].clone()
65 | x[pair[0], :] = x[pair[1], :]
66 | x[pair[1], :] = tmp
67 |
68 | return x
69 |
70 |
71 | def fliplr(x):
72 | if x.ndim == 3:
73 | x = np.transpose(np.fliplr(np.transpose(x, (0, 2, 1))), (0, 2, 1))
74 | elif x.ndim == 4:
75 | for i in range(x.shape[0]):
76 | x[i] = np.transpose(np.fliplr(np.transpose(x[i], (0, 2, 1))), (0, 2, 1))
77 | return x.astype(float)
78 |
79 |
80 | def get_transform(center, scale, res, rot=0):
81 | """
82 | General image processing functions
83 | """
84 | # Generate transformation matrix
85 | h = 200 * scale
86 | t = np.zeros((3, 3))
87 | t[0, 0] = float(res[1]) / h
88 | t[1, 1] = float(res[0]) / h
89 | t[0, 2] = res[1] * (-float(center[0]) / h + .5)
90 | t[1, 2] = res[0] * (-float(center[1]) / h + .5)
91 | t[2, 2] = 1
92 | if not rot == 0:
93 | rot = -rot # To match direction of rotation from cropping
94 | rot_mat = np.zeros((3,3))
95 | rot_rad = rot * np.pi / 180
96 | sn,cs = np.sin(rot_rad), np.cos(rot_rad)
97 | rot_mat[0,:2] = [cs, -sn]
98 | rot_mat[1,:2] = [sn, cs]
99 | rot_mat[2,2] = 1
100 | # Need to rotate around center
101 | t_mat = np.eye(3)
102 | t_mat[0,2] = -res[1]/2
103 | t_mat[1,2] = -res[0]/2
104 | t_inv = t_mat.copy()
105 | t_inv[:2,2] *= -1
106 | t = np.dot(t_inv,np.dot(rot_mat,np.dot(t_mat,t)))
107 | return t
108 |
109 |
110 | def transform(pt, center, scale, res, invert=0, rot=0):
111 | # Transform pixel location to different reference
112 | t = get_transform(center, scale, res, rot=rot)
113 | if invert:
114 | t = np.linalg.inv(t)
115 | new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
116 | new_pt = np.dot(t, new_pt)
117 | return new_pt[:2].astype(int) + 1
118 |
119 |
120 | def transform_preds(coords, center, scale, res):
121 | # size = coords.size()
122 | # coords = coords.view(-1, coords.size(-1))
123 | # print(coords.size())
124 | for p in range(coords.size(0)):
125 | coords[p, 0:2] = to_torch(transform(coords[p, 0:2], center, scale, res, 1, 0))
126 | return coords
127 |
128 |
129 | def crop(img, center, scale, res, rot=0):
130 | img = im_to_numpy(img)
131 |
132 | # Upper left point
133 | ul = np.array(transform([0, 0], center, scale, res, invert=1))
134 | # Bottom right point
135 | br = np.array(transform(res, center, scale, res, invert=1))
136 |
137 | # Padding so that when rotated proper amount of context is included
138 | pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
139 | if not rot == 0:
140 | ul -= pad
141 | br += pad
142 |
143 | new_shape = [br[1] - ul[1], br[0] - ul[0]]
144 | if len(img.shape) > 2:
145 | new_shape += [img.shape[2]]
146 | new_img = np.zeros(new_shape)
147 |
148 | # Range to fill new array
149 | new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
150 | new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
151 | # Range to sample from original image
152 | old_x = max(0, ul[0]), min(len(img[0]), br[0])
153 | old_y = max(0, ul[1]), min(len(img), br[1])
154 | new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]]
155 |
156 | if not rot == 0:
157 | # Remove padding
158 | new_img = scipy.misc.imrotate(new_img, rot)
159 | new_img = new_img[pad:-pad, pad:-pad]
160 |
161 | new_img = im_to_torch(scipy.misc.imresize(new_img, res))
162 | return new_img
163 |
164 |
165 | def get_right(img,gray=False):
166 | img = im_to_numpy(img) #H*W*C
167 |
168 | new_img = img[:,0:256,:]
169 |
170 |
171 | new_img = im_to_torch(new_img)
172 | if gray == True:
173 | new_img = new_img[1,:,:];
174 |
175 | return new_img
176 |
177 | class NormalizeInverse(torchvision.transforms.Normalize):
178 | """
179 | Undoes the normalization and returns the reconstructed images in the input domain.
180 | """
181 |
182 | def __init__(self, mean, std):
183 | mean = torch.as_tensor(mean)
184 | std = torch.as_tensor(std)
185 | std_inv = 1 / (std + 1e-7)
186 | mean_inv = -mean * std_inv
187 | super().__init__(mean=mean_inv, std=std_inv)
188 |
189 | def __call__(self, tensor):
190 | return super().__call__(tensor.clone())
191 |
--------------------------------------------------------------------------------
/options.py:
--------------------------------------------------------------------------------
1 |
2 | import scripts.models as models
3 |
4 | model_names = sorted(name for name in models.__dict__
5 | if name.islower() and not name.startswith("__")
6 | and callable(models.__dict__[name]))
7 |
8 | class Options():
9 | """docstring for Options"""
10 | def __init__(self):
11 | pass
12 |
13 | def init(self, parser):
14 | # Model structure
15 | parser.add_argument('--arch', '-a', metavar='ARCH', default='dhn',
16 | choices=model_names,
17 | help='model architecture: ' +
18 | ' | '.join(model_names) +
19 | ' (default: resnet18)')
20 | parser.add_argument('--darch', metavar='ARCH', default='dhn',
21 | choices=model_names,
22 | help='model architecture: ' +
23 | ' | '.join(model_names) +
24 | ' (default: resnet18)')
25 |
26 | parser.add_argument('--machine', '-m', metavar='NACHINE', default='basic')
27 | # Training strategy
28 | parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
29 | help='number of data loading workers (default: 4)')
30 | parser.add_argument('--epochs', default=30, type=int, metavar='N',
31 | help='number of total epochs to run')
32 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
33 | help='manual epoch number (useful on restarts)')
34 | parser.add_argument('--train-batch', default=64, type=int, metavar='N',
35 | help='train batchsize')
36 | parser.add_argument('--test-batch', default=6, type=int, metavar='N',
37 | help='test batchsize')
38 | parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float,metavar='LR', help='initial learning rate')
39 | parser.add_argument('--dlr', '--dlearning-rate', default=1e-3, type=float, help='initial learning rate')
40 | parser.add_argument('--beta1', default=0.9, type=float, help='initial learning rate')
41 | parser.add_argument('--beta2', default=0.999, type=float, help='initial learning rate')
42 | parser.add_argument('--momentum', default=0, type=float, metavar='M',
43 | help='momentum')
44 | parser.add_argument('--weight-decay', '--wd', default=0, type=float,
45 | metavar='W', help='weight decay (default: 0)')
46 | parser.add_argument('--schedule', type=int, nargs='+', default=[5, 10],
47 | help='Decrease learning rate at these epochs.')
48 | parser.add_argument('--gamma', type=float, default=0.1,
49 | help='LR is multiplied by gamma on schedule.')
50 | # Data processing
51 | parser.add_argument('-f', '--flip', dest='flip', action='store_true',
52 | help='flip the input during validation')
53 | parser.add_argument('--lambdaL1', type=float, default=1, help='the weight of L1.')
54 | parser.add_argument('--alpha', type=float, default=0.5,
55 | help='Groundtruth Gaussian sigma.')
56 | parser.add_argument('--sigma-decay', type=float, default=0,
57 | help='Sigma decay rate for each epoch.')
58 | # Miscs
59 | parser.add_argument('--base-dir', default='/PATH_TO_DATA_FOLDER/', type=str, metavar='PATH')
60 | parser.add_argument('--data', default='', type=str, metavar='PATH',
61 | help='path to save checkpoint (default: checkpoint)')
62 | parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH',
63 | help='path to save checkpoint (default: checkpoint)')
64 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
65 | help='path to latest checkpoint (default: none)')
66 | parser.add_argument('--finetune', default='', type=str, metavar='PATH',
67 | help='path to latest checkpoint (default: none)')
68 |
69 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
70 | help='evaluate model on validation set')
71 | parser.add_argument('--style-loss', default=0, type=float,
72 | help='preception loss')
73 | parser.add_argument('--ssim-loss', default=0, type=float,help='msssim loss')
74 | parser.add_argument('--att-loss', default=1, type=float,help='msssim loss')
75 | parser.add_argument('--default-loss',default=False,type=bool)
76 | parser.add_argument('--sltype', default='vggx', type=str)
77 | parser.add_argument('-da', '--data-augumentation', default=False, type=bool,
78 | help='preception loss')
79 | parser.add_argument('-d', '--debug', dest='debug', action='store_true',
80 | help='show intermediate results')
81 | parser.add_argument('--input-size', default=256, type=int, metavar='N',
82 | help='train batchsize')
83 | parser.add_argument('--freq', default=-1, type=int, metavar='N',
84 | help='evaluation frequence')
85 | parser.add_argument('--normalized-input', default=False, type=bool,
86 | help='train batchsize')
87 | parser.add_argument('--res', default=False, type=bool,help='residual learning for s2am')
88 | parser.add_argument('--requires-grad', default=False, type=bool,
89 | help='train batchsize')
90 | parser.add_argument('--limited-dataset', default=0, type=int, metavar='N')
91 | parser.add_argument('--gpu',default=True,type=bool)
92 | parser.add_argument('--masked',default=False,type=bool)
93 | parser.add_argument('--gan-norm', default=False,type=bool, help='train batchsize')
94 | parser.add_argument('--hl', default=False,type=bool, help='homogenious leanring')
95 | parser.add_argument('--loss-type', default='l2',type=str, help='train batchsize')
96 | return parser
--------------------------------------------------------------------------------
/scripts/models/rasc.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import torch
4 | import torchvision
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import numpy as np
8 | import math
9 |
10 | from scripts.utils.model_init import *
11 | from scripts.models.vgg import Vgg16
12 | from scripts.models.blocks import *
13 |
14 |
15 | class CAWapper(nn.Module):
16 | """docstring for SENet"""
17 |
18 | def __init__(self, channel, type_of_connection=BasicLearningBlock):
19 | super(CAWapper, self).__init__()
20 | self.attention = ContextualAttention(ksize=3, stride=1, rate=2, fuse_k=3, softmax_scale=10, fuse=True, use_cuda=True)
21 |
22 | def forward(self, feature, mask):
23 | _, _, w, _ = feature.size()
24 | _, _, mw, _ = mask.size()
25 | # binaryfiy
26 | # selected the feature from the background as the additional feature to masked splicing feature.
27 | mask = torch.round(F.avg_pool2d(mask, 2, stride=mw//w))
28 |
29 | result = self.attention(feature,mask)
30 |
31 | return result
32 |
33 |
34 | class NLWapper(nn.Module):
35 | """docstring for SENet"""
36 |
37 | def __init__(self, channel, type_of_connection=BasicLearningBlock):
38 | super(NLWapper, self).__init__()
39 | self.attention = NONLocalBlock2D(channel)
40 |
41 | def forward(self, feature, mask):
42 | _, _, w, _ = feature.size()
43 | _, _, mw, _ = mask.size()
44 | # binaryfiy
45 | # selected the feature from the background as the additional feature to masked splicing feature.
46 | # mask = torch.round(F.avg_pool2d(mask, 2, stride=mw//w))
47 |
48 | result = self.attention(feature)
49 |
50 | return result
51 |
52 | class SENet(nn.Module):
53 | """docstring for SENet"""
54 | def __init__(self,channel,type_of_connection=BasicLearningBlock):
55 | super(SENet, self).__init__()
56 | self.attention = SEBlock(channel,16)
57 |
58 | def forward(self,feature,mask):
59 | _,_,w,_ = feature.size()
60 | _,_,mw,_ = mask.size()
61 | # binaryfiy
62 | # selected the feature from the background as the additional feature to masked splicing feature.
63 | mask = torch.round(F.avg_pool2d(mask,2,stride=mw//w))
64 |
65 | result = self.attention(feature)
66 |
67 | return result
68 |
69 | class CBAMConnect(nn.Module):
70 | def __init__(self,channel):
71 | super(CBAMConnect, self).__init__()
72 | self.attention = CBAM(channel)
73 |
74 | def forward(self,feature,mask):
75 | results = self.attention(feature)
76 | return results
77 |
78 |
79 |
80 | class RASC(nn.Module):
81 | def __init__(self,channel,type_of_connection=BasicLearningBlock):
82 | super(RASC, self).__init__()
83 | self.connection = type_of_connection(channel)
84 | self.background_attention = GlobalAttentionModule(channel,16)
85 | self.mixed_attention = GlobalAttentionModule(channel,16)
86 | self.spliced_attention = GlobalAttentionModule(channel,16)
87 | self.gaussianMask = GaussianSmoothing(1,5,1)
88 |
89 | def forward(self,feature,mask):
90 | _,_,w,_ = feature.size()
91 | _,_,mw,_ = mask.size()
92 | # binaryfiy
93 | # selected the feature from the background as the additional feature to masked splicing feature.
94 | if w != mw:
95 | mask = torch.round(F.avg_pool2d(mask,2,stride=mw//w))
96 | reverse_mask = -1*(mask-1)
97 | # here we add gaussin filter to mask and reverse_mask for better harimoization of edges.
98 |
99 | mask = self.gaussianMask(F.pad(mask,(2,2,2,2),mode='reflect'))
100 | reverse_mask = self.gaussianMask(F.pad(reverse_mask,(2,2,2,2),mode='reflect'))
101 |
102 |
103 | background = self.background_attention(feature) * reverse_mask
104 | selected_feature = self.mixed_attention(feature)
105 | spliced_feature = self.spliced_attention(feature)
106 | spliced = ( self.connection(spliced_feature) + selected_feature ) * mask
107 | return background + spliced
108 |
109 |
110 | class UNO(nn.Module):
111 | def __init__(self,channel):
112 | super(UNO, self).__init__()
113 |
114 | def forward(self,feature,_m):
115 | return feature
116 |
117 |
118 | class URASC(nn.Module):
119 | def __init__(self,channel,type_of_connection=BasicLearningBlock):
120 | super(URASC, self).__init__()
121 | self.connection = type_of_connection(channel)
122 | self.background_attention = GlobalAttentionModule(channel,16)
123 | self.mixed_attention = GlobalAttentionModule(channel,16)
124 | self.spliced_attention = GlobalAttentionModule(channel,16)
125 | self.mask_attention = SpatialAttentionModule(channel,16)
126 |
127 | def forward(self,feature, m=None):
128 | _,_,w,_ = feature.size()
129 |
130 | mask, reverse_mask = self.mask_attention(feature)
131 |
132 | background = self.background_attention(feature) * reverse_mask
133 | selected_feature = self.mixed_attention(feature)
134 | spliced_feature = self.spliced_attention(feature)
135 | spliced = ( self.connection(spliced_feature) + selected_feature ) * mask
136 | return background + spliced
137 |
138 |
139 | class MaskedURASC(nn.Module):
140 | def __init__(self,channel,type_of_connection=BasicLearningBlock):
141 | super(MaskedURASC, self).__init__()
142 | self.connection = type_of_connection(channel)
143 | self.background_attention = GlobalAttentionModule(channel,16)
144 | self.mixed_attention = GlobalAttentionModule(channel,16)
145 | self.spliced_attention = GlobalAttentionModule(channel,16)
146 | self.mask_attention = SpatialAttentionModule(channel,16)
147 |
148 | def forward(self,feature):
149 | _,_,w,_ = feature.size()
150 |
151 | mask, reverse_mask = self.mask_attention(feature)
152 |
153 | background = self.background_attention(feature) * reverse_mask
154 | selected_feature = self.mixed_attention(feature)
155 | spliced_feature = self.spliced_attention(feature)
156 | spliced = ( self.connection(spliced_feature) + selected_feature ) * mask
157 | return background + spliced, mask
158 |
159 |
--------------------------------------------------------------------------------
/scripts/models/discriminator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import functools
3 | import math
4 | import torch
5 | from torch.autograd import Variable
6 | import torch.nn.functional as F
7 | from torch import nn
8 | from torch import Tensor
9 | from torch.nn import Parameter
10 | from scripts.utils.model_init import *
11 | from torch.optim.optimizer import Optimizer, required
12 |
13 |
14 | __all__ = ['patchgan','sngan','maskedsngan']
15 |
16 |
17 | class SNCoXvWithActivation(torch.nn.Module):
18 | """
19 | SN convolution for spetral normalization conv
20 | """
21 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, activation=torch.nn.LeakyReLU(0.2, inplace=True)):
22 | super(SNCoXvWithActivation, self).__init__()
23 | self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
24 | self.conv2d = torch.nn.utils.spectral_norm(self.conv2d)
25 | self.activation = activation
26 | for m in self.modules():
27 | if isinstance(m, nn.Conv2d):
28 | nn.init.kaiming_normal_(m.weight)
29 | def forward(self, input):
30 | x = self.conv2d(input)
31 | if self.activation is not None:
32 | return self.activation(x)
33 | else:
34 | return x
35 |
36 | def l2normalize(v, eps=1e-12):
37 | return v / (v.norm() + eps)
38 |
39 |
40 | class SpectralNorm(nn.Module):
41 | def __init__(self, module, name='weight', power_iterations=1):
42 | super(SpectralNorm, self).__init__()
43 | self.module = module
44 | self.name = name
45 | self.power_iterations = power_iterations
46 | if not self._made_params():
47 | self._make_params()
48 |
49 | def _update_u_v(self):
50 | u = getattr(self.module, self.name + "_u")
51 | v = getattr(self.module, self.name + "_v")
52 | w = getattr(self.module, self.name + "_bar")
53 |
54 | height = w.data.shape[0]
55 | for _ in range(self.power_iterations):
56 | v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
57 | u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))
58 |
59 | # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
60 | sigma = u.dot(w.view(height, -1).mv(v))
61 | setattr(self.module, self.name, w / sigma.expand_as(w))
62 |
63 | def _made_params(self):
64 | try:
65 | u = getattr(self.module, self.name + "_u")
66 | v = getattr(self.module, self.name + "_v")
67 | w = getattr(self.module, self.name + "_bar")
68 | return True
69 | except AttributeError:
70 | return False
71 |
72 |
73 | def _make_params(self):
74 | w = getattr(self.module, self.name)
75 |
76 | height = w.data.shape[0]
77 | width = w.view(height, -1).data.shape[1]
78 |
79 | u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
80 | v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
81 | u.data = l2normalize(u.data)
82 | v.data = l2normalize(v.data)
83 | w_bar = Parameter(w.data)
84 |
85 | del self.module._parameters[self.name]
86 |
87 | self.module.register_parameter(self.name + "_u", u)
88 | self.module.register_parameter(self.name + "_v", v)
89 | self.module.register_parameter(self.name + "_bar", w_bar)
90 |
91 |
92 | def forward(self, *args):
93 | self._update_u_v()
94 | return self.module.forward(*args)
95 |
96 |
97 | def get_pad(in_, ksize, stride, atrous=1):
98 | out_ = np.ceil(float(in_)/stride)
99 | return int(((out_ - 1) * stride + atrous*(ksize-1) + 1 - in_)/2)
100 |
101 | class SNDiscriminator(nn.Module):
102 | def __init__(self,channel=6):
103 | super(SNDiscriminator, self).__init__()
104 | cnum = 32
105 | self.discriminator_net = nn.Sequential(
106 | SNCoXvWithActivation(channel, 2*cnum, 4, 2, padding=get_pad(256, 5, 2)),
107 | SNCoXvWithActivation(2*cnum, 4*cnum, 4, 2, padding=get_pad(128, 5, 2)),
108 | SNCoXvWithActivation(4*cnum, 8*cnum, 4, 2, padding=get_pad(64, 5, 2)),
109 | SNCoXvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(32, 5, 2)),
110 | SNCoXvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(16, 5, 2)), # 8*8*256
111 | # SNConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(8, 5, 2)), # 4*4*256
112 | # SNConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(4, 5, 2)), # 2*2*256
113 | )
114 | # self.linear = nn.Linear(2*2*256,1)
115 |
116 | def forward(self, img_A, img_B):
117 | # Concatenate image and condition image by channels to produce input
118 | img_input = torch.cat((img_A, img_B), 1)
119 | x = self.discriminator_net(img_input)
120 | # x = x.view((x.size(0),-1))
121 | # x = self.linear(x)
122 | return x
123 |
124 | class Discriminator(nn.Module):
125 | def __init__(self, in_channels=3):
126 | super(Discriminator, self).__init__()
127 |
128 | def discriminator_block(in_filters, out_filters, normalization=True):
129 | """Returns downsampling layers of each discriminator block"""
130 | layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
131 | if normalization:
132 | layers.append(nn.InstanceNorm2d(out_filters))
133 | layers.append(nn.LeakyReLU(0.2, inplace=True))
134 | return layers
135 |
136 | self.model = nn.Sequential(
137 | *discriminator_block(in_channels*2, 64, normalization=False),
138 | *discriminator_block(64, 128),
139 | *discriminator_block(128, 256),
140 | *discriminator_block(256, 512),
141 | nn.ZeroPad2d((1, 0, 1, 0)),
142 | nn.Conv2d(512, 1, 4, padding=1, bias=False)
143 | )
144 |
145 | def forward(self, img_A, img_B):
146 | # Concatenate image and condition image by channels to produce input
147 | img_input = torch.cat((img_A, img_B), 1)
148 | return self.model(img_input)
149 |
150 |
151 | def patchgan():
152 | model = Discriminator()
153 | model.apply(weights_init_kaiming)
154 | return model
155 |
156 | def sngan():
157 | model = SNDiscriminator()
158 | model.apply(weights_init_kaiming)
159 | return model
160 |
161 | def maskedsngan():
162 | model = SNDiscriminator(channel=7)
163 | model.apply(weights_init_kaiming)
164 | return model
--------------------------------------------------------------------------------
/scripts/models/unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 | import functools
5 | from scripts.models.blocks import *
6 | from scripts.models.rasc import *
7 |
8 |
9 | class MinimalUnetV2(nn.Module):
10 | """docstring for MinimalUnet"""
11 | def __init__(self, down=None,up=None,submodule=None,attention=None,withoutskip=False,**kwags):
12 | super(MinimalUnetV2, self).__init__()
13 |
14 | self.down = nn.Sequential(*down)
15 | self.up = nn.Sequential(*up)
16 | self.sub = submodule
17 | self.attention = attention
18 | self.withoutskip = withoutskip
19 | self.is_attention = not self.attention == None
20 | self.is_sub = not submodule == None
21 |
22 | def forward(self,x,mask=None):
23 | if self.is_sub:
24 | x_up,_ = self.sub(self.down(x),mask)
25 | else:
26 | x_up = self.down(x)
27 |
28 | if self.withoutskip: #outer or inner.
29 | x_out = self.up(x_up)
30 | else:
31 | if self.is_attention:
32 | x_out = (self.attention(torch.cat([x,self.up(x_up)],1),mask),mask)
33 | else:
34 | x_out = (torch.cat([x,self.up(x_up)],1),mask)
35 |
36 | return x_out
37 |
38 |
39 | class MinimalUnet(nn.Module):
40 | """docstring for MinimalUnet"""
41 | def __init__(self, down=None,up=None,submodule=None,attention=None,withoutskip=False,**kwags):
42 | super(MinimalUnet, self).__init__()
43 |
44 | self.down = nn.Sequential(*down)
45 | self.up = nn.Sequential(*up)
46 | self.sub = submodule
47 | self.attention = attention
48 | self.withoutskip = withoutskip
49 | self.is_attention = not self.attention == None
50 | self.is_sub = not submodule == None
51 |
52 | def forward(self,x,mask=None):
53 | if self.is_sub:
54 | x_up,_ = self.sub(self.down(x),mask)
55 | else:
56 | x_up = self.down(x)
57 |
58 | if self.is_attention:
59 | x = self.attention(x,mask)
60 |
61 | if self.withoutskip: #outer or inner.
62 | x_out = self.up(x_up)
63 | else:
64 | x_out = (torch.cat([x,self.up(x_up)],1),mask)
65 |
66 | return x_out
67 |
68 |
69 | # Defines the submodule with skip connection.
70 | # X -------------------identity---------------------- X
71 | # |-- downsampling -- |submodule| -- upsampling --|
72 | class UnetSkipConnectionBlock(nn.Module):
73 | def __init__(self, outer_nc, inner_nc, input_nc=None,
74 | submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False,is_attention_layer=False,
75 | attention_model=RASC,basicblock=MinimalUnet,outermostattention=False):
76 | super(UnetSkipConnectionBlock, self).__init__()
77 | self.outermost = outermost
78 | if type(norm_layer) == functools.partial:
79 | use_bias = norm_layer.func == nn.InstanceNorm2d
80 | else:
81 | use_bias = norm_layer == nn.InstanceNorm2d
82 | if input_nc is None:
83 | input_nc = outer_nc
84 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
85 | stride=2, padding=1, bias=use_bias)
86 | downrelu = nn.LeakyReLU(0.2, True)
87 | downnorm = norm_layer(inner_nc)
88 | uprelu = nn.ReLU(True)
89 | upnorm = norm_layer(outer_nc)
90 |
91 |
92 | if outermost:
93 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
94 | kernel_size=4, stride=2,
95 | padding=1)
96 | down = [downconv]
97 | up = [uprelu, upconv]
98 | model = basicblock(down,up,submodule,withoutskip=outermost)
99 | elif innermost:
100 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
101 | kernel_size=4, stride=2,
102 | padding=1, bias=use_bias)
103 | down = [downrelu, downconv]
104 | up = [uprelu, upconv, upnorm]
105 | model = basicblock(down,up)
106 | else:
107 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
108 | kernel_size=4, stride=2,
109 | padding=1, bias=use_bias)
110 | down = [downrelu, downconv, downnorm]
111 | up = [uprelu, upconv, upnorm]
112 |
113 | if is_attention_layer:
114 | if MinimalUnetV2.__qualname__ in basicblock.__qualname__ :
115 | attention_model = attention_model(input_nc*2)
116 | else:
117 | attention_model = attention_model(input_nc)
118 | else:
119 | attention_model = None
120 |
121 | if use_dropout:
122 | model = basicblock(down,up.append(nn.Dropout(0.5)),submodule,attention_model,outermostattention=outermostattention)
123 | else:
124 | model = basicblock(down,up,submodule,attention_model,outermostattention=outermostattention)
125 |
126 | self.model = model
127 |
128 |
129 | def forward(self, x,mask=None):
130 | # build the mask for attention use
131 | return self.model(x,mask)
132 |
133 | class UnetGenerator(nn.Module):
134 | def __init__(self, input_nc, output_nc, num_downs=8, ngf=64,norm_layer=nn.BatchNorm2d, use_dropout=False,
135 | is_attention_layer=False,attention_model=RASC,use_inner_attention=False,basicblock=MinimalUnet):
136 | super(UnetGenerator, self).__init__()
137 |
138 | # 8 for 256x256
139 | # 9 for 512x512
140 | # construct unet structure
141 | self.need_mask = not input_nc == output_nc
142 |
143 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True,basicblock=basicblock) # 1
144 | for i in range(num_downs - 5): #3 times
145 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout,is_attention_layer=use_inner_attention,attention_model=attention_model,basicblock=basicblock) # 8,4,2
146 | unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer,is_attention_layer=is_attention_layer,attention_model=attention_model,basicblock=basicblock) #16
147 | unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer,is_attention_layer=is_attention_layer,attention_model=attention_model,basicblock=basicblock) #32
148 | unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer,is_attention_layer=is_attention_layer,attention_model=attention_model,basicblock=basicblock, outermostattention=True) #64
149 | unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, basicblock=basicblock, norm_layer=norm_layer) # 128
150 |
151 | self.model = unet_block
152 |
153 | def forward(self, input):
154 | if self.need_mask:
155 | return self.model(input,input[:,3:4,:,:])
156 | else:
157 | return self.model(input[:,0:3,:,:],input[:,3:4,:,:])
158 |
159 |
160 |
161 |
--------------------------------------------------------------------------------
/scripts/utils/imutils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import torch
4 | import torch.nn as nn
5 | import numpy as np
6 | import scipy.misc
7 |
8 | from .misc import *
9 |
10 | def im_to_numpy(img):
11 | img = to_numpy(img)
12 | img = np.transpose(img, (1, 2, 0)) # H*W*C
13 | return img
14 |
15 | def im_to_torch(img):
16 | img = np.transpose(img, (2, 0, 1)) # C*H*W
17 | img = to_torch(img).float()
18 | if img.max() > 1:
19 | img /= 255
20 | return img
21 |
22 | def load_image(img_path):
23 | # H x W x C => C x H x W
24 | return im_to_torch(scipy.misc.imread(img_path, mode='RGB'))
25 |
26 | def imread_all(img_path):
27 | return scipy.misc.imread(img_path, mode='RGB')
28 |
29 | def load_image_gray(img_path):
30 | # H x W x C => C x H x W
31 | x = scipy.misc.imread(img_path, mode='L')
32 | x = x[:,:,np.newaxis]
33 | return im_to_torch(x)
34 |
35 | def resize(img, owidth, oheight):
36 | img = im_to_numpy(img)
37 |
38 | if img.shape[2] == 1:
39 | img = scipy.misc.imresize(img.squeeze(),(oheight,owidth))
40 | img = img[:,:,np.newaxis]
41 | else:
42 | img = scipy.misc.imresize(
43 | img,
44 | (oheight, owidth)
45 | )
46 | img = im_to_torch(img)
47 | # print('%f %f' % (img.min(), img.max()))
48 | return img
49 |
50 | # =============================================================================
51 | # Helpful functions generating groundtruth labelmap
52 | # =============================================================================
53 |
54 | def gaussian(shape=(7,7),sigma=1):
55 | """
56 | 2D gaussian mask - should give the same result as MATLAB's
57 | fspecial('gaussian',[shape],[sigma])
58 | """
59 | m,n = [(ss-1.)/2. for ss in shape]
60 | y,x = np.ogrid[-m:m+1,-n:n+1]
61 | h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) )
62 | h[ h < np.finfo(h.dtype).eps*h.max() ] = 0
63 | return to_torch(h).float()
64 |
65 | def draw_labelmap(img, pt, sigma, type='Gaussian'):
66 | # Draw a 2D gaussian
67 | # Adopted from https://github.com/anewell/pose-hg-train/blob/master/src/pypose/draw.py
68 | img = to_numpy(img)
69 |
70 | # Check that any part of the gaussian is in-bounds
71 | ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)]
72 | br = [int(pt[0] + 3 * sigma + 1), int(pt[1] + 3 * sigma + 1)]
73 | if (ul[0] >= img.shape[1] or ul[1] >= img.shape[0] or
74 | br[0] < 0 or br[1] < 0):
75 | # If not, just return the image as is
76 | return to_torch(img)
77 |
78 | # Generate gaussian
79 | size = 6 * sigma + 1
80 | x = np.arange(0, size, 1, float)
81 | y = x[:, np.newaxis]
82 | x0 = y0 = size // 2
83 | # The gaussian is not normalized, we want the center value to equal 1
84 | if type == 'Gaussian':
85 | g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
86 | elif type == 'Cauchy':
87 | g = sigma / (((x - x0) ** 2 + (y - y0) ** 2 + sigma ** 2) ** 1.5)
88 |
89 |
90 | # Usable gaussian range
91 | g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0]
92 | g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1]
93 | # Image range
94 | img_x = max(0, ul[0]), min(br[0], img.shape[1])
95 | img_y = max(0, ul[1]), min(br[1], img.shape[0])
96 |
97 | img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
98 | return to_torch(img)
99 |
100 | # =============================================================================
101 | # Helpful display functions
102 | # =============================================================================
103 |
104 | def gauss(x, a, b, c, d=0):
105 | return a * np.exp(-(x - b)**2 / (2 * c**2)) + d
106 |
107 | def color_heatmap(x):
108 | x = to_numpy(x)
109 | color = np.zeros((x.shape[0],x.shape[1],3))
110 | color[:,:,0] = gauss(x, .5, .6, .2) + gauss(x, 1, .8, .3)
111 | color[:,:,1] = gauss(x, 1, .5, .3)
112 | color[:,:,2] = gauss(x, 1, .2, .3)
113 | color[color > 1] = 1
114 | color = (color * 255).astype(np.uint8)
115 | return color
116 |
117 | def imshow(img):
118 | npimg = im_to_numpy(img*255).astype(np.uint8)
119 | plt.imshow(npimg)
120 | plt.axis('off')
121 |
122 | def show_joints(img, pts):
123 | imshow(img)
124 |
125 | for i in range(pts.size(0)):
126 | if pts[i, 2] > 0:
127 | plt.plot(pts[i, 0], pts[i, 1], 'yo')
128 | plt.axis('off')
129 |
130 | def show_sample(inputs, target):
131 | num_sample = inputs.size(0)
132 | num_joints = target.size(1)
133 | height = target.size(2)
134 | width = target.size(3)
135 |
136 | for n in range(num_sample):
137 | inp = resize(inputs[n], width, height)
138 | out = inp
139 | for p in range(num_joints):
140 | tgt = inp*0.5 + color_heatmap(target[n,p,:,:])*0.5
141 | out = torch.cat((out, tgt), 2)
142 |
143 | imshow(out)
144 | plt.show()
145 |
146 | def sample_with_heatmap(inp, out, num_rows=2, parts_to_show=None):
147 | inp = to_numpy(inp * 255)
148 | out = to_numpy(out)
149 |
150 | img = np.zeros((inp.shape[1], inp.shape[2], inp.shape[0]))
151 | for i in range(3):
152 | img[:, :, i] = inp[i, :, :]
153 |
154 | if parts_to_show is None:
155 | parts_to_show = np.arange(out.shape[0])
156 |
157 | # Generate a single image to display input/output pair
158 | num_cols = int(np.ceil(float(len(parts_to_show)) / num_rows))
159 | size = img.shape[0] // num_rows
160 |
161 | full_img = np.zeros((img.shape[0], size * (num_cols + num_rows), 3), np.uint8)
162 | full_img[:img.shape[0], :img.shape[1]] = img
163 |
164 | inp_small = scipy.misc.imresize(img, [size, size])
165 |
166 | # Set up heatmap display for each part
167 | for i, part in enumerate(parts_to_show):
168 | part_idx = part
169 | out_resized = scipy.misc.imresize(out[part_idx], [size, size])
170 | out_resized = out_resized.astype(float)/255
171 | out_img = inp_small.copy() * .3
172 | color_hm = color_heatmap(out_resized)
173 | out_img += color_hm * .7
174 |
175 | col_offset = (i % num_cols + num_rows) * size
176 | row_offset = (i // num_cols) * size
177 | full_img[row_offset:row_offset + size, col_offset:col_offset + size] = out_img
178 |
179 | return full_img
180 |
181 | def batch_with_heatmap(inputs, outputs, mean=torch.Tensor([0.5, 0.5, 0.5]), num_rows=2, parts_to_show=None):
182 | batch_img = []
183 | for n in range(min(inputs.size(0), 4)):
184 | inp = inputs[n] + mean.view(3, 1, 1).expand_as(inputs[n])
185 | batch_img.append(
186 | sample_with_heatmap(inp.clamp(0, 1), outputs[n], num_rows=num_rows, parts_to_show=parts_to_show)
187 | )
188 | return np.concatenate(batch_img)
189 |
190 |
191 | def normalize_batch(batch):
192 | # normalize using imagenet mean and std
193 | mean = batch.new_tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
194 | std = batch.new_tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
195 | batch = batch/255.0
196 | return (batch - mean) / std
197 |
198 | def show_image_tensor(tensor):
199 | re = []
200 | for i in range(tensor.size(0)):
201 | inp = tensor[i].data.cpu() #w,h,c
202 | inp = inp.numpy().transpose((1, 2, 0))
203 | mean = np.array([0.485, 0.456, 0.406])
204 | std = np.array([0.229, 0.224, 0.225])
205 | inp = std * inp + mean
206 | inp = np.clip(inp, 0, 1).transpose((2,0,1))
207 | re.append(torch.from_numpy(inp).unsqueeze(0))
208 | return torch.cat(re,0)
209 |
210 |
211 | def get_jet():
212 | colormap_int = np.zeros((256, 3), np.uint8)
213 |
214 | for i in range(0, 256, 1):
215 | colormap_int[i, 0] = np.int_(np.round(cm.jet(i)[0] * 255.0))
216 | colormap_int[i, 1] = np.int_(np.round(cm.jet(i)[1] * 255.0))
217 | colormap_int[i, 2] = np.int_(np.round(cm.jet(i)[2] * 255.0))
218 |
219 | return colormap_int
220 |
221 | def clamp(num, min_value, max_value):
222 | return max(min(num, max_value), min_value)
223 |
224 | def gray2color(gray_array, color_map):
225 |
226 | rows, cols = gray_array.shape
227 | color_array = np.zeros((rows, cols, 3), np.uint8)
228 |
229 | for i in range(0, rows):
230 | for j in range(0, cols):
231 | # log(256,2) = 8 , log(1,2) = 0 * 8
232 | color_array[i, j] = color_map[clamp(int(abs(gray_array[i, j])*10),0,255)]
233 |
234 | return color_array
235 |
236 | class objectview(object):
237 | def __init__(self, *args, **kwargs):
238 | d = dict(*args, **kwargs)
239 | self.__dict__ = d
--------------------------------------------------------------------------------
/scripts/models/blocks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import numpy as np
6 | import functools
7 | import math
8 | import numbers
9 |
10 | from scripts.utils.model_init import *
11 | from scripts.models.vgg import Vgg16
12 | from torch import nn, cuda
13 | from torch.autograd import Variable
14 |
15 | class BasicLearningBlock(nn.Module):
16 | """docstring for BasicLearningBlock"""
17 | def __init__(self,channel):
18 | super(BasicLearningBlock, self).__init__()
19 | self.rconv1 = nn.Conv2d(channel,channel*2,3,padding=1,bias=False)
20 | self.rbn1 = nn.BatchNorm2d(channel*2)
21 | self.rconv2 = nn.Conv2d(channel*2,channel,3,padding=1,bias=False)
22 | self.rbn2 = nn.BatchNorm2d(channel)
23 |
24 | def forward(self,feature):
25 | return F.elu(self.rbn2(self.rconv2(F.elu(self.rbn1(self.rconv1(feature))))))
26 |
27 |
28 |
29 | # From https://discuss.pytorch.org/t/is-there-anyway-to-do-gaussian-filtering-for-an-image-2d-3d-in-pytorch/12351/3
30 | class GaussianSmoothing(nn.Module):
31 | """
32 | Apply gaussian smoothing on a
33 | 1d, 2d or 3d tensor. Filtering is performed seperately for each channel
34 | in the input using a depthwise convolution.
35 | Arguments:
36 | channels (int, sequence): Number of channels of the input tensors. Output will
37 | have this number of channels as well.
38 | kernel_size (int, sequence): Size of the gaussian kernel.
39 | sigma (float, sequence): Standard deviation of the gaussian kernel.
40 | dim (int, optional): The number of dimensions of the data.
41 | Default value is 2 (spatial).
42 | """
43 | def __init__(self, channels, kernel_size, sigma, dim=2):
44 | super(GaussianSmoothing, self).__init__()
45 | if isinstance(kernel_size, numbers.Number):
46 | kernel_size = [kernel_size] * dim
47 | if isinstance(sigma, numbers.Number):
48 | sigma = [sigma] * dim
49 |
50 | # The gaussian kernel is the product of the
51 | # gaussian function of each dimension.
52 | kernel = 1
53 | meshgrids = torch.meshgrid(
54 | [
55 | torch.arange(size, dtype=torch.float32)
56 | for size in kernel_size
57 | ]
58 | )
59 | for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
60 | mean = (size - 1) / 2
61 | kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \
62 | torch.exp(-((mgrid - mean) / (2 * std)) ** 2)
63 |
64 | # Make sure sum of values in gaussian kernel equals 1.
65 | kernel = kernel / torch.sum(kernel)
66 |
67 | # Reshape to depthwise convolutional weight
68 | kernel = kernel.view(1, 1, *kernel.size())
69 | kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
70 |
71 | self.register_buffer('weight', kernel)
72 | self.groups = channels
73 |
74 | if dim == 1:
75 | self.conv = F.conv1d
76 | elif dim == 2:
77 | self.conv = F.conv2d
78 | elif dim == 3:
79 | self.conv = F.conv3d
80 | else:
81 | raise RuntimeError(
82 | 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
83 | )
84 |
85 | def forward(self, input):
86 | """
87 | Apply gaussian filter to input.
88 | Arguments:
89 | input (torch.Tensor): Input to apply gaussian filter on.
90 | Returns:
91 | filtered (torch.Tensor): Filtered output.
92 | """
93 | return self.conv(input, weight=self.weight, groups=self.groups)
94 |
95 | class ChannelPool(nn.Module):
96 | def __init__(self,types):
97 | super(ChannelPool, self).__init__()
98 | if types == 'avg':
99 | self.poolingx = nn.AdaptiveAvgPool1d(1)
100 | elif types == 'max':
101 | self.poolingx = nn.AdaptiveMaxPool1d(1)
102 | else:
103 | raise 'inner error'
104 |
105 | def forward(self, input):
106 | n, c, w, h = input.size()
107 | input = input.view(n,c,w*h).permute(0,2,1)
108 | pooled = self.poolingx(input)# b,w*h,c -> b,w*h,1
109 | _, _, c = pooled.size()
110 | return pooled.view(n,c,w,h)
111 |
112 |
113 |
114 | class SEBlock(nn.Module):
115 | """docstring for SEBlock"""
116 | def __init__(self, channel,reducation=16):
117 | super(SEBlock, self).__init__()
118 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
119 | self.fc = nn.Sequential(
120 | nn.Linear(channel,channel//reducation),
121 | nn.ReLU(inplace=True),
122 | nn.Linear(channel//reducation,channel),
123 | nn.Sigmoid())
124 |
125 | def forward(self,x):
126 | b,c,w,h = x.size()
127 | y1 = self.avg_pool(x).view(b,c)
128 | y = self.fc(y1).view(b,c,1,1)
129 | return x*y
130 |
131 |
132 |
133 | class GlobalAttentionModule(nn.Module):
134 | """docstring for GlobalAttentionModule"""
135 | def __init__(self, channel,reducation=16):
136 | super(GlobalAttentionModule, self).__init__()
137 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
138 | self.max_pool = nn.AdaptiveMaxPool2d(1)
139 | self.fc = nn.Sequential(
140 | nn.Linear(channel*2,channel//reducation),
141 | nn.ReLU(inplace=True),
142 | nn.Linear(channel//reducation,channel),
143 | nn.Sigmoid())
144 |
145 | def forward(self,x):
146 | b,c,w,h = x.size()
147 | y1 = self.avg_pool(x).view(b,c)
148 | y2 = self.max_pool(x).view(b,c)
149 | y = self.fc(torch.cat([y1,y2],1)).view(b,c,1,1)
150 | return x*y
151 |
152 | class SpatialAttentionModule(nn.Module):
153 | """docstring for SpatialAttentionModule"""
154 | def __init__(self, channel,reducation=16):
155 | super(SpatialAttentionModule, self).__init__()
156 | self.avg_pool = ChannelPool('avg')
157 | self.max_pool = ChannelPool('max')
158 | self.fc = nn.Sequential(
159 | nn.Conv2d(2,reducation,7,stride=1,padding=3),
160 | nn.ReLU(inplace=True),
161 | nn.Conv2d(reducation,1,7,stride=1,padding=3),
162 | nn.Sigmoid())
163 |
164 | def forward(self,x):
165 | b,c,w,h = x.size()
166 | y1 = self.avg_pool(x)
167 | y2 = self.max_pool(x)
168 | y = self.fc(torch.cat([y1,y2],1))
169 | yr = 1-y
170 | return y,yr
171 |
172 |
173 |
174 | class GlobalAttentionModuleJustSigmoid(nn.Module):
175 | """docstring for GlobalAttentionModule"""
176 | def __init__(self, channel,reducation=16):
177 | super(GlobalAttentionModuleJustSigmoid, self).__init__()
178 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
179 | self.max_pool = nn.AdaptiveMaxPool2d(1)
180 | self.fc = nn.Sequential(
181 | nn.Linear(channel*2,channel//reducation),
182 | nn.ReLU(inplace=True),
183 | nn.Linear(channel//reducation,channel),
184 | nn.Sigmoid())
185 |
186 | def forward(self,x):
187 | b,c,w,h = x.size()
188 | y1 = self.avg_pool(x).view(b,c)
189 | y2 = self.max_pool(x).view(b,c)
190 | y = self.fc(torch.cat([y1,y2],1)).view(b,c,1,1)
191 | return y
192 |
193 |
194 |
195 | class BasicBlock(nn.Module):
196 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
197 | super(BasicBlock, self).__init__()
198 | self.out_channels = out_planes
199 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
200 | self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
201 | self.relu = nn.ReLU() if relu else None
202 |
203 | def forward(self, x):
204 | x = self.conv(x)
205 | if self.bn is not None:
206 | x = self.bn(x)
207 | if self.relu is not None:
208 | x = self.relu(x)
209 | return x
210 |
211 | class Flatten(nn.Module):
212 | def forward(self, x):
213 | return x.view(x.size(0), -1)
214 |
215 | class ChannelGate(nn.Module):
216 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
217 | super(ChannelGate, self).__init__()
218 | self.gate_channels = gate_channels
219 | self.mlp = nn.Sequential(
220 | Flatten(),
221 | nn.Linear(gate_channels, gate_channels // reduction_ratio),
222 | nn.ReLU(),
223 | nn.Linear(gate_channels // reduction_ratio, gate_channels)
224 | )
225 | self.pool_types = pool_types
226 | def forward(self, x):
227 | channel_att_sum = None
228 | for pool_type in self.pool_types:
229 | if pool_type=='avg':
230 | avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
231 | channel_att_raw = self.mlp( avg_pool )
232 | elif pool_type=='max':
233 | max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
234 | channel_att_raw = self.mlp( max_pool )
235 | elif pool_type=='lp':
236 | lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
237 | channel_att_raw = self.mlp( lp_pool )
238 | elif pool_type=='lse':
239 | # LSE pool only
240 | lse_pool = logsumexp_2d(x)
241 | channel_att_raw = self.mlp( lse_pool )
242 |
243 | if channel_att_sum is None:
244 | channel_att_sum = channel_att_raw
245 | else:
246 | channel_att_sum = channel_att_sum + channel_att_raw
247 |
248 | scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
249 | return x * scale
250 |
251 | def logsumexp_2d(tensor):
252 | tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
253 | s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
254 | outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
255 | return outputs
256 |
257 | class ChannelPoolX(nn.Module):
258 | def forward(self, x):
259 | return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )
260 |
261 | class SpatialGate(nn.Module):
262 | def __init__(self):
263 | super(SpatialGate, self).__init__()
264 | kernel_size = 7
265 | self.compress = ChannelPoolX()
266 | self.spatial = BasicBlock(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
267 | def forward(self, x):
268 | x_compress = self.compress(x)
269 | x_out = self.spatial(x_compress)
270 | scale = F.sigmoid(x_out) # broadcasting
271 | return x * scale
272 |
273 | class CBAM(nn.Module):
274 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
275 | super(CBAM, self).__init__()
276 | self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
277 | self.no_spatial=no_spatial
278 | if not no_spatial:
279 | self.SpatialGate = SpatialGate()
280 | def forward(self, x):
281 | x_out = self.ChannelGate(x)
282 | if not self.no_spatial:
283 | x_out = self.SpatialGate(x_out)
284 | return x_out
285 |
286 |
287 |
--------------------------------------------------------------------------------
/scripts/utils/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from scripts.models.vgg import Vgg19
5 | from torchvision import models
6 | from scripts.utils.misc import resize_to_match
7 | # from pytorch_msssim import SSIM, MS_SSIM
8 | import pytorch_ssim
9 |
10 | class WeightedBCE(nn.Module):
11 | def __init__(self):
12 | super(WeightedBCE, self).__init__()
13 |
14 | def forward(self, pred, gt):
15 | eposion = 1e-10
16 | sigmoid_pred = torch.sigmoid(pred)
17 | count_pos = torch.sum(gt)*1.0+eposion
18 | count_neg = torch.sum(1.-gt)*1.0
19 | beta = count_neg/count_pos
20 | beta_back = count_pos / (count_pos + count_neg)
21 |
22 | bce1 = nn.BCEWithLogitsLoss(pos_weight=beta)
23 | loss = beta_back*bce1(pred, gt)
24 |
25 | return loss
26 |
27 |
28 | def l1_relative(reconstructed, real, mask):
29 | batch = real.size(0)
30 | area = torch.sum(mask.view(batch,-1),dim=1)
31 | reconstructed = reconstructed * mask
32 | real = real * mask
33 |
34 | loss_l1 = torch.abs(reconstructed - real).view(batch, -1)
35 | loss_l1 = torch.sum(loss_l1, dim=1) / area
36 | loss_l1 = torch.sum(loss_l1) / batch
37 | return loss_l1
38 |
39 |
40 | def is_dic(x):
41 | return type(x) == type([])
42 |
43 | class Losses(nn.Module):
44 | def __init__(self, argx, device):
45 | super(Losses, self).__init__()
46 | self.args = argx
47 |
48 | if self.args.loss_type == 'l1bl2':
49 | self.outputLoss, self.attLoss, self.wrloss = nn.L1Loss(), nn.BCELoss(), nn.MSELoss()
50 | elif self.args.loss_type == 'l1wbl2':
51 | self.outputLoss, self.attLoss, self.wrloss = nn.L1Loss(), WeightedBCE(), nn.MSELoss()
52 | elif self.args.loss_type == 'l2wbl2':
53 | self.outputLoss, self.attLoss, self.wrloss = nn.MSELoss(), WeightedBCE(), nn.MSELoss()
54 | elif self.args.loss_type == 'l2xbl2':
55 | self.outputLoss, self.attLoss, self.wrloss = nn.MSELoss(), nn.BCEWithLogitsLoss(), nn.MSELoss()
56 | else: # l2bl2
57 | self.outputLoss, self.attLoss, self.wrloss = nn.MSELoss(), nn.BCELoss(), nn.MSELoss()
58 |
59 | if self.args.style_loss > 0:
60 | self.vggloss = VGGLoss(self.args.sltype).to(device)
61 |
62 | if self.args.ssim_loss > 0:
63 | self.ssimloss = pytorch_ssim.SSIM().to(device)
64 |
65 | self.outputLoss = self.outputLoss.to(device)
66 | self.attLoss = self.attLoss.to(device)
67 | self.wrloss = self.wrloss.to(device)
68 |
69 |
70 | def forward(self,imgx,target,attx,mask,wmx,wm):
71 | pixel_loss,att_loss,wm_loss,vgg_loss,ssim_loss = 0,0,0,0,0
72 |
73 | if is_dic(imgx):
74 |
75 | if self.args.masked:
76 | # calculate the overall loss and side output
77 | pixel_loss = self.outputLoss(imgx[0],target) + sum([self.outputLoss(im,resize_to_match(mask,im)*resize_to_match(target,im)) for im in imgx[1:]])
78 | else:
79 | pixel_loss = sum([self.outputLoss(im,resize_to_match(target,im)) for im in imgx])
80 |
81 | if self.args.style_loss > 0:
82 | vgg_loss = sum([self.vggloss(im,resize_to_match(target,im),resize_to_match(mask,im)) for im in imgx])
83 |
84 | if self.args.ssim_loss > 0:
85 | ssim_loss = sum([ 1 - self.ssimloss(im,resize_to_match(target,im)) for im in imgx])
86 | else:
87 |
88 | if self.args.masked:
89 | pixel_loss = self.outputLoss(imgx,mask*target)
90 | else:
91 | pixel_loss = self.outputLoss(imgx,target)
92 |
93 | if self.args.style_loss > 0:
94 | vgg_loss = self.vggloss(imgx,target,mask)
95 |
96 | if self.args.ssim_loss > 0:
97 | ssim_loss = 1 - self.ssimloss(imgx,target)
98 |
99 | if is_dic(attx):
100 | att_loss = sum([self.attLoss(at,resize_to_match(mask,at)) for at in attx])
101 | else:
102 | att_loss = self.attLoss(attx, mask)
103 |
104 | if is_dic(wmx):
105 | wm_loss = sum([self.wrloss(w,resize_to_match(wm,w)) for w in wmx])
106 | else:
107 | if self.args.masked:
108 | wm_loss = self.wrloss(wmx,mask*wm)
109 | else:
110 | wm_loss = self.wrloss(wmx, wm)
111 |
112 | return pixel_loss,att_loss,wm_loss,vgg_loss,ssim_loss
113 |
114 |
115 |
116 | def gram_matrix(feat):
117 | # https://github.com/pytorch/examples/blob/master/fast_neural_style/neural_style/utils.py
118 | (b, ch, h, w) = feat.size()
119 | feat = feat.view(b, ch, h * w)
120 | feat_t = feat.transpose(1, 2)
121 | gram = torch.bmm(feat, feat_t) / (ch * h * w)
122 | return gram
123 |
124 | class MeanShift(nn.Conv2d):
125 | def __init__(self, data_mean, data_std, data_range=1, norm=True):
126 | """norm (bool): normalize/denormalize the stats"""
127 | c = len(data_mean)
128 | super(MeanShift, self).__init__(c, c, kernel_size=1)
129 | std = torch.Tensor(data_std)
130 | self.weight.data = torch.eye(c).view(c, c, 1, 1)
131 | if norm:
132 | self.weight.data.div_(std.view(c, 1, 1, 1))
133 | self.bias.data = -1 * data_range * torch.Tensor(data_mean)
134 | self.bias.data.div_(std)
135 | else:
136 | self.weight.data.mul_(std.view(c, 1, 1, 1))
137 | self.bias.data = data_range * torch.Tensor(data_mean)
138 | self.requires_grad = False
139 |
140 |
141 |
142 | def VGGLoss(losstype):
143 | if losstype == 'vgg':
144 | return VGGLossA()
145 | elif losstype == 'vggx':
146 | return VGGLossX(mask=False)
147 | elif losstype == 'mvggx':
148 | return VGGLossX(mask=True)
149 | elif losstype == 'rvggx':
150 | return VGGLossX(mask=True,relative=True)
151 | else:
152 | raise Exception("error in %s"%losstype)
153 |
154 |
155 |
156 | class VGGLossA(nn.Module):
157 | def __init__(self, vgg=None, weights=None, indices=None, normalize=True):
158 | super(VGGLossA, self).__init__()
159 | if vgg is None:
160 | self.vgg = Vgg19().cuda()
161 | else:
162 | self.vgg = vgg
163 | self.criterion = nn.L1Loss()
164 | self.weights = weights or [1.0/2.6, 1.0/4.8, 1.0/3.7, 1.0/5.6, 10/1.5]
165 | self.indices = indices or [2, 7, 12, 21, 30]
166 | if normalize:
167 | self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda()
168 | else:
169 | self.normalize = None
170 |
171 | def forward(self, x, y):
172 | if self.normalize is not None:
173 | x = self.normalize(x)
174 | y = self.normalize(y)
175 | x_vgg, y_vgg = self.vgg(x, self.indices), self.vgg(y, self.indices)
176 | loss = 0
177 | for i in range(len(x_vgg)):
178 | loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
179 | return loss
180 |
181 |
182 | class VGG16FeatureExtractor(nn.Module):
183 | def __init__(self):
184 | super().__init__()
185 | vgg16 = models.vgg16(pretrained=True)
186 | self.enc_1 = nn.Sequential(*vgg16.features[:5])
187 | self.enc_2 = nn.Sequential(*vgg16.features[5:10])
188 | self.enc_3 = nn.Sequential(*vgg16.features[10:17])
189 |
190 | # fix the encoder
191 | for i in range(3):
192 | for param in getattr(self, 'enc_{:d}'.format(i + 1)).parameters():
193 | param.requires_grad = False
194 |
195 | def forward(self, image):
196 | results = [image]
197 | for i in range(3):
198 | func = getattr(self, 'enc_{:d}'.format(i + 1))
199 | results.append(func(results[-1]))
200 | return results[1:]
201 |
202 | class VGGLossX(nn.Module):
203 | def __init__(self, normalize=True, mask=False, relative=False):
204 | super(VGGLossX, self).__init__()
205 |
206 | self.vgg = VGG16FeatureExtractor().cuda()
207 | self.criterion = nn.L1Loss().cuda() if not relative else l1_relative
208 | self.use_mask= mask
209 | self.relative = relative
210 |
211 | if normalize:
212 | self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda()
213 | else:
214 | self.normalize = None
215 |
216 | def forward(self, x, y, Xmask=None):
217 | if not self.use_mask:
218 | mask = torch.ones_like(x)[:,0:1,:,:]
219 | else:
220 | mask = Xmask
221 |
222 | if self.normalize is not None:
223 | x = self.normalize(x)
224 | y = self.normalize(y)
225 |
226 | x_vgg = self.vgg(x)
227 | y_vgg = self.vgg(y)
228 |
229 | loss = 0
230 | for i in range(3):
231 | if self.relative:
232 | loss += self.criterion(x_vgg[i],y_vgg[i].detach(),resize_to_match(mask,x_vgg[i]))
233 | else:
234 | loss += self.criterion(resize_to_match(mask,x_vgg[i])*x_vgg[i],resize_to_match(mask,y_vgg[i])*y_vgg[i].detach())
235 |
236 | return loss
237 |
238 |
239 | class GANLosses(object):
240 | """docstring for Loss"""
241 | def __init__(self, gantype):
242 | super(GANLosses, self).__init__()
243 | self.generator_loss = gen_gan(gantype)
244 | self.discriminator_loss = dis_gan(gantype)
245 | self.gantype = gantype
246 |
247 | def g_loss(self,dis_fake):
248 | if 'hinge' in self.gantype:
249 | return gen_hinge(dis_fake)
250 | else:
251 | return self.generator_loss(dis_fake)
252 |
253 | def d_loss(self,dis_fake,dis_real):
254 | if 'hinge' in self.gantype:
255 | return dis_hinge(dis_fake,dis_real)
256 | else:
257 | return self.discriminator_loss(dis_fake,dis_real)
258 |
259 |
260 | class gen_gan(nn.Module):
261 | def __init__(self,gantype):
262 | super(gen_gan,self).__init__()
263 | if gantype == 'lsgan':
264 | self.criterion = nn.MSELoss()
265 | elif gantype == 'naive':
266 | self.criterion = nn.BCEWithLogitsLoss()
267 | else:
268 | raise Exception("error gan type")
269 |
270 | def forward(self,dis_fake):
271 | return self.criterion(dis_fake, torch.ones_like(dis_fake))
272 |
273 | class dis_gan(nn.Module):
274 | def __init__(self,gantype):
275 | super(dis_gan,self).__init__()
276 | if gantype == 'lsgan':
277 | self.criterion = nn.MSELoss()
278 | elif gantype == 'naive':
279 | self.criterion = nn.BCEWithLogitsLoss()
280 | else:
281 | raise Exception("error gan type")
282 |
283 | def forward(self,dis_fake,dis_real):
284 | loss_fake = self.criterion(dis_fake, torch.zeros_like(dis_fake))
285 | loss_real = self.criterion(dis_real, torch.ones_like(dis_real))
286 | return loss_fake, loss_real
287 |
288 | # def gen_gan(dis_fake):
289 | # # fake -> 1
290 | # return F.binary_cross_entropy_with_logits(dis_fake,torch.ones_like(dis_fake))
291 |
292 | # def dis_gan(dis_fake,dis_real):
293 | # # fake -> 0 , real ->1
294 | # loss_fake = F.binary_cross_entropy_with_logits(dis_fake, torch.zeros_like(dis_real))
295 | # loss_real = F.binary_cross_entropy_with_logits(dis_real, torch.ones_like(dis_fake))
296 | # return loss_fake,loss_real
297 |
298 | # def gen_lsgan(dis_fake):
299 | # loss = F.mse_loss(dis_fake,torch.ones_like(dis_fake)) #
300 | # return loss
301 |
302 | # def dis_lsgan(dis_fake, dis_real):
303 | # loss_fake = F.mse_loss(dis_fake, torch.zeros_like(dis_real))
304 | # loss_real = F.mse_loss(dis_real, torch.ones_like(dis_real))
305 | # return loss_fake,loss_real
306 |
307 | def gen_hinge(dis_fake, dis_real=None):
308 | return -torch.mean(dis_fake)
309 |
310 | def dis_hinge(dis_fake, dis_real):
311 | loss_fake = torch.mean(torch.relu(1. + dis_fake))
312 | loss_real = torch.mean(torch.relu(1. - dis_real))
313 | return loss_fake,loss_real
314 |
315 |
--------------------------------------------------------------------------------
/scripts/machines/S2AM.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.backends.cudnn as cudnn
4 | from progress.bar import Bar
5 | import json
6 | import numpy as np
7 | from tensorboardX import SummaryWriter
8 | from scripts.utils.evaluation import accuracy, AverageMeter, final_preds
9 | from scripts.utils.osutils import mkdir_p, isfile, isdir, join
10 | from scripts.utils.parallel import DataParallelModel, DataParallelCriterion
11 | import pytorch_ssim as pytorch_ssim
12 | import torch.optim
13 | import sys,shutil,os
14 | import time
15 | import scripts.models as archs
16 | from math import log10
17 | from torch.autograd import Variable
18 | from scripts.utils.losses import VGGLoss
19 | from scripts.utils.imutils import im_to_numpy
20 |
21 | import skimage.io
22 | from skimage.measure import compare_psnr,compare_ssim
23 |
24 |
25 | class S2AM(object):
26 | def __init__(self, datasets =(None,None), models = None, args = None, **kwargs):
27 | super(S2AM, self).__init__()
28 |
29 | self.args = args
30 |
31 | # create model
32 | print("==> creating model ")
33 | self.model = archs.__dict__[self.args.arch]()
34 | print("==> creating model [Finish]")
35 |
36 | self.train_loader, self.val_loader = datasets
37 | self.loss = torch.nn.MSELoss()
38 |
39 | self.title = '_'+args.machine + '_' + args.data + '_' + args.arch
40 | self.args.checkpoint = args.checkpoint + self.title
41 | self.device = torch.device('cuda')
42 | # create checkpoint dir
43 | if not isdir(self.args.checkpoint):
44 | mkdir_p(self.args.checkpoint)
45 |
46 | self.optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()),
47 | lr=args.lr,
48 | betas=(args.beta1,args.beta2),
49 | weight_decay=args.weight_decay)
50 |
51 | if not self.args.evaluate:
52 | self.writer = SummaryWriter(self.args.checkpoint+'/'+'ckpt')
53 |
54 | self.best_acc = 0
55 | self.is_best = False
56 | self.current_epoch = 0
57 | self.hl = 1
58 | self.metric = -100000
59 | self.count_gpu = len(range(torch.cuda.device_count()))
60 |
61 | if self.args.style_loss > 0:
62 | # init perception loss
63 | self.vggloss = VGGLoss(self.args.sltype).to(self.device)
64 |
65 | if self.count_gpu > 1 : # multiple
66 | # self.model = DataParallelModel(self.model, device_ids=range(torch.cuda.device_count()))
67 | # self.loss = DataParallelCriterion(self.loss, device_ids=range(torch.cuda.device_count()))
68 | self.model = nn.DataParallel(self.model, device_ids=range(torch.cuda.device_count()))
69 |
70 | self.model.to(self.device)
71 | self.loss.to(self.device)
72 |
73 | print('==> Total params: %.2fM' % (sum(p.numel() for p in self.model.parameters())/1000000.0))
74 | print('==> Total devices: %d' % (torch.cuda.device_count()))
75 | print('==> Current Checkpoint: %s' % (self.args.checkpoint))
76 |
77 |
78 | if self.args.resume != '':
79 | self.resume(self.args.resume)
80 |
81 |
82 | def train(self,epoch):
83 | batch_time = AverageMeter()
84 | data_time = AverageMeter()
85 | losses = AverageMeter()
86 | lossvgg = AverageMeter()
87 |
88 | # switch to train mode
89 | self.model.train()
90 | end = time.time()
91 |
92 | bar = Bar('Processing', max=len(self.train_loader)*self.hl)
93 | for _ in range(self.hl):
94 | for i, batches in enumerate(self.train_loader):
95 | # measure data loading time
96 | inputs = batches['image'].to(self.device)
97 | target = batches['target'].to(self.device)
98 | mask =batches['mask'].to(self.device)
99 | current_index = len(self.train_loader) * epoch + i
100 |
101 | feeded = torch.cat([inputs,mask],dim=1)
102 | feeded = feeded.to(self.device)
103 |
104 | output = self.model(feeded)
105 |
106 | if self.args.res:
107 | output = output + inputs
108 |
109 | L2_loss = self.loss(output,target)
110 |
111 | if self.args.style_loss > 0:
112 | vgg_loss = self.vggloss(output,target,mask)
113 | else:
114 | vgg_loss = 0
115 |
116 | total_loss = L2_loss + self.args.style_loss * vgg_loss
117 |
118 | # compute gradient and do SGD step
119 | self.optimizer.zero_grad()
120 | total_loss.backward()
121 | self.optimizer.step()
122 |
123 | # measure accuracy and record loss
124 | losses.update(L2_loss.item(), inputs.size(0))
125 |
126 | if self.args.style_loss > 0 :
127 | lossvgg.update(vgg_loss.item(), inputs.size(0))
128 |
129 | # measure elapsed time
130 | batch_time.update(time.time() - end)
131 | end = time.time()
132 |
133 | # plot progress
134 | suffix = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss L2: {loss_label:.4f} | Loss VGG: {loss_vgg:.4f}'.format(
135 | batch=i + 1,
136 | size=len(self.train_loader),
137 | data=data_time.val,
138 | bt=batch_time.val,
139 | total=bar.elapsed_td,
140 | eta=bar.eta_td,
141 | loss_label=losses.avg,
142 | loss_vgg=lossvgg.avg
143 | )
144 |
145 | if current_index % 1000 == 0:
146 | print(suffix)
147 |
148 | if self.args.freq > 0 and current_index % self.args.freq == 0:
149 | self.validate(current_index)
150 | self.flush()
151 | self.save_checkpoint()
152 |
153 | self.record('train/loss_L2', losses.avg, current_index)
154 |
155 |
156 | def test(self, ):
157 |
158 | # switch to evaluate mode
159 | self.model.eval()
160 |
161 | ssimes = AverageMeter()
162 | psnres = AverageMeter()
163 |
164 | with torch.no_grad():
165 | for i, batches in enumerate(self.val_loader):
166 |
167 | inputs = batches['image'].to(self.device)
168 | target = batches['target'].to(self.device)
169 | mask =batches['mask'].to(self.device)
170 |
171 | feeded = torch.cat([inputs,mask],dim=1)
172 | feeded = feeded.to(self.device)
173 |
174 | output = self.model(feeded)
175 |
176 | if self.args.res:
177 | output = output + inputs
178 |
179 | # recover the image to 255
180 | output = im_to_numpy(torch.clamp(output[0]*255,min=0.0,max=255.0)).astype(np.uint8)
181 | target = im_to_numpy(torch.clamp(target[0]*255,min=0.0,max=255.0)).astype(np.uint8)
182 |
183 | skimage.io.imsave('%s/%s'%(self.args.checkpoint,batches['name'][0]), output)
184 |
185 | psnr = compare_psnr(target,output)
186 | ssim = compare_ssim(target,output,multichannel=True)
187 |
188 | psnres.update(psnr, inputs.size(0))
189 | ssimes.update(ssim, inputs.size(0))
190 |
191 | print("%s:PSNR:%s,SSIM:%s"%(self.args.checkpoint,psnres.avg,ssimes.avg))
192 | print("DONE.\n")
193 |
194 |
195 | def validate(self, epoch):
196 | batch_time = AverageMeter()
197 | data_time = AverageMeter()
198 | losses = AverageMeter()
199 | ssimes = AverageMeter()
200 | psnres = AverageMeter()
201 | # switch to evaluate mode
202 | self.model.eval()
203 |
204 | end = time.time()
205 | with torch.no_grad():
206 | for i, batches in enumerate(self.val_loader):
207 |
208 | inputs = batches['image'].to(self.device)
209 | target = batches['target'].to(self.device)
210 | mask =batches['mask'].to(self.device)
211 |
212 | feeded = torch.cat([inputs,mask],dim=1)
213 | feeded = feeded.to(self.device)
214 |
215 | output = self.model(feeded)
216 |
217 | if self.args.res:
218 | output = output + inputs
219 |
220 | L2_loss = self.loss(output, target)
221 |
222 | psnr = 10 * log10(1 / L2_loss.item())
223 | ssim = pytorch_ssim.ssim(output, target)
224 |
225 | losses.update(L2_loss.item(), inputs.size(0))
226 | psnres.update(psnr, inputs.size(0))
227 | ssimes.update(ssim.item(), inputs.size(0))
228 |
229 | # measure elapsed time
230 | batch_time.update(time.time() - end)
231 | end = time.time()
232 |
233 | print("Epoches:%s,Losses:%.3f,PSNR:%.3f,SSIM:%.3f"%(epoch+1, losses.avg,psnres.avg,ssimes.avg))
234 | self.record('val/loss_L2', losses.avg, epoch)
235 | self.record('val/PSNR', psnres.avg, epoch)
236 | self.record('val/SSIM', ssimes.avg, epoch)
237 |
238 | self.metric = psnres.avg
239 |
240 | def resume(self,resume_path):
241 | if isfile(resume_path):
242 | print("=> loading checkpoint '{}'".format(resume_path))
243 | current_checkpoint = torch.load(resume_path)
244 | if isinstance(current_checkpoint['state_dict'], torch.nn.DataParallel):
245 | current_checkpoint['state_dict'] = current_checkpoint['state_dict'].module
246 |
247 | if isinstance(current_checkpoint['optimizer'], torch.nn.DataParallel):
248 | current_checkpoint['optimizer'] = current_checkpoint['optimizer'].module
249 |
250 | self.args.start_epoch = current_checkpoint['epoch']
251 | self.metric = current_checkpoint['best_acc']
252 | self.model.load_state_dict(current_checkpoint['state_dict'])
253 | # self.optimizer.load_state_dict(current_checkpoint['optimizer'])
254 | print("=> loaded checkpoint '{}' (epoch {})"
255 | .format(resume_path, current_checkpoint['epoch']))
256 | else:
257 | raise Exception("=> no checkpoint found at '{}'".format(resume_path))
258 |
259 | def save_checkpoint(self,filename='checkpoint.pth.tar', snapshot=None):
260 | is_best = True if self.best_acc < self.metric else False
261 |
262 | if is_best:
263 | self.best_acc = self.metric
264 |
265 | state = {
266 | 'epoch': self.current_epoch + 1,
267 | 'arch': self.args.arch,
268 | 'state_dict': self.model.state_dict(),
269 | 'best_acc': self.best_acc,
270 | 'optimizer' : self.optimizer.state_dict() if self.optimizer else None,
271 | }
272 |
273 | filepath = os.path.join(self.args.checkpoint, filename)
274 | torch.save(state, filepath)
275 |
276 | if snapshot and state['epoch'] % snapshot == 0:
277 | shutil.copyfile(filepath, os.path.join(self.args.checkpoint, 'checkpoint_{}.pth.tar'.format(state.epoch)))
278 |
279 | if is_best:
280 | self.best_acc = self.metric
281 | print('Saving Best Metric with PSNR:%s'%self.best_acc)
282 | shutil.copyfile(filepath, os.path.join(self.args.checkpoint, 'model_best.pth.tar'))
283 |
284 | def clean(self):
285 | self.writer.close()
286 |
287 | def record(self,k,v,epoch):
288 | self.writer.add_scalar(k, v, epoch)
289 |
290 | def flush(self):
291 | self.writer.flush()
292 | sys.stdout.flush()
293 |
294 | def norm(self,x):
295 | if self.args.gan_norm:
296 | return x*2.0 - 1.0
297 | else:
298 | return x
299 |
300 | def denorm(self,x):
301 | if self.args.gan_norm:
302 | return (x+1.0)/2.0
303 | else:
304 | return x
305 |
306 |
--------------------------------------------------------------------------------
/scripts/utils/parallel.py:
--------------------------------------------------------------------------------
1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 | ## Created by: Hang Zhang, Rutgers University, Email: zhang.hang@rutgers.edu
3 | ## Modified by Thomas Wolf, HuggingFace Inc., Email: thomas@huggingface.co
4 | ## Copyright (c) 2017-2018
5 | ##
6 | ## This source code is licensed under the MIT-style license found in the
7 | ## LICENSE file in the root directory of this source tree
8 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
9 |
10 | """Encoding Data Parallel"""
11 | import threading
12 | import functools
13 | import torch
14 | from torch.autograd import Variable, Function
15 | import torch.cuda.comm as comm
16 | from torch.nn.parallel import DistributedDataParallel
17 | from torch.nn.parallel.data_parallel import DataParallel
18 | from torch.nn.parallel.parallel_apply import get_a_var
19 | from torch.nn.parallel.scatter_gather import gather
20 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
21 |
22 | torch_ver = torch.__version__[:3]
23 |
24 | __all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion',
25 | 'patch_replication_callback']
26 |
27 | def allreduce(*inputs):
28 | """Cross GPU all reduce autograd operation for calculate mean and
29 | variance in SyncBN.
30 | """
31 | return AllReduce.apply(*inputs)
32 |
33 | class AllReduce(Function):
34 | @staticmethod
35 | def forward(ctx, num_inputs, *inputs):
36 | ctx.num_inputs = num_inputs
37 | ctx.target_gpus = [inputs[i].get_device() for i in range(0, len(inputs), num_inputs)]
38 | inputs = [inputs[i:i + num_inputs]
39 | for i in range(0, len(inputs), num_inputs)]
40 | # sort before reduce sum
41 | inputs = sorted(inputs, key=lambda i: i[0].get_device())
42 | results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0])
43 | outputs = comm.broadcast_coalesced(results, ctx.target_gpus)
44 | return tuple([t for tensors in outputs for t in tensors])
45 |
46 | @staticmethod
47 | def backward(ctx, *inputs):
48 | inputs = [i.data for i in inputs]
49 | inputs = [inputs[i:i + ctx.num_inputs]
50 | for i in range(0, len(inputs), ctx.num_inputs)]
51 | results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0])
52 | outputs = comm.broadcast_coalesced(results, ctx.target_gpus)
53 | return (None,) + tuple([Variable(t) for tensors in outputs for t in tensors])
54 |
55 |
56 | class Reduce(Function):
57 | @staticmethod
58 | def forward(ctx, *inputs):
59 | ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))]
60 | inputs = sorted(inputs, key=lambda i: i.get_device())
61 | return comm.reduce_add(inputs)
62 |
63 | @staticmethod
64 | def backward(ctx, gradOutput):
65 | return Broadcast.apply(ctx.target_gpus, gradOutput)
66 |
67 | class DistributedDataParallelModel(DistributedDataParallel):
68 | """Implements data parallelism at the module level for the DistributedDataParallel module.
69 | This container parallelizes the application of the given module by
70 | splitting the input across the specified devices by chunking in the
71 | batch dimension.
72 | In the forward pass, the module is replicated on each device,
73 | and each replica handles a portion of the input. During the backwards pass,
74 | gradients from each replica are summed into the original module.
75 | Note that the outputs are not gathered, please use compatible
76 | :class:`encoding.parallel.DataParallelCriterion`.
77 | The batch size should be larger than the number of GPUs used. It should
78 | also be an integer multiple of the number of GPUs so that each chunk is
79 | the same size (so that each GPU processes the same number of samples).
80 | Args:
81 | module: module to be parallelized
82 | device_ids: CUDA devices (default: all devices)
83 | Reference:
84 | Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
85 | Amit Agrawal. “Context Encoding for Semantic Segmentation.
86 | *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*
87 | Example::
88 | >>> net = encoding.nn.DistributedDataParallelModel(model, device_ids=[0, 1, 2])
89 | >>> y = net(x)
90 | """
91 | def gather(self, outputs, output_device):
92 | return outputs
93 |
94 | class DataParallelModel(DataParallel):
95 | """Implements data parallelism at the module level.
96 |
97 | This container parallelizes the application of the given module by
98 | splitting the input across the specified devices by chunking in the
99 | batch dimension.
100 | In the forward pass, the module is replicated on each device,
101 | and each replica handles a portion of the input. During the backwards pass,
102 | gradients from each replica are summed into the original module.
103 | Note that the outputs are not gathered, please use compatible
104 | :class:`encoding.parallel.DataParallelCriterion`.
105 |
106 | The batch size should be larger than the number of GPUs used. It should
107 | also be an integer multiple of the number of GPUs so that each chunk is
108 | the same size (so that each GPU processes the same number of samples).
109 |
110 | Args:
111 | module: module to be parallelized
112 | device_ids: CUDA devices (default: all devices)
113 |
114 | Reference:
115 | Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
116 | Amit Agrawal. “Context Encoding for Semantic Segmentation.
117 | *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*
118 |
119 | Example::
120 |
121 | >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
122 | >>> y = net(x)
123 | """
124 | def gather(self, outputs, output_device):
125 | return outputs
126 |
127 | def replicate(self, module, device_ids):
128 | modules = super(DataParallelModel, self).replicate(module, device_ids)
129 | execute_replication_callbacks(modules)
130 | return modules
131 |
132 |
133 | class DataParallelCriterion(DataParallel):
134 | """
135 | Calculate loss in multiple-GPUs, which balance the memory usage.
136 | The targets are splitted across the specified devices by chunking in
137 | the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`.
138 |
139 | Reference:
140 | Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
141 | Amit Agrawal. “Context Encoding for Semantic Segmentation.
142 | *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*
143 |
144 | Example::
145 |
146 | >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
147 | >>> criterion = encoding.nn.DataParallelCriterion(criterion, device_ids=[0, 1, 2])
148 | >>> y = net(x)
149 | >>> loss = criterion(y, target)
150 | """
151 | def forward(self, inputs, *targets, **kwargs):
152 | # input should be already scatterd
153 | # scattering the targets instead
154 | if not self.device_ids:
155 | return self.module(inputs, *targets, **kwargs)
156 | targets, kwargs = self.scatter(targets, kwargs, self.device_ids)
157 | if len(self.device_ids) == 1:
158 | return self.module(inputs, *targets[0], **kwargs[0])
159 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
160 | outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs)
161 | #return Reduce.apply(*outputs) / len(outputs)
162 | #return self.gather(outputs, self.output_device).mean()
163 | return self.gather(outputs, self.output_device)
164 |
165 |
166 | def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None):
167 | assert len(modules) == len(inputs)
168 | assert len(targets) == len(inputs)
169 | if kwargs_tup:
170 | assert len(modules) == len(kwargs_tup)
171 | else:
172 | kwargs_tup = ({},) * len(modules)
173 | if devices is not None:
174 | assert len(modules) == len(devices)
175 | else:
176 | devices = [None] * len(modules)
177 |
178 | lock = threading.Lock()
179 | results = {}
180 | if torch_ver != "0.3":
181 | grad_enabled = torch.is_grad_enabled()
182 |
183 | def _worker(i, module, input, target, kwargs, device=None):
184 | if torch_ver != "0.3":
185 | torch.set_grad_enabled(grad_enabled)
186 | if device is None:
187 | device = get_a_var(input).get_device()
188 | try:
189 | with torch.cuda.device(device):
190 | # this also avoids accidental slicing of `input` if it is a Tensor
191 | if not isinstance(input, (list, tuple)):
192 | input = (input,)
193 | if not isinstance(target, (list, tuple)):
194 | target = (target,)
195 | output = module(*(input + target), **kwargs)
196 | with lock:
197 | results[i] = output
198 | except Exception as e:
199 | with lock:
200 | results[i] = e
201 |
202 | if len(modules) > 1:
203 | threads = [threading.Thread(target=_worker,
204 | args=(i, module, input, target,
205 | kwargs, device),)
206 | for i, (module, input, target, kwargs, device) in
207 | enumerate(zip(modules, inputs, targets, kwargs_tup, devices))]
208 |
209 | for thread in threads:
210 | thread.start()
211 | for thread in threads:
212 | thread.join()
213 | else:
214 | _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
215 |
216 | outputs = []
217 | for i in range(len(inputs)):
218 | output = results[i]
219 | if isinstance(output, Exception):
220 | raise output
221 | outputs.append(output)
222 | return outputs
223 |
224 |
225 | ###########################################################################
226 | # Adapted from Synchronized-BatchNorm-PyTorch.
227 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
228 | #
229 | class CallbackContext(object):
230 | pass
231 |
232 |
233 | def execute_replication_callbacks(modules):
234 | """
235 | Execute an replication callback `__data_parallel_replicate__` on each module created
236 | by original replication.
237 |
238 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
239 |
240 | Note that, as all modules are isomorphism, we assign each sub-module with a context
241 | (shared among multiple copies of this module on different devices).
242 | Through this context, different copies can share some information.
243 |
244 | We guarantee that the callback on the master copy (the first copy) will be called ahead
245 | of calling the callback of any slave copies.
246 | """
247 | master_copy = modules[0]
248 | nr_modules = len(list(master_copy.modules()))
249 | ctxs = [CallbackContext() for _ in range(nr_modules)]
250 |
251 | for i, module in enumerate(modules):
252 | for j, m in enumerate(module.modules()):
253 | if hasattr(m, '__data_parallel_replicate__'):
254 | m.__data_parallel_replicate__(ctxs[j], i)
255 |
256 |
257 | def patch_replication_callback(data_parallel):
258 | """
259 | Monkey-patch an existing `DataParallel` object. Add the replication callback.
260 | Useful when you have customized `DataParallel` implementation.
261 |
262 | Examples:
263 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
264 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
265 | > patch_replication_callback(sync_bn)
266 | # this is equivalent to
267 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
268 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
269 | """
270 |
271 | assert isinstance(data_parallel, DataParallel)
272 |
273 | old_replicate = data_parallel.replicate
274 |
275 | @functools.wraps(old_replicate)
276 | def new_replicate(module, device_ids):
277 | modules = old_replicate(module, device_ids)
278 | execute_replication_callbacks(modules)
279 | return modules
280 |
281 | data_parallel.replicate = new_replicate
--------------------------------------------------------------------------------
/scripts/machines/BasicMachine.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.backends.cudnn as cudnn
4 | from progress.bar import Bar
5 | import json
6 | import numpy as np
7 | from tensorboardX import SummaryWriter
8 | from scripts.utils.evaluation import accuracy, AverageMeter, final_preds
9 | from scripts.utils.osutils import mkdir_p, isfile, isdir, join
10 | from scripts.utils.parallel import DataParallelModel, DataParallelCriterion
11 | import pytorch_ssim as pytorch_ssim
12 | import torch.optim
13 | import sys,shutil,os
14 | import time
15 | import scripts.models as archs
16 | from math import log10
17 | from torch.autograd import Variable
18 | from scripts.utils.losses import VGGLoss
19 | from scripts.utils.imutils import im_to_numpy
20 |
21 | import skimage.io
22 | from skimage.measure import compare_psnr,compare_ssim
23 |
24 |
25 | class BasicMachine(object):
26 | def __init__(self, datasets =(None,None), models = None, args = None, **kwargs):
27 | super(BasicMachine, self).__init__()
28 |
29 | self.args = args
30 |
31 | # create model
32 | print("==> creating model ")
33 | self.model = archs.__dict__[self.args.arch]()
34 | print("==> creating model [Finish]")
35 |
36 | self.train_loader, self.val_loader = datasets
37 | self.loss = torch.nn.MSELoss()
38 |
39 | self.title = '_'+args.machine + '_' + args.data + '_' + args.arch
40 | self.args.checkpoint = args.checkpoint + self.title
41 | self.device = torch.device('cuda')
42 | # create checkpoint dir
43 | if not isdir(self.args.checkpoint):
44 | mkdir_p(self.args.checkpoint)
45 |
46 | self.optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()),
47 | lr=args.lr,
48 | betas=(args.beta1,args.beta2),
49 | weight_decay=args.weight_decay)
50 |
51 | if not self.args.evaluate:
52 | self.writer = SummaryWriter(self.args.checkpoint+'/'+'ckpt')
53 |
54 | self.best_acc = 0
55 | self.is_best = False
56 | self.current_epoch = 0
57 | self.metric = -100000
58 | self.hl = 6 if self.args.hl else 1
59 | self.count_gpu = len(range(torch.cuda.device_count()))
60 |
61 | if self.args.style_loss > 0:
62 | # init perception loss
63 | self.vggloss = VGGLoss(self.args.sltype).to(self.device)
64 |
65 | if self.count_gpu > 1 : # multiple
66 | # self.model = DataParallelModel(self.model, device_ids=range(torch.cuda.device_count()))
67 | # self.loss = DataParallelCriterion(self.loss, device_ids=range(torch.cuda.device_count()))
68 | self.model = nn.DataParallel(self.model, device_ids=range(torch.cuda.device_count()))
69 |
70 | self.model.to(self.device)
71 | self.loss.to(self.device)
72 |
73 | print('==> Total params: %.2fM' % (sum(p.numel() for p in self.model.parameters())/1000000.0))
74 | print('==> Total devices: %d' % (torch.cuda.device_count()))
75 | print('==> Current Checkpoint: %s' % (self.args.checkpoint))
76 |
77 |
78 | if self.args.resume != '':
79 | self.resume(self.args.resume)
80 |
81 |
82 | def train(self,epoch):
83 | batch_time = AverageMeter()
84 | data_time = AverageMeter()
85 | losses = AverageMeter()
86 | lossvgg = AverageMeter()
87 |
88 | # switch to train mode
89 | self.model.train()
90 | end = time.time()
91 |
92 | bar = Bar('Processing', max=len(self.train_loader)*self.hl)
93 | for _ in range(self.hl):
94 | for i, batches in enumerate(self.train_loader):
95 | # measure data loading time
96 | inputs = batches['image']
97 | target = batches['target'].to(self.device)
98 | mask =batches['mask'].to(self.device)
99 | current_index = len(self.train_loader) * epoch + i
100 |
101 | if self.args.hl:
102 | feeded = torch.cat([inputs,mask],dim=1)
103 | else:
104 | feeded = inputs
105 | feeded = feeded.to(self.device)
106 |
107 | output = self.model(feeded)
108 | L2_loss = self.loss(output,target)
109 |
110 | if self.args.style_loss > 0:
111 | vgg_loss = self.vggloss(output,target,mask)
112 | else:
113 | vgg_loss = 0
114 |
115 | total_loss = L2_loss + self.args.style_loss * vgg_loss
116 |
117 | # compute gradient and do SGD step
118 | self.optimizer.zero_grad()
119 | total_loss.backward()
120 | self.optimizer.step()
121 |
122 | # measure accuracy and record loss
123 | losses.update(L2_loss.item(), inputs.size(0))
124 |
125 | if self.args.style_loss > 0 :
126 | lossvgg.update(vgg_loss.item(), inputs.size(0))
127 |
128 | # measure elapsed time
129 | batch_time.update(time.time() - end)
130 | end = time.time()
131 |
132 | # plot progress
133 | suffix = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss L2: {loss_label:.4f} | Loss VGG: {loss_vgg:.4f}'.format(
134 | batch=i + 1,
135 | size=len(self.train_loader),
136 | data=data_time.val,
137 | bt=batch_time.val,
138 | total=bar.elapsed_td,
139 | eta=bar.eta_td,
140 | loss_label=losses.avg,
141 | loss_vgg=lossvgg.avg
142 | )
143 |
144 | if current_index % 1000 == 0:
145 | print(suffix)
146 |
147 | if self.args.freq > 0 and current_index % self.args.freq == 0:
148 | self.validate(current_index)
149 | self.flush()
150 | self.save_checkpoint()
151 |
152 | self.record('train/loss_L2', losses.avg, current_index)
153 |
154 |
155 | def test(self, ):
156 |
157 | # switch to evaluate mode
158 | self.model.eval()
159 |
160 | ssimes = AverageMeter()
161 | psnres = AverageMeter()
162 |
163 | with torch.no_grad():
164 | for i, batches in enumerate(self.val_loader):
165 |
166 | inputs = batches['image'].to(self.device)
167 | target = batches['target'].to(self.device)
168 | mask =batches['mask'].to(self.device)
169 |
170 | outputs = self.model(inputs)
171 |
172 | # select the outputs by the giving arch
173 | if type(outputs) == type(inputs):
174 | output = outputs
175 | elif type(outputs[0]) == type([]):
176 | output = outputs[0][0]
177 | else:
178 | output = outputs[0]
179 |
180 | # recover the image to 255
181 | output = im_to_numpy(torch.clamp(output[0]*255,min=0.0,max=255.0)).astype(np.uint8)
182 | target = im_to_numpy(torch.clamp(target[0]*255,min=0.0,max=255.0)).astype(np.uint8)
183 |
184 | skimage.io.imsave('%s/%s'%(self.args.checkpoint,batches['name'][0]), output)
185 |
186 | psnr = compare_psnr(target,output)
187 | ssim = compare_ssim(target,output,multichannel=True)
188 |
189 | psnres.update(psnr, inputs.size(0))
190 | ssimes.update(ssim, inputs.size(0))
191 |
192 | print("%s:PSNR:%s,SSIM:%s"%(self.args.checkpoint,psnres.avg,ssimes.avg))
193 | print("DONE.\n")
194 |
195 |
196 | def validate(self, epoch):
197 | batch_time = AverageMeter()
198 | data_time = AverageMeter()
199 | losses = AverageMeter()
200 | ssimes = AverageMeter()
201 | psnres = AverageMeter()
202 | # switch to evaluate mode
203 | self.model.eval()
204 |
205 | end = time.time()
206 | with torch.no_grad():
207 | for i, batches in enumerate(self.val_loader):
208 |
209 | inputs = batches['image'].to(self.device)
210 | target = batches['target'].to(self.device)
211 | mask =batches['mask'].to(self.device)
212 |
213 | if self.args.hl:
214 | feeded = torch.cat([inputs,torch.zeros((1,4,self.args.input_size,self.args.input_size)).to(self.device)],dim=1)
215 | else:
216 | feeded = inputs
217 |
218 | output = self.model(feeded)
219 |
220 | L2_loss = self.loss(output, target)
221 |
222 | psnr = 10 * log10(1 / L2_loss.item())
223 | ssim = pytorch_ssim.ssim(output, target)
224 |
225 | losses.update(L2_loss.item(), inputs.size(0))
226 | psnres.update(psnr, inputs.size(0))
227 | ssimes.update(ssim.item(), inputs.size(0))
228 |
229 | # measure elapsed time
230 | batch_time.update(time.time() - end)
231 | end = time.time()
232 |
233 | print("Epoches:%s,Losses:%.3f,PSNR:%.3f,SSIM:%.3f"%(epoch+1, losses.avg,psnres.avg,ssimes.avg))
234 | self.record('val/loss_L2', losses.avg, epoch)
235 | self.record('val/PSNR', psnres.avg, epoch)
236 | self.record('val/SSIM', ssimes.avg, epoch)
237 |
238 | self.metric = psnres.avg
239 |
240 | def resume(self,resume_path):
241 | if isfile(resume_path):
242 | print("=> loading checkpoint '{}'".format(resume_path))
243 | current_checkpoint = torch.load(resume_path)
244 | if isinstance(current_checkpoint['state_dict'], torch.nn.DataParallel):
245 | current_checkpoint['state_dict'] = current_checkpoint['state_dict'].module
246 |
247 | if isinstance(current_checkpoint['optimizer'], torch.nn.DataParallel):
248 | current_checkpoint['optimizer'] = current_checkpoint['optimizer'].module
249 |
250 | self.args.start_epoch = current_checkpoint['epoch']
251 | self.metric = current_checkpoint['best_acc']
252 | self.model.load_state_dict(current_checkpoint['state_dict'])
253 | # self.optimizer.load_state_dict(current_checkpoint['optimizer'])
254 | print("=> loaded checkpoint '{}' (epoch {})"
255 | .format(resume_path, current_checkpoint['epoch']))
256 | else:
257 | raise Exception("=> no checkpoint found at '{}'".format(resume_path))
258 |
259 | def save_checkpoint(self,filename='checkpoint.pth.tar', snapshot=None):
260 | is_best = True if self.best_acc < self.metric else False
261 |
262 | if is_best:
263 | self.best_acc = self.metric
264 |
265 | state = {
266 | 'epoch': self.current_epoch + 1,
267 | 'arch': self.args.arch,
268 | 'state_dict': self.model.state_dict(),
269 | 'best_acc': self.best_acc,
270 | 'optimizer' : self.optimizer.state_dict() if self.optimizer else None,
271 | }
272 |
273 | filepath = os.path.join(self.args.checkpoint, filename)
274 | torch.save(state, filepath)
275 |
276 | if snapshot and state['epoch'] % snapshot == 0:
277 | shutil.copyfile(filepath, os.path.join(self.args.checkpoint, 'checkpoint_{}.pth.tar'.format(state.epoch)))
278 |
279 | if is_best:
280 | self.best_acc = self.metric
281 | print('Saving Best Metric with PSNR:%s'%self.best_acc)
282 | shutil.copyfile(filepath, os.path.join(self.args.checkpoint, 'model_best.pth.tar'))
283 |
284 | def clean(self):
285 | self.writer.close()
286 |
287 | def record(self,k,v,epoch):
288 | self.writer.add_scalar(k, v, epoch)
289 |
290 | def flush(self):
291 | self.writer.flush()
292 | sys.stdout.flush()
293 |
294 | def norm(self,x):
295 | if self.args.gan_norm:
296 | return x*2.0 - 1.0
297 | else:
298 | return x
299 |
300 | def denorm(self,x):
301 | if self.args.gan_norm:
302 | return (x+1.0)/2.0
303 | else:
304 | return x
305 |
306 |
--------------------------------------------------------------------------------
/scripts/machines/VX.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from progress.bar import Bar
4 | from tqdm import tqdm
5 | import pytorch_ssim
6 | import json
7 | import sys,time,os
8 | import torchvision
9 | from math import log10
10 | import numpy as np
11 | from .BasicMachine import BasicMachine
12 | from scripts.utils.evaluation import accuracy, AverageMeter, final_preds
13 | from scripts.utils.misc import resize_to_match
14 | from torch.autograd import Variable
15 | import torch.nn.functional as F
16 | from scripts.utils.parallel import DataParallelModel, DataParallelCriterion
17 | from scripts.utils.losses import VGGLoss, l1_relative,is_dic
18 | from scripts.utils.imutils import im_to_numpy
19 | import skimage.io
20 | from skimage.measure import compare_psnr,compare_ssim
21 |
22 |
23 | class Losses(nn.Module):
24 | def __init__(self, argx, device, norm_func=None, denorm_func=None):
25 | super(Losses, self).__init__()
26 | self.args = argx
27 |
28 | if self.args.loss_type == 'l1bl2':
29 | self.outputLoss, self.attLoss, self.wrloss = nn.L1Loss(), nn.BCELoss(), nn.MSELoss()
30 | elif self.args.loss_type == 'l2xbl2':
31 | self.outputLoss, self.attLoss, self.wrloss = nn.MSELoss(), nn.BCEWithLogitsLoss(), nn.MSELoss()
32 | elif self.args.loss_type == 'relative' or self.args.loss_type == 'hybrid':
33 | self.outputLoss, self.attLoss, self.wrloss = l1_relative, nn.BCELoss(), l1_relative
34 | else: # l2bl2
35 | self.outputLoss, self.attLoss, self.wrloss = nn.MSELoss(), nn.BCELoss(), nn.MSELoss()
36 |
37 | self.default = nn.L1Loss()
38 |
39 | if self.args.style_loss > 0:
40 | self.vggloss = VGGLoss(self.args.sltype).to(device)
41 |
42 | if self.args.ssim_loss > 0:
43 | self.ssimloss = pytorch_ssim.SSIM().to(device)
44 |
45 | self.norm = norm_func
46 | self.denorm = denorm_func
47 |
48 |
49 | def forward(self,pred_ims,target,pred_ms,mask,pred_wms,wm):
50 | pixel_loss,att_loss,wm_loss,vgg_loss,ssim_loss = [0]*5
51 | pred_ims = pred_ims if is_dic(pred_ims) else [pred_ims]
52 |
53 | # try the loss in the masked region
54 | if self.args.masked and 'hybrid' in self.args.loss_type: # masked loss
55 | pixel_loss += sum([self.outputLoss(pred_im, target, mask) for pred_im in pred_ims])
56 | pixel_loss += sum([self.default(pred_im*pred_ms,target*mask) for pred_im in pred_ims])
57 | recov_imgs = [ self.denorm(pred_im*mask + (1-mask)*self.norm(target)) for pred_im in pred_ims ]
58 | wm_loss += self.wrloss(pred_wms, wm, mask)
59 | wm_loss += self.default(pred_wms*pred_ms, wm*mask)
60 |
61 | elif self.args.masked and 'relative' in self.args.loss_type: # masked loss
62 | pixel_loss += sum([self.outputLoss(pred_im, target, mask) for pred_im in pred_ims])
63 | recov_imgs = [ self.denorm(pred_im*mask + (1-mask)*self.norm(target)) for pred_im in pred_ims ]
64 | wm_loss = self.wrloss(pred_wms, wm, mask)
65 | elif self.args.masked:
66 | pixel_loss += sum([self.outputLoss(pred_im*mask, target*mask) for pred_im in pred_ims])
67 | recov_imgs = [ self.denorm(pred_im*pred_ms + (1-pred_ms)*self.norm(target)) for pred_im in pred_ims ]
68 | wm_loss = self.wrloss(pred_wms*mask, wm*mask)
69 | else:
70 | pixel_loss += sum([self.outputLoss(pred_im*pred_ms, target*mask) for pred_im in pred_ims])
71 | recov_imgs = [ self.denorm(pred_im*pred_ms + (1-pred_ms)*self.norm(target)) for pred_im in pred_ims ]
72 | wm_loss = self.wrloss(pred_wms*pred_ms,wm*mask)
73 |
74 | pixel_loss += sum([self.default(im,target) for im in recov_imgs])
75 |
76 | if self.args.style_loss > 0:
77 | vgg_loss = sum([self.vggloss(im,target,mask) for im in recov_imgs])
78 |
79 | if self.args.ssim_loss > 0:
80 | ssim_loss = sum([ 1 - self.ssimloss(im,target) for im in recov_imgs])
81 |
82 | att_loss = self.attLoss(pred_ms, mask)
83 |
84 | return pixel_loss,att_loss,wm_loss,vgg_loss,ssim_loss
85 |
86 |
87 | class VX(BasicMachine):
88 | def __init__(self,**kwargs):
89 | BasicMachine.__init__(self,**kwargs)
90 | self.loss = Losses(self.args, self.device, self.norm, self.denorm)
91 | self.model.set_optimizers()
92 | self.optimizer = None
93 |
94 | def train(self,epoch):
95 |
96 | self.current_epoch = epoch
97 |
98 | batch_time = AverageMeter()
99 | data_time = AverageMeter()
100 | losses = AverageMeter()
101 | lossMask = AverageMeter()
102 | lossWM = AverageMeter()
103 | lossMX = AverageMeter()
104 | lossvgg = AverageMeter()
105 | lossssim = AverageMeter()
106 |
107 | # switch to train mode
108 | self.model.train()
109 |
110 | end = time.time()
111 | bar = Bar('Processing {} '.format(self.args.arch), max=len(self.train_loader))
112 |
113 | for i, batches in enumerate(self.train_loader):
114 |
115 | current_index = len(self.train_loader) * epoch + i
116 |
117 | inputs = batches['image'].to(self.device)
118 | target = batches['target'].to(self.device)
119 | mask = batches['mask'].to(self.device)
120 | wm = batches['wm'].to(self.device)
121 |
122 | outputs = self.model(self.norm(inputs))
123 |
124 | self.model.zero_grad_all()
125 |
126 | l2_loss,att_loss,wm_loss,style_loss,ssim_loss = self.loss(outputs[0],self.norm(target),outputs[1],mask,outputs[2],self.norm(wm))
127 | total_loss = 2*l2_loss + self.args.att_loss * att_loss + wm_loss + self.args.style_loss * style_loss + self.args.ssim_loss * ssim_loss
128 |
129 | # compute gradient and do SGD step
130 | total_loss.backward()
131 | self.model.step_all()
132 |
133 | # measure accuracy and record loss
134 | losses.update(l2_loss.item(), inputs.size(0))
135 | lossMask.update(att_loss.item(), inputs.size(0))
136 | lossWM.update(wm_loss.item(), inputs.size(0))
137 |
138 | if self.args.style_loss > 0 :
139 | lossvgg.update(style_loss.item(), inputs.size(0))
140 |
141 | if self.args.ssim_loss > 0 :
142 | lossssim.update(ssim_loss.item(), inputs.size(0))
143 |
144 |
145 | # measure elapsed time
146 | batch_time.update(time.time() - end)
147 | end = time.time()
148 |
149 | # plot progress
150 | suffix = "({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss L2: {loss_label:.4f} | Loss Mask: {loss_mask:.4f} | loss WM: {loss_wm:.4f} | loss VGG: {loss_vgg:.4f} | loss SSIM: {loss_ssim:.4f}| loss MX: {loss_mx:.4f}".format(
151 | batch=i + 1,
152 | size=len(self.train_loader),
153 | data=data_time.val,
154 | bt=batch_time.val,
155 | total=bar.elapsed_td,
156 | eta=bar.eta_td,
157 | loss_label=losses.avg,
158 | loss_mask=lossMask.avg,
159 | loss_wm=lossWM.avg,
160 | loss_vgg=lossvgg.avg,
161 | loss_ssim=lossssim.avg,
162 | loss_mx=lossMX.avg
163 | )
164 | if current_index % 1000 == 0:
165 | print(suffix)
166 |
167 | if self.args.freq > 0 and current_index % self.args.freq == 0:
168 | self.validate(current_index)
169 | self.flush()
170 | self.save_checkpoint()
171 |
172 | self.record('train/loss_L2', losses.avg, epoch)
173 | self.record('train/loss_Mask', lossMask.avg, epoch)
174 | self.record('train/loss_WM', lossWM.avg, epoch)
175 | self.record('train/loss_VGG', lossvgg.avg, epoch)
176 | self.record('train/loss_SSIM', lossssim.avg, epoch)
177 | self.record('train/loss_MX', lossMX.avg, epoch)
178 |
179 |
180 |
181 |
182 | def validate(self, epoch):
183 |
184 | self.current_epoch = epoch
185 |
186 | batch_time = AverageMeter()
187 | data_time = AverageMeter()
188 | losses = AverageMeter()
189 | lossMask = AverageMeter()
190 | psnres = AverageMeter()
191 | ssimes = AverageMeter()
192 |
193 | # switch to evaluate mode
194 | self.model.eval()
195 |
196 | end = time.time()
197 | bar = Bar('Processing {} '.format(self.args.arch), max=len(self.val_loader))
198 | with torch.no_grad():
199 | for i, batches in enumerate(self.val_loader):
200 |
201 | current_index = len(self.val_loader) * epoch + i
202 |
203 | inputs = batches['image'].to(self.device)
204 | target = batches['target'].to(self.device)
205 |
206 | outputs = self.model(self.norm(inputs))
207 | imoutput,immask,imwatermark = outputs
208 | imoutput = imoutput[0] if is_dic(imoutput) else imoutput
209 |
210 | imfinal = self.denorm(imoutput*immask + self.norm(inputs)*(1-immask))
211 |
212 | if i % 300 == 0:
213 | # save the sample images
214 | ims = torch.cat([inputs,target,imfinal,immask.repeat(1,3,1,1)],dim=3)
215 | torchvision.utils.save_image(ims,os.path.join(self.args.checkpoint,'%s_%s.jpg'%(i,epoch)))
216 |
217 | # here two choice: mseLoss or NLLLoss
218 | psnr = 10 * log10(1 / F.mse_loss(imfinal,target).item())
219 |
220 | ssim = pytorch_ssim.ssim(imfinal,target)
221 |
222 | psnres.update(psnr, inputs.size(0))
223 | ssimes.update(ssim, inputs.size(0))
224 |
225 | # measure elapsed time
226 | batch_time.update(time.time() - end)
227 | end = time.time()
228 |
229 | # plot progress
230 | bar.suffix = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss_L2: {loss_label:.4f} | Loss_Mask: {loss_mask:.4f} | PSNR: {psnr:.4f} | SSIM: {ssim:.4f}'.format(
231 | batch=i + 1,
232 | size=len(self.val_loader),
233 | data=data_time.val,
234 | bt=batch_time.val,
235 | total=bar.elapsed_td,
236 | eta=bar.eta_td,
237 | loss_label=losses.avg,
238 | loss_mask=lossMask.avg,
239 | psnr=psnres.avg,
240 | ssim=ssimes.avg
241 | )
242 | bar.next()
243 | bar.finish()
244 |
245 | print("Iter:%s,Losses:%s,PSNR:%.4f,SSIM:%.4f"%(epoch, losses.avg,psnres.avg,ssimes.avg))
246 | self.record('val/loss_L2', losses.avg, epoch)
247 | self.record('val/lossMask', lossMask.avg, epoch)
248 | self.record('val/PSNR', psnres.avg, epoch)
249 | self.record('val/SSIM', ssimes.avg, epoch)
250 | self.metric = psnres.avg
251 |
252 | self.model.train()
253 |
254 | def test(self, ):
255 |
256 | # switch to evaluate mode
257 | self.model.eval()
258 | print("==> testing VM model ")
259 | ssimes = AverageMeter()
260 | psnres = AverageMeter()
261 | ssimesx = AverageMeter()
262 | psnresx = AverageMeter()
263 |
264 | with torch.no_grad():
265 | for i, batches in enumerate(tqdm(self.val_loader)):
266 |
267 | inputs = batches['image'].to(self.device)
268 | target = batches['target'].to(self.device)
269 | mask =batches['mask'].to(self.device)
270 |
271 | # select the outputs by the giving arch
272 | outputs = self.model(self.norm(inputs))
273 | imoutput,immask,imwatermark = outputs
274 | imoutput = imoutput[0] if is_dic(imoutput) else imoutput
275 |
276 | imfinal = self.denorm(imoutput*immask + self.norm(inputs)*(1-immask))
277 | psnrx = 10 * log10(1 / F.mse_loss(imfinal,target).item())
278 | ssimx = pytorch_ssim.ssim(imfinal,target)
279 | # recover the image to 255
280 | imfinal = im_to_numpy(torch.clamp(imfinal[0]*255,min=0.0,max=255.0)).astype(np.uint8)
281 | target = im_to_numpy(torch.clamp(target[0]*255,min=0.0,max=255.0)).astype(np.uint8)
282 |
283 | skimage.io.imsave('%s/%s'%(self.args.checkpoint,batches['name'][0]), imfinal)
284 |
285 | psnr = compare_psnr(target,imfinal)
286 | ssim = compare_ssim(target,imfinal,multichannel=True)
287 |
288 | psnres.update(psnr, inputs.size(0))
289 | ssimes.update(ssim, inputs.size(0))
290 | psnresx.update(psnrx, inputs.size(0))
291 | ssimesx.update(ssimx, inputs.size(0))
292 |
293 | print("%s:PSNR:%.5f(%.5f),SSIM:%.5f(%.5f)"%(self.args.checkpoint,psnres.avg,psnresx.avg,ssimes.avg,ssimesx.avg))
294 | print("DONE.\n")
--------------------------------------------------------------------------------
/scripts/models/vmu.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from scripts.models.blocks import SEBlock
6 | from scripts.models.rasc import *
7 | from scripts.models.unet import UnetGenerator,MinimalUnetV2
8 |
9 | def weight_init(m):
10 | if isinstance(m, nn.Conv2d):
11 | nn.init.xavier_normal_(m.weight)
12 | nn.init.constant_(m.bias, 0)
13 |
14 | def reset_params(model):
15 | for i, m in enumerate(model.modules()):
16 | weight_init(m)
17 |
18 |
19 | def conv3x3(in_channels, out_channels, stride=1,
20 | padding=1, bias=True, groups=1):
21 | return nn.Conv2d(
22 | in_channels,
23 | out_channels,
24 | kernel_size=3,
25 | stride=stride,
26 | padding=padding,
27 | bias=bias,
28 | groups=groups)
29 |
30 |
31 | def up_conv2x2(in_channels, out_channels, transpose=True):
32 | if transpose:
33 | return nn.ConvTranspose2d(
34 | in_channels,
35 | out_channels,
36 | kernel_size=2,
37 | stride=2)
38 | else:
39 | return nn.Sequential(
40 | nn.Upsample(mode='bilinear', scale_factor=2),
41 | conv1x1(in_channels, out_channels))
42 |
43 |
44 | def conv1x1(in_channels, out_channels, groups=1):
45 | return nn.Conv2d(
46 | in_channels,
47 | out_channels,
48 | kernel_size=1,
49 | groups=groups,
50 | stride=1)
51 |
52 |
53 |
54 |
55 | class UpCoXvD(nn.Module):
56 |
57 | def __init__(self, in_channels, out_channels, blocks, residual=True, batch_norm=True, transpose=True,concat=True,use_att=False):
58 | super(UpCoXvD, self).__init__()
59 | self.concat = concat
60 | self.residual = residual
61 | self.batch_norm = batch_norm
62 | self.bn = None
63 | self.conv2 = []
64 | self.use_att = use_att
65 | self.up_conv = up_conv2x2(in_channels, out_channels, transpose=transpose)
66 |
67 | if self.use_att:
68 | self.s2am = RASC(2 * out_channels)
69 | else:
70 | self.s2am = None
71 |
72 | if self.concat:
73 | self.conv1 = conv3x3(2 * out_channels, out_channels)
74 | else:
75 | self.conv1 = conv3x3(out_channels, out_channels)
76 | for _ in range(blocks):
77 | self.conv2.append(conv3x3(out_channels, out_channels))
78 | if self.batch_norm:
79 | self.bn = []
80 | for _ in range(blocks):
81 | self.bn.append(nn.BatchNorm2d(out_channels))
82 | self.bn = nn.ModuleList(self.bn)
83 | self.conv2 = nn.ModuleList(self.conv2)
84 |
85 | def forward(self, from_up, from_down, mask=None):
86 | from_up = self.up_conv(from_up)
87 | if self.concat:
88 | x1 = torch.cat((from_up, from_down), 1)
89 | else:
90 | if from_down is not None:
91 | x1 = from_up + from_down
92 | else:
93 | x1 = from_up
94 |
95 | if self.use_att:
96 | x1 = self.s2am(x1,mask)
97 |
98 | x1 = F.relu(self.conv1(x1))
99 | x2 = None
100 | for idx, conv in enumerate(self.conv2):
101 | x2 = conv(x1)
102 | if self.batch_norm:
103 | x2 = self.bn[idx](x2)
104 | if self.residual:
105 | x2 = x2 + x1
106 | x2 = F.relu(x2)
107 | x1 = x2
108 | return x2
109 |
110 |
111 | class DownCoXvD(nn.Module):
112 |
113 | def __init__(self, in_channels, out_channels, blocks, pooling=True, residual=True, batch_norm=True):
114 | super(DownCoXvD, self).__init__()
115 | self.pooling = pooling
116 | self.residual = residual
117 | self.batch_norm = batch_norm
118 | self.bn = None
119 | self.pool = None
120 | self.conv1 = conv3x3(in_channels, out_channels)
121 | self.conv2 = []
122 | for _ in range(blocks):
123 | self.conv2.append(conv3x3(out_channels, out_channels))
124 | if self.batch_norm:
125 | self.bn = []
126 | for _ in range(blocks):
127 | self.bn.append(nn.BatchNorm2d(out_channels))
128 | self.bn = nn.ModuleList(self.bn)
129 | if self.pooling:
130 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
131 | self.conv2 = nn.ModuleList(self.conv2)
132 |
133 | def __call__(self, x):
134 | return self.forward(x)
135 |
136 | def forward(self, x):
137 | x1 = F.relu(self.conv1(x))
138 | x2 = None
139 | for idx, conv in enumerate(self.conv2):
140 | x2 = conv(x1)
141 | if self.batch_norm:
142 | x2 = self.bn[idx](x2)
143 | if self.residual:
144 | x2 = x2 + x1
145 | x2 = F.relu(x2)
146 | x1 = x2
147 | before_pool = x2
148 | if self.pooling:
149 | x2 = self.pool(x2)
150 | return x2, before_pool
151 |
152 | class UnetDecoderD(nn.Module):
153 | def __init__(self, in_channels=512, out_channels=3, depth=5, blocks=1, residual=True, batch_norm=True,
154 | transpose=True, concat=True, is_final=True):
155 | super(UnetDecoderD, self).__init__()
156 | self.conv_final = None
157 | self.up_convs = []
158 | outs = in_channels
159 | for i in range(depth-1):
160 | ins = outs
161 | outs = ins // 2
162 | # 512,256
163 | # 256,128
164 | # 128,64
165 | # 64,32
166 | up_conv = UpCoXvD(ins, outs, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose,
167 | concat=concat)
168 | self.up_convs.append(up_conv)
169 | if is_final:
170 | self.conv_final = conv1x1(outs, out_channels)
171 | else:
172 | up_conv = UpCoXvD(outs, out_channels, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose,
173 | concat=concat)
174 | self.up_convs.append(up_conv)
175 | self.up_convs = nn.ModuleList(self.up_convs)
176 | reset_params(self)
177 |
178 | def __call__(self, x, encoder_outs=None):
179 | return self.forward(x, encoder_outs)
180 |
181 | def forward(self, x, encoder_outs=None):
182 | for i, up_conv in enumerate(self.up_convs):
183 | before_pool = None
184 | if encoder_outs is not None:
185 | before_pool = encoder_outs[-(i+2)]
186 | x = up_conv(x, before_pool)
187 | if self.conv_final is not None:
188 | x = self.conv_final(x)
189 | return x
190 |
191 |
192 | class UnetEncoderD(nn.Module):
193 |
194 | def __init__(self, in_channels=3, depth=5, blocks=1, start_filters=32, residual=True, batch_norm=True):
195 | super(UnetEncoderD, self).__init__()
196 | self.down_convs = []
197 | outs = None
198 | if type(blocks) is tuple:
199 | blocks = blocks[0]
200 | for i in range(depth):
201 | ins = in_channels if i == 0 else outs
202 | outs = start_filters*(2**i)
203 | pooling = True if i < depth-1 else False
204 | down_conv = DownCoXvD(ins, outs, blocks, pooling=pooling, residual=residual, batch_norm=batch_norm)
205 | self.down_convs.append(down_conv)
206 | self.down_convs = nn.ModuleList(self.down_convs)
207 | reset_params(self)
208 |
209 | def __call__(self, x):
210 | return self.forward(x)
211 |
212 | def forward(self, x):
213 | encoder_outs = []
214 | for d_conv in self.down_convs:
215 | x, before_pool = d_conv(x)
216 | encoder_outs.append(before_pool)
217 | return x, encoder_outs
218 |
219 |
220 |
221 | class UnetVM(nn.Module):
222 |
223 | def __init__(self, in_channels=3, depth=5, shared_depth=0, use_vm_decoder=False, blocks=1,
224 | out_channels_image=3, out_channels_mask=1, start_filters=32, residual=True, batch_norm=True,
225 | transpose=True, concat=True, transfer_data=True, long_skip=False):
226 | super(UnetVM, self).__init__()
227 | self.transfer_data = transfer_data
228 | self.shared = shared_depth
229 | self.optimizer_encoder, self.optimizer_image, self.optimizer_vm = None, None, None
230 | self.optimizer_mask, self.optimizer_shared = None, None
231 | if type(blocks) is not tuple:
232 | blocks = (blocks, blocks, blocks, blocks, blocks)
233 | if not transfer_data:
234 | concat = False
235 | self.encoder = UnetEncoderD(in_channels=in_channels, depth=depth, blocks=blocks[0],
236 | start_filters=start_filters, residual=residual, batch_norm=batch_norm)
237 | self.image_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1),
238 | out_channels=out_channels_image, depth=depth - shared_depth,
239 | blocks=blocks[1], residual=residual, batch_norm=batch_norm,
240 | transpose=transpose, concat=concat)
241 | self.mask_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - 1),
242 | out_channels=out_channels_mask, depth=depth,
243 | blocks=blocks[2], residual=residual, batch_norm=batch_norm,
244 | transpose=transpose, concat=concat)
245 | self.vm_decoder = None
246 | if use_vm_decoder:
247 | self.vm_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1),
248 | out_channels=out_channels_image, depth=depth - shared_depth,
249 | blocks=blocks[3], residual=residual, batch_norm=batch_norm,
250 | transpose=transpose, concat=concat)
251 | self.shared_decoder = None
252 | self.long_skip = long_skip
253 | self._forward = self.unshared_forward
254 | if self.shared != 0:
255 | self._forward = self.shared_forward
256 | self.shared_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - 1),
257 | out_channels=start_filters * 2 ** (depth - shared_depth - 1),
258 | depth=shared_depth, blocks=blocks[4], residual=residual,
259 | batch_norm=batch_norm, transpose=transpose, concat=concat,
260 | is_final=False)
261 |
262 | def set_optimizers(self):
263 | self.optimizer_encoder = torch.optim.Adam(self.encoder.parameters(), lr=0.001)
264 | self.optimizer_image = torch.optim.Adam(self.image_decoder.parameters(), lr=0.001)
265 | self.optimizer_mask = torch.optim.Adam(self.mask_decoder.parameters(), lr=0.001)
266 | if self.vm_decoder is not None:
267 | self.optimizer_vm = torch.optim.Adam(self.vm_decoder.parameters(), lr=0.001)
268 | if self.shared != 0:
269 | self.optimizer_shared = torch.optim.Adam(self.shared_decoder.parameters(), lr=0.001)
270 |
271 | def zero_grad_all(self):
272 | self.optimizer_encoder.zero_grad()
273 | self.optimizer_image.zero_grad()
274 | self.optimizer_mask.zero_grad()
275 | if self.vm_decoder is not None:
276 | self.optimizer_vm.zero_grad()
277 | if self.shared != 0:
278 | self.optimizer_shared.zero_grad()
279 |
280 | def step_all(self):
281 | self.optimizer_encoder.step()
282 | self.optimizer_image.step()
283 | self.optimizer_mask.step()
284 | if self.vm_decoder is not None:
285 | self.optimizer_vm.step()
286 | if self.shared != 0:
287 | self.optimizer_shared.step()
288 |
289 | def step_optimizer_image(self):
290 | self.optimizer_image.step()
291 |
292 | def __call__(self, synthesized):
293 | return self._forward(synthesized)
294 |
295 | def forward(self, synthesized):
296 | return self._forward(synthesized)
297 |
298 | def unshared_forward(self, synthesized):
299 | image_code, before_pool = self.encoder(synthesized)
300 | if not self.transfer_data:
301 | before_pool = None
302 | reconstructed_image = torch.tanh(self.image_decoder(image_code, before_pool))
303 | reconstructed_mask = torch.sigmoid(self.mask_decoder(image_code, before_pool))
304 | if self.vm_decoder is not None:
305 | reconstructed_vm = torch.tanh(self.vm_decoder(image_code, before_pool))
306 | return reconstructed_image, reconstructed_mask, reconstructed_vm
307 | return reconstructed_image, reconstructed_mask
308 |
309 | def shared_forward(self, synthesized):
310 | image_code, before_pool = self.encoder(synthesized)
311 | if self.transfer_data:
312 | shared_before_pool = before_pool[- self.shared - 1:]
313 | unshared_before_pool = before_pool[: - self.shared]
314 | else:
315 | before_pool = None
316 | shared_before_pool = None
317 | unshared_before_pool = None
318 | x = self.shared_decoder(image_code, shared_before_pool)
319 | reconstructed_image = torch.tanh(self.image_decoder(x, unshared_before_pool))
320 | if self.long_skip:
321 | reconstructed_image = reconstructed_image + synthesized
322 |
323 | reconstructed_mask = torch.sigmoid(self.mask_decoder(image_code, before_pool))
324 | if self.vm_decoder is not None:
325 | reconstructed_vm = torch.tanh(self.vm_decoder(x, unshared_before_pool))
326 | if self.long_skip:
327 | reconstructed_vm = reconstructed_vm + synthesized
328 | return reconstructed_image, reconstructed_mask, reconstructed_vm
329 | return reconstructed_image, reconstructed_mask
330 |
--------------------------------------------------------------------------------
/scripts/models/sa_resunet.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from scripts.models.blocks import SEBlock
6 | from scripts.models.rasc import *
7 | from scripts.models.unet import UnetGenerator,MinimalUnetV2
8 |
9 | def weight_init(m):
10 | if isinstance(m, nn.Conv2d):
11 | nn.init.xavier_normal_(m.weight)
12 | nn.init.constant_(m.bias, 0)
13 |
14 | def reset_params(model):
15 | for i, m in enumerate(model.modules()):
16 | weight_init(m)
17 |
18 |
19 | def conv3x3(in_channels, out_channels, stride=1,
20 | padding=1, bias=True, groups=1):
21 | return nn.Conv2d(
22 | in_channels,
23 | out_channels,
24 | kernel_size=3,
25 | stride=stride,
26 | padding=padding,
27 | bias=bias,
28 | groups=groups)
29 |
30 |
31 | def up_conv2x2(in_channels, out_channels, transpose=True):
32 | if transpose:
33 | return nn.ConvTranspose2d(
34 | in_channels,
35 | out_channels,
36 | kernel_size=2,
37 | stride=2)
38 | else:
39 | return nn.Sequential(
40 | nn.Upsample(mode='bilinear', scale_factor=2),
41 | conv1x1(in_channels, out_channels))
42 |
43 |
44 | def conv1x1(in_channels, out_channels, groups=1):
45 | return nn.Conv2d(
46 | in_channels,
47 | out_channels,
48 | kernel_size=1,
49 | groups=groups,
50 | stride=1)
51 |
52 |
53 | class UpCoXvD(nn.Module):
54 |
55 | def __init__(self, in_channels, out_channels, blocks, residual=True,norm=nn.BatchNorm2d, act=F.relu,batch_norm=True, transpose=True,concat=True,use_att=False):
56 | super(UpCoXvD, self).__init__()
57 | self.concat = concat
58 | self.residual = residual
59 | self.batch_norm = batch_norm
60 | self.bn = None
61 | self.conv2 = []
62 | self.use_att = use_att
63 | self.up_conv = up_conv2x2(in_channels, out_channels, transpose=transpose)
64 | self.norm0 = norm(out_channels)
65 |
66 | if self.use_att:
67 | self.s2am = RASC(2 * out_channels)
68 | else:
69 | self.s2am = None
70 |
71 | if self.concat:
72 | self.conv1 = conv3x3(2 * out_channels, out_channels)
73 | self.norm1 = norm(out_channels , out_channels)
74 | else:
75 | self.conv1 = conv3x3(out_channels, out_channels)
76 | self.norm1 = norm(out_channels , out_channels)
77 |
78 | for _ in range(blocks):
79 | self.conv2.append(conv3x3(out_channels, out_channels))
80 | if self.batch_norm:
81 | self.bn = []
82 | for _ in range(blocks):
83 | self.bn.append(norm(out_channels))
84 | self.bn = nn.ModuleList(self.bn)
85 | self.conv2 = nn.ModuleList(self.conv2)
86 | self.act = act
87 |
88 | def forward(self, from_up, from_down, mask=None,se=None):
89 | from_up = self.act(self.norm0(self.up_conv(from_up)))
90 | if self.concat:
91 | x1 = torch.cat((from_up, from_down), 1)
92 | else:
93 | if from_down is not None:
94 | x1 = from_up + from_down
95 | else:
96 | x1 = from_up
97 |
98 | if self.use_att:
99 | x1 = self.s2am(x1,mask)
100 |
101 | x1 = self.act(self.norm1(self.conv1(x1)))
102 | x2 = None
103 | for idx, conv in enumerate(self.conv2):
104 | x2 = conv(x1)
105 | if self.batch_norm:
106 | x2 = self.bn[idx](x2)
107 |
108 | if (se is not None) and (idx == len(self.conv2) - 1): # last
109 | x2 = se(x2)
110 |
111 | if self.residual:
112 | x2 = x2 + x1
113 | x2 = self.act(x2)
114 | x1 = x2
115 | return x2
116 |
117 |
118 | class DownCoXvD(nn.Module):
119 |
120 | def __init__(self, in_channels, out_channels, blocks, pooling=True, norm=nn.BatchNorm2d,act=F.relu,residual=True, batch_norm=True):
121 | super(DownCoXvD, self).__init__()
122 | self.pooling = pooling
123 | self.residual = residual
124 | self.batch_norm = batch_norm
125 | self.bn = None
126 | self.pool = None
127 | self.conv1 = conv3x3(in_channels, out_channels)
128 | self.norm1 = norm(out_channels)
129 |
130 | self.conv2 = []
131 | for _ in range(blocks):
132 | self.conv2.append(conv3x3(out_channels, out_channels))
133 | if self.batch_norm:
134 | self.bn = []
135 | for _ in range(blocks):
136 | self.bn.append(norm(out_channels))
137 | self.bn = nn.ModuleList(self.bn)
138 | if self.pooling:
139 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
140 | self.conv2 = nn.ModuleList(self.conv2)
141 | self.act = act
142 |
143 | def __call__(self, x):
144 | return self.forward(x)
145 |
146 | def forward(self, x):
147 | x1 = self.act(self.norm1(self.conv1(x)))
148 | x2 = None
149 | for idx, conv in enumerate(self.conv2):
150 | x2 = conv(x1)
151 | if self.batch_norm:
152 | x2 = self.bn[idx](x2)
153 | if self.residual:
154 | x2 = x2 + x1
155 | x2 = self.act(x2)
156 | x1 = x2
157 | before_pool = x2
158 | if self.pooling:
159 | x2 = self.pool(x2)
160 | return x2, before_pool
161 |
162 | class UnetDecoderD(nn.Module):
163 | def __init__(self, in_channels=512, out_channels=3, norm=nn.BatchNorm2d,act=F.relu, depth=5, blocks=1, residual=True, batch_norm=True,
164 | transpose=True, concat=True, is_final=True, use_att=False):
165 | super(UnetDecoderD, self).__init__()
166 | self.conv_final = None
167 | self.up_convs = []
168 | self.atts = []
169 | self.use_att = use_att
170 |
171 | outs = in_channels
172 | for i in range(depth-1): # depth = 1
173 | ins = outs
174 | outs = ins // 2
175 | # 512,256
176 | # 256,128
177 | # 128,64
178 | # 64,32
179 | up_conv = UpCoXvD(ins, outs, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose,
180 | concat=concat, norm=norm, act=act)
181 | if self.use_att:
182 | self.atts.append(SEBlock(outs))
183 |
184 | self.up_convs.append(up_conv)
185 |
186 | if is_final:
187 | self.conv_final = conv1x1(outs, out_channels)
188 | else:
189 | up_conv = UpCoXvD(outs, out_channels, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose,
190 | concat=concat,norm=norm, act=act)
191 | if self.use_att:
192 | self.atts.append(SEBlock(out_channels))
193 |
194 | self.up_convs.append(up_conv)
195 | self.up_convs = nn.ModuleList(self.up_convs)
196 | self.atts = nn.ModuleList(self.atts)
197 |
198 | reset_params(self)
199 |
200 | def __call__(self, x, encoder_outs=None):
201 | return self.forward(x, encoder_outs)
202 |
203 | def forward(self, x, encoder_outs=None):
204 | for i, up_conv in enumerate(self.up_convs):
205 | before_pool = None
206 | if encoder_outs is not None:
207 | before_pool = encoder_outs[-(i+2)]
208 | x = up_conv(x, before_pool)
209 | if self.use_att:
210 | x = self.atts[i](x)
211 |
212 | if self.conv_final is not None:
213 | x = self.conv_final(x)
214 | return x
215 |
216 |
217 | class UnetDecoderDatt(nn.Module):
218 | def __init__(self, in_channels=512, out_channels=3, depth=5, blocks=1, residual=True, batch_norm=True,
219 | transpose=True, concat=True, is_final=True, norm=nn.BatchNorm2d,act=F.relu):
220 | super(UnetDecoderDatt, self).__init__()
221 | self.conv_final = None
222 | self.up_convs = []
223 | self.im_atts = []
224 | self.vm_atts = []
225 | self.mask_atts = []
226 |
227 | outs = in_channels
228 | for i in range(depth-1): # depth = 5 [0,1,2,3]
229 | ins = outs
230 | outs = ins // 2
231 | # 512,256
232 | # 256,128
233 | # 128,64
234 | # 64,32
235 | up_conv = UpCoXvD(ins, outs, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose,
236 | concat=concat, norm=nn.BatchNorm2d,act=F.relu)
237 | self.up_convs.append(up_conv)
238 | self.im_atts.append(SEBlock(outs))
239 | self.vm_atts.append(SEBlock(outs))
240 | self.mask_atts.append(SEBlock(outs))
241 | if is_final:
242 | self.conv_final = conv1x1(outs, out_channels)
243 | else:
244 | up_conv = UpCoXvD(outs, out_channels, blocks, residual=residual, batch_norm=batch_norm, transpose=transpose,
245 | concat=concat, norm=nn.BatchNorm2d,act=F.relu)
246 | self.up_convs.append(up_conv)
247 | self.im_atts.append(SEBlock(out_channels))
248 | self.vm_atts.append(SEBlock(out_channels))
249 | self.mask_atts.append(SEBlock(out_channels))
250 |
251 | self.up_convs = nn.ModuleList(self.up_convs)
252 | self.im_atts = nn.ModuleList(self.im_atts)
253 | self.vm_atts = nn.ModuleList(self.vm_atts)
254 | self.mask_atts = nn.ModuleList(self.mask_atts)
255 |
256 | reset_params(self)
257 |
258 | def forward(self, input, encoder_outs=None):
259 | # im branch
260 | x = input
261 | for i, up_conv in enumerate(self.up_convs):
262 | before_pool = None
263 | if encoder_outs is not None:
264 | before_pool = encoder_outs[-(i+2)]
265 | x = up_conv(x, before_pool,se=self.im_atts[i])
266 | x_im = x
267 |
268 | x = input
269 | for i, up_conv in enumerate(self.up_convs):
270 | before_pool = None
271 | if encoder_outs is not None:
272 | before_pool = encoder_outs[-(i+2)]
273 | x = up_conv(x, before_pool, se = self.mask_atts[i])
274 | x_mask = x
275 |
276 | x = input
277 | for i, up_conv in enumerate(self.up_convs):
278 | before_pool = None
279 | if encoder_outs is not None:
280 | before_pool = encoder_outs[-(i+2)]
281 | x = up_conv(x, before_pool, se=self.vm_atts[i])
282 | x_vm = x
283 |
284 | return x_im,x_mask,x_vm
285 |
286 | class UnetEncoderD(nn.Module):
287 |
288 | def __init__(self, in_channels=3, depth=5, blocks=1, start_filters=32, residual=True, batch_norm=True, norm=nn.BatchNorm2d, act=F.relu):
289 | super(UnetEncoderD, self).__init__()
290 | self.down_convs = []
291 | outs = None
292 | if type(blocks) is tuple:
293 | blocks = blocks[0]
294 | for i in range(depth):
295 | ins = in_channels if i == 0 else outs
296 | outs = start_filters*(2**i)
297 | pooling = True if i < depth-1 else False
298 | down_conv = DownCoXvD(ins, outs, blocks, pooling=pooling, residual=residual, batch_norm=batch_norm, norm=nn.BatchNorm2d, act=F.relu)
299 | self.down_convs.append(down_conv)
300 | self.down_convs = nn.ModuleList(self.down_convs)
301 | reset_params(self)
302 |
303 | def __call__(self, x):
304 | return self.forward(x)
305 |
306 | def forward(self, x):
307 | encoder_outs = []
308 | for d_conv in self.down_convs:
309 | x, before_pool = d_conv(x)
310 | encoder_outs.append(before_pool)
311 | return x, encoder_outs
312 |
313 | class ResDown(nn.Module):
314 | def __init__(self, in_size, out_size, pooling=True, use_att=False):
315 | super(ResDown, self).__init__()
316 | self.model = DownCoXvD(in_size, out_size, 3, pooling=pooling)
317 |
318 | def forward(self, x):
319 | return self.model(x)
320 |
321 | class ResUp(nn.Module):
322 | def __init__(self, in_size, out_size, use_att=False):
323 | super(ResUp, self).__init__()
324 | self.model = UpCoXvD(in_size, out_size, 3, use_att=use_att)
325 |
326 | def forward(self, x, skip_input, mask=None):
327 | return self.model(x,skip_input,mask)
328 |
329 | class ResDownNew(nn.Module):
330 | def __init__(self, in_size, out_size, pooling=True, use_att=False):
331 | super(ResDownNew, self).__init__()
332 | self.model = DownCoXvD(in_size, out_size, 3, pooling=pooling, norm=nn.InstanceNorm2d, act=F.leaky_relu)
333 |
334 | def forward(self, x):
335 | return self.model(x)
336 |
337 | class ResUpNew(nn.Module):
338 | def __init__(self, in_size, out_size, use_att=False):
339 | super(ResUpNew, self).__init__()
340 | self.model = UpCoXvD(in_size, out_size, 3, use_att=use_att, norm=nn.InstanceNorm2d)
341 |
342 | def forward(self, x, skip_input, mask=None):
343 | return self.model(x,skip_input,mask)
344 |
345 |
346 |
347 | class VMSingle(nn.Module):
348 | def __init__(self, in_channels=3, out_channels=3, down=ResDown, up=ResUp, ngf=32, res=True,use_att=False):
349 | super(VMSingle, self).__init__()
350 |
351 | self.down1 = down(in_channels, ngf)
352 | self.down2 = down(ngf, ngf*2)
353 | self.down3 = down(ngf*2, ngf*4)
354 | self.down4 = down(ngf*4, ngf*8)
355 | self.down5 = down(ngf*8, ngf*16, pooling=False)
356 |
357 | self.up1 = up(ngf*16, ngf*8)
358 | self.up2 = up(ngf*8, ngf*4, use_att=use_att)
359 | self.up3 = up(ngf*4, ngf*2, use_att=use_att)
360 | self.up4 = up(ngf*2, ngf*1, use_att=use_att)
361 |
362 | self.im = nn.Conv2d(ngf, 3, 1)
363 | self.res = res
364 |
365 |
366 | def forward(self, input):
367 | img, mask = input[:,0:3,:,:],input[:,3:4,:,:]
368 | # U-Net generator with skip connections from encoder to decoder
369 | x,d1 = self.down1(input) # 128,256
370 | x,d2 = self.down2(x) # 64,128
371 | x,d3 = self.down3(x) # 32,64
372 | x,d4 = self.down4(x) # 16,32
373 | x,_ = self.down5(x) # 8,16
374 |
375 | x = self.up1(x, d4) # 16
376 | x = self.up2(x, d3, mask) # 32
377 | x = self.up3(x, d2, mask) # 64
378 | x = self.up4(x, d1, mask) # 128
379 | im = self.im(x)
380 |
381 | return im
382 |
383 |
384 |
385 | class VMSingleS2AM(nn.Module):
386 | def __init__(self, in_channels=3, out_channels=3, down=ResDown, up=ResUp, ngf=32):
387 | super(VMSingleS2AM, self).__init__()
388 |
389 | self.down1 = down(in_channels, ngf)
390 | self.down2 = down(ngf, ngf*2)
391 | self.down3 = down(ngf*2, ngf*4)
392 | self.down4 = down(ngf*4, ngf*8)
393 | self.down5 = down(ngf*8, ngf*16, pooling=False)
394 |
395 | self.up1 = up(ngf*16, ngf*8)
396 | self.up2 = up(ngf*8, ngf*4)
397 | self.s2am2 = RASC(ngf*4)
398 |
399 | self.up3 = up(ngf*4, ngf*2)
400 | self.s2am3 = RASC(ngf*2)
401 |
402 | self.up4 = up(ngf*2, ngf*1)
403 | self.s2am4 = RASC(ngf)
404 |
405 | self.im = nn.Conv2d(ngf, 3, 1)
406 |
407 |
408 | def forward(self, input):
409 | img, mask = input[:,0:3,:,:],input[:,3:4,:,:]
410 | # U-Net generator with skip connections from encoder to decoder
411 | x,d1 = self.down1(input) # 128,256
412 | x,d2 = self.down2(x) # 64,128
413 | x,d3 = self.down3(x) # 32,64
414 | x,d4 = self.down4(x) # 16,32
415 | x,_ = self.down5(x) # 8,16
416 |
417 | x = self.up1(x, d4) # 16
418 | x = self.up2(x, d3) # 32
419 | x = self.s2am2(x, mask)
420 |
421 | x = self.up3(x, d2) # 64
422 | x = self.s2am3(x, mask)
423 |
424 | x = self.up4(x, d1) # 128
425 | x = self.s2am4(x, mask)
426 | im = self.im(x)
427 | return im
428 |
429 |
430 | class UnetVMS2AMv4(nn.Module):
431 |
432 | def __init__(self, in_channels=3, depth=5, shared_depth=0, use_vm_decoder=False, blocks=1,
433 | out_channels_image=3, out_channels_mask=1, start_filters=32, residual=True, batch_norm=True,
434 | transpose=True, concat=True, transfer_data=True, long_skip=False, s2am='unet', use_coarser=True,no_stage2=False):
435 | super(UnetVMS2AMv4, self).__init__()
436 | self.transfer_data = transfer_data
437 | self.shared = shared_depth
438 | self.optimizer_encoder, self.optimizer_image, self.optimizer_vm = None, None, None
439 | self.optimizer_mask, self.optimizer_shared = None, None
440 | if type(blocks) is not tuple:
441 | blocks = (blocks, blocks, blocks, blocks, blocks)
442 | if not transfer_data:
443 | concat = False
444 | self.encoder = UnetEncoderD(in_channels=in_channels, depth=depth, blocks=blocks[0],
445 | start_filters=start_filters, residual=residual, batch_norm=batch_norm,norm=nn.InstanceNorm2d,act=F.leaky_relu)
446 | self.image_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1),
447 | out_channels=out_channels_image, depth=depth - shared_depth,
448 | blocks=blocks[1], residual=residual, batch_norm=batch_norm,
449 | transpose=transpose, concat=concat,norm=nn.InstanceNorm2d)
450 | self.mask_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1),
451 | out_channels=out_channels_mask, depth=depth - shared_depth,
452 | blocks=blocks[2], residual=residual, batch_norm=batch_norm,
453 | transpose=transpose, concat=concat,norm=nn.InstanceNorm2d)
454 | self.vm_decoder = None
455 | if use_vm_decoder:
456 | self.vm_decoder = UnetDecoderD(in_channels=start_filters * 2 ** (depth - shared_depth - 1),
457 | out_channels=out_channels_image, depth=depth - shared_depth,
458 | blocks=blocks[3], residual=residual, batch_norm=batch_norm,
459 | transpose=transpose, concat=concat,norm=nn.InstanceNorm2d)
460 | self.shared_decoder = None
461 | self.use_coarser = use_coarser
462 | self.long_skip = long_skip
463 | self.no_stage2 = no_stage2
464 | self._forward = self.unshared_forward
465 | if self.shared != 0:
466 | self._forward = self.shared_forward
467 | self.shared_decoder = UnetDecoderDatt(in_channels=start_filters * 2 ** (depth - 1),
468 | out_channels=start_filters * 2 ** (depth - shared_depth - 1),
469 | depth=shared_depth, blocks=blocks[4], residual=residual,
470 | batch_norm=batch_norm, transpose=transpose, concat=concat,
471 | is_final=False,norm=nn.InstanceNorm2d)
472 |
473 | if s2am == 'unet':
474 | self.s2am = UnetGenerator(4,3,is_attention_layer=True,attention_model=RASC,basicblock=MinimalUnetV2)
475 | elif s2am == 'vm':
476 | self.s2am = VMSingle(4)
477 | elif s2am == 'vms2am':
478 | self.s2am = VMSingleS2AM(4,down=ResDownNew,up=ResUpNew)
479 |
480 | def set_optimizers(self):
481 | self.optimizer_encoder = torch.optim.Adam(self.encoder.parameters(), lr=0.001)
482 | self.optimizer_image = torch.optim.Adam(self.image_decoder.parameters(), lr=0.001)
483 | self.optimizer_mask = torch.optim.Adam(self.mask_decoder.parameters(), lr=0.001)
484 | self.optimizer_s2am = torch.optim.Adam(self.s2am.parameters(), lr=0.001)
485 |
486 | if self.vm_decoder is not None:
487 | self.optimizer_vm = torch.optim.Adam(self.vm_decoder.parameters(), lr=0.001)
488 | if self.shared != 0:
489 | self.optimizer_shared = torch.optim.Adam(self.shared_decoder.parameters(), lr=0.001)
490 |
491 | def zero_grad_all(self):
492 | self.optimizer_encoder.zero_grad()
493 | self.optimizer_image.zero_grad()
494 | self.optimizer_mask.zero_grad()
495 | self.optimizer_s2am.zero_grad()
496 | if self.vm_decoder is not None:
497 | self.optimizer_vm.zero_grad()
498 | if self.shared != 0:
499 | self.optimizer_shared.zero_grad()
500 |
501 | def step_all(self):
502 | self.optimizer_encoder.step()
503 | self.optimizer_image.step()
504 | self.optimizer_mask.step()
505 | self.optimizer_s2am.step()
506 | if self.vm_decoder is not None:
507 | self.optimizer_vm.step()
508 | if self.shared != 0:
509 | self.optimizer_shared.step()
510 |
511 | def step_optimizer_image(self):
512 | self.optimizer_image.step()
513 |
514 | def __call__(self, synthesized):
515 | return self._forward(synthesized)
516 |
517 | def forward(self, synthesized):
518 | return self._forward(synthesized)
519 |
520 | def unshared_forward(self, synthesized):
521 | image_code, before_pool = self.encoder(synthesized)
522 | if not self.transfer_data:
523 | before_pool = None
524 | reconstructed_image = torch.tanh(self.image_decoder(image_code, before_pool))
525 | reconstructed_mask = torch.sigmoid(self.mask_decoder(image_code, before_pool))
526 | if self.vm_decoder is not None:
527 | reconstructed_vm = torch.tanh(self.vm_decoder(image_code, before_pool))
528 | return reconstructed_image, reconstructed_mask, reconstructed_vm
529 | return reconstructed_image, reconstructed_mask
530 |
531 | def shared_forward(self, synthesized):
532 | image_code, before_pool = self.encoder(synthesized)
533 | if self.transfer_data:
534 | shared_before_pool = before_pool[- self.shared - 1:]
535 | unshared_before_pool = before_pool[: - self.shared]
536 | else:
537 | before_pool = None
538 | shared_before_pool = None
539 | unshared_before_pool = None
540 | im,mask,vm = self.shared_decoder(image_code, shared_before_pool)
541 | reconstructed_image = torch.tanh(self.image_decoder(im, unshared_before_pool))
542 | if self.long_skip:
543 | reconstructed_image = reconstructed_image + synthesized
544 |
545 | reconstructed_mask = torch.sigmoid(self.mask_decoder(mask, unshared_before_pool))
546 | if self.vm_decoder is not None:
547 | reconstructed_vm = torch.tanh(self.vm_decoder(vm, unshared_before_pool))
548 | if self.long_skip:
549 | reconstructed_vm = reconstructed_vm + synthesized
550 |
551 | coarser = reconstructed_image * reconstructed_mask + (1-reconstructed_mask)* synthesized
552 |
553 | if self.use_coarser:
554 | refine = torch.tanh(self.s2am(torch.cat([coarser,reconstructed_mask],dim=1))) + coarser
555 | elif self.no_stage2:
556 | refine = torch.tanh(self.s2am(torch.cat([coarser,reconstructed_mask],dim=1)))
557 | else:
558 | refine = torch.tanh(self.s2am(torch.cat([coarser,reconstructed_mask],dim=1))) + synthesized
559 |
560 | # final = refine * reconstructed_mask + (1-reconstructed_mask)* synthesized
561 | if self.vm_decoder is not None:
562 | return [refine, reconstructed_image], reconstructed_mask, reconstructed_vm
563 | else:
564 | return [refine, reconstructed_image], reconstructed_mask
565 |
566 |
567 |
--------------------------------------------------------------------------------