├── .idea ├── .name ├── misc.xml ├── vcs.xml ├── inspectionProfiles │ └── profiles_settings.xml ├── .gitignore ├── modules.xml └── DAH_github.iml ├── PerceptualSimilarity ├── data │ ├── __init__.py │ ├── dataset │ │ ├── __init__.py │ │ ├── base_dataset.py │ │ ├── jnd_dataset.py │ │ └── twoafc_dataset.py │ ├── base_data_loader.py │ ├── data_loader.py │ ├── custom_dataset_data_loader.py │ └── image_folder.py ├── util │ ├── __init__.py │ ├── util.py │ ├── html.py │ └── visualizer.py ├── .gitignore ├── imgs │ ├── fig1.png │ ├── ex_p0.png │ ├── ex_p1.png │ ├── ex_ref.png │ ├── ex_dir0 │ │ ├── 0.png │ │ └── 1.png │ ├── ex_dir1 │ │ ├── 0.png │ │ └── 1.png │ └── ex_dir_pair │ │ ├── ex_p0.png │ │ ├── ex_p1.png │ │ └── ex_ref.png ├── scripts │ ├── eval_valsets.sh │ ├── train_test_metric.sh │ ├── train_test_metric_tune.sh │ ├── train_test_metric_scratch.sh │ ├── download_dataset_valonly.sh │ └── download_dataset.sh ├── models │ ├── weights │ │ ├── v0.0 │ │ │ ├── alex.pth │ │ │ ├── vgg.pth │ │ │ └── squeeze.pth │ │ └── v0.1 │ │ │ ├── alex.pth │ │ │ ├── vgg.pth │ │ │ └── squeeze.pth │ ├── base_model.py │ ├── __init__.py │ ├── pretrained_networks.py │ ├── networks_basic.py │ └── dist_model.py ├── requirements.txt ├── compute_dists.py ├── compute_dists_dirs.py ├── LICENSE ├── perceptual_loss.py ├── Dockerfile ├── test_network.py ├── compute_dists_pair.py ├── test_dataset_model.py ├── train.py └── README.md ├── fig ├── AFE.jpg ├── AFE.pdf ├── GDE.jpg └── GDE.pdf ├── models ├── __pycache__ │ ├── module.cpython-38.pyc │ ├── RevealNet.cpython-38.pyc │ ├── HidingUNet_C.cpython-38.pyc │ └── HidingUNet_S.cpython-38.pyc ├── RevealNet.py ├── module.py ├── HidingUNet_S.py └── HidingUNet_C.py ├── runs └── main_dah_beta0.75_num_secret1num_cover1_2023-08-01_H19-31-52 │ └── events.out.tfevents.1690889527.amax-SYS-7049GP-TRT ├── scripts └── train_dah.sh ├── training └── main_dah_beta0.75_num_secret1num_cover1_2023-08-01_H19-31-52 │ └── trainingLogs │ └── train_44_log.txt ├── README.md └── main_DAH.py /.idea/.name: -------------------------------------------------------------------------------- 1 | main_DAH.py -------------------------------------------------------------------------------- /PerceptualSimilarity/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /PerceptualSimilarity/util/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /PerceptualSimilarity/data/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /PerceptualSimilarity/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | 3 | checkpoints/* 4 | -------------------------------------------------------------------------------- /fig/AFE.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/fig/AFE.jpg -------------------------------------------------------------------------------- /fig/AFE.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/fig/AFE.pdf -------------------------------------------------------------------------------- /fig/GDE.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/fig/GDE.jpg -------------------------------------------------------------------------------- /fig/GDE.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/fig/GDE.pdf -------------------------------------------------------------------------------- /PerceptualSimilarity/imgs/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/PerceptualSimilarity/imgs/fig1.png -------------------------------------------------------------------------------- /PerceptualSimilarity/imgs/ex_p0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/PerceptualSimilarity/imgs/ex_p0.png -------------------------------------------------------------------------------- /PerceptualSimilarity/imgs/ex_p1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/PerceptualSimilarity/imgs/ex_p1.png -------------------------------------------------------------------------------- /PerceptualSimilarity/imgs/ex_ref.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/PerceptualSimilarity/imgs/ex_ref.png -------------------------------------------------------------------------------- /PerceptualSimilarity/imgs/ex_dir0/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/PerceptualSimilarity/imgs/ex_dir0/0.png -------------------------------------------------------------------------------- /PerceptualSimilarity/imgs/ex_dir0/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/PerceptualSimilarity/imgs/ex_dir0/1.png -------------------------------------------------------------------------------- /PerceptualSimilarity/imgs/ex_dir1/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/PerceptualSimilarity/imgs/ex_dir1/0.png -------------------------------------------------------------------------------- /PerceptualSimilarity/imgs/ex_dir1/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/PerceptualSimilarity/imgs/ex_dir1/1.png -------------------------------------------------------------------------------- /models/__pycache__/module.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/models/__pycache__/module.cpython-38.pyc -------------------------------------------------------------------------------- /PerceptualSimilarity/scripts/eval_valsets.sh: -------------------------------------------------------------------------------- 1 | 2 | python ./test_dataset_model.py --dataset_mode 2afc --model net-lin --net alex --use_gpu --batch_size 50 3 | 4 | -------------------------------------------------------------------------------- /models/__pycache__/RevealNet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/models/__pycache__/RevealNet.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/HidingUNet_C.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/models/__pycache__/HidingUNet_C.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/HidingUNet_S.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/models/__pycache__/HidingUNet_S.cpython-38.pyc -------------------------------------------------------------------------------- /PerceptualSimilarity/imgs/ex_dir_pair/ex_p0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/PerceptualSimilarity/imgs/ex_dir_pair/ex_p0.png -------------------------------------------------------------------------------- /PerceptualSimilarity/imgs/ex_dir_pair/ex_p1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/PerceptualSimilarity/imgs/ex_dir_pair/ex_p1.png -------------------------------------------------------------------------------- /PerceptualSimilarity/imgs/ex_dir_pair/ex_ref.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/PerceptualSimilarity/imgs/ex_dir_pair/ex_ref.png -------------------------------------------------------------------------------- /PerceptualSimilarity/models/weights/v0.0/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/PerceptualSimilarity/models/weights/v0.0/alex.pth -------------------------------------------------------------------------------- /PerceptualSimilarity/models/weights/v0.0/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/PerceptualSimilarity/models/weights/v0.0/vgg.pth -------------------------------------------------------------------------------- /PerceptualSimilarity/models/weights/v0.1/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/PerceptualSimilarity/models/weights/v0.1/alex.pth -------------------------------------------------------------------------------- /PerceptualSimilarity/models/weights/v0.1/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/PerceptualSimilarity/models/weights/v0.1/vgg.pth -------------------------------------------------------------------------------- /PerceptualSimilarity/models/weights/v0.0/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/PerceptualSimilarity/models/weights/v0.0/squeeze.pth -------------------------------------------------------------------------------- /PerceptualSimilarity/models/weights/v0.1/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/PerceptualSimilarity/models/weights/v0.1/squeeze.pth -------------------------------------------------------------------------------- /PerceptualSimilarity/requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=0.4.0 2 | torchvision>=0.2.1 3 | numpy>=1.14.3 4 | scipy>=1.0.1 5 | scikit-image>=0.13.0 6 | opencv>=2.4.11 7 | matplotlib>=1.5.1 8 | tqdm>=4.28.1 9 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /PerceptualSimilarity/data/base_data_loader.py: -------------------------------------------------------------------------------- 1 | 2 | class BaseDataLoader(): 3 | def __init__(self): 4 | pass 5 | 6 | def initialize(self): 7 | pass 8 | 9 | def load_data(): 10 | return None 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Datasource local storage ignored files 5 | /../../../../../../:\Users\lgm\Desktop\DAH_github\.idea/dataSources/ 6 | /dataSources.local.xml 7 | # Editor-based HTTP Client requests 8 | /httpRequests/ 9 | -------------------------------------------------------------------------------- /runs/main_dah_beta0.75_num_secret1num_cover1_2023-08-01_H19-31-52/events.out.tfevents.1690889527.amax-SYS-7049GP-TRT: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/runs/main_dah_beta0.75_num_secret1num_cover1_2023-08-01_H19-31-52/events.out.tfevents.1690889527.amax-SYS-7049GP-TRT -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /PerceptualSimilarity/scripts/train_test_metric.sh: -------------------------------------------------------------------------------- 1 | 2 | TRIAL=${1} 3 | NET=${2} 4 | mkdir checkpoints 5 | mkdir checkpoints/${NET}_${TRIAL} 6 | python ./train.py --use_gpu --net ${NET} --name ${NET}_${TRIAL} 7 | python ./test_dataset_model.py --use_gpu --net ${NET} --model_path ./checkpoints/${NET}_${TRIAL}/latest_net_.pth 8 | -------------------------------------------------------------------------------- /PerceptualSimilarity/data/dataset/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | class BaseDataset(data.Dataset): 4 | def __init__(self): 5 | super(BaseDataset, self).__init__() 6 | 7 | def name(self): 8 | return 'BaseDataset' 9 | 10 | def initialize(self): 11 | pass 12 | 13 | -------------------------------------------------------------------------------- /PerceptualSimilarity/scripts/train_test_metric_tune.sh: -------------------------------------------------------------------------------- 1 | 2 | TRIAL=${1} 3 | NET=${2} 4 | mkdir checkpoints 5 | mkdir checkpoints/${NET}_${TRIAL}_tune 6 | python ./train.py --train_trunk --use_gpu --net ${NET} --name ${NET}_${TRIAL}_tune 7 | python ./test_dataset_model.py --train_trunk --use_gpu --net ${NET} --model_path ./checkpoints/${NET}_${TRIAL}_tune/latest_net_.pth 8 | -------------------------------------------------------------------------------- /PerceptualSimilarity/scripts/train_test_metric_scratch.sh: -------------------------------------------------------------------------------- 1 | 2 | TRIAL=${1} 3 | NET=${2} 4 | mkdir checkpoints 5 | mkdir checkpoints/${NET}_${TRIAL}_scratch 6 | python ./train.py --from_scratch --train_trunk --use_gpu --net ${NET} --name ${NET}_${TRIAL}_scratch 7 | python ./test_dataset_model.py --from_scratch --train_trunk --use_gpu --net ${NET} --model_path ./checkpoints/${NET}_${TRIAL}_scratch/latest_net_.pth 8 | 9 | -------------------------------------------------------------------------------- /scripts/train_dah.sh: -------------------------------------------------------------------------------- 1 | #python main_DAH.py --imageSize 128 --bs_secret 44 --num_training 1 --num_secret 1 --num_cover 1 --channel_cover 3 --channel_secret 3 --norm 'batch' --epochs 120 --loss 'l2' --beta 0.75 --remark 'main_dah' 2 | python main_DAH.py --imageSize 128 --bs_secret 44 --num_training 1 --num_secret 1 --num_cover 1 --channel_cover 3 --channel_secret 3 --norm 'batch' --epochs 120 --loss 'l2' --beta 0.75 --remark 'main_dah' -------------------------------------------------------------------------------- /PerceptualSimilarity/data/data_loader.py: -------------------------------------------------------------------------------- 1 | def CreateDataLoader(datafolder,dataroot='./dataset',dataset_mode='2afc',load_size=64,batch_size=1,serial_batches=True,nThreads=4): 2 | from data.custom_dataset_data_loader import CustomDatasetDataLoader 3 | data_loader = CustomDatasetDataLoader() 4 | # print(data_loader.name()) 5 | data_loader.initialize(datafolder,dataroot=dataroot+'/'+dataset_mode,dataset_mode=dataset_mode,load_size=load_size,batch_size=batch_size,serial_batches=serial_batches, nThreads=nThreads) 6 | return data_loader 7 | -------------------------------------------------------------------------------- /.idea/DAH_github.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | 13 | 15 | -------------------------------------------------------------------------------- /PerceptualSimilarity/scripts/download_dataset_valonly.sh: -------------------------------------------------------------------------------- 1 | 2 | mkdir dataset 3 | 4 | # JND Dataset 5 | wget https://people.eecs.berkeley.edu/~rich.zhang/projects/2018_perceptual/dataset/jnd.tar.gz -O ./dataset/jnd.tar.gz 6 | 7 | mkdir dataset/jnd 8 | tar -xzf ./dataset/jnd.tar.gz -C ./dataset 9 | rm ./dataset/jnd.tar.gz 10 | 11 | # 2AFC Val set 12 | mkdir dataset/2afc/ 13 | wget https://people.eecs.berkeley.edu/~rich.zhang/projects/2018_perceptual/dataset/twoafc_val.tar.gz -O ./dataset/twoafc_val.tar.gz 14 | 15 | mkdir dataset/2afc/val 16 | tar -xzf ./dataset/twoafc_val.tar.gz -C ./dataset/2afc 17 | rm ./dataset/twoafc_val.tar.gz 18 | -------------------------------------------------------------------------------- /PerceptualSimilarity/compute_dists.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import models 3 | from util import util 4 | 5 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 6 | parser.add_argument('-p0','--path0', type=str, default='./imgs/ex_ref.png') 7 | parser.add_argument('-p1','--path1', type=str, default='./imgs/ex_p0.png') 8 | parser.add_argument('--use_gpu', action='store_true', help='turn on flag to use GPU') 9 | 10 | opt = parser.parse_args() 11 | 12 | ## Initializing the model 13 | model = models.PerceptualLoss(model='net-lin',net='alex',use_gpu=opt.use_gpu) 14 | 15 | # Load images 16 | img0 = util.im2tensor(util.load_image(opt.path0)) # RGB image from [-1,1] 17 | img1 = util.im2tensor(util.load_image(opt.path1)) 18 | 19 | if(opt.use_gpu): 20 | img0 = img0.cuda() 21 | img1 = img1.cuda() 22 | 23 | 24 | # Compute distance 25 | dist01 = model.forward(img0,img1) 26 | print('Distance: %.3f'%dist01) 27 | -------------------------------------------------------------------------------- /PerceptualSimilarity/scripts/download_dataset.sh: -------------------------------------------------------------------------------- 1 | 2 | mkdir dataset 3 | 4 | # JND Dataset 5 | wget https://people.eecs.berkeley.edu/~rich.zhang/projects/2018_perceptual/dataset/jnd.tar.gz -O ./dataset/jnd.tar.gz 6 | 7 | mkdir dataset/jnd 8 | tar -xzf ./dataset/jnd.tar.gz -C ./dataset 9 | rm ./dataset/jnd.tar.gz 10 | 11 | # 2AFC Val set 12 | mkdir dataset/2afc/ 13 | wget https://people.eecs.berkeley.edu/~rich.zhang/projects/2018_perceptual/dataset/twoafc_val.tar.gz -O ./dataset/twoafc_val.tar.gz 14 | 15 | mkdir dataset/2afc/val 16 | tar -xzf ./dataset/twoafc_val.tar.gz -C ./dataset/2afc 17 | rm ./dataset/twoafc_val.tar.gz 18 | 19 | # 2AFC Train set 20 | mkdir dataset/2afc/ 21 | wget https://people.eecs.berkeley.edu/~rich.zhang/projects/2018_perceptual/dataset/twoafc_train.tar.gz -O ./dataset/twoafc_train.tar.gz 22 | 23 | mkdir dataset/2afc/train 24 | tar -xzf ./dataset/twoafc_train.tar.gz -C ./dataset/2afc 25 | rm ./dataset/twoafc_train.tar.gz 26 | -------------------------------------------------------------------------------- /training/main_dah_beta0.75_num_secret1num_cover1_2023-08-01_H19-31-52/trainingLogs/train_44_log.txt: -------------------------------------------------------------------------------- 1 | Namespace(Hnet_C='', Hnet_S='', Rnet='', beta=0.75, beta1=0.5, bs_secret=44, channel_cover=3, channel_secret=3, checkpoint='', checkpoint_diff='', cover_dependent=False, cuda=True, dataset='train', debug=False, decay_round=10, epochs=120, hostname='amax-SYS-7049GP-TRT', imageSize=128, iters_per_epoch=2000, logFrequency=1000, loss='l2', lr=0.001, ngpu=2, no_cover=False, noise_cover=False, norm='batch', num_cover=1, num_secret=1, num_training=1, outckpts='./training/main_dah_beta0.75_num_secret1num_cover1_2023-08-01_H19-31-52/checkPoints', outcodes='./training/main_dah_beta0.75_num_secret1num_cover1_2023-08-01_H19-31-52/codes', outlogs='./training/main_dah_beta0.75_num_secret1num_cover1_2023-08-01_H19-31-52/trainingLogs', plain_cover=False, remark='main_dah', resultPicFrequency=100, test='', testPics='./training/', test_diff='', trainpics='./training/main_dah_beta0.75_num_secret1num_cover1_2023-08-01_H19-31-52/trainPics', validationpics='./training/main_dah_beta0.75_num_secret1num_cover1_2023-08-01_H19-31-52/validationPics', workers=8) 2 | training is beginning ....................................................... 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep-adaptive-hiding-network 2 | Official PyTorch implementation of "Deep adaptive hiding network for image hiding using attentive frequency extraction and gradual depth extraction" 3 | 4 | ![GDE.jpg](fig/GDE.jpg) 5 | 6 | ![AFE.jpg](fig/AFE.jpg) 7 | ## Requirements 8 | This code was developed and tested with Python3.6, Pytorch 1.5 and CUDA 10.2 on Ubuntu 18.04.5. 9 | 10 | ## Train DAH-Net on ImageNet datasets 11 | You are able to run the provided demo code. 12 | 13 | 1. Prepare the ImageNet datasets and visualization dataset. 14 | 15 | 2. Change the data path on lines 210-214 of train_dah.py. 16 | 17 | (Images for training exist in traindir and valdir, and images for visualization exist in coverdir and secretdir ). 18 | 19 | 3. ```sh ./scripts/train_dah.sh ``` 20 | 21 | ## Citation 22 | If you found our research helpful or influential please consider citing 23 | 24 | 25 | ### BibTeX 26 | @article{zhang2023deep, 27 | title={Deep adaptive hiding network for image hiding using attentive frequency extraction and gradual depth extraction}, 28 | author={Zhang, Le and Lu, Yao and Li, Jinxing and Chen, Fanglin and Lu, Guangming and Zhang, David}, 29 | journal={Neural Computing and Applications}, 30 | pages={1--19}, 31 | year={2023}, 32 | publisher={Springer} 33 | } 34 | -------------------------------------------------------------------------------- /PerceptualSimilarity/compute_dists_dirs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import models 4 | from util import util 5 | 6 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 7 | parser.add_argument('-d0','--dir0', type=str, default='./imgs/ex_dir0') 8 | parser.add_argument('-d1','--dir1', type=str, default='./imgs/ex_dir1') 9 | parser.add_argument('-o','--out', type=str, default='./imgs/example_dists.txt') 10 | parser.add_argument('-v','--version', type=str, default='0.1') 11 | parser.add_argument('--use_gpu', action='store_true', help='turn on flag to use GPU') 12 | 13 | opt = parser.parse_args() 14 | 15 | ## Initializing the model 16 | model = models.PerceptualLoss(model='net-lin',net='alex',use_gpu=opt.use_gpu,version=opt.version) 17 | 18 | # crawl directories 19 | f = open(opt.out,'w') 20 | files = os.listdir(opt.dir0) 21 | 22 | for file in files: 23 | if(os.path.exists(os.path.join(opt.dir1,file))): 24 | # Load images 25 | img0 = util.im2tensor(util.load_image(os.path.join(opt.dir0,file))) # RGB image from [-1,1] 26 | img1 = util.im2tensor(util.load_image(os.path.join(opt.dir1,file))) 27 | 28 | if(opt.use_gpu): 29 | img0 = img0.cuda() 30 | img1 = img1.cuda() 31 | 32 | # Compute distance 33 | dist01 = model.forward(img0,img1) 34 | print('%s: %.3f'%(file,dist01)) 35 | f.writelines('%s: %.6f\n'%(file,dist01)) 36 | 37 | f.close() 38 | -------------------------------------------------------------------------------- /PerceptualSimilarity/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | 25 | -------------------------------------------------------------------------------- /PerceptualSimilarity/util/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import numpy as np 4 | from PIL import Image 5 | import numpy as np 6 | import os 7 | import matplotlib.pyplot as plt 8 | import torch 9 | 10 | def load_image(path): 11 | if(path[-3:] == 'dng'): 12 | import rawpy 13 | with rawpy.imread(path) as raw: 14 | img = raw.postprocess() 15 | elif(path[-3:]=='bmp' or path[-3:]=='jpg' or path[-3:]=='png'): 16 | import cv2 17 | return cv2.imread(path)[:,:,::-1] 18 | else: 19 | img = (255*plt.imread(path)[:,:,:3]).astype('uint8') 20 | 21 | return img 22 | 23 | def save_image(image_numpy, image_path, ): 24 | image_pil = Image.fromarray(image_numpy) 25 | image_pil.save(image_path) 26 | 27 | def mkdirs(paths): 28 | if isinstance(paths, list) and not isinstance(paths, str): 29 | for path in paths: 30 | mkdir(path) 31 | else: 32 | mkdir(paths) 33 | 34 | def mkdir(path): 35 | if not os.path.exists(path): 36 | os.makedirs(path) 37 | 38 | 39 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 40 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): 41 | image_numpy = image_tensor[0].cpu().float().numpy() 42 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 43 | return image_numpy.astype(imtype) 44 | 45 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 46 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): 47 | return torch.Tensor((image / factor - cent) 48 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 49 | -------------------------------------------------------------------------------- /PerceptualSimilarity/data/custom_dataset_data_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from data.base_data_loader import BaseDataLoader 3 | import os 4 | 5 | def CreateDataset(dataroots,dataset_mode='2afc',load_size=64,): 6 | dataset = None 7 | if dataset_mode=='2afc': # human judgements 8 | from dataset.twoafc_dataset import TwoAFCDataset 9 | dataset = TwoAFCDataset() 10 | elif dataset_mode=='jnd': # human judgements 11 | from dataset.jnd_dataset import JNDDataset 12 | dataset = JNDDataset() 13 | else: 14 | raise ValueError("Dataset Mode [%s] not recognized."%self.dataset_mode) 15 | 16 | dataset.initialize(dataroots,load_size=load_size) 17 | return dataset 18 | 19 | class CustomDatasetDataLoader(BaseDataLoader): 20 | def name(self): 21 | return 'CustomDatasetDataLoader' 22 | 23 | def initialize(self, datafolders, dataroot='./dataset',dataset_mode='2afc',load_size=64,batch_size=1,serial_batches=True, nThreads=1): 24 | BaseDataLoader.initialize(self) 25 | if(not isinstance(datafolders,list)): 26 | datafolders = [datafolders,] 27 | data_root_folders = [os.path.join(dataroot,datafolder) for datafolder in datafolders] 28 | self.dataset = CreateDataset(data_root_folders,dataset_mode=dataset_mode,load_size=load_size) 29 | self.dataloader = torch.utils.data.DataLoader( 30 | self.dataset, 31 | batch_size=batch_size, 32 | shuffle=not serial_batches, 33 | num_workers=int(nThreads)) 34 | 35 | def load_data(self): 36 | return self.dataloader 37 | 38 | def __len__(self): 39 | return len(self.dataset) 40 | -------------------------------------------------------------------------------- /PerceptualSimilarity/perceptual_loss.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | import sys 5 | import scipy 6 | import scipy.misc 7 | import numpy as np 8 | import torch 9 | from torch.autograd import Variable 10 | import models 11 | 12 | use_gpu = True 13 | 14 | ref_path = './imgs/ex_ref.png' 15 | pred_path = './imgs/ex_p1.png' 16 | 17 | ref_img = scipy.misc.imread(ref_path).transpose(2, 0, 1) / 255. 18 | pred_img = scipy.misc.imread(pred_path).transpose(2, 0, 1) / 255. 19 | 20 | # Torchify 21 | ref = Variable(torch.FloatTensor(ref_img)[None,:,:,:]) 22 | pred = Variable(torch.FloatTensor(pred_img)[None,:,:,:], requires_grad=True) 23 | 24 | loss_fn = models.PerceptualLoss(model='net-lin', net='vgg', use_gpu=use_gpu) 25 | optimizer = torch.optim.Adam([pred,], lr=1e-3, betas=(0.9, 0.999)) 26 | 27 | import matplotlib.pyplot as plt 28 | plt.ion() 29 | fig = plt.figure(1) 30 | ax = fig.add_subplot(131) 31 | ax.imshow(ref_img.transpose(1, 2, 0)) 32 | ax.set_title('target') 33 | ax = fig.add_subplot(133) 34 | ax.imshow(pred_img.transpose(1, 2, 0)) 35 | ax.set_title('initialization') 36 | 37 | for i in range(1000): 38 | dist = loss_fn.forward(pred, ref, normalize=True) 39 | optimizer.zero_grad() 40 | dist.backward() 41 | optimizer.step() 42 | pred.data = torch.clamp(pred.data, 0, 1) 43 | 44 | if i % 10 == 0: 45 | print('iter %d, dist %.3g' % (i, dist.view(-1).data.cpu().numpy()[0])) 46 | pred_img = pred[0].data.cpu().numpy().transpose(1, 2, 0) 47 | pred_img = np.clip(pred_img, 0, 1) 48 | ax = fig.add_subplot(132) 49 | ax.imshow(pred_img) 50 | ax.set_title('iter %d, dist %.3f' % (i, dist.view(-1).data.cpu().numpy()[0])) 51 | plt.pause(5e-2) 52 | # plt.imsave('imgs_saved/%04d.jpg'%i,pred_img) 53 | 54 | 55 | -------------------------------------------------------------------------------- /PerceptualSimilarity/models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.autograd import Variable 4 | from pdb import set_trace as st 5 | from IPython import embed 6 | 7 | class BaseModel(): 8 | def __init__(self): 9 | pass; 10 | 11 | def name(self): 12 | return 'BaseModel' 13 | 14 | def initialize(self, use_gpu=True, gpu_ids=[0]): 15 | self.use_gpu = use_gpu 16 | self.gpu_ids = gpu_ids 17 | 18 | def forward(self): 19 | pass 20 | 21 | def get_image_paths(self): 22 | pass 23 | 24 | def optimize_parameters(self): 25 | pass 26 | 27 | def get_current_visuals(self): 28 | return self.input 29 | 30 | def get_current_errors(self): 31 | return {} 32 | 33 | def save(self, label): 34 | pass 35 | 36 | # helper saving function that can be used by subclasses 37 | def save_network(self, network, path, network_label, epoch_label): 38 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 39 | save_path = os.path.join(path, save_filename) 40 | torch.save(network.state_dict(), save_path) 41 | 42 | # helper loading function that can be used by subclasses 43 | def load_network(self, network, network_label, epoch_label): 44 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 45 | save_path = os.path.join(self.save_dir, save_filename) 46 | # print('Loading network from %s'%save_path) 47 | network.load_state_dict(torch.load(save_path)) 48 | 49 | def update_learning_rate(): 50 | pass 51 | 52 | def get_image_paths(self): 53 | return self.image_paths 54 | 55 | def save_done(self, flag=False): 56 | np.save(os.path.join(self.save_dir, 'done_flag'),flag) 57 | np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i') 58 | 59 | -------------------------------------------------------------------------------- /models/RevealNet.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: yongzhi li 4 | @contact: yongzhili@vip.qq.com 5 | 6 | @version: 1.0 7 | @file: Reveal.py 8 | @time: 2018/3/20 9 | 10 | """ 11 | 12 | import torch.nn as nn 13 | 14 | 15 | class RevealNet(nn.Module): 16 | def __init__(self, input_nc, output_nc, nhf=64, norm_layer=None, output_function=nn.Sigmoid): 17 | super(RevealNet, self).__init__() 18 | # input is (3) x 256 x 256 19 | 20 | self.conv1 = nn.Conv2d(input_nc, nhf, 3, 1, 1) 21 | self.conv2 = nn.Conv2d(nhf, nhf * 2, 3, 1, 1) 22 | self.conv3 = nn.Conv2d(nhf * 2, nhf * 4, 3, 1, 1) 23 | self.conv4 = nn.Conv2d(nhf * 4, nhf * 2, 3, 1, 1) 24 | self.conv5 = nn.Conv2d(nhf * 2, nhf, 3, 1, 1) 25 | self.conv6 = nn.Conv2d(nhf, output_nc, 3, 1, 1) 26 | self.output=output_function() 27 | self.relu = nn.ReLU(True) 28 | 29 | self.norm_layer = norm_layer 30 | if norm_layer != None: 31 | self.norm1 = norm_layer(nhf) 32 | self.norm2 = norm_layer(nhf*2) 33 | self.norm3 = norm_layer(nhf*4) 34 | self.norm4 = norm_layer(nhf*2) 35 | self.norm5 = norm_layer(nhf) 36 | 37 | def forward(self, input): 38 | 39 | if self.norm_layer != None: 40 | x=self.relu(self.norm1(self.conv1(input))) 41 | x=self.relu(self.norm2(self.conv2(x))) 42 | x=self.relu(self.norm3(self.conv3(x))) 43 | x=self.relu(self.norm4(self.conv4(x))) 44 | x=self.relu(self.norm5(self.conv5(x))) 45 | x=self.output(self.conv6(x)) 46 | else: 47 | x=self.relu(self.conv1(input)) 48 | x=self.relu(self.conv2(x)) 49 | x=self.relu(self.conv3(x)) 50 | x=self.relu(self.conv4(x)) 51 | x=self.relu(self.conv5(x)) 52 | x=self.output(self.conv6(x)) 53 | 54 | return x 55 | -------------------------------------------------------------------------------- /PerceptualSimilarity/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:9.0-base-ubuntu16.04 2 | 3 | LABEL maintainer="Seyoung Park " 4 | 5 | # This Dockerfile is forked from Tensorflow Dockerfile 6 | 7 | # Pick up some PyTorch gpu dependencies 8 | RUN apt-get update && apt-get install -y --no-install-recommends \ 9 | build-essential \ 10 | cuda-command-line-tools-9-0 \ 11 | cuda-cublas-9-0 \ 12 | cuda-cufft-9-0 \ 13 | cuda-curand-9-0 \ 14 | cuda-cusolver-9-0 \ 15 | cuda-cusparse-9-0 \ 16 | curl \ 17 | libcudnn7=7.1.4.18-1+cuda9.0 \ 18 | libfreetype6-dev \ 19 | libhdf5-serial-dev \ 20 | libpng12-dev \ 21 | libzmq3-dev \ 22 | pkg-config \ 23 | python \ 24 | python-dev \ 25 | rsync \ 26 | software-properties-common \ 27 | unzip \ 28 | && \ 29 | apt-get clean && \ 30 | rm -rf /var/lib/apt/lists/* 31 | 32 | 33 | # Install miniconda 34 | RUN apt-get update && apt-get install -y --no-install-recommends \ 35 | wget && \ 36 | MINICONDA="Miniconda3-latest-Linux-x86_64.sh" && \ 37 | wget --quiet https://repo.continuum.io/miniconda/$MINICONDA && \ 38 | bash $MINICONDA -b -p /miniconda && \ 39 | rm -f $MINICONDA 40 | ENV PATH /miniconda/bin:$PATH 41 | 42 | # Install PyTorch 43 | RUN conda update -n base conda && \ 44 | conda install pytorch torchvision cuda90 -c pytorch 45 | 46 | # Install PerceptualSimilarity dependencies 47 | RUN conda install numpy scipy jupyter matplotlib && \ 48 | conda install -c conda-forge scikit-image && \ 49 | apt-get install -y python-qt4 && \ 50 | pip install opencv-python 51 | 52 | # For CUDA profiling, TensorFlow requires CUPTI. Maybe PyTorch needs this too. 53 | ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH 54 | 55 | # IPython 56 | EXPOSE 8888 57 | 58 | WORKDIR "/notebooks" 59 | 60 | -------------------------------------------------------------------------------- /PerceptualSimilarity/test_network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from util import util 3 | import models 4 | from models import dist_model as dm 5 | from IPython import embed 6 | 7 | use_gpu = False # Whether to use GPU 8 | spatial = True # Return a spatial map of perceptual distance. 9 | 10 | # Linearly calibrated models (LPIPS) 11 | model = models.PerceptualLoss(model='net-lin', net='alex', use_gpu=use_gpu, spatial=spatial) 12 | # Can also set net = 'squeeze' or 'vgg' 13 | 14 | # Off-the-shelf uncalibrated networks 15 | # model = models.PerceptualLoss(model='net', net='alex', use_gpu=use_gpu, spatial=spatial) 16 | # Can also set net = 'squeeze' or 'vgg' 17 | 18 | # Low-level metrics 19 | # model = models.PerceptualLoss(model='L2', colorspace='Lab', use_gpu=use_gpu) 20 | # model = models.PerceptualLoss(model='ssim', colorspace='RGB', use_gpu=use_gpu) 21 | 22 | ## Example usage with dummy tensors 23 | dummy_im0 = torch.zeros(1,3,64,64) # image should be RGB, normalized to [-1,1] 24 | dummy_im1 = torch.zeros(1,3,64,64) 25 | if(use_gpu): 26 | dummy_im0 = dummy_im0.cuda() 27 | dummy_im1 = dummy_im1.cuda() 28 | dist = model.forward(dummy_im0,dummy_im1) 29 | 30 | ## Example usage with images 31 | ex_ref = util.im2tensor(util.load_image('./imgs/ex_ref.png')) 32 | ex_p0 = util.im2tensor(util.load_image('./imgs/ex_p0.png')) 33 | ex_p1 = util.im2tensor(util.load_image('./imgs/ex_p1.png')) 34 | if(use_gpu): 35 | ex_ref = ex_ref.cuda() 36 | ex_p0 = ex_p0.cuda() 37 | ex_p1 = ex_p1.cuda() 38 | 39 | ex_d0 = model.forward(ex_ref,ex_p0) 40 | ex_d1 = model.forward(ex_ref,ex_p1) 41 | 42 | if not spatial: 43 | print('Distances: (%.3f, %.3f)'%(ex_d0, ex_d1)) 44 | else: 45 | print('Distances: (%.3f, %.3f)'%(ex_d0.mean(), ex_d1.mean())) # The mean distance is approximately the same as the non-spatial distance 46 | 47 | # Visualize a spatially-varying distance map between ex_p0 and ex_ref 48 | import pylab 49 | pylab.imshow(ex_d0[0,0,...].data.cpu().numpy()) 50 | pylab.show() 51 | -------------------------------------------------------------------------------- /PerceptualSimilarity/compute_dists_pair.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import models 4 | from util import util 5 | import numpy as np 6 | from IPython import embed 7 | 8 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 9 | parser.add_argument('-d','--dir', type=str, default='./imgs/ex_dir0') 10 | parser.add_argument('-o','--out', type=str, default='./imgs/example_dists.txt') 11 | parser.add_argument('-v','--version', type=str, default='0.1') 12 | parser.add_argument('--all-pairs', action='store_true', help='turn on to test all N(N-1)/2 pairs, leave off to just do consecutive pairs (N-1)') 13 | parser.add_argument('-N', type=int, default=None) 14 | parser.add_argument('--use_gpu', action='store_true', help='turn on flag to use GPU') 15 | 16 | opt = parser.parse_args() 17 | 18 | ## Initializing the model 19 | model = models.PerceptualLoss(model='net-lin',net='alex',use_gpu=opt.use_gpu,version=opt.version) 20 | 21 | # crawl directories 22 | f = open(opt.out,'w') 23 | files = os.listdir(opt.dir) 24 | if(opt.N is not None): 25 | files = files[:opt.N] 26 | F = len(files) 27 | 28 | dists = [] 29 | for (ff,file) in enumerate(files[:-1]): 30 | img0 = util.im2tensor(util.load_image(os.path.join(opt.dir,file))) # RGB image from [-1,1] 31 | if(opt.use_gpu): 32 | img0 = img0.cuda() 33 | 34 | if(opt.all_pairs): 35 | files1 = files[ff+1:] 36 | else: 37 | files1 = [files[ff+1],] 38 | 39 | for file1 in files1: 40 | img1 = util.im2tensor(util.load_image(os.path.join(opt.dir,file1))) 41 | 42 | if(opt.use_gpu): 43 | img1 = img1.cuda() 44 | 45 | # Compute distance 46 | dist01 = model.forward(img0,img1) 47 | print('(%s,%s): %.3f'%(file,file1,dist01)) 48 | f.writelines('(%s,%s): %.6f\n'%(file,file1,dist01)) 49 | 50 | dists.append(dist01.item()) 51 | 52 | avg_dist = np.mean(np.array(dists)) 53 | stderr_dist = np.std(np.array(dists))/np.sqrt(len(dists)) 54 | 55 | print('Avg: %.5f +/- %.5f'%(avg_dist,stderr_dist)) 56 | f.writelines('Avg: %.6f +/- %.6f'%(avg_dist,stderr_dist)) 57 | 58 | f.close() 59 | -------------------------------------------------------------------------------- /PerceptualSimilarity/data/dataset/jnd_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import torchvision.transforms as transforms 3 | from data.dataset.base_dataset import BaseDataset 4 | from data.image_folder import make_dataset 5 | from PIL import Image 6 | import numpy as np 7 | import torch 8 | from IPython import embed 9 | 10 | class JNDDataset(BaseDataset): 11 | def initialize(self, dataroot, load_size=64): 12 | self.root = dataroot 13 | self.load_size = load_size 14 | 15 | self.dir_p0 = os.path.join(self.root, 'p0') 16 | self.p0_paths = make_dataset(self.dir_p0) 17 | self.p0_paths = sorted(self.p0_paths) 18 | 19 | self.dir_p1 = os.path.join(self.root, 'p1') 20 | self.p1_paths = make_dataset(self.dir_p1) 21 | self.p1_paths = sorted(self.p1_paths) 22 | 23 | transform_list = [] 24 | transform_list.append(transforms.Scale(load_size)) 25 | transform_list += [transforms.ToTensor(), 26 | transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))] 27 | 28 | self.transform = transforms.Compose(transform_list) 29 | 30 | # judgement directory 31 | self.dir_S = os.path.join(self.root, 'same') 32 | self.same_paths = make_dataset(self.dir_S,mode='np') 33 | self.same_paths = sorted(self.same_paths) 34 | 35 | def __getitem__(self, index): 36 | p0_path = self.p0_paths[index] 37 | p0_img_ = Image.open(p0_path).convert('RGB') 38 | p0_img = self.transform(p0_img_) 39 | 40 | p1_path = self.p1_paths[index] 41 | p1_img_ = Image.open(p1_path).convert('RGB') 42 | p1_img = self.transform(p1_img_) 43 | 44 | same_path = self.same_paths[index] 45 | same_img = np.load(same_path).reshape((1,1,1,)) # [0,1] 46 | 47 | same_img = torch.FloatTensor(same_img) 48 | 49 | return {'p0': p0_img, 'p1': p1_img, 'same': same_img, 50 | 'p0_path': p0_path, 'p1_path': p1_path, 'same_path': same_path} 51 | 52 | def __len__(self): 53 | return len(self.p0_paths) 54 | -------------------------------------------------------------------------------- /PerceptualSimilarity/util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import * 3 | import os 4 | 5 | 6 | class HTML: 7 | def __init__(self, web_dir, title, image_subdir='', reflesh=0): 8 | self.title = title 9 | self.web_dir = web_dir 10 | # self.img_dir = os.path.join(self.web_dir, ) 11 | self.img_subdir = image_subdir 12 | self.img_dir = os.path.join(self.web_dir, image_subdir) 13 | if not os.path.exists(self.web_dir): 14 | os.makedirs(self.web_dir) 15 | if not os.path.exists(self.img_dir): 16 | os.makedirs(self.img_dir) 17 | # print(self.img_dir) 18 | 19 | self.doc = dominate.document(title=title) 20 | if reflesh > 0: 21 | with self.doc.head: 22 | meta(http_equiv="reflesh", content=str(reflesh)) 23 | 24 | def get_image_dir(self): 25 | return self.img_dir 26 | 27 | def add_header(self, str): 28 | with self.doc: 29 | h3(str) 30 | 31 | def add_table(self, border=1): 32 | self.t = table(border=border, style="table-layout: fixed;") 33 | self.doc.add(self.t) 34 | 35 | def add_images(self, ims, txts, links, width=400): 36 | self.add_table() 37 | with self.t: 38 | with tr(): 39 | for im, txt, link in zip(ims, txts, links): 40 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 41 | with p(): 42 | with a(href=os.path.join(link)): 43 | img(style="width:%dpx" % width, src=os.path.join(im)) 44 | br() 45 | p(txt) 46 | 47 | def save(self,file='index'): 48 | html_file = '%s/%s.html' % (self.web_dir,file) 49 | f = open(html_file, 'wt') 50 | f.write(self.doc.render()) 51 | f.close() 52 | 53 | 54 | if __name__ == '__main__': 55 | html = HTML('web/', 'test_html') 56 | html.add_header('hello world') 57 | 58 | ims = [] 59 | txts = [] 60 | links = [] 61 | for n in range(4): 62 | ims.append('image_%d.png' % n) 63 | txts.append('text_%d' % n) 64 | links.append('image_%d.png' % n) 65 | html.add_images(ims, txts, links) 66 | html.save() 67 | -------------------------------------------------------------------------------- /PerceptualSimilarity/data/image_folder.py: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # Code from 3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py 4 | # Modified the original code so that it also loads images from the current 5 | # directory as well as the subdirectories 6 | ################################################################################ 7 | 8 | import torch.utils.data as data 9 | 10 | from PIL import Image 11 | import os 12 | import os.path 13 | 14 | IMG_EXTENSIONS = [ 15 | '.jpg', '.JPG', '.jpeg', '.JPEG', 16 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 17 | ] 18 | 19 | NP_EXTENSIONS = ['.npy',] 20 | 21 | def is_image_file(filename, mode='img'): 22 | if(mode=='img'): 23 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 24 | elif(mode=='np'): 25 | return any(filename.endswith(extension) for extension in NP_EXTENSIONS) 26 | 27 | def make_dataset(dirs, mode='img'): 28 | if(not isinstance(dirs,list)): 29 | dirs = [dirs,] 30 | 31 | images = [] 32 | for dir in dirs: 33 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 34 | for root, _, fnames in sorted(os.walk(dir)): 35 | for fname in fnames: 36 | if is_image_file(fname, mode=mode): 37 | path = os.path.join(root, fname) 38 | images.append(path) 39 | 40 | # print("Found %i images in %s"%(len(images),root)) 41 | return images 42 | 43 | def default_loader(path): 44 | return Image.open(path).convert('RGB') 45 | 46 | class ImageFolder(data.Dataset): 47 | def __init__(self, root, transform=None, return_paths=False, 48 | loader=default_loader): 49 | imgs = make_dataset(root) 50 | if len(imgs) == 0: 51 | raise(RuntimeError("Found 0 images in: " + root + "\n" 52 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 53 | 54 | self.root = root 55 | self.imgs = imgs 56 | self.transform = transform 57 | self.return_paths = return_paths 58 | self.loader = loader 59 | 60 | def __getitem__(self, index): 61 | path = self.imgs[index] 62 | img = self.loader(path) 63 | if self.transform is not None: 64 | img = self.transform(img) 65 | if self.return_paths: 66 | return img, path 67 | else: 68 | return img 69 | 70 | def __len__(self): 71 | return len(self.imgs) 72 | -------------------------------------------------------------------------------- /PerceptualSimilarity/data/dataset/twoafc_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import torchvision.transforms as transforms 3 | from data.dataset.base_dataset import BaseDataset 4 | from data.image_folder import make_dataset 5 | from PIL import Image 6 | import numpy as np 7 | import torch 8 | # from IPython import embed 9 | 10 | class TwoAFCDataset(BaseDataset): 11 | def initialize(self, dataroots, load_size=64): 12 | if(not isinstance(dataroots,list)): 13 | dataroots = [dataroots,] 14 | self.roots = dataroots 15 | self.load_size = load_size 16 | 17 | # image directory 18 | self.dir_ref = [os.path.join(root, 'ref') for root in self.roots] 19 | self.ref_paths = make_dataset(self.dir_ref) 20 | self.ref_paths = sorted(self.ref_paths) 21 | 22 | self.dir_p0 = [os.path.join(root, 'p0') for root in self.roots] 23 | self.p0_paths = make_dataset(self.dir_p0) 24 | self.p0_paths = sorted(self.p0_paths) 25 | 26 | self.dir_p1 = [os.path.join(root, 'p1') for root in self.roots] 27 | self.p1_paths = make_dataset(self.dir_p1) 28 | self.p1_paths = sorted(self.p1_paths) 29 | 30 | transform_list = [] 31 | transform_list.append(transforms.Scale(load_size)) 32 | transform_list += [transforms.ToTensor(), 33 | transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))] 34 | 35 | self.transform = transforms.Compose(transform_list) 36 | 37 | # judgement directory 38 | self.dir_J = [os.path.join(root, 'judge') for root in self.roots] 39 | self.judge_paths = make_dataset(self.dir_J,mode='np') 40 | self.judge_paths = sorted(self.judge_paths) 41 | 42 | def __getitem__(self, index): 43 | p0_path = self.p0_paths[index] 44 | p0_img_ = Image.open(p0_path).convert('RGB') 45 | p0_img = self.transform(p0_img_) 46 | 47 | p1_path = self.p1_paths[index] 48 | p1_img_ = Image.open(p1_path).convert('RGB') 49 | p1_img = self.transform(p1_img_) 50 | 51 | ref_path = self.ref_paths[index] 52 | ref_img_ = Image.open(ref_path).convert('RGB') 53 | ref_img = self.transform(ref_img_) 54 | 55 | judge_path = self.judge_paths[index] 56 | # judge_img = (np.load(judge_path)*2.-1.).reshape((1,1,1,)) # [-1,1] 57 | judge_img = np.load(judge_path).reshape((1,1,1,)) # [0,1] 58 | 59 | judge_img = torch.FloatTensor(judge_img) 60 | 61 | return {'p0': p0_img, 'p1': p1_img, 'ref': ref_img, 'judge': judge_img, 62 | 'p0_path': p0_path, 'p1_path': p1_path, 'ref_path': ref_path, 'judge_path': judge_path} 63 | 64 | def __len__(self): 65 | return len(self.p0_paths) 66 | -------------------------------------------------------------------------------- /models/module.py: -------------------------------------------------------------------------------- 1 | """ 2 | Harmonic block definition. 3 | 4 | Licensed under the BSD License [see LICENSE for details]. 5 | 6 | Written by Matej Ulicny 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import math 13 | import numpy as np 14 | 15 | 16 | def dct_filters(k=3, groups=1, expand_dim=1, level=None, DC=True, l1_norm=True): 17 | if level is None: 18 | nf = k ** 2 - int(not DC) 19 | else: 20 | if level <= k: 21 | nf = level * (level + 1) // 2 - int(not DC) 22 | else: 23 | r = 2 * k - 1 - level 24 | nf = k ** 2 - r * (r + 1) // 2 - int(not DC) 25 | filter_bank = np.zeros((nf, k, k), dtype=np.float32) 26 | m = 0 27 | for i in range(k): 28 | for j in range(k): 29 | if (not DC and i == 0 and j == 0) or (not level is None and i + j >= level): 30 | continue 31 | for x in range(k): 32 | for y in range(k): 33 | filter_bank[m, x, y] = math.cos((math.pi * (x + .5) * i) / k) * math.cos( 34 | (math.pi * (y + .5) * j) / k) 35 | if l1_norm: 36 | filter_bank[m, :, :] /= np.sum(np.abs(filter_bank[m, :, :])) 37 | else: 38 | ai = 1.0 if i > 0 else 1.0 / math.sqrt(2.0) 39 | aj = 1.0 if j > 0 else 1.0 / math.sqrt(2.0) 40 | filter_bank[m, :, :] *= (2.0 / k) * ai * aj 41 | m += 1 42 | #print(filter_bank.shape) 43 | filter_bank = np.tile(np.expand_dims(filter_bank, axis=expand_dim), (groups, 1, 1, 1)) 44 | #print(filter_bank.shape) 45 | return torch.FloatTensor(filter_bank) 46 | 47 | 48 | class Harm2d(nn.Module): 49 | 50 | def __init__(self, ni, no, kernel_size, stride=1, padding=0, bias=True, dilation=1, use_bn=True, level=None, 51 | DC=True, groups=1): 52 | super(Harm2d, self).__init__() 53 | self.ni = ni 54 | self.kernel_size = kernel_size 55 | self.stride = stride 56 | self.padding = padding 57 | self.dilation = dilation 58 | self.groups = groups 59 | self.dct = nn.Parameter( 60 | dct_filters(k=kernel_size, groups=ni, expand_dim=1, level=level, DC=DC), requires_grad=False) 61 | 62 | nf = self.dct.shape[0] // ni #if use_bn else self.dct.shape[1] 63 | self.bn = nn.BatchNorm2d(ni * nf, affine=False) 64 | '''self.weight = nn.Parameter( 65 | nn.init.kaiming_normal_(torch.Tensor(no, ni // self.groups * nf, 1, 1), mode='fan_out', 66 | nonlinearity='relu'))''' 67 | 68 | self.bias = nn.Parameter(nn.init.zeros_(torch.Tensor(no))) if bias else None 69 | 70 | def forward(self, x): 71 | #print('self.dct', self.dct.shape) 72 | #print('x', x.shape) 73 | #print('self.ni', self.ni) 74 | #print(x.size(1)) 75 | x = F.conv2d(x, self.dct, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=x.size(1)) 76 | #print('888') 77 | 78 | x = self.bn(x) 79 | #x = F.conv2d(x, self.weight, bias=self.bias, padding=0, groups=self.groups) 80 | return x 81 | 82 | -------------------------------------------------------------------------------- /PerceptualSimilarity/test_dataset_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from models import dist_model as dm 3 | from data import data_loader as dl 4 | import argparse 5 | from IPython import embed 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--dataset_mode', type=str, default='2afc', help='[2afc,jnd]') 9 | parser.add_argument('--datasets', type=str, nargs='+', default=['val/traditional','val/cnn','val/superres','val/deblur','val/color','val/frameinterp'], help='datasets to test - for jnd mode: [val/traditional],[val/cnn]; for 2afc mode: [train/traditional],[train/cnn],[train/mix],[val/traditional],[val/cnn],[val/color],[val/deblur],[val/frameinterp],[val/superres]') 10 | parser.add_argument('--model', type=str, default='net-lin', help='distance model type [net-lin] for linearly calibrated net, [net] for off-the-shelf network, [l2] for euclidean distance, [ssim] for Structured Similarity Image Metric') 11 | parser.add_argument('--net', type=str, default='alex', help='[squeeze], [alex], or [vgg] for network architectures') 12 | parser.add_argument('--colorspace', type=str, default='Lab', help='[Lab] or [RGB] for colorspace to use for l2, ssim model types') 13 | parser.add_argument('--batch_size', type=int, default=50, help='batch size to test image patches in') 14 | parser.add_argument('--use_gpu', action='store_true', help='turn on flag to use GPU') 15 | parser.add_argument('--gpu_ids', type=int, nargs='+', default=[0], help='gpus to use') 16 | parser.add_argument('--nThreads', type=int, default=4, help='number of threads to use in data loader') 17 | 18 | parser.add_argument('--model_path', type=str, default=None, help='location of model, will default to ./weights/v[version]/[net_name].pth') 19 | 20 | parser.add_argument('--from_scratch', action='store_true', help='model was initialized from scratch') 21 | parser.add_argument('--train_trunk', action='store_true', help='model trunk was trained/tuned') 22 | parser.add_argument('--version', type=str, default='0.1', help='v0.1 is latest, v0.0 was original release') 23 | 24 | opt = parser.parse_args() 25 | if(opt.model in ['l2','ssim']): 26 | opt.batch_size = 1 27 | 28 | # initialize model 29 | model = dm.DistModel() 30 | # model.initialize(model=opt.model,net=opt.net,colorspace=opt.colorspace,model_path=opt.model_path,use_gpu=opt.use_gpu) 31 | model.initialize(model=opt.model, net=opt.net, colorspace=opt.colorspace, 32 | model_path=opt.model_path, use_gpu=opt.use_gpu, pnet_rand=opt.from_scratch, pnet_tune=opt.train_trunk, 33 | version=opt.version, gpu_ids=opt.gpu_ids) 34 | 35 | if(opt.model in ['net-lin','net']): 36 | print('Testing model [%s]-[%s]'%(opt.model,opt.net)) 37 | elif(opt.model in ['l2','ssim']): 38 | print('Testing model [%s]-[%s]'%(opt.model,opt.colorspace)) 39 | 40 | # initialize data loader 41 | for dataset in opt.datasets: 42 | data_loader = dl.CreateDataLoader(dataset,dataset_mode=opt.dataset_mode, batch_size=opt.batch_size, nThreads=opt.nThreads) 43 | 44 | # evaluate model on data 45 | if(opt.dataset_mode=='2afc'): 46 | (score, results_verbose) = dm.score_2afc_dataset(data_loader, model.forward, name=dataset) 47 | elif(opt.dataset_mode=='jnd'): 48 | (score, results_verbose) = dm.score_jnd_dataset(data_loader, model.forward, name=dataset) 49 | 50 | # print results 51 | print(' Dataset [%s]: %.2f'%(dataset,100.*score)) 52 | 53 | -------------------------------------------------------------------------------- /PerceptualSimilarity/train.py: -------------------------------------------------------------------------------- 1 | import torch.backends.cudnn as cudnn 2 | cudnn.benchmark=False 3 | 4 | import numpy as np 5 | import time 6 | import os 7 | from models import dist_model as dm 8 | from data import data_loader as dl 9 | import argparse 10 | from util.visualizer import Visualizer 11 | from IPython import embed 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--datasets', type=str, nargs='+', default=['train/traditional','train/cnn','train/mix'], help='datasets to train on: [train/traditional],[train/cnn],[train/mix],[val/traditional],[val/cnn],[val/color],[val/deblur],[val/frameinterp],[val/superres]') 15 | parser.add_argument('--model', type=str, default='net-lin', help='distance model type [net-lin] for linearly calibrated net, [net] for off-the-shelf network, [l2] for euclidean distance, [ssim] for Structured Similarity Image Metric') 16 | parser.add_argument('--net', type=str, default='alex', help='[squeeze], [alex], or [vgg] for network architectures') 17 | parser.add_argument('--batch_size', type=int, default=50, help='batch size to test image patches in') 18 | parser.add_argument('--use_gpu', action='store_true', help='turn on flag to use GPU') 19 | parser.add_argument('--gpu_ids', type=int, nargs='+', default=[0], help='gpus to use') 20 | 21 | parser.add_argument('--nThreads', type=int, default=4, help='number of threads to use in data loader') 22 | parser.add_argument('--nepoch', type=int, default=5, help='# epochs at base learning rate') 23 | parser.add_argument('--nepoch_decay', type=int, default=5, help='# additional epochs at linearly learning rate') 24 | parser.add_argument('--display_freq', type=int, default=5000, help='frequency (in instances) of showing training results on screen') 25 | parser.add_argument('--print_freq', type=int, default=5000, help='frequency (in instances) of showing training results on console') 26 | parser.add_argument('--save_latest_freq', type=int, default=20000, help='frequency (in instances) of saving the latest results') 27 | parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs') 28 | parser.add_argument('--display_id', type=int, default=0, help='window id of the visdom display, [0] for no displaying') 29 | parser.add_argument('--display_winsize', type=int, default=256, help='display window size') 30 | parser.add_argument('--display_port', type=int, default=8001, help='visdom display port') 31 | parser.add_argument('--use_html', action='store_true', help='save off html pages') 32 | parser.add_argument('--checkpoints_dir', type=str, default='checkpoints', help='checkpoints directory') 33 | parser.add_argument('--name', type=str, default='tmp', help='directory name for training') 34 | 35 | parser.add_argument('--from_scratch', action='store_true', help='model was initialized from scratch') 36 | parser.add_argument('--train_trunk', action='store_true', help='model trunk was trained/tuned') 37 | parser.add_argument('--train_plot', action='store_true', help='plot saving') 38 | 39 | opt = parser.parse_args() 40 | opt.save_dir = os.path.join(opt.checkpoints_dir,opt.name) 41 | if(not os.path.exists(opt.save_dir)): 42 | os.mkdir(opt.save_dir) 43 | 44 | # initialize model 45 | model = dm.DistModel() 46 | model.initialize(model=opt.model, net=opt.net, use_gpu=opt.use_gpu, is_train=True, 47 | pnet_rand=opt.from_scratch, pnet_tune=opt.train_trunk, gpu_ids=opt.gpu_ids) 48 | 49 | # load data from all training sets 50 | data_loader = dl.CreateDataLoader(opt.datasets,dataset_mode='2afc', batch_size=opt.batch_size, serial_batches=False, nThreads=opt.nThreads) 51 | dataset = data_loader.load_data() 52 | dataset_size = len(data_loader) 53 | D = len(dataset) 54 | print('Loading %i instances from'%dataset_size,opt.datasets) 55 | visualizer = Visualizer(opt) 56 | 57 | total_steps = 0 58 | fid = open(os.path.join(opt.checkpoints_dir,opt.name,'train_log.txt'),'w+') 59 | for epoch in range(1, opt.nepoch + opt.nepoch_decay + 1): 60 | epoch_start_time = time.time() 61 | for i, data in enumerate(dataset): 62 | iter_start_time = time.time() 63 | total_steps += opt.batch_size 64 | epoch_iter = total_steps - dataset_size * (epoch - 1) 65 | 66 | model.set_input(data) 67 | model.optimize_parameters() 68 | 69 | if total_steps % opt.display_freq == 0: 70 | visualizer.display_current_results(model.get_current_visuals(), epoch) 71 | 72 | if total_steps % opt.print_freq == 0: 73 | errors = model.get_current_errors() 74 | t = (time.time()-iter_start_time)/opt.batch_size 75 | t2o = (time.time()-epoch_start_time)/3600. 76 | t2 = t2o*D/(i+.0001) 77 | visualizer.print_current_errors(epoch, epoch_iter, errors, t, t2=t2, t2o=t2o, fid=fid) 78 | 79 | for key in errors.keys(): 80 | visualizer.plot_current_errors_save(epoch, float(epoch_iter)/dataset_size, opt, errors, keys=[key,], name=key, to_plot=opt.train_plot) 81 | 82 | if opt.display_id > 0: 83 | visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors) 84 | 85 | if total_steps % opt.save_latest_freq == 0: 86 | print('saving the latest model (epoch %d, total_steps %d)' % 87 | (epoch, total_steps)) 88 | model.save(opt.save_dir, 'latest') 89 | 90 | if epoch % opt.save_epoch_freq == 0: 91 | print('saving the model at the end of epoch %d, iters %d' % 92 | (epoch, total_steps)) 93 | model.save(opt.save_dir, 'latest') 94 | model.save(opt.save_dir, epoch) 95 | 96 | print('End of epoch %d / %d \t Time Taken: %d sec' % 97 | (epoch, opt.nepoch + opt.nepoch_decay, time.time() - epoch_start_time)) 98 | 99 | if epoch > opt.nepoch: 100 | model.update_learning_rate(opt.nepoch_decay) 101 | 102 | # model.save_done(True) 103 | fid.close() 104 | -------------------------------------------------------------------------------- /PerceptualSimilarity/models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import numpy as np 7 | #from skimage.measure import compare_ssim 8 | from skimage.metrics import structural_similarity as SSIM 9 | import torch 10 | from torch.autograd import Variable 11 | 12 | from PerceptualSimilarity.models import dist_model 13 | 14 | class PerceptualLoss(torch.nn.Module): 15 | def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0], version='0.1'): # VGG using our perceptually-learned weights (LPIPS metric) 16 | # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss 17 | super(PerceptualLoss, self).__init__() 18 | # print('Setting up Perceptual loss...') 19 | self.use_gpu = use_gpu 20 | self.spatial = spatial 21 | self.gpu_ids = gpu_ids 22 | self.model = dist_model.DistModel() 23 | self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids, version=version) 24 | # print('...[%s] initialized'%self.model.name()) 25 | # print('...Done') 26 | 27 | def forward(self, pred, target, normalize=False): 28 | """ 29 | Pred and target are Variables. 30 | If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] 31 | If normalize is False, assumes the images are already between [-1,+1] 32 | 33 | Inputs pred and target are Nx3xHxW 34 | Output pytorch Variable N long 35 | """ 36 | 37 | if normalize: 38 | target = 2 * target - 1 39 | pred = 2 * pred - 1 40 | 41 | return self.model.forward(target, pred) 42 | 43 | def normalize_tensor(in_feat,eps=1e-10): 44 | norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True)) 45 | return in_feat/(norm_factor+eps) 46 | 47 | def l2(p0, p1, range=255.): 48 | return .5*np.mean((p0 / range - p1 / range)**2) 49 | 50 | def psnr(p0, p1, peak=255.): 51 | return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2)) 52 | 53 | def dssim(p0, p1, range=255.): 54 | return (1 - SSIM(p0, p1, data_range=range, multichannel=True)) / 2. 55 | 56 | def rgb2lab(in_img,mean_cent=False): 57 | from skimage import color 58 | img_lab = color.rgb2lab(in_img) 59 | if(mean_cent): 60 | img_lab[:,:,0] = img_lab[:,:,0]-50 61 | return img_lab 62 | 63 | def tensor2np(tensor_obj): 64 | # change dimension of a tensor object into a numpy array 65 | return tensor_obj[0].cpu().float().numpy().transpose((1,2,0)) 66 | 67 | def np2tensor(np_obj): 68 | # change dimenion of np array into tensor array 69 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 70 | 71 | def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False): 72 | # image tensor to lab tensor 73 | from skimage import color 74 | 75 | img = tensor2im(image_tensor) 76 | img_lab = color.rgb2lab(img) 77 | if(mc_only): 78 | img_lab[:,:,0] = img_lab[:,:,0]-50 79 | if(to_norm and not mc_only): 80 | img_lab[:,:,0] = img_lab[:,:,0]-50 81 | img_lab = img_lab/100. 82 | 83 | return np2tensor(img_lab) 84 | 85 | def tensorlab2tensor(lab_tensor,return_inbnd=False): 86 | from skimage import color 87 | import warnings 88 | warnings.filterwarnings("ignore") 89 | 90 | lab = tensor2np(lab_tensor)*100. 91 | lab[:,:,0] = lab[:,:,0]+50 92 | 93 | rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1) 94 | if(return_inbnd): 95 | # convert back to lab, see if we match 96 | lab_back = color.rgb2lab(rgb_back.astype('uint8')) 97 | mask = 1.*np.isclose(lab_back,lab,atol=2.) 98 | mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis]) 99 | return (im2tensor(rgb_back),mask) 100 | else: 101 | return im2tensor(rgb_back) 102 | 103 | def rgb2lab(input): 104 | from skimage import color 105 | return color.rgb2lab(input / 255.) 106 | 107 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 108 | image_numpy = image_tensor[0].cpu().float().numpy() 109 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 110 | return image_numpy.astype(imtype) 111 | 112 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 113 | return torch.Tensor((image / factor - cent) 114 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 115 | 116 | def tensor2vec(vector_tensor): 117 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0] 118 | 119 | def voc_ap(rec, prec, use_07_metric=False): 120 | """ ap = voc_ap(rec, prec, [use_07_metric]) 121 | Compute VOC AP given precision and recall. 122 | If use_07_metric is true, uses the 123 | VOC 07 11 point method (default:False). 124 | """ 125 | if use_07_metric: 126 | # 11 point metric 127 | ap = 0. 128 | for t in np.arange(0., 1.1, 0.1): 129 | if np.sum(rec >= t) == 0: 130 | p = 0 131 | else: 132 | p = np.max(prec[rec >= t]) 133 | ap = ap + p / 11. 134 | else: 135 | # correct AP calculation 136 | # first append sentinel values at the end 137 | mrec = np.concatenate(([0.], rec, [1.])) 138 | mpre = np.concatenate(([0.], prec, [0.])) 139 | 140 | # compute the precision envelope 141 | for i in range(mpre.size - 1, 0, -1): 142 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 143 | 144 | # to calculate area under PR curve, look for points 145 | # where X axis (recall) changes value 146 | i = np.where(mrec[1:] != mrec[:-1])[0] 147 | 148 | # and sum (\Delta recall) * prec 149 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 150 | return ap 151 | 152 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 153 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): 154 | image_numpy = image_tensor[0].cpu().float().numpy() 155 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 156 | return image_numpy.astype(imtype) 157 | 158 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 159 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): 160 | return torch.Tensor((image / factor - cent) 161 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 162 | -------------------------------------------------------------------------------- /PerceptualSimilarity/models/pretrained_networks.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torchvision import models as tv 4 | from IPython import embed 5 | 6 | class squeezenet(torch.nn.Module): 7 | def __init__(self, requires_grad=False, pretrained=True): 8 | super(squeezenet, self).__init__() 9 | pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features 10 | self.slice1 = torch.nn.Sequential() 11 | self.slice2 = torch.nn.Sequential() 12 | self.slice3 = torch.nn.Sequential() 13 | self.slice4 = torch.nn.Sequential() 14 | self.slice5 = torch.nn.Sequential() 15 | self.slice6 = torch.nn.Sequential() 16 | self.slice7 = torch.nn.Sequential() 17 | self.N_slices = 7 18 | for x in range(2): 19 | self.slice1.add_module(str(x), pretrained_features[x]) 20 | for x in range(2,5): 21 | self.slice2.add_module(str(x), pretrained_features[x]) 22 | for x in range(5, 8): 23 | self.slice3.add_module(str(x), pretrained_features[x]) 24 | for x in range(8, 10): 25 | self.slice4.add_module(str(x), pretrained_features[x]) 26 | for x in range(10, 11): 27 | self.slice5.add_module(str(x), pretrained_features[x]) 28 | for x in range(11, 12): 29 | self.slice6.add_module(str(x), pretrained_features[x]) 30 | for x in range(12, 13): 31 | self.slice7.add_module(str(x), pretrained_features[x]) 32 | if not requires_grad: 33 | for param in self.parameters(): 34 | param.requires_grad = False 35 | 36 | def forward(self, X): 37 | h = self.slice1(X) 38 | h_relu1 = h 39 | h = self.slice2(h) 40 | h_relu2 = h 41 | h = self.slice3(h) 42 | h_relu3 = h 43 | h = self.slice4(h) 44 | h_relu4 = h 45 | h = self.slice5(h) 46 | h_relu5 = h 47 | h = self.slice6(h) 48 | h_relu6 = h 49 | h = self.slice7(h) 50 | h_relu7 = h 51 | vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7']) 52 | out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7) 53 | 54 | return out 55 | 56 | 57 | class alexnet(torch.nn.Module): 58 | def __init__(self, requires_grad=False, pretrained=True): 59 | super(alexnet, self).__init__() 60 | alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features 61 | self.slice1 = torch.nn.Sequential() 62 | self.slice2 = torch.nn.Sequential() 63 | self.slice3 = torch.nn.Sequential() 64 | self.slice4 = torch.nn.Sequential() 65 | self.slice5 = torch.nn.Sequential() 66 | self.N_slices = 5 67 | for x in range(2): 68 | self.slice1.add_module(str(x), alexnet_pretrained_features[x]) 69 | for x in range(2, 5): 70 | self.slice2.add_module(str(x), alexnet_pretrained_features[x]) 71 | for x in range(5, 8): 72 | self.slice3.add_module(str(x), alexnet_pretrained_features[x]) 73 | for x in range(8, 10): 74 | self.slice4.add_module(str(x), alexnet_pretrained_features[x]) 75 | for x in range(10, 12): 76 | self.slice5.add_module(str(x), alexnet_pretrained_features[x]) 77 | if not requires_grad: 78 | for param in self.parameters(): 79 | param.requires_grad = False 80 | 81 | def forward(self, X): 82 | h = self.slice1(X) 83 | h_relu1 = h 84 | h = self.slice2(h) 85 | h_relu2 = h 86 | h = self.slice3(h) 87 | h_relu3 = h 88 | h = self.slice4(h) 89 | h_relu4 = h 90 | h = self.slice5(h) 91 | h_relu5 = h 92 | alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']) 93 | out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) 94 | 95 | return out 96 | 97 | class vgg16(torch.nn.Module): 98 | def __init__(self, requires_grad=False, pretrained=True): 99 | super(vgg16, self).__init__() 100 | vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features 101 | self.slice1 = torch.nn.Sequential() 102 | self.slice2 = torch.nn.Sequential() 103 | self.slice3 = torch.nn.Sequential() 104 | self.slice4 = torch.nn.Sequential() 105 | self.slice5 = torch.nn.Sequential() 106 | self.N_slices = 5 107 | for x in range(4): 108 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 109 | for x in range(4, 9): 110 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 111 | for x in range(9, 16): 112 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 113 | for x in range(16, 23): 114 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 115 | for x in range(23, 30): 116 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 117 | if not requires_grad: 118 | for param in self.parameters(): 119 | param.requires_grad = False 120 | 121 | def forward(self, X): 122 | h = self.slice1(X) 123 | h_relu1_2 = h 124 | h = self.slice2(h) 125 | h_relu2_2 = h 126 | h = self.slice3(h) 127 | h_relu3_3 = h 128 | h = self.slice4(h) 129 | h_relu4_3 = h 130 | h = self.slice5(h) 131 | h_relu5_3 = h 132 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 133 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 134 | 135 | return out 136 | 137 | 138 | 139 | class resnet(torch.nn.Module): 140 | def __init__(self, requires_grad=False, pretrained=True, num=18): 141 | super(resnet, self).__init__() 142 | if(num==18): 143 | self.net = tv.resnet18(pretrained=pretrained) 144 | elif(num==34): 145 | self.net = tv.resnet34(pretrained=pretrained) 146 | elif(num==50): 147 | self.net = tv.resnet50(pretrained=pretrained) 148 | elif(num==101): 149 | self.net = tv.resnet101(pretrained=pretrained) 150 | elif(num==152): 151 | self.net = tv.resnet152(pretrained=pretrained) 152 | self.N_slices = 5 153 | 154 | self.conv1 = self.net.conv1 155 | self.bn1 = self.net.bn1 156 | self.relu = self.net.relu 157 | self.maxpool = self.net.maxpool 158 | self.layer1 = self.net.layer1 159 | self.layer2 = self.net.layer2 160 | self.layer3 = self.net.layer3 161 | self.layer4 = self.net.layer4 162 | 163 | def forward(self, X): 164 | h = self.conv1(X) 165 | h = self.bn1(h) 166 | h = self.relu(h) 167 | h_relu1 = h 168 | h = self.maxpool(h) 169 | h = self.layer1(h) 170 | h_conv2 = h 171 | h = self.layer2(h) 172 | h_conv3 = h 173 | h = self.layer3(h) 174 | h_conv4 = h 175 | h = self.layer4(h) 176 | h_conv5 = h 177 | 178 | outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5']) 179 | out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) 180 | 181 | return out 182 | -------------------------------------------------------------------------------- /PerceptualSimilarity/models/networks_basic.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | import sys 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.init as init 8 | from torch.autograd import Variable 9 | import numpy as np 10 | from pdb import set_trace as st 11 | from skimage import color 12 | from IPython import embed 13 | from . import pretrained_networks as pn 14 | 15 | from PerceptualSimilarity import models as util 16 | 17 | def spatial_average(in_tens, keepdim=True): 18 | # import pdb; pdb.set_trace() 19 | # return in_tens.mean([2,3],keepdim=keepdim) 20 | return in_tens.mean(2,keepdim=keepdim).mean(3,keepdim=keepdim) 21 | 22 | 23 | def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W 24 | in_H = in_tens.shape[2] 25 | scale_factor = 1.*out_H/in_H 26 | 27 | return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens) 28 | 29 | # Learned perceptual metric 30 | class PNetLin(nn.Module): 31 | def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True): 32 | super(PNetLin, self).__init__() 33 | 34 | self.pnet_type = pnet_type 35 | self.pnet_tune = pnet_tune 36 | self.pnet_rand = pnet_rand 37 | self.spatial = spatial 38 | self.lpips = lpips 39 | self.version = version 40 | self.scaling_layer = ScalingLayer() 41 | 42 | if(self.pnet_type in ['vgg','vgg16']): 43 | net_type = pn.vgg16 44 | self.chns = [64,128,256,512,512] 45 | elif(self.pnet_type=='alex'): 46 | net_type = pn.alexnet 47 | self.chns = [64,192,384,256,256] 48 | elif(self.pnet_type=='squeeze'): 49 | net_type = pn.squeezenet 50 | self.chns = [64,128,256,384,384,512,512] 51 | self.L = len(self.chns) 52 | 53 | self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) 54 | 55 | if(lpips): 56 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 57 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 58 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 59 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 60 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 61 | self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4] 62 | if(self.pnet_type=='squeeze'): # 7 layers for squeezenet 63 | self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) 64 | self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) 65 | self.lins+=[self.lin5,self.lin6] 66 | 67 | def forward(self, in0, in1, retPerLayer=False): 68 | # v0.0 - original release had a bug, where input was not scaled 69 | in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1) 70 | outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) 71 | feats0, feats1, diffs = {}, {}, {} 72 | 73 | for kk in range(self.L): 74 | # import pdb; pdb.set_trace() 75 | feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk]) 76 | diffs[kk] = (feats0[kk]-feats1[kk])**2 77 | 78 | if(self.lpips): 79 | if(self.spatial): 80 | res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)] 81 | else: 82 | res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)] 83 | else: 84 | if(self.spatial): 85 | res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)] 86 | else: 87 | res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)] 88 | 89 | val = res[0] 90 | for l in range(1,self.L): 91 | val += res[l] 92 | 93 | if(retPerLayer): 94 | return (val, res) 95 | else: 96 | return val 97 | 98 | class ScalingLayer(nn.Module): 99 | def __init__(self): 100 | super(ScalingLayer, self).__init__() 101 | self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None]) 102 | self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None]) 103 | 104 | def forward(self, inp): 105 | return (inp - self.shift) / self.scale 106 | 107 | 108 | class NetLinLayer(nn.Module): 109 | ''' A single linear layer which does a 1x1 conv ''' 110 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 111 | super(NetLinLayer, self).__init__() 112 | 113 | layers = [nn.Dropout(),] if(use_dropout) else [] 114 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),] 115 | self.model = nn.Sequential(*layers) 116 | 117 | 118 | class Dist2LogitLayer(nn.Module): 119 | ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) ''' 120 | def __init__(self, chn_mid=32, use_sigmoid=True): 121 | super(Dist2LogitLayer, self).__init__() 122 | 123 | layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),] 124 | layers += [nn.LeakyReLU(0.2,True),] 125 | layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),] 126 | layers += [nn.LeakyReLU(0.2,True),] 127 | layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),] 128 | if(use_sigmoid): 129 | layers += [nn.Sigmoid(),] 130 | self.model = nn.Sequential(*layers) 131 | 132 | def forward(self,d0,d1,eps=0.1): 133 | return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1)) 134 | 135 | class BCERankingLoss(nn.Module): 136 | def __init__(self, chn_mid=32): 137 | super(BCERankingLoss, self).__init__() 138 | self.net = Dist2LogitLayer(chn_mid=chn_mid) 139 | # self.parameters = list(self.net.parameters()) 140 | self.loss = torch.nn.BCELoss() 141 | 142 | def forward(self, d0, d1, judge): 143 | per = (judge+1.)/2. 144 | self.logit = self.net.forward(d0,d1) 145 | return self.loss(self.logit, per) 146 | 147 | # L2, DSSIM metrics 148 | class FakeNet(nn.Module): 149 | def __init__(self, use_gpu=True, colorspace='Lab'): 150 | super(FakeNet, self).__init__() 151 | self.use_gpu = use_gpu 152 | self.colorspace=colorspace 153 | 154 | class L2(FakeNet): 155 | 156 | def forward(self, in0, in1, retPerLayer=None): 157 | assert(in0.size()[0]==1) # currently only supports batchSize 1 158 | 159 | if(self.colorspace=='RGB'): 160 | (N,C,X,Y) = in0.size() 161 | value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N) 162 | return value 163 | elif(self.colorspace=='Lab'): 164 | value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)), 165 | util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') 166 | ret_var = Variable( torch.Tensor((value,) ) ) 167 | if(self.use_gpu): 168 | ret_var = ret_var.cuda() 169 | return ret_var 170 | 171 | class DSSIM(FakeNet): 172 | 173 | def forward(self, in0, in1, retPerLayer=None): 174 | assert(in0.size()[0]==1) # currently only supports batchSize 1 175 | 176 | if(self.colorspace=='RGB'): 177 | value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float') 178 | elif(self.colorspace=='Lab'): 179 | value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)), 180 | util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') 181 | ret_var = Variable( torch.Tensor((value,) ) ) 182 | if(self.use_gpu): 183 | ret_var = ret_var.cuda() 184 | return ret_var 185 | 186 | def print_network(net): 187 | num_params = 0 188 | for param in net.parameters(): 189 | num_params += param.numel() 190 | # print('Network',net) 191 | # print('Total number of parameters: %d' % num_params) 192 | -------------------------------------------------------------------------------- /PerceptualSimilarity/util/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import time 4 | from . import util 5 | from . import html 6 | # from pdb import set_trace as st 7 | import matplotlib.pyplot as plt 8 | import math 9 | # from IPython import embed 10 | 11 | def zoom_to_res(img,res=256,order=0,axis=0): 12 | # img 3xXxX 13 | from scipy.ndimage import zoom 14 | zoom_factor = res/img.shape[1] 15 | if(axis==0): 16 | return zoom(img,[1,zoom_factor,zoom_factor],order=order) 17 | elif(axis==2): 18 | return zoom(img,[zoom_factor,zoom_factor,1],order=order) 19 | 20 | class Visualizer(): 21 | def __init__(self, opt): 22 | # self.opt = opt 23 | self.display_id = opt.display_id 24 | # self.use_html = opt.is_train and not opt.no_html 25 | self.win_size = opt.display_winsize 26 | self.name = opt.name 27 | self.display_cnt = 0 # display_current_results counter 28 | self.display_cnt_high = 0 29 | self.use_html = opt.use_html 30 | 31 | if self.display_id > 0: 32 | import visdom 33 | self.vis = visdom.Visdom(port = opt.display_port) 34 | 35 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 36 | util.mkdirs([self.web_dir,]) 37 | if self.use_html: 38 | self.img_dir = os.path.join(self.web_dir, 'images') 39 | print('create web directory %s...' % self.web_dir) 40 | util.mkdirs([self.img_dir,]) 41 | 42 | # |visuals|: dictionary of images to display or save 43 | def display_current_results(self, visuals, epoch, nrows=None, res=256): 44 | if self.display_id > 0: # show images in the browser 45 | title = self.name 46 | if(nrows is None): 47 | nrows = int(math.ceil(len(visuals.items()) / 2.0)) 48 | images = [] 49 | idx = 0 50 | for label, image_numpy in visuals.items(): 51 | title += " | " if idx % nrows == 0 else ", " 52 | title += label 53 | img = image_numpy.transpose([2, 0, 1]) 54 | img = zoom_to_res(img,res=res,order=0) 55 | images.append(img) 56 | idx += 1 57 | if len(visuals.items()) % 2 != 0: 58 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1]))*255 59 | white_image = zoom_to_res(white_image,res=res,order=0) 60 | images.append(white_image) 61 | self.vis.images(images, nrow=nrows, win=self.display_id + 1, 62 | opts=dict(title=title)) 63 | 64 | if self.use_html: # save images to a html file 65 | for label, image_numpy in visuals.items(): 66 | img_path = os.path.join(self.img_dir, 'epoch%.3d_cnt%.6d_%s.png' % (epoch, self.display_cnt, label)) 67 | util.save_image(zoom_to_res(image_numpy, res=res, axis=2), img_path) 68 | 69 | self.display_cnt += 1 70 | self.display_cnt_high = np.maximum(self.display_cnt_high, self.display_cnt) 71 | 72 | # update website 73 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1) 74 | for n in range(epoch, 0, -1): 75 | webpage.add_header('epoch [%d]' % n) 76 | if(n==epoch): 77 | high = self.display_cnt 78 | else: 79 | high = self.display_cnt_high 80 | for c in range(high-1,-1,-1): 81 | ims = [] 82 | txts = [] 83 | links = [] 84 | 85 | for label, image_numpy in visuals.items(): 86 | img_path = 'epoch%.3d_cnt%.6d_%s.png' % (n, c, label) 87 | ims.append(os.path.join('images',img_path)) 88 | txts.append(label) 89 | links.append(os.path.join('images',img_path)) 90 | webpage.add_images(ims, txts, links, width=self.win_size) 91 | webpage.save() 92 | 93 | # save errors into a directory 94 | def plot_current_errors_save(self, epoch, counter_ratio, opt, errors,keys='+ALL',name='loss', to_plot=False): 95 | if not hasattr(self, 'plot_data'): 96 | self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())} 97 | self.plot_data['X'].append(epoch + counter_ratio) 98 | self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']]) 99 | 100 | # embed() 101 | if(keys=='+ALL'): 102 | plot_keys = self.plot_data['legend'] 103 | else: 104 | plot_keys = keys 105 | 106 | if(to_plot): 107 | (f,ax) = plt.subplots(1,1) 108 | for (k,kname) in enumerate(plot_keys): 109 | kk = np.where(np.array(self.plot_data['legend'])==kname)[0][0] 110 | x = self.plot_data['X'] 111 | y = np.array(self.plot_data['Y'])[:,kk] 112 | if(to_plot): 113 | ax.plot(x, y, 'o-', label=kname) 114 | np.save(os.path.join(self.web_dir,'%s_x')%kname,x) 115 | np.save(os.path.join(self.web_dir,'%s_y')%kname,y) 116 | 117 | if(to_plot): 118 | plt.legend(loc=0,fontsize='small') 119 | plt.xlabel('epoch') 120 | plt.ylabel('Value') 121 | f.savefig(os.path.join(self.web_dir,'%s.png'%name)) 122 | f.clf() 123 | plt.close() 124 | 125 | # errors: dictionary of error labels and values 126 | def plot_current_errors(self, epoch, counter_ratio, opt, errors): 127 | if not hasattr(self, 'plot_data'): 128 | self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())} 129 | self.plot_data['X'].append(epoch + counter_ratio) 130 | self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']]) 131 | self.vis.line( 132 | X=np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1), 133 | Y=np.array(self.plot_data['Y']), 134 | opts={ 135 | 'title': self.name + ' loss over time', 136 | 'legend': self.plot_data['legend'], 137 | 'xlabel': 'epoch', 138 | 'ylabel': 'loss'}, 139 | win=self.display_id) 140 | 141 | # errors: same format as |errors| of plotCurrentErrors 142 | def print_current_errors(self, epoch, i, errors, t, t2=-1, t2o=-1, fid=None): 143 | message = '(ep: %d, it: %d, t: %.3f[s], ept: %.2f/%.2f[h]) ' % (epoch, i, t, t2o, t2) 144 | message += (', ').join(['%s: %.3f' % (k, v) for k, v in errors.items()]) 145 | 146 | print(message) 147 | if(fid is not None): 148 | fid.write('%s\n'%message) 149 | 150 | 151 | # save image to the disk 152 | def save_images_simple(self, webpage, images, names, in_txts, prefix='', res=256): 153 | image_dir = webpage.get_image_dir() 154 | ims = [] 155 | txts = [] 156 | links = [] 157 | 158 | for name, image_numpy, txt in zip(names, images, in_txts): 159 | image_name = '%s_%s.png' % (prefix, name) 160 | save_path = os.path.join(image_dir, image_name) 161 | if(res is not None): 162 | util.save_image(zoom_to_res(image_numpy,res=res,axis=2), save_path) 163 | else: 164 | util.save_image(image_numpy, save_path) 165 | 166 | ims.append(os.path.join(webpage.img_subdir,image_name)) 167 | # txts.append(name) 168 | txts.append(txt) 169 | links.append(os.path.join(webpage.img_subdir,image_name)) 170 | # embed() 171 | webpage.add_images(ims, txts, links, width=self.win_size) 172 | 173 | # save image to the disk 174 | def save_images(self, webpage, images, names, image_path, title=''): 175 | image_dir = webpage.get_image_dir() 176 | # short_path = ntpath.basename(image_path) 177 | # name = os.path.splitext(short_path)[0] 178 | # name = short_path 179 | # webpage.add_header('%s, %s' % (name, title)) 180 | ims = [] 181 | txts = [] 182 | links = [] 183 | 184 | for label, image_numpy in zip(names, images): 185 | image_name = '%s.jpg' % (label,) 186 | save_path = os.path.join(image_dir, image_name) 187 | util.save_image(image_numpy, save_path) 188 | 189 | ims.append(image_name) 190 | txts.append(label) 191 | links.append(image_name) 192 | webpage.add_images(ims, txts, links, width=self.win_size) 193 | 194 | # save image to the disk 195 | # def save_images(self, webpage, visuals, image_path, short=False): 196 | # image_dir = webpage.get_image_dir() 197 | # if short: 198 | # short_path = ntpath.basename(image_path) 199 | # name = os.path.splitext(short_path)[0] 200 | # else: 201 | # name = image_path 202 | 203 | # webpage.add_header(name) 204 | # ims = [] 205 | # txts = [] 206 | # links = [] 207 | 208 | # for label, image_numpy in visuals.items(): 209 | # image_name = '%s_%s.png' % (name, label) 210 | # save_path = os.path.join(image_dir, image_name) 211 | # util.save_image(image_numpy, save_path) 212 | 213 | # ims.append(image_name) 214 | # txts.append(label) 215 | # links.append(image_name) 216 | # webpage.add_images(ims, txts, links, width=self.win_size) 217 | -------------------------------------------------------------------------------- /PerceptualSimilarity/README.md: -------------------------------------------------------------------------------- 1 | 2 | ## Perceptual Similarity Metric and Dataset [[Project Page]](http://richzhang.github.io/PerceptualSimilarity/) 3 | 4 | **The Unreasonable Effectiveness of Deep Features as a Perceptual Metric** 5 | [Richard Zhang](https://richzhang.github.io/), [Phillip Isola](http://web.mit.edu/phillipi/), [Alexei A. Efros](http://www.eecs.berkeley.edu/~efros/), [Eli Shechtman](https://research.adobe.com/person/eli-shechtman/), [Oliver Wang](http://www.oliverwang.info/). 6 |
In [CVPR](https://arxiv.org/abs/1801.03924), 2018. 7 | 8 | 9 | 10 | This repository contains our **perceptual metric (LPIPS)** and **dataset (BAPPS)**. It can also be used as a "perceptual loss". This uses PyTorch; a Tensorflow alternative is [here](https://github.com/alexlee-gk/lpips-tensorflow). 11 | 12 | **Table of Contents**
13 | 1. [Learned Perceptual Image Patch Similarity (LPIPS) metric](#1-learned-perceptual-image-patch-similarity-lpips-metric)
14 | a. [Basic Usage](#a-basic-usage) If you just want to run the metric through command line, this is all you need.
15 | b. ["Perceptual Loss" usage](#b-backpropping-through-the-metric)
16 | c. [About the metric](#c-about-the-metric)
17 | 2. [Berkeley-Adobe Perceptual Patch Similarity (BAPPS) dataset](#2-berkeley-adobe-perceptual-patch-similarity-bapps-dataset)
18 | a. [Download](#a-downloading-the-dataset)
19 | b. [Evaluation](#b-evaluating-a-perceptual-similarity-metric-on-a-dataset)
20 | c. [About the dataset](#c-about-the-dataset)
21 | d. [Train the metric using the dataset](#d-using-the-dataset-to-train-the-metric)
22 | 23 | ## (0) Dependencies/Setup 24 | 25 | ### Installation 26 | - Install PyTorch 1.0+ and torchvision fom http://pytorch.org 27 | 28 | ```bash 29 | pip install -r requirements.txt 30 | ``` 31 | - Clone this repo: 32 | ```bash 33 | git clone https://github.com/richzhang/PerceptualSimilarity 34 | cd PerceptualSimilarity 35 | ``` 36 | 37 | ## (1) Learned Perceptual Image Patch Similarity (LPIPS) metric 38 | 39 | Evaluate the distance between image patches. **Higher means further/more different. Lower means more similar.** 40 | 41 | ### (A) Basic Usage 42 | 43 | #### (A.I) Line commands 44 | 45 | Example scripts to take the distance between 2 specific images, all corresponding pairs of images in 2 directories, or all pairs of images within a directory: 46 | 47 | ``` 48 | python compute_dists.py -p0 imgs/ex_ref.png -p1 imgs/ex_p0.png --use_gpu 49 | python compute_dists_dirs.py -d0 imgs/ex_dir0 -d1 imgs/ex_dir1 -o imgs/example_dists.txt --use_gpu 50 | python compute_dists_pair.py -d imgs/ex_dir_pair -o imgs/example_dists_pair.txt --use_gpu 51 | ``` 52 | 53 | #### (A.II) Python code 54 | 55 | File [test_network.py](test_network.py) shows example usage. This snippet is all you really need. 56 | 57 | ```python 58 | import models 59 | model = models.PerceptualLoss(model='net-lin', net='alex', use_gpu=use_gpu, gpu_ids=[0]) 60 | d = model.forward(im0,im1) 61 | ``` 62 | 63 | Variables ```im0, im1``` is a PyTorch Tensor/Variable with shape ```Nx3xHxW``` (```N``` patches of size ```HxW```, RGB images scaled in `[-1,+1]`). This returns `d`, a length `N` Tensor/Variable. 64 | 65 | Run `python test_network.py` to take the distance between example reference image [`ex_ref.png`](imgs/ex_ref.png) to distorted images [`ex_p0.png`](./imgs/ex_p0.png) and [`ex_p1.png`](imgs/ex_p1.png). Before running it - which do you think *should* be closer? 66 | 67 | **Some Options** By default in `model.initialize`: 68 | - `net='alex'`: Network `alex` is fastest, performs the best, and is the default. You can instead use `squeeze` or `vgg`. 69 | - `model='net-lin'`: This adds a linear calibration on top of intermediate features in the net. Set this to `model=net` to equally weight all the features. 70 | 71 | ### (B) Backpropping through the metric 72 | 73 | File [`perceptual_loss.py`](perceptual_loss.py) shows how to iteratively optimize using the metric. Run `python perceptual_loss.py` for a demo. The code can also be used to implement vanilla VGG loss, without our learned weights. 74 | 75 | ### (C) About the metric 76 | 77 | **Higher means further/more different. Lower means more similar.** 78 | 79 | We found that deep network activations work surprisingly well as a perceptual similarity metric. This was true across network architectures (SqueezeNet [2.8 MB], AlexNet [9.1 MB], and VGG [58.9 MB] provided similar scores) and supervisory signals (unsupervised, self-supervised, and supervised all perform strongly). We slightly improved scores by linearly "calibrating" networks - adding a linear layer on top of off-the-shelf classification networks. We provide 3 variants, using linear layers on top of the SqueezeNet, AlexNet (default), and VGG networks. 80 | 81 | If you use LPIPS in your publication, please specify which version you are using. The current version is 0.1. You can set `version='0.0'` for the initial release. 82 | 83 | ## (2) Berkeley Adobe Perceptual Patch Similarity (BAPPS) dataset 84 | 85 | ### (A) Downloading the dataset 86 | 87 | Run `bash ./scripts/download_dataset.sh` to download and unzip the dataset into directory `./dataset`. It takes [6.6 GB] total. Alternatively, run `bash ./scripts/get_dataset_valonly.sh` to only download the validation set [1.3 GB]. 88 | - 2AFC train [5.3 GB] 89 | - 2AFC val [1.1 GB] 90 | - JND val [0.2 GB] 91 | 92 | ### (B) Evaluating a perceptual similarity metric on a dataset 93 | 94 | Script `test_dataset_model.py` evaluates a perceptual model on a subset of the dataset. 95 | 96 | **Dataset flags** 97 | - `--dataset_mode`: `2afc` or `jnd`, which type of perceptual judgment to evaluate 98 | - `--datasets`: list the datasets to evaluate 99 | - if `--dataset_mode 2afc`: choices are [`train/traditional`, `train/cnn`, `val/traditional`, `val/cnn`, `val/superres`, `val/deblur`, `val/color`, `val/frameinterp`] 100 | - if `--dataset_mode jnd`: choices are [`val/traditional`, `val/cnn`] 101 | 102 | **Perceptual similarity model flags** 103 | - `--model`: perceptual similarity model to use 104 | - `net-lin` for our LPIPS learned similarity model (linear network on top of internal activations of pretrained network) 105 | - `net` for a classification network (uncalibrated with all layers averaged) 106 | - `l2` for Euclidean distance 107 | - `ssim` for Structured Similarity Image Metric 108 | - `--net`: [`squeeze`,`alex`,`vgg`] for the `net-lin` and `net` models; ignored for `l2` and `ssim` models 109 | - `--colorspace`: choices are [`Lab`,`RGB`], used for the `l2` and `ssim` models; ignored for `net-lin` and `net` models 110 | 111 | **Misc flags** 112 | - `--batch_size`: evaluation batch size (will default to 1) 113 | - `--use_gpu`: turn on this flag for GPU usage 114 | 115 | An example usage is as follows: `python ./test_dataset_model.py --dataset_mode 2afc --datasets val/traditional val/cnn --model net-lin --net alex --use_gpu --batch_size 50`. This would evaluate our model on the "traditional" and "cnn" validation datasets. 116 | 117 | ### (C) About the dataset 118 | 119 | The dataset contains two types of perceptual judgements: **Two Alternative Forced Choice (2AFC)** and **Just Noticeable Differences (JND)**. 120 | 121 | **(1) 2AFC** Evaluators were given a patch triplet (1 reference + 2 distorted). They were asked to select which of the distorted was "closer" to the reference. 122 | 123 | Training sets contain 2 judgments/triplet. 124 | - `train/traditional` [56.6k triplets] 125 | - `train/cnn` [38.1k triplets] 126 | - `train/mix` [56.6k triplets] 127 | 128 | Validation sets contain 5 judgments/triplet. 129 | - `val/traditional` [4.7k triplets] 130 | - `val/cnn` [4.7k triplets] 131 | - `val/superres` [10.9k triplets] 132 | - `val/deblur` [9.4k triplets] 133 | - `val/color` [4.7k triplets] 134 | - `val/frameinterp` [1.9k triplets] 135 | 136 | Each 2AFC subdirectory contains the following folders: 137 | - `ref`: original reference patches 138 | - `p0,p1`: two distorted patches 139 | - `judge`: human judgments - 0 if all preferred p0, 1 if all humans preferred p1 140 | 141 | **(2) JND** Evaluators were presented with two patches - a reference and a distorted - for a limited time. They were asked if the patches were the same (identically) or different. 142 | 143 | Each set contains 3 human evaluations/example. 144 | - `val/traditional` [4.8k pairs] 145 | - `val/cnn` [4.8k pairs] 146 | 147 | Each JND subdirectory contains the following folders: 148 | - `p0,p1`: two patches 149 | - `same`: human judgments: 0 if all humans thought patches were different, 1 if all humans thought patches were same 150 | 151 | ### (D) Using the dataset to train the metric 152 | 153 | See script `train_test_metric.sh` for an example of training and testing the metric. The script will train a model on the full training set for 10 epochs, and then test the learned metric on all of the validation sets. The numbers should roughly match the **Alex - lin** row in Table 5 in the [paper](https://arxiv.org/abs/1801.03924). The code supports training a linear layer on top of an existing representation. Training will add a subdirectory in the `checkpoints` directory. 154 | 155 | You can also train "scratch" and "tune" versions by running `train_test_metric_scratch.sh` and `train_test_metric_tune.sh`, respectively. 156 | 157 | ### Docker Environment 158 | 159 | [Docker](https://hub.docker.com/r/shinyeyes/perceptualsimilarity/) set up by [SuperShinyEyes](https://github.com/SuperShinyEyes). 160 | 161 | ## Citation 162 | 163 | If you find this repository useful for your research, please use the following. 164 | 165 | ``` 166 | @inproceedings{zhang2018perceptual, 167 | title={The Unreasonable Effectiveness of Deep Features as a Perceptual Metric}, 168 | author={Zhang, Richard and Isola, Phillip and Efros, Alexei A and Shechtman, Eli and Wang, Oliver}, 169 | booktitle={CVPR}, 170 | year={2018} 171 | } 172 | ``` 173 | 174 | ## Acknowledgements 175 | 176 | This repository borrows partially from the [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) repository. The average precision (AP) code is borrowed from the [py-faster-rcnn](https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/datasets/voc_eval.py) repository. Backpropping through the metric was implemented by [Angjoo Kanazawa](https://github.com/akanazawa). 177 | -------------------------------------------------------------------------------- /models/HidingUNet_S.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import functools 4 | 5 | import torch 6 | import torch.nn as nn 7 | from models.module import Harm2d 8 | import torch.nn.functional as F 9 | import math 10 | import numpy as np 11 | 12 | 13 | 14 | # Defines the Unet generator. 15 | # |num_downs|: number of downsamplings in UNet. For example, 16 | # if |num_downs| == 7, image of size 128x128 will become of size 1x1 17 | # at the bottleneck 18 | 19 | def harm3x3(in_planes, out_planes, stride=1, level=None): 20 | """3x3 harmonic convolution with padding""" 21 | return Harm2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, 22 | bias=False, use_bn=False, level=level) 23 | 24 | def get_feature_cood(frequency_num=9): 25 | array=2*((np.arange(frequency_num)*1.0)/(frequency_num-1)) - 1 26 | #array = np.random.random(frequency_num) 27 | return torch.FloatTensor(np.float32(array))#.view(1, -1, 1, 1) 28 | 29 | 30 | class SELayer(nn.Module): 31 | def __init__(self, channel, reduction=32): 32 | super(SELayer, self).__init__() 33 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 34 | mip = max(8, channel // reduction) 35 | self.fc = nn.Sequential( 36 | nn.Linear(channel, mip, bias=False), 37 | nn.ReLU(inplace=True), 38 | nn.Linear(mip, channel, bias=False), 39 | nn.Sigmoid() 40 | ) 41 | 42 | def forward(self, x): 43 | b, c, _, _ = x.size() 44 | y = self.avg_pool(x).view(b, c) 45 | y = self.fc(y).view(b, c, 1, 1) 46 | #return x * y.expand_as(x) 47 | return y 48 | 49 | class attention(nn.Module): 50 | def __init__(self, channel, reduction=32): 51 | super(attention, self).__init__() 52 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 53 | mip = max(8, channel // reduction) 54 | 55 | #nn.Linear(channel, mip, bias=False), 56 | #nn.ReLU(inplace=True), 57 | #nn.Linear(mip, channel, bias=False), 58 | #nn.Sigmoid() 59 | self.pool_h = nn.AdaptiveAvgPool2d((None, 1)) 60 | self.pool_w = nn.AdaptiveAvgPool2d((1, None)) 61 | inp = channel 62 | oup = channel 63 | 64 | #mip = max(8, inp // reduction) 65 | 66 | self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0) 67 | self.bn1 = nn.BatchNorm2d(mip) 68 | self.act = nn.ReLU(inplace=True) 69 | 70 | self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0) 71 | self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0) 72 | self.conv_c = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0) 73 | 74 | def forward(self, x): 75 | '''b, c, _, _ = x.size() 76 | y = self.avg_pool(x).view(b, c) 77 | y = self.fc(y).view(b, c, 1, 1) 78 | ''' 79 | n, c, h, w = x.size() 80 | x_h = self.pool_h(x) 81 | x_w = self.pool_w(x).permute(0, 1, 3, 2) 82 | x_c = self.avg_pool(x) 83 | 84 | y = torch.cat([x_h, x_w, x_c], dim=2) 85 | y = self.conv1(y) 86 | y = self.bn1(y) 87 | y = self.act(y) 88 | 89 | x_h, x_w, x_c = torch.split(y, [h, w, 1], dim=2) 90 | x_w = x_w.permute(0, 1, 3, 2) 91 | 92 | a_h = self.conv_h(x_h).sigmoid() 93 | a_w = self.conv_w(x_w).sigmoid() 94 | a_c = self.conv_c(x_c).sigmoid() 95 | #return x * y.expand_as(x) 96 | return a_h*a_w*a_c 97 | 98 | class ChannelShuffle(nn.Module): 99 | 100 | def __init__(self, groups): 101 | super().__init__() 102 | self.groups = groups 103 | 104 | def forward(self, x): 105 | batchsize, channels, height, width = x.data.size() 106 | channels_per_group = int(channels / self.groups) 107 | 108 | #"""suppose a convolutional layer with g groups whose output has 109 | #g x n channels; we first reshape the output channel dimension 110 | #into (g, n)""" 111 | x = x.view(batchsize, self.groups, channels_per_group, height, width) 112 | 113 | #"""transposing and then flattening it back as the input of next layer.""" 114 | x = x.transpose(1, 2).contiguous() 115 | x = x.view(batchsize, -1, height, width) 116 | 117 | return x 118 | '''Hnet = UnetGenerator_C(input_nc=opt.channel_secret * opt.num_secret, output_nc=opt.channel_cover * opt.num_cover, 119 | num_downs=num_downs, norm_layer=norm_layer, output_function=nn.Tanh)''' 120 | 121 | 122 | class UnetGenerator_S(nn.Module): 123 | def __init__(self, input_nc, output_nc, num_downs, ngf=64, 124 | norm_layer=None, use_dropout=False, output_function=nn.Sigmoid): 125 | super(UnetGenerator_S, self).__init__() 126 | self.output_function = nn.Tanh 127 | '''self.tanh = output_function==nn.Tanh 128 | if self.tanh: 129 | self.factor = 10/255 130 | else: 131 | self.factor = 1.0''' 132 | nf = 9 133 | self.factor = 10 / 255 134 | self.tanh = nn.Tanh 135 | self.conv1 = nn.Conv2d(input_nc, 64, kernel_size=4, stride=2, padding=1, bias=False) 136 | self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False) 137 | self.conv3 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False) 138 | self.conv4 = nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False) 139 | self.conv5 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1, bias=False) 140 | self.bn2 = norm_layer(128) 141 | self.bn3 = norm_layer(256) 142 | self.bn4 = norm_layer(512) 143 | self.leakyrelu = nn.LeakyReLU(0.2, True) 144 | self.relu = nn.ReLU() 145 | self.convtran5 = nn.ConvTranspose2d(512, 512, kernel_size=4, stride=2, padding=1, bias=False) 146 | self.bnt5 = norm_layer(512) 147 | self.convtran4 = nn.ConvTranspose2d(1024, 256, kernel_size=4, stride=2, padding=1, bias=False) 148 | self.bnt4 = norm_layer(256) 149 | self.convtran3 = nn.ConvTranspose2d(512, 128, kernel_size=4, stride=2, padding=1, bias=False) 150 | self.bnt3 = norm_layer(128) 151 | self.convtran2 = nn.ConvTranspose2d(256, 64, kernel_size=4, stride=2, padding=1, bias=False) 152 | self.bnt2 = norm_layer(64) 153 | self.convtran1 = nn.ConvTranspose2d(128, output_nc, kernel_size=4, stride=2, padding=1, bias=False) 154 | 155 | self.dctconv1 = harm3x3(512, 512) 156 | self.dctconv2 = harm3x3(256, 256) 157 | self.dctconv3 = harm3x3(128, 128) 158 | self.dctconv4 = harm3x3(64, 64) 159 | self.dctconv5 = harm3x3(output_nc, output_nc) 160 | 161 | self.atten1 = attention(512) 162 | self.atten2 = attention(256) 163 | self.atten3 = attention(128) 164 | self.atten4 = attention(64) 165 | self.atten5 = attention(output_nc) 166 | self.channel_shuffle1 = ChannelShuffle(512) 167 | self.channel_shuffle2 = ChannelShuffle(256) 168 | self.channel_shuffle3 = ChannelShuffle(128) 169 | self.channel_shuffle4 = ChannelShuffle(64) 170 | self.channel_shuffle5 = ChannelShuffle(output_nc) 171 | 172 | self.f_atten = get_feature_cood(nf) 173 | self.f_atten_1 = nn.Parameter( 174 | self.f_atten.repeat(512).view(1, -1, 1, 1), requires_grad=True) 175 | self.f_atten_2 = nn.Parameter( 176 | self.f_atten.repeat(256).view(1, -1, 1, 1), requires_grad=True) 177 | self.f_atten_3 = nn.Parameter( 178 | self.f_atten.repeat(128).view(1, -1, 1, 1), requires_grad=True) 179 | self.f_atten_4 = nn.Parameter( 180 | self.f_atten.repeat(64).view(1, -1, 1, 1), requires_grad=True) 181 | self.f_atten_5 = nn.Parameter( 182 | self.f_atten.repeat(output_nc).view(1, -1, 1, 1), requires_grad=True) 183 | self.weight5 = nn.Parameter( 184 | nn.init.kaiming_normal_(torch.Tensor(512, 512 * nf, 1, 1), mode='fan_out', 185 | nonlinearity='relu')) 186 | self.weight4 = nn.Parameter( 187 | nn.init.kaiming_normal_(torch.Tensor(256, 256 * nf, 1, 1), mode='fan_out', 188 | nonlinearity='relu')) 189 | self.weight3 = nn.Parameter( 190 | nn.init.kaiming_normal_(torch.Tensor(128, 128 * nf, 1, 1), mode='fan_out', 191 | nonlinearity='relu')) 192 | self.weight2 = nn.Parameter( 193 | nn.init.kaiming_normal_(torch.Tensor(64, 64 * nf, 1, 1), mode='fan_out', 194 | nonlinearity='relu')) 195 | self.weight1 = nn.Parameter( 196 | nn.init.kaiming_normal_(torch.Tensor(output_nc, output_nc * nf, 1, 1), mode='fan_out', 197 | nonlinearity='relu')) 198 | self.groups = 1 199 | 200 | self.drop = nn.Dropout(0.5) 201 | 202 | 203 | def forward(self, input): 204 | out1 = self.conv1(input) 205 | out2 = self.bn2(self.conv2(self.leakyrelu(out1))) 206 | out3 = self.bn3(self.conv3(self.leakyrelu(out2))) 207 | out4 = self.bn4(self.conv4(self.leakyrelu(out3))) 208 | out5 = self.conv5(self.leakyrelu(out4)) 209 | out_5 = self.bnt5(self.convtran5(self.relu(out5))) 210 | out_dct_1 = self.dctconv1(out_5) 211 | #out_dct_1_cs = self.atten1_CAM(out_dct_1)+self.atten1_PAM(out_dct_1) 212 | #out_dct_1_cs = self.atten1(out_dct_1) 213 | out_dct_1_cs = self.atten1(out_5) 214 | out_dct_1_cs = self.channel_shuffle1(out_dct_1_cs.repeat(1, 9, 1, 1).expand_as(out_dct_1))#*out_dct_1 215 | out_dct_1_f = self.f_atten_1.expand_as(out_dct_1).to(input.device) #* out_dct_1 216 | out_dct_1 = (out_dct_1_cs+out_dct_1_f) * out_dct_1 217 | #out_5 = F.conv2d(out_dct_1, self.weight5, padding=0, groups=self.groups) 218 | out_5 = torch.cat([out4, out_5], 1) 219 | 220 | out_4 = self.bnt4(self.convtran4(self.relu(out_5))) 221 | out_4 = self.drop(out_4) 222 | #out_dct_2 = self.atten2(self.dctconv2(out_4)) 223 | out_dct_2 = self.dctconv2(out_4) 224 | #out_dct_2_cs = self.atten2_CAM(out_dct_2)+self.atten2_PAM(out_dct_2) 225 | out_dct_2_cs = self.atten2(out_4) 226 | out_dct_2_cs = self.channel_shuffle2(out_dct_2_cs.repeat(1, 9, 1, 1).expand_as(out_dct_2))#*out_dct_2 227 | out_dct_2_f = self.f_atten_2.expand_as(out_dct_2).to(input.device) #* out_dct_2 228 | out_dct_2 = (out_dct_2_cs + out_dct_2_f)* out_dct_2 229 | #out_4 = F.conv2d(out_dct_2, self.weight4, padding=0, groups=self.groups) 230 | out_4 = torch.cat([out3, out_4], 1) 231 | 232 | out_3 = self.bnt3(self.convtran3(self.relu(out_4))) 233 | #out_dct_3 = self.atten3(self.dctconv3(out_3)) 234 | out_dct_3 = self.dctconv3(out_3) 235 | #out_dct_3_cs = self.atten3_CAM(out_dct_3)+self.atten3_PAM(out_dct_3) 236 | out_dct_3_cs = self.atten3(out_3) 237 | out_dct_3_cs = self.channel_shuffle3(out_dct_3_cs.repeat(1, 9, 1, 1).expand_as(out_dct_3))#*out_dct_3 238 | out_dct_3_f = self.f_atten_3.expand_as(out_dct_3).to(input.device) #* out_dct_3 239 | out_dct_3 = (out_dct_3_cs + out_dct_3_f)*out_dct_3 240 | #out_3 = F.conv2d(out_dct_3, self.weight3, padding=0, groups=self.groups) 241 | out_3 = torch.cat([out2, out_3], 1) 242 | 243 | out_2 = self.bnt2(self.convtran2(self.relu(out_3))) 244 | #out_dct_4 = self.atten4(self.dctconv4(out_2)) 245 | out_dct_4 = self.dctconv4(out_2) 246 | #out_dct_4_cs = self.atten4_CAM(out_dct_4)+self.atten4_PAM(out_dct_4) 247 | out_dct_4_cs = self.atten4(out_2) 248 | out_dct_4_cs = self.channel_shuffle4(out_dct_4_cs.repeat(1, 9, 1, 1).expand_as(out_dct_4))#*out_dct_4 249 | out_dct_4_f = self.f_atten_4.expand_as(out_dct_4).to(input.device) #* out_dct_4 250 | out_dct_4 = (out_dct_4_cs + out_dct_4_f)* out_dct_4 251 | #out_2 = F.conv2d(out_dct_4, self.weight2, padding=0, groups=self.groups) 252 | out_2 = torch.cat([out1, out_2], 1) 253 | out_1 = self.relu(out_2) 254 | out_1 = self.convtran1(out_1) 255 | #out_dct_5 = self.atten5(self.dctconv5(out_1)) 256 | out_dct_5 = self.dctconv5(out_1) 257 | #out_dct_5_cs = self.atten5_CAM(out_dct_5) + self.atten5_PAM(out_dct_5) 258 | out_dct_5_cs = self.atten5(out_1) 259 | out_dct_5_cs = self.channel_shuffle5(out_dct_5_cs.repeat(1, 9, 1, 1).expand_as(out_dct_5))#*out_dct_5 260 | out_dct_5_f = self.f_atten_5.expand_as(out_dct_5).to(input.device) # 261 | out_dct_5 = (out_dct_5_cs + out_dct_5_f)* out_dct_5 262 | #out = F.conv2d(out_dct_5, self.weight1, padding=0, groups=self.groups) 263 | #out = torch.tanh(out_dct_5) 264 | #out = torch.tanh(out_dct_5) 265 | #out_dct_5 = self.factor * out 266 | # out = torch.cat([input, out], 1) 267 | 268 | return out_dct_1, out_dct_2, out_dct_3, out_dct_4, out_dct_5 269 | #return out_dct_1, out_dct_5 270 | 271 | 272 | -------------------------------------------------------------------------------- /PerceptualSimilarity/models/dist_model.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | import sys 5 | import numpy as np 6 | import torch 7 | from torch import nn 8 | import os 9 | from collections import OrderedDict 10 | from torch.autograd import Variable 11 | import itertools 12 | from .base_model import BaseModel 13 | from scipy.ndimage import zoom 14 | import fractions 15 | import functools 16 | import skimage.transform 17 | from tqdm import tqdm 18 | 19 | from IPython import embed 20 | 21 | from . import networks_basic as networks 22 | import models as util 23 | 24 | class DistModel(BaseModel): 25 | def name(self): 26 | return self.model_name 27 | 28 | def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None, 29 | use_gpu=True, printNet=False, spatial=False, 30 | is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]): 31 | ''' 32 | INPUTS 33 | model - ['net-lin'] for linearly calibrated network 34 | ['net'] for off-the-shelf network 35 | ['L2'] for L2 distance in Lab colorspace 36 | ['SSIM'] for ssim in RGB colorspace 37 | net - ['squeeze','alex','vgg'] 38 | model_path - if None, will look in weights/[NET_NAME].pth 39 | colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM 40 | use_gpu - bool - whether or not to use a GPU 41 | printNet - bool - whether or not to print network architecture out 42 | spatial - bool - whether to output an array containing varying distances across spatial dimensions 43 | spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below). 44 | spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images. 45 | spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear). 46 | is_train - bool - [True] for training mode 47 | lr - float - initial learning rate 48 | beta1 - float - initial momentum term for adam 49 | version - 0.1 for latest, 0.0 was original (with a bug) 50 | gpu_ids - int array - [0] by default, gpus to use 51 | ''' 52 | BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids) 53 | 54 | self.model = model 55 | self.net = net 56 | self.is_train = is_train 57 | self.spatial = spatial 58 | self.gpu_ids = gpu_ids 59 | self.model_name = '%s [%s]'%(model,net) 60 | 61 | if(self.model == 'net-lin'): # pretrained net + linear layer 62 | self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net, 63 | use_dropout=True, spatial=spatial, version=version, lpips=True) 64 | kw = {} 65 | if not use_gpu: 66 | kw['map_location'] = 'cpu' 67 | if(model_path is None): 68 | import inspect 69 | model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', 'weights/v%s/%s.pth'%(version,net))) 70 | 71 | if(not is_train): 72 | # print('Loading model from: %s'%model_path) 73 | self.net.load_state_dict(torch.load(model_path, **kw), strict=False) 74 | 75 | elif(self.model=='net'): # pretrained network 76 | self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False) 77 | elif(self.model in ['L2','l2']): 78 | self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing 79 | self.model_name = 'L2' 80 | elif(self.model in ['DSSIM','dssim','SSIM','ssim']): 81 | self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace) 82 | self.model_name = 'SSIM' 83 | else: 84 | raise ValueError("Model [%s] not recognized." % self.model) 85 | 86 | self.parameters = list(self.net.parameters()) 87 | 88 | if self.is_train: # training mode 89 | # extra network on top to go from distances (d0,d1) => predicted human judgment (h*) 90 | self.rankLoss = networks.BCERankingLoss() 91 | self.parameters += list(self.rankLoss.net.parameters()) 92 | self.lr = lr 93 | self.old_lr = lr 94 | self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999)) 95 | else: # test mode 96 | self.net.eval() 97 | 98 | if(use_gpu): 99 | self.net.to(gpu_ids[0]) 100 | self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids) 101 | if(self.is_train): 102 | self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0 103 | 104 | if(printNet): 105 | # print('---------- Networks initialized -------------') 106 | networks.print_network(self.net) 107 | # print('-----------------------------------------------') 108 | 109 | def forward(self, in0, in1, retPerLayer=False): 110 | ''' Function computes the distance between image patches in0 and in1 111 | INPUTS 112 | in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1] 113 | OUTPUT 114 | computed distances between in0 and in1 115 | ''' 116 | 117 | return self.net.forward(in0, in1, retPerLayer=retPerLayer) 118 | 119 | # ***** TRAINING FUNCTIONS ***** 120 | def optimize_parameters(self): 121 | self.forward_train() 122 | self.optimizer_net.zero_grad() 123 | self.backward_train() 124 | self.optimizer_net.step() 125 | self.clamp_weights() 126 | 127 | def clamp_weights(self): 128 | for module in self.net.modules(): 129 | if(hasattr(module, 'weight') and module.kernel_size==(1,1)): 130 | module.weight.data = torch.clamp(module.weight.data,min=0) 131 | 132 | def set_input(self, data): 133 | self.input_ref = data['ref'] 134 | self.input_p0 = data['p0'] 135 | self.input_p1 = data['p1'] 136 | self.input_judge = data['judge'] 137 | 138 | if(self.use_gpu): 139 | self.input_ref = self.input_ref.to(device=self.gpu_ids[0]) 140 | self.input_p0 = self.input_p0.to(device=self.gpu_ids[0]) 141 | self.input_p1 = self.input_p1.to(device=self.gpu_ids[0]) 142 | self.input_judge = self.input_judge.to(device=self.gpu_ids[0]) 143 | 144 | self.var_ref = Variable(self.input_ref,requires_grad=True) 145 | self.var_p0 = Variable(self.input_p0,requires_grad=True) 146 | self.var_p1 = Variable(self.input_p1,requires_grad=True) 147 | 148 | def forward_train(self): # run forward pass 149 | # print(self.net.module.scaling_layer.shift) 150 | # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item()) 151 | 152 | self.d0 = self.forward(self.var_ref, self.var_p0) 153 | self.d1 = self.forward(self.var_ref, self.var_p1) 154 | self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge) 155 | 156 | self.var_judge = Variable(1.*self.input_judge).view(self.d0.size()) 157 | 158 | self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.) 159 | 160 | return self.loss_total 161 | 162 | def backward_train(self): 163 | torch.mean(self.loss_total).backward() 164 | 165 | def compute_accuracy(self,d0,d1,judge): 166 | ''' d0, d1 are Variables, judge is a Tensor ''' 167 | d1_lt_d0 = (d1 %f' % (type,self.old_lr, lr)) 210 | self.old_lr = lr 211 | 212 | def score_2afc_dataset(data_loader, func, name=''): 213 | ''' Function computes Two Alternative Forced Choice (2AFC) score using 214 | distance function 'func' in dataset 'data_loader' 215 | INPUTS 216 | data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside 217 | func - callable distance function - calling d=func(in0,in1) should take 2 218 | pytorch tensors with shape Nx3xXxY, and return numpy array of length N 219 | OUTPUTS 220 | [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators 221 | [1] - dictionary with following elements 222 | d0s,d1s - N arrays containing distances between reference patch to perturbed patches 223 | gts - N array in [0,1], preferred patch selected by human evaluators 224 | (closer to "0" for left patch p0, "1" for right patch p1, 225 | "0.6" means 60pct people preferred right patch, 40pct preferred left) 226 | scores - N array in [0,1], corresponding to what percentage function agreed with humans 227 | CONSTS 228 | N - number of test triplets in data_loader 229 | ''' 230 | 231 | d0s = [] 232 | d1s = [] 233 | gts = [] 234 | 235 | for data in tqdm(data_loader.load_data(), desc=name): 236 | d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist() 237 | d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist() 238 | gts+=data['judge'].cpu().numpy().flatten().tolist() 239 | 240 | d0s = np.array(d0s) 241 | d1s = np.array(d1s) 242 | gts = np.array(gts) 243 | scores = (d0s>> not using cover in training 563 | cover_img.fill_(0.0) 564 | print('no_cover') 565 | if (opt.plain_cover or opt.noise_cover) and (val_cover == 0): 566 | cover_img.fill_(0.0) 567 | print('plain_cover') 568 | b, c, w, h = cover_img.size() 569 | 570 | if opt.plain_cover and (val_cover == 0): 571 | img_w1 = torch.cat((torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(), 572 | torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(), 573 | torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(), 574 | torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda()), dim=2) 575 | img_w2 = torch.cat((torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(), 576 | torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(), 577 | torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(), 578 | torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda()), dim=2) 579 | img_w3 = torch.cat((torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(), 580 | torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(), 581 | torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(), 582 | torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda()), dim=2) 583 | img_w4 = torch.cat((torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(), 584 | torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(), 585 | torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(), 586 | torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda()), dim=2) 587 | img_wh = torch.cat((img_w1, img_w2, img_w3, img_w4), dim=3) 588 | cover_img = cover_img + img_wh 589 | print('if opt.plain_cover and (val_cover == 0):') 590 | if opt.noise_cover and (val_cover == 0): 591 | cover_img = cover_img + ((torch.rand(b, c, w, h) - 0.5) * 2 * 0 / 255).cuda() 592 | print('if opt.noise_cover and (val_cover == 0):') 593 | #+++++++++++++++++++++++++++ 594 | cover_imgv = cover_img 595 | 596 | if opt.cover_dependent: 597 | H_input = torch.cat((cover_imgv, secret_imgv), dim=1) 598 | else: 599 | H_input = secret_imgv 600 | 601 | out_dct_1, out_dct_2, out_dct_3, out_dct_4, itm_secret_img = Hnet_S(H_input) 602 | #************** 603 | if i_c != None: 604 | print('if i_c != None') 605 | if type(i_c) == type(1.0): 606 | ####### To keep one channel ####### 607 | itm_secret_img_clone = itm_secret_img.clone() 608 | itm_secret_img.fill_(0) 609 | itm_secret_img[:, int(i_c):int(i_c) + 1, :, :] = itm_secret_img_clone[:, int(i_c):int(i_c) + 1, :, :] 610 | if type(i_c) == type(1): 611 | print('aaaaa', i_c) 612 | ####### To set one channel to zero ####### 613 | itm_secret_img[:, i_c:i_c + 1, :, :].fill_(0.0) 614 | 615 | if position != None: 616 | print('if position != None') 617 | itm_secret_img[:, :, position:position + 1, position:position + 1].fill_(0.0) 618 | if Se_two == 2: 619 | print('if Se_two == 2') 620 | itm_secret_img_half = itm_secret_img[0:batch_size_secret // 2, :, :, :] 621 | itm_secret_img = itm_secret_img + torch.cat((itm_secret_img_half.clone().fill_(0.0), itm_secret_img_half), 0) 622 | elif type(Se_two) == type(0.1): 623 | print('type(Se_two) == type(0.1)') 624 | itm_secret_img = itm_secret_img + Se_two * torch.rand(itm_secret_img.size()).cuda() 625 | if opt.cover_dependent: 626 | container_img = itm_secret_img 627 | else: 628 | itm_secret_img = itm_secret_img.repeat(opt.num_training, 1, 1, 1) 629 | container_img = Hnet_C(cover_img, out_dct_1, out_dct_2, out_dct_3, out_dct_4, itm_secret_img) 630 | #************** 631 | errH = criterion(container_img, cover_imgv) # Hiding net 632 | 633 | rev_secret_img = Rnet(container_img) 634 | errR = criterion(rev_secret_img, secret_imgv_nh) # Reveal net 635 | 636 | # L1 metric 637 | diffH = (container_img - cover_imgv).abs().mean() * 255 638 | diffR = (rev_secret_img - secret_imgv_nh).abs().mean() * 255 639 | return cover_imgv, container_img, secret_imgv_nh, rev_secret_img, errH, errR, diffH, diffR 640 | 641 | 642 | def train(train_loader, epoch, Hnet_C, Hnet_S, Rnet, criterion): 643 | batch_time = AverageMeter() 644 | data_time = AverageMeter() 645 | Hlosses = AverageMeter() 646 | Rlosses = AverageMeter() 647 | SumLosses = AverageMeter() 648 | Hdiff = AverageMeter() 649 | Rdiff = AverageMeter() 650 | 651 | # Switch to train mode 652 | Hnet_C.train() 653 | Hnet_S.train() 654 | Rnet.train() 655 | 656 | start_time = time.time() 657 | 658 | for i, ((secret_img, secret_target), (cover_img, cover_target)) in enumerate(train_loader, 0): 659 | 660 | data_time.update(time.time() - start_time) 661 | 662 | cover_imgv, container_img, secret_imgv_nh, rev_secret_img, errH, errR, diffH, diffR \ 663 | = forward_pass(secret_img, secret_target, cover_img, cover_target, Hnet_C, Hnet_S, Rnet, criterion) 664 | 665 | Hlosses.update(errH.item(), opt.bs_secret * opt.num_cover * opt.num_training) # H loss 666 | Rlosses.update(errR.item(), opt.bs_secret * opt.num_secret * opt.num_training) # R loss 667 | Hdiff.update(diffH.item(), opt.bs_secret * opt.num_cover * opt.num_training) 668 | Rdiff.update(diffR.item(), opt.bs_secret * opt.num_secret * opt.num_training) 669 | '''Hlosses.update(errH.data[0], opt.bs_secret * opt.num_cover * opt.num_training) # H loss 670 | Rlosses.update(errR.data[0], opt.bs_secret * opt.num_secret * opt.num_training) # R loss 671 | Hdiff.update(diffH.data[0], opt.bs_secret * opt.num_cover * opt.num_training) 672 | Rdiff.update(diffR.data[0], opt.bs_secret * opt.num_secret * opt.num_training)''' 673 | 674 | # Loss, backprop, and optimization step 675 | betaerrR_secret = opt.beta * errR 676 | err_sum = errH + betaerrR_secret 677 | optimizer.zero_grad() 678 | err_sum.backward() 679 | optimizer.step() 680 | 681 | # Time spent on one batch 682 | batch_time.update(time.time() - start_time) 683 | start_time = time.time() 684 | 685 | log = '[%d/%d][%d/%d]\tLoss_H: %.6f Loss_R: %.6f L1_H: %.4f L1_R: %.4f \tdatatime: %.4f \tbatchtime: %.4f' % ( 686 | epoch, opt.epochs, i, opt.iters_per_epoch, 687 | Hlosses.val, Rlosses.val, Hdiff.val, Rdiff.val, data_time.val, batch_time.val) 688 | 689 | if i % opt.logFrequency == 0: 690 | print(log) 691 | 692 | if epoch == opt.epochs and i % opt.resultPicFrequency == 0: 693 | save_result_pic(opt.bs_secret * opt.num_training, cover_imgv, container_img.data, secret_imgv_nh, 694 | rev_secret_img.data, epoch, i, opt.trainpics) 695 | 696 | if i == opt.iters_per_epoch - 1: 697 | break 698 | 699 | # To save the last batch only 700 | save_result_pic(opt.bs_secret * opt.num_training, cover_imgv, container_img.data, secret_imgv_nh, 701 | rev_secret_img.data, epoch, i, opt.trainpics) 702 | 703 | epoch_log = "Training[%d] Hloss=%.6f\tRloss=%.6f\tHdiff=%.4f\tRdiff=%.4f\tlr= %.6f\t Epoch time= %.4f" % ( 704 | epoch, Hlosses.avg, Rlosses.avg, Hdiff.avg, Rdiff.avg, optimizer.param_groups[0]['lr'], batch_time.sum) 705 | print_log(epoch_log, logPath) 706 | 707 | if not opt.debug: 708 | writer.add_scalar("lr/lr", optimizer.param_groups[0]['lr'], epoch) 709 | writer.add_scalar("lr/beta", opt.beta, epoch) 710 | writer.add_scalar('train/H_loss', Hlosses.avg, epoch) 711 | writer.add_scalar('train/R_loss', Rlosses.avg, epoch) 712 | writer.add_scalar('train/sum_loss', SumLosses.avg, epoch) 713 | writer.add_scalar('train/H_diff', Hdiff.avg, epoch) 714 | writer.add_scalar('train/R_diff', Rdiff.avg, epoch) 715 | 716 | 717 | def validation(val_loader, epoch, Hnet_C, Hnet_S,Rnet, criterion): 718 | print( 719 | "#################################################### validation begin ########################################################") 720 | start_time = time.time() 721 | Hnet_C.eval() 722 | Hnet_S.eval() 723 | Rnet.eval() 724 | batch_time = AverageMeter() 725 | Hlosses = AverageMeter() 726 | Rlosses = AverageMeter() 727 | SumLosses = AverageMeter() 728 | Hdiff = AverageMeter() 729 | Rdiff = AverageMeter() 730 | 731 | for i, ((secret_img, secret_target), (cover_img, cover_target)) in enumerate(val_loader, 0): 732 | 733 | cover_imgv, container_img, secret_imgv_nh, rev_secret_img, errH, errR, diffH, diffR \ 734 | = forward_pass(secret_img, secret_target, cover_img, cover_target, Hnet_C, Hnet_S, Rnet, criterion, val_cover=1) 735 | 736 | Hlosses.update(errH.item(), opt.bs_secret * opt.num_cover * opt.num_training) # H loss 737 | Rlosses.update(errR.item(), opt.bs_secret * opt.num_secret * opt.num_training) # R loss 738 | Hdiff.update(diffH.item(), opt.bs_secret * opt.num_cover * opt.num_training) 739 | Rdiff.update(diffR.item(), opt.bs_secret * opt.num_secret * opt.num_training) 740 | '''Hlosses.update(errH.data[0], opt.bs_secret * opt.num_cover * opt.num_training) # H loss 741 | Rlosses.update(errR.data[0], opt.bs_secret * opt.num_secret * opt.num_training) # R loss 742 | Hdiff.update(diffH.data[0], opt.bs_secret * opt.num_cover * opt.num_training) 743 | Rdiff.update(diffR.data[0], opt.bs_secret * opt.num_secret * opt.num_training)''' 744 | 745 | if i == 0: 746 | save_result_pic(opt.bs_secret * opt.num_training, cover_imgv, container_img.data, secret_imgv_nh, 747 | rev_secret_img.data, epoch, i, opt.validationpics) 748 | if epoch == opt.epochs and i % opt.resultPicFrequency == 0: 749 | save_result_pic(opt.bs_secret * opt.num_training, cover_imgv, container_img.data, secret_imgv_nh, 750 | rev_secret_img.data, epoch, i, opt.trainpics) 751 | if opt.num_secret >= 6: 752 | i_total = 80 753 | else: 754 | i_total = 200 755 | if i == i_total - 1: 756 | break 757 | 758 | batch_time.update(time.time() - start_time) 759 | start_time = time.time() 760 | 761 | val_log = "validation[%d] val_Hloss = %.6f\t val_Rloss = %.6f\t val_Hdiff = %.6f\t val_Rdiff=%.2f\t batch time=%.2f" % ( 762 | epoch, Hlosses.val, Rlosses.val, Hdiff.val, Rdiff.val, batch_time.val) 763 | if i % opt.logFrequency == 0: 764 | print(val_log) 765 | 766 | val_log = "validation[%d] val_Hloss = %.6f\t val_Rloss = %.6f\t val_Hdiff = %.4f\t val_Rdiff=%.4f\t validation time=%.2f" % ( 767 | epoch, Hlosses.avg, Rlosses.avg, Hdiff.avg, Rdiff.avg, batch_time.sum) 768 | print_log(val_log, logPath) 769 | 770 | if not opt.debug: 771 | writer.add_scalar('validation/H_loss_avg', Hlosses.avg, epoch) 772 | writer.add_scalar('validation/R_loss_avg', Rlosses.avg, epoch) 773 | writer.add_scalar('validation/H_diff_avg', Hdiff.avg, epoch) 774 | writer.add_scalar('validation/R_diff_avg', Rdiff.avg, epoch) 775 | 776 | print( 777 | "#################################################### validation end ########################################################") 778 | return Hlosses.avg, Rlosses.avg, Hdiff.avg, Rdiff.avg 779 | 780 | 781 | #def analysis(val_loader, epoch, Hnet, Rnet, HnetD, RnetD, criterion): 782 | def analysis(val_loader, epoch, Hnet_C, Hnet_S, Rnet, criterion): 783 | print( 784 | "#################################################### analysis begin ########################################################") 785 | 786 | Hnet_C.eval() 787 | Hnet_S.eval() 788 | Rnet.eval() 789 | Hdiff = AverageMeter() 790 | Rdiff = AverageMeter() 791 | psnr_C = AverageMeter() 792 | psnr_S = AverageMeter() 793 | ssim_C = AverageMeter() 794 | ssim_S = AverageMeter() 795 | lpips_C = AverageMeter() 796 | lpips_S = AverageMeter() 797 | 798 | #HnetD.eval() 799 | #RnetD.eval() 800 | import warnings 801 | warnings.filterwarnings("ignore") 802 | 803 | for ii, ((secret_img, secret_target), (cover_img, cover_target)) in enumerate(val_loader, 0): 804 | 805 | ####################################### Cover Agnostic ####################################### 806 | cover_imgv, container_img, secret_imgv_nh, rev_secret_img, errH, errR, diffH, diffR \ 807 | = forward_pass(secret_img, secret_target, cover_img, cover_target, Hnet_C, Hnet_S, Rnet, criterion, val_cover=1) 808 | secret_encoded = container_img - cover_imgv 809 | 810 | '''save_result_pic_analysis(opt.bs_secret * opt.num_training, cover_imgv.clone(), container_img.clone(), 811 | secret_imgv_nh.clone(), rev_secret_img.clone(), epoch, i, opt.validationpics)''' 812 | 813 | N, _, _, _ = rev_secret_img.shape 814 | 815 | cover_img_numpy = cover_imgv.clone().cpu().detach().numpy() 816 | container_img_numpy = container_img.clone().cpu().detach().numpy() 817 | 818 | cover_img_numpy = cover_img_numpy.transpose(0, 2, 3, 1) 819 | container_img_numpy = container_img_numpy.transpose(0, 2, 3, 1) 820 | 821 | rev_secret_numpy = rev_secret_img.cpu().detach().numpy() 822 | secret_img_numpy = secret_imgv_nh.cpu().detach().numpy() 823 | 824 | rev_secret_numpy = rev_secret_numpy.transpose(0, 2, 3, 1) 825 | secret_img_numpy = secret_img_numpy.transpose(0, 2, 3, 1) 826 | 827 | # PSNR 828 | print("Cover Agnostic") 829 | 830 | print("Secret APD C:", diffH.item()) 831 | 832 | psnr_c = np.zeros((N, 3)) 833 | for i in range(N): 834 | psnr_c[i, 0] = PSNR(cover_img_numpy[i, :, :, 0], container_img_numpy[i, :, :, 0]) 835 | psnr_c[i, 1] = PSNR(cover_img_numpy[i, :, :, 1], container_img_numpy[i, :, :, 1]) 836 | psnr_c[i, 2] = PSNR(cover_img_numpy[i, :, :, 2], container_img_numpy[i, :, :, 2]) 837 | print("Avg. PSNR C:", psnr_c.mean().item()) 838 | 839 | # SSIM 840 | ssim_c = np.zeros(N) 841 | for i in range(N): 842 | ssim_c[i] = SSIM(cover_img_numpy[i], container_img_numpy[i], multichannel=True) 843 | print("Avg. SSIM C:", ssim_c.mean().item()) 844 | 845 | # LPIPS 846 | import PerceptualSimilarity.models 847 | model = PerceptualSimilarity.models.PerceptualLoss(model='net-lin', net='alex', use_gpu=True, gpu_ids=[0]) 848 | lpips_c = model.forward(cover_imgv, container_img) 849 | print("Avg. LPIPS C:", lpips_c.mean().item()) 850 | 851 | print("Secret APD S:", diffR.item()) 852 | 853 | psnr_s = np.zeros(N) 854 | for i in range(N): 855 | psnr_s[i] = PSNR(secret_img_numpy[i], rev_secret_numpy[i]) 856 | print("Avg. PSNR S:", psnr_s.mean().item()) 857 | 858 | # SSIM 859 | ssim_s = np.zeros(N) 860 | for i in range(N): 861 | ssim_s[i] = SSIM(secret_img_numpy[i], rev_secret_numpy[i], multichannel=True) 862 | print("Avg. SSIM S:", ssim_s.mean().item()) 863 | 864 | # LPIPS 865 | import PerceptualSimilarity.models 866 | model = PerceptualSimilarity.models.PerceptualLoss(model='net-lin', net='alex', use_gpu=True, gpu_ids=[0]) 867 | secret_imgv_nh_1 = secret_imgv_nh.view(-1, 3, 128, 128) 868 | rev_secret_img_1 = rev_secret_img.view(-1, 3, 128, 128) 869 | lpips_s = model.forward(secret_imgv_nh_1, rev_secret_img_1) 870 | print("Avg. LPIPS S:", lpips_s.mean().item()) 871 | 872 | #print("*******DONE!**********") 873 | 874 | #break 875 | lpips_S.update(lpips_s.mean().item(), opt.bs_secret * opt.num_cover * opt.num_training) # H loss 876 | psnr_S.update(psnr_s.mean().item(), opt.bs_secret * opt.num_secret * opt.num_training) # R loss 877 | ssim_S.update(ssim_s.mean().item(), opt.bs_secret * opt.num_cover * opt.num_training) 878 | Rdiff.update(diffR.item(), opt.bs_secret * opt.num_secret * opt.num_training) 879 | lpips_C.update(lpips_c.mean().item(), opt.bs_secret * opt.num_cover * opt.num_training) # H loss 880 | psnr_C.update(psnr_c.mean().item(), opt.bs_secret * opt.num_secret * opt.num_training) # R loss 881 | ssim_C.update(ssim_c.mean().item(), opt.bs_secret * opt.num_cover * opt.num_training) 882 | Hdiff.update(diffH.item(), opt.bs_secret * opt.num_secret * opt.num_training) 883 | if opt.num_secret >= 6: 884 | i_total = 80 885 | else: 886 | i_total = 200 887 | if ii == i_total - 1: 888 | break 889 | print('Hdiff.avg, Rdiff.avg', Hdiff.avg, Rdiff.avg) 890 | print('Hdiff.avg', Hdiff.avg, 'psnr_c.avg', psnr_C.avg, 'ssim_c.avg', ssim_C.avg, 'lpips_c.avg', lpips_C.avg) 891 | print('Rdiff.avg', Rdiff.avg, 'psnr_s.avg', psnr_S.avg, 'ssim_s.avg', ssim_S.avg, 'lpips_s.avg', lpips_S.avg) 892 | 893 | 894 | 895 | def print_log(log_info, log_path, console=True): 896 | # print the info into the console 897 | if console: 898 | print(log_info) 899 | # debug mode don't write the log into files 900 | if not opt.debug: 901 | # write the log into log file 902 | if not os.path.exists(log_path): 903 | fp = open(log_path, "w") 904 | fp.writelines(log_info + "\n") 905 | else: 906 | with open(log_path, 'a+') as f: 907 | f.writelines(log_info + '\n') 908 | 909 | 910 | def adjust_learning_rate(optimizer, epoch): 911 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 912 | lr = opt.lr * (0.1 ** (epoch // 30)) 913 | for param_group in optimizer.param_groups: 914 | param_group['lr'] = lr 915 | 916 | 917 | # save result pic and the coverImg filePath and the secretImg filePath 918 | def save_result_pic_analysis(bs_secret_times_num_training, cover, container, secret, rev_secret, epoch, i, 919 | save_path=None, postname=''): 920 | path = './qualitative_results/' 921 | if not os.path.exists(path): 922 | os.makedirs(path) 923 | resultImgName = path + 'universal_qualitative_results.png' 924 | 925 | cover = cover[:4] 926 | container = container[:4] 927 | secret = secret[:4] 928 | rev_secret = rev_secret[:4] 929 | 930 | cover_gap = container - cover 931 | secret_gap = rev_secret - secret 932 | cover_gap = (cover_gap * 10 + 0.5).clamp_(0.0, 1.0) 933 | secret_gap = (secret_gap * 10 + 0.5).clamp_(0.0, 1.0) 934 | 935 | for i_cover in range(4): 936 | cover_i = cover[:, i_cover * opt.channel_cover:(i_cover + 1) * opt.channel_cover, :, :] 937 | container_i = container[:, i_cover * opt.channel_cover:(i_cover + 1) * opt.channel_cover, :, :] 938 | cover_gap_i = cover_gap[:, i_cover * opt.channel_cover:(i_cover + 1) * opt.channel_cover, :, :] 939 | 940 | if i_cover == 0: 941 | showCover = torch.cat((cover_i, container_i, cover_gap_i), 0) 942 | else: 943 | showCover = torch.cat((showCover, cover_i, container_i, cover_gap_i), 0) 944 | 945 | for i_secret in range(4): 946 | secret_i = secret[:, i_secret * opt.channel_secret:(i_secret + 1) * opt.channel_secret, :, :] 947 | rev_secret_i = rev_secret[:, i_secret * opt.channel_secret:(i_secret + 1) * opt.channel_secret, :, :] 948 | secret_gap_i = secret_gap[:, i_secret * opt.channel_secret:(i_secret + 1) * opt.channel_secret, :, :] 949 | 950 | if i_secret == 0: 951 | showSecret = torch.cat((secret_i, rev_secret_i, secret_gap_i), 0) 952 | else: 953 | showSecret = torch.cat((showSecret, secret_i, rev_secret_i, secret_gap_i), 0) 954 | 955 | showAll = torch.cat((showCover, showSecret), 0) 956 | showAll = showAll.reshape(6, 4, 3, 128, 128) 957 | showAll = showAll.permute(1, 0, 2, 3, 4) 958 | showAll = showAll.reshape(4 * 6, 3, 128, 128) 959 | vutils.save_image(showAll, resultImgName, nrow=6, padding=1, normalize=False) 960 | 961 | 962 | # save result pic and the coverImg filePath and the secretImg filePath 963 | def save_result_pic(bs_secret_times_num_training, cover, container, secret, rev_secret, epoch, i, save_path=None, 964 | postname=''): 965 | # if not opt.debug: 966 | # cover=container: bs*nt/nc; secret=rev_secret: bs*nt/3*nh 967 | if opt.debug: 968 | save_path = './debug/debug_images' 969 | resultImgName = '%s/ResultPics_epoch%03d_batch%04d%s.png' % (save_path, epoch, i, postname) 970 | 971 | cover_gap = container - cover 972 | secret_gap = rev_secret - secret 973 | cover_gap = (cover_gap * 10 + 0.5).clamp_(0.0, 1.0) 974 | secret_gap = (secret_gap * 10 + 0.5).clamp_(0.0, 1.0) 975 | # print(cover_gap.abs().sum(dim=-1).sum(dim=-1).sum(dim=-1), secret_gap.abs().sum(dim=-1).sum(dim=-1).sum(dim=-1)) 976 | 977 | # showCover = torch.cat((cover, container, cover_gap),0) 978 | 979 | for i_cover in range(opt.num_cover): 980 | cover_i = cover[:, i_cover * opt.channel_cover:(i_cover + 1) * opt.channel_cover, :, :] 981 | container_i = container[:, i_cover * opt.channel_cover:(i_cover + 1) * opt.channel_cover, :, :] 982 | cover_gap_i = cover_gap[:, i_cover * opt.channel_cover:(i_cover + 1) * opt.channel_cover, :, :] 983 | 984 | if i_cover == 0: 985 | showCover = torch.cat((cover_i, container_i, cover_gap_i), 0) 986 | else: 987 | showCover = torch.cat((showCover, cover_i, container_i, cover_gap_i), 0) 988 | 989 | for i_secret in range(opt.num_secret): 990 | secret_i = secret[:, i_secret * opt.channel_secret:(i_secret + 1) * opt.channel_secret, :, :] 991 | rev_secret_i = rev_secret[:, i_secret * opt.channel_secret:(i_secret + 1) * opt.channel_secret, :, :] 992 | secret_gap_i = secret_gap[:, i_secret * opt.channel_secret:(i_secret + 1) * opt.channel_secret, :, :] 993 | 994 | if i_secret == 0: 995 | showSecret = torch.cat((secret_i, rev_secret_i, secret_gap_i), 0) 996 | else: 997 | showSecret = torch.cat((showSecret, secret_i, rev_secret_i, secret_gap_i), 0) 998 | 999 | if opt.channel_secret == opt.channel_cover: 1000 | showAll = torch.cat((showCover, showSecret), 0) 1001 | vutils.save_image(showAll, resultImgName, nrow=bs_secret_times_num_training, padding=1, normalize=True) 1002 | else: 1003 | ContainerImgName = '%s/ContainerPics_epoch%03d_batch%04d.png' % (save_path, epoch, i) 1004 | SecretImgName = '%s/SecretPics_epoch%03d_batch%04d.png' % (save_path, epoch, i) 1005 | vutils.save_image(showCover, ContainerImgName, nrow=bs_secret_times_num_training, padding=1, normalize=True) 1006 | vutils.save_image(showSecret, SecretImgName, nrow=bs_secret_times_num_training, padding=1, normalize=True) 1007 | 1008 | 1009 | 1010 | 1011 | class AverageMeter(object): 1012 | """ 1013 | Computes and stores the average and current value. 1014 | """ 1015 | 1016 | def __init__(self): 1017 | self.reset() 1018 | 1019 | def reset(self): 1020 | self.val = 0 1021 | self.avg = 0 1022 | self.sum = 0 1023 | self.count = 0 1024 | 1025 | def update(self, val, n=1): 1026 | self.val = val 1027 | self.sum += val * n 1028 | self.count += n 1029 | self.avg = self.sum / self.count 1030 | 1031 | 1032 | if __name__ == '__main__': 1033 | main() --------------------------------------------------------------------------------