├── .idea
├── .name
├── misc.xml
├── vcs.xml
├── inspectionProfiles
│ └── profiles_settings.xml
├── .gitignore
├── modules.xml
└── DAH_github.iml
├── PerceptualSimilarity
├── data
│ ├── __init__.py
│ ├── dataset
│ │ ├── __init__.py
│ │ ├── base_dataset.py
│ │ ├── jnd_dataset.py
│ │ └── twoafc_dataset.py
│ ├── base_data_loader.py
│ ├── data_loader.py
│ ├── custom_dataset_data_loader.py
│ └── image_folder.py
├── util
│ ├── __init__.py
│ ├── util.py
│ ├── html.py
│ └── visualizer.py
├── .gitignore
├── imgs
│ ├── fig1.png
│ ├── ex_p0.png
│ ├── ex_p1.png
│ ├── ex_ref.png
│ ├── ex_dir0
│ │ ├── 0.png
│ │ └── 1.png
│ ├── ex_dir1
│ │ ├── 0.png
│ │ └── 1.png
│ └── ex_dir_pair
│ │ ├── ex_p0.png
│ │ ├── ex_p1.png
│ │ └── ex_ref.png
├── scripts
│ ├── eval_valsets.sh
│ ├── train_test_metric.sh
│ ├── train_test_metric_tune.sh
│ ├── train_test_metric_scratch.sh
│ ├── download_dataset_valonly.sh
│ └── download_dataset.sh
├── models
│ ├── weights
│ │ ├── v0.0
│ │ │ ├── alex.pth
│ │ │ ├── vgg.pth
│ │ │ └── squeeze.pth
│ │ └── v0.1
│ │ │ ├── alex.pth
│ │ │ ├── vgg.pth
│ │ │ └── squeeze.pth
│ ├── base_model.py
│ ├── __init__.py
│ ├── pretrained_networks.py
│ ├── networks_basic.py
│ └── dist_model.py
├── requirements.txt
├── compute_dists.py
├── compute_dists_dirs.py
├── LICENSE
├── perceptual_loss.py
├── Dockerfile
├── test_network.py
├── compute_dists_pair.py
├── test_dataset_model.py
├── train.py
└── README.md
├── fig
├── AFE.jpg
├── AFE.pdf
├── GDE.jpg
└── GDE.pdf
├── models
├── __pycache__
│ ├── module.cpython-38.pyc
│ ├── RevealNet.cpython-38.pyc
│ ├── HidingUNet_C.cpython-38.pyc
│ └── HidingUNet_S.cpython-38.pyc
├── RevealNet.py
├── module.py
├── HidingUNet_S.py
└── HidingUNet_C.py
├── runs
└── main_dah_beta0.75_num_secret1num_cover1_2023-08-01_H19-31-52
│ └── events.out.tfevents.1690889527.amax-SYS-7049GP-TRT
├── scripts
└── train_dah.sh
├── training
└── main_dah_beta0.75_num_secret1num_cover1_2023-08-01_H19-31-52
│ └── trainingLogs
│ └── train_44_log.txt
├── README.md
└── main_DAH.py
/.idea/.name:
--------------------------------------------------------------------------------
1 | main_DAH.py
--------------------------------------------------------------------------------
/PerceptualSimilarity/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/util/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/data/dataset/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 |
3 | checkpoints/*
4 |
--------------------------------------------------------------------------------
/fig/AFE.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/fig/AFE.jpg
--------------------------------------------------------------------------------
/fig/AFE.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/fig/AFE.pdf
--------------------------------------------------------------------------------
/fig/GDE.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/fig/GDE.jpg
--------------------------------------------------------------------------------
/fig/GDE.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/fig/GDE.pdf
--------------------------------------------------------------------------------
/PerceptualSimilarity/imgs/fig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/PerceptualSimilarity/imgs/fig1.png
--------------------------------------------------------------------------------
/PerceptualSimilarity/imgs/ex_p0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/PerceptualSimilarity/imgs/ex_p0.png
--------------------------------------------------------------------------------
/PerceptualSimilarity/imgs/ex_p1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/PerceptualSimilarity/imgs/ex_p1.png
--------------------------------------------------------------------------------
/PerceptualSimilarity/imgs/ex_ref.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/PerceptualSimilarity/imgs/ex_ref.png
--------------------------------------------------------------------------------
/PerceptualSimilarity/imgs/ex_dir0/0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/PerceptualSimilarity/imgs/ex_dir0/0.png
--------------------------------------------------------------------------------
/PerceptualSimilarity/imgs/ex_dir0/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/PerceptualSimilarity/imgs/ex_dir0/1.png
--------------------------------------------------------------------------------
/PerceptualSimilarity/imgs/ex_dir1/0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/PerceptualSimilarity/imgs/ex_dir1/0.png
--------------------------------------------------------------------------------
/PerceptualSimilarity/imgs/ex_dir1/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/PerceptualSimilarity/imgs/ex_dir1/1.png
--------------------------------------------------------------------------------
/models/__pycache__/module.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/models/__pycache__/module.cpython-38.pyc
--------------------------------------------------------------------------------
/PerceptualSimilarity/scripts/eval_valsets.sh:
--------------------------------------------------------------------------------
1 |
2 | python ./test_dataset_model.py --dataset_mode 2afc --model net-lin --net alex --use_gpu --batch_size 50
3 |
4 |
--------------------------------------------------------------------------------
/models/__pycache__/RevealNet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/models/__pycache__/RevealNet.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/HidingUNet_C.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/models/__pycache__/HidingUNet_C.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/HidingUNet_S.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/models/__pycache__/HidingUNet_S.cpython-38.pyc
--------------------------------------------------------------------------------
/PerceptualSimilarity/imgs/ex_dir_pair/ex_p0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/PerceptualSimilarity/imgs/ex_dir_pair/ex_p0.png
--------------------------------------------------------------------------------
/PerceptualSimilarity/imgs/ex_dir_pair/ex_p1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/PerceptualSimilarity/imgs/ex_dir_pair/ex_p1.png
--------------------------------------------------------------------------------
/PerceptualSimilarity/imgs/ex_dir_pair/ex_ref.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/PerceptualSimilarity/imgs/ex_dir_pair/ex_ref.png
--------------------------------------------------------------------------------
/PerceptualSimilarity/models/weights/v0.0/alex.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/PerceptualSimilarity/models/weights/v0.0/alex.pth
--------------------------------------------------------------------------------
/PerceptualSimilarity/models/weights/v0.0/vgg.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/PerceptualSimilarity/models/weights/v0.0/vgg.pth
--------------------------------------------------------------------------------
/PerceptualSimilarity/models/weights/v0.1/alex.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/PerceptualSimilarity/models/weights/v0.1/alex.pth
--------------------------------------------------------------------------------
/PerceptualSimilarity/models/weights/v0.1/vgg.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/PerceptualSimilarity/models/weights/v0.1/vgg.pth
--------------------------------------------------------------------------------
/PerceptualSimilarity/models/weights/v0.0/squeeze.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/PerceptualSimilarity/models/weights/v0.0/squeeze.pth
--------------------------------------------------------------------------------
/PerceptualSimilarity/models/weights/v0.1/squeeze.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/PerceptualSimilarity/models/weights/v0.1/squeeze.pth
--------------------------------------------------------------------------------
/PerceptualSimilarity/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=0.4.0
2 | torchvision>=0.2.1
3 | numpy>=1.14.3
4 | scipy>=1.0.1
5 | scikit-image>=0.13.0
6 | opencv>=2.4.11
7 | matplotlib>=1.5.1
8 | tqdm>=4.28.1
9 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/data/base_data_loader.py:
--------------------------------------------------------------------------------
1 |
2 | class BaseDataLoader():
3 | def __init__(self):
4 | pass
5 |
6 | def initialize(self):
7 | pass
8 |
9 | def load_data():
10 | return None
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Datasource local storage ignored files
5 | /../../../../../../:\Users\lgm\Desktop\DAH_github\.idea/dataSources/
6 | /dataSources.local.xml
7 | # Editor-based HTTP Client requests
8 | /httpRequests/
9 |
--------------------------------------------------------------------------------
/runs/main_dah_beta0.75_num_secret1num_cover1_2023-08-01_H19-31-52/events.out.tfevents.1690889527.amax-SYS-7049GP-TRT:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangle408/Deep-adaptive-hiding-network/HEAD/runs/main_dah_beta0.75_num_secret1num_cover1_2023-08-01_H19-31-52/events.out.tfevents.1690889527.amax-SYS-7049GP-TRT
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/scripts/train_test_metric.sh:
--------------------------------------------------------------------------------
1 |
2 | TRIAL=${1}
3 | NET=${2}
4 | mkdir checkpoints
5 | mkdir checkpoints/${NET}_${TRIAL}
6 | python ./train.py --use_gpu --net ${NET} --name ${NET}_${TRIAL}
7 | python ./test_dataset_model.py --use_gpu --net ${NET} --model_path ./checkpoints/${NET}_${TRIAL}/latest_net_.pth
8 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/data/dataset/base_dataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | class BaseDataset(data.Dataset):
4 | def __init__(self):
5 | super(BaseDataset, self).__init__()
6 |
7 | def name(self):
8 | return 'BaseDataset'
9 |
10 | def initialize(self):
11 | pass
12 |
13 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/scripts/train_test_metric_tune.sh:
--------------------------------------------------------------------------------
1 |
2 | TRIAL=${1}
3 | NET=${2}
4 | mkdir checkpoints
5 | mkdir checkpoints/${NET}_${TRIAL}_tune
6 | python ./train.py --train_trunk --use_gpu --net ${NET} --name ${NET}_${TRIAL}_tune
7 | python ./test_dataset_model.py --train_trunk --use_gpu --net ${NET} --model_path ./checkpoints/${NET}_${TRIAL}_tune/latest_net_.pth
8 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/scripts/train_test_metric_scratch.sh:
--------------------------------------------------------------------------------
1 |
2 | TRIAL=${1}
3 | NET=${2}
4 | mkdir checkpoints
5 | mkdir checkpoints/${NET}_${TRIAL}_scratch
6 | python ./train.py --from_scratch --train_trunk --use_gpu --net ${NET} --name ${NET}_${TRIAL}_scratch
7 | python ./test_dataset_model.py --from_scratch --train_trunk --use_gpu --net ${NET} --model_path ./checkpoints/${NET}_${TRIAL}_scratch/latest_net_.pth
8 |
9 |
--------------------------------------------------------------------------------
/scripts/train_dah.sh:
--------------------------------------------------------------------------------
1 | #python main_DAH.py --imageSize 128 --bs_secret 44 --num_training 1 --num_secret 1 --num_cover 1 --channel_cover 3 --channel_secret 3 --norm 'batch' --epochs 120 --loss 'l2' --beta 0.75 --remark 'main_dah'
2 | python main_DAH.py --imageSize 128 --bs_secret 44 --num_training 1 --num_secret 1 --num_cover 1 --channel_cover 3 --channel_secret 3 --norm 'batch' --epochs 120 --loss 'l2' --beta 0.75 --remark 'main_dah'
--------------------------------------------------------------------------------
/PerceptualSimilarity/data/data_loader.py:
--------------------------------------------------------------------------------
1 | def CreateDataLoader(datafolder,dataroot='./dataset',dataset_mode='2afc',load_size=64,batch_size=1,serial_batches=True,nThreads=4):
2 | from data.custom_dataset_data_loader import CustomDatasetDataLoader
3 | data_loader = CustomDatasetDataLoader()
4 | # print(data_loader.name())
5 | data_loader.initialize(datafolder,dataroot=dataroot+'/'+dataset_mode,dataset_mode=dataset_mode,load_size=load_size,batch_size=batch_size,serial_batches=serial_batches, nThreads=nThreads)
6 | return data_loader
7 |
--------------------------------------------------------------------------------
/.idea/DAH_github.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/scripts/download_dataset_valonly.sh:
--------------------------------------------------------------------------------
1 |
2 | mkdir dataset
3 |
4 | # JND Dataset
5 | wget https://people.eecs.berkeley.edu/~rich.zhang/projects/2018_perceptual/dataset/jnd.tar.gz -O ./dataset/jnd.tar.gz
6 |
7 | mkdir dataset/jnd
8 | tar -xzf ./dataset/jnd.tar.gz -C ./dataset
9 | rm ./dataset/jnd.tar.gz
10 |
11 | # 2AFC Val set
12 | mkdir dataset/2afc/
13 | wget https://people.eecs.berkeley.edu/~rich.zhang/projects/2018_perceptual/dataset/twoafc_val.tar.gz -O ./dataset/twoafc_val.tar.gz
14 |
15 | mkdir dataset/2afc/val
16 | tar -xzf ./dataset/twoafc_val.tar.gz -C ./dataset/2afc
17 | rm ./dataset/twoafc_val.tar.gz
18 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/compute_dists.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import models
3 | from util import util
4 |
5 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
6 | parser.add_argument('-p0','--path0', type=str, default='./imgs/ex_ref.png')
7 | parser.add_argument('-p1','--path1', type=str, default='./imgs/ex_p0.png')
8 | parser.add_argument('--use_gpu', action='store_true', help='turn on flag to use GPU')
9 |
10 | opt = parser.parse_args()
11 |
12 | ## Initializing the model
13 | model = models.PerceptualLoss(model='net-lin',net='alex',use_gpu=opt.use_gpu)
14 |
15 | # Load images
16 | img0 = util.im2tensor(util.load_image(opt.path0)) # RGB image from [-1,1]
17 | img1 = util.im2tensor(util.load_image(opt.path1))
18 |
19 | if(opt.use_gpu):
20 | img0 = img0.cuda()
21 | img1 = img1.cuda()
22 |
23 |
24 | # Compute distance
25 | dist01 = model.forward(img0,img1)
26 | print('Distance: %.3f'%dist01)
27 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/scripts/download_dataset.sh:
--------------------------------------------------------------------------------
1 |
2 | mkdir dataset
3 |
4 | # JND Dataset
5 | wget https://people.eecs.berkeley.edu/~rich.zhang/projects/2018_perceptual/dataset/jnd.tar.gz -O ./dataset/jnd.tar.gz
6 |
7 | mkdir dataset/jnd
8 | tar -xzf ./dataset/jnd.tar.gz -C ./dataset
9 | rm ./dataset/jnd.tar.gz
10 |
11 | # 2AFC Val set
12 | mkdir dataset/2afc/
13 | wget https://people.eecs.berkeley.edu/~rich.zhang/projects/2018_perceptual/dataset/twoafc_val.tar.gz -O ./dataset/twoafc_val.tar.gz
14 |
15 | mkdir dataset/2afc/val
16 | tar -xzf ./dataset/twoafc_val.tar.gz -C ./dataset/2afc
17 | rm ./dataset/twoafc_val.tar.gz
18 |
19 | # 2AFC Train set
20 | mkdir dataset/2afc/
21 | wget https://people.eecs.berkeley.edu/~rich.zhang/projects/2018_perceptual/dataset/twoafc_train.tar.gz -O ./dataset/twoafc_train.tar.gz
22 |
23 | mkdir dataset/2afc/train
24 | tar -xzf ./dataset/twoafc_train.tar.gz -C ./dataset/2afc
25 | rm ./dataset/twoafc_train.tar.gz
26 |
--------------------------------------------------------------------------------
/training/main_dah_beta0.75_num_secret1num_cover1_2023-08-01_H19-31-52/trainingLogs/train_44_log.txt:
--------------------------------------------------------------------------------
1 | Namespace(Hnet_C='', Hnet_S='', Rnet='', beta=0.75, beta1=0.5, bs_secret=44, channel_cover=3, channel_secret=3, checkpoint='', checkpoint_diff='', cover_dependent=False, cuda=True, dataset='train', debug=False, decay_round=10, epochs=120, hostname='amax-SYS-7049GP-TRT', imageSize=128, iters_per_epoch=2000, logFrequency=1000, loss='l2', lr=0.001, ngpu=2, no_cover=False, noise_cover=False, norm='batch', num_cover=1, num_secret=1, num_training=1, outckpts='./training/main_dah_beta0.75_num_secret1num_cover1_2023-08-01_H19-31-52/checkPoints', outcodes='./training/main_dah_beta0.75_num_secret1num_cover1_2023-08-01_H19-31-52/codes', outlogs='./training/main_dah_beta0.75_num_secret1num_cover1_2023-08-01_H19-31-52/trainingLogs', plain_cover=False, remark='main_dah', resultPicFrequency=100, test='', testPics='./training/', test_diff='', trainpics='./training/main_dah_beta0.75_num_secret1num_cover1_2023-08-01_H19-31-52/trainPics', validationpics='./training/main_dah_beta0.75_num_secret1num_cover1_2023-08-01_H19-31-52/validationPics', workers=8)
2 | training is beginning .......................................................
3 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deep-adaptive-hiding-network
2 | Official PyTorch implementation of "Deep adaptive hiding network for image hiding using attentive frequency extraction and gradual depth extraction"
3 |
4 | 
5 |
6 | 
7 | ## Requirements
8 | This code was developed and tested with Python3.6, Pytorch 1.5 and CUDA 10.2 on Ubuntu 18.04.5.
9 |
10 | ## Train DAH-Net on ImageNet datasets
11 | You are able to run the provided demo code.
12 |
13 | 1. Prepare the ImageNet datasets and visualization dataset.
14 |
15 | 2. Change the data path on lines 210-214 of train_dah.py.
16 |
17 | (Images for training exist in traindir and valdir, and images for visualization exist in coverdir and secretdir ).
18 |
19 | 3. ```sh ./scripts/train_dah.sh ```
20 |
21 | ## Citation
22 | If you found our research helpful or influential please consider citing
23 |
24 |
25 | ### BibTeX
26 | @article{zhang2023deep,
27 | title={Deep adaptive hiding network for image hiding using attentive frequency extraction and gradual depth extraction},
28 | author={Zhang, Le and Lu, Yao and Li, Jinxing and Chen, Fanglin and Lu, Guangming and Zhang, David},
29 | journal={Neural Computing and Applications},
30 | pages={1--19},
31 | year={2023},
32 | publisher={Springer}
33 | }
34 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/compute_dists_dirs.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import models
4 | from util import util
5 |
6 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
7 | parser.add_argument('-d0','--dir0', type=str, default='./imgs/ex_dir0')
8 | parser.add_argument('-d1','--dir1', type=str, default='./imgs/ex_dir1')
9 | parser.add_argument('-o','--out', type=str, default='./imgs/example_dists.txt')
10 | parser.add_argument('-v','--version', type=str, default='0.1')
11 | parser.add_argument('--use_gpu', action='store_true', help='turn on flag to use GPU')
12 |
13 | opt = parser.parse_args()
14 |
15 | ## Initializing the model
16 | model = models.PerceptualLoss(model='net-lin',net='alex',use_gpu=opt.use_gpu,version=opt.version)
17 |
18 | # crawl directories
19 | f = open(opt.out,'w')
20 | files = os.listdir(opt.dir0)
21 |
22 | for file in files:
23 | if(os.path.exists(os.path.join(opt.dir1,file))):
24 | # Load images
25 | img0 = util.im2tensor(util.load_image(os.path.join(opt.dir0,file))) # RGB image from [-1,1]
26 | img1 = util.im2tensor(util.load_image(os.path.join(opt.dir1,file)))
27 |
28 | if(opt.use_gpu):
29 | img0 = img0.cuda()
30 | img1 = img1.cuda()
31 |
32 | # Compute distance
33 | dist01 = model.forward(img0,img1)
34 | print('%s: %.3f'%(file,dist01))
35 | f.writelines('%s: %.6f\n'%(file,dist01))
36 |
37 | f.close()
38 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are met:
6 |
7 | * Redistributions of source code must retain the above copyright notice, this
8 | list of conditions and the following disclaimer.
9 |
10 | * Redistributions in binary form must reproduce the above copyright notice,
11 | this list of conditions and the following disclaimer in the documentation
12 | and/or other materials provided with the distribution.
13 |
14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24 |
25 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/util/util.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import numpy as np
4 | from PIL import Image
5 | import numpy as np
6 | import os
7 | import matplotlib.pyplot as plt
8 | import torch
9 |
10 | def load_image(path):
11 | if(path[-3:] == 'dng'):
12 | import rawpy
13 | with rawpy.imread(path) as raw:
14 | img = raw.postprocess()
15 | elif(path[-3:]=='bmp' or path[-3:]=='jpg' or path[-3:]=='png'):
16 | import cv2
17 | return cv2.imread(path)[:,:,::-1]
18 | else:
19 | img = (255*plt.imread(path)[:,:,:3]).astype('uint8')
20 |
21 | return img
22 |
23 | def save_image(image_numpy, image_path, ):
24 | image_pil = Image.fromarray(image_numpy)
25 | image_pil.save(image_path)
26 |
27 | def mkdirs(paths):
28 | if isinstance(paths, list) and not isinstance(paths, str):
29 | for path in paths:
30 | mkdir(path)
31 | else:
32 | mkdir(paths)
33 |
34 | def mkdir(path):
35 | if not os.path.exists(path):
36 | os.makedirs(path)
37 |
38 |
39 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
40 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
41 | image_numpy = image_tensor[0].cpu().float().numpy()
42 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
43 | return image_numpy.astype(imtype)
44 |
45 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
46 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
47 | return torch.Tensor((image / factor - cent)
48 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
49 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/data/custom_dataset_data_loader.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data
2 | from data.base_data_loader import BaseDataLoader
3 | import os
4 |
5 | def CreateDataset(dataroots,dataset_mode='2afc',load_size=64,):
6 | dataset = None
7 | if dataset_mode=='2afc': # human judgements
8 | from dataset.twoafc_dataset import TwoAFCDataset
9 | dataset = TwoAFCDataset()
10 | elif dataset_mode=='jnd': # human judgements
11 | from dataset.jnd_dataset import JNDDataset
12 | dataset = JNDDataset()
13 | else:
14 | raise ValueError("Dataset Mode [%s] not recognized."%self.dataset_mode)
15 |
16 | dataset.initialize(dataroots,load_size=load_size)
17 | return dataset
18 |
19 | class CustomDatasetDataLoader(BaseDataLoader):
20 | def name(self):
21 | return 'CustomDatasetDataLoader'
22 |
23 | def initialize(self, datafolders, dataroot='./dataset',dataset_mode='2afc',load_size=64,batch_size=1,serial_batches=True, nThreads=1):
24 | BaseDataLoader.initialize(self)
25 | if(not isinstance(datafolders,list)):
26 | datafolders = [datafolders,]
27 | data_root_folders = [os.path.join(dataroot,datafolder) for datafolder in datafolders]
28 | self.dataset = CreateDataset(data_root_folders,dataset_mode=dataset_mode,load_size=load_size)
29 | self.dataloader = torch.utils.data.DataLoader(
30 | self.dataset,
31 | batch_size=batch_size,
32 | shuffle=not serial_batches,
33 | num_workers=int(nThreads))
34 |
35 | def load_data(self):
36 | return self.dataloader
37 |
38 | def __len__(self):
39 | return len(self.dataset)
40 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/perceptual_loss.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 |
4 | import sys
5 | import scipy
6 | import scipy.misc
7 | import numpy as np
8 | import torch
9 | from torch.autograd import Variable
10 | import models
11 |
12 | use_gpu = True
13 |
14 | ref_path = './imgs/ex_ref.png'
15 | pred_path = './imgs/ex_p1.png'
16 |
17 | ref_img = scipy.misc.imread(ref_path).transpose(2, 0, 1) / 255.
18 | pred_img = scipy.misc.imread(pred_path).transpose(2, 0, 1) / 255.
19 |
20 | # Torchify
21 | ref = Variable(torch.FloatTensor(ref_img)[None,:,:,:])
22 | pred = Variable(torch.FloatTensor(pred_img)[None,:,:,:], requires_grad=True)
23 |
24 | loss_fn = models.PerceptualLoss(model='net-lin', net='vgg', use_gpu=use_gpu)
25 | optimizer = torch.optim.Adam([pred,], lr=1e-3, betas=(0.9, 0.999))
26 |
27 | import matplotlib.pyplot as plt
28 | plt.ion()
29 | fig = plt.figure(1)
30 | ax = fig.add_subplot(131)
31 | ax.imshow(ref_img.transpose(1, 2, 0))
32 | ax.set_title('target')
33 | ax = fig.add_subplot(133)
34 | ax.imshow(pred_img.transpose(1, 2, 0))
35 | ax.set_title('initialization')
36 |
37 | for i in range(1000):
38 | dist = loss_fn.forward(pred, ref, normalize=True)
39 | optimizer.zero_grad()
40 | dist.backward()
41 | optimizer.step()
42 | pred.data = torch.clamp(pred.data, 0, 1)
43 |
44 | if i % 10 == 0:
45 | print('iter %d, dist %.3g' % (i, dist.view(-1).data.cpu().numpy()[0]))
46 | pred_img = pred[0].data.cpu().numpy().transpose(1, 2, 0)
47 | pred_img = np.clip(pred_img, 0, 1)
48 | ax = fig.add_subplot(132)
49 | ax.imshow(pred_img)
50 | ax.set_title('iter %d, dist %.3f' % (i, dist.view(-1).data.cpu().numpy()[0]))
51 | plt.pause(5e-2)
52 | # plt.imsave('imgs_saved/%04d.jpg'%i,pred_img)
53 |
54 |
55 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/models/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch.autograd import Variable
4 | from pdb import set_trace as st
5 | from IPython import embed
6 |
7 | class BaseModel():
8 | def __init__(self):
9 | pass;
10 |
11 | def name(self):
12 | return 'BaseModel'
13 |
14 | def initialize(self, use_gpu=True, gpu_ids=[0]):
15 | self.use_gpu = use_gpu
16 | self.gpu_ids = gpu_ids
17 |
18 | def forward(self):
19 | pass
20 |
21 | def get_image_paths(self):
22 | pass
23 |
24 | def optimize_parameters(self):
25 | pass
26 |
27 | def get_current_visuals(self):
28 | return self.input
29 |
30 | def get_current_errors(self):
31 | return {}
32 |
33 | def save(self, label):
34 | pass
35 |
36 | # helper saving function that can be used by subclasses
37 | def save_network(self, network, path, network_label, epoch_label):
38 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
39 | save_path = os.path.join(path, save_filename)
40 | torch.save(network.state_dict(), save_path)
41 |
42 | # helper loading function that can be used by subclasses
43 | def load_network(self, network, network_label, epoch_label):
44 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
45 | save_path = os.path.join(self.save_dir, save_filename)
46 | # print('Loading network from %s'%save_path)
47 | network.load_state_dict(torch.load(save_path))
48 |
49 | def update_learning_rate():
50 | pass
51 |
52 | def get_image_paths(self):
53 | return self.image_paths
54 |
55 | def save_done(self, flag=False):
56 | np.save(os.path.join(self.save_dir, 'done_flag'),flag)
57 | np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i')
58 |
59 |
--------------------------------------------------------------------------------
/models/RevealNet.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: yongzhi li
4 | @contact: yongzhili@vip.qq.com
5 |
6 | @version: 1.0
7 | @file: Reveal.py
8 | @time: 2018/3/20
9 |
10 | """
11 |
12 | import torch.nn as nn
13 |
14 |
15 | class RevealNet(nn.Module):
16 | def __init__(self, input_nc, output_nc, nhf=64, norm_layer=None, output_function=nn.Sigmoid):
17 | super(RevealNet, self).__init__()
18 | # input is (3) x 256 x 256
19 |
20 | self.conv1 = nn.Conv2d(input_nc, nhf, 3, 1, 1)
21 | self.conv2 = nn.Conv2d(nhf, nhf * 2, 3, 1, 1)
22 | self.conv3 = nn.Conv2d(nhf * 2, nhf * 4, 3, 1, 1)
23 | self.conv4 = nn.Conv2d(nhf * 4, nhf * 2, 3, 1, 1)
24 | self.conv5 = nn.Conv2d(nhf * 2, nhf, 3, 1, 1)
25 | self.conv6 = nn.Conv2d(nhf, output_nc, 3, 1, 1)
26 | self.output=output_function()
27 | self.relu = nn.ReLU(True)
28 |
29 | self.norm_layer = norm_layer
30 | if norm_layer != None:
31 | self.norm1 = norm_layer(nhf)
32 | self.norm2 = norm_layer(nhf*2)
33 | self.norm3 = norm_layer(nhf*4)
34 | self.norm4 = norm_layer(nhf*2)
35 | self.norm5 = norm_layer(nhf)
36 |
37 | def forward(self, input):
38 |
39 | if self.norm_layer != None:
40 | x=self.relu(self.norm1(self.conv1(input)))
41 | x=self.relu(self.norm2(self.conv2(x)))
42 | x=self.relu(self.norm3(self.conv3(x)))
43 | x=self.relu(self.norm4(self.conv4(x)))
44 | x=self.relu(self.norm5(self.conv5(x)))
45 | x=self.output(self.conv6(x))
46 | else:
47 | x=self.relu(self.conv1(input))
48 | x=self.relu(self.conv2(x))
49 | x=self.relu(self.conv3(x))
50 | x=self.relu(self.conv4(x))
51 | x=self.relu(self.conv5(x))
52 | x=self.output(self.conv6(x))
53 |
54 | return x
55 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:9.0-base-ubuntu16.04
2 |
3 | LABEL maintainer="Seyoung Park "
4 |
5 | # This Dockerfile is forked from Tensorflow Dockerfile
6 |
7 | # Pick up some PyTorch gpu dependencies
8 | RUN apt-get update && apt-get install -y --no-install-recommends \
9 | build-essential \
10 | cuda-command-line-tools-9-0 \
11 | cuda-cublas-9-0 \
12 | cuda-cufft-9-0 \
13 | cuda-curand-9-0 \
14 | cuda-cusolver-9-0 \
15 | cuda-cusparse-9-0 \
16 | curl \
17 | libcudnn7=7.1.4.18-1+cuda9.0 \
18 | libfreetype6-dev \
19 | libhdf5-serial-dev \
20 | libpng12-dev \
21 | libzmq3-dev \
22 | pkg-config \
23 | python \
24 | python-dev \
25 | rsync \
26 | software-properties-common \
27 | unzip \
28 | && \
29 | apt-get clean && \
30 | rm -rf /var/lib/apt/lists/*
31 |
32 |
33 | # Install miniconda
34 | RUN apt-get update && apt-get install -y --no-install-recommends \
35 | wget && \
36 | MINICONDA="Miniconda3-latest-Linux-x86_64.sh" && \
37 | wget --quiet https://repo.continuum.io/miniconda/$MINICONDA && \
38 | bash $MINICONDA -b -p /miniconda && \
39 | rm -f $MINICONDA
40 | ENV PATH /miniconda/bin:$PATH
41 |
42 | # Install PyTorch
43 | RUN conda update -n base conda && \
44 | conda install pytorch torchvision cuda90 -c pytorch
45 |
46 | # Install PerceptualSimilarity dependencies
47 | RUN conda install numpy scipy jupyter matplotlib && \
48 | conda install -c conda-forge scikit-image && \
49 | apt-get install -y python-qt4 && \
50 | pip install opencv-python
51 |
52 | # For CUDA profiling, TensorFlow requires CUPTI. Maybe PyTorch needs this too.
53 | ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH
54 |
55 | # IPython
56 | EXPOSE 8888
57 |
58 | WORKDIR "/notebooks"
59 |
60 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/test_network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from util import util
3 | import models
4 | from models import dist_model as dm
5 | from IPython import embed
6 |
7 | use_gpu = False # Whether to use GPU
8 | spatial = True # Return a spatial map of perceptual distance.
9 |
10 | # Linearly calibrated models (LPIPS)
11 | model = models.PerceptualLoss(model='net-lin', net='alex', use_gpu=use_gpu, spatial=spatial)
12 | # Can also set net = 'squeeze' or 'vgg'
13 |
14 | # Off-the-shelf uncalibrated networks
15 | # model = models.PerceptualLoss(model='net', net='alex', use_gpu=use_gpu, spatial=spatial)
16 | # Can also set net = 'squeeze' or 'vgg'
17 |
18 | # Low-level metrics
19 | # model = models.PerceptualLoss(model='L2', colorspace='Lab', use_gpu=use_gpu)
20 | # model = models.PerceptualLoss(model='ssim', colorspace='RGB', use_gpu=use_gpu)
21 |
22 | ## Example usage with dummy tensors
23 | dummy_im0 = torch.zeros(1,3,64,64) # image should be RGB, normalized to [-1,1]
24 | dummy_im1 = torch.zeros(1,3,64,64)
25 | if(use_gpu):
26 | dummy_im0 = dummy_im0.cuda()
27 | dummy_im1 = dummy_im1.cuda()
28 | dist = model.forward(dummy_im0,dummy_im1)
29 |
30 | ## Example usage with images
31 | ex_ref = util.im2tensor(util.load_image('./imgs/ex_ref.png'))
32 | ex_p0 = util.im2tensor(util.load_image('./imgs/ex_p0.png'))
33 | ex_p1 = util.im2tensor(util.load_image('./imgs/ex_p1.png'))
34 | if(use_gpu):
35 | ex_ref = ex_ref.cuda()
36 | ex_p0 = ex_p0.cuda()
37 | ex_p1 = ex_p1.cuda()
38 |
39 | ex_d0 = model.forward(ex_ref,ex_p0)
40 | ex_d1 = model.forward(ex_ref,ex_p1)
41 |
42 | if not spatial:
43 | print('Distances: (%.3f, %.3f)'%(ex_d0, ex_d1))
44 | else:
45 | print('Distances: (%.3f, %.3f)'%(ex_d0.mean(), ex_d1.mean())) # The mean distance is approximately the same as the non-spatial distance
46 |
47 | # Visualize a spatially-varying distance map between ex_p0 and ex_ref
48 | import pylab
49 | pylab.imshow(ex_d0[0,0,...].data.cpu().numpy())
50 | pylab.show()
51 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/compute_dists_pair.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import models
4 | from util import util
5 | import numpy as np
6 | from IPython import embed
7 |
8 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
9 | parser.add_argument('-d','--dir', type=str, default='./imgs/ex_dir0')
10 | parser.add_argument('-o','--out', type=str, default='./imgs/example_dists.txt')
11 | parser.add_argument('-v','--version', type=str, default='0.1')
12 | parser.add_argument('--all-pairs', action='store_true', help='turn on to test all N(N-1)/2 pairs, leave off to just do consecutive pairs (N-1)')
13 | parser.add_argument('-N', type=int, default=None)
14 | parser.add_argument('--use_gpu', action='store_true', help='turn on flag to use GPU')
15 |
16 | opt = parser.parse_args()
17 |
18 | ## Initializing the model
19 | model = models.PerceptualLoss(model='net-lin',net='alex',use_gpu=opt.use_gpu,version=opt.version)
20 |
21 | # crawl directories
22 | f = open(opt.out,'w')
23 | files = os.listdir(opt.dir)
24 | if(opt.N is not None):
25 | files = files[:opt.N]
26 | F = len(files)
27 |
28 | dists = []
29 | for (ff,file) in enumerate(files[:-1]):
30 | img0 = util.im2tensor(util.load_image(os.path.join(opt.dir,file))) # RGB image from [-1,1]
31 | if(opt.use_gpu):
32 | img0 = img0.cuda()
33 |
34 | if(opt.all_pairs):
35 | files1 = files[ff+1:]
36 | else:
37 | files1 = [files[ff+1],]
38 |
39 | for file1 in files1:
40 | img1 = util.im2tensor(util.load_image(os.path.join(opt.dir,file1)))
41 |
42 | if(opt.use_gpu):
43 | img1 = img1.cuda()
44 |
45 | # Compute distance
46 | dist01 = model.forward(img0,img1)
47 | print('(%s,%s): %.3f'%(file,file1,dist01))
48 | f.writelines('(%s,%s): %.6f\n'%(file,file1,dist01))
49 |
50 | dists.append(dist01.item())
51 |
52 | avg_dist = np.mean(np.array(dists))
53 | stderr_dist = np.std(np.array(dists))/np.sqrt(len(dists))
54 |
55 | print('Avg: %.5f +/- %.5f'%(avg_dist,stderr_dist))
56 | f.writelines('Avg: %.6f +/- %.6f'%(avg_dist,stderr_dist))
57 |
58 | f.close()
59 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/data/dataset/jnd_dataset.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import torchvision.transforms as transforms
3 | from data.dataset.base_dataset import BaseDataset
4 | from data.image_folder import make_dataset
5 | from PIL import Image
6 | import numpy as np
7 | import torch
8 | from IPython import embed
9 |
10 | class JNDDataset(BaseDataset):
11 | def initialize(self, dataroot, load_size=64):
12 | self.root = dataroot
13 | self.load_size = load_size
14 |
15 | self.dir_p0 = os.path.join(self.root, 'p0')
16 | self.p0_paths = make_dataset(self.dir_p0)
17 | self.p0_paths = sorted(self.p0_paths)
18 |
19 | self.dir_p1 = os.path.join(self.root, 'p1')
20 | self.p1_paths = make_dataset(self.dir_p1)
21 | self.p1_paths = sorted(self.p1_paths)
22 |
23 | transform_list = []
24 | transform_list.append(transforms.Scale(load_size))
25 | transform_list += [transforms.ToTensor(),
26 | transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))]
27 |
28 | self.transform = transforms.Compose(transform_list)
29 |
30 | # judgement directory
31 | self.dir_S = os.path.join(self.root, 'same')
32 | self.same_paths = make_dataset(self.dir_S,mode='np')
33 | self.same_paths = sorted(self.same_paths)
34 |
35 | def __getitem__(self, index):
36 | p0_path = self.p0_paths[index]
37 | p0_img_ = Image.open(p0_path).convert('RGB')
38 | p0_img = self.transform(p0_img_)
39 |
40 | p1_path = self.p1_paths[index]
41 | p1_img_ = Image.open(p1_path).convert('RGB')
42 | p1_img = self.transform(p1_img_)
43 |
44 | same_path = self.same_paths[index]
45 | same_img = np.load(same_path).reshape((1,1,1,)) # [0,1]
46 |
47 | same_img = torch.FloatTensor(same_img)
48 |
49 | return {'p0': p0_img, 'p1': p1_img, 'same': same_img,
50 | 'p0_path': p0_path, 'p1_path': p1_path, 'same_path': same_path}
51 |
52 | def __len__(self):
53 | return len(self.p0_paths)
54 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/util/html.py:
--------------------------------------------------------------------------------
1 | import dominate
2 | from dominate.tags import *
3 | import os
4 |
5 |
6 | class HTML:
7 | def __init__(self, web_dir, title, image_subdir='', reflesh=0):
8 | self.title = title
9 | self.web_dir = web_dir
10 | # self.img_dir = os.path.join(self.web_dir, )
11 | self.img_subdir = image_subdir
12 | self.img_dir = os.path.join(self.web_dir, image_subdir)
13 | if not os.path.exists(self.web_dir):
14 | os.makedirs(self.web_dir)
15 | if not os.path.exists(self.img_dir):
16 | os.makedirs(self.img_dir)
17 | # print(self.img_dir)
18 |
19 | self.doc = dominate.document(title=title)
20 | if reflesh > 0:
21 | with self.doc.head:
22 | meta(http_equiv="reflesh", content=str(reflesh))
23 |
24 | def get_image_dir(self):
25 | return self.img_dir
26 |
27 | def add_header(self, str):
28 | with self.doc:
29 | h3(str)
30 |
31 | def add_table(self, border=1):
32 | self.t = table(border=border, style="table-layout: fixed;")
33 | self.doc.add(self.t)
34 |
35 | def add_images(self, ims, txts, links, width=400):
36 | self.add_table()
37 | with self.t:
38 | with tr():
39 | for im, txt, link in zip(ims, txts, links):
40 | with td(style="word-wrap: break-word;", halign="center", valign="top"):
41 | with p():
42 | with a(href=os.path.join(link)):
43 | img(style="width:%dpx" % width, src=os.path.join(im))
44 | br()
45 | p(txt)
46 |
47 | def save(self,file='index'):
48 | html_file = '%s/%s.html' % (self.web_dir,file)
49 | f = open(html_file, 'wt')
50 | f.write(self.doc.render())
51 | f.close()
52 |
53 |
54 | if __name__ == '__main__':
55 | html = HTML('web/', 'test_html')
56 | html.add_header('hello world')
57 |
58 | ims = []
59 | txts = []
60 | links = []
61 | for n in range(4):
62 | ims.append('image_%d.png' % n)
63 | txts.append('text_%d' % n)
64 | links.append('image_%d.png' % n)
65 | html.add_images(ims, txts, links)
66 | html.save()
67 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/data/image_folder.py:
--------------------------------------------------------------------------------
1 | ################################################################################
2 | # Code from
3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
4 | # Modified the original code so that it also loads images from the current
5 | # directory as well as the subdirectories
6 | ################################################################################
7 |
8 | import torch.utils.data as data
9 |
10 | from PIL import Image
11 | import os
12 | import os.path
13 |
14 | IMG_EXTENSIONS = [
15 | '.jpg', '.JPG', '.jpeg', '.JPEG',
16 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
17 | ]
18 |
19 | NP_EXTENSIONS = ['.npy',]
20 |
21 | def is_image_file(filename, mode='img'):
22 | if(mode=='img'):
23 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
24 | elif(mode=='np'):
25 | return any(filename.endswith(extension) for extension in NP_EXTENSIONS)
26 |
27 | def make_dataset(dirs, mode='img'):
28 | if(not isinstance(dirs,list)):
29 | dirs = [dirs,]
30 |
31 | images = []
32 | for dir in dirs:
33 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
34 | for root, _, fnames in sorted(os.walk(dir)):
35 | for fname in fnames:
36 | if is_image_file(fname, mode=mode):
37 | path = os.path.join(root, fname)
38 | images.append(path)
39 |
40 | # print("Found %i images in %s"%(len(images),root))
41 | return images
42 |
43 | def default_loader(path):
44 | return Image.open(path).convert('RGB')
45 |
46 | class ImageFolder(data.Dataset):
47 | def __init__(self, root, transform=None, return_paths=False,
48 | loader=default_loader):
49 | imgs = make_dataset(root)
50 | if len(imgs) == 0:
51 | raise(RuntimeError("Found 0 images in: " + root + "\n"
52 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
53 |
54 | self.root = root
55 | self.imgs = imgs
56 | self.transform = transform
57 | self.return_paths = return_paths
58 | self.loader = loader
59 |
60 | def __getitem__(self, index):
61 | path = self.imgs[index]
62 | img = self.loader(path)
63 | if self.transform is not None:
64 | img = self.transform(img)
65 | if self.return_paths:
66 | return img, path
67 | else:
68 | return img
69 |
70 | def __len__(self):
71 | return len(self.imgs)
72 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/data/dataset/twoafc_dataset.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import torchvision.transforms as transforms
3 | from data.dataset.base_dataset import BaseDataset
4 | from data.image_folder import make_dataset
5 | from PIL import Image
6 | import numpy as np
7 | import torch
8 | # from IPython import embed
9 |
10 | class TwoAFCDataset(BaseDataset):
11 | def initialize(self, dataroots, load_size=64):
12 | if(not isinstance(dataroots,list)):
13 | dataroots = [dataroots,]
14 | self.roots = dataroots
15 | self.load_size = load_size
16 |
17 | # image directory
18 | self.dir_ref = [os.path.join(root, 'ref') for root in self.roots]
19 | self.ref_paths = make_dataset(self.dir_ref)
20 | self.ref_paths = sorted(self.ref_paths)
21 |
22 | self.dir_p0 = [os.path.join(root, 'p0') for root in self.roots]
23 | self.p0_paths = make_dataset(self.dir_p0)
24 | self.p0_paths = sorted(self.p0_paths)
25 |
26 | self.dir_p1 = [os.path.join(root, 'p1') for root in self.roots]
27 | self.p1_paths = make_dataset(self.dir_p1)
28 | self.p1_paths = sorted(self.p1_paths)
29 |
30 | transform_list = []
31 | transform_list.append(transforms.Scale(load_size))
32 | transform_list += [transforms.ToTensor(),
33 | transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))]
34 |
35 | self.transform = transforms.Compose(transform_list)
36 |
37 | # judgement directory
38 | self.dir_J = [os.path.join(root, 'judge') for root in self.roots]
39 | self.judge_paths = make_dataset(self.dir_J,mode='np')
40 | self.judge_paths = sorted(self.judge_paths)
41 |
42 | def __getitem__(self, index):
43 | p0_path = self.p0_paths[index]
44 | p0_img_ = Image.open(p0_path).convert('RGB')
45 | p0_img = self.transform(p0_img_)
46 |
47 | p1_path = self.p1_paths[index]
48 | p1_img_ = Image.open(p1_path).convert('RGB')
49 | p1_img = self.transform(p1_img_)
50 |
51 | ref_path = self.ref_paths[index]
52 | ref_img_ = Image.open(ref_path).convert('RGB')
53 | ref_img = self.transform(ref_img_)
54 |
55 | judge_path = self.judge_paths[index]
56 | # judge_img = (np.load(judge_path)*2.-1.).reshape((1,1,1,)) # [-1,1]
57 | judge_img = np.load(judge_path).reshape((1,1,1,)) # [0,1]
58 |
59 | judge_img = torch.FloatTensor(judge_img)
60 |
61 | return {'p0': p0_img, 'p1': p1_img, 'ref': ref_img, 'judge': judge_img,
62 | 'p0_path': p0_path, 'p1_path': p1_path, 'ref_path': ref_path, 'judge_path': judge_path}
63 |
64 | def __len__(self):
65 | return len(self.p0_paths)
66 |
--------------------------------------------------------------------------------
/models/module.py:
--------------------------------------------------------------------------------
1 | """
2 | Harmonic block definition.
3 |
4 | Licensed under the BSD License [see LICENSE for details].
5 |
6 | Written by Matej Ulicny
7 | """
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | import math
13 | import numpy as np
14 |
15 |
16 | def dct_filters(k=3, groups=1, expand_dim=1, level=None, DC=True, l1_norm=True):
17 | if level is None:
18 | nf = k ** 2 - int(not DC)
19 | else:
20 | if level <= k:
21 | nf = level * (level + 1) // 2 - int(not DC)
22 | else:
23 | r = 2 * k - 1 - level
24 | nf = k ** 2 - r * (r + 1) // 2 - int(not DC)
25 | filter_bank = np.zeros((nf, k, k), dtype=np.float32)
26 | m = 0
27 | for i in range(k):
28 | for j in range(k):
29 | if (not DC and i == 0 and j == 0) or (not level is None and i + j >= level):
30 | continue
31 | for x in range(k):
32 | for y in range(k):
33 | filter_bank[m, x, y] = math.cos((math.pi * (x + .5) * i) / k) * math.cos(
34 | (math.pi * (y + .5) * j) / k)
35 | if l1_norm:
36 | filter_bank[m, :, :] /= np.sum(np.abs(filter_bank[m, :, :]))
37 | else:
38 | ai = 1.0 if i > 0 else 1.0 / math.sqrt(2.0)
39 | aj = 1.0 if j > 0 else 1.0 / math.sqrt(2.0)
40 | filter_bank[m, :, :] *= (2.0 / k) * ai * aj
41 | m += 1
42 | #print(filter_bank.shape)
43 | filter_bank = np.tile(np.expand_dims(filter_bank, axis=expand_dim), (groups, 1, 1, 1))
44 | #print(filter_bank.shape)
45 | return torch.FloatTensor(filter_bank)
46 |
47 |
48 | class Harm2d(nn.Module):
49 |
50 | def __init__(self, ni, no, kernel_size, stride=1, padding=0, bias=True, dilation=1, use_bn=True, level=None,
51 | DC=True, groups=1):
52 | super(Harm2d, self).__init__()
53 | self.ni = ni
54 | self.kernel_size = kernel_size
55 | self.stride = stride
56 | self.padding = padding
57 | self.dilation = dilation
58 | self.groups = groups
59 | self.dct = nn.Parameter(
60 | dct_filters(k=kernel_size, groups=ni, expand_dim=1, level=level, DC=DC), requires_grad=False)
61 |
62 | nf = self.dct.shape[0] // ni #if use_bn else self.dct.shape[1]
63 | self.bn = nn.BatchNorm2d(ni * nf, affine=False)
64 | '''self.weight = nn.Parameter(
65 | nn.init.kaiming_normal_(torch.Tensor(no, ni // self.groups * nf, 1, 1), mode='fan_out',
66 | nonlinearity='relu'))'''
67 |
68 | self.bias = nn.Parameter(nn.init.zeros_(torch.Tensor(no))) if bias else None
69 |
70 | def forward(self, x):
71 | #print('self.dct', self.dct.shape)
72 | #print('x', x.shape)
73 | #print('self.ni', self.ni)
74 | #print(x.size(1))
75 | x = F.conv2d(x, self.dct, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=x.size(1))
76 | #print('888')
77 |
78 | x = self.bn(x)
79 | #x = F.conv2d(x, self.weight, bias=self.bias, padding=0, groups=self.groups)
80 | return x
81 |
82 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/test_dataset_model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from models import dist_model as dm
3 | from data import data_loader as dl
4 | import argparse
5 | from IPython import embed
6 |
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument('--dataset_mode', type=str, default='2afc', help='[2afc,jnd]')
9 | parser.add_argument('--datasets', type=str, nargs='+', default=['val/traditional','val/cnn','val/superres','val/deblur','val/color','val/frameinterp'], help='datasets to test - for jnd mode: [val/traditional],[val/cnn]; for 2afc mode: [train/traditional],[train/cnn],[train/mix],[val/traditional],[val/cnn],[val/color],[val/deblur],[val/frameinterp],[val/superres]')
10 | parser.add_argument('--model', type=str, default='net-lin', help='distance model type [net-lin] for linearly calibrated net, [net] for off-the-shelf network, [l2] for euclidean distance, [ssim] for Structured Similarity Image Metric')
11 | parser.add_argument('--net', type=str, default='alex', help='[squeeze], [alex], or [vgg] for network architectures')
12 | parser.add_argument('--colorspace', type=str, default='Lab', help='[Lab] or [RGB] for colorspace to use for l2, ssim model types')
13 | parser.add_argument('--batch_size', type=int, default=50, help='batch size to test image patches in')
14 | parser.add_argument('--use_gpu', action='store_true', help='turn on flag to use GPU')
15 | parser.add_argument('--gpu_ids', type=int, nargs='+', default=[0], help='gpus to use')
16 | parser.add_argument('--nThreads', type=int, default=4, help='number of threads to use in data loader')
17 |
18 | parser.add_argument('--model_path', type=str, default=None, help='location of model, will default to ./weights/v[version]/[net_name].pth')
19 |
20 | parser.add_argument('--from_scratch', action='store_true', help='model was initialized from scratch')
21 | parser.add_argument('--train_trunk', action='store_true', help='model trunk was trained/tuned')
22 | parser.add_argument('--version', type=str, default='0.1', help='v0.1 is latest, v0.0 was original release')
23 |
24 | opt = parser.parse_args()
25 | if(opt.model in ['l2','ssim']):
26 | opt.batch_size = 1
27 |
28 | # initialize model
29 | model = dm.DistModel()
30 | # model.initialize(model=opt.model,net=opt.net,colorspace=opt.colorspace,model_path=opt.model_path,use_gpu=opt.use_gpu)
31 | model.initialize(model=opt.model, net=opt.net, colorspace=opt.colorspace,
32 | model_path=opt.model_path, use_gpu=opt.use_gpu, pnet_rand=opt.from_scratch, pnet_tune=opt.train_trunk,
33 | version=opt.version, gpu_ids=opt.gpu_ids)
34 |
35 | if(opt.model in ['net-lin','net']):
36 | print('Testing model [%s]-[%s]'%(opt.model,opt.net))
37 | elif(opt.model in ['l2','ssim']):
38 | print('Testing model [%s]-[%s]'%(opt.model,opt.colorspace))
39 |
40 | # initialize data loader
41 | for dataset in opt.datasets:
42 | data_loader = dl.CreateDataLoader(dataset,dataset_mode=opt.dataset_mode, batch_size=opt.batch_size, nThreads=opt.nThreads)
43 |
44 | # evaluate model on data
45 | if(opt.dataset_mode=='2afc'):
46 | (score, results_verbose) = dm.score_2afc_dataset(data_loader, model.forward, name=dataset)
47 | elif(opt.dataset_mode=='jnd'):
48 | (score, results_verbose) = dm.score_jnd_dataset(data_loader, model.forward, name=dataset)
49 |
50 | # print results
51 | print(' Dataset [%s]: %.2f'%(dataset,100.*score))
52 |
53 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/train.py:
--------------------------------------------------------------------------------
1 | import torch.backends.cudnn as cudnn
2 | cudnn.benchmark=False
3 |
4 | import numpy as np
5 | import time
6 | import os
7 | from models import dist_model as dm
8 | from data import data_loader as dl
9 | import argparse
10 | from util.visualizer import Visualizer
11 | from IPython import embed
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('--datasets', type=str, nargs='+', default=['train/traditional','train/cnn','train/mix'], help='datasets to train on: [train/traditional],[train/cnn],[train/mix],[val/traditional],[val/cnn],[val/color],[val/deblur],[val/frameinterp],[val/superres]')
15 | parser.add_argument('--model', type=str, default='net-lin', help='distance model type [net-lin] for linearly calibrated net, [net] for off-the-shelf network, [l2] for euclidean distance, [ssim] for Structured Similarity Image Metric')
16 | parser.add_argument('--net', type=str, default='alex', help='[squeeze], [alex], or [vgg] for network architectures')
17 | parser.add_argument('--batch_size', type=int, default=50, help='batch size to test image patches in')
18 | parser.add_argument('--use_gpu', action='store_true', help='turn on flag to use GPU')
19 | parser.add_argument('--gpu_ids', type=int, nargs='+', default=[0], help='gpus to use')
20 |
21 | parser.add_argument('--nThreads', type=int, default=4, help='number of threads to use in data loader')
22 | parser.add_argument('--nepoch', type=int, default=5, help='# epochs at base learning rate')
23 | parser.add_argument('--nepoch_decay', type=int, default=5, help='# additional epochs at linearly learning rate')
24 | parser.add_argument('--display_freq', type=int, default=5000, help='frequency (in instances) of showing training results on screen')
25 | parser.add_argument('--print_freq', type=int, default=5000, help='frequency (in instances) of showing training results on console')
26 | parser.add_argument('--save_latest_freq', type=int, default=20000, help='frequency (in instances) of saving the latest results')
27 | parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs')
28 | parser.add_argument('--display_id', type=int, default=0, help='window id of the visdom display, [0] for no displaying')
29 | parser.add_argument('--display_winsize', type=int, default=256, help='display window size')
30 | parser.add_argument('--display_port', type=int, default=8001, help='visdom display port')
31 | parser.add_argument('--use_html', action='store_true', help='save off html pages')
32 | parser.add_argument('--checkpoints_dir', type=str, default='checkpoints', help='checkpoints directory')
33 | parser.add_argument('--name', type=str, default='tmp', help='directory name for training')
34 |
35 | parser.add_argument('--from_scratch', action='store_true', help='model was initialized from scratch')
36 | parser.add_argument('--train_trunk', action='store_true', help='model trunk was trained/tuned')
37 | parser.add_argument('--train_plot', action='store_true', help='plot saving')
38 |
39 | opt = parser.parse_args()
40 | opt.save_dir = os.path.join(opt.checkpoints_dir,opt.name)
41 | if(not os.path.exists(opt.save_dir)):
42 | os.mkdir(opt.save_dir)
43 |
44 | # initialize model
45 | model = dm.DistModel()
46 | model.initialize(model=opt.model, net=opt.net, use_gpu=opt.use_gpu, is_train=True,
47 | pnet_rand=opt.from_scratch, pnet_tune=opt.train_trunk, gpu_ids=opt.gpu_ids)
48 |
49 | # load data from all training sets
50 | data_loader = dl.CreateDataLoader(opt.datasets,dataset_mode='2afc', batch_size=opt.batch_size, serial_batches=False, nThreads=opt.nThreads)
51 | dataset = data_loader.load_data()
52 | dataset_size = len(data_loader)
53 | D = len(dataset)
54 | print('Loading %i instances from'%dataset_size,opt.datasets)
55 | visualizer = Visualizer(opt)
56 |
57 | total_steps = 0
58 | fid = open(os.path.join(opt.checkpoints_dir,opt.name,'train_log.txt'),'w+')
59 | for epoch in range(1, opt.nepoch + opt.nepoch_decay + 1):
60 | epoch_start_time = time.time()
61 | for i, data in enumerate(dataset):
62 | iter_start_time = time.time()
63 | total_steps += opt.batch_size
64 | epoch_iter = total_steps - dataset_size * (epoch - 1)
65 |
66 | model.set_input(data)
67 | model.optimize_parameters()
68 |
69 | if total_steps % opt.display_freq == 0:
70 | visualizer.display_current_results(model.get_current_visuals(), epoch)
71 |
72 | if total_steps % opt.print_freq == 0:
73 | errors = model.get_current_errors()
74 | t = (time.time()-iter_start_time)/opt.batch_size
75 | t2o = (time.time()-epoch_start_time)/3600.
76 | t2 = t2o*D/(i+.0001)
77 | visualizer.print_current_errors(epoch, epoch_iter, errors, t, t2=t2, t2o=t2o, fid=fid)
78 |
79 | for key in errors.keys():
80 | visualizer.plot_current_errors_save(epoch, float(epoch_iter)/dataset_size, opt, errors, keys=[key,], name=key, to_plot=opt.train_plot)
81 |
82 | if opt.display_id > 0:
83 | visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors)
84 |
85 | if total_steps % opt.save_latest_freq == 0:
86 | print('saving the latest model (epoch %d, total_steps %d)' %
87 | (epoch, total_steps))
88 | model.save(opt.save_dir, 'latest')
89 |
90 | if epoch % opt.save_epoch_freq == 0:
91 | print('saving the model at the end of epoch %d, iters %d' %
92 | (epoch, total_steps))
93 | model.save(opt.save_dir, 'latest')
94 | model.save(opt.save_dir, epoch)
95 |
96 | print('End of epoch %d / %d \t Time Taken: %d sec' %
97 | (epoch, opt.nepoch + opt.nepoch_decay, time.time() - epoch_start_time))
98 |
99 | if epoch > opt.nepoch:
100 | model.update_learning_rate(opt.nepoch_decay)
101 |
102 | # model.save_done(True)
103 | fid.close()
104 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/models/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import numpy as np
7 | #from skimage.measure import compare_ssim
8 | from skimage.metrics import structural_similarity as SSIM
9 | import torch
10 | from torch.autograd import Variable
11 |
12 | from PerceptualSimilarity.models import dist_model
13 |
14 | class PerceptualLoss(torch.nn.Module):
15 | def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0], version='0.1'): # VGG using our perceptually-learned weights (LPIPS metric)
16 | # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss
17 | super(PerceptualLoss, self).__init__()
18 | # print('Setting up Perceptual loss...')
19 | self.use_gpu = use_gpu
20 | self.spatial = spatial
21 | self.gpu_ids = gpu_ids
22 | self.model = dist_model.DistModel()
23 | self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids, version=version)
24 | # print('...[%s] initialized'%self.model.name())
25 | # print('...Done')
26 |
27 | def forward(self, pred, target, normalize=False):
28 | """
29 | Pred and target are Variables.
30 | If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1]
31 | If normalize is False, assumes the images are already between [-1,+1]
32 |
33 | Inputs pred and target are Nx3xHxW
34 | Output pytorch Variable N long
35 | """
36 |
37 | if normalize:
38 | target = 2 * target - 1
39 | pred = 2 * pred - 1
40 |
41 | return self.model.forward(target, pred)
42 |
43 | def normalize_tensor(in_feat,eps=1e-10):
44 | norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
45 | return in_feat/(norm_factor+eps)
46 |
47 | def l2(p0, p1, range=255.):
48 | return .5*np.mean((p0 / range - p1 / range)**2)
49 |
50 | def psnr(p0, p1, peak=255.):
51 | return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2))
52 |
53 | def dssim(p0, p1, range=255.):
54 | return (1 - SSIM(p0, p1, data_range=range, multichannel=True)) / 2.
55 |
56 | def rgb2lab(in_img,mean_cent=False):
57 | from skimage import color
58 | img_lab = color.rgb2lab(in_img)
59 | if(mean_cent):
60 | img_lab[:,:,0] = img_lab[:,:,0]-50
61 | return img_lab
62 |
63 | def tensor2np(tensor_obj):
64 | # change dimension of a tensor object into a numpy array
65 | return tensor_obj[0].cpu().float().numpy().transpose((1,2,0))
66 |
67 | def np2tensor(np_obj):
68 | # change dimenion of np array into tensor array
69 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
70 |
71 | def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False):
72 | # image tensor to lab tensor
73 | from skimage import color
74 |
75 | img = tensor2im(image_tensor)
76 | img_lab = color.rgb2lab(img)
77 | if(mc_only):
78 | img_lab[:,:,0] = img_lab[:,:,0]-50
79 | if(to_norm and not mc_only):
80 | img_lab[:,:,0] = img_lab[:,:,0]-50
81 | img_lab = img_lab/100.
82 |
83 | return np2tensor(img_lab)
84 |
85 | def tensorlab2tensor(lab_tensor,return_inbnd=False):
86 | from skimage import color
87 | import warnings
88 | warnings.filterwarnings("ignore")
89 |
90 | lab = tensor2np(lab_tensor)*100.
91 | lab[:,:,0] = lab[:,:,0]+50
92 |
93 | rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1)
94 | if(return_inbnd):
95 | # convert back to lab, see if we match
96 | lab_back = color.rgb2lab(rgb_back.astype('uint8'))
97 | mask = 1.*np.isclose(lab_back,lab,atol=2.)
98 | mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis])
99 | return (im2tensor(rgb_back),mask)
100 | else:
101 | return im2tensor(rgb_back)
102 |
103 | def rgb2lab(input):
104 | from skimage import color
105 | return color.rgb2lab(input / 255.)
106 |
107 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
108 | image_numpy = image_tensor[0].cpu().float().numpy()
109 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
110 | return image_numpy.astype(imtype)
111 |
112 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
113 | return torch.Tensor((image / factor - cent)
114 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
115 |
116 | def tensor2vec(vector_tensor):
117 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0]
118 |
119 | def voc_ap(rec, prec, use_07_metric=False):
120 | """ ap = voc_ap(rec, prec, [use_07_metric])
121 | Compute VOC AP given precision and recall.
122 | If use_07_metric is true, uses the
123 | VOC 07 11 point method (default:False).
124 | """
125 | if use_07_metric:
126 | # 11 point metric
127 | ap = 0.
128 | for t in np.arange(0., 1.1, 0.1):
129 | if np.sum(rec >= t) == 0:
130 | p = 0
131 | else:
132 | p = np.max(prec[rec >= t])
133 | ap = ap + p / 11.
134 | else:
135 | # correct AP calculation
136 | # first append sentinel values at the end
137 | mrec = np.concatenate(([0.], rec, [1.]))
138 | mpre = np.concatenate(([0.], prec, [0.]))
139 |
140 | # compute the precision envelope
141 | for i in range(mpre.size - 1, 0, -1):
142 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
143 |
144 | # to calculate area under PR curve, look for points
145 | # where X axis (recall) changes value
146 | i = np.where(mrec[1:] != mrec[:-1])[0]
147 |
148 | # and sum (\Delta recall) * prec
149 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
150 | return ap
151 |
152 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
153 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
154 | image_numpy = image_tensor[0].cpu().float().numpy()
155 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
156 | return image_numpy.astype(imtype)
157 |
158 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
159 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
160 | return torch.Tensor((image / factor - cent)
161 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
162 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/models/pretrained_networks.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | import torch
3 | from torchvision import models as tv
4 | from IPython import embed
5 |
6 | class squeezenet(torch.nn.Module):
7 | def __init__(self, requires_grad=False, pretrained=True):
8 | super(squeezenet, self).__init__()
9 | pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features
10 | self.slice1 = torch.nn.Sequential()
11 | self.slice2 = torch.nn.Sequential()
12 | self.slice3 = torch.nn.Sequential()
13 | self.slice4 = torch.nn.Sequential()
14 | self.slice5 = torch.nn.Sequential()
15 | self.slice6 = torch.nn.Sequential()
16 | self.slice7 = torch.nn.Sequential()
17 | self.N_slices = 7
18 | for x in range(2):
19 | self.slice1.add_module(str(x), pretrained_features[x])
20 | for x in range(2,5):
21 | self.slice2.add_module(str(x), pretrained_features[x])
22 | for x in range(5, 8):
23 | self.slice3.add_module(str(x), pretrained_features[x])
24 | for x in range(8, 10):
25 | self.slice4.add_module(str(x), pretrained_features[x])
26 | for x in range(10, 11):
27 | self.slice5.add_module(str(x), pretrained_features[x])
28 | for x in range(11, 12):
29 | self.slice6.add_module(str(x), pretrained_features[x])
30 | for x in range(12, 13):
31 | self.slice7.add_module(str(x), pretrained_features[x])
32 | if not requires_grad:
33 | for param in self.parameters():
34 | param.requires_grad = False
35 |
36 | def forward(self, X):
37 | h = self.slice1(X)
38 | h_relu1 = h
39 | h = self.slice2(h)
40 | h_relu2 = h
41 | h = self.slice3(h)
42 | h_relu3 = h
43 | h = self.slice4(h)
44 | h_relu4 = h
45 | h = self.slice5(h)
46 | h_relu5 = h
47 | h = self.slice6(h)
48 | h_relu6 = h
49 | h = self.slice7(h)
50 | h_relu7 = h
51 | vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7'])
52 | out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7)
53 |
54 | return out
55 |
56 |
57 | class alexnet(torch.nn.Module):
58 | def __init__(self, requires_grad=False, pretrained=True):
59 | super(alexnet, self).__init__()
60 | alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features
61 | self.slice1 = torch.nn.Sequential()
62 | self.slice2 = torch.nn.Sequential()
63 | self.slice3 = torch.nn.Sequential()
64 | self.slice4 = torch.nn.Sequential()
65 | self.slice5 = torch.nn.Sequential()
66 | self.N_slices = 5
67 | for x in range(2):
68 | self.slice1.add_module(str(x), alexnet_pretrained_features[x])
69 | for x in range(2, 5):
70 | self.slice2.add_module(str(x), alexnet_pretrained_features[x])
71 | for x in range(5, 8):
72 | self.slice3.add_module(str(x), alexnet_pretrained_features[x])
73 | for x in range(8, 10):
74 | self.slice4.add_module(str(x), alexnet_pretrained_features[x])
75 | for x in range(10, 12):
76 | self.slice5.add_module(str(x), alexnet_pretrained_features[x])
77 | if not requires_grad:
78 | for param in self.parameters():
79 | param.requires_grad = False
80 |
81 | def forward(self, X):
82 | h = self.slice1(X)
83 | h_relu1 = h
84 | h = self.slice2(h)
85 | h_relu2 = h
86 | h = self.slice3(h)
87 | h_relu3 = h
88 | h = self.slice4(h)
89 | h_relu4 = h
90 | h = self.slice5(h)
91 | h_relu5 = h
92 | alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
93 | out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
94 |
95 | return out
96 |
97 | class vgg16(torch.nn.Module):
98 | def __init__(self, requires_grad=False, pretrained=True):
99 | super(vgg16, self).__init__()
100 | vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
101 | self.slice1 = torch.nn.Sequential()
102 | self.slice2 = torch.nn.Sequential()
103 | self.slice3 = torch.nn.Sequential()
104 | self.slice4 = torch.nn.Sequential()
105 | self.slice5 = torch.nn.Sequential()
106 | self.N_slices = 5
107 | for x in range(4):
108 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
109 | for x in range(4, 9):
110 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
111 | for x in range(9, 16):
112 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
113 | for x in range(16, 23):
114 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
115 | for x in range(23, 30):
116 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
117 | if not requires_grad:
118 | for param in self.parameters():
119 | param.requires_grad = False
120 |
121 | def forward(self, X):
122 | h = self.slice1(X)
123 | h_relu1_2 = h
124 | h = self.slice2(h)
125 | h_relu2_2 = h
126 | h = self.slice3(h)
127 | h_relu3_3 = h
128 | h = self.slice4(h)
129 | h_relu4_3 = h
130 | h = self.slice5(h)
131 | h_relu5_3 = h
132 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
133 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
134 |
135 | return out
136 |
137 |
138 |
139 | class resnet(torch.nn.Module):
140 | def __init__(self, requires_grad=False, pretrained=True, num=18):
141 | super(resnet, self).__init__()
142 | if(num==18):
143 | self.net = tv.resnet18(pretrained=pretrained)
144 | elif(num==34):
145 | self.net = tv.resnet34(pretrained=pretrained)
146 | elif(num==50):
147 | self.net = tv.resnet50(pretrained=pretrained)
148 | elif(num==101):
149 | self.net = tv.resnet101(pretrained=pretrained)
150 | elif(num==152):
151 | self.net = tv.resnet152(pretrained=pretrained)
152 | self.N_slices = 5
153 |
154 | self.conv1 = self.net.conv1
155 | self.bn1 = self.net.bn1
156 | self.relu = self.net.relu
157 | self.maxpool = self.net.maxpool
158 | self.layer1 = self.net.layer1
159 | self.layer2 = self.net.layer2
160 | self.layer3 = self.net.layer3
161 | self.layer4 = self.net.layer4
162 |
163 | def forward(self, X):
164 | h = self.conv1(X)
165 | h = self.bn1(h)
166 | h = self.relu(h)
167 | h_relu1 = h
168 | h = self.maxpool(h)
169 | h = self.layer1(h)
170 | h_conv2 = h
171 | h = self.layer2(h)
172 | h_conv3 = h
173 | h = self.layer3(h)
174 | h_conv4 = h
175 | h = self.layer4(h)
176 | h_conv5 = h
177 |
178 | outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5'])
179 | out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
180 |
181 | return out
182 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/models/networks_basic.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 |
4 | import sys
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.init as init
8 | from torch.autograd import Variable
9 | import numpy as np
10 | from pdb import set_trace as st
11 | from skimage import color
12 | from IPython import embed
13 | from . import pretrained_networks as pn
14 |
15 | from PerceptualSimilarity import models as util
16 |
17 | def spatial_average(in_tens, keepdim=True):
18 | # import pdb; pdb.set_trace()
19 | # return in_tens.mean([2,3],keepdim=keepdim)
20 | return in_tens.mean(2,keepdim=keepdim).mean(3,keepdim=keepdim)
21 |
22 |
23 | def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W
24 | in_H = in_tens.shape[2]
25 | scale_factor = 1.*out_H/in_H
26 |
27 | return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens)
28 |
29 | # Learned perceptual metric
30 | class PNetLin(nn.Module):
31 | def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True):
32 | super(PNetLin, self).__init__()
33 |
34 | self.pnet_type = pnet_type
35 | self.pnet_tune = pnet_tune
36 | self.pnet_rand = pnet_rand
37 | self.spatial = spatial
38 | self.lpips = lpips
39 | self.version = version
40 | self.scaling_layer = ScalingLayer()
41 |
42 | if(self.pnet_type in ['vgg','vgg16']):
43 | net_type = pn.vgg16
44 | self.chns = [64,128,256,512,512]
45 | elif(self.pnet_type=='alex'):
46 | net_type = pn.alexnet
47 | self.chns = [64,192,384,256,256]
48 | elif(self.pnet_type=='squeeze'):
49 | net_type = pn.squeezenet
50 | self.chns = [64,128,256,384,384,512,512]
51 | self.L = len(self.chns)
52 |
53 | self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
54 |
55 | if(lpips):
56 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
57 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
58 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
59 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
60 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
61 | self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]
62 | if(self.pnet_type=='squeeze'): # 7 layers for squeezenet
63 | self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
64 | self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
65 | self.lins+=[self.lin5,self.lin6]
66 |
67 | def forward(self, in0, in1, retPerLayer=False):
68 | # v0.0 - original release had a bug, where input was not scaled
69 | in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
70 | outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
71 | feats0, feats1, diffs = {}, {}, {}
72 |
73 | for kk in range(self.L):
74 | # import pdb; pdb.set_trace()
75 | feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk])
76 | diffs[kk] = (feats0[kk]-feats1[kk])**2
77 |
78 | if(self.lpips):
79 | if(self.spatial):
80 | res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)]
81 | else:
82 | res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)]
83 | else:
84 | if(self.spatial):
85 | res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)]
86 | else:
87 | res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]
88 |
89 | val = res[0]
90 | for l in range(1,self.L):
91 | val += res[l]
92 |
93 | if(retPerLayer):
94 | return (val, res)
95 | else:
96 | return val
97 |
98 | class ScalingLayer(nn.Module):
99 | def __init__(self):
100 | super(ScalingLayer, self).__init__()
101 | self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None])
102 | self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None])
103 |
104 | def forward(self, inp):
105 | return (inp - self.shift) / self.scale
106 |
107 |
108 | class NetLinLayer(nn.Module):
109 | ''' A single linear layer which does a 1x1 conv '''
110 | def __init__(self, chn_in, chn_out=1, use_dropout=False):
111 | super(NetLinLayer, self).__init__()
112 |
113 | layers = [nn.Dropout(),] if(use_dropout) else []
114 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]
115 | self.model = nn.Sequential(*layers)
116 |
117 |
118 | class Dist2LogitLayer(nn.Module):
119 | ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
120 | def __init__(self, chn_mid=32, use_sigmoid=True):
121 | super(Dist2LogitLayer, self).__init__()
122 |
123 | layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),]
124 | layers += [nn.LeakyReLU(0.2,True),]
125 | layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),]
126 | layers += [nn.LeakyReLU(0.2,True),]
127 | layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),]
128 | if(use_sigmoid):
129 | layers += [nn.Sigmoid(),]
130 | self.model = nn.Sequential(*layers)
131 |
132 | def forward(self,d0,d1,eps=0.1):
133 | return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1))
134 |
135 | class BCERankingLoss(nn.Module):
136 | def __init__(self, chn_mid=32):
137 | super(BCERankingLoss, self).__init__()
138 | self.net = Dist2LogitLayer(chn_mid=chn_mid)
139 | # self.parameters = list(self.net.parameters())
140 | self.loss = torch.nn.BCELoss()
141 |
142 | def forward(self, d0, d1, judge):
143 | per = (judge+1.)/2.
144 | self.logit = self.net.forward(d0,d1)
145 | return self.loss(self.logit, per)
146 |
147 | # L2, DSSIM metrics
148 | class FakeNet(nn.Module):
149 | def __init__(self, use_gpu=True, colorspace='Lab'):
150 | super(FakeNet, self).__init__()
151 | self.use_gpu = use_gpu
152 | self.colorspace=colorspace
153 |
154 | class L2(FakeNet):
155 |
156 | def forward(self, in0, in1, retPerLayer=None):
157 | assert(in0.size()[0]==1) # currently only supports batchSize 1
158 |
159 | if(self.colorspace=='RGB'):
160 | (N,C,X,Y) = in0.size()
161 | value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N)
162 | return value
163 | elif(self.colorspace=='Lab'):
164 | value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
165 | util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
166 | ret_var = Variable( torch.Tensor((value,) ) )
167 | if(self.use_gpu):
168 | ret_var = ret_var.cuda()
169 | return ret_var
170 |
171 | class DSSIM(FakeNet):
172 |
173 | def forward(self, in0, in1, retPerLayer=None):
174 | assert(in0.size()[0]==1) # currently only supports batchSize 1
175 |
176 | if(self.colorspace=='RGB'):
177 | value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float')
178 | elif(self.colorspace=='Lab'):
179 | value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
180 | util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
181 | ret_var = Variable( torch.Tensor((value,) ) )
182 | if(self.use_gpu):
183 | ret_var = ret_var.cuda()
184 | return ret_var
185 |
186 | def print_network(net):
187 | num_params = 0
188 | for param in net.parameters():
189 | num_params += param.numel()
190 | # print('Network',net)
191 | # print('Total number of parameters: %d' % num_params)
192 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/util/visualizer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import time
4 | from . import util
5 | from . import html
6 | # from pdb import set_trace as st
7 | import matplotlib.pyplot as plt
8 | import math
9 | # from IPython import embed
10 |
11 | def zoom_to_res(img,res=256,order=0,axis=0):
12 | # img 3xXxX
13 | from scipy.ndimage import zoom
14 | zoom_factor = res/img.shape[1]
15 | if(axis==0):
16 | return zoom(img,[1,zoom_factor,zoom_factor],order=order)
17 | elif(axis==2):
18 | return zoom(img,[zoom_factor,zoom_factor,1],order=order)
19 |
20 | class Visualizer():
21 | def __init__(self, opt):
22 | # self.opt = opt
23 | self.display_id = opt.display_id
24 | # self.use_html = opt.is_train and not opt.no_html
25 | self.win_size = opt.display_winsize
26 | self.name = opt.name
27 | self.display_cnt = 0 # display_current_results counter
28 | self.display_cnt_high = 0
29 | self.use_html = opt.use_html
30 |
31 | if self.display_id > 0:
32 | import visdom
33 | self.vis = visdom.Visdom(port = opt.display_port)
34 |
35 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
36 | util.mkdirs([self.web_dir,])
37 | if self.use_html:
38 | self.img_dir = os.path.join(self.web_dir, 'images')
39 | print('create web directory %s...' % self.web_dir)
40 | util.mkdirs([self.img_dir,])
41 |
42 | # |visuals|: dictionary of images to display or save
43 | def display_current_results(self, visuals, epoch, nrows=None, res=256):
44 | if self.display_id > 0: # show images in the browser
45 | title = self.name
46 | if(nrows is None):
47 | nrows = int(math.ceil(len(visuals.items()) / 2.0))
48 | images = []
49 | idx = 0
50 | for label, image_numpy in visuals.items():
51 | title += " | " if idx % nrows == 0 else ", "
52 | title += label
53 | img = image_numpy.transpose([2, 0, 1])
54 | img = zoom_to_res(img,res=res,order=0)
55 | images.append(img)
56 | idx += 1
57 | if len(visuals.items()) % 2 != 0:
58 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1]))*255
59 | white_image = zoom_to_res(white_image,res=res,order=0)
60 | images.append(white_image)
61 | self.vis.images(images, nrow=nrows, win=self.display_id + 1,
62 | opts=dict(title=title))
63 |
64 | if self.use_html: # save images to a html file
65 | for label, image_numpy in visuals.items():
66 | img_path = os.path.join(self.img_dir, 'epoch%.3d_cnt%.6d_%s.png' % (epoch, self.display_cnt, label))
67 | util.save_image(zoom_to_res(image_numpy, res=res, axis=2), img_path)
68 |
69 | self.display_cnt += 1
70 | self.display_cnt_high = np.maximum(self.display_cnt_high, self.display_cnt)
71 |
72 | # update website
73 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1)
74 | for n in range(epoch, 0, -1):
75 | webpage.add_header('epoch [%d]' % n)
76 | if(n==epoch):
77 | high = self.display_cnt
78 | else:
79 | high = self.display_cnt_high
80 | for c in range(high-1,-1,-1):
81 | ims = []
82 | txts = []
83 | links = []
84 |
85 | for label, image_numpy in visuals.items():
86 | img_path = 'epoch%.3d_cnt%.6d_%s.png' % (n, c, label)
87 | ims.append(os.path.join('images',img_path))
88 | txts.append(label)
89 | links.append(os.path.join('images',img_path))
90 | webpage.add_images(ims, txts, links, width=self.win_size)
91 | webpage.save()
92 |
93 | # save errors into a directory
94 | def plot_current_errors_save(self, epoch, counter_ratio, opt, errors,keys='+ALL',name='loss', to_plot=False):
95 | if not hasattr(self, 'plot_data'):
96 | self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())}
97 | self.plot_data['X'].append(epoch + counter_ratio)
98 | self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']])
99 |
100 | # embed()
101 | if(keys=='+ALL'):
102 | plot_keys = self.plot_data['legend']
103 | else:
104 | plot_keys = keys
105 |
106 | if(to_plot):
107 | (f,ax) = plt.subplots(1,1)
108 | for (k,kname) in enumerate(plot_keys):
109 | kk = np.where(np.array(self.plot_data['legend'])==kname)[0][0]
110 | x = self.plot_data['X']
111 | y = np.array(self.plot_data['Y'])[:,kk]
112 | if(to_plot):
113 | ax.plot(x, y, 'o-', label=kname)
114 | np.save(os.path.join(self.web_dir,'%s_x')%kname,x)
115 | np.save(os.path.join(self.web_dir,'%s_y')%kname,y)
116 |
117 | if(to_plot):
118 | plt.legend(loc=0,fontsize='small')
119 | plt.xlabel('epoch')
120 | plt.ylabel('Value')
121 | f.savefig(os.path.join(self.web_dir,'%s.png'%name))
122 | f.clf()
123 | plt.close()
124 |
125 | # errors: dictionary of error labels and values
126 | def plot_current_errors(self, epoch, counter_ratio, opt, errors):
127 | if not hasattr(self, 'plot_data'):
128 | self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())}
129 | self.plot_data['X'].append(epoch + counter_ratio)
130 | self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']])
131 | self.vis.line(
132 | X=np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1),
133 | Y=np.array(self.plot_data['Y']),
134 | opts={
135 | 'title': self.name + ' loss over time',
136 | 'legend': self.plot_data['legend'],
137 | 'xlabel': 'epoch',
138 | 'ylabel': 'loss'},
139 | win=self.display_id)
140 |
141 | # errors: same format as |errors| of plotCurrentErrors
142 | def print_current_errors(self, epoch, i, errors, t, t2=-1, t2o=-1, fid=None):
143 | message = '(ep: %d, it: %d, t: %.3f[s], ept: %.2f/%.2f[h]) ' % (epoch, i, t, t2o, t2)
144 | message += (', ').join(['%s: %.3f' % (k, v) for k, v in errors.items()])
145 |
146 | print(message)
147 | if(fid is not None):
148 | fid.write('%s\n'%message)
149 |
150 |
151 | # save image to the disk
152 | def save_images_simple(self, webpage, images, names, in_txts, prefix='', res=256):
153 | image_dir = webpage.get_image_dir()
154 | ims = []
155 | txts = []
156 | links = []
157 |
158 | for name, image_numpy, txt in zip(names, images, in_txts):
159 | image_name = '%s_%s.png' % (prefix, name)
160 | save_path = os.path.join(image_dir, image_name)
161 | if(res is not None):
162 | util.save_image(zoom_to_res(image_numpy,res=res,axis=2), save_path)
163 | else:
164 | util.save_image(image_numpy, save_path)
165 |
166 | ims.append(os.path.join(webpage.img_subdir,image_name))
167 | # txts.append(name)
168 | txts.append(txt)
169 | links.append(os.path.join(webpage.img_subdir,image_name))
170 | # embed()
171 | webpage.add_images(ims, txts, links, width=self.win_size)
172 |
173 | # save image to the disk
174 | def save_images(self, webpage, images, names, image_path, title=''):
175 | image_dir = webpage.get_image_dir()
176 | # short_path = ntpath.basename(image_path)
177 | # name = os.path.splitext(short_path)[0]
178 | # name = short_path
179 | # webpage.add_header('%s, %s' % (name, title))
180 | ims = []
181 | txts = []
182 | links = []
183 |
184 | for label, image_numpy in zip(names, images):
185 | image_name = '%s.jpg' % (label,)
186 | save_path = os.path.join(image_dir, image_name)
187 | util.save_image(image_numpy, save_path)
188 |
189 | ims.append(image_name)
190 | txts.append(label)
191 | links.append(image_name)
192 | webpage.add_images(ims, txts, links, width=self.win_size)
193 |
194 | # save image to the disk
195 | # def save_images(self, webpage, visuals, image_path, short=False):
196 | # image_dir = webpage.get_image_dir()
197 | # if short:
198 | # short_path = ntpath.basename(image_path)
199 | # name = os.path.splitext(short_path)[0]
200 | # else:
201 | # name = image_path
202 |
203 | # webpage.add_header(name)
204 | # ims = []
205 | # txts = []
206 | # links = []
207 |
208 | # for label, image_numpy in visuals.items():
209 | # image_name = '%s_%s.png' % (name, label)
210 | # save_path = os.path.join(image_dir, image_name)
211 | # util.save_image(image_numpy, save_path)
212 |
213 | # ims.append(image_name)
214 | # txts.append(label)
215 | # links.append(image_name)
216 | # webpage.add_images(ims, txts, links, width=self.win_size)
217 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/README.md:
--------------------------------------------------------------------------------
1 |
2 | ## Perceptual Similarity Metric and Dataset [[Project Page]](http://richzhang.github.io/PerceptualSimilarity/)
3 |
4 | **The Unreasonable Effectiveness of Deep Features as a Perceptual Metric**
5 | [Richard Zhang](https://richzhang.github.io/), [Phillip Isola](http://web.mit.edu/phillipi/), [Alexei A. Efros](http://www.eecs.berkeley.edu/~efros/), [Eli Shechtman](https://research.adobe.com/person/eli-shechtman/), [Oliver Wang](http://www.oliverwang.info/).
6 |
In [CVPR](https://arxiv.org/abs/1801.03924), 2018.
7 |
8 |
9 |
10 | This repository contains our **perceptual metric (LPIPS)** and **dataset (BAPPS)**. It can also be used as a "perceptual loss". This uses PyTorch; a Tensorflow alternative is [here](https://github.com/alexlee-gk/lpips-tensorflow).
11 |
12 | **Table of Contents**
13 | 1. [Learned Perceptual Image Patch Similarity (LPIPS) metric](#1-learned-perceptual-image-patch-similarity-lpips-metric)
14 | a. [Basic Usage](#a-basic-usage) If you just want to run the metric through command line, this is all you need.
15 | b. ["Perceptual Loss" usage](#b-backpropping-through-the-metric)
16 | c. [About the metric](#c-about-the-metric)
17 | 2. [Berkeley-Adobe Perceptual Patch Similarity (BAPPS) dataset](#2-berkeley-adobe-perceptual-patch-similarity-bapps-dataset)
18 | a. [Download](#a-downloading-the-dataset)
19 | b. [Evaluation](#b-evaluating-a-perceptual-similarity-metric-on-a-dataset)
20 | c. [About the dataset](#c-about-the-dataset)
21 | d. [Train the metric using the dataset](#d-using-the-dataset-to-train-the-metric)
22 |
23 | ## (0) Dependencies/Setup
24 |
25 | ### Installation
26 | - Install PyTorch 1.0+ and torchvision fom http://pytorch.org
27 |
28 | ```bash
29 | pip install -r requirements.txt
30 | ```
31 | - Clone this repo:
32 | ```bash
33 | git clone https://github.com/richzhang/PerceptualSimilarity
34 | cd PerceptualSimilarity
35 | ```
36 |
37 | ## (1) Learned Perceptual Image Patch Similarity (LPIPS) metric
38 |
39 | Evaluate the distance between image patches. **Higher means further/more different. Lower means more similar.**
40 |
41 | ### (A) Basic Usage
42 |
43 | #### (A.I) Line commands
44 |
45 | Example scripts to take the distance between 2 specific images, all corresponding pairs of images in 2 directories, or all pairs of images within a directory:
46 |
47 | ```
48 | python compute_dists.py -p0 imgs/ex_ref.png -p1 imgs/ex_p0.png --use_gpu
49 | python compute_dists_dirs.py -d0 imgs/ex_dir0 -d1 imgs/ex_dir1 -o imgs/example_dists.txt --use_gpu
50 | python compute_dists_pair.py -d imgs/ex_dir_pair -o imgs/example_dists_pair.txt --use_gpu
51 | ```
52 |
53 | #### (A.II) Python code
54 |
55 | File [test_network.py](test_network.py) shows example usage. This snippet is all you really need.
56 |
57 | ```python
58 | import models
59 | model = models.PerceptualLoss(model='net-lin', net='alex', use_gpu=use_gpu, gpu_ids=[0])
60 | d = model.forward(im0,im1)
61 | ```
62 |
63 | Variables ```im0, im1``` is a PyTorch Tensor/Variable with shape ```Nx3xHxW``` (```N``` patches of size ```HxW```, RGB images scaled in `[-1,+1]`). This returns `d`, a length `N` Tensor/Variable.
64 |
65 | Run `python test_network.py` to take the distance between example reference image [`ex_ref.png`](imgs/ex_ref.png) to distorted images [`ex_p0.png`](./imgs/ex_p0.png) and [`ex_p1.png`](imgs/ex_p1.png). Before running it - which do you think *should* be closer?
66 |
67 | **Some Options** By default in `model.initialize`:
68 | - `net='alex'`: Network `alex` is fastest, performs the best, and is the default. You can instead use `squeeze` or `vgg`.
69 | - `model='net-lin'`: This adds a linear calibration on top of intermediate features in the net. Set this to `model=net` to equally weight all the features.
70 |
71 | ### (B) Backpropping through the metric
72 |
73 | File [`perceptual_loss.py`](perceptual_loss.py) shows how to iteratively optimize using the metric. Run `python perceptual_loss.py` for a demo. The code can also be used to implement vanilla VGG loss, without our learned weights.
74 |
75 | ### (C) About the metric
76 |
77 | **Higher means further/more different. Lower means more similar.**
78 |
79 | We found that deep network activations work surprisingly well as a perceptual similarity metric. This was true across network architectures (SqueezeNet [2.8 MB], AlexNet [9.1 MB], and VGG [58.9 MB] provided similar scores) and supervisory signals (unsupervised, self-supervised, and supervised all perform strongly). We slightly improved scores by linearly "calibrating" networks - adding a linear layer on top of off-the-shelf classification networks. We provide 3 variants, using linear layers on top of the SqueezeNet, AlexNet (default), and VGG networks.
80 |
81 | If you use LPIPS in your publication, please specify which version you are using. The current version is 0.1. You can set `version='0.0'` for the initial release.
82 |
83 | ## (2) Berkeley Adobe Perceptual Patch Similarity (BAPPS) dataset
84 |
85 | ### (A) Downloading the dataset
86 |
87 | Run `bash ./scripts/download_dataset.sh` to download and unzip the dataset into directory `./dataset`. It takes [6.6 GB] total. Alternatively, run `bash ./scripts/get_dataset_valonly.sh` to only download the validation set [1.3 GB].
88 | - 2AFC train [5.3 GB]
89 | - 2AFC val [1.1 GB]
90 | - JND val [0.2 GB]
91 |
92 | ### (B) Evaluating a perceptual similarity metric on a dataset
93 |
94 | Script `test_dataset_model.py` evaluates a perceptual model on a subset of the dataset.
95 |
96 | **Dataset flags**
97 | - `--dataset_mode`: `2afc` or `jnd`, which type of perceptual judgment to evaluate
98 | - `--datasets`: list the datasets to evaluate
99 | - if `--dataset_mode 2afc`: choices are [`train/traditional`, `train/cnn`, `val/traditional`, `val/cnn`, `val/superres`, `val/deblur`, `val/color`, `val/frameinterp`]
100 | - if `--dataset_mode jnd`: choices are [`val/traditional`, `val/cnn`]
101 |
102 | **Perceptual similarity model flags**
103 | - `--model`: perceptual similarity model to use
104 | - `net-lin` for our LPIPS learned similarity model (linear network on top of internal activations of pretrained network)
105 | - `net` for a classification network (uncalibrated with all layers averaged)
106 | - `l2` for Euclidean distance
107 | - `ssim` for Structured Similarity Image Metric
108 | - `--net`: [`squeeze`,`alex`,`vgg`] for the `net-lin` and `net` models; ignored for `l2` and `ssim` models
109 | - `--colorspace`: choices are [`Lab`,`RGB`], used for the `l2` and `ssim` models; ignored for `net-lin` and `net` models
110 |
111 | **Misc flags**
112 | - `--batch_size`: evaluation batch size (will default to 1)
113 | - `--use_gpu`: turn on this flag for GPU usage
114 |
115 | An example usage is as follows: `python ./test_dataset_model.py --dataset_mode 2afc --datasets val/traditional val/cnn --model net-lin --net alex --use_gpu --batch_size 50`. This would evaluate our model on the "traditional" and "cnn" validation datasets.
116 |
117 | ### (C) About the dataset
118 |
119 | The dataset contains two types of perceptual judgements: **Two Alternative Forced Choice (2AFC)** and **Just Noticeable Differences (JND)**.
120 |
121 | **(1) 2AFC** Evaluators were given a patch triplet (1 reference + 2 distorted). They were asked to select which of the distorted was "closer" to the reference.
122 |
123 | Training sets contain 2 judgments/triplet.
124 | - `train/traditional` [56.6k triplets]
125 | - `train/cnn` [38.1k triplets]
126 | - `train/mix` [56.6k triplets]
127 |
128 | Validation sets contain 5 judgments/triplet.
129 | - `val/traditional` [4.7k triplets]
130 | - `val/cnn` [4.7k triplets]
131 | - `val/superres` [10.9k triplets]
132 | - `val/deblur` [9.4k triplets]
133 | - `val/color` [4.7k triplets]
134 | - `val/frameinterp` [1.9k triplets]
135 |
136 | Each 2AFC subdirectory contains the following folders:
137 | - `ref`: original reference patches
138 | - `p0,p1`: two distorted patches
139 | - `judge`: human judgments - 0 if all preferred p0, 1 if all humans preferred p1
140 |
141 | **(2) JND** Evaluators were presented with two patches - a reference and a distorted - for a limited time. They were asked if the patches were the same (identically) or different.
142 |
143 | Each set contains 3 human evaluations/example.
144 | - `val/traditional` [4.8k pairs]
145 | - `val/cnn` [4.8k pairs]
146 |
147 | Each JND subdirectory contains the following folders:
148 | - `p0,p1`: two patches
149 | - `same`: human judgments: 0 if all humans thought patches were different, 1 if all humans thought patches were same
150 |
151 | ### (D) Using the dataset to train the metric
152 |
153 | See script `train_test_metric.sh` for an example of training and testing the metric. The script will train a model on the full training set for 10 epochs, and then test the learned metric on all of the validation sets. The numbers should roughly match the **Alex - lin** row in Table 5 in the [paper](https://arxiv.org/abs/1801.03924). The code supports training a linear layer on top of an existing representation. Training will add a subdirectory in the `checkpoints` directory.
154 |
155 | You can also train "scratch" and "tune" versions by running `train_test_metric_scratch.sh` and `train_test_metric_tune.sh`, respectively.
156 |
157 | ### Docker Environment
158 |
159 | [Docker](https://hub.docker.com/r/shinyeyes/perceptualsimilarity/) set up by [SuperShinyEyes](https://github.com/SuperShinyEyes).
160 |
161 | ## Citation
162 |
163 | If you find this repository useful for your research, please use the following.
164 |
165 | ```
166 | @inproceedings{zhang2018perceptual,
167 | title={The Unreasonable Effectiveness of Deep Features as a Perceptual Metric},
168 | author={Zhang, Richard and Isola, Phillip and Efros, Alexei A and Shechtman, Eli and Wang, Oliver},
169 | booktitle={CVPR},
170 | year={2018}
171 | }
172 | ```
173 |
174 | ## Acknowledgements
175 |
176 | This repository borrows partially from the [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) repository. The average precision (AP) code is borrowed from the [py-faster-rcnn](https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/datasets/voc_eval.py) repository. Backpropping through the metric was implemented by [Angjoo Kanazawa](https://github.com/akanazawa).
177 |
--------------------------------------------------------------------------------
/models/HidingUNet_S.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 |
3 | import functools
4 |
5 | import torch
6 | import torch.nn as nn
7 | from models.module import Harm2d
8 | import torch.nn.functional as F
9 | import math
10 | import numpy as np
11 |
12 |
13 |
14 | # Defines the Unet generator.
15 | # |num_downs|: number of downsamplings in UNet. For example,
16 | # if |num_downs| == 7, image of size 128x128 will become of size 1x1
17 | # at the bottleneck
18 |
19 | def harm3x3(in_planes, out_planes, stride=1, level=None):
20 | """3x3 harmonic convolution with padding"""
21 | return Harm2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1,
22 | bias=False, use_bn=False, level=level)
23 |
24 | def get_feature_cood(frequency_num=9):
25 | array=2*((np.arange(frequency_num)*1.0)/(frequency_num-1)) - 1
26 | #array = np.random.random(frequency_num)
27 | return torch.FloatTensor(np.float32(array))#.view(1, -1, 1, 1)
28 |
29 |
30 | class SELayer(nn.Module):
31 | def __init__(self, channel, reduction=32):
32 | super(SELayer, self).__init__()
33 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
34 | mip = max(8, channel // reduction)
35 | self.fc = nn.Sequential(
36 | nn.Linear(channel, mip, bias=False),
37 | nn.ReLU(inplace=True),
38 | nn.Linear(mip, channel, bias=False),
39 | nn.Sigmoid()
40 | )
41 |
42 | def forward(self, x):
43 | b, c, _, _ = x.size()
44 | y = self.avg_pool(x).view(b, c)
45 | y = self.fc(y).view(b, c, 1, 1)
46 | #return x * y.expand_as(x)
47 | return y
48 |
49 | class attention(nn.Module):
50 | def __init__(self, channel, reduction=32):
51 | super(attention, self).__init__()
52 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
53 | mip = max(8, channel // reduction)
54 |
55 | #nn.Linear(channel, mip, bias=False),
56 | #nn.ReLU(inplace=True),
57 | #nn.Linear(mip, channel, bias=False),
58 | #nn.Sigmoid()
59 | self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
60 | self.pool_w = nn.AdaptiveAvgPool2d((1, None))
61 | inp = channel
62 | oup = channel
63 |
64 | #mip = max(8, inp // reduction)
65 |
66 | self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
67 | self.bn1 = nn.BatchNorm2d(mip)
68 | self.act = nn.ReLU(inplace=True)
69 |
70 | self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
71 | self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
72 | self.conv_c = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
73 |
74 | def forward(self, x):
75 | '''b, c, _, _ = x.size()
76 | y = self.avg_pool(x).view(b, c)
77 | y = self.fc(y).view(b, c, 1, 1)
78 | '''
79 | n, c, h, w = x.size()
80 | x_h = self.pool_h(x)
81 | x_w = self.pool_w(x).permute(0, 1, 3, 2)
82 | x_c = self.avg_pool(x)
83 |
84 | y = torch.cat([x_h, x_w, x_c], dim=2)
85 | y = self.conv1(y)
86 | y = self.bn1(y)
87 | y = self.act(y)
88 |
89 | x_h, x_w, x_c = torch.split(y, [h, w, 1], dim=2)
90 | x_w = x_w.permute(0, 1, 3, 2)
91 |
92 | a_h = self.conv_h(x_h).sigmoid()
93 | a_w = self.conv_w(x_w).sigmoid()
94 | a_c = self.conv_c(x_c).sigmoid()
95 | #return x * y.expand_as(x)
96 | return a_h*a_w*a_c
97 |
98 | class ChannelShuffle(nn.Module):
99 |
100 | def __init__(self, groups):
101 | super().__init__()
102 | self.groups = groups
103 |
104 | def forward(self, x):
105 | batchsize, channels, height, width = x.data.size()
106 | channels_per_group = int(channels / self.groups)
107 |
108 | #"""suppose a convolutional layer with g groups whose output has
109 | #g x n channels; we first reshape the output channel dimension
110 | #into (g, n)"""
111 | x = x.view(batchsize, self.groups, channels_per_group, height, width)
112 |
113 | #"""transposing and then flattening it back as the input of next layer."""
114 | x = x.transpose(1, 2).contiguous()
115 | x = x.view(batchsize, -1, height, width)
116 |
117 | return x
118 | '''Hnet = UnetGenerator_C(input_nc=opt.channel_secret * opt.num_secret, output_nc=opt.channel_cover * opt.num_cover,
119 | num_downs=num_downs, norm_layer=norm_layer, output_function=nn.Tanh)'''
120 |
121 |
122 | class UnetGenerator_S(nn.Module):
123 | def __init__(self, input_nc, output_nc, num_downs, ngf=64,
124 | norm_layer=None, use_dropout=False, output_function=nn.Sigmoid):
125 | super(UnetGenerator_S, self).__init__()
126 | self.output_function = nn.Tanh
127 | '''self.tanh = output_function==nn.Tanh
128 | if self.tanh:
129 | self.factor = 10/255
130 | else:
131 | self.factor = 1.0'''
132 | nf = 9
133 | self.factor = 10 / 255
134 | self.tanh = nn.Tanh
135 | self.conv1 = nn.Conv2d(input_nc, 64, kernel_size=4, stride=2, padding=1, bias=False)
136 | self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False)
137 | self.conv3 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False)
138 | self.conv4 = nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False)
139 | self.conv5 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1, bias=False)
140 | self.bn2 = norm_layer(128)
141 | self.bn3 = norm_layer(256)
142 | self.bn4 = norm_layer(512)
143 | self.leakyrelu = nn.LeakyReLU(0.2, True)
144 | self.relu = nn.ReLU()
145 | self.convtran5 = nn.ConvTranspose2d(512, 512, kernel_size=4, stride=2, padding=1, bias=False)
146 | self.bnt5 = norm_layer(512)
147 | self.convtran4 = nn.ConvTranspose2d(1024, 256, kernel_size=4, stride=2, padding=1, bias=False)
148 | self.bnt4 = norm_layer(256)
149 | self.convtran3 = nn.ConvTranspose2d(512, 128, kernel_size=4, stride=2, padding=1, bias=False)
150 | self.bnt3 = norm_layer(128)
151 | self.convtran2 = nn.ConvTranspose2d(256, 64, kernel_size=4, stride=2, padding=1, bias=False)
152 | self.bnt2 = norm_layer(64)
153 | self.convtran1 = nn.ConvTranspose2d(128, output_nc, kernel_size=4, stride=2, padding=1, bias=False)
154 |
155 | self.dctconv1 = harm3x3(512, 512)
156 | self.dctconv2 = harm3x3(256, 256)
157 | self.dctconv3 = harm3x3(128, 128)
158 | self.dctconv4 = harm3x3(64, 64)
159 | self.dctconv5 = harm3x3(output_nc, output_nc)
160 |
161 | self.atten1 = attention(512)
162 | self.atten2 = attention(256)
163 | self.atten3 = attention(128)
164 | self.atten4 = attention(64)
165 | self.atten5 = attention(output_nc)
166 | self.channel_shuffle1 = ChannelShuffle(512)
167 | self.channel_shuffle2 = ChannelShuffle(256)
168 | self.channel_shuffle3 = ChannelShuffle(128)
169 | self.channel_shuffle4 = ChannelShuffle(64)
170 | self.channel_shuffle5 = ChannelShuffle(output_nc)
171 |
172 | self.f_atten = get_feature_cood(nf)
173 | self.f_atten_1 = nn.Parameter(
174 | self.f_atten.repeat(512).view(1, -1, 1, 1), requires_grad=True)
175 | self.f_atten_2 = nn.Parameter(
176 | self.f_atten.repeat(256).view(1, -1, 1, 1), requires_grad=True)
177 | self.f_atten_3 = nn.Parameter(
178 | self.f_atten.repeat(128).view(1, -1, 1, 1), requires_grad=True)
179 | self.f_atten_4 = nn.Parameter(
180 | self.f_atten.repeat(64).view(1, -1, 1, 1), requires_grad=True)
181 | self.f_atten_5 = nn.Parameter(
182 | self.f_atten.repeat(output_nc).view(1, -1, 1, 1), requires_grad=True)
183 | self.weight5 = nn.Parameter(
184 | nn.init.kaiming_normal_(torch.Tensor(512, 512 * nf, 1, 1), mode='fan_out',
185 | nonlinearity='relu'))
186 | self.weight4 = nn.Parameter(
187 | nn.init.kaiming_normal_(torch.Tensor(256, 256 * nf, 1, 1), mode='fan_out',
188 | nonlinearity='relu'))
189 | self.weight3 = nn.Parameter(
190 | nn.init.kaiming_normal_(torch.Tensor(128, 128 * nf, 1, 1), mode='fan_out',
191 | nonlinearity='relu'))
192 | self.weight2 = nn.Parameter(
193 | nn.init.kaiming_normal_(torch.Tensor(64, 64 * nf, 1, 1), mode='fan_out',
194 | nonlinearity='relu'))
195 | self.weight1 = nn.Parameter(
196 | nn.init.kaiming_normal_(torch.Tensor(output_nc, output_nc * nf, 1, 1), mode='fan_out',
197 | nonlinearity='relu'))
198 | self.groups = 1
199 |
200 | self.drop = nn.Dropout(0.5)
201 |
202 |
203 | def forward(self, input):
204 | out1 = self.conv1(input)
205 | out2 = self.bn2(self.conv2(self.leakyrelu(out1)))
206 | out3 = self.bn3(self.conv3(self.leakyrelu(out2)))
207 | out4 = self.bn4(self.conv4(self.leakyrelu(out3)))
208 | out5 = self.conv5(self.leakyrelu(out4))
209 | out_5 = self.bnt5(self.convtran5(self.relu(out5)))
210 | out_dct_1 = self.dctconv1(out_5)
211 | #out_dct_1_cs = self.atten1_CAM(out_dct_1)+self.atten1_PAM(out_dct_1)
212 | #out_dct_1_cs = self.atten1(out_dct_1)
213 | out_dct_1_cs = self.atten1(out_5)
214 | out_dct_1_cs = self.channel_shuffle1(out_dct_1_cs.repeat(1, 9, 1, 1).expand_as(out_dct_1))#*out_dct_1
215 | out_dct_1_f = self.f_atten_1.expand_as(out_dct_1).to(input.device) #* out_dct_1
216 | out_dct_1 = (out_dct_1_cs+out_dct_1_f) * out_dct_1
217 | #out_5 = F.conv2d(out_dct_1, self.weight5, padding=0, groups=self.groups)
218 | out_5 = torch.cat([out4, out_5], 1)
219 |
220 | out_4 = self.bnt4(self.convtran4(self.relu(out_5)))
221 | out_4 = self.drop(out_4)
222 | #out_dct_2 = self.atten2(self.dctconv2(out_4))
223 | out_dct_2 = self.dctconv2(out_4)
224 | #out_dct_2_cs = self.atten2_CAM(out_dct_2)+self.atten2_PAM(out_dct_2)
225 | out_dct_2_cs = self.atten2(out_4)
226 | out_dct_2_cs = self.channel_shuffle2(out_dct_2_cs.repeat(1, 9, 1, 1).expand_as(out_dct_2))#*out_dct_2
227 | out_dct_2_f = self.f_atten_2.expand_as(out_dct_2).to(input.device) #* out_dct_2
228 | out_dct_2 = (out_dct_2_cs + out_dct_2_f)* out_dct_2
229 | #out_4 = F.conv2d(out_dct_2, self.weight4, padding=0, groups=self.groups)
230 | out_4 = torch.cat([out3, out_4], 1)
231 |
232 | out_3 = self.bnt3(self.convtran3(self.relu(out_4)))
233 | #out_dct_3 = self.atten3(self.dctconv3(out_3))
234 | out_dct_3 = self.dctconv3(out_3)
235 | #out_dct_3_cs = self.atten3_CAM(out_dct_3)+self.atten3_PAM(out_dct_3)
236 | out_dct_3_cs = self.atten3(out_3)
237 | out_dct_3_cs = self.channel_shuffle3(out_dct_3_cs.repeat(1, 9, 1, 1).expand_as(out_dct_3))#*out_dct_3
238 | out_dct_3_f = self.f_atten_3.expand_as(out_dct_3).to(input.device) #* out_dct_3
239 | out_dct_3 = (out_dct_3_cs + out_dct_3_f)*out_dct_3
240 | #out_3 = F.conv2d(out_dct_3, self.weight3, padding=0, groups=self.groups)
241 | out_3 = torch.cat([out2, out_3], 1)
242 |
243 | out_2 = self.bnt2(self.convtran2(self.relu(out_3)))
244 | #out_dct_4 = self.atten4(self.dctconv4(out_2))
245 | out_dct_4 = self.dctconv4(out_2)
246 | #out_dct_4_cs = self.atten4_CAM(out_dct_4)+self.atten4_PAM(out_dct_4)
247 | out_dct_4_cs = self.atten4(out_2)
248 | out_dct_4_cs = self.channel_shuffle4(out_dct_4_cs.repeat(1, 9, 1, 1).expand_as(out_dct_4))#*out_dct_4
249 | out_dct_4_f = self.f_atten_4.expand_as(out_dct_4).to(input.device) #* out_dct_4
250 | out_dct_4 = (out_dct_4_cs + out_dct_4_f)* out_dct_4
251 | #out_2 = F.conv2d(out_dct_4, self.weight2, padding=0, groups=self.groups)
252 | out_2 = torch.cat([out1, out_2], 1)
253 | out_1 = self.relu(out_2)
254 | out_1 = self.convtran1(out_1)
255 | #out_dct_5 = self.atten5(self.dctconv5(out_1))
256 | out_dct_5 = self.dctconv5(out_1)
257 | #out_dct_5_cs = self.atten5_CAM(out_dct_5) + self.atten5_PAM(out_dct_5)
258 | out_dct_5_cs = self.atten5(out_1)
259 | out_dct_5_cs = self.channel_shuffle5(out_dct_5_cs.repeat(1, 9, 1, 1).expand_as(out_dct_5))#*out_dct_5
260 | out_dct_5_f = self.f_atten_5.expand_as(out_dct_5).to(input.device) #
261 | out_dct_5 = (out_dct_5_cs + out_dct_5_f)* out_dct_5
262 | #out = F.conv2d(out_dct_5, self.weight1, padding=0, groups=self.groups)
263 | #out = torch.tanh(out_dct_5)
264 | #out = torch.tanh(out_dct_5)
265 | #out_dct_5 = self.factor * out
266 | # out = torch.cat([input, out], 1)
267 |
268 | return out_dct_1, out_dct_2, out_dct_3, out_dct_4, out_dct_5
269 | #return out_dct_1, out_dct_5
270 |
271 |
272 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/models/dist_model.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 |
4 | import sys
5 | import numpy as np
6 | import torch
7 | from torch import nn
8 | import os
9 | from collections import OrderedDict
10 | from torch.autograd import Variable
11 | import itertools
12 | from .base_model import BaseModel
13 | from scipy.ndimage import zoom
14 | import fractions
15 | import functools
16 | import skimage.transform
17 | from tqdm import tqdm
18 |
19 | from IPython import embed
20 |
21 | from . import networks_basic as networks
22 | import models as util
23 |
24 | class DistModel(BaseModel):
25 | def name(self):
26 | return self.model_name
27 |
28 | def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None,
29 | use_gpu=True, printNet=False, spatial=False,
30 | is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]):
31 | '''
32 | INPUTS
33 | model - ['net-lin'] for linearly calibrated network
34 | ['net'] for off-the-shelf network
35 | ['L2'] for L2 distance in Lab colorspace
36 | ['SSIM'] for ssim in RGB colorspace
37 | net - ['squeeze','alex','vgg']
38 | model_path - if None, will look in weights/[NET_NAME].pth
39 | colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM
40 | use_gpu - bool - whether or not to use a GPU
41 | printNet - bool - whether or not to print network architecture out
42 | spatial - bool - whether to output an array containing varying distances across spatial dimensions
43 | spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below).
44 | spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images.
45 | spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear).
46 | is_train - bool - [True] for training mode
47 | lr - float - initial learning rate
48 | beta1 - float - initial momentum term for adam
49 | version - 0.1 for latest, 0.0 was original (with a bug)
50 | gpu_ids - int array - [0] by default, gpus to use
51 | '''
52 | BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids)
53 |
54 | self.model = model
55 | self.net = net
56 | self.is_train = is_train
57 | self.spatial = spatial
58 | self.gpu_ids = gpu_ids
59 | self.model_name = '%s [%s]'%(model,net)
60 |
61 | if(self.model == 'net-lin'): # pretrained net + linear layer
62 | self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net,
63 | use_dropout=True, spatial=spatial, version=version, lpips=True)
64 | kw = {}
65 | if not use_gpu:
66 | kw['map_location'] = 'cpu'
67 | if(model_path is None):
68 | import inspect
69 | model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', 'weights/v%s/%s.pth'%(version,net)))
70 |
71 | if(not is_train):
72 | # print('Loading model from: %s'%model_path)
73 | self.net.load_state_dict(torch.load(model_path, **kw), strict=False)
74 |
75 | elif(self.model=='net'): # pretrained network
76 | self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False)
77 | elif(self.model in ['L2','l2']):
78 | self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing
79 | self.model_name = 'L2'
80 | elif(self.model in ['DSSIM','dssim','SSIM','ssim']):
81 | self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace)
82 | self.model_name = 'SSIM'
83 | else:
84 | raise ValueError("Model [%s] not recognized." % self.model)
85 |
86 | self.parameters = list(self.net.parameters())
87 |
88 | if self.is_train: # training mode
89 | # extra network on top to go from distances (d0,d1) => predicted human judgment (h*)
90 | self.rankLoss = networks.BCERankingLoss()
91 | self.parameters += list(self.rankLoss.net.parameters())
92 | self.lr = lr
93 | self.old_lr = lr
94 | self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999))
95 | else: # test mode
96 | self.net.eval()
97 |
98 | if(use_gpu):
99 | self.net.to(gpu_ids[0])
100 | self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids)
101 | if(self.is_train):
102 | self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0
103 |
104 | if(printNet):
105 | # print('---------- Networks initialized -------------')
106 | networks.print_network(self.net)
107 | # print('-----------------------------------------------')
108 |
109 | def forward(self, in0, in1, retPerLayer=False):
110 | ''' Function computes the distance between image patches in0 and in1
111 | INPUTS
112 | in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1]
113 | OUTPUT
114 | computed distances between in0 and in1
115 | '''
116 |
117 | return self.net.forward(in0, in1, retPerLayer=retPerLayer)
118 |
119 | # ***** TRAINING FUNCTIONS *****
120 | def optimize_parameters(self):
121 | self.forward_train()
122 | self.optimizer_net.zero_grad()
123 | self.backward_train()
124 | self.optimizer_net.step()
125 | self.clamp_weights()
126 |
127 | def clamp_weights(self):
128 | for module in self.net.modules():
129 | if(hasattr(module, 'weight') and module.kernel_size==(1,1)):
130 | module.weight.data = torch.clamp(module.weight.data,min=0)
131 |
132 | def set_input(self, data):
133 | self.input_ref = data['ref']
134 | self.input_p0 = data['p0']
135 | self.input_p1 = data['p1']
136 | self.input_judge = data['judge']
137 |
138 | if(self.use_gpu):
139 | self.input_ref = self.input_ref.to(device=self.gpu_ids[0])
140 | self.input_p0 = self.input_p0.to(device=self.gpu_ids[0])
141 | self.input_p1 = self.input_p1.to(device=self.gpu_ids[0])
142 | self.input_judge = self.input_judge.to(device=self.gpu_ids[0])
143 |
144 | self.var_ref = Variable(self.input_ref,requires_grad=True)
145 | self.var_p0 = Variable(self.input_p0,requires_grad=True)
146 | self.var_p1 = Variable(self.input_p1,requires_grad=True)
147 |
148 | def forward_train(self): # run forward pass
149 | # print(self.net.module.scaling_layer.shift)
150 | # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item())
151 |
152 | self.d0 = self.forward(self.var_ref, self.var_p0)
153 | self.d1 = self.forward(self.var_ref, self.var_p1)
154 | self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge)
155 |
156 | self.var_judge = Variable(1.*self.input_judge).view(self.d0.size())
157 |
158 | self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.)
159 |
160 | return self.loss_total
161 |
162 | def backward_train(self):
163 | torch.mean(self.loss_total).backward()
164 |
165 | def compute_accuracy(self,d0,d1,judge):
166 | ''' d0, d1 are Variables, judge is a Tensor '''
167 | d1_lt_d0 = (d1 %f' % (type,self.old_lr, lr))
210 | self.old_lr = lr
211 |
212 | def score_2afc_dataset(data_loader, func, name=''):
213 | ''' Function computes Two Alternative Forced Choice (2AFC) score using
214 | distance function 'func' in dataset 'data_loader'
215 | INPUTS
216 | data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside
217 | func - callable distance function - calling d=func(in0,in1) should take 2
218 | pytorch tensors with shape Nx3xXxY, and return numpy array of length N
219 | OUTPUTS
220 | [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators
221 | [1] - dictionary with following elements
222 | d0s,d1s - N arrays containing distances between reference patch to perturbed patches
223 | gts - N array in [0,1], preferred patch selected by human evaluators
224 | (closer to "0" for left patch p0, "1" for right patch p1,
225 | "0.6" means 60pct people preferred right patch, 40pct preferred left)
226 | scores - N array in [0,1], corresponding to what percentage function agreed with humans
227 | CONSTS
228 | N - number of test triplets in data_loader
229 | '''
230 |
231 | d0s = []
232 | d1s = []
233 | gts = []
234 |
235 | for data in tqdm(data_loader.load_data(), desc=name):
236 | d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist()
237 | d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist()
238 | gts+=data['judge'].cpu().numpy().flatten().tolist()
239 |
240 | d0s = np.array(d0s)
241 | d1s = np.array(d1s)
242 | gts = np.array(gts)
243 | scores = (d0s>> not using cover in training
563 | cover_img.fill_(0.0)
564 | print('no_cover')
565 | if (opt.plain_cover or opt.noise_cover) and (val_cover == 0):
566 | cover_img.fill_(0.0)
567 | print('plain_cover')
568 | b, c, w, h = cover_img.size()
569 |
570 | if opt.plain_cover and (val_cover == 0):
571 | img_w1 = torch.cat((torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(),
572 | torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(),
573 | torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(),
574 | torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda()), dim=2)
575 | img_w2 = torch.cat((torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(),
576 | torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(),
577 | torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(),
578 | torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda()), dim=2)
579 | img_w3 = torch.cat((torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(),
580 | torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(),
581 | torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(),
582 | torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda()), dim=2)
583 | img_w4 = torch.cat((torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(),
584 | torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(),
585 | torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda(),
586 | torch.rand(b, c, 1, 1).repeat(1, 1, w // 4, h // 4).cuda()), dim=2)
587 | img_wh = torch.cat((img_w1, img_w2, img_w3, img_w4), dim=3)
588 | cover_img = cover_img + img_wh
589 | print('if opt.plain_cover and (val_cover == 0):')
590 | if opt.noise_cover and (val_cover == 0):
591 | cover_img = cover_img + ((torch.rand(b, c, w, h) - 0.5) * 2 * 0 / 255).cuda()
592 | print('if opt.noise_cover and (val_cover == 0):')
593 | #+++++++++++++++++++++++++++
594 | cover_imgv = cover_img
595 |
596 | if opt.cover_dependent:
597 | H_input = torch.cat((cover_imgv, secret_imgv), dim=1)
598 | else:
599 | H_input = secret_imgv
600 |
601 | out_dct_1, out_dct_2, out_dct_3, out_dct_4, itm_secret_img = Hnet_S(H_input)
602 | #**************
603 | if i_c != None:
604 | print('if i_c != None')
605 | if type(i_c) == type(1.0):
606 | ####### To keep one channel #######
607 | itm_secret_img_clone = itm_secret_img.clone()
608 | itm_secret_img.fill_(0)
609 | itm_secret_img[:, int(i_c):int(i_c) + 1, :, :] = itm_secret_img_clone[:, int(i_c):int(i_c) + 1, :, :]
610 | if type(i_c) == type(1):
611 | print('aaaaa', i_c)
612 | ####### To set one channel to zero #######
613 | itm_secret_img[:, i_c:i_c + 1, :, :].fill_(0.0)
614 |
615 | if position != None:
616 | print('if position != None')
617 | itm_secret_img[:, :, position:position + 1, position:position + 1].fill_(0.0)
618 | if Se_two == 2:
619 | print('if Se_two == 2')
620 | itm_secret_img_half = itm_secret_img[0:batch_size_secret // 2, :, :, :]
621 | itm_secret_img = itm_secret_img + torch.cat((itm_secret_img_half.clone().fill_(0.0), itm_secret_img_half), 0)
622 | elif type(Se_two) == type(0.1):
623 | print('type(Se_two) == type(0.1)')
624 | itm_secret_img = itm_secret_img + Se_two * torch.rand(itm_secret_img.size()).cuda()
625 | if opt.cover_dependent:
626 | container_img = itm_secret_img
627 | else:
628 | itm_secret_img = itm_secret_img.repeat(opt.num_training, 1, 1, 1)
629 | container_img = Hnet_C(cover_img, out_dct_1, out_dct_2, out_dct_3, out_dct_4, itm_secret_img)
630 | #**************
631 | errH = criterion(container_img, cover_imgv) # Hiding net
632 |
633 | rev_secret_img = Rnet(container_img)
634 | errR = criterion(rev_secret_img, secret_imgv_nh) # Reveal net
635 |
636 | # L1 metric
637 | diffH = (container_img - cover_imgv).abs().mean() * 255
638 | diffR = (rev_secret_img - secret_imgv_nh).abs().mean() * 255
639 | return cover_imgv, container_img, secret_imgv_nh, rev_secret_img, errH, errR, diffH, diffR
640 |
641 |
642 | def train(train_loader, epoch, Hnet_C, Hnet_S, Rnet, criterion):
643 | batch_time = AverageMeter()
644 | data_time = AverageMeter()
645 | Hlosses = AverageMeter()
646 | Rlosses = AverageMeter()
647 | SumLosses = AverageMeter()
648 | Hdiff = AverageMeter()
649 | Rdiff = AverageMeter()
650 |
651 | # Switch to train mode
652 | Hnet_C.train()
653 | Hnet_S.train()
654 | Rnet.train()
655 |
656 | start_time = time.time()
657 |
658 | for i, ((secret_img, secret_target), (cover_img, cover_target)) in enumerate(train_loader, 0):
659 |
660 | data_time.update(time.time() - start_time)
661 |
662 | cover_imgv, container_img, secret_imgv_nh, rev_secret_img, errH, errR, diffH, diffR \
663 | = forward_pass(secret_img, secret_target, cover_img, cover_target, Hnet_C, Hnet_S, Rnet, criterion)
664 |
665 | Hlosses.update(errH.item(), opt.bs_secret * opt.num_cover * opt.num_training) # H loss
666 | Rlosses.update(errR.item(), opt.bs_secret * opt.num_secret * opt.num_training) # R loss
667 | Hdiff.update(diffH.item(), opt.bs_secret * opt.num_cover * opt.num_training)
668 | Rdiff.update(diffR.item(), opt.bs_secret * opt.num_secret * opt.num_training)
669 | '''Hlosses.update(errH.data[0], opt.bs_secret * opt.num_cover * opt.num_training) # H loss
670 | Rlosses.update(errR.data[0], opt.bs_secret * opt.num_secret * opt.num_training) # R loss
671 | Hdiff.update(diffH.data[0], opt.bs_secret * opt.num_cover * opt.num_training)
672 | Rdiff.update(diffR.data[0], opt.bs_secret * opt.num_secret * opt.num_training)'''
673 |
674 | # Loss, backprop, and optimization step
675 | betaerrR_secret = opt.beta * errR
676 | err_sum = errH + betaerrR_secret
677 | optimizer.zero_grad()
678 | err_sum.backward()
679 | optimizer.step()
680 |
681 | # Time spent on one batch
682 | batch_time.update(time.time() - start_time)
683 | start_time = time.time()
684 |
685 | log = '[%d/%d][%d/%d]\tLoss_H: %.6f Loss_R: %.6f L1_H: %.4f L1_R: %.4f \tdatatime: %.4f \tbatchtime: %.4f' % (
686 | epoch, opt.epochs, i, opt.iters_per_epoch,
687 | Hlosses.val, Rlosses.val, Hdiff.val, Rdiff.val, data_time.val, batch_time.val)
688 |
689 | if i % opt.logFrequency == 0:
690 | print(log)
691 |
692 | if epoch == opt.epochs and i % opt.resultPicFrequency == 0:
693 | save_result_pic(opt.bs_secret * opt.num_training, cover_imgv, container_img.data, secret_imgv_nh,
694 | rev_secret_img.data, epoch, i, opt.trainpics)
695 |
696 | if i == opt.iters_per_epoch - 1:
697 | break
698 |
699 | # To save the last batch only
700 | save_result_pic(opt.bs_secret * opt.num_training, cover_imgv, container_img.data, secret_imgv_nh,
701 | rev_secret_img.data, epoch, i, opt.trainpics)
702 |
703 | epoch_log = "Training[%d] Hloss=%.6f\tRloss=%.6f\tHdiff=%.4f\tRdiff=%.4f\tlr= %.6f\t Epoch time= %.4f" % (
704 | epoch, Hlosses.avg, Rlosses.avg, Hdiff.avg, Rdiff.avg, optimizer.param_groups[0]['lr'], batch_time.sum)
705 | print_log(epoch_log, logPath)
706 |
707 | if not opt.debug:
708 | writer.add_scalar("lr/lr", optimizer.param_groups[0]['lr'], epoch)
709 | writer.add_scalar("lr/beta", opt.beta, epoch)
710 | writer.add_scalar('train/H_loss', Hlosses.avg, epoch)
711 | writer.add_scalar('train/R_loss', Rlosses.avg, epoch)
712 | writer.add_scalar('train/sum_loss', SumLosses.avg, epoch)
713 | writer.add_scalar('train/H_diff', Hdiff.avg, epoch)
714 | writer.add_scalar('train/R_diff', Rdiff.avg, epoch)
715 |
716 |
717 | def validation(val_loader, epoch, Hnet_C, Hnet_S,Rnet, criterion):
718 | print(
719 | "#################################################### validation begin ########################################################")
720 | start_time = time.time()
721 | Hnet_C.eval()
722 | Hnet_S.eval()
723 | Rnet.eval()
724 | batch_time = AverageMeter()
725 | Hlosses = AverageMeter()
726 | Rlosses = AverageMeter()
727 | SumLosses = AverageMeter()
728 | Hdiff = AverageMeter()
729 | Rdiff = AverageMeter()
730 |
731 | for i, ((secret_img, secret_target), (cover_img, cover_target)) in enumerate(val_loader, 0):
732 |
733 | cover_imgv, container_img, secret_imgv_nh, rev_secret_img, errH, errR, diffH, diffR \
734 | = forward_pass(secret_img, secret_target, cover_img, cover_target, Hnet_C, Hnet_S, Rnet, criterion, val_cover=1)
735 |
736 | Hlosses.update(errH.item(), opt.bs_secret * opt.num_cover * opt.num_training) # H loss
737 | Rlosses.update(errR.item(), opt.bs_secret * opt.num_secret * opt.num_training) # R loss
738 | Hdiff.update(diffH.item(), opt.bs_secret * opt.num_cover * opt.num_training)
739 | Rdiff.update(diffR.item(), opt.bs_secret * opt.num_secret * opt.num_training)
740 | '''Hlosses.update(errH.data[0], opt.bs_secret * opt.num_cover * opt.num_training) # H loss
741 | Rlosses.update(errR.data[0], opt.bs_secret * opt.num_secret * opt.num_training) # R loss
742 | Hdiff.update(diffH.data[0], opt.bs_secret * opt.num_cover * opt.num_training)
743 | Rdiff.update(diffR.data[0], opt.bs_secret * opt.num_secret * opt.num_training)'''
744 |
745 | if i == 0:
746 | save_result_pic(opt.bs_secret * opt.num_training, cover_imgv, container_img.data, secret_imgv_nh,
747 | rev_secret_img.data, epoch, i, opt.validationpics)
748 | if epoch == opt.epochs and i % opt.resultPicFrequency == 0:
749 | save_result_pic(opt.bs_secret * opt.num_training, cover_imgv, container_img.data, secret_imgv_nh,
750 | rev_secret_img.data, epoch, i, opt.trainpics)
751 | if opt.num_secret >= 6:
752 | i_total = 80
753 | else:
754 | i_total = 200
755 | if i == i_total - 1:
756 | break
757 |
758 | batch_time.update(time.time() - start_time)
759 | start_time = time.time()
760 |
761 | val_log = "validation[%d] val_Hloss = %.6f\t val_Rloss = %.6f\t val_Hdiff = %.6f\t val_Rdiff=%.2f\t batch time=%.2f" % (
762 | epoch, Hlosses.val, Rlosses.val, Hdiff.val, Rdiff.val, batch_time.val)
763 | if i % opt.logFrequency == 0:
764 | print(val_log)
765 |
766 | val_log = "validation[%d] val_Hloss = %.6f\t val_Rloss = %.6f\t val_Hdiff = %.4f\t val_Rdiff=%.4f\t validation time=%.2f" % (
767 | epoch, Hlosses.avg, Rlosses.avg, Hdiff.avg, Rdiff.avg, batch_time.sum)
768 | print_log(val_log, logPath)
769 |
770 | if not opt.debug:
771 | writer.add_scalar('validation/H_loss_avg', Hlosses.avg, epoch)
772 | writer.add_scalar('validation/R_loss_avg', Rlosses.avg, epoch)
773 | writer.add_scalar('validation/H_diff_avg', Hdiff.avg, epoch)
774 | writer.add_scalar('validation/R_diff_avg', Rdiff.avg, epoch)
775 |
776 | print(
777 | "#################################################### validation end ########################################################")
778 | return Hlosses.avg, Rlosses.avg, Hdiff.avg, Rdiff.avg
779 |
780 |
781 | #def analysis(val_loader, epoch, Hnet, Rnet, HnetD, RnetD, criterion):
782 | def analysis(val_loader, epoch, Hnet_C, Hnet_S, Rnet, criterion):
783 | print(
784 | "#################################################### analysis begin ########################################################")
785 |
786 | Hnet_C.eval()
787 | Hnet_S.eval()
788 | Rnet.eval()
789 | Hdiff = AverageMeter()
790 | Rdiff = AverageMeter()
791 | psnr_C = AverageMeter()
792 | psnr_S = AverageMeter()
793 | ssim_C = AverageMeter()
794 | ssim_S = AverageMeter()
795 | lpips_C = AverageMeter()
796 | lpips_S = AverageMeter()
797 |
798 | #HnetD.eval()
799 | #RnetD.eval()
800 | import warnings
801 | warnings.filterwarnings("ignore")
802 |
803 | for ii, ((secret_img, secret_target), (cover_img, cover_target)) in enumerate(val_loader, 0):
804 |
805 | ####################################### Cover Agnostic #######################################
806 | cover_imgv, container_img, secret_imgv_nh, rev_secret_img, errH, errR, diffH, diffR \
807 | = forward_pass(secret_img, secret_target, cover_img, cover_target, Hnet_C, Hnet_S, Rnet, criterion, val_cover=1)
808 | secret_encoded = container_img - cover_imgv
809 |
810 | '''save_result_pic_analysis(opt.bs_secret * opt.num_training, cover_imgv.clone(), container_img.clone(),
811 | secret_imgv_nh.clone(), rev_secret_img.clone(), epoch, i, opt.validationpics)'''
812 |
813 | N, _, _, _ = rev_secret_img.shape
814 |
815 | cover_img_numpy = cover_imgv.clone().cpu().detach().numpy()
816 | container_img_numpy = container_img.clone().cpu().detach().numpy()
817 |
818 | cover_img_numpy = cover_img_numpy.transpose(0, 2, 3, 1)
819 | container_img_numpy = container_img_numpy.transpose(0, 2, 3, 1)
820 |
821 | rev_secret_numpy = rev_secret_img.cpu().detach().numpy()
822 | secret_img_numpy = secret_imgv_nh.cpu().detach().numpy()
823 |
824 | rev_secret_numpy = rev_secret_numpy.transpose(0, 2, 3, 1)
825 | secret_img_numpy = secret_img_numpy.transpose(0, 2, 3, 1)
826 |
827 | # PSNR
828 | print("Cover Agnostic")
829 |
830 | print("Secret APD C:", diffH.item())
831 |
832 | psnr_c = np.zeros((N, 3))
833 | for i in range(N):
834 | psnr_c[i, 0] = PSNR(cover_img_numpy[i, :, :, 0], container_img_numpy[i, :, :, 0])
835 | psnr_c[i, 1] = PSNR(cover_img_numpy[i, :, :, 1], container_img_numpy[i, :, :, 1])
836 | psnr_c[i, 2] = PSNR(cover_img_numpy[i, :, :, 2], container_img_numpy[i, :, :, 2])
837 | print("Avg. PSNR C:", psnr_c.mean().item())
838 |
839 | # SSIM
840 | ssim_c = np.zeros(N)
841 | for i in range(N):
842 | ssim_c[i] = SSIM(cover_img_numpy[i], container_img_numpy[i], multichannel=True)
843 | print("Avg. SSIM C:", ssim_c.mean().item())
844 |
845 | # LPIPS
846 | import PerceptualSimilarity.models
847 | model = PerceptualSimilarity.models.PerceptualLoss(model='net-lin', net='alex', use_gpu=True, gpu_ids=[0])
848 | lpips_c = model.forward(cover_imgv, container_img)
849 | print("Avg. LPIPS C:", lpips_c.mean().item())
850 |
851 | print("Secret APD S:", diffR.item())
852 |
853 | psnr_s = np.zeros(N)
854 | for i in range(N):
855 | psnr_s[i] = PSNR(secret_img_numpy[i], rev_secret_numpy[i])
856 | print("Avg. PSNR S:", psnr_s.mean().item())
857 |
858 | # SSIM
859 | ssim_s = np.zeros(N)
860 | for i in range(N):
861 | ssim_s[i] = SSIM(secret_img_numpy[i], rev_secret_numpy[i], multichannel=True)
862 | print("Avg. SSIM S:", ssim_s.mean().item())
863 |
864 | # LPIPS
865 | import PerceptualSimilarity.models
866 | model = PerceptualSimilarity.models.PerceptualLoss(model='net-lin', net='alex', use_gpu=True, gpu_ids=[0])
867 | secret_imgv_nh_1 = secret_imgv_nh.view(-1, 3, 128, 128)
868 | rev_secret_img_1 = rev_secret_img.view(-1, 3, 128, 128)
869 | lpips_s = model.forward(secret_imgv_nh_1, rev_secret_img_1)
870 | print("Avg. LPIPS S:", lpips_s.mean().item())
871 |
872 | #print("*******DONE!**********")
873 |
874 | #break
875 | lpips_S.update(lpips_s.mean().item(), opt.bs_secret * opt.num_cover * opt.num_training) # H loss
876 | psnr_S.update(psnr_s.mean().item(), opt.bs_secret * opt.num_secret * opt.num_training) # R loss
877 | ssim_S.update(ssim_s.mean().item(), opt.bs_secret * opt.num_cover * opt.num_training)
878 | Rdiff.update(diffR.item(), opt.bs_secret * opt.num_secret * opt.num_training)
879 | lpips_C.update(lpips_c.mean().item(), opt.bs_secret * opt.num_cover * opt.num_training) # H loss
880 | psnr_C.update(psnr_c.mean().item(), opt.bs_secret * opt.num_secret * opt.num_training) # R loss
881 | ssim_C.update(ssim_c.mean().item(), opt.bs_secret * opt.num_cover * opt.num_training)
882 | Hdiff.update(diffH.item(), opt.bs_secret * opt.num_secret * opt.num_training)
883 | if opt.num_secret >= 6:
884 | i_total = 80
885 | else:
886 | i_total = 200
887 | if ii == i_total - 1:
888 | break
889 | print('Hdiff.avg, Rdiff.avg', Hdiff.avg, Rdiff.avg)
890 | print('Hdiff.avg', Hdiff.avg, 'psnr_c.avg', psnr_C.avg, 'ssim_c.avg', ssim_C.avg, 'lpips_c.avg', lpips_C.avg)
891 | print('Rdiff.avg', Rdiff.avg, 'psnr_s.avg', psnr_S.avg, 'ssim_s.avg', ssim_S.avg, 'lpips_s.avg', lpips_S.avg)
892 |
893 |
894 |
895 | def print_log(log_info, log_path, console=True):
896 | # print the info into the console
897 | if console:
898 | print(log_info)
899 | # debug mode don't write the log into files
900 | if not opt.debug:
901 | # write the log into log file
902 | if not os.path.exists(log_path):
903 | fp = open(log_path, "w")
904 | fp.writelines(log_info + "\n")
905 | else:
906 | with open(log_path, 'a+') as f:
907 | f.writelines(log_info + '\n')
908 |
909 |
910 | def adjust_learning_rate(optimizer, epoch):
911 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
912 | lr = opt.lr * (0.1 ** (epoch // 30))
913 | for param_group in optimizer.param_groups:
914 | param_group['lr'] = lr
915 |
916 |
917 | # save result pic and the coverImg filePath and the secretImg filePath
918 | def save_result_pic_analysis(bs_secret_times_num_training, cover, container, secret, rev_secret, epoch, i,
919 | save_path=None, postname=''):
920 | path = './qualitative_results/'
921 | if not os.path.exists(path):
922 | os.makedirs(path)
923 | resultImgName = path + 'universal_qualitative_results.png'
924 |
925 | cover = cover[:4]
926 | container = container[:4]
927 | secret = secret[:4]
928 | rev_secret = rev_secret[:4]
929 |
930 | cover_gap = container - cover
931 | secret_gap = rev_secret - secret
932 | cover_gap = (cover_gap * 10 + 0.5).clamp_(0.0, 1.0)
933 | secret_gap = (secret_gap * 10 + 0.5).clamp_(0.0, 1.0)
934 |
935 | for i_cover in range(4):
936 | cover_i = cover[:, i_cover * opt.channel_cover:(i_cover + 1) * opt.channel_cover, :, :]
937 | container_i = container[:, i_cover * opt.channel_cover:(i_cover + 1) * opt.channel_cover, :, :]
938 | cover_gap_i = cover_gap[:, i_cover * opt.channel_cover:(i_cover + 1) * opt.channel_cover, :, :]
939 |
940 | if i_cover == 0:
941 | showCover = torch.cat((cover_i, container_i, cover_gap_i), 0)
942 | else:
943 | showCover = torch.cat((showCover, cover_i, container_i, cover_gap_i), 0)
944 |
945 | for i_secret in range(4):
946 | secret_i = secret[:, i_secret * opt.channel_secret:(i_secret + 1) * opt.channel_secret, :, :]
947 | rev_secret_i = rev_secret[:, i_secret * opt.channel_secret:(i_secret + 1) * opt.channel_secret, :, :]
948 | secret_gap_i = secret_gap[:, i_secret * opt.channel_secret:(i_secret + 1) * opt.channel_secret, :, :]
949 |
950 | if i_secret == 0:
951 | showSecret = torch.cat((secret_i, rev_secret_i, secret_gap_i), 0)
952 | else:
953 | showSecret = torch.cat((showSecret, secret_i, rev_secret_i, secret_gap_i), 0)
954 |
955 | showAll = torch.cat((showCover, showSecret), 0)
956 | showAll = showAll.reshape(6, 4, 3, 128, 128)
957 | showAll = showAll.permute(1, 0, 2, 3, 4)
958 | showAll = showAll.reshape(4 * 6, 3, 128, 128)
959 | vutils.save_image(showAll, resultImgName, nrow=6, padding=1, normalize=False)
960 |
961 |
962 | # save result pic and the coverImg filePath and the secretImg filePath
963 | def save_result_pic(bs_secret_times_num_training, cover, container, secret, rev_secret, epoch, i, save_path=None,
964 | postname=''):
965 | # if not opt.debug:
966 | # cover=container: bs*nt/nc; secret=rev_secret: bs*nt/3*nh
967 | if opt.debug:
968 | save_path = './debug/debug_images'
969 | resultImgName = '%s/ResultPics_epoch%03d_batch%04d%s.png' % (save_path, epoch, i, postname)
970 |
971 | cover_gap = container - cover
972 | secret_gap = rev_secret - secret
973 | cover_gap = (cover_gap * 10 + 0.5).clamp_(0.0, 1.0)
974 | secret_gap = (secret_gap * 10 + 0.5).clamp_(0.0, 1.0)
975 | # print(cover_gap.abs().sum(dim=-1).sum(dim=-1).sum(dim=-1), secret_gap.abs().sum(dim=-1).sum(dim=-1).sum(dim=-1))
976 |
977 | # showCover = torch.cat((cover, container, cover_gap),0)
978 |
979 | for i_cover in range(opt.num_cover):
980 | cover_i = cover[:, i_cover * opt.channel_cover:(i_cover + 1) * opt.channel_cover, :, :]
981 | container_i = container[:, i_cover * opt.channel_cover:(i_cover + 1) * opt.channel_cover, :, :]
982 | cover_gap_i = cover_gap[:, i_cover * opt.channel_cover:(i_cover + 1) * opt.channel_cover, :, :]
983 |
984 | if i_cover == 0:
985 | showCover = torch.cat((cover_i, container_i, cover_gap_i), 0)
986 | else:
987 | showCover = torch.cat((showCover, cover_i, container_i, cover_gap_i), 0)
988 |
989 | for i_secret in range(opt.num_secret):
990 | secret_i = secret[:, i_secret * opt.channel_secret:(i_secret + 1) * opt.channel_secret, :, :]
991 | rev_secret_i = rev_secret[:, i_secret * opt.channel_secret:(i_secret + 1) * opt.channel_secret, :, :]
992 | secret_gap_i = secret_gap[:, i_secret * opt.channel_secret:(i_secret + 1) * opt.channel_secret, :, :]
993 |
994 | if i_secret == 0:
995 | showSecret = torch.cat((secret_i, rev_secret_i, secret_gap_i), 0)
996 | else:
997 | showSecret = torch.cat((showSecret, secret_i, rev_secret_i, secret_gap_i), 0)
998 |
999 | if opt.channel_secret == opt.channel_cover:
1000 | showAll = torch.cat((showCover, showSecret), 0)
1001 | vutils.save_image(showAll, resultImgName, nrow=bs_secret_times_num_training, padding=1, normalize=True)
1002 | else:
1003 | ContainerImgName = '%s/ContainerPics_epoch%03d_batch%04d.png' % (save_path, epoch, i)
1004 | SecretImgName = '%s/SecretPics_epoch%03d_batch%04d.png' % (save_path, epoch, i)
1005 | vutils.save_image(showCover, ContainerImgName, nrow=bs_secret_times_num_training, padding=1, normalize=True)
1006 | vutils.save_image(showSecret, SecretImgName, nrow=bs_secret_times_num_training, padding=1, normalize=True)
1007 |
1008 |
1009 |
1010 |
1011 | class AverageMeter(object):
1012 | """
1013 | Computes and stores the average and current value.
1014 | """
1015 |
1016 | def __init__(self):
1017 | self.reset()
1018 |
1019 | def reset(self):
1020 | self.val = 0
1021 | self.avg = 0
1022 | self.sum = 0
1023 | self.count = 0
1024 |
1025 | def update(self, val, n=1):
1026 | self.val = val
1027 | self.sum += val * n
1028 | self.count += n
1029 | self.avg = self.sum / self.count
1030 |
1031 |
1032 | if __name__ == '__main__':
1033 | main()
--------------------------------------------------------------------------------