├── src ├── networks │ ├── __init__.py │ ├── methods.py │ ├── discriminator.py │ ├── resunet.py │ └── blocks.py ├── utils │ ├── __init__.py │ ├── osutils.py │ ├── model_init.py │ ├── misc.py │ ├── transforms.py │ ├── imutils.py │ ├── parallel.py │ └── losses.py ├── models │ ├── __init__.py │ ├── BasicModel.py │ └── SLBR.py └── __init__.py ├── figs ├── blocks.jpg ├── framework.png ├── bg_comparison.png └── mask_comparison.png ├── pytorch_iou ├── __pycache__ │ └── __init__.cpython-37.pyc └── __init__.py ├── pytorch_ssim ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── __init__.cpython-37.pyc └── __init__.py ├── requirements.txt ├── scripts ├── test_custom.sh ├── test.sh └── train.sh ├── datasets ├── __init__.py ├── lvw_dataset.py ├── clwd_dataset.py └── base_dataset.py ├── train.py ├── test_custom.py ├── README.md ├── evaluation.py ├── test.py └── options.py /src/networks/__init__.py: -------------------------------------------------------------------------------- 1 | from .methods import * 2 | from .discriminator import * 3 | 4 | -------------------------------------------------------------------------------- /figs/blocks.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SLBR-Visible-Watermark-Removal/HEAD/figs/blocks.jpg -------------------------------------------------------------------------------- /figs/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SLBR-Visible-Watermark-Removal/HEAD/figs/framework.png -------------------------------------------------------------------------------- /figs/bg_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SLBR-Visible-Watermark-Removal/HEAD/figs/bg_comparison.png -------------------------------------------------------------------------------- /figs/mask_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SLBR-Visible-Watermark-Removal/HEAD/figs/mask_comparison.png -------------------------------------------------------------------------------- /pytorch_iou/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SLBR-Visible-Watermark-Removal/HEAD/pytorch_iou/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /pytorch_ssim/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SLBR-Visible-Watermark-Removal/HEAD/pytorch_ssim/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /pytorch_ssim/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/SLBR-Visible-Watermark-Removal/HEAD/pytorch_ssim/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .imutils import * 4 | from .misc import * 5 | from .osutils import * 6 | from .transforms import * 7 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .BasicModel import BasicModel 3 | from .SLBR import SLBR 4 | 5 | 6 | def basic(**kwargs): 7 | return BasicModel(**kwargs) 8 | 9 | def slbr(**kwargs): 10 | return SLBR(**kwargs) 11 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from . import networks 4 | from . import utils 5 | from . import models 6 | 7 | # import os, sys 8 | # sys.path.append(os.path.join(os.path.dirname(__file__), "progress")) 9 | # from progress.bar import Bar as Bar 10 | 11 | # __version__ = '0.1.0' -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==2.2.2 2 | albumentations==0.4.5 3 | scipy==1.4.1 4 | scikit_image==0.17.2 5 | torch==1.6.0 6 | tqdm==4.46.1 7 | progress==1.5 8 | numpy==1.18.1 9 | torchvision==0.6.0 10 | opencv_python_headless==4.2.0.34 11 | Pillow==8.3.2 12 | scikit_learn==1.0 13 | skimage==0.0 14 | tensorboardX==2.4 15 | -------------------------------------------------------------------------------- /src/networks/methods.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | import torchvision 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import numpy as np 8 | import functools 9 | import math 10 | 11 | from src.utils.model_init import * 12 | from src.networks.resunet import SLBR 13 | 14 | 15 | # our method 16 | def slbr(**kwargs): 17 | return SLBR(args=kwargs['args'], shared_depth=1, blocks=3, long_skip=True) 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /src/utils/osutils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import os 4 | import errno 5 | 6 | def mkdir_p(dir_path): 7 | try: 8 | os.makedirs(dir_path) 9 | except OSError as e: 10 | if e.errno != errno.EEXIST: 11 | raise 12 | 13 | def isfile(fname): 14 | return os.path.isfile(fname) 15 | 16 | def isdir(dirname): 17 | return os.path.isdir(dirname) 18 | 19 | def join(path, *paths): 20 | return os.path.join(path, *paths) 21 | -------------------------------------------------------------------------------- /scripts/test_custom.sh: -------------------------------------------------------------------------------- 1 | K_CENTER=2 2 | K_REFINE=3 3 | K_SKIP=3 4 | MASK_MODE=res 5 | 6 | 7 | INPUT_SIZE=256 8 | NAME=slbr_v1 9 | TEST_DIR=/media/sda/Watermark 10 | 11 | CUDA_VISIBLE_DEVICES=1 python3 test_custom.py \ 12 | --name ${NAME} \ 13 | --nets slbr \ 14 | --models slbr \ 15 | --input-size ${INPUT_SIZE} \ 16 | --crop_size ${INPUT_SIZE} \ 17 | --test-batch 1 \ 18 | --evaluate\ 19 | --preprocess resize \ 20 | --no_flip \ 21 | --mask_mode ${MASK_MODE} \ 22 | --k_center ${K_CENTER} \ 23 | --use_refine \ 24 | --k_refine ${K_REFINE} \ 25 | --k_skip_stage ${K_SKIP} \ 26 | --resume /media/sda/Watermark/${NAME}/model_best.pth.tar \ 27 | --test_dir ${TEST_DIR} 28 | 29 | -------------------------------------------------------------------------------- /scripts/test.sh: -------------------------------------------------------------------------------- 1 | K_CENTER=2 2 | K_REFINE=3 3 | K_SKIP=3 4 | MASK_MODE=res 5 | 6 | INPUT_SIZE=256 7 | DATASET=CLWD 8 | NAME=slbr_v1 9 | 10 | CUDA_VISIBLE_DEVICES=1 python3 test.py \ 11 | --nets slbr \ 12 | --models slbr \ 13 | --input-size ${INPUT_SIZE} \ 14 | --crop_size ${INPUT_SIZE} \ 15 | --test-batch 1 \ 16 | --evaluate\ 17 | --dataset_dir /media/sda/datasets/Watermark/${DATASET} \ 18 | --preprocess resize \ 19 | --no_flip \ 20 | --name ${NAME} \ 21 | --mask_mode ${MASK_MODE} \ 22 | --k_center ${K_CENTER} \ 23 | --dataset ${DATASET} \ 24 | --resume /media/sda/Watermark/${NAME}/model_best.pth.tar \ 25 | --use_refine \ 26 | --k_refine ${K_REFINE} \ 27 | --k_skip_stage ${K_SKIP} 28 | 29 | # --checkpoint /media/sda/Watermark \ 30 | 31 | -------------------------------------------------------------------------------- /pytorch_iou/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | 6 | def _iou(pred, target, size_average = True): 7 | 8 | b = pred.shape[0] 9 | IoU = 0.0 10 | for i in range(0,b): 11 | #compute the IoU of the foreground 12 | Iand1 = torch.sum(target[i,:,:,:]*pred[i,:,:,:]) 13 | Ior1 = torch.sum(target[i,:,:,:]) + torch.sum(pred[i,:,:,:])-Iand1 14 | IoU1 = Iand1/Ior1 15 | 16 | #IoU loss is (1-IoU1) 17 | IoU = IoU + (1-IoU1) 18 | 19 | return IoU/b 20 | 21 | class IOU(torch.nn.Module): 22 | def __init__(self, size_average = True): 23 | super(IOU, self).__init__() 24 | self.size_average = size_average 25 | 26 | def forward(self, pred, target): 27 | 28 | return _iou(pred, target, self.size_average) 29 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | K_CENTER=2 2 | K_REFINE=3 3 | K_SKIP=3 4 | MASK_MODE=res #'cat' 5 | 6 | L1_LOSS=2 7 | CONTENT_LOSS=2.5e-1 8 | STYLE_LOSS=2.5e-1 9 | PRIMARY_LOSS=0.01 10 | IOU_LOSS=0.25 11 | 12 | INPUT_SIZE=256 13 | DATASET=CLWD 14 | NAME=slbr_v1 15 | # nohup python -u main.py \ 16 | CUDA_VISIBLE_DEVICES=1 python -u train.py \ 17 | --epochs 100 \ 18 | --schedule 65 \ 19 | --lr 1e-3 \ 20 | --gpu_id 1 \ 21 | --checkpoint /media/sda/Watermark \ 22 | --dataset_dir /media/sda/datasets/Watermark/${DATASET} \ 23 | --nets slbr \ 24 | --sltype vggx \ 25 | --mask_mode ${MASK_MODE} \ 26 | --lambda_content ${CONTENT_LOSS} \ 27 | --lambda_style ${STYLE_LOSS} \ 28 | --lambda_iou ${IOU_LOSS} \ 29 | --lambda_l1 ${L1_LOSS} \ 30 | --lambda_primary ${PRIMARY_LOSS} \ 31 | --masked True \ 32 | --loss-type hybrid \ 33 | --models slbr \ 34 | --input-size ${INPUT_SIZE} \ 35 | --crop_size ${INPUT_SIZE} \ 36 | --train-batch 8 \ 37 | --test-batch 1 \ 38 | --preprocess resize \ 39 | --name ${NAME} \ 40 | --k_center ${K_CENTER} \ 41 | --dataset ${DATASET} \ 42 | --use_refine \ 43 | --k_refine ${K_REFINE} \ 44 | --k_skip_stage ${K_SKIP} \ 45 | # --start-epoch 70 \ 46 | # --resume /media/sda/Watermark/${NAME}/model_best.pth.tar 47 | -------------------------------------------------------------------------------- /src/utils/model_init.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from torch.nn import init 4 | 5 | 6 | def weights_init_normal(m): 7 | classname = m.__class__.__name__ 8 | # print(classname) 9 | if classname.find('Conv') != -1: 10 | init.normal_(m.weight.data, 0.0, 0.02) 11 | elif classname.find('Linear') != -1: 12 | init.normal_(m.weight.data, 0.0, 0.02) 13 | elif classname.find('BatchNorm2d') != -1: 14 | init.normal_(m.weight.data, 1.0, 0.02) 15 | init.constant_(m.bias.data, 0.0) 16 | 17 | 18 | def weights_init_xavier(m): 19 | classname = m.__class__.__name__ 20 | # print(classname) 21 | if classname.find('Conv') != -1: 22 | init.xavier_normal(m.weight.data, gain=0.02) 23 | elif classname.find('Linear') != -1: 24 | init.xavier_normal(m.weight.data, gain=0.02) 25 | # elif classname.find('BatchNorm2d') != -1: 26 | # init.normal(m.weight.data, 1.0, 0.02) 27 | # init.constant(m.bias.data, 0.0) 28 | 29 | 30 | def weights_init_kaiming(m): 31 | classname = m.__class__.__name__ 32 | # print(classname) 33 | if classname.find('Conv') != -1 and m.weight.requires_grad == True: 34 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 35 | elif classname.find('Linear') != -1 and m.weight.requires_grad == True: 36 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 37 | elif classname.find('BatchNorm2d') != -1 and m.weight.requires_grad == True: 38 | init.normal_(m.weight.data, 1.0, 0.02) 39 | init.constant_(m.bias.data, 0.0) 40 | 41 | 42 | def weights_init_orthogonal(m): 43 | classname = m.__class__.__name__ 44 | if classname.find('Conv') != -1: 45 | init.orthogonal(m.weight.data, gain=1) 46 | elif classname.find('Linear') != -1: 47 | init.orthogonal(m.weight.data, gain=1) 48 | # elif classname.find('BatchNorm2d') != -1: 49 | # init.normal(m.weight.data, 1.0, 0.02) 50 | # init.constant(m.bias.data, 0.0) -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # from .COCO import COCO 2 | # from .BIH import BIH 3 | from .clwd_dataset import CLWDDataset 4 | from .lvw_dataset import LVWDataset 5 | import importlib 6 | import torch.utils.data 7 | from datasets.base_dataset import BaseDataset 8 | 9 | __all__ = ('CLWDDataset', 'LVWDataset') 10 | 11 | 12 | 13 | 14 | 15 | def find_dataset_using_name(dataset_name): 16 | """Import the module "data/[dataset_name]_dataset.py". 17 | 18 | In the file, the class called DatasetNameDataset() will 19 | be instantiated. It has to be a subclass of BaseDataset, 20 | and it is case-insensitive. 21 | """ 22 | dataset_filename = "data." + dataset_name + "_dataset" 23 | datasetlib = importlib.import_module(dataset_filename) 24 | 25 | dataset = None 26 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 27 | for name, cls in datasetlib.__dict__.items(): 28 | if name.lower() == target_dataset_name.lower() \ 29 | and issubclass(cls, BaseDataset): 30 | dataset = cls 31 | 32 | if dataset is None: 33 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) 34 | 35 | return dataset 36 | 37 | 38 | def get_option_setter(dataset_name): 39 | """Return the static method of the dataset class.""" 40 | dataset_class = find_dataset_using_name(dataset_name) 41 | return dataset_class.modify_commandline_options 42 | 43 | 44 | def create_dataset(opt): 45 | """Create a dataset given the option. 46 | 47 | This function wraps the class CustomDatasetDataLoader. 48 | This is the main interface between this package and 'train.py'/'test.py' 49 | 50 | Example: 51 | >>> from data import create_dataset 52 | >>> dataset = create_dataset(opt) 53 | """ 54 | data_loader = CustomDatasetDataLoader(opt) 55 | dataset = data_loader.load_data() 56 | return dataset 57 | 58 | 59 | class CustomDatasetDataLoader(): 60 | """Wrapper class of Dataset class that performs multi-threaded data loading""" 61 | 62 | def __init__(self, opt): 63 | """Initialize this class 64 | 65 | Step 1: create a dataset instance given the name [dataset_mode] 66 | Step 2: create a multi-threaded data loader. 67 | """ 68 | self.opt = opt 69 | dataset_class = find_dataset_using_name(opt.dataset_mode) 70 | self.dataset = dataset_class(opt) 71 | print("dataset [%s] was created" % type(self.dataset).__name__) 72 | self.dataloader = torch.utils.data.DataLoader( 73 | self.dataset, 74 | batch_size=opt.batch_size, 75 | shuffle=not opt.serial_batches, 76 | num_workers=int(opt.num_threads)) 77 | 78 | def load_data(self): 79 | return self 80 | 81 | def __len__(self): 82 | """Return the number of data in the dataset""" 83 | return min(len(self.dataset), self.opt.max_dataset_size) 84 | 85 | def __iter__(self): 86 | """Return a batch of data""" 87 | for i, data in enumerate(self.dataloader): 88 | if i * self.opt.batch_size >= self.opt.max_dataset_size: 89 | break 90 | yield data 91 | -------------------------------------------------------------------------------- /src/utils/misc.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import os 4 | import shutil 5 | import torch 6 | import math 7 | import numpy as np 8 | import scipy.io 9 | import matplotlib.pyplot as plt 10 | import torch.nn.functional as F 11 | 12 | def to_numpy(tensor): 13 | if torch.is_tensor(tensor): 14 | return tensor.cpu().numpy() 15 | elif type(tensor).__module__ != 'numpy': 16 | raise ValueError("Cannot convert {} to numpy array" 17 | .format(type(tensor))) 18 | return tensor 19 | 20 | def resize_to_match(fm,to): 21 | # just use interpolate 22 | # [1,3] = (h,w) 23 | return F.interpolate(fm,to.size()[-2:],mode='bilinear',align_corners=False) 24 | 25 | def to_torch(ndarray): 26 | if type(ndarray).__module__ == 'numpy': 27 | return torch.from_numpy(ndarray) 28 | elif not torch.is_tensor(ndarray): 29 | raise ValueError("Cannot convert {} to torch tensor" 30 | .format(type(ndarray))) 31 | return ndarray 32 | 33 | 34 | def save_checkpoint(machine,filename='checkpoint.pth.tar', snapshot=None): 35 | is_best = True if machine.best_acc < machine.metric else False 36 | 37 | if is_best: 38 | machine.best_acc = machine.metric 39 | 40 | state = { 41 | 'epoch': machine.current_epoch + 1, 42 | 'arch': machine.args.arch, 43 | 'state_dict': machine.model.state_dict(), 44 | 'best_acc': machine.best_acc, 45 | 'optimizer' : machine.optimizer.state_dict(), 46 | } 47 | 48 | filepath = os.path.join(machine.args.checkpoint, filename) 49 | torch.save(state, filepath) 50 | 51 | if snapshot and state['epoch'] % snapshot == 0: 52 | shutil.copyfile(filepath, os.path.join(machine.args.checkpoint, 'checkpoint_{}.pth.tar'.format(state.epoch))) 53 | 54 | if is_best: 55 | machine.best_acc = machine.metric 56 | print('Saving Best Metric with PSNR:%s'%machine.best_acc) 57 | shutil.copyfile(filepath, os.path.join(machine.args.checkpoint, 'model_best.pth.tar')) 58 | 59 | 60 | 61 | def save_pred(preds, checkpoint='checkpoint', filename='preds_valid.mat'): 62 | preds = to_numpy(preds) 63 | filepath = os.path.join(checkpoint, filename) 64 | scipy.io.savemat(filepath, mdict={'preds' : preds}) 65 | 66 | 67 | def adjust_learning_rate(datasets, model, epoch, lr,args): 68 | """Sets the learning rate to the initial LR decayed by schedule""" 69 | if epoch in args.schedule: 70 | lr *= args.gamma 71 | optimizers = [getattr(model.model, attr) for attr in dir(model.model) if attr.startswith("optimizer") and getattr(model.model, attr) is not None] 72 | for optimizer in optimizers: 73 | for param_group in optimizer.param_groups: 74 | param_group['lr'] = lr 75 | else: # access the current learning rate 76 | optimizers = [getattr(model.model, attr) for attr in dir(model.model) if attr.startswith("optimizer") and getattr(model.model, attr) is not None] 77 | for optimizer in optimizers: 78 | for param_group in optimizer.param_groups: 79 | lr = param_group['lr'] 80 | # decay sigma 81 | # for dset in datasets: 82 | # if args.sigma_decay > 0: 83 | # dset.dataset.sigma *= args.sigma_decay 84 | # dset.dataset.sigma *= args.sigma_decay 85 | 86 | return lr 87 | 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | import argparse 4 | import torch,time,os 5 | 6 | torch.backends.cudnn.benchmark = True 7 | 8 | from src.utils.misc import save_checkpoint, adjust_learning_rate 9 | import src.models as models 10 | 11 | import datasets as datasets 12 | from options import Options 13 | import numpy as np 14 | 15 | def main(args): 16 | args.seed = 1 17 | np.random.seed(args.seed) 18 | torch.manual_seed(args.seed) 19 | 20 | args.dataset = args.dataset.lower() 21 | if args.dataset == 'clwd': 22 | dataset_func = datasets.CLWDDataset 23 | elif args.dataset == 'lvw': 24 | dataset_func = datasets.LVWDataset 25 | else: 26 | raise ValueError("Not known dataset:\t{}".format(args.dataset)) 27 | 28 | train_loader = torch.utils.data.DataLoader(dataset_func('train',args),batch_size=args.train_batch, shuffle=True, 29 | num_workers=args.workers, pin_memory=True) 30 | 31 | val_loader = torch.utils.data.DataLoader(dataset_func('val',args),batch_size=args.test_batch, shuffle=False, 32 | num_workers=args.workers, pin_memory=True) 33 | 34 | lr = args.lr 35 | data_loaders = (train_loader,val_loader) 36 | 37 | model = models.__dict__[args.models](datasets=data_loaders, args=args) 38 | print('============================ Initization Finish && Training Start =============================================') 39 | 40 | for epoch in range(model.args.start_epoch, model.args.epochs): 41 | lr = adjust_learning_rate(data_loaders, model, epoch, lr, args) 42 | print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr)) 43 | 44 | model.record('lr',lr, epoch) 45 | model.train(epoch) 46 | # model.validate(epoch) 47 | if args.freq < 0: 48 | model.validate(epoch) 49 | model.flush() 50 | model.save_checkpoint() 51 | 52 | if __name__ == '__main__': 53 | torch.backends.cudnn.benchmark = True 54 | parser=Options().init(argparse.ArgumentParser(description='WaterMark Removal')) 55 | args = parser.parse_args() 56 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id 57 | print('==================================== WaterMark Removal =============================================') 58 | print('==> {:50}: {:<}'.format("Start Time",time.ctime(time.time()))) 59 | print('==> {:50}: {:<}'.format("USE GPU",os.environ['CUDA_VISIBLE_DEVICES'])) 60 | print('==================================== Stable Parameters =============================================') 61 | for arg in vars(args): 62 | if type(getattr(args, arg)) == type([]): 63 | if ','.join([ str(i) for i in getattr(args, arg)]) == ','.join([ str(i) for i in parser.get_default(arg)]): 64 | print('==> {:50}: {:<}({:<})'.format(arg,','.join([ str(i) for i in getattr(args, arg)]),','.join([ str(i) for i in parser.get_default(arg)]))) 65 | else: 66 | if getattr(args, arg) == parser.get_default(arg): 67 | print('==> {:50}: {:<}({:<})'.format(arg,getattr(args, arg),parser.get_default(arg))) 68 | print('==================================== Changed Parameters =============================================') 69 | for arg in vars(args): 70 | if type(getattr(args, arg)) == type([]): 71 | if ','.join([ str(i) for i in getattr(args, arg)]) != ','.join([ str(i) for i in parser.get_default(arg)]): 72 | print('==> {:50}: {:<}({:<})'.format(arg,','.join([ str(i) for i in getattr(args, arg)]),','.join([ str(i) for i in parser.get_default(arg)]))) 73 | else: 74 | if getattr(args, arg) != parser.get_default(arg): 75 | print('==> {:50}: {:<}({:<})'.format(arg,getattr(args, arg),parser.get_default(arg))) 76 | print('==================================== Start Init Model ===============================================') 77 | main(args) 78 | print('==================================== FINISH WITHOUT ERROR =============================================') 79 | -------------------------------------------------------------------------------- /test_custom.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import cv2 5 | import numpy as np 6 | 7 | torch.backends.cudnn.benchmark = True 8 | 9 | import datasets as datasets 10 | import src.models as models 11 | from options import Options 12 | import torch.nn.functional as F 13 | 14 | 15 | 16 | def tensor2np(x, isMask=False): 17 | if isMask: 18 | if x.shape[1] == 1: 19 | x = x.repeat(1,3,1,1) 20 | x = ((x.cpu().detach()))*255 21 | else: 22 | x = x.cpu().detach() 23 | mean = 0 24 | std = 1 25 | x = (x * std + mean)*255 26 | 27 | return x.numpy().transpose(0,2,3,1).astype(np.uint8) 28 | 29 | def save_output(inputs, preds, save_dir, img_fn, extra_infos=None, verbose=False, alpha=0.5): 30 | outs = [] 31 | image = inputs['I'] #, inputs['bg'], inputs['mask'] 32 | image = cv2.cvtColor(tensor2np(image)[0], cv2.COLOR_RGB2BGR) 33 | 34 | bg_pred,mask_preds = preds['bg'], preds['mask'] 35 | bg_pred = cv2.cvtColor(tensor2np(bg_pred)[0], cv2.COLOR_RGB2BGR) 36 | mask_pred = tensor2np(mask_preds, isMask=True)[0] 37 | outs = [image, bg_pred, mask_pred] 38 | outimg = np.concatenate(outs, axis=1) 39 | 40 | if verbose==True: 41 | # print("show") 42 | cv2.imshow("out",outimg) 43 | cv2.waitKey(0) 44 | else: 45 | img_fn = os.path.split(img_fn)[-1] 46 | out_fn = os.path.join(save_dir, "{}{}".format(os.path.splitext(img_fn)[0], os.path.splitext(img_fn)[1])) 47 | cv2.imwrite(out_fn, outimg) 48 | 49 | def preprocess(file_path, img_size=512): 50 | img_J = cv2.imread(file_path) 51 | assert img_J is not None, "NoneType" 52 | h,w,_ = img_J.shape 53 | img_J = cv2.cvtColor(img_J, cv2.COLOR_BGR2RGB).astype(np.float)/255. 54 | img_J = torch.from_numpy(img_J.transpose(2,0,1)[np.newaxis,...]) #[1,C,H,W] 55 | img_J = F.interpolate(img_J, size=(img_size, img_size), mode='bilinear') 56 | 57 | return img_J 58 | 59 | 60 | def test_dataloder(img_path, crop_size): 61 | loaders = [] 62 | save_fns = [] 63 | 64 | for root, dirs, fns in os.walk(img_path): 65 | for dir in dirs: 66 | path = os.path.join(root, dir) 67 | fn_list = os.listdir(path) 68 | for fn in fn_list: 69 | if fn.startswith('.'): continue 70 | if not (fn.endswith('.jpg') or fn.endswith('jpeg') or fn.endswith('png') ): continue 71 | fn = os.path.join(path, fn) 72 | J = preprocess(fn, img_size=crop_size) 73 | loaders.append(J) 74 | save_fns.append(fn) 75 | return loaders,save_fns 76 | 77 | 78 | 79 | def main(args): 80 | 81 | Machine = models.__dict__[args.models](datasets=(None, None), args=args) 82 | 83 | model = Machine 84 | model.model.eval() 85 | print("==> testing VM model ") 86 | 87 | prediction_dir = os.path.join(args.test_dir,'rst') 88 | if not os.path.exists(prediction_dir): os.makedirs(prediction_dir) 89 | 90 | doc_loader,fns = test_dataloder(args.test_dir, args.crop_size) 91 | with torch.no_grad(): 92 | for i, batches in enumerate(zip(doc_loader, fns)): 93 | inputs, fn = batches[0], batches[1] 94 | inputs = inputs.to(model.device).float() 95 | outputs = model.model(inputs) 96 | imoutput,immask_all,imwatermark = outputs 97 | 98 | imoutput = imoutput[0] 99 | immask = immask_all[0] 100 | 101 | imfinal =imoutput*immask + model.norm(inputs)*(1-immask) 102 | save_output( 103 | inputs = {'I':inputs}, 104 | preds = {'bg':imfinal, 'mask':immask}, 105 | save_dir= prediction_dir, 106 | img_fn = fn 107 | ) 108 | 109 | 110 | 111 | 112 | 113 | 114 | if __name__ == '__main__': 115 | parser=Options().init(argparse.ArgumentParser(description='WaterMark Removal')) 116 | main(parser.parse_args()) 117 | 118 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Visible Watermark Removal via Self-calibrated Localization and Background Refinement 2 | --- 3 | 4 | ## Introduction 5 | This is the official code of the following paper: 6 | 7 | > 8 | > **Visible Watermark Removal via Self-calibrated Localization and Background Refinement**[[1]](#reference) 9 | >
Jing Liang1, Li Niu1, Fengjun Guo2, Teng Long2 and Liqing Zhang1 10 | >
1MoE Key Lab of Artificial Intelligence, Shanghai Jiao Tong University 11 | >
2INTSIG
12 | ([ACM MM 2021](https://arxiv.org/pdf/2108.03581.pdf) | [Bibtex](#citation)) 13 | 14 | 15 | ### SLBR Network 16 | Here is our proposed **SLBR**(**S**elf-calibrated **L**ocalization and **B**ackground **R**efinement). Top row depicts the whole framework of SLBR and bottom row elaborates the details of our proposed three modules. 17 |
18 | Some examples of inharmonious region 19 |
20 |
21 | Some examples of inharmonious region 22 |
23 | 24 | 25 | ## Quick Start 26 | ### Install 27 | - Install PyTorch>=1.0 following the [official instructions](https://pytorch.org/) 28 | - git clone https://github.com/bcmi/SLBR-Visible-Watermark-Removal.git 29 | - Install dependencies: pip install -r requirements.txt 30 | 31 | ### Data Preparation 32 | In this paper, we conduct all of the experiments on the latest released dataset [CLWD](https://drive.google.com/file/d/17y1gkUhIV6rZJg1gMG-gzVMnH27fm4Ij/view?usp=sharing)[[2]](#reference) and LVW[[3]](#reference). You can contact the authors of LVW to obtain the dataset. 33 | 34 | 35 | 36 | ### Train and Test 37 | - How to train and test my model? 38 | 39 | We provide an example of training and a test bash respectively:```scripts/train.sh```, ```scripts/test.sh``` 40 | 41 | Please specify the checkpoint save path in ```--checkpoint``` and dataset path in```--dataset_dir```. 42 | 43 | - How to test on my data? 44 | 45 | We also provide an example of a custom data test bash: 46 | ```scripts/test_custom.sh``` 47 | And you can further tailor ```test_custom.py``` to meet your demands. For the best performance, it is better to finetune on your dataset since our training data size is set as 256x256. 48 | 49 | ### Pretrained Model 50 | Here is the model trained on CLWD dataset: 51 | - [Google Drive](https://drive.google.com/file/d/1uTCzubnWZtu3HIXaK8xsXX-7x302ss13/view?usp=sharing) 52 | 53 | - [OneDrive](https://1drv.ms/u/s!AvQt5C5JE-WqkRkz9KI9o3OTfpZf?e=TDp9LV) 54 | 55 | ## Visualization Results 56 | We also show some qualitative comparision with state-of-art methods: 57 | 58 |
59 | Some examples of inharmonious region 60 |
61 | 62 | 63 | ## **Acknowledgements** 64 | Part of the code is based upon the previous work [SplitNet](https://github.com/vinthony/deep-blind-watermark-removal)[[4]](#reference). 65 | 66 | ## Citation 67 | If you find this work or code is helpful in your research, please cite: 68 | ```` 69 | @inproceedings{liang2021visible, 70 | title={Visible Watermark Removal via Self-calibrated Localization and Background Refinement}, 71 | author={Liang, Jing and Niu, Li and Guo, Fengjun and Long, Teng and Zhang, Liqing}, 72 | booktitle={Proceedings of the 29th ACM International Conference on Multimedia}, 73 | pages={4426--4434}, 74 | year={2021} 75 | } 76 | ```` 77 | 78 | ## Resources 79 | 80 | We have summarized the existing papers, codes, and datasets on visible watermark removal in the following repository: 81 | [https://github.com/bcmi/Awesome-Visible-Watermark-Removal](https://github.com/bcmi/Awesome-Visible-Watermark-Removal) 82 | 83 | 84 | ## Reference 85 | [1] Jing Liang, Li Niu, Fengjun Guo, Teng Long and Liqing Zhang. 2021. Visible Watermark Removal via Self-calibrated Localization and Background Refinement. In *Proceedings of the 29th ACM International Conference on Multimedia*. [download](https://arxiv.org/pdf/2104.09453.pdf) 86 | 87 | [2] Liu, Yang and Zhu, Zhen and Bai, Xiang. 2021. WDNet: Watermark-Decomposition Network for Visible Watermark Removal. In *Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision.* 88 | 89 | [3] Danni Cheng, Xiang Li, Wei-Hong Li, Chan Lu, Fake Li, Hua Zhao, and WeiShi Zheng. 2018. Large-scale visible watermark detection and removal with deep convolutional networks. In *Chinese Conference on Pattern Recognition and Computer Vision*. 27–40. 90 | 91 | [4] Xiaodong Cun and Chi-Man Pun. 2020. Split then Refine: Stacked Attentionguided ResUNets for Blind Single Image Visible Watermark Removal. arXiv preprint arXiv:2012.07007 (2020). 92 | -------------------------------------------------------------------------------- /datasets/lvw_dataset.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import cv2 4 | import os.path as osp 5 | import os 6 | import sys 7 | import torch 8 | from torchvision import datasets, transforms 9 | from .base_dataset import get_transform 10 | import random 11 | 12 | class LVWDataset(torch.utils.data.Dataset): 13 | def __init__(self, phase, args): 14 | if phase == 'train': 15 | self.keep_background_prob = 0.01 16 | else: 17 | phase = 'test' 18 | self.keep_background_prob = -1 19 | self.augment_transform = get_transform(args, 20 | additional_targets={'J':'image', 'I':'image', 'watermark':'image', 'mask':'mask', 'alpha':'mask' }) #, 21 | 22 | 23 | self.transform_norm=transforms.Compose([transforms.ToTensor()]) 24 | self.transform_tensor= transforms.ToTensor() 25 | root = args.dataset_dir + '/' + phase + '/' 26 | self.imageJ_path=osp.join(root,'image','%s.png') 27 | self.imageI_path=osp.join(root,'background','%s.png') 28 | self.mask_path=osp.join(root,'mask','%s.png') 29 | # self.balance_path=osp.join(root,'Loss_balance','%s.png') 30 | self.alpha_path=osp.join(root,'alpha','%s.png') 31 | self.W_path=osp.join(root,'mask','%s.png') 32 | self.root = root 33 | self.transform= transforms 34 | self.ids = list() 35 | for file in os.listdir(root+'image'): 36 | #if(file[:-4]=='.jpg'): 37 | if file.endswith('.jpg') or file.endswith('.png'): 38 | self.ids.append(file.strip('.png')) 39 | 40 | def __getitem__(self,index): 41 | sample = self.get_sample(index) 42 | self.check_sample_types(sample) 43 | sample = self.augment_sample(sample) 44 | 45 | J = self.transform_norm(sample['J']) 46 | I = self.transform_norm(sample['I']) 47 | w = self.transform_norm(sample['watermark']) 48 | 49 | mask = sample['mask'][np.newaxis, ...].astype(np.float32) 50 | mask = np.where(mask > 0.1, 1, 0).astype(np.float32) 51 | alpha = sample['alpha'][np.newaxis, ...].astype(np.float32) 52 | balance = torch.ones_like(w) 53 | # mask = self.transform_tensor(mask) 54 | data = { 55 | 'image': J, 56 | 'target': I, 57 | 'wm': w, 58 | 'mask': mask, 59 | 'alpha':alpha, 60 | 'img_path':sample['img_path'] 61 | } 62 | return data 63 | #return J,I,mask,w, sample['img_path'] 64 | 65 | def __len__(self): 66 | return len(self.ids) 67 | 68 | def get_sample(self, index): 69 | img_id = self.ids[index] 70 | # img_id = self.corrupt_list[index % len(self.corrupt_list)].split('.')[0] 71 | img_J = np.asarray(Image.open(self.imageJ_path%img_id))[...,:3] 72 | # print(self.imageJ_path%img_id, type(img_J)) 73 | # img_J = cv2.cvtColor(img_J, cv2.COLOR_BGR2RGB) 74 | 75 | img_I = np.asarray(Image.open(self.imageI_path%img_id))[...,:3] 76 | # img_I = cv2.cvtColor(img_I, cv2.COLOR_BGR2RGB) 77 | 78 | w = np.asarray(Image.open(self.W_path%img_id))[...,:3] 79 | if w is None: print(self.W_path%img_id) 80 | # w = cv2.cvtColor(w, cv2.COLOR_BGR2RGB) 81 | 82 | mask = np.asarray(Image.open(self.mask_path%img_id)) 83 | alpha = np.asarray(Image.open(self.alpha_path%img_id)) 84 | 85 | mask = mask[:, :, 0].astype(np.float32) / 255. 86 | alpha = alpha[:, :, 0].astype(np.float32) / 255. 87 | 88 | return {'J': img_J, 'I': img_I, 'watermark': w, 'mask':mask, 'alpha':alpha, 'img_path':self.imageJ_path%img_id} 89 | 90 | def check_sample_types(self, sample): 91 | assert sample['J'].dtype == 'uint8' 92 | assert sample['I'].dtype == 'uint8' 93 | assert sample['watermark'].dtype == 'uint8' 94 | 95 | def augment_sample(self, sample): 96 | if self.augment_transform is None: 97 | return sample 98 | #print(self.transform.additional_targets.keys()) 99 | additional_targets = {target_name: sample[target_name] 100 | for target_name in self.augment_transform.additional_targets.keys()} 101 | 102 | valid_augmentation = False 103 | while not valid_augmentation: 104 | aug_output = self.augment_transform(image=sample['I'], **additional_targets) 105 | valid_augmentation = self.check_augmented_sample(sample, aug_output) 106 | 107 | for target_name, transformed_target in aug_output.items(): 108 | #print(target_name,transformed_target.shape) 109 | sample[target_name] = transformed_target 110 | 111 | return sample 112 | 113 | def check_augmented_sample(self, sample, aug_output): 114 | if self.keep_background_prob < 0.0 or random.random() < self.keep_background_prob: 115 | return True 116 | return aug_output['mask'].sum() > 100 -------------------------------------------------------------------------------- /pytorch_ssim/__init__.py: -------------------------------------------------------------------------------- 1 | # https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import numpy as np 6 | from math import exp 7 | 8 | def gaussian(window_size, sigma): 9 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 10 | return gauss/gauss.sum() 11 | 12 | def create_window(window_size, channel): 13 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 14 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 15 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 16 | return window 17 | 18 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 19 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 20 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 21 | 22 | mu1_sq = mu1.pow(2) 23 | mu2_sq = mu2.pow(2) 24 | mu1_mu2 = mu1*mu2 25 | 26 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 27 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 28 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 29 | 30 | C1 = 0.01**2 31 | C2 = 0.03**2 32 | 33 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 34 | 35 | if size_average: 36 | return ssim_map.mean() 37 | else: 38 | # return ssim_map.mean(1).mean(1).mean(1) 39 | return ssim_map 40 | 41 | class SSIM(torch.nn.Module): 42 | def __init__(self, window_size = 11, size_average = True): 43 | super(SSIM, self).__init__() 44 | self.window_size = window_size 45 | self.size_average = size_average 46 | self.channel = 1 47 | self.window = create_window(window_size, self.channel) 48 | 49 | def forward(self, img1, img2): 50 | (_, channel, _, _) = img1.size() 51 | 52 | if channel == self.channel and self.window.data.type() == img1.data.type(): 53 | window = self.window 54 | else: 55 | window = create_window(self.window_size, channel) 56 | 57 | if img1.is_cuda: 58 | window = window.cuda(img1.get_device()) 59 | window = window.type_as(img1) 60 | 61 | self.window = window 62 | self.channel = channel 63 | 64 | 65 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 66 | 67 | def _logssim(img1, img2, window, window_size, channel, size_average = True): 68 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 69 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 70 | 71 | mu1_sq = mu1.pow(2) 72 | mu2_sq = mu2.pow(2) 73 | mu1_mu2 = mu1*mu2 74 | 75 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 76 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 77 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 78 | 79 | C1 = 0.01**2 80 | C2 = 0.03**2 81 | 82 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 83 | ssim_map = (ssim_map - torch.min(ssim_map))/(torch.max(ssim_map)-torch.min(ssim_map)) 84 | ssim_map = -torch.log(ssim_map + 1e-8) 85 | 86 | if size_average: 87 | return ssim_map.mean() 88 | else: 89 | return ssim_map.mean(1).mean(1).mean(1) 90 | 91 | class LOGSSIM(torch.nn.Module): 92 | def __init__(self, window_size = 11, size_average = True): 93 | super(LOGSSIM, self).__init__() 94 | self.window_size = window_size 95 | self.size_average = size_average 96 | self.channel = 1 97 | self.window = create_window(window_size, self.channel) 98 | 99 | def forward(self, img1, img2): 100 | (_, channel, _, _) = img1.size() 101 | 102 | if channel == self.channel and self.window.data.type() == img1.data.type(): 103 | window = self.window 104 | else: 105 | window = create_window(self.window_size, channel) 106 | 107 | if img1.is_cuda: 108 | window = window.cuda(img1.get_device()) 109 | window = window.type_as(img1) 110 | 111 | self.window = window 112 | self.channel = channel 113 | 114 | 115 | return _logssim(img1, img2, window, self.window_size, channel, self.size_average) 116 | 117 | 118 | def ssim(img1, img2, window_size = 11, size_average = True): 119 | (_, channel, _, _) = img1.size() 120 | window = create_window(window_size, channel) 121 | 122 | if img1.is_cuda: 123 | window = window.cuda(img1.get_device()) 124 | window = window.type_as(img1) 125 | 126 | return _ssim(img1, img2, window, window_size, channel, size_average) 127 | -------------------------------------------------------------------------------- /datasets/clwd_dataset.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import cv2 4 | import os.path as osp 5 | import os 6 | import sys 7 | import torch 8 | from torchvision import datasets, transforms 9 | from .base_dataset import get_transform 10 | import random 11 | 12 | class CLWDDataset(torch.utils.data.Dataset): 13 | def __init__(self, is_train, args): 14 | 15 | args.is_train = is_train == 'train' 16 | if args.is_train == True: 17 | self.root = args.dataset_dir + '/train/' 18 | # self.keep_background_prob = 0.01 19 | self.keep_background_prob = -1 20 | elif args.is_train == False: 21 | self.root = args.dataset_dir + '/test/' #'/test/' 22 | self.keep_background_prob = -1 23 | args.preprocess = 'resize' 24 | args.no_flip = True 25 | 26 | self.args = args 27 | # Augmentataion? 28 | self.transform_norm=transforms.Compose([ 29 | transforms.ToTensor()]) 30 | # transforms.Normalize( 31 | # # (0.485, 0.456, 0.406), 32 | # # (0.229, 0.224, 0.225) 33 | # (0.5,0.5,0.5), 34 | # (0.5,0.5,0.5) 35 | # )]) 36 | self.augment_transform = get_transform(args, 37 | additional_targets={'J':'image', 'I':'image', 'watermark':'image', 'mask':'mask', 'alpha':'mask' }) #, 38 | self.transform_tensor = transforms.ToTensor() 39 | 40 | self.imageJ_path=osp.join(self.root,'Watermarked_image','%s.jpg') 41 | self.imageI_path=osp.join(self.root,'Watermark_free_image','%s.jpg') 42 | self.mask_path=osp.join(self.root,'Mask','%s.png') 43 | # self.balance_path=osp.join(self.root,'Loss_balance','%s.png') 44 | self.alpha_path=osp.join(self.root,'Alpha','%s.png') 45 | self.W_path=osp.join(self.root,'Watermark','%s.png') 46 | 47 | self.ids = list() 48 | for file in os.listdir(self.root+'/Watermarked_image'): 49 | self.ids.append(file.strip('.jpg')) 50 | cv2.setNumThreads(0) 51 | cv2.ocl.setUseOpenCL(False) 52 | 53 | 54 | 55 | def __len__(self): 56 | return len(self.ids) 57 | 58 | def get_sample(self, index): 59 | img_id = self.ids[index] 60 | # img_id = self.corrupt_list[index % len(self.corrupt_list)].split('.')[0] 61 | img_J = cv2.imread(self.imageJ_path%img_id) 62 | img_J = cv2.cvtColor(img_J, cv2.COLOR_BGR2RGB) 63 | 64 | img_I = cv2.imread(self.imageI_path%img_id) 65 | img_I = cv2.cvtColor(img_I, cv2.COLOR_BGR2RGB) 66 | 67 | w = cv2.imread(self.W_path%img_id) 68 | if w is None: print(self.W_path%img_id) 69 | w = cv2.cvtColor(w, cv2.COLOR_BGR2RGB) 70 | 71 | mask = cv2.imread(self.mask_path%img_id) 72 | alpha = cv2.imread(self.alpha_path%img_id) 73 | 74 | mask = mask[:, :, 0].astype(np.float32) / 255. 75 | alpha = alpha[:, :, 0].astype(np.float32) / 255. 76 | 77 | return {'J': img_J, 'I': img_I, 'watermark': w, 'mask':mask, 'alpha':alpha, 'img_path':self.imageJ_path%img_id} 78 | 79 | 80 | def __getitem__(self, index): 81 | sample = self.get_sample(index) 82 | self.check_sample_types(sample) 83 | sample = self.augment_sample(sample) 84 | 85 | J = self.transform_norm(sample['J']) 86 | I = self.transform_norm(sample['I']) 87 | w = self.transform_norm(sample['watermark']) 88 | 89 | mask = sample['mask'][np.newaxis, ...].astype(np.float32) 90 | mask = np.where(mask > 0.1, 1, 0).astype(np.uint8) 91 | alpha = sample['alpha'][np.newaxis, ...].astype(np.float32) 92 | 93 | data = { 94 | 'image': J, 95 | 'target': I, 96 | 'wm': w, 97 | 'mask': mask, 98 | 'alpha':alpha, 99 | 'img_path':sample['img_path'] 100 | } 101 | return data 102 | 103 | def check_sample_types(self, sample): 104 | assert sample['J'].dtype == 'uint8' 105 | assert sample['I'].dtype == 'uint8' 106 | assert sample['watermark'].dtype == 'uint8' 107 | 108 | def augment_sample(self, sample): 109 | if self.augment_transform is None: 110 | return sample 111 | #print(self.transform.additional_targets.keys()) 112 | additional_targets = {target_name: sample[target_name] 113 | for target_name in self.augment_transform.additional_targets.keys()} 114 | 115 | valid_augmentation = False 116 | while not valid_augmentation: 117 | aug_output = self.augment_transform(image=sample['I'], **additional_targets) 118 | valid_augmentation = self.check_augmented_sample(sample, aug_output) 119 | 120 | for target_name, transformed_target in aug_output.items(): 121 | #print(target_name,transformed_target.shape) 122 | sample[target_name] = transformed_target 123 | 124 | return sample 125 | 126 | def check_augmented_sample(self, sample, aug_output): 127 | if self.keep_background_prob < 0.0 or random.random() < self.keep_background_prob: 128 | return True 129 | return aug_output['mask'].sum() > 100 130 | 131 | 132 | 133 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from sklearn.metrics import average_precision_score 4 | 5 | 6 | class AverageMeter(object): 7 | """Computes and stores the average and current value""" 8 | def __init__(self): 9 | self.reset() 10 | 11 | def reset(self): 12 | self.val = 0 13 | self.avg = 0 14 | self.sum = 0 15 | self.count = 0 16 | 17 | def update(self, val, n=1): 18 | self.val = val 19 | self.sum += val * n 20 | self.count += n 21 | self.avg = self.sum / self.count 22 | 23 | 24 | def normPRED(d, eps=1e-2): 25 | ma = torch.max(d) 26 | mi = torch.min(d) 27 | 28 | if ma-mi threshold, torch.ones_like(pred), torch.zeros_like(pred)).to(pred.device) 69 | intersection = (pred * gt).sum(dim=[1,2,3]) 70 | union = pred.sum(dim=[1,2,3]) + gt.sum(dim=[1,2,3]) - intersection 71 | return (intersection / (union+eps)).mean().item() 72 | 73 | def MAE(pred, gt): 74 | if isinstance(pred, torch.Tensor): 75 | return torch.mean(torch.abs(pred - gt)) 76 | elif isinstance(pred, np.ndarray): 77 | return np.mean(np.abs(pred-gt)) 78 | 79 | def FScore(pred, gt, beta2=1.0, threshold=0.5, eps=1e-6, reduce_dims=[1,2,3]): 80 | if isinstance(pred, torch.Tensor): 81 | if threshold == -1: threshold = pred.mean().item() * 2 82 | ones = torch.ones_like(pred).to(pred.device) 83 | zeros = torch.zeros_like(pred).to(pred.device) 84 | pred_ = torch.where(pred > threshold, ones, zeros) 85 | gt = torch.where(gt>threshold, ones, zeros) 86 | total_num = pred.nelement() 87 | 88 | TP = (pred_ * gt).sum(dim=reduce_dims) 89 | NumPrecision = pred_.sum(dim=reduce_dims) 90 | NumRecall = gt.sum(dim=reduce_dims) 91 | 92 | precision = TP / (NumPrecision+eps) 93 | recall = TP / (NumRecall+eps) 94 | F_beta = (1+beta2)*(precision * recall) / (beta2*precision + recall + eps) 95 | F_beta = F_beta.mean() 96 | 97 | elif isinstance(pred, np.ndarray): 98 | if threshold == -1: threshold = pred.mean()* 2 99 | pred_ = np.where(pred > threshold, 1.0, 0.0) 100 | gt = np.where(gt > threshold, 1.0, 0.0) 101 | total_num = np.prod(pred_.shape) 102 | 103 | TP = (pred_ * gt).sum() 104 | NumPrecision = pred_.sum() 105 | NumRecall = gt.sum() 106 | 107 | precision = TP / (NumPrecision+eps) 108 | recall = TP / (NumRecall+eps) 109 | F_beta = (1+beta2)*(precision * recall) / (beta2*precision + recall + eps) 110 | 111 | return F_beta 112 | 113 | class Fmeasure: 114 | def __init__(self, n_imgs, beta2=0.3, thresholds=[t/255 for t in range(255,-1, -1)]): 115 | self.n_imgs = n_imgs 116 | self.idx = 0 117 | self.beta2 = beta2 118 | self.thresholds = thresholds 119 | self.reset() 120 | 121 | def reset(self): 122 | if isinstance(self.thresholds, int): 123 | self.thresholds_fm = np.zeros((self.n_imgs, 1), dtype=np.float) 124 | elif isinstance(self.thresholds, list): 125 | self.thresholds_fm = np.zeros((self.n_imgs, len(self.thresholds)), type=np.float) 126 | self.adp_fm = np.zeros((self.n_imgs,), dtype=np.float) # adaptive threshold 127 | self.fixed_fm = np.zeros((self.n_imgs,), dtype=np.float) # fixed threshold: 0.5, beta2 = 1.0 128 | 129 | def update(self, pred, gt): 130 | if isinstance(self.thresholds, int): 131 | self.thresholds_fm[self.idx] = FScore(pred, gt, beta2=self.beta2, threshold=self.threshold) 132 | elif isinstance(self.thresholds, list): 133 | for i, t in enumerate(self.thresholds): 134 | self.thresholds_fm[self.idx, i] = FScore(pred, gt, beta2=self.beta2, threshold=t) 135 | # adaptive thresold 136 | self.adp_fm[self.idx] = FScore(pred, gt, beta2=self.beta2, threshold=-1) 137 | self.fixed_fm[self.idx] = FScore(pred, gt, beta2=1.0, threshold=0.5) 138 | self.idx+=1 139 | 140 | def val(self, eps=1e-6): 141 | column_Fm = self.thresholds.sum(axis=0) / (self.idx+eps) 142 | mean_Fm = column_Fm.mean() 143 | max_Fm = column_Fm.max() 144 | adp_Fm = (self.adp_fm.sum(axis=0) / (self.idx+eps)) 145 | fixed_Fm = (self.fixed_fm.sum(axis=0) / (self.idx+eps)) 146 | return {'meanFm':mean_Fm, 'maxFm':max_Fm, 'adpFm':adp_Fm, 'fixedFm':fixed_Fm} 147 | -------------------------------------------------------------------------------- /src/utils/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import os 4 | import numpy as np 5 | import scipy.misc 6 | import matplotlib.pyplot as plt 7 | import torch 8 | import torchvision 9 | 10 | from .misc import * 11 | from .imutils import * 12 | 13 | 14 | def color_normalize(x, mean, std): 15 | if x.size(0) == 1: 16 | x = x.repeat(3, x.size(1), x.size(2)) 17 | 18 | for t, m, s in zip(x, mean, std): 19 | t.sub_(m) 20 | return x 21 | 22 | 23 | def flip_back(flip_output, dataset='mpii'): 24 | """ 25 | flip output map 26 | """ 27 | if dataset == 'mpii': 28 | matchedParts = ( 29 | [0,5], [1,4], [2,3], 30 | [10,15], [11,14], [12,13] 31 | ) 32 | else: 33 | print('Not supported dataset: ' + dataset) 34 | 35 | # flip output horizontally 36 | flip_output = fliplr(flip_output.numpy()) 37 | 38 | # Change left-right parts 39 | for pair in matchedParts: 40 | tmp = np.copy(flip_output[:, pair[0], :, :]) 41 | flip_output[:, pair[0], :, :] = flip_output[:, pair[1], :, :] 42 | flip_output[:, pair[1], :, :] = tmp 43 | 44 | return torch.from_numpy(flip_output).float() 45 | 46 | 47 | def shufflelr(x, width, dataset='mpii'): 48 | """ 49 | flip coords 50 | """ 51 | if dataset == 'mpii': 52 | matchedParts = ( 53 | [0,5], [1,4], [2,3], 54 | [10,15], [11,14], [12,13] 55 | ) 56 | else: 57 | print('Not supported dataset: ' + dataset) 58 | 59 | # Flip horizontal 60 | x[:, 0] = width - x[:, 0] 61 | 62 | # Change left-right parts 63 | for pair in matchedParts: 64 | tmp = x[pair[0], :].clone() 65 | x[pair[0], :] = x[pair[1], :] 66 | x[pair[1], :] = tmp 67 | 68 | return x 69 | 70 | 71 | def fliplr(x): 72 | if x.ndim == 3: 73 | x = np.transpose(np.fliplr(np.transpose(x, (0, 2, 1))), (0, 2, 1)) 74 | elif x.ndim == 4: 75 | for i in range(x.shape[0]): 76 | x[i] = np.transpose(np.fliplr(np.transpose(x[i], (0, 2, 1))), (0, 2, 1)) 77 | return x.astype(float) 78 | 79 | 80 | def get_transform(center, scale, res, rot=0): 81 | """ 82 | General image processing functions 83 | """ 84 | # Generate transformation matrix 85 | h = 200 * scale 86 | t = np.zeros((3, 3)) 87 | t[0, 0] = float(res[1]) / h 88 | t[1, 1] = float(res[0]) / h 89 | t[0, 2] = res[1] * (-float(center[0]) / h + .5) 90 | t[1, 2] = res[0] * (-float(center[1]) / h + .5) 91 | t[2, 2] = 1 92 | if not rot == 0: 93 | rot = -rot # To match direction of rotation from cropping 94 | rot_mat = np.zeros((3,3)) 95 | rot_rad = rot * np.pi / 180 96 | sn,cs = np.sin(rot_rad), np.cos(rot_rad) 97 | rot_mat[0,:2] = [cs, -sn] 98 | rot_mat[1,:2] = [sn, cs] 99 | rot_mat[2,2] = 1 100 | # Need to rotate around center 101 | t_mat = np.eye(3) 102 | t_mat[0,2] = -res[1]/2 103 | t_mat[1,2] = -res[0]/2 104 | t_inv = t_mat.copy() 105 | t_inv[:2,2] *= -1 106 | t = np.dot(t_inv,np.dot(rot_mat,np.dot(t_mat,t))) 107 | return t 108 | 109 | 110 | def transform(pt, center, scale, res, invert=0, rot=0): 111 | # Transform pixel location to different reference 112 | t = get_transform(center, scale, res, rot=rot) 113 | if invert: 114 | t = np.linalg.inv(t) 115 | new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T 116 | new_pt = np.dot(t, new_pt) 117 | return new_pt[:2].astype(int) + 1 118 | 119 | 120 | def transform_preds(coords, center, scale, res): 121 | # size = coords.size() 122 | # coords = coords.view(-1, coords.size(-1)) 123 | # print(coords.size()) 124 | for p in range(coords.size(0)): 125 | coords[p, 0:2] = to_torch(transform(coords[p, 0:2], center, scale, res, 1, 0)) 126 | return coords 127 | 128 | 129 | def crop(img, center, scale, res, rot=0): 130 | img = im_to_numpy(img) 131 | 132 | # Upper left point 133 | ul = np.array(transform([0, 0], center, scale, res, invert=1)) 134 | # Bottom right point 135 | br = np.array(transform(res, center, scale, res, invert=1)) 136 | 137 | # Padding so that when rotated proper amount of context is included 138 | pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2) 139 | if not rot == 0: 140 | ul -= pad 141 | br += pad 142 | 143 | new_shape = [br[1] - ul[1], br[0] - ul[0]] 144 | if len(img.shape) > 2: 145 | new_shape += [img.shape[2]] 146 | new_img = np.zeros(new_shape) 147 | 148 | # Range to fill new array 149 | new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0] 150 | new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1] 151 | # Range to sample from original image 152 | old_x = max(0, ul[0]), min(len(img[0]), br[0]) 153 | old_y = max(0, ul[1]), min(len(img), br[1]) 154 | new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], old_x[0]:old_x[1]] 155 | 156 | if not rot == 0: 157 | # Remove padding 158 | new_img = scipy.misc.imrotate(new_img, rot) 159 | new_img = new_img[pad:-pad, pad:-pad] 160 | 161 | new_img = im_to_torch(scipy.misc.imresize(new_img, res)) 162 | return new_img 163 | 164 | 165 | def get_right(img,gray=False): 166 | img = im_to_numpy(img) #H*W*C 167 | 168 | new_img = img[:,0:256,:] 169 | 170 | 171 | new_img = im_to_torch(new_img) 172 | if gray == True: 173 | new_img = new_img[1,:,:]; 174 | 175 | return new_img 176 | 177 | class NormalizeInverse(torchvision.transforms.Normalize): 178 | """ 179 | Undoes the normalization and returns the reconstructed images in the input domain. 180 | """ 181 | 182 | def __init__(self, mean, std): 183 | mean = torch.as_tensor(mean) 184 | std = torch.as_tensor(std) 185 | std_inv = 1 / (std + 1e-7) 186 | mean_inv = -mean * std_inv 187 | super().__init__(mean=mean_inv, std=std_inv) 188 | 189 | def __call__(self, tensor): 190 | return super().__call__(tensor.clone()) 191 | -------------------------------------------------------------------------------- /datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets. 2 | 3 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. 4 | """ 5 | import random 6 | import numpy as np 7 | import torch.utils.data as data 8 | #from PIL import Image 9 | import cv2 10 | #import torchvision.transforms as transforms 11 | from abc import ABC, abstractmethod 12 | from albumentations import HorizontalFlip, RandomResizedCrop, Compose, DualTransform 13 | import albumentations.augmentations.transforms as transforms 14 | 15 | class BaseDataset(data.Dataset, ABC): 16 | """This class is an abstract base class (ABC) for datasets. 17 | 18 | To create a subclass, you need to implement the following four functions: 19 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 20 | -- <__len__>: return the size of dataset. 21 | -- <__getitem__>: get a data point. 22 | -- : (optionally) add dataset-specific options and set default options. 23 | """ 24 | 25 | def __init__(self, opt): 26 | """Initialize the class; save the options in the class 27 | 28 | Parameters: 29 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 30 | """ 31 | self.opt = opt 32 | self.root = opt.dataset_root #mia 33 | 34 | @staticmethod 35 | def modify_commandline_options(parser, is_train): 36 | """Add new dataset-specific options, and rewrite default values for existing options. 37 | 38 | Parameters: 39 | parser -- original option parser 40 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 41 | 42 | Returns: 43 | the modified parser. 44 | """ 45 | return parser 46 | 47 | @abstractmethod 48 | def __len__(self): 49 | """Return the total number of images in the dataset.""" 50 | return 0 51 | 52 | @abstractmethod 53 | def __getitem__(self, index): 54 | """Return a data point and its metadata information. 55 | 56 | Parameters: 57 | index - - a random integer for data indexing 58 | 59 | Returns: 60 | a dictionary of data with their names. It ususally contains the data itself and its metadata information. 61 | """ 62 | pass 63 | 64 | class HCompose(Compose): 65 | def __init__(self, transforms, *args, additional_targets=None, no_nearest_for_masks=True, **kwargs): 66 | if additional_targets is None: 67 | additional_targets = { 68 | 'real': 'image', 69 | # 'mask': 'mask' 70 | } 71 | self.additional_targets = additional_targets 72 | super().__init__(transforms, *args, additional_targets=additional_targets, **kwargs) 73 | # if no_nearest_for_masks: 74 | # for t in transforms: 75 | # if isinstance(t, DualTransform): 76 | # t._additional_targets['mask'] = 'image' 77 | # t._additional_targets['edge'] = 'image' 78 | 79 | 80 | def get_params(opt, size): 81 | w, h = size 82 | new_h = h 83 | new_w = w 84 | if opt.preprocess == 'resize_and_crop': 85 | new_h = new_w = opt.load_size 86 | elif opt.preprocess == 'scale_width_and_crop': 87 | new_w = opt.load_size 88 | new_h = opt.load_size * h // w 89 | 90 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) 91 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) 92 | 93 | flip = random.random() > 0.5 94 | 95 | return {'crop_pos': (x, y), 'flip': flip} 96 | 97 | 98 | def get_transform(opt, params=None, grayscale=False, convert=True, additional_targets=None): 99 | transform_list = [] 100 | if grayscale: 101 | transform_list.append(transforms.ToGray()) 102 | if opt.preprocess == 'resize_and_crop': 103 | if params is None: 104 | transform_list.append(RandomResizedCrop(opt.crop_size, opt.crop_size, scale=(0.9, 1.0))) # 0.5,1.0 105 | elif opt.preprocess == 'resize': 106 | transform_list.append(transforms.Resize(opt.input_size, opt.input_size)) 107 | elif opt.preprocess == 'none': 108 | return HCompose(transform_list) 109 | 110 | if not opt.no_flip: 111 | if params is None: 112 | # print("flip") 113 | transform_list.append(HorizontalFlip()) 114 | 115 | return HCompose(transform_list, additional_targets=additional_targets) 116 | 117 | def __make_power_2(img, base): 118 | ow, oh = img.size 119 | h = int(round(oh / base) * base) 120 | w = int(round(ow / base) * base) 121 | if (h == oh) and (w == ow): 122 | return img 123 | 124 | __print_size_warning(ow, oh, w, h) 125 | return cv2.resize(img, (w, h), interpolation = cv2.INTER_LINEAR) 126 | 127 | 128 | ''' 129 | def __make_power_2(img, base, method=Image.BICUBIC): 130 | ow, oh = img.size 131 | h = int(round(oh / base) * base) 132 | w = int(round(ow / base) * base) 133 | if (h == oh) and (w == ow): 134 | return img 135 | 136 | __print_size_warning(ow, oh, w, h) 137 | return img.resize((w, h), method) 138 | 139 | 140 | def __scale_width(img, target_width, method=Image.BICUBIC): 141 | ow, oh = img.size 142 | if (ow == target_width): 143 | return img 144 | w = target_width 145 | h = int(target_width * oh / ow) 146 | return img.resize((w, h), method) 147 | 148 | 149 | def __crop(img, pos, size): 150 | ow, oh = img.size 151 | x1, y1 = pos 152 | tw = th = size 153 | if (ow > tw or oh > th): 154 | return img.crop((x1, y1, x1 + tw, y1 + th)) 155 | return img 156 | 157 | 158 | def __flip(img, flip): 159 | if flip: 160 | return img.transpose(Image.FLIP_LEFT_RIGHT) 161 | return img 162 | ''' 163 | 164 | def __print_size_warning(ow, oh, w, h): 165 | """Print warning information about image size(only print once)""" 166 | if not hasattr(__print_size_warning, 'has_printed'): 167 | print("The image size needs to be a multiple of 4. " 168 | "The loaded image size was (%d, %d), so it was adjusted to " 169 | "(%d, %d). This adjustment will be done to all images " 170 | "whose sizes are not multiples of 4" % (ow, oh, w, h)) 171 | __print_size_warning.has_printed = True 172 | -------------------------------------------------------------------------------- /src/networks/discriminator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import functools 3 | import math 4 | import torch 5 | from torch.autograd import Variable 6 | import torch.nn.functional as F 7 | from torch import nn 8 | from torch import Tensor 9 | from torch.nn import Parameter 10 | from src.utils.model_init import * 11 | from torch.optim.optimizer import Optimizer, required 12 | 13 | 14 | __all__ = ['patchgan','sngan','maskedsngan'] 15 | 16 | 17 | class SNCoXvWithActivation(torch.nn.Module): 18 | """ 19 | SN convolution for spetral normalization conv 20 | """ 21 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, activation=torch.nn.LeakyReLU(0.2, inplace=True)): 22 | super(SNCoXvWithActivation, self).__init__() 23 | self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) 24 | self.conv2d = torch.nn.utils.spectral_norm(self.conv2d) 25 | self.activation = activation 26 | for m in self.modules(): 27 | if isinstance(m, nn.Conv2d): 28 | nn.init.kaiming_normal_(m.weight) 29 | def forward(self, input): 30 | x = self.conv2d(input) 31 | if self.activation is not None: 32 | return self.activation(x) 33 | else: 34 | return x 35 | 36 | def l2normalize(v, eps=1e-12): 37 | return v / (v.norm() + eps) 38 | 39 | 40 | class SpectralNorm(nn.Module): 41 | def __init__(self, module, name='weight', power_iterations=1): 42 | super(SpectralNorm, self).__init__() 43 | self.module = module 44 | self.name = name 45 | self.power_iterations = power_iterations 46 | if not self._made_params(): 47 | self._make_params() 48 | 49 | def _update_u_v(self): 50 | u = getattr(self.module, self.name + "_u") 51 | v = getattr(self.module, self.name + "_v") 52 | w = getattr(self.module, self.name + "_bar") 53 | 54 | height = w.data.shape[0] 55 | for _ in range(self.power_iterations): 56 | v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data)) 57 | u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data)) 58 | 59 | # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data)) 60 | sigma = u.dot(w.view(height, -1).mv(v)) 61 | setattr(self.module, self.name, w / sigma.expand_as(w)) 62 | 63 | def _made_params(self): 64 | try: 65 | u = getattr(self.module, self.name + "_u") 66 | v = getattr(self.module, self.name + "_v") 67 | w = getattr(self.module, self.name + "_bar") 68 | return True 69 | except AttributeError: 70 | return False 71 | 72 | 73 | def _make_params(self): 74 | w = getattr(self.module, self.name) 75 | 76 | height = w.data.shape[0] 77 | width = w.view(height, -1).data.shape[1] 78 | 79 | u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) 80 | v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) 81 | u.data = l2normalize(u.data) 82 | v.data = l2normalize(v.data) 83 | w_bar = Parameter(w.data) 84 | 85 | del self.module._parameters[self.name] 86 | 87 | self.module.register_parameter(self.name + "_u", u) 88 | self.module.register_parameter(self.name + "_v", v) 89 | self.module.register_parameter(self.name + "_bar", w_bar) 90 | 91 | 92 | def forward(self, *args): 93 | self._update_u_v() 94 | return self.module.forward(*args) 95 | 96 | 97 | def get_pad(in_, ksize, stride, atrous=1): 98 | out_ = np.ceil(float(in_)/stride) 99 | return int(((out_ - 1) * stride + atrous*(ksize-1) + 1 - in_)/2) 100 | 101 | class SNDiscriminator(nn.Module): 102 | def __init__(self,channel=6): 103 | super(SNDiscriminator, self).__init__() 104 | cnum = 32 105 | self.discriminator_net = nn.Sequential( 106 | SNCoXvWithActivation(channel, 2*cnum, 4, 2, padding=get_pad(256, 5, 2)), 107 | SNCoXvWithActivation(2*cnum, 4*cnum, 4, 2, padding=get_pad(128, 5, 2)), 108 | SNCoXvWithActivation(4*cnum, 8*cnum, 4, 2, padding=get_pad(64, 5, 2)), 109 | SNCoXvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(32, 5, 2)), 110 | SNCoXvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(16, 5, 2)), # 8*8*256 111 | # SNConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(8, 5, 2)), # 4*4*256 112 | # SNConvWithActivation(8*cnum, 8*cnum, 4, 2, padding=get_pad(4, 5, 2)), # 2*2*256 113 | ) 114 | # self.linear = nn.Linear(2*2*256,1) 115 | 116 | def forward(self, img_A, img_B): 117 | # Concatenate image and condition image by channels to produce input 118 | img_input = torch.cat((img_A, img_B), 1) 119 | x = self.discriminator_net(img_input) 120 | # x = x.view((x.size(0),-1)) 121 | # x = self.linear(x) 122 | return x 123 | 124 | class Discriminator(nn.Module): 125 | def __init__(self, in_channels=3): 126 | super(Discriminator, self).__init__() 127 | 128 | def discriminator_block(in_filters, out_filters, normalization=True): 129 | """Returns downsampling layers of each discriminator block""" 130 | layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)] 131 | if normalization: 132 | layers.append(nn.InstanceNorm2d(out_filters)) 133 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 134 | return layers 135 | 136 | self.model = nn.Sequential( 137 | *discriminator_block(in_channels*2, 64, normalization=False), 138 | *discriminator_block(64, 128), 139 | *discriminator_block(128, 256), 140 | *discriminator_block(256, 512), 141 | nn.ZeroPad2d((1, 0, 1, 0)), 142 | nn.Conv2d(512, 1, 4, padding=1, bias=False) 143 | ) 144 | 145 | def forward(self, img_A, img_B): 146 | # Concatenate image and condition image by channels to produce input 147 | img_input = torch.cat((img_A, img_B), 1) 148 | return self.model(img_input) 149 | 150 | 151 | def patchgan(): 152 | model = Discriminator() 153 | model.apply(weights_init_kaiming) 154 | return model 155 | 156 | def sngan(): 157 | model = SNDiscriminator() 158 | model.apply(weights_init_kaiming) 159 | return model 160 | 161 | def maskedsngan(): 162 | model = SNDiscriminator(channel=7) 163 | model.apply(weights_init_kaiming) 164 | return model -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | import argparse 4 | import torch 5 | import os 6 | from math import log10 7 | import cv2 8 | import numpy as np 9 | 10 | torch.backends.cudnn.benchmark = True 11 | 12 | import datasets as datasets 13 | import src.models as models 14 | from options import Options 15 | import torch.nn.functional as F 16 | import pytorch_ssim 17 | from evaluation import compute_IoU, FScore, AverageMeter, compute_RMSE, normPRED 18 | from skimage.measure import compare_ssim as ssim 19 | import time 20 | 21 | 22 | def is_dic(x): 23 | return type(x) == type([]) 24 | 25 | 26 | 27 | def tensor2np(x, isMask=False): 28 | if isMask: 29 | if x.shape[1] == 1: 30 | x = x.repeat(1,3,1,1) 31 | x = ((x.cpu().detach()))*255 32 | else: 33 | x = x.cpu().detach() 34 | mean = 0 35 | std = 1 36 | x = (x * std + mean)*255 37 | 38 | return x.numpy().transpose(0,2,3,1).astype(np.uint8) 39 | 40 | def save_output(inputs, preds, save_dir, img_fn, extra_infos=None, verbose=False, alpha=0.5): 41 | outs = [] 42 | image, bg_gt,mask_gt = inputs['I'], inputs['bg'], inputs['mask'] 43 | image = cv2.cvtColor(tensor2np(image)[0], cv2.COLOR_RGB2BGR) 44 | # fg_gt = cv2.cvtColor(tensor2np(fg_gt)[0], cv2.COLOR_RGB2BGR) 45 | bg_gt = cv2.cvtColor(tensor2np(bg_gt)[0], cv2.COLOR_RGB2BGR) 46 | mask_gt = tensor2np(mask_gt, isMask=True)[0] 47 | 48 | bg_pred,mask_preds = preds['bg'], preds['mask'] 49 | # fg_pred = cv2.cvtColor(tensor2np(fg_pred)[0], cv2.COLOR_RGB2BGR) 50 | bg_pred = cv2.cvtColor(tensor2np(bg_pred)[0], cv2.COLOR_RGB2BGR) 51 | mask_preds = [tensor2np(m, isMask=True)[0] for m in mask_preds] 52 | main_mask = mask_preds[-2] 53 | mask_pred = mask_preds[0] 54 | outs = [image, bg_gt, bg_pred, mask_gt, mask_pred] #, main_mask] 55 | outimg = np.concatenate(outs, axis=1) 56 | 57 | if verbose==True: 58 | # print("show") 59 | cv2.imshow("out",outimg) 60 | cv2.waitKey(0) 61 | else: 62 | psnr = extra_infos['psnr'] 63 | rmsew = extra_infos['rmsew'] 64 | f1 = extra_infos['f1'] 65 | 66 | img_fn = os.path.split(img_fn)[-1] 67 | out_fn = os.path.join(save_dir, "{}_psnr_{:.2f}_rmsew_{:.2f}_f1_{:.4f}{}".format(os.path.splitext(img_fn)[0],psnr,rmsew, f1, os.path.splitext(img_fn)[1])) 68 | cv2.imwrite(out_fn, outimg) 69 | 70 | 71 | 72 | 73 | 74 | def main(args): 75 | args.dataset = args.dataset.lower() 76 | if args.dataset == 'clwd': 77 | dataset_func = datasets.CLWDDataset 78 | elif args.dataset == 'lvw': 79 | dataset_func = datasets.LVWDataset 80 | 81 | val_loader = torch.utils.data.DataLoader(dataset_func('test',args),batch_size=args.test_batch, shuffle=False, 82 | num_workers=args.workers, pin_memory=True) 83 | data_loaders = (None,val_loader) 84 | 85 | Machine = models.__dict__[args.models](datasets=data_loaders, args=args) 86 | 87 | 88 | model = Machine 89 | model.model.eval() 90 | print("==> testing VM model ") 91 | rmses = AverageMeter() 92 | rmsews = AverageMeter() 93 | ssimesx = AverageMeter() 94 | psnresx = AverageMeter() 95 | maskIoU = AverageMeter() 96 | maskF1 = AverageMeter() 97 | prime_maskIoU = AverageMeter() 98 | prime_maskF1 = AverageMeter() 99 | processTime = AverageMeter() 100 | 101 | prediction_dir = os.path.join(args.checkpoint,'rst') 102 | if not os.path.exists(prediction_dir): os.makedirs(prediction_dir) 103 | 104 | save_flag = False 105 | with torch.no_grad(): 106 | for i, batches in enumerate(model.val_loader): 107 | 108 | inputs = batches['image'].to(model.device) 109 | target = batches['target'].to(model.device) 110 | mask =batches['mask'].to(model.device) 111 | wm = batches['wm'].float().to(model.device) 112 | img_path = batches['img_path'] 113 | 114 | # select the outputs by the giving arch 115 | start_time = time.time() 116 | outputs = model.model(model.norm(inputs)) 117 | process_time = time.time() - start_time 118 | processTime.update((process_time*1000), inputs.size(0)) 119 | 120 | imoutput,immask_all,imwatermark = outputs 121 | imoutput = imoutput[0] if is_dic(imoutput) else imoutput 122 | 123 | immask = immask_all[0] 124 | 125 | imfinal =imoutput*immask + model.norm(inputs)*(1-immask) 126 | psnrx = 10 * log10(1 / F.mse_loss(imfinal,target).item()) 127 | final_np = (imfinal.detach().cpu().numpy()[0].transpose(1,2,0)*255).astype(np.uint8) 128 | target_np = (target.detach().cpu().numpy()[0].transpose(1,2,0)*255).astype(np.uint8) 129 | # ssimx = ssim(final_np, target_np, multichannel=True) 130 | ssimx = pytorch_ssim.ssim(imfinal, target) 131 | 132 | 133 | 134 | rmsex = compute_RMSE(imfinal, target, mask, is_w=False) 135 | rmsewx = compute_RMSE(imfinal, target, mask, is_w=True) 136 | rmses.update(rmsex, inputs.size(0)) 137 | rmsews.update(rmsewx, inputs.size(0)) 138 | psnresx.update(psnrx, inputs.size(0)) 139 | ssimesx.update(ssimx, inputs.size(0)) 140 | 141 | 142 | main_mask = immask_all[1::2] 143 | comp_mask = immask_all[2::2] 144 | out_mask = main_mask[-1] 145 | comp_mask = comp_mask[-1] 146 | 147 | comp_sets = [] 148 | prime_mask_pred = torch.where(out_mask > 0.5, torch.ones_like(out_mask), torch.zeros_like(out_mask)).to(out_mask.device) 149 | mask_pred = torch.where(comp_mask > 0.5, torch.ones_like(out_mask), torch.zeros_like(out_mask)).to(out_mask.device) 150 | 151 | iou = compute_IoU(prime_mask_pred, mask) 152 | prime_maskIoU.update(iou) 153 | f1 = FScore(prime_mask_pred, mask).item() 154 | prime_maskF1.update(f1, inputs.size(0)) 155 | 156 | iou = compute_IoU(mask_pred, mask) 157 | maskIoU.update(iou) 158 | f1 = FScore(mask_pred, mask).item() 159 | maskF1.update(f1, inputs.size(0)) 160 | 161 | if save_flag: 162 | save_output( 163 | inputs={'I':inputs, 'bg':target, 'mask':mask}, 164 | preds={'bg':imfinal, 'mask':immask_all}, 165 | save_dir=prediction_dir, 166 | img_fn=img_path[0], 167 | extra_infos={"psnr":psnrx, "rmsew":rmsewx, "f1":f1}, 168 | verbose=False 169 | ) 170 | if i % 100 == 0: 171 | print("Batch[%d/%d]| PSNR:%.4f | SSIM:%.4f | RMSE:%.4f | RMSEw:%.4f | primeIoU:%.4f, primeF1:%.4f | maskIoU:%.4f | maskF1:%.4f | time:%.2f" 172 | %(i,len(model.val_loader),psnresx.avg,ssimesx.avg, rmses.avg, rmsews.avg, prime_maskIoU.avg, prime_maskF1.avg, maskIoU.avg, maskF1.avg, processTime.avg)) 173 | print("Total:\nPSNR:%.4f | SSIM:%.4f | RMSE:%.4f | RMSEw:%.4f | primeIoU:%.4f, primeF1:%.4f | maskIoU:%.4f | maskF1:%.4f | time:%.2f" 174 | %(psnresx.avg,ssimesx.avg, rmses.avg, rmsews.avg, prime_maskIoU.avg, prime_maskF1.avg, maskIoU.avg, maskF1.avg, processTime.avg)) 175 | print("DONE.\n") 176 | 177 | 178 | if __name__ == '__main__': 179 | parser=Options().init(argparse.ArgumentParser(description='WaterMark Removal')) 180 | main(parser.parse_args()) 181 | 182 | -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | 2 | import src.networks as networks 3 | 4 | model_names = sorted(name for name in networks.__dict__ 5 | if name.islower() and not name.startswith("__") 6 | and callable(networks.__dict__[name])) 7 | 8 | class Options(): 9 | """docstring for Options""" 10 | def __init__(self): 11 | pass 12 | 13 | def init(self, parser): 14 | # Model structure 15 | parser.add_argument('--nets', '-n', metavar='NET', default='dhn', 16 | choices=model_names, 17 | help='model architecture: ' + 18 | ' | '.join(model_names) + 19 | ' (default: resnet18)') 20 | 21 | parser.add_argument('--models', '-m', metavar='NACHINE', default='basic') 22 | # Training strategy 23 | parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', 24 | help='number of data loading workers (default: 4)') 25 | parser.add_argument('--epochs', default=30, type=int, metavar='N', 26 | help='number of total epochs to run') 27 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 28 | help='manual epoch number (useful on restarts)') 29 | parser.add_argument('--train-batch', default=64, type=int, metavar='N', 30 | help='train batchsize') 31 | parser.add_argument('--test-batch', default=6, type=int, metavar='N', 32 | help='test batchsize') 33 | parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float,metavar='LR', help='initial learning rate') 34 | parser.add_argument('--dlr', '--dlearning-rate', default=1e-3, type=float, help='initial learning rate') 35 | parser.add_argument('--beta1', default=0.9, type=float, help='initial learning rate') 36 | parser.add_argument('--beta2', default=0.999, type=float, help='initial learning rate') 37 | parser.add_argument('--momentum', default=0, type=float, metavar='M', 38 | help='momentum') 39 | parser.add_argument('--weight-decay', '--wd', default=0, type=float, 40 | metavar='W', help='weight decay (default: 0)') 41 | parser.add_argument('--schedule', type=int, nargs='+', default=[5, 10], 42 | help='Decrease learning rate at these epochs.') 43 | parser.add_argument('--gamma', type=float, default=0.1, 44 | help='LR is multiplied by gamma on schedule.') 45 | # Data processing 46 | parser.add_argument('-f', '--flip', dest='flip', action='store_true', 47 | help='flip the input during validation') 48 | 49 | parser.add_argument('--lambda_l1', type=float, default=4, help='the weight of L1.') 50 | parser.add_argument('--lambda_primary', type=float, default=0.01, help='the weight of primary mask prediction.') 51 | parser.add_argument('--lambda_style', default=0, type=float, 52 | help='preception loss') 53 | parser.add_argument('--lambda_content', default=0, type=float, 54 | help='preception loss') 55 | 56 | parser.add_argument('--lambda_iou', default=0, type=float,help='msiou loss') 57 | parser.add_argument('--lambda_mask', default=1, type=float,help='mask loss') 58 | 59 | parser.add_argument('--sltype', default='vggx', type=str) 60 | 61 | parser.add_argument('--alpha', type=float, default=0.5, 62 | help='Groundtruth Gaussian sigma.') 63 | parser.add_argument('--sigma-decay', type=float, default=0, 64 | help='Sigma decay rate for each epoch.') 65 | # Miscs 66 | parser.add_argument('--dataset_dir', default='/PATH_TO_DATA_FOLDER/', type=str, metavar='PATH') 67 | parser.add_argument('--test_dir', default='/PATH_TO_DATA_FOLDER/', type=str, metavar='PATH') 68 | 69 | parser.add_argument('--data', default='', type=str, metavar='PATH', 70 | help='path to save checkpoint (default: checkpoint)') 71 | parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', 72 | help='path to save checkpoint (default: checkpoint)') 73 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 74 | help='path to latest checkpoint (default: none)') 75 | parser.add_argument('--finetune', default='', type=str, metavar='PATH', 76 | help='path to latest checkpoint (default: none)') 77 | 78 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 79 | help='evaluate model on validation set') 80 | 81 | parser.add_argument('-da', '--data-augumentation', default=False, type=bool, 82 | help='preception loss') 83 | parser.add_argument('-d', '--debug', dest='debug', action='store_true', 84 | help='show intermediate results') 85 | parser.add_argument('--input-size', default=256, type=int, metavar='N', 86 | help='train batchsize') 87 | parser.add_argument('--freq', default=-1, type=int, metavar='N', 88 | help='evaluation frequence') 89 | parser.add_argument('--normalized-input', default=False, type=bool, 90 | help='train batchsize') 91 | parser.add_argument('--res', default=False, type=bool,help='residual learning for s2am') 92 | parser.add_argument('--requires-grad', default=False, type=bool, 93 | help='train batchsize') 94 | 95 | parser.add_argument('--gpu',default=True,type=bool) 96 | parser.add_argument('--gpu_id',default='0',type=str) 97 | parser.add_argument('--preprocess',default='resize_crop',type=str) 98 | parser.add_argument('--crop_size',default=256,type=int) 99 | parser.add_argument('--no_flip',action='store_true') 100 | parser.add_argument('--masked',default=False,type=bool) 101 | parser.add_argument('--gan-norm', default=False,type=bool, help='train batchsize') 102 | parser.add_argument('--hl', default=False,type=bool, help='homogenious leanring') 103 | parser.add_argument('--loss-type', default='l2',type=str, help='train batchsize') 104 | 105 | parser.add_argument('--dataset', default='clwd',type=str, help='train batchsize') 106 | parser.add_argument('--name', default='v2',type=str, help='train batchsize') 107 | 108 | parser.add_argument('--sim_metric', default='cos',type=str, help='train batchsize') 109 | parser.add_argument('--k_center', default=1,type=int, help='train batchsize') 110 | parser.add_argument('--project_mode', default='simple',type=str, help='train batchsize') 111 | parser.add_argument('--mask_mode', default='cat',type=str, help='train batchsize') # vanilla, cat, ca, psp 112 | parser.add_argument('--bg_mode', default='res_mask',type=str, help='train batchsize') # vanilla, res_mask, res_feat, proposed 113 | parser.add_argument('--use_refine', action='store_true', help='train batchsize') 114 | parser.add_argument('--k_refine', default=3, type=int, help='train batchsize') 115 | parser.add_argument('--k_skip_stage', default=3, type=int, help='train batchsize') 116 | 117 | return parser -------------------------------------------------------------------------------- /src/utils/imutils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | import scipy.misc 7 | 8 | from .misc import * 9 | 10 | def im_to_numpy(img): 11 | img = to_numpy(img) 12 | img = np.transpose(img, (1, 2, 0)) # H*W*C 13 | return img 14 | 15 | def im_to_torch(img): 16 | img = np.transpose(img, (2, 0, 1)) # C*H*W 17 | img = to_torch(img).float() 18 | if img.max() > 1: 19 | img /= 255 20 | return img 21 | 22 | def load_image(img_path): 23 | # H x W x C => C x H x W 24 | return im_to_torch(scipy.misc.imread(img_path, mode='RGB')) 25 | 26 | def imread_all(img_path): 27 | return scipy.misc.imread(img_path, mode='RGB') 28 | 29 | def load_image_gray(img_path): 30 | # H x W x C => C x H x W 31 | x = scipy.misc.imread(img_path, mode='L') 32 | x = x[:,:,np.newaxis] 33 | return im_to_torch(x) 34 | 35 | def resize(img, owidth, oheight): 36 | img = im_to_numpy(img) 37 | 38 | if img.shape[2] == 1: 39 | img = scipy.misc.imresize(img.squeeze(),(oheight,owidth)) 40 | img = img[:,:,np.newaxis] 41 | else: 42 | img = scipy.misc.imresize( 43 | img, 44 | (oheight, owidth) 45 | ) 46 | img = im_to_torch(img) 47 | # print('%f %f' % (img.min(), img.max())) 48 | return img 49 | 50 | # ============================================================================= 51 | # Helpful functions generating groundtruth labelmap 52 | # ============================================================================= 53 | 54 | def gaussian(shape=(7,7),sigma=1): 55 | """ 56 | 2D gaussian mask - should give the same result as MATLAB's 57 | fspecial('gaussian',[shape],[sigma]) 58 | """ 59 | m,n = [(ss-1.)/2. for ss in shape] 60 | y,x = np.ogrid[-m:m+1,-n:n+1] 61 | h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) ) 62 | h[ h < np.finfo(h.dtype).eps*h.max() ] = 0 63 | return to_torch(h).float() 64 | 65 | def draw_labelmap(img, pt, sigma, type='Gaussian'): 66 | # Draw a 2D gaussian 67 | # Adopted from https://github.com/anewell/pose-hg-train/blob/master/src/pypose/draw.py 68 | img = to_numpy(img) 69 | 70 | # Check that any part of the gaussian is in-bounds 71 | ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)] 72 | br = [int(pt[0] + 3 * sigma + 1), int(pt[1] + 3 * sigma + 1)] 73 | if (ul[0] >= img.shape[1] or ul[1] >= img.shape[0] or 74 | br[0] < 0 or br[1] < 0): 75 | # If not, just return the image as is 76 | return to_torch(img) 77 | 78 | # Generate gaussian 79 | size = 6 * sigma + 1 80 | x = np.arange(0, size, 1, float) 81 | y = x[:, np.newaxis] 82 | x0 = y0 = size // 2 83 | # The gaussian is not normalized, we want the center value to equal 1 84 | if type == 'Gaussian': 85 | g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2)) 86 | elif type == 'Cauchy': 87 | g = sigma / (((x - x0) ** 2 + (y - y0) ** 2 + sigma ** 2) ** 1.5) 88 | 89 | 90 | # Usable gaussian range 91 | g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0] 92 | g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1] 93 | # Image range 94 | img_x = max(0, ul[0]), min(br[0], img.shape[1]) 95 | img_y = max(0, ul[1]), min(br[1], img.shape[0]) 96 | 97 | img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]] 98 | return to_torch(img) 99 | 100 | # ============================================================================= 101 | # Helpful display functions 102 | # ============================================================================= 103 | 104 | def gauss(x, a, b, c, d=0): 105 | return a * np.exp(-(x - b)**2 / (2 * c**2)) + d 106 | 107 | def color_heatmap(x): 108 | x = to_numpy(x) 109 | color = np.zeros((x.shape[0],x.shape[1],3)) 110 | color[:,:,0] = gauss(x, .5, .6, .2) + gauss(x, 1, .8, .3) 111 | color[:,:,1] = gauss(x, 1, .5, .3) 112 | color[:,:,2] = gauss(x, 1, .2, .3) 113 | color[color > 1] = 1 114 | color = (color * 255).astype(np.uint8) 115 | return color 116 | 117 | def imshow(img): 118 | npimg = im_to_numpy(img*255).astype(np.uint8) 119 | plt.imshow(npimg) 120 | plt.axis('off') 121 | 122 | def show_joints(img, pts): 123 | imshow(img) 124 | 125 | for i in range(pts.size(0)): 126 | if pts[i, 2] > 0: 127 | plt.plot(pts[i, 0], pts[i, 1], 'yo') 128 | plt.axis('off') 129 | 130 | def show_sample(inputs, target): 131 | num_sample = inputs.size(0) 132 | num_joints = target.size(1) 133 | height = target.size(2) 134 | width = target.size(3) 135 | 136 | for n in range(num_sample): 137 | inp = resize(inputs[n], width, height) 138 | out = inp 139 | for p in range(num_joints): 140 | tgt = inp*0.5 + color_heatmap(target[n,p,:,:])*0.5 141 | out = torch.cat((out, tgt), 2) 142 | 143 | imshow(out) 144 | plt.show() 145 | 146 | def sample_with_heatmap(inp, out, num_rows=2, parts_to_show=None): 147 | inp = to_numpy(inp * 255) 148 | out = to_numpy(out) 149 | 150 | img = np.zeros((inp.shape[1], inp.shape[2], inp.shape[0])) 151 | for i in range(3): 152 | img[:, :, i] = inp[i, :, :] 153 | 154 | if parts_to_show is None: 155 | parts_to_show = np.arange(out.shape[0]) 156 | 157 | # Generate a single image to display input/output pair 158 | num_cols = int(np.ceil(float(len(parts_to_show)) / num_rows)) 159 | size = img.shape[0] // num_rows 160 | 161 | full_img = np.zeros((img.shape[0], size * (num_cols + num_rows), 3), np.uint8) 162 | full_img[:img.shape[0], :img.shape[1]] = img 163 | 164 | inp_small = scipy.misc.imresize(img, [size, size]) 165 | 166 | # Set up heatmap display for each part 167 | for i, part in enumerate(parts_to_show): 168 | part_idx = part 169 | out_resized = scipy.misc.imresize(out[part_idx], [size, size]) 170 | out_resized = out_resized.astype(float)/255 171 | out_img = inp_small.copy() * .3 172 | color_hm = color_heatmap(out_resized) 173 | out_img += color_hm * .7 174 | 175 | col_offset = (i % num_cols + num_rows) * size 176 | row_offset = (i // num_cols) * size 177 | full_img[row_offset:row_offset + size, col_offset:col_offset + size] = out_img 178 | 179 | return full_img 180 | 181 | def batch_with_heatmap(inputs, outputs, mean=torch.Tensor([0.5, 0.5, 0.5]), num_rows=2, parts_to_show=None): 182 | batch_img = [] 183 | for n in range(min(inputs.size(0), 4)): 184 | inp = inputs[n] + mean.view(3, 1, 1).expand_as(inputs[n]) 185 | batch_img.append( 186 | sample_with_heatmap(inp.clamp(0, 1), outputs[n], num_rows=num_rows, parts_to_show=parts_to_show) 187 | ) 188 | return np.concatenate(batch_img) 189 | 190 | 191 | def normalize_batch(batch): 192 | # normalize using imagenet mean and std 193 | mean = batch.new_tensor([0.485, 0.456, 0.406]).view(-1, 1, 1) 194 | std = batch.new_tensor([0.229, 0.224, 0.225]).view(-1, 1, 1) 195 | batch = batch/255.0 196 | return (batch - mean) / std 197 | 198 | def show_image_tensor(tensor): 199 | re = [] 200 | for i in range(tensor.size(0)): 201 | inp = tensor[i].data.cpu() #w,h,c 202 | inp = inp.numpy().transpose((1, 2, 0)) 203 | mean = np.array([0.485, 0.456, 0.406]) 204 | std = np.array([0.229, 0.224, 0.225]) 205 | inp = std * inp + mean 206 | inp = np.clip(inp, 0, 1).transpose((2,0,1)) 207 | re.append(torch.from_numpy(inp).unsqueeze(0)) 208 | return torch.cat(re,0) 209 | 210 | 211 | def get_jet(): 212 | colormap_int = np.zeros((256, 3), np.uint8) 213 | 214 | for i in range(0, 256, 1): 215 | colormap_int[i, 0] = np.int_(np.round(cm.jet(i)[0] * 255.0)) 216 | colormap_int[i, 1] = np.int_(np.round(cm.jet(i)[1] * 255.0)) 217 | colormap_int[i, 2] = np.int_(np.round(cm.jet(i)[2] * 255.0)) 218 | 219 | return colormap_int 220 | 221 | def clamp(num, min_value, max_value): 222 | return max(min(num, max_value), min_value) 223 | 224 | def gray2color(gray_array, color_map): 225 | 226 | rows, cols = gray_array.shape 227 | color_array = np.zeros((rows, cols, 3), np.uint8) 228 | 229 | for i in range(0, rows): 230 | for j in range(0, cols): 231 | # log(256,2) = 8 , log(1,2) = 0 * 8 232 | color_array[i, j] = color_map[clamp(int(abs(gray_array[i, j])*10),0,255)] 233 | 234 | return color_array 235 | 236 | class objectview(object): 237 | def __init__(self, *args, **kwargs): 238 | d = dict(*args, **kwargs) 239 | self.__dict__ = d -------------------------------------------------------------------------------- /src/models/BasicModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.backends.cudnn as cudnn 4 | from progress.bar import Bar 5 | import json 6 | import numpy as np 7 | from tensorboardX import SummaryWriter 8 | 9 | import torch.optim 10 | import sys,shutil,os 11 | import time 12 | import src.networks as nets 13 | from math import log10 14 | import skimage.io 15 | from skimage.measure import compare_psnr,compare_ssim 16 | 17 | from evaluation import AverageMeter 18 | import pytorch_ssim as pytorch_ssim 19 | from src.utils.osutils import mkdir_p, isfile, isdir, join 20 | from src.utils.parallel import DataParallelModel, DataParallelCriterion 21 | from src.utils.losses import VGGLoss 22 | 23 | 24 | 25 | 26 | class BasicModel(object): 27 | def __init__(self, datasets =(None,None), models = None, args = None, **kwargs): 28 | super(BasicModel, self).__init__() 29 | 30 | self.args = args 31 | 32 | # create model 33 | print("==> creating model ") 34 | self.model = nets.__dict__[self.args.nets](args=args) 35 | print("==> creating model [Finish]") 36 | 37 | self.train_loader, self.val_loader = datasets 38 | self.loss = torch.nn.MSELoss() 39 | 40 | self.title = args.name 41 | self.args.checkpoint = os.path.join(args.checkpoint, self.title) 42 | self.device = torch.device('cuda') 43 | # create checkpoint dir 44 | if not isdir(self.args.checkpoint): 45 | mkdir_p(self.args.checkpoint) 46 | 47 | self.optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), 48 | lr=args.lr, 49 | betas=(args.beta1,args.beta2), 50 | weight_decay=args.weight_decay) 51 | 52 | if not self.args.evaluate: 53 | self.writer = SummaryWriter(self.args.checkpoint+'/'+'ckpt') 54 | 55 | self.best_acc = 0 56 | self.is_best = False 57 | self.current_epoch = 0 58 | self.metric = -100000 59 | self.hl = 6 if self.args.hl else 1 60 | self.count_gpu = len(range(torch.cuda.device_count())) 61 | 62 | if self.args.lambda_style > 0: 63 | # init perception loss 64 | self.vggloss = VGGLoss(self.args.sltype).to(self.device) 65 | 66 | if self.count_gpu > 1 : # multiple 67 | # self.model = DataParallelModel(self.model, device_ids=range(torch.cuda.device_count())) 68 | # self.loss = DataParallelCriterion(self.loss, device_ids=range(torch.cuda.device_count())) 69 | self.model.multi_gpu() 70 | 71 | self.model.to(self.device) 72 | self.loss.to(self.device) 73 | 74 | print('==> Total params: %.2fM' % (sum(p.numel() for p in self.model.parameters())/1000000.0)) 75 | print('==> Total devices: %d' % (torch.cuda.device_count())) 76 | print('==> Current Checkpoint: %s' % (self.args.checkpoint)) 77 | 78 | 79 | def train(self,epoch): 80 | batch_time = AverageMeter() 81 | data_time = AverageMeter() 82 | losses = AverageMeter() 83 | lossvgg = AverageMeter() 84 | 85 | # switch to train mode 86 | self.model.train() 87 | end = time.time() 88 | 89 | bar = Bar('Processing', max=len(self.train_loader)*self.hl) 90 | for _ in range(self.hl): 91 | for i, batches in enumerate(self.train_loader): 92 | # measure data loading time 93 | inputs = batches['image'] 94 | target = batches['target'].to(self.device) 95 | mask =batches['mask'].to(self.device) 96 | current_index = len(self.train_loader) * epoch + i 97 | 98 | if self.args.hl: 99 | feeded = torch.cat([inputs,mask],dim=1) 100 | else: 101 | feeded = inputs 102 | feeded = feeded.to(self.device) 103 | 104 | output = self.model(feeded) 105 | L2_loss = self.loss(output,target) 106 | 107 | if self.args.lambda_style > 0: 108 | vgg_loss = self.vggloss(output,target,mask) 109 | else: 110 | vgg_loss = 0 111 | 112 | total_loss = L2_loss + self.args.lambda_style * vgg_loss 113 | 114 | # compute gradient and do SGD step 115 | self.optimizer.zero_grad() 116 | total_loss.backward() 117 | self.optimizer.step() 118 | 119 | # measure accuracy and record loss 120 | losses.update(L2_loss.item(), inputs.size(0)) 121 | 122 | if self.args.lambda_style > 0 : 123 | lossvgg.update(vgg_loss.item(), inputs.size(0)) 124 | 125 | # measure elapsed time 126 | batch_time.update(time.time() - end) 127 | end = time.time() 128 | 129 | # plot progress 130 | suffix = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss L2: {loss_label:.4f} | Loss VGG: {loss_vgg:.4f}'.format( 131 | batch=i + 1, 132 | size=len(self.train_loader), 133 | data=data_time.val, 134 | bt=batch_time.val, 135 | total=bar.elapsed_td, 136 | eta=bar.eta_td, 137 | loss_label=losses.avg, 138 | loss_vgg=lossvgg.avg 139 | ) 140 | 141 | if current_index % 1000 == 0: 142 | print(suffix) 143 | 144 | if self.args.freq > 0 and current_index % self.args.freq == 0: 145 | self.validate(current_index) 146 | self.flush() 147 | self.save_checkpoint() 148 | 149 | self.record('train/loss_L2', losses.avg, current_index) 150 | 151 | def validate(self, epoch): 152 | batch_time = AverageMeter() 153 | data_time = AverageMeter() 154 | losses = AverageMeter() 155 | ssimes = AverageMeter() 156 | psnres = AverageMeter() 157 | # switch to evaluate mode 158 | self.model.eval() 159 | 160 | end = time.time() 161 | with torch.no_grad(): 162 | for i, batches in enumerate(self.val_loader): 163 | 164 | inputs = batches['image'].to(self.device) 165 | target = batches['target'].to(self.device) 166 | mask =batches['mask'].to(self.device) 167 | 168 | if self.args.hl: 169 | feeded = torch.cat([inputs,torch.zeros((1,4,self.args.input_size,self.args.input_size)).to(self.device)],dim=1) 170 | else: 171 | feeded = inputs 172 | 173 | output = self.model(feeded) 174 | 175 | L2_loss = self.loss(output, target) 176 | 177 | psnr = 10 * log10(1 / L2_loss.item()) 178 | ssim = pytorch_ssim.ssim(output, target) 179 | 180 | losses.update(L2_loss.item(), inputs.size(0)) 181 | psnres.update(psnr, inputs.size(0)) 182 | ssimes.update(ssim.item(), inputs.size(0)) 183 | 184 | # measure elapsed time 185 | batch_time.update(time.time() - end) 186 | end = time.time() 187 | 188 | print("Epoches:%s,Losses:%.3f,PSNR:%.3f,SSIM:%.3f"%(epoch+1, losses.avg,psnres.avg,ssimes.avg)) 189 | self.record('val/loss_L2', losses.avg, epoch) 190 | self.record('val/PSNR', psnres.avg, epoch) 191 | self.record('val/SSIM', ssimes.avg, epoch) 192 | 193 | self.metric = psnres.avg 194 | 195 | def resume(self,resume_path): 196 | # if isfile(resume_path): 197 | if not os.path.exists(resume_path): 198 | resume_path = os.path.join(self.args.checkpoint, 'checkpoint.pth.tar') 199 | if not os.path.exists(resume_path): 200 | raise Exception("=> no checkpoint found at '{}'".format(resume_path)) 201 | 202 | print("=> loading checkpoint '{}'".format(resume_path)) 203 | current_checkpoint = torch.load(resume_path) 204 | if isinstance(current_checkpoint['state_dict'], torch.nn.DataParallel): 205 | current_checkpoint['state_dict'] = current_checkpoint['state_dict'].module 206 | 207 | if isinstance(current_checkpoint['optimizer'], torch.nn.DataParallel): 208 | current_checkpoint['optimizer'] = current_checkpoint['optimizer'].module 209 | 210 | if self.args.start_epoch == 0: 211 | self.args.start_epoch = current_checkpoint['epoch'] 212 | self.metric = current_checkpoint['best_acc'] 213 | items = list(current_checkpoint['state_dict'].keys()) 214 | 215 | ## restore the learning rate 216 | lr = self.args.lr 217 | for epoch in self.args.schedule: 218 | if epoch <= self.args.start_epoch: 219 | lr *= self.args.gamma 220 | optimizers = [getattr(self.model, attr) for attr in dir(self.model) if attr.startswith("optimizer") and getattr(self.model, attr) is not None] 221 | for optimizer in optimizers: 222 | for param_group in optimizer.param_groups: 223 | param_group['lr'] = lr 224 | 225 | # ---------------- Load Model Weights -------------------------------------- 226 | self.model.load_state_dict(current_checkpoint['state_dict'], strict=True) 227 | print("=> loaded checkpoint '{}' (epoch {})" 228 | .format(resume_path, current_checkpoint['epoch'])) 229 | 230 | def save_checkpoint(self,filename='checkpoint.pth.tar', snapshot=None): 231 | is_best = True if self.best_acc < self.metric else False 232 | 233 | if is_best: 234 | self.best_acc = self.metric 235 | 236 | state = { 237 | 'epoch': self.current_epoch + 1, 238 | 'nets': self.args.nets, 239 | 'state_dict': self.model.state_dict(), 240 | 'best_acc': self.best_acc, 241 | 'optimizer' : self.optimizer.state_dict() if self.optimizer else None, 242 | } 243 | 244 | filepath = os.path.join(self.args.checkpoint, filename) 245 | torch.save(state, filepath) 246 | 247 | if snapshot and state['epoch'] % snapshot == 0: 248 | shutil.copyfile(filepath, os.path.join(self.args.checkpoint, 'checkpoint_{}.pth.tar'.format(state.epoch))) 249 | 250 | if is_best: 251 | self.best_acc = self.metric 252 | print('Saving Best Metric with PSNR:%s'%self.best_acc) 253 | if not os.path.exists(self.args.checkpoint): os.makedirs(self.args.checkpoint) 254 | shutil.copyfile(filepath, os.path.join(self.args.checkpoint, 'model_best.pth.tar')) 255 | 256 | def clean(self): 257 | self.writer.close() 258 | 259 | def record(self,k,v,epoch): 260 | self.writer.add_scalar(k, v, epoch) 261 | 262 | def flush(self): 263 | self.writer.flush() 264 | sys.stdout.flush() 265 | 266 | def norm(self,x): 267 | if self.args.gan_norm: 268 | return x*2.0 - 1.0 269 | else: 270 | return x 271 | 272 | def denorm(self,x): 273 | if self.args.gan_norm: 274 | return (x+1.0)/2.0 275 | else: 276 | return x 277 | 278 | -------------------------------------------------------------------------------- /src/utils/parallel.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang, Rutgers University, Email: zhang.hang@rutgers.edu 3 | ## Modified by Thomas Wolf, HuggingFace Inc., Email: thomas@huggingface.co 4 | ## Copyright (c) 2017-2018 5 | ## 6 | ## This source code is licensed under the MIT-style license found in the 7 | ## LICENSE file in the root directory of this source tree 8 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 9 | 10 | """Encoding Data Parallel""" 11 | import threading 12 | import functools 13 | import torch 14 | from torch.autograd import Variable, Function 15 | import torch.cuda.comm as comm 16 | from torch.nn.parallel import DistributedDataParallel 17 | from torch.nn.parallel.data_parallel import DataParallel 18 | from torch.nn.parallel.parallel_apply import get_a_var 19 | from torch.nn.parallel.scatter_gather import gather 20 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 21 | 22 | torch_ver = torch.__version__[:3] 23 | 24 | __all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion', 25 | 'patch_replication_callback'] 26 | 27 | def allreduce(*inputs): 28 | """Cross GPU all reduce autograd operation for calculate mean and 29 | variance in SyncBN. 30 | """ 31 | return AllReduce.apply(*inputs) 32 | 33 | class AllReduce(Function): 34 | @staticmethod 35 | def forward(ctx, num_inputs, *inputs): 36 | ctx.num_inputs = num_inputs 37 | ctx.target_gpus = [inputs[i].get_device() for i in range(0, len(inputs), num_inputs)] 38 | inputs = [inputs[i:i + num_inputs] 39 | for i in range(0, len(inputs), num_inputs)] 40 | # sort before reduce sum 41 | inputs = sorted(inputs, key=lambda i: i[0].get_device()) 42 | results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0]) 43 | outputs = comm.broadcast_coalesced(results, ctx.target_gpus) 44 | return tuple([t for tensors in outputs for t in tensors]) 45 | 46 | @staticmethod 47 | def backward(ctx, *inputs): 48 | inputs = [i.data for i in inputs] 49 | inputs = [inputs[i:i + ctx.num_inputs] 50 | for i in range(0, len(inputs), ctx.num_inputs)] 51 | results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0]) 52 | outputs = comm.broadcast_coalesced(results, ctx.target_gpus) 53 | return (None,) + tuple([Variable(t) for tensors in outputs for t in tensors]) 54 | 55 | 56 | class Reduce(Function): 57 | @staticmethod 58 | def forward(ctx, *inputs): 59 | ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))] 60 | inputs = sorted(inputs, key=lambda i: i.get_device()) 61 | return comm.reduce_add(inputs) 62 | 63 | @staticmethod 64 | def backward(ctx, gradOutput): 65 | return Broadcast.apply(ctx.target_gpus, gradOutput) 66 | 67 | class DistributedDataParallelModel(DistributedDataParallel): 68 | """Implements data parallelism at the module level for the DistributedDataParallel module. 69 | This container parallelizes the application of the given module by 70 | splitting the input across the specified devices by chunking in the 71 | batch dimension. 72 | In the forward pass, the module is replicated on each device, 73 | and each replica handles a portion of the input. During the backwards pass, 74 | gradients from each replica are summed into the original module. 75 | Note that the outputs are not gathered, please use compatible 76 | :class:`encoding.parallel.DataParallelCriterion`. 77 | The batch size should be larger than the number of GPUs used. It should 78 | also be an integer multiple of the number of GPUs so that each chunk is 79 | the same size (so that each GPU processes the same number of samples). 80 | Args: 81 | module: module to be parallelized 82 | device_ids: CUDA devices (default: all devices) 83 | Reference: 84 | Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, 85 | Amit Agrawal. “Context Encoding for Semantic Segmentation. 86 | *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018* 87 | Example:: 88 | >>> net = encoding.nn.DistributedDataParallelModel(model, device_ids=[0, 1, 2]) 89 | >>> y = net(x) 90 | """ 91 | def gather(self, outputs, output_device): 92 | return outputs 93 | 94 | class DataParallelModel(DataParallel): 95 | """Implements data parallelism at the module level. 96 | 97 | This container parallelizes the application of the given module by 98 | splitting the input across the specified devices by chunking in the 99 | batch dimension. 100 | In the forward pass, the module is replicated on each device, 101 | and each replica handles a portion of the input. During the backwards pass, 102 | gradients from each replica are summed into the original module. 103 | Note that the outputs are not gathered, please use compatible 104 | :class:`encoding.parallel.DataParallelCriterion`. 105 | 106 | The batch size should be larger than the number of GPUs used. It should 107 | also be an integer multiple of the number of GPUs so that each chunk is 108 | the same size (so that each GPU processes the same number of samples). 109 | 110 | Args: 111 | module: module to be parallelized 112 | device_ids: CUDA devices (default: all devices) 113 | 114 | Reference: 115 | Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, 116 | Amit Agrawal. “Context Encoding for Semantic Segmentation. 117 | *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018* 118 | 119 | Example:: 120 | 121 | >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2]) 122 | >>> y = net(x) 123 | """ 124 | def gather(self, outputs, output_device): 125 | return outputs 126 | 127 | def replicate(self, module, device_ids): 128 | modules = super(DataParallelModel, self).replicate(module, device_ids) 129 | execute_replication_callbacks(modules) 130 | return modules 131 | 132 | 133 | class DataParallelCriterion(DataParallel): 134 | """ 135 | Calculate loss in multiple-GPUs, which balance the memory usage. 136 | The targets are splitted across the specified devices by chunking in 137 | the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`. 138 | 139 | Reference: 140 | Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, 141 | Amit Agrawal. “Context Encoding for Semantic Segmentation. 142 | *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018* 143 | 144 | Example:: 145 | 146 | >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2]) 147 | >>> criterion = encoding.nn.DataParallelCriterion(criterion, device_ids=[0, 1, 2]) 148 | >>> y = net(x) 149 | >>> loss = criterion(y, target) 150 | """ 151 | def forward(self, inputs, *targets, **kwargs): 152 | # input should be already scatterd 153 | # scattering the targets instead 154 | if not self.device_ids: 155 | return self.module(inputs, *targets, **kwargs) 156 | targets, kwargs = self.scatter(targets, kwargs, self.device_ids) 157 | if len(self.device_ids) == 1: 158 | return self.module(inputs, *targets[0], **kwargs[0]) 159 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) 160 | outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs) 161 | #return Reduce.apply(*outputs) / len(outputs) 162 | #return self.gather(outputs, self.output_device).mean() 163 | return self.gather(outputs, self.output_device) 164 | 165 | 166 | def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None): 167 | assert len(modules) == len(inputs) 168 | assert len(targets) == len(inputs) 169 | if kwargs_tup: 170 | assert len(modules) == len(kwargs_tup) 171 | else: 172 | kwargs_tup = ({},) * len(modules) 173 | if devices is not None: 174 | assert len(modules) == len(devices) 175 | else: 176 | devices = [None] * len(modules) 177 | 178 | lock = threading.Lock() 179 | results = {} 180 | if torch_ver != "0.3": 181 | grad_enabled = torch.is_grad_enabled() 182 | 183 | def _worker(i, module, input, target, kwargs, device=None): 184 | if torch_ver != "0.3": 185 | torch.set_grad_enabled(grad_enabled) 186 | if device is None: 187 | device = get_a_var(input).get_device() 188 | try: 189 | with torch.cuda.device(device): 190 | # this also avoids accidental slicing of `input` if it is a Tensor 191 | if not isinstance(input, (list, tuple)): 192 | input = (input,) 193 | if not isinstance(target, (list, tuple)): 194 | target = (target,) 195 | output = module(*(input + target), **kwargs) 196 | with lock: 197 | results[i] = output 198 | except Exception as e: 199 | with lock: 200 | results[i] = e 201 | 202 | if len(modules) > 1: 203 | threads = [threading.Thread(target=_worker, 204 | args=(i, module, input, target, 205 | kwargs, device),) 206 | for i, (module, input, target, kwargs, device) in 207 | enumerate(zip(modules, inputs, targets, kwargs_tup, devices))] 208 | 209 | for thread in threads: 210 | thread.start() 211 | for thread in threads: 212 | thread.join() 213 | else: 214 | _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) 215 | 216 | outputs = [] 217 | for i in range(len(inputs)): 218 | output = results[i] 219 | if isinstance(output, Exception): 220 | raise output 221 | outputs.append(output) 222 | return outputs 223 | 224 | 225 | ########################################################################### 226 | # Adapted from Synchronized-BatchNorm-PyTorch. 227 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 228 | # 229 | class CallbackContext(object): 230 | pass 231 | 232 | 233 | def execute_replication_callbacks(modules): 234 | """ 235 | Execute an replication callback `__data_parallel_replicate__` on each module created 236 | by original replication. 237 | 238 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 239 | 240 | Note that, as all modules are isomorphism, we assign each sub-module with a context 241 | (shared among multiple copies of this module on different devices). 242 | Through this context, different copies can share some information. 243 | 244 | We guarantee that the callback on the master copy (the first copy) will be called ahead 245 | of calling the callback of any slave copies. 246 | """ 247 | master_copy = modules[0] 248 | nr_modules = len(list(master_copy.modules())) 249 | ctxs = [CallbackContext() for _ in range(nr_modules)] 250 | 251 | for i, module in enumerate(modules): 252 | for j, m in enumerate(module.modules()): 253 | if hasattr(m, '__data_parallel_replicate__'): 254 | m.__data_parallel_replicate__(ctxs[j], i) 255 | 256 | 257 | def patch_replication_callback(data_parallel): 258 | """ 259 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 260 | Useful when you have customized `DataParallel` implementation. 261 | 262 | Examples: 263 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 264 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 265 | > patch_replication_callback(sync_bn) 266 | # this is equivalent to 267 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 268 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 269 | """ 270 | 271 | assert isinstance(data_parallel, DataParallel) 272 | 273 | old_replicate = data_parallel.replicate 274 | 275 | @functools.wraps(old_replicate) 276 | def new_replicate(module, device_ids): 277 | modules = old_replicate(module, device_ids) 278 | execute_replication_callbacks(modules) 279 | return modules 280 | 281 | data_parallel.replicate = new_replicate -------------------------------------------------------------------------------- /src/models/SLBR.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from progress.bar import Bar 4 | from tqdm import tqdm 5 | import pytorch_ssim 6 | import json 7 | import sys,time,os 8 | import torchvision 9 | from math import log10 10 | import numpy as np 11 | from .BasicModel import BasicModel 12 | from evaluation import AverageMeter, compute_IoU, FScore, compute_RMSE 13 | import torch.nn.functional as F 14 | from src.utils.parallel import DataParallelModel, DataParallelCriterion 15 | from src.utils.losses import VGGLoss, l1_relative,is_dic 16 | from src.utils.imutils import im_to_numpy 17 | import skimage.io 18 | from skimage.measure import compare_psnr,compare_ssim 19 | import torchvision 20 | import pytorch_iou 21 | import pytorch_ssim 22 | 23 | class Losses(nn.Module): 24 | def __init__(self, argx, device, norm_func, denorm_func): 25 | super(Losses, self).__init__() 26 | self.args = argx 27 | self.masked_l1_loss, self.mask_loss = l1_relative, nn.BCELoss() 28 | self.l1_loss = nn.L1Loss() 29 | 30 | if self.args.lambda_content > 0: 31 | self.vgg_loss = VGGLoss(self.args.sltype, style=self.args.lambda_style>0).to(device) 32 | 33 | if self.args.lambda_iou > 0: 34 | self.iou_loss = pytorch_iou.IOU(size_average=True) 35 | 36 | self.lambda_primary = self.args.lambda_primary 37 | self.gamma = 0.5 38 | self.norm = norm_func 39 | self.denorm = denorm_func 40 | 41 | def forward(self, synthesis, pred_ims, target, pred_ms, mask, threshold=0.5): 42 | pixel_loss, refine_loss, vgg_loss, mask_loss = [0]*4 43 | pred_ims = pred_ims if is_dic(pred_ims) else [pred_ims] 44 | 45 | # reconstruction loss 46 | pixel_loss += self.masked_l1_loss(pred_ims[-1], target, mask) # coarse stage 47 | if len(pred_ims) > 1: 48 | refine_loss = self.masked_l1_loss(pred_ims[0], target, mask) # refinement stage 49 | 50 | recov_imgs = [ self.denorm(pred_im*mask + (1-mask)*self.norm(target)) for pred_im in pred_ims ] 51 | pixel_loss += sum([self.l1_loss(im,target) for im in recov_imgs]) * 1.5 52 | 53 | 54 | if self.args.lambda_content > 0: 55 | vgg_loss = [self.vgg_loss(im,target,mask) for im in recov_imgs] 56 | vgg_loss = sum([vgg['content'] for vgg in vgg_loss]) * self.args.lambda_content + \ 57 | sum([vgg['style'] for vgg in vgg_loss]) * self.args.lambda_style 58 | 59 | # mask loss 60 | pred_ms = [F.interpolate(ms, size=mask.shape[2:], mode='bilinear') for ms in pred_ms] 61 | pred_ms = [pred_m.clamp(0,1) for pred_m in pred_ms] 62 | mask = mask.clamp(0,1) 63 | 64 | final_mask_loss = 0 65 | final_mask_loss += self.mask_loss(pred_ms[0], mask) 66 | 67 | primary_mask = pred_ms[1::2][::-1] 68 | self_calibrated_mask = pred_ms[2::2][::-1] 69 | # primary prediction 70 | primary_loss = sum([self.mask_loss(pred_m, mask) * (self.gamma**i) for i,pred_m in enumerate(primary_mask)]) 71 | # self calibrated Branch 72 | self_calibrated_loss = sum([self.mask_loss(pred_m, mask) * (self.gamma**i) for i,pred_m in enumerate(self_calibrated_mask)]) 73 | if self.args.lambda_iou > 0: 74 | self_calibrated_loss += sum([self.iou_loss(pred_m, mask) * (self.gamma**i) for i,pred_m in enumerate(self_calibrated_mask)]) * self.args.lambda_iou 75 | 76 | mask_loss = final_mask_loss + self_calibrated_loss + self.lambda_primary * primary_loss 77 | return pixel_loss, refine_loss, vgg_loss, mask_loss 78 | 79 | 80 | 81 | 82 | class SLBR(BasicModel): 83 | def __init__(self,**kwargs): 84 | BasicModel.__init__(self,**kwargs) 85 | self.loss = Losses(self.args, self.device, self.norm, self.denorm) 86 | if isinstance(self.model, nn.DataParallel): 87 | self.model = self.model.module 88 | self.model.set_optimizers() 89 | if self.args.resume != '': 90 | self.resume(self.args.resume) 91 | 92 | def train(self,epoch): 93 | 94 | self.current_epoch = epoch 95 | 96 | batch_time = AverageMeter() 97 | data_time = AverageMeter() 98 | losses_meter = AverageMeter() 99 | loss_mask_meter = AverageMeter() 100 | loss_vgg_meter = AverageMeter() 101 | loss_refine_meter = AverageMeter() 102 | f1_meter = AverageMeter() 103 | # switch to train mode 104 | self.model.train() 105 | 106 | end = time.time() 107 | bar = Bar('Processing {} '.format(self.args.nets), max=len(self.train_loader)) 108 | for i, batches in enumerate(self.train_loader): 109 | current_index = len(self.train_loader) * epoch + i 110 | 111 | inputs = batches['image'].float().to(self.device) 112 | target = batches['target'].float().to(self.device) 113 | mask = batches['mask'].float().to(self.device) 114 | # wm = batches['wm'].float().to(self.device) 115 | # alpha_gt = batches['alpha'].float().to(self.device) 116 | img_path = batches['img_path'] 117 | 118 | outputs = self.model(self.norm(inputs)) 119 | self.model.zero_grad_all() 120 | coarse_loss, refine_loss, style_loss, mask_loss = self.loss( 121 | inputs,outputs[0],self.norm(target),outputs[1],mask) 122 | 123 | total_loss = self.args.lambda_l1*(coarse_loss+refine_loss) + self.args.lambda_mask * (mask_loss) + style_loss 124 | 125 | # compute gradient and do SGD step 126 | total_loss.backward() 127 | self.model.step_all() 128 | 129 | # measure accuracy and record loss 130 | losses_meter.update(coarse_loss.item(), inputs.size(0)) 131 | loss_mask_meter.update(mask_loss.item(), inputs.size(0)) 132 | if isinstance(refine_loss,int): 133 | loss_refine_meter.update(refine_loss, inputs.size(0)) 134 | else: 135 | loss_refine_meter.update(refine_loss.item(), inputs.size(0)) 136 | 137 | f1 = FScore(outputs[1][0], mask).item() 138 | f1_meter.update(f1, inputs.size(0)) 139 | if self.args.lambda_content > 0 and not isinstance(style_loss,int): 140 | loss_vgg_meter.update(style_loss.item(), inputs.size(0)) 141 | 142 | # measure elapsed timec 143 | batch_time.update(time.time() - end) 144 | end = time.time() 145 | 146 | # plot progress 147 | suffix = "({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | loss L1: {loss_label:.4f} | loss Refine: {loss_refine:.4f} | loss VGG: {loss_vgg:.4f} | loss Mask: {loss_mask:.4f} | mask F1: {mask_f1:.4f}".format( 148 | batch=i + 1, 149 | size=len(self.train_loader), 150 | data=data_time.val, 151 | bt=batch_time.val, 152 | total=bar.elapsed_td, 153 | eta=bar.eta_td, 154 | loss_label=losses_meter.avg, 155 | loss_refine=loss_refine_meter.avg, 156 | loss_vgg=loss_vgg_meter.avg, 157 | loss_mask=loss_mask_meter.avg, 158 | mask_f1=f1_meter.avg, 159 | ) 160 | if current_index % 100 == 0: 161 | print(suffix) 162 | 163 | if self.args.freq > 0 and current_index % self.args.freq == 0: 164 | self.validate(current_index) 165 | self.flush() 166 | self.save_checkpoint() 167 | if i % 100 == 0: 168 | self.record('train/loss_L2', losses_meter.avg, current_index) 169 | self.record('train/loss_Refine', loss_refine_meter.avg, current_index) 170 | self.record('train/loss_VGG', loss_vgg_meter.avg, current_index) 171 | self.record('train/loss_Mask', loss_mask_meter.avg, current_index) 172 | self.record('train/mask_F1', f1_meter.avg, current_index) 173 | 174 | mask_pred = outputs[1][0] 175 | bg_pred = self.denorm(outputs[0][0]*mask_pred + (1-mask_pred)*self.norm(inputs)) 176 | show_size = 5 if inputs.shape[0] > 5 else inputs.shape[0] 177 | self.image_display = torch.cat([ 178 | inputs[0:show_size].detach().cpu(), # input image 179 | target[0:show_size].detach().cpu(), # ground truth 180 | bg_pred[0:show_size].detach().cpu(), # refine out 181 | mask[0:show_size].detach().cpu().repeat(1,3,1,1), 182 | outputs[1][0][0:show_size].detach().cpu().repeat(1,3,1,1), 183 | outputs[1][-2][0:show_size].detach().cpu().repeat(1,3,1,1) 184 | ],dim=0) 185 | image_dis = torchvision.utils.make_grid(self.image_display, nrow=show_size) 186 | self.writer.add_image('Image', image_dis, current_index) 187 | del outputs 188 | 189 | 190 | def validate(self, epoch): 191 | 192 | self.current_epoch = epoch 193 | 194 | batch_time = AverageMeter() 195 | data_time = AverageMeter() 196 | losses_meter = AverageMeter() 197 | loss_mask_meter = AverageMeter() 198 | psnr_meter = AverageMeter() 199 | fpsnr_meter = AverageMeter() 200 | ssim_meter = AverageMeter() 201 | rmse_meter = AverageMeter() 202 | rmsew_meter = AverageMeter() 203 | 204 | 205 | coarse_psnr_meter = AverageMeter() 206 | coarse_rmsew_meter = AverageMeter() 207 | 208 | iou_meter = AverageMeter() 209 | f1_meter = AverageMeter() 210 | # switch to evaluate mode 211 | self.model.eval() 212 | 213 | end = time.time() 214 | bar = Bar('Processing {} '.format(self.args.nets), max=len(self.val_loader)) 215 | with torch.no_grad(): 216 | for i, batches in enumerate(self.val_loader): 217 | 218 | current_index = len(self.val_loader) * epoch + i 219 | 220 | inputs = batches['image'].to(self.device) 221 | target = batches['target'].to(self.device) 222 | mask = batches['mask'].to(self.device) 223 | # alpha_gt = batches['alpha'].float().to(self.device) 224 | 225 | outputs = self.model(self.norm(inputs)) 226 | imoutput,immask,imwatermark = outputs 227 | 228 | immask = immask[0] 229 | if len(imoutput) > 1: 230 | imcoarse = imoutput[1] 231 | imcoarse = imcoarse*immask + inputs*(1-immask) 232 | else: imcoarse = None 233 | imoutput = imoutput[0] if is_dic(imoutput) else imoutput 234 | 235 | imfinal = self.denorm(imoutput*immask + self.norm(inputs)*(1-immask)) 236 | 237 | eps = 1e-6 238 | psnr = 10 * log10(1 / F.mse_loss(imfinal,target).item()) 239 | fmse = F.mse_loss(imfinal*mask, target*mask, reduction='none').sum(dim=[1,2,3]) / (mask.sum(dim=[1,2,3])*3+eps) 240 | fpsnr = 10 * torch.log10(1 / fmse).mean().item() 241 | ssim = pytorch_ssim.ssim(imfinal,target) 242 | if imcoarse is not None: 243 | psnr_coarse = 10 * log10(1 / F.mse_loss(imcoarse,target).item()) 244 | rmsew_coarse = compute_RMSE(imcoarse, target, mask, is_w=True) 245 | coarse_psnr_meter.update(psnr_coarse, inputs.size(0)) 246 | coarse_rmsew_meter.update(rmsew_coarse, inputs.size(0)) 247 | 248 | psnr_meter.update(psnr, inputs.size(0)) 249 | fpsnr_meter.update(fpsnr, inputs.size(0)) 250 | ssim_meter.update(ssim, inputs.size(0)) 251 | rmse_meter.update(compute_RMSE(imfinal,target,mask),inputs.size(0)) 252 | rmsew_meter.update(compute_RMSE(imfinal,target,mask,is_w=True), inputs.size(0)) 253 | 254 | iou = compute_IoU(immask, mask) 255 | iou_meter.update(iou, inputs.size(0)) 256 | f1 = FScore(immask, mask).item() 257 | f1_meter.update(f1, inputs.size(0)) 258 | # measure elapsed time 259 | batch_time.update(time.time() - end) 260 | end = time.time() 261 | 262 | # plot progress 263 | if imcoarse is None: 264 | suffix = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | PSNR: {psnr:.4f} | fPSNR: {fpsnr:.4f} | SSIM: {ssim:.4f} | RMSE: {rmse:.4f} | RMSEw: {rmsew:.4f} | IoU: {iou:.4f} | F1: {f1:.4f}'.format( 265 | batch=i + 1, 266 | size=len(self.val_loader), 267 | data=data_time.val, 268 | bt=batch_time.val, 269 | total=bar.elapsed_td, 270 | eta=bar.eta_td, 271 | psnr=psnr_meter.avg, 272 | fpsnr=fpsnr_meter.avg, 273 | ssim=ssim_meter.avg, 274 | rmse=rmse_meter.avg, 275 | rmsew=rmsew_meter.avg, 276 | iou=iou_meter.avg, 277 | f1=f1_meter.avg 278 | ) 279 | else: 280 | suffix = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | CPSNR: {cpsnr:.4f} | CRMSEw: {crmsew:.4f} | PSNR: {psnr:.4f} | fPSNR: {fpsnr:.4f} | RMSE: {rmse:.4f} | RMSEw: {rmsew:.4f} | SSIM: {ssim:.4f} | IoU: {iou:.4f} | F1: {f1:.4f}'.format( 281 | batch=i + 1, 282 | size=len(self.val_loader), 283 | data=data_time.val, 284 | bt=batch_time.val, 285 | total=bar.elapsed_td, 286 | eta=bar.eta_td, 287 | cpsnr=coarse_psnr_meter.avg, 288 | crmsew=coarse_rmsew_meter.avg, 289 | psnr=psnr_meter.avg, 290 | fpsnr=fpsnr_meter.avg, 291 | ssim=ssim_meter.avg, 292 | rmse=rmse_meter.avg, 293 | rmsew=rmsew_meter.avg, 294 | iou=iou_meter.avg, 295 | f1=f1_meter.avg 296 | ) 297 | if i%100 == 0: 298 | print(suffix) 299 | # bar.next() 300 | print("Total:") 301 | print(suffix) 302 | bar.finish() 303 | 304 | print("Iter:%s,losses:%s,PSNR:%.4f,SSIM:%.4f"%(epoch, losses_meter.avg,psnr_meter.avg,ssim_meter.avg)) 305 | self.record('val/loss_L2', losses_meter.avg, epoch) 306 | self.record('val/loss_mask', loss_mask_meter.avg, epoch) 307 | self.record('val/PSNR', psnr_meter.avg, epoch) 308 | self.record('val/SSIM', ssim_meter.avg, epoch) 309 | self.record('val/RMSEw', rmsew_meter.avg, epoch) 310 | self.metric = psnr_meter.avg 311 | 312 | self.model.train() 313 | 314 | -------------------------------------------------------------------------------- /src/networks/resunet.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from src.networks.blocks import UpConv, DownConv, MBEBlock, SMRBlock, CFFBlock, ResDownNew, ResUpNew, ECABlock 6 | import scipy.stats as st 7 | import itertools 8 | import cv2 9 | 10 | def weight_init(m): 11 | if isinstance(m, nn.Conv2d): 12 | nn.init.xavier_normal_(m.weight) 13 | if m.bias is not None: 14 | nn.init.constant_(m.bias, 0) 15 | 16 | def reset_params(model): 17 | for i, m in enumerate(model.modules()): 18 | weight_init(m) 19 | 20 | class CoarseEncoder(nn.Module): 21 | def __init__(self, in_channels=3, depth=3, blocks=1, start_filters=32, residual=True, norm=nn.BatchNorm2d, act=F.relu): 22 | super(CoarseEncoder, self).__init__() 23 | self.down_convs = [] 24 | outs = None 25 | if type(blocks) is tuple: 26 | blocks = blocks[0] 27 | for i in range(depth): 28 | ins = in_channels if i == 0 else outs 29 | outs = start_filters*(2**i) 30 | # pooling = True if i < depth-1 else False 31 | pooling = True 32 | down_conv = DownConv(ins, outs, blocks, pooling=pooling, residual=residual, norm=norm, act=act) 33 | self.down_convs.append(down_conv) 34 | self.down_convs = nn.ModuleList(self.down_convs) 35 | reset_params(self) 36 | 37 | def forward(self, x): 38 | encoder_outs = [] 39 | for d_conv in self.down_convs: 40 | x, before_pool = d_conv(x) 41 | encoder_outs.append(before_pool) 42 | return x, encoder_outs 43 | 44 | class SharedBottleNeck(nn.Module): 45 | def __init__(self, in_channels=512, depth=5, shared_depth=2, start_filters=32, blocks=1, residual=True, 46 | concat=True, norm=nn.BatchNorm2d, act=F.relu, dilations=[1,2,5]): 47 | super(SharedBottleNeck, self).__init__() 48 | self.down_convs = [] 49 | self.up_convs = [] 50 | self.down_im_atts = [] 51 | self.down_mask_atts = [] 52 | self.up_im_atts = [] 53 | self.up_mask_atts = [] 54 | 55 | dilations = [1,2,5] 56 | start_depth = depth - shared_depth 57 | max_filters = 512 58 | for i in range(start_depth, depth): # depth = 5 [0,1,2,3] 59 | ins = in_channels if i == start_depth else outs 60 | outs = min(ins * 2, max_filters) 61 | # Encoder convs 62 | pooling = True if i < depth-1 else False 63 | down_conv = DownConv(ins, outs, blocks, pooling=pooling, residual=residual, norm=norm, act=act, dilations=dilations) 64 | self.down_convs.append(down_conv) 65 | 66 | # Decoder convs 67 | if i < depth - 1: 68 | up_conv = UpConv(min(outs*2, max_filters), outs, blocks, residual=residual, concat=concat, norm=norm,act=F.relu, dilations=dilations) 69 | self.up_convs.append(up_conv) 70 | self.up_im_atts.append(ECABlock(outs)) 71 | self.up_mask_atts.append(ECABlock(outs)) 72 | 73 | self.down_convs = nn.ModuleList(self.down_convs) 74 | self.up_convs = nn.ModuleList(self.up_convs) 75 | 76 | # task-specific channel attention blocks 77 | self.up_im_atts = nn.ModuleList(self.up_im_atts) 78 | self.up_mask_atts = nn.ModuleList(self.up_mask_atts) 79 | 80 | reset_params(self) 81 | 82 | def forward(self, input): 83 | # Encoder convs 84 | im_encoder_outs = [] 85 | mask_encoder_outs = [] 86 | x = input 87 | for i, d_conv in enumerate(self.down_convs): 88 | # d_conv, attn = nets 89 | x, before_pool = d_conv(x) 90 | im_encoder_outs.append(before_pool) 91 | mask_encoder_outs.append(before_pool) 92 | x_im = x 93 | x_mask = x 94 | 95 | # Decoder convs 96 | x = x_im 97 | for i, nets in enumerate(zip(self.up_convs, self.up_im_atts)): 98 | up_conv, attn = nets 99 | before_pool = None 100 | if im_encoder_outs is not None: 101 | before_pool = im_encoder_outs[-(i+2)] 102 | x = up_conv(x, before_pool,se=attn) 103 | x_im = x 104 | 105 | x = x_mask 106 | for i, nets in enumerate(zip(self.up_convs, self.up_mask_atts)): 107 | up_conv, attn = nets 108 | before_pool = None 109 | if mask_encoder_outs is not None: 110 | before_pool = mask_encoder_outs[-(i+2)] 111 | x = up_conv(x, before_pool, se = attn) 112 | x_mask = x 113 | 114 | return x_im, x_mask 115 | 116 | class CoarseDecoder(nn.Module): 117 | def __init__(self, args, in_channels=512, out_channels=3, norm='bn',act=F.relu, depth=5, blocks=1, residual=True, 118 | concat=True, use_att=False): 119 | super(CoarseDecoder, self).__init__() 120 | self.up_convs_bg = [] 121 | self.up_convs_mask = [] 122 | 123 | # apply channel attention to skip connection for different decoders 124 | self.atts_bg = [] 125 | self.atts_mask = [] 126 | self.use_att = use_att 127 | outs = in_channels 128 | for i in range(depth): 129 | ins = outs 130 | outs = ins // 2 131 | # background reconstruction branch 132 | up_conv = MBEBlock(args.bg_mode, ins, outs, blocks=blocks, residual=residual, concat=concat, norm='in', act=act) 133 | self.up_convs_bg.append(up_conv) 134 | if self.use_att: 135 | self.atts_bg.append(ECABlock(outs)) 136 | 137 | # mask prediction branch 138 | up_conv = SMRBlock(args, ins, outs, blocks=blocks, residual=residual, concat=concat, norm=norm, act=act) 139 | self.up_convs_mask.append(up_conv) 140 | if self.use_att: 141 | self.atts_mask.append(ECABlock(outs)) 142 | # final conv 143 | self.conv_final_bg = nn.Conv2d(outs, out_channels, 1,1,0) 144 | 145 | self.up_convs_bg = nn.ModuleList(self.up_convs_bg) 146 | self.atts_bg = nn.ModuleList(self.atts_bg) 147 | self.up_convs_mask = nn.ModuleList(self.up_convs_mask) 148 | self.atts_mask = nn.ModuleList(self.atts_mask) 149 | 150 | reset_params(self) 151 | 152 | def forward(self, bg, fg, mask, encoder_outs=None): 153 | bg_x = bg 154 | fg_x = fg 155 | mask_x = mask 156 | mask_outs = [] 157 | bg_outs = [] 158 | for i, up_convs in enumerate(zip(self.up_convs_bg, self.up_convs_mask)): 159 | up_bg, up_mask = up_convs 160 | before_pool = None 161 | if encoder_outs is not None: 162 | before_pool = encoder_outs[-(i+1)] 163 | 164 | if self.use_att: 165 | mask_before_pool = self.atts_mask[i](before_pool) 166 | bg_before_pool = self.atts_bg[i](before_pool) 167 | smr_outs = up_mask(mask_x, mask_before_pool) 168 | mask_x= smr_outs['feats'][0] 169 | primary_map, self_calibrated_map = smr_outs['attn_maps'] 170 | mask_outs.append(primary_map) 171 | mask_outs.append(self_calibrated_map) 172 | 173 | 174 | bg_x = up_bg(bg_x, bg_before_pool, self_calibrated_map.detach()) 175 | bg_outs.append(bg_x) 176 | 177 | if self.conv_final_bg is not None: 178 | bg_x = self.conv_final_bg(bg_x) 179 | mask_x = mask_outs[-1] 180 | bg_outs = [bg_x] + bg_outs 181 | return bg_outs, [mask_x] + mask_outs, None 182 | 183 | 184 | ################################################################# 185 | # Refinement Stage 186 | ################################################################# 187 | 188 | 189 | 190 | class Refinement(nn.Module): 191 | def __init__(self, in_channels=3, out_channels=3, shared_depth=2, down=ResDownNew, up=ResUpNew, ngf=32, n_cff=3, n_skips=3): 192 | super(Refinement, self).__init__() 193 | 194 | self.conv_in = nn.Sequential(nn.Conv2d(in_channels, ngf, 3,1,1), nn.InstanceNorm2d(ngf), nn.LeakyReLU(0.2)) 195 | self.down1 = down(ngf, ngf) 196 | self.down2 = down(ngf, ngf*2) 197 | self.down3 = down(ngf*2, ngf*4, pooling=False, dilation=True) 198 | 199 | self.dec_conv2 = nn.Sequential(nn.Conv2d(ngf*1,ngf*1,1,1,0)) 200 | self.dec_conv3 = nn.Sequential(nn.Conv2d(ngf*2,ngf*1,1,1,0), nn.LeakyReLU(0.2), nn.Conv2d(ngf, ngf, 3,1,1), nn.LeakyReLU(0.2)) 201 | self.dec_conv4 = nn.Sequential(nn.Conv2d(ngf*4,ngf*2,1,1,0), nn.LeakyReLU(0.2), nn.Conv2d(ngf*2, ngf*2, 3,1,1), nn.LeakyReLU(0.2)) 202 | self.n_skips = n_skips 203 | 204 | 205 | # CFF Blocks 206 | self.cff_blocks = [] 207 | for i in range(n_cff): 208 | self.cff_blocks.append(CFFBlock(ngf=ngf)) 209 | self.cff_blocks = nn.ModuleList(self.cff_blocks) 210 | 211 | self.out_conv = nn.Sequential(*[ 212 | nn.Conv2d(ngf + ngf*2 + ngf*4, ngf, 3,1,1), 213 | nn.InstanceNorm2d(ngf), 214 | nn.LeakyReLU(0.2), 215 | nn.Conv2d(ngf, out_channels, 1,1,0) 216 | ]) 217 | 218 | def forward(self, input, coarse_bg, mask, encoder_outs, decoder_outs): 219 | if self.n_skips < 1: 220 | dec_feat2 = 0 221 | else: 222 | dec_feat2 = self.dec_conv2(decoder_outs[0]) 223 | if self.n_skips < 2: 224 | dec_feat3 = 0 225 | else: 226 | dec_feat3 = self.dec_conv3(decoder_outs[1]) # 64 227 | if self.n_skips < 3: 228 | dec_feat4 = 0 229 | else: 230 | dec_feat4 = self.dec_conv4(decoder_outs[2]) # 64 231 | 232 | xin = torch.cat([coarse_bg, mask], dim=1) 233 | x = self.conv_in(xin) 234 | 235 | x,d1 = self.down1(x + dec_feat2) # 128,256 236 | x,d2 = self.down2(x + dec_feat3) # 64,128 237 | x,d3 = self.down3(x + dec_feat4) # 32,64 238 | 239 | xs = [d1,d2,d3] 240 | for block in self.cff_blocks: 241 | xs = block(xs) 242 | 243 | xs = [F.interpolate(x_hr, size=coarse_bg.shape[2:][::-1], mode='bilinear') for x_hr in xs] 244 | im = self.out_conv(torch.cat(xs,dim=1)) 245 | return im 246 | 247 | 248 | 249 | 250 | 251 | 252 | class SLBR(nn.Module): 253 | 254 | def __init__(self, args, in_channels=3, depth=5, shared_depth=2, blocks=1, 255 | out_channels_image=3, out_channels_mask=1, start_filters=32, residual=True, 256 | concat=True, long_skip=False): 257 | super(SLBR, self).__init__() 258 | self.shared = shared_depth = 2 259 | self.optimizer_encoder, self.optimizer_image, self.optimizer_wm = None, None, None 260 | self.optimizer_mask, self.optimizer_shared = None, None 261 | self.args = args 262 | if type(blocks) is not tuple: 263 | blocks = (blocks, blocks, blocks, blocks, blocks) 264 | 265 | # coarse stage 266 | self.encoder = CoarseEncoder(in_channels=in_channels, depth= depth - shared_depth, blocks=blocks[0], 267 | start_filters=start_filters, residual=residual, norm='bn',act=F.relu) 268 | self.shared_decoder = SharedBottleNeck(in_channels=start_filters * 2 ** (depth - shared_depth - 1), 269 | depth=depth, shared_depth=shared_depth, blocks=blocks[4], residual=residual, 270 | concat=concat, norm='in') 271 | 272 | self.coarse_decoder = CoarseDecoder(args, in_channels=start_filters * 2 ** (depth - shared_depth), 273 | out_channels=out_channels_image, depth=depth - shared_depth, 274 | blocks=blocks[1], residual=residual, 275 | concat=concat, norm='bn', use_att=True, 276 | ) 277 | 278 | self.long_skip = long_skip 279 | 280 | # refinement stage 281 | if args.use_refine: 282 | self.refinement = Refinement(in_channels=4, out_channels=3, shared_depth=1, n_cff=args.k_refine, n_skips=args.k_skip_stage) 283 | else: 284 | self.refinement = None 285 | 286 | def set_optimizers(self): 287 | self.optimizer_encoder = torch.optim.Adam(self.encoder.parameters(), lr=self.args.lr) 288 | self.optimizer_image = torch.optim.Adam(self.coarse_decoder.parameters(), lr=self.args.lr) 289 | 290 | if self.refinement is not None: 291 | self.optimizer_refine = torch.optim.Adam(self.refinement.parameters(), lr=self.args.lr) 292 | 293 | if self.shared != 0: 294 | self.optimizer_shared = torch.optim.Adam(self.shared_decoder.parameters(), lr=self.args.lr) 295 | 296 | def zero_grad_all(self): 297 | self.optimizer_encoder.zero_grad() 298 | self.optimizer_image.zero_grad() 299 | 300 | if self.shared != 0: 301 | self.optimizer_shared.zero_grad() 302 | if self.refinement is not None: 303 | self.optimizer_refine.zero_grad() 304 | 305 | def step_all(self): 306 | self.optimizer_encoder.step() 307 | if self.shared != 0: 308 | self.optimizer_shared.step() 309 | self.optimizer_image.step() 310 | if self.refinement is not None: 311 | self.optimizer_refine.step() 312 | 313 | def multi_gpu(self): 314 | self.encoder = nn.DataParallel(self.encoder, device_ids=range(torch.cuda.device_count())) 315 | self.shared_decoder = nn.DataParallel(self.shared_decoder, device_ids=range(torch.cuda.device_count())) 316 | self.coarse_decoder = nn.DataParallel(self.coarse_decoder, device_ids=range(torch.cuda.device_count())) 317 | if self.refinement is not None: 318 | self.refinement = nn.DataParallel(self.refinement, device_ids=range(torch.cuda.device_count())) 319 | return 320 | 321 | def forward(self, synthesized): 322 | image_code, before_pool = self.encoder(synthesized) 323 | unshared_before_pool = before_pool #[: - self.shared] 324 | 325 | im, mask = self.shared_decoder(image_code) 326 | ims, mask, wm = self.coarse_decoder(im, None, mask, unshared_before_pool) 327 | im = ims[0] 328 | reconstructed_image = torch.tanh(im) 329 | if self.long_skip: 330 | reconstructed_image = (reconstructed_image + synthesized).clamp(0,1) 331 | 332 | reconstructed_mask = mask[0] 333 | reconstructed_wm = wm 334 | 335 | if self.refinement is not None: 336 | dec_feats = (ims)[1:][::-1] 337 | coarser = reconstructed_image * reconstructed_mask + (1-reconstructed_mask)* synthesized 338 | refine_bg = self.refinement(synthesized, coarser, reconstructed_mask, None, dec_feats) 339 | refine_bg = (torch.tanh(refine_bg) + synthesized).clamp(0,1) # coarser 340 | return [refine_bg, reconstructed_image], mask, [reconstructed_wm] 341 | 342 | else: 343 | return [reconstructed_image], mask, [reconstructed_wm] 344 | 345 | 346 | -------------------------------------------------------------------------------- /src/utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision import models 5 | from src.utils.misc import resize_to_match 6 | import pytorch_ssim 7 | import numpy as np 8 | import cv2 9 | 10 | 11 | class FocalLoss(nn.Module): 12 | def __init__(self, alpha=0, gamma=2, logits=False, reduce=True): 13 | super(FocalLoss, self).__init__() 14 | self.alpha = alpha 15 | self.gamma = gamma 16 | self.logits = logits 17 | self.reduce = reduce 18 | 19 | def forward(self, inputs, targets): 20 | if self.logits: 21 | BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False) 22 | else: 23 | BCE_loss = F.binary_cross_entropy(inputs, targets, reduce=False) 24 | pt = torch.exp(-BCE_loss) 25 | F_loss = (1-pt)**self.gamma * BCE_loss 26 | if self.alpha > 0: 27 | F_loss = self.alpha * targets * F_loss + (1 - self.alpha) * (1-targets)* F_loss 28 | if self.reduce: 29 | return torch.mean(F_loss) 30 | else: 31 | return F_loss 32 | 33 | class WeightedBCE(nn.Module): 34 | def __init__(self): 35 | super(WeightedBCE, self).__init__() 36 | 37 | def forward(self, pred, gt): 38 | eposion = 1e-10 39 | sigmoid_pred = torch.sigmoid(pred) 40 | count_pos = torch.sum(gt)*1.0+eposion 41 | count_neg = torch.sum(1.-gt)*1.0 42 | beta = count_neg/count_pos 43 | beta_back = count_pos / (count_pos + count_neg) 44 | 45 | bce1 = nn.BCEWithLogitsLoss(pos_weight=beta) 46 | loss = beta_back*bce1(pred, gt) 47 | 48 | return loss 49 | 50 | 51 | def l1_relative(reconstructed, real, mask): 52 | batch = real.size(0) 53 | area = torch.sum(mask.view(batch,-1),dim=1) 54 | reconstructed = reconstructed * mask 55 | real = real * mask 56 | 57 | loss_l1 = torch.abs(reconstructed - real).view(batch, -1) 58 | loss_l1 = torch.sum(loss_l1, dim=1) / (area+1e-6) 59 | loss_l1 = torch.sum(loss_l1) / batch 60 | return loss_l1 61 | 62 | 63 | def is_dic(x): 64 | return type(x) == type([]) 65 | 66 | class Losses(nn.Module): 67 | def __init__(self, argx, device): 68 | super(Losses, self).__init__() 69 | self.args = argx 70 | 71 | if self.args.loss_type == 'l1bl2': 72 | self.outputLoss, self.attLoss, self.wrloss = nn.L1Loss(), nn.BCELoss(), nn.MSELoss() 73 | elif self.args.loss_type == 'l1wbl2': 74 | self.outputLoss, self.attLoss, self.wrloss = nn.L1Loss(), WeightedBCE(), nn.MSELoss() 75 | elif self.args.loss_type == 'l2wbl2': 76 | self.outputLoss, self.attLoss, self.wrloss = nn.MSELoss(), WeightedBCE(), nn.MSELoss() 77 | elif self.args.loss_type == 'l2xbl2': 78 | self.outputLoss, self.attLoss, self.wrloss = nn.MSELoss(), nn.BCEWithLogitsLoss(), nn.MSELoss() 79 | else: # l2bl2 80 | self.outputLoss, self.attLoss, self.wrloss = nn.MSELoss(), nn.BCELoss(), nn.MSELoss() 81 | 82 | if self.args.lambda_style > 0: 83 | self.vggloss = VGGLoss(self.args.sltype).to(device) 84 | 85 | if self.args.ssim_loss > 0: 86 | self.ssimloss = pytorch_ssim.SSIM().to(device) 87 | 88 | self.outputLoss = self.outputLoss.to(device) 89 | self.attLoss = self.attLoss.to(device) 90 | self.wrloss = self.wrloss.to(device) 91 | 92 | 93 | def forward(self,imgx,target,attx,mask,wmx,wm): 94 | pixel_loss,att_loss,wm_loss,vgg_loss,ssim_loss = 0,0,0,0,0 95 | 96 | if is_dic(imgx): 97 | 98 | if self.args.masked: 99 | # calculate the overall loss and side output 100 | pixel_loss = self.outputLoss(imgx[0],target) + sum([self.outputLoss(im,resize_to_match(mask,im)*resize_to_match(target,im)) for im in imgx[1:]]) 101 | else: 102 | pixel_loss = sum([self.outputLoss(im,resize_to_match(target,im)) for im in imgx]) 103 | 104 | if self.args.lambda_style > 0: 105 | vgg_loss = sum([self.vggloss(im,resize_to_match(target,im),resize_to_match(mask,im)) for im in imgx]) 106 | 107 | if self.args.ssim_loss > 0: 108 | ssim_loss = sum([ 1 - self.ssimloss(im,resize_to_match(target,im)) for im in imgx]) 109 | else: 110 | 111 | if self.args.masked: 112 | pixel_loss = self.outputLoss(imgx,mask*target) 113 | else: 114 | pixel_loss = self.outputLoss(imgx,target) 115 | 116 | if self.args.lambda_style > 0: 117 | vgg_loss = self.vggloss(imgx,target,mask) 118 | 119 | if self.args.ssim_loss > 0: 120 | ssim_loss = 1 - self.ssimloss(imgx,target) 121 | 122 | if is_dic(attx): 123 | att_loss = sum([self.attLoss(at,resize_to_match(mask,at)) for at in attx]) 124 | else: 125 | att_loss = self.attLoss(attx, mask) 126 | 127 | if is_dic(wmx): 128 | wm_loss = sum([self.wrloss(w,resize_to_match(wm,w)) for w in wmx]) 129 | else: 130 | if self.args.masked: 131 | wm_loss = self.wrloss(wmx,mask*wm) 132 | else: 133 | wm_loss = self.wrloss(wmx, wm) 134 | 135 | return pixel_loss,att_loss,wm_loss,vgg_loss,ssim_loss 136 | 137 | 138 | 139 | 140 | 141 | class MeanShift(nn.Conv2d): 142 | def __init__(self, data_mean, data_std, data_range=1, norm=True): 143 | """norm (bool): normalize/denormalize the stats""" 144 | c = len(data_mean) 145 | super(MeanShift, self).__init__(c, c, kernel_size=1) 146 | std = torch.Tensor(data_std) 147 | self.weight.data = torch.eye(c).view(c, c, 1, 1) 148 | if norm: 149 | self.weight.data.div_(std.view(c, 1, 1, 1)) 150 | self.bias.data = -1 * data_range * torch.Tensor(data_mean) 151 | self.bias.data.div_(std) 152 | else: 153 | self.weight.data.mul_(std.view(c, 1, 1, 1)) 154 | self.bias.data = data_range * torch.Tensor(data_mean) 155 | self.requires_grad = False 156 | 157 | 158 | 159 | def VGGLoss(losstype, style=False): 160 | if losstype == 'vggx': 161 | return VGGLossX(mask=False, style=style) 162 | elif losstype == 'mvggx': 163 | return VGGLossX(mask=True) 164 | elif losstype == 'rvggx': 165 | return VGGLossX(mask=True,relative=True) 166 | else: 167 | raise Exception("error in %s"%losstype) 168 | 169 | 170 | 171 | 172 | 173 | class VGG16FeatureExtractor(nn.Module): 174 | def __init__(self): 175 | super().__init__() 176 | vgg16 = models.vgg16(pretrained=True) 177 | self.enc_1 = nn.Sequential(*vgg16.features[:5]) 178 | self.enc_2 = nn.Sequential(*vgg16.features[5:10]) 179 | self.enc_3 = nn.Sequential(*vgg16.features[10:17]) 180 | 181 | # fix the encoder 182 | for i in range(3): 183 | for param in getattr(self, 'enc_{:d}'.format(i + 1)).parameters(): 184 | param.requires_grad = False 185 | 186 | def forward(self, image): 187 | results = [image] 188 | for i in range(3): 189 | func = getattr(self, 'enc_{:d}'.format(i + 1)) 190 | results.append(func(results[-1])) 191 | return results[1:] 192 | 193 | class VGGLossX(nn.Module): 194 | def __init__(self, normalize=True, mask=False, relative=False, style=False): 195 | super(VGGLossX, self).__init__() 196 | 197 | self.vgg = VGG16FeatureExtractor().cuda() 198 | self.criterion = nn.L1Loss().cuda() if not relative else l1_relative 199 | self.use_style = style 200 | self.use_mask= mask 201 | self.relative = relative 202 | 203 | if normalize: 204 | self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda() 205 | else: 206 | self.normalize = None 207 | 208 | def forward(self, x, y, Xmask=None): 209 | if not self.use_mask: 210 | mask = torch.ones_like(x)[:,0:1,:,:] 211 | else: 212 | mask = Xmask 213 | 214 | x0 = x 215 | y0 = y 216 | if self.normalize is not None: 217 | x = self.normalize(x) 218 | y = self.normalize(y) 219 | 220 | x_vgg = self.vgg(x) 221 | y_vgg = self.vgg(y) 222 | # self.visualize([x0]+x_vgg, [y0]+y_vgg) 223 | loss = 0 224 | style_loss = 0 225 | for i in range(3): 226 | # VGG Content Loss 227 | if self.relative: 228 | loss += self.criterion(x_vgg[i],y_vgg[i].detach(),resize_to_match(mask,x_vgg[i])) 229 | else: 230 | loss += self.criterion(resize_to_match(mask,x_vgg[i])*x_vgg[i],resize_to_match(mask,y_vgg[i])*y_vgg[i].detach()) # 231 | # loss += self.criterion(x_vgg[i], y_vgg[i].detach()) 232 | # VGG Style Loss 233 | if self.use_style: 234 | x_gram = self.gram_matrix(x_vgg[i]) 235 | y_gram = self.gram_matrix(y_vgg[i].detach()) 236 | style_loss += F.l1_loss(x_gram, y_gram) 237 | 238 | return {"content":loss, "style":style_loss} 239 | 240 | 241 | def gram_matrix(self, feat): 242 | # https://github.com/pytorch/examples/blob/master/fast_neural_style/neural_style/utils.py 243 | (b, ch, h, w) = feat.size() 244 | feat = feat.view(b, ch, h * w) 245 | feat_t = feat.transpose(1, 2) 246 | gram = torch.bmm(feat, feat_t) / (ch * h * w) 247 | return gram 248 | 249 | def normPRED(self, d, eps=1e-2): 250 | ma = np.max(d) 251 | mi = np.min(d) 252 | 253 | if ma-mi= self.threshold, torch.ones_like(mask), torch.zeros_like(mask)).to(mask.device) 333 | s_area = torch.clamp_min(torch.sum(importance_map, dim=[2,3]), self.min_area)[:,0:1] 334 | if self.k_center != 2: 335 | keys = [torch.sum(k*importance_map, dim=[2,3]) / s_area for k in keys] # b,c * k 336 | else: 337 | keys = [ 338 | torch.sum(keys[0]*importance_map, dim=[2,3]) / s_area, 339 | torch.sum(keys[1]*(1-importance_map), dim=[2,3]) / (keys[1].shape[2]*keys[1].shape[3] - s_area + eps) 340 | ] 341 | 342 | f_query = query # b, c, h, w 343 | f_key = [k.reshape(b,c,1,1).repeat(1, 1, f_query.size(2),f_query.size(3)) for k in keys] 344 | attention_scores = [] 345 | for k in f_key: 346 | combine_qk = torch.cat([f_query, k],dim=1).tanh() # tanh 347 | sk = self.sim_func(combine_qk) 348 | attention_scores.append(sk) 349 | s = ascore = torch.cat(attention_scores, dim=1) # b,k,h,w 350 | 351 | s = s.permute(0,2,3,1) # b,h,w,k 352 | v = self.v_conv(key_in) 353 | if self.k_center == 2: 354 | v_fg = torch.sum(v[:,:c]*importance_map, dim=[2,3]) / s_area 355 | v_bg = torch.sum(v[:,c:]*(1-importance_map), dim=[2,3]) / (v.shape[2]*v.shape[3] - s_area + eps) 356 | v = torch.cat([v_fg, v_bg],dim=1) 357 | else: 358 | v = torch.sum(v*importance_map, dim=[2,3]) / s_area # b, c*k 359 | v = v.reshape(b, self.k_center, c) # b, k, c 360 | attn = torch.bmm(s.reshape(b,h*w,self.k_center), v).reshape(b,h,w,c).permute(0,3,1,2) 361 | s = self.out_conv(attn + query) 362 | return s 363 | 364 | 365 | def forward(self, xin, xout, xmask): 366 | b_num,c,h,w = xin.shape 367 | attention_score = self.compute_attention(xin, xout, xmask) # b,h*w,k 368 | attention_score = attention_score.reshape(b_num,1,h,w) 369 | return xout, attention_score.sigmoid() 370 | 371 | 372 | 373 | 374 | 375 | 376 | 377 | ## Refinement Stage 378 | class ResDownNew(nn.Module): 379 | def __init__(self, in_size, out_size, pooling=True, use_att=False, dilation=False): 380 | super(ResDownNew, self).__init__() 381 | self.model = DownConv(in_size, out_size, 3, pooling=pooling, norm=nn.InstanceNorm2d, act=F.leaky_relu, dilations=[1,2,5] if dilation else []) 382 | 383 | def forward(self, x): 384 | return self.model(x) 385 | 386 | class ResUpNew(nn.Module): 387 | def __init__(self, in_size, out_size, use_att=False): 388 | super(ResUpNew, self).__init__() 389 | self.model = UpConv(in_size, out_size, 3, use_att=use_att, norm=nn.InstanceNorm2d) 390 | 391 | def forward(self, x, skip_input, mask=None): 392 | return self.model(x,skip_input,mask) 393 | 394 | 395 | 396 | 397 | class CFFBlock(nn.Module): 398 | def __init__(self, down=ResDownNew, up=ResUpNew, ngf = 32): 399 | super(CFFBlock, self).__init__() 400 | self.down1 = down(ngf, ngf) 401 | self.down2 = down(ngf, ngf*2) 402 | self.down3 = down(ngf*2, ngf*4, pooling=False, dilation=True) 403 | 404 | self.conv22 = nn.Sequential(*[ 405 | nn.Conv2d(ngf*2, ngf, 3,1,1), 406 | nn.LeakyReLU(0.2), 407 | nn.Conv2d(ngf, ngf, 3,1,1), 408 | nn.LeakyReLU(0.2) 409 | ]) 410 | 411 | self.conv33 = nn.Sequential(*[ 412 | nn.Conv2d(ngf*4,ngf*2, 3, 1, 1), 413 | nn.LeakyReLU(0.2), 414 | nn.Conv2d(ngf*2, ngf*2, 3,1,1), 415 | nn.LeakyReLU(0.2) 416 | ]) 417 | 418 | self.up32 = nn.Sequential(*[ 419 | nn.Conv2d(ngf*4, ngf*1, 3,1,1), 420 | nn.LeakyReLU(0.2), 421 | ]) 422 | 423 | self.up31 = nn.Sequential(*[ 424 | nn.Conv2d(ngf*4, ngf*1, 3,1,1), 425 | nn.LeakyReLU(0.2), 426 | ]) 427 | 428 | def forward(self, inputs): 429 | x1,x2,x3 = inputs 430 | x32 = F.interpolate(x3, size=x2.shape[2:][::-1], mode='bilinear') 431 | x32 = self.up32(x32) 432 | x31 = F.interpolate(x3, size=x1.shape[2:][::-1], mode='bilinear') 433 | x31 = self.up31(x31) 434 | 435 | # cross-connection 436 | x,d1 = self.down1(x1 + x31) 437 | x,d2 = self.down2(x + self.conv22(x2) + x32) 438 | d3,_ = self.down3(x + self.conv33(x3)) 439 | return [d1,d2,d3] 440 | 441 | 442 | 443 | 444 | 445 | 446 | --------------------------------------------------------------------------------