├── pytorch_ssim
├── LICENSE.txt
├── setup.cfg
├── einstein.png
├── max_ssim.gif
├── .gitignore
├── setup.py
├── max_ssim.py
├── README.md
└── pytorch_ssim
│ └── __init__.py
├── PerceptualSimilarity
├── data
│ ├── __init__.py
│ ├── dataset
│ │ ├── __init__.py
│ │ ├── base_dataset.py
│ │ ├── jnd_dataset.py
│ │ └── twoafc_dataset.py
│ ├── base_data_loader.py
│ ├── data_loader.py
│ ├── custom_dataset_data_loader.py
│ └── image_folder.py
├── util
│ ├── __init__.py
│ ├── util.py
│ ├── html.py
│ └── visualizer.py
├── .gitignore
├── imgs
│ ├── ex_p0.png
│ ├── ex_p1.png
│ ├── ex_ref.png
│ ├── fig1.png
│ ├── ex_dir0
│ │ ├── 0.png
│ │ └── 1.png
│ ├── ex_dir1
│ │ ├── 0.png
│ │ └── 1.png
│ └── ex_dir_pair
│ │ ├── ex_p0.png
│ │ ├── ex_p1.png
│ │ └── ex_ref.png
├── scripts
│ ├── eval_valsets.sh
│ ├── train_test_metric.sh
│ ├── train_test_metric_tune.sh
│ ├── train_test_metric_scratch.sh
│ ├── download_dataset_valonly.sh
│ └── download_dataset.sh
├── lpips
│ ├── weights
│ │ ├── v0.0
│ │ │ ├── alex.pth
│ │ │ ├── vgg.pth
│ │ │ └── squeeze.pth
│ │ └── v0.1
│ │ │ ├── alex.pth
│ │ │ ├── vgg.pth
│ │ │ └── squeeze.pth
│ ├── __init__.py
│ ├── pretrained_networks.py
│ ├── lpips.py
│ └── trainer.py
├── requirements.txt
├── setup.py
├── lpips_2imgs.py
├── lpips_2dirs.py
├── LICENSE
├── test_network.py
├── lpips_loss.py
├── Dockerfile
├── lpips_1dir_allpairs.py
├── test_dataset_model.py
├── train.py
└── README.md
├── pytorch_msssim
├── LICENSE.txt
├── try1.py
├── setup.cfg
├── einstein.png
├── .gitignore
├── setup.py
├── max_ssim.py
├── README.md
└── pytorch_msssim
│ └── __init__.py
├── Examples
└── BRViT_sample1.jfif
├── LICENSE
├── README.md
├── evaluate.py
├── store_results.py
├── train.py
├── Bokeh_Data
└── test.csv
└── model.py
/pytorch_ssim/LICENSE.txt:
--------------------------------------------------------------------------------
1 | MIT
2 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/util/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/pytorch_msssim/LICENSE.txt:
--------------------------------------------------------------------------------
1 | MIT
2 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/data/dataset/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/pytorch_msssim/try1.py:
--------------------------------------------------------------------------------
1 | import pytorch_msssim
2 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 |
3 | checkpoints/*
4 |
--------------------------------------------------------------------------------
/pytorch_msssim/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | description-file = README.md
3 |
--------------------------------------------------------------------------------
/pytorch_ssim/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | description-file = README.md
3 |
--------------------------------------------------------------------------------
/pytorch_ssim/einstein.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Soester10/Bokeh-Rendering-with-Vision-Transformers/HEAD/pytorch_ssim/einstein.png
--------------------------------------------------------------------------------
/pytorch_ssim/max_ssim.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Soester10/Bokeh-Rendering-with-Vision-Transformers/HEAD/pytorch_ssim/max_ssim.gif
--------------------------------------------------------------------------------
/Examples/BRViT_sample1.jfif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Soester10/Bokeh-Rendering-with-Vision-Transformers/HEAD/Examples/BRViT_sample1.jfif
--------------------------------------------------------------------------------
/pytorch_msssim/einstein.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Soester10/Bokeh-Rendering-with-Vision-Transformers/HEAD/pytorch_msssim/einstein.png
--------------------------------------------------------------------------------
/PerceptualSimilarity/imgs/ex_p0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Soester10/Bokeh-Rendering-with-Vision-Transformers/HEAD/PerceptualSimilarity/imgs/ex_p0.png
--------------------------------------------------------------------------------
/PerceptualSimilarity/imgs/ex_p1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Soester10/Bokeh-Rendering-with-Vision-Transformers/HEAD/PerceptualSimilarity/imgs/ex_p1.png
--------------------------------------------------------------------------------
/PerceptualSimilarity/imgs/ex_ref.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Soester10/Bokeh-Rendering-with-Vision-Transformers/HEAD/PerceptualSimilarity/imgs/ex_ref.png
--------------------------------------------------------------------------------
/PerceptualSimilarity/imgs/fig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Soester10/Bokeh-Rendering-with-Vision-Transformers/HEAD/PerceptualSimilarity/imgs/fig1.png
--------------------------------------------------------------------------------
/PerceptualSimilarity/scripts/eval_valsets.sh:
--------------------------------------------------------------------------------
1 |
2 | python ./test_dataset_model.py --dataset_mode 2afc --model lpips --net alex --use_gpu --batch_size 50
3 |
4 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/imgs/ex_dir0/0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Soester10/Bokeh-Rendering-with-Vision-Transformers/HEAD/PerceptualSimilarity/imgs/ex_dir0/0.png
--------------------------------------------------------------------------------
/PerceptualSimilarity/imgs/ex_dir0/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Soester10/Bokeh-Rendering-with-Vision-Transformers/HEAD/PerceptualSimilarity/imgs/ex_dir0/1.png
--------------------------------------------------------------------------------
/PerceptualSimilarity/imgs/ex_dir1/0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Soester10/Bokeh-Rendering-with-Vision-Transformers/HEAD/PerceptualSimilarity/imgs/ex_dir1/0.png
--------------------------------------------------------------------------------
/PerceptualSimilarity/imgs/ex_dir1/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Soester10/Bokeh-Rendering-with-Vision-Transformers/HEAD/PerceptualSimilarity/imgs/ex_dir1/1.png
--------------------------------------------------------------------------------
/pytorch_msssim/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.gif
3 | *.png
4 | *.jpg
5 | test*
6 | !einstein.png
7 | !max_ssim.gif
8 | MANIFEST
9 | dist/*
10 | .sync-config.cson
11 |
--------------------------------------------------------------------------------
/pytorch_ssim/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.gif
3 | *.png
4 | *.jpg
5 | test*
6 | !einstein.png
7 | !max_ssim.gif
8 | MANIFEST
9 | dist/*
10 | .sync-config.cson
11 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/imgs/ex_dir_pair/ex_p0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Soester10/Bokeh-Rendering-with-Vision-Transformers/HEAD/PerceptualSimilarity/imgs/ex_dir_pair/ex_p0.png
--------------------------------------------------------------------------------
/PerceptualSimilarity/imgs/ex_dir_pair/ex_p1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Soester10/Bokeh-Rendering-with-Vision-Transformers/HEAD/PerceptualSimilarity/imgs/ex_dir_pair/ex_p1.png
--------------------------------------------------------------------------------
/PerceptualSimilarity/imgs/ex_dir_pair/ex_ref.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Soester10/Bokeh-Rendering-with-Vision-Transformers/HEAD/PerceptualSimilarity/imgs/ex_dir_pair/ex_ref.png
--------------------------------------------------------------------------------
/PerceptualSimilarity/lpips/weights/v0.0/alex.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Soester10/Bokeh-Rendering-with-Vision-Transformers/HEAD/PerceptualSimilarity/lpips/weights/v0.0/alex.pth
--------------------------------------------------------------------------------
/PerceptualSimilarity/lpips/weights/v0.0/vgg.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Soester10/Bokeh-Rendering-with-Vision-Transformers/HEAD/PerceptualSimilarity/lpips/weights/v0.0/vgg.pth
--------------------------------------------------------------------------------
/PerceptualSimilarity/lpips/weights/v0.1/alex.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Soester10/Bokeh-Rendering-with-Vision-Transformers/HEAD/PerceptualSimilarity/lpips/weights/v0.1/alex.pth
--------------------------------------------------------------------------------
/PerceptualSimilarity/lpips/weights/v0.1/vgg.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Soester10/Bokeh-Rendering-with-Vision-Transformers/HEAD/PerceptualSimilarity/lpips/weights/v0.1/vgg.pth
--------------------------------------------------------------------------------
/PerceptualSimilarity/lpips/weights/v0.0/squeeze.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Soester10/Bokeh-Rendering-with-Vision-Transformers/HEAD/PerceptualSimilarity/lpips/weights/v0.0/squeeze.pth
--------------------------------------------------------------------------------
/PerceptualSimilarity/lpips/weights/v0.1/squeeze.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Soester10/Bokeh-Rendering-with-Vision-Transformers/HEAD/PerceptualSimilarity/lpips/weights/v0.1/squeeze.pth
--------------------------------------------------------------------------------
/PerceptualSimilarity/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=0.4.0
2 | torchvision>=0.2.1
3 | numpy>=1.14.3
4 | scipy>=1.0.1
5 | scikit-image>=0.13.0
6 | opencv>=2.4.11
7 | matplotlib>=1.5.1
8 | tqdm>=4.28.1
9 | jupyter
10 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/data/base_data_loader.py:
--------------------------------------------------------------------------------
1 |
2 | class BaseDataLoader():
3 | def __init__(self):
4 | pass
5 |
6 | def initialize(self):
7 | pass
8 |
9 | def load_data():
10 | return None
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/scripts/train_test_metric.sh:
--------------------------------------------------------------------------------
1 |
2 | TRIAL=${1}
3 | NET=${2}
4 | mkdir checkpoints
5 | mkdir checkpoints/${NET}_${TRIAL}
6 | python ./train.py --use_gpu --net ${NET} --name ${NET}_${TRIAL}
7 | python ./test_dataset_model.py --use_gpu --net ${NET} --model_path ./checkpoints/${NET}_${TRIAL}/latest_net_.pth
8 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/data/dataset/base_dataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | class BaseDataset(data.Dataset):
4 | def __init__(self):
5 | super(BaseDataset, self).__init__()
6 |
7 | def name(self):
8 | return 'BaseDataset'
9 |
10 | def initialize(self):
11 | pass
12 |
13 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/scripts/train_test_metric_tune.sh:
--------------------------------------------------------------------------------
1 |
2 | TRIAL=${1}
3 | NET=${2}
4 | mkdir checkpoints
5 | mkdir checkpoints/${NET}_${TRIAL}_tune
6 | python ./train.py --train_trunk --use_gpu --net ${NET} --name ${NET}_${TRIAL}_tune
7 | python ./test_dataset_model.py --train_trunk --use_gpu --net ${NET} --model_path ./checkpoints/${NET}_${TRIAL}_tune/latest_net_.pth
8 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/scripts/train_test_metric_scratch.sh:
--------------------------------------------------------------------------------
1 |
2 | TRIAL=${1}
3 | NET=${2}
4 | mkdir checkpoints
5 | mkdir checkpoints/${NET}_${TRIAL}_scratch
6 | python ./train.py --from_scratch --train_trunk --use_gpu --net ${NET} --name ${NET}_${TRIAL}_scratch
7 | python ./test_dataset_model.py --from_scratch --train_trunk --use_gpu --net ${NET} --model_path ./checkpoints/${NET}_${TRIAL}_scratch/latest_net_.pth
8 |
9 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/data/data_loader.py:
--------------------------------------------------------------------------------
1 | def CreateDataLoader(datafolder,dataroot='./dataset',dataset_mode='2afc',load_size=64,batch_size=1,serial_batches=True,nThreads=4):
2 | from data.custom_dataset_data_loader import CustomDatasetDataLoader
3 | data_loader = CustomDatasetDataLoader()
4 | # print(data_loader.name())
5 | data_loader.initialize(datafolder,dataroot=dataroot+'/'+dataset_mode,dataset_mode=dataset_mode,load_size=load_size,batch_size=batch_size,serial_batches=serial_batches, nThreads=nThreads)
6 | return data_loader
7 |
--------------------------------------------------------------------------------
/pytorch_msssim/setup.py:
--------------------------------------------------------------------------------
1 | from distutils.core import setup
2 |
3 | setup(
4 | name = 'pytorch_msssim',
5 | packages = ['pytorch_msssim'], # this must be the same as the name above
6 | version = '0.1',
7 | description = 'Differentiable multi-scale structural similarity (MS-SSIM) index',
8 | author = 'Jorge Pessoa',
9 | author_email = 'jpessoa.on@gmail.com',
10 | url = 'https://github.com/jorge-pessoa/pytorch-msssim', # use the URL to the github repo
11 | keywords = ['pytorch', 'image-processing', 'deep-learning', 'ms-ssim'], # arbitrary keywords
12 | classifiers = [],
13 | )
14 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/scripts/download_dataset_valonly.sh:
--------------------------------------------------------------------------------
1 |
2 | mkdir dataset
3 |
4 | # JND Dataset
5 | wget https://people.eecs.berkeley.edu/~rich.zhang/projects/2018_perceptual/dataset/jnd.tar.gz -O ./dataset/jnd.tar.gz
6 |
7 | mkdir dataset/jnd
8 | tar -xzf ./dataset/jnd.tar.gz -C ./dataset
9 | rm ./dataset/jnd.tar.gz
10 |
11 | # 2AFC Val set
12 | mkdir dataset/2afc/
13 | wget https://people.eecs.berkeley.edu/~rich.zhang/projects/2018_perceptual/dataset/twoafc_val.tar.gz -O ./dataset/twoafc_val.tar.gz
14 |
15 | mkdir dataset/2afc/val
16 | tar -xzf ./dataset/twoafc_val.tar.gz -C ./dataset/2afc
17 | rm ./dataset/twoafc_val.tar.gz
18 |
--------------------------------------------------------------------------------
/pytorch_ssim/setup.py:
--------------------------------------------------------------------------------
1 | from distutils.core import setup
2 | setup(
3 | name = 'pytorch_ssim',
4 | packages = ['pytorch_ssim'], # this must be the same as the name above
5 | version = '0.1',
6 | description = 'Differentiable structural similarity (SSIM) index',
7 | author = 'Po-Hsun (Evan) Su',
8 | author_email = 'evan.pohsun.su@gmail.com',
9 | url = 'https://github.com/Po-Hsun-Su/pytorch-ssim', # use the URL to the github repo
10 | download_url = 'https://github.com/Po-Hsun-Su/pytorch-ssim/archive/0.1.tar.gz', # I'll explain this in a second
11 | keywords = ['pytorch', 'image-processing', 'deep-learning'], # arbitrary keywords
12 | classifiers = [],
13 | )
14 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/setup.py:
--------------------------------------------------------------------------------
1 |
2 | import setuptools
3 | with open("README.md", "r") as fh:
4 | long_description = fh.read()
5 | setuptools.setup(
6 | name='lpips',
7 | version='0.1.2',
8 | author="Richard Zhang",
9 | author_email="rizhang@adobe.com",
10 | description="LPIPS Similarity metric",
11 | long_description=long_description,
12 | long_description_content_type="text/markdown",
13 | url="https://github.com/richzhang/PerceptualSimilarity",
14 | packages=['lpips'],
15 | package_data={'lpips': ['weights/v0.0/*.pth','weights/v0.1/*.pth']},
16 | include_package_data=True,
17 | classifiers=[
18 | "Programming Language :: Python :: 3",
19 | "License :: OSI Approved :: BSD License",
20 | "Operating System :: OS Independent",
21 | ],
22 | )
23 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/scripts/download_dataset.sh:
--------------------------------------------------------------------------------
1 |
2 | mkdir dataset
3 |
4 | # JND Dataset
5 | wget https://people.eecs.berkeley.edu/~rich.zhang/projects/2018_perceptual/dataset/jnd.tar.gz -O ./dataset/jnd.tar.gz
6 |
7 | mkdir dataset/jnd
8 | tar -xzf ./dataset/jnd.tar.gz -C ./dataset
9 | rm ./dataset/jnd.tar.gz
10 |
11 | # 2AFC Val set
12 | mkdir dataset/2afc/
13 | wget https://people.eecs.berkeley.edu/~rich.zhang/projects/2018_perceptual/dataset/twoafc_val.tar.gz -O ./dataset/twoafc_val.tar.gz
14 |
15 | mkdir dataset/2afc/val
16 | tar -xzf ./dataset/twoafc_val.tar.gz -C ./dataset/2afc
17 | rm ./dataset/twoafc_val.tar.gz
18 |
19 | # 2AFC Train set
20 | mkdir dataset/2afc/
21 | wget https://people.eecs.berkeley.edu/~rich.zhang/projects/2018_perceptual/dataset/twoafc_train.tar.gz -O ./dataset/twoafc_train.tar.gz
22 |
23 | mkdir dataset/2afc/train
24 | tar -xzf ./dataset/twoafc_train.tar.gz -C ./dataset/2afc
25 | rm ./dataset/twoafc_train.tar.gz
26 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/lpips_2imgs.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import lpips
3 |
4 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
5 | parser.add_argument('-p0','--path0', type=str, default='./imgs/ex_ref.png')
6 | parser.add_argument('-p1','--path1', type=str, default='./imgs/ex_p0.png')
7 | parser.add_argument('-v','--version', type=str, default='0.1')
8 | parser.add_argument('--use_gpu', action='store_true', help='turn on flag to use GPU')
9 |
10 | opt = parser.parse_args()
11 |
12 | ## Initializing the model
13 | loss_fn = lpips.LPIPS(net='alex',version=opt.version)
14 |
15 | if(opt.use_gpu):
16 | loss_fn.cuda()
17 |
18 | # Load images
19 | img0 = lpips.im2tensor(lpips.load_image(opt.path0)) # RGB image from [-1,1]
20 | img1 = lpips.im2tensor(lpips.load_image(opt.path1))
21 |
22 | if(opt.use_gpu):
23 | img0 = img0.cuda()
24 | img1 = img1.cuda()
25 |
26 | # Compute distance
27 | dist01 = loss_fn.forward(img0,img1)
28 | print('Distance: %.3f'%dist01)
29 |
--------------------------------------------------------------------------------
/pytorch_ssim/max_ssim.py:
--------------------------------------------------------------------------------
1 | import pytorch_ssim
2 | import torch
3 | from torch.autograd import Variable
4 | from torch import optim
5 | import cv2
6 | import numpy as np
7 |
8 | npImg1 = cv2.imread("einstein.png")
9 |
10 | img1 = torch.from_numpy(np.rollaxis(npImg1, 2)).float().unsqueeze(0)/255.0
11 | img2 = torch.rand(img1.size())
12 |
13 | if torch.cuda.is_available():
14 | img1 = img1.cuda()
15 | img2 = img2.cuda()
16 |
17 |
18 | img1 = Variable( img1, requires_grad=False)
19 | img2 = Variable( img2, requires_grad = True)
20 |
21 |
22 | # Functional: pytorch_ssim.ssim(img1, img2, window_size = 11, size_average = True)
23 | ssim_value = pytorch_ssim.ssim(img1, img2).data[0]
24 | print("Initial ssim:", ssim_value)
25 |
26 | # Module: pytorch_ssim.SSIM(window_size = 11, size_average = True)
27 | ssim_loss = pytorch_ssim.SSIM()
28 |
29 | optimizer = optim.Adam([img2], lr=0.01)
30 |
31 | while ssim_value < 0.95:
32 | optimizer.zero_grad()
33 | ssim_out = -ssim_loss(img1, img2)
34 | ssim_value = - ssim_out.data[0]
35 | print(ssim_value)
36 | ssim_out.backward()
37 | optimizer.step()
38 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Hariharan N
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/lpips_2dirs.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import lpips
4 |
5 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
6 | parser.add_argument('-d0','--dir0', type=str, default='./imgs/ex_dir0')
7 | parser.add_argument('-d1','--dir1', type=str, default='./imgs/ex_dir1')
8 | parser.add_argument('-o','--out', type=str, default='./imgs/example_dists.txt')
9 | parser.add_argument('-v','--version', type=str, default='0.1')
10 | parser.add_argument('--use_gpu', action='store_true', help='turn on flag to use GPU')
11 |
12 | opt = parser.parse_args()
13 |
14 | ## Initializing the model
15 | loss_fn = lpips.LPIPS(net='alex',version=opt.version)
16 | if(opt.use_gpu):
17 | loss_fn.cuda()
18 |
19 | # crawl directories
20 | f = open(opt.out,'w')
21 | files = os.listdir(opt.dir0)
22 |
23 | for file in files:
24 | if(os.path.exists(os.path.join(opt.dir1,file))):
25 | # Load images
26 | img0 = lpips.im2tensor(lpips.load_image(os.path.join(opt.dir0,file))) # RGB image from [-1,1]
27 | img1 = lpips.im2tensor(lpips.load_image(os.path.join(opt.dir1,file)))
28 |
29 | if(opt.use_gpu):
30 | img0 = img0.cuda()
31 | img1 = img1.cuda()
32 |
33 | # Compute distance
34 | dist01 = loss_fn.forward(img0,img1)
35 | print('%s: %.3f'%(file,dist01))
36 | f.writelines('%s: %.6f\n'%(file,dist01))
37 |
38 | f.close()
39 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are met:
6 |
7 | * Redistributions of source code must retain the above copyright notice, this
8 | list of conditions and the following disclaimer.
9 |
10 | * Redistributions in binary form must reproduce the above copyright notice,
11 | this list of conditions and the following disclaimer in the documentation
12 | and/or other materials provided with the distribution.
13 |
14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24 |
25 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Bokeh-Rendering-with-Vision-Transformers
2 | Establishing new state-of-the-art results for Bokeh Rendering on the EBB! Dataset. The preprint of our work can be found [here](https://www.techrxiv.org/articles/preprint/Bokeh_Effect_Rendering_with_Vision_Transformers/17714849).
3 |
4 | ### Sample:
5 |
6 |
7 |
8 |
9 | ### References
10 |
11 | Model adapted from https://github.com/isl-org/DPT
12 |
13 | SSIM loss can be found at https://github.com/Po-Hsun-Su/pytorch-ssim
14 |
15 | MSSSIM loss can be found at https://github.com/jorge-pessoa/pytorch-msssim
16 |
17 | LPIPS can be found at https://github.com/richzhang/PerceptualSimilarity
18 |
19 |
20 | ### BRViT Weights
21 |
22 | Our latest model weights can be downloaded from [here](https://drive.google.com/file/d/1V4oX1fARjaIujXQ7Vf4UDwxJhm9ubVG-/view?usp=sharing).
23 |
24 |
25 | ### BRViT Metrics
26 |
27 | Common Metrics with the latest weights for model comparison:
28 |
29 | 1. PSNR: 24.76
30 | 2. SSIM: 0.8904
31 | 3. LPIPS: 0.1924
32 |
33 |
34 | ### Dataset
35 |
36 | Training: [https://data.vision.ee.ethz.ch/timofter/AIM19Bokeh/Training.zip](https://data.vision.ee.ethz.ch/timofter/AIM19Bokeh/Training.zip)
37 |
38 | Validation: [https://data.vision.ee.ethz.ch/timofter/AIM19Bokeh/ValidationBokehFree.zip](https://data.vision.ee.ethz.ch/timofter/AIM19Bokeh/ValidationBokehFree.zip)
39 |
40 | Testing: [https://data.vision.ee.ethz.ch/timofter/AIM19Bokeh/TestBokehFree.zip](https://data.vision.ee.ethz.ch/timofter/AIM19Bokeh/TestBokehFree.zip)
41 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/util/util.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import numpy as np
4 | from PIL import Image
5 | import numpy as np
6 | import os
7 | import matplotlib.pyplot as plt
8 | import torch
9 |
10 | def load_image(path):
11 | if(path[-3:] == 'dng'):
12 | import rawpy
13 | with rawpy.imread(path) as raw:
14 | img = raw.postprocess()
15 | elif(path[-3:]=='bmp' or path[-3:]=='jpg' or path[-3:]=='png'):
16 | import cv2
17 | return cv2.imread(path)[:,:,::-1]
18 | else:
19 | img = (255*plt.imread(path)[:,:,:3]).astype('uint8')
20 |
21 | return img
22 |
23 | def save_image(image_numpy, image_path, ):
24 | image_pil = Image.fromarray(image_numpy)
25 | image_pil.save(image_path)
26 |
27 | def mkdirs(paths):
28 | if isinstance(paths, list) and not isinstance(paths, str):
29 | for path in paths:
30 | mkdir(path)
31 | else:
32 | mkdir(paths)
33 |
34 | def mkdir(path):
35 | if not os.path.exists(path):
36 | os.makedirs(path)
37 |
38 |
39 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
40 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
41 | image_numpy = image_tensor[0].cpu().float().numpy()
42 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
43 | return image_numpy.astype(imtype)
44 |
45 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
46 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
47 | return torch.Tensor((image / factor - cent)
48 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
49 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/test_network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import lpips
3 | from IPython import embed
4 |
5 | use_gpu = False # Whether to use GPU
6 | spatial = True # Return a spatial map of perceptual distance.
7 |
8 | # Linearly calibrated models (LPIPS)
9 | loss_fn = lpips.LPIPS(net='alex', spatial=spatial) # Can also set net = 'squeeze' or 'vgg'
10 | # loss_fn = lpips.LPIPS(net='alex', spatial=spatial, lpips=False) # Can also set net = 'squeeze' or 'vgg'
11 |
12 | if(use_gpu):
13 | loss_fn.cuda()
14 |
15 | ## Example usage with dummy tensors
16 | dummy_im0 = torch.zeros(1,3,64,64) # image should be RGB, normalized to [-1,1]
17 | dummy_im1 = torch.zeros(1,3,64,64)
18 | if(use_gpu):
19 | dummy_im0 = dummy_im0.cuda()
20 | dummy_im1 = dummy_im1.cuda()
21 | dist = loss_fn.forward(dummy_im0,dummy_im1)
22 |
23 | ## Example usage with images
24 | ex_ref = lpips.im2tensor(lpips.load_image('./imgs/ex_ref.png'))
25 | ex_p0 = lpips.im2tensor(lpips.load_image('./imgs/ex_p0.png'))
26 | ex_p1 = lpips.im2tensor(lpips.load_image('./imgs/ex_p1.png'))
27 | if(use_gpu):
28 | ex_ref = ex_ref.cuda()
29 | ex_p0 = ex_p0.cuda()
30 | ex_p1 = ex_p1.cuda()
31 |
32 | ex_d0 = loss_fn.forward(ex_ref,ex_p0)
33 | ex_d1 = loss_fn.forward(ex_ref,ex_p1)
34 |
35 | if not spatial:
36 | print('Distances: (%.3f, %.3f)'%(ex_d0, ex_d1))
37 | else:
38 | print('Distances: (%.3f, %.3f)'%(ex_d0.mean(), ex_d1.mean())) # The mean distance is approximately the same as the non-spatial distance
39 |
40 | # Visualize a spatially-varying distance map between ex_p0 and ex_ref
41 | import pylab
42 | pylab.imshow(ex_d0[0,0,...].data.cpu().numpy())
43 | pylab.show()
44 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/data/custom_dataset_data_loader.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data
2 | from data.base_data_loader import BaseDataLoader
3 | import os
4 |
5 | def CreateDataset(dataroots,dataset_mode='2afc',load_size=64,):
6 | dataset = None
7 | if dataset_mode=='2afc': # human judgements
8 | from data.dataset.twoafc_dataset import TwoAFCDataset
9 | dataset = TwoAFCDataset()
10 | elif dataset_mode=='jnd': # human judgements
11 | from data.dataset.jnd_dataset import JNDDataset
12 | dataset = JNDDataset()
13 | else:
14 | raise ValueError("Dataset Mode [%s] not recognized."%self.dataset_mode)
15 |
16 | dataset.initialize(dataroots,load_size=load_size)
17 | return dataset
18 |
19 | class CustomDatasetDataLoader(BaseDataLoader):
20 | def name(self):
21 | return 'CustomDatasetDataLoader'
22 |
23 | def initialize(self, datafolders, dataroot='./dataset',dataset_mode='2afc',load_size=64,batch_size=1,serial_batches=True, nThreads=1):
24 | BaseDataLoader.initialize(self)
25 | if(not isinstance(datafolders,list)):
26 | datafolders = [datafolders,]
27 | data_root_folders = [os.path.join(dataroot,datafolder) for datafolder in datafolders]
28 | self.dataset = CreateDataset(data_root_folders,dataset_mode=dataset_mode,load_size=load_size)
29 | self.dataloader = torch.utils.data.DataLoader(
30 | self.dataset,
31 | batch_size=batch_size,
32 | shuffle=not serial_batches,
33 | num_workers=int(nThreads))
34 |
35 | def load_data(self):
36 | return self.dataloader
37 |
38 | def __len__(self):
39 | return len(self.dataset)
40 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/lpips_loss.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.autograd import Variable
4 | import matplotlib.pyplot as plt
5 | import argparse
6 | import lpips
7 |
8 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
9 | parser.add_argument('--ref_path', type=str, default='./imgs/ex_ref.png')
10 | parser.add_argument('--pred_path', type=str, default='./imgs/ex_p1.png')
11 | parser.add_argument('--use_gpu', action='store_true', help='turn on flag to use GPU')
12 |
13 | opt = parser.parse_args()
14 |
15 | loss_fn = lpips.LPIPS(net='vgg')
16 | if(opt.use_gpu):
17 | loss_fn.cuda()
18 |
19 | ref = lpips.im2tensor(lpips.load_image(opt.ref_path))
20 | pred = Variable(lpips.im2tensor(lpips.load_image(opt.pred_path)), requires_grad=True)
21 | if(opt.use_gpu):
22 | with torch.no_grad():
23 | ref = ref.cuda()
24 | pred = pred.cuda()
25 |
26 | optimizer = torch.optim.Adam([pred,], lr=1e-3, betas=(0.9, 0.999))
27 |
28 | plt.ion()
29 | fig = plt.figure(1)
30 | ax = fig.add_subplot(131)
31 | ax.imshow(lpips.tensor2im(ref))
32 | ax.set_title('target')
33 | ax = fig.add_subplot(133)
34 | ax.imshow(lpips.tensor2im(pred.data))
35 | ax.set_title('initialization')
36 |
37 | for i in range(1000):
38 | dist = loss_fn.forward(pred, ref)
39 | optimizer.zero_grad()
40 | dist.backward()
41 | optimizer.step()
42 | pred.data = torch.clamp(pred.data, -1, 1)
43 |
44 | if i % 10 == 0:
45 | print('iter %d, dist %.3g' % (i, dist.view(-1).data.cpu().numpy()[0]))
46 | pred.data = torch.clamp(pred.data, -1, 1)
47 | pred_img = lpips.tensor2im(pred.data)
48 |
49 | ax = fig.add_subplot(132)
50 | ax.imshow(pred_img)
51 | ax.set_title('iter %d, dist %.3f' % (i, dist.view(-1).data.cpu().numpy()[0]))
52 | plt.pause(5e-2)
53 | # plt.imsave('imgs_saved/%04d.jpg'%i,pred_img)
54 |
55 |
56 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:9.0-base-ubuntu16.04
2 |
3 | LABEL maintainer="Seyoung Park "
4 |
5 | # This Dockerfile is forked from Tensorflow Dockerfile
6 |
7 | # Pick up some PyTorch gpu dependencies
8 | RUN apt-get update && apt-get install -y --no-install-recommends \
9 | build-essential \
10 | cuda-command-line-tools-9-0 \
11 | cuda-cublas-9-0 \
12 | cuda-cufft-9-0 \
13 | cuda-curand-9-0 \
14 | cuda-cusolver-9-0 \
15 | cuda-cusparse-9-0 \
16 | curl \
17 | libcudnn7=7.1.4.18-1+cuda9.0 \
18 | libfreetype6-dev \
19 | libhdf5-serial-dev \
20 | libpng12-dev \
21 | libzmq3-dev \
22 | pkg-config \
23 | python \
24 | python-dev \
25 | rsync \
26 | software-properties-common \
27 | unzip \
28 | && \
29 | apt-get clean && \
30 | rm -rf /var/lib/apt/lists/*
31 |
32 |
33 | # Install miniconda
34 | RUN apt-get update && apt-get install -y --no-install-recommends \
35 | wget && \
36 | MINICONDA="Miniconda3-latest-Linux-x86_64.sh" && \
37 | wget --quiet https://repo.continuum.io/miniconda/$MINICONDA && \
38 | bash $MINICONDA -b -p /miniconda && \
39 | rm -f $MINICONDA
40 | ENV PATH /miniconda/bin:$PATH
41 |
42 | # Install PyTorch
43 | RUN conda update -n base conda && \
44 | conda install pytorch torchvision cuda90 -c pytorch
45 |
46 | # Install PerceptualSimilarity dependencies
47 | RUN conda install numpy scipy jupyter matplotlib && \
48 | conda install -c conda-forge scikit-image && \
49 | apt-get install -y python-qt4 && \
50 | pip install opencv-python
51 |
52 | # For CUDA profiling, TensorFlow requires CUPTI. Maybe PyTorch needs this too.
53 | ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH
54 |
55 | # IPython
56 | EXPOSE 8888
57 |
58 | WORKDIR "/notebooks"
59 |
60 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/lpips_1dir_allpairs.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import lpips
4 | import numpy as np
5 |
6 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
7 | parser.add_argument('-d','--dir', type=str, default='./imgs/ex_dir_pair')
8 | parser.add_argument('-o','--out', type=str, default='./imgs/example_dists.txt')
9 | parser.add_argument('-v','--version', type=str, default='0.1')
10 | parser.add_argument('--all-pairs', action='store_true', help='turn on to test all N(N-1)/2 pairs, leave off to just do consecutive pairs (N-1)')
11 | parser.add_argument('-N', type=int, default=None)
12 | parser.add_argument('--use_gpu', action='store_true', help='turn on flag to use GPU')
13 |
14 | opt = parser.parse_args()
15 |
16 | ## Initializing the model
17 | loss_fn = lpips.LPIPS(net='alex',version=opt.version)
18 | if(opt.use_gpu):
19 | loss_fn.cuda()
20 |
21 | # crawl directories
22 | f = open(opt.out,'w')
23 | files = os.listdir(opt.dir)
24 | if(opt.N is not None):
25 | files = files[:opt.N]
26 | F = len(files)
27 |
28 | dists = []
29 | for (ff,file) in enumerate(files[:-1]):
30 | img0 = lpips.im2tensor(lpips.load_image(os.path.join(opt.dir,file))) # RGB image from [-1,1]
31 | if(opt.use_gpu):
32 | img0 = img0.cuda()
33 |
34 | if(opt.all_pairs):
35 | files1 = files[ff+1:]
36 | else:
37 | files1 = [files[ff+1],]
38 |
39 | for file1 in files1:
40 | img1 = lpips.im2tensor(lpips.load_image(os.path.join(opt.dir,file1)))
41 |
42 | if(opt.use_gpu):
43 | img1 = img1.cuda()
44 |
45 | # Compute distance
46 | dist01 = loss_fn.forward(img0,img1)
47 | print('(%s,%s): %.3f'%(file,file1,dist01))
48 | f.writelines('(%s,%s): %.6f\n'%(file,file1,dist01))
49 |
50 | dists.append(dist01.item())
51 |
52 | avg_dist = np.mean(np.array(dists))
53 | stderr_dist = np.std(np.array(dists))/np.sqrt(len(dists))
54 |
55 | print('Avg: %.5f +/- %.5f'%(avg_dist,stderr_dist))
56 | f.writelines('Avg: %.6f +/- %.6f'%(avg_dist,stderr_dist))
57 |
58 | f.close()
59 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/data/dataset/jnd_dataset.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import torchvision.transforms as transforms
3 | from data.dataset.base_dataset import BaseDataset
4 | from data.image_folder import make_dataset
5 | from PIL import Image
6 | import numpy as np
7 | import torch
8 | from IPython import embed
9 |
10 | class JNDDataset(BaseDataset):
11 | def initialize(self, dataroot, load_size=64):
12 | self.root = dataroot
13 | self.load_size = load_size
14 |
15 | self.dir_p0 = os.path.join(self.root, 'p0')
16 | self.p0_paths = make_dataset(self.dir_p0)
17 | self.p0_paths = sorted(self.p0_paths)
18 |
19 | self.dir_p1 = os.path.join(self.root, 'p1')
20 | self.p1_paths = make_dataset(self.dir_p1)
21 | self.p1_paths = sorted(self.p1_paths)
22 |
23 | transform_list = []
24 | transform_list.append(transforms.Scale(load_size))
25 | transform_list += [transforms.ToTensor(),
26 | transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))]
27 |
28 | self.transform = transforms.Compose(transform_list)
29 |
30 | # judgement directory
31 | self.dir_S = os.path.join(self.root, 'same')
32 | self.same_paths = make_dataset(self.dir_S,mode='np')
33 | self.same_paths = sorted(self.same_paths)
34 |
35 | def __getitem__(self, index):
36 | p0_path = self.p0_paths[index]
37 | p0_img_ = Image.open(p0_path).convert('RGB')
38 | p0_img = self.transform(p0_img_)
39 |
40 | p1_path = self.p1_paths[index]
41 | p1_img_ = Image.open(p1_path).convert('RGB')
42 | p1_img = self.transform(p1_img_)
43 |
44 | same_path = self.same_paths[index]
45 | same_img = np.load(same_path).reshape((1,1,1,)) # [0,1]
46 |
47 | same_img = torch.FloatTensor(same_img)
48 |
49 | return {'p0': p0_img, 'p1': p1_img, 'same': same_img,
50 | 'p0_path': p0_path, 'p1_path': p1_path, 'same_path': same_path}
51 |
52 | def __len__(self):
53 | return len(self.p0_paths)
54 |
--------------------------------------------------------------------------------
/pytorch_ssim/README.md:
--------------------------------------------------------------------------------
1 | # pytorch-ssim
2 |
3 | ### Differentiable structural similarity (SSIM) index.
4 |  
5 |
6 | ## Installation
7 | 1. Clone this repo.
8 | 2. Copy "pytorch_ssim" folder in your project.
9 |
10 | ## Example
11 | ### basic usage
12 | ```python
13 | import pytorch_ssim
14 | import torch
15 | from torch.autograd import Variable
16 |
17 | img1 = Variable(torch.rand(1, 1, 256, 256))
18 | img2 = Variable(torch.rand(1, 1, 256, 256))
19 |
20 | if torch.cuda.is_available():
21 | img1 = img1.cuda()
22 | img2 = img2.cuda()
23 |
24 | print(pytorch_ssim.ssim(img1, img2))
25 |
26 | ssim_loss = pytorch_ssim.SSIM(window_size = 11)
27 |
28 | print(ssim_loss(img1, img2))
29 |
30 | ```
31 | ### maximize ssim
32 | ```python
33 | import pytorch_ssim
34 | import torch
35 | from torch.autograd import Variable
36 | from torch import optim
37 | import cv2
38 | import numpy as np
39 |
40 | npImg1 = cv2.imread("einstein.png")
41 |
42 | img1 = torch.from_numpy(np.rollaxis(npImg1, 2)).float().unsqueeze(0)/255.0
43 | img2 = torch.rand(img1.size())
44 |
45 | if torch.cuda.is_available():
46 | img1 = img1.cuda()
47 | img2 = img2.cuda()
48 |
49 |
50 | img1 = Variable( img1, requires_grad=False)
51 | img2 = Variable( img2, requires_grad = True)
52 |
53 |
54 | # Functional: pytorch_ssim.ssim(img1, img2, window_size = 11, size_average = True)
55 | ssim_value = pytorch_ssim.ssim(img1, img2).data[0]
56 | print("Initial ssim:", ssim_value)
57 |
58 | # Module: pytorch_ssim.SSIM(window_size = 11, size_average = True)
59 | ssim_loss = pytorch_ssim.SSIM()
60 |
61 | optimizer = optim.Adam([img2], lr=0.01)
62 |
63 | while ssim_value < 0.95:
64 | optimizer.zero_grad()
65 | ssim_out = -ssim_loss(img1, img2)
66 | ssim_value = - ssim_out.data[0]
67 | print(ssim_value)
68 | ssim_out.backward()
69 | optimizer.step()
70 |
71 | ```
72 |
73 | ## Reference
74 | https://ece.uwaterloo.ca/~z70wang/research/ssim/
75 |
--------------------------------------------------------------------------------
/pytorch_msssim/max_ssim.py:
--------------------------------------------------------------------------------
1 | from pytorch_msssim import msssim, ssim
2 | import torch
3 | from torch import optim
4 |
5 | from PIL import Image
6 | from torchvision.transforms.functional import to_tensor
7 | import numpy as np
8 |
9 | # display = True requires matplotlib
10 | display = True
11 | metric = 'MSSSIM' # MSSSIM or SSIM
12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13 |
14 | def post_process(img):
15 | img = img.detach().cpu().numpy()
16 | img = np.transpose(np.squeeze(img, axis=0), (1, 2, 0))
17 | img = np.squeeze(img) # works if grayscale
18 | return img
19 |
20 | # Preprocessing
21 | img1 = to_tensor(Image.open('einstein.png')).unsqueeze(0).type(torch.FloatTensor)
22 |
23 | img2 = torch.rand(img1.size())
24 | img2 = torch.nn.functional.sigmoid(img2) # use sigmoid to clamp between [0, 1]
25 |
26 | img1 = img1.to(device)
27 | img2 = img2.to(device)
28 |
29 | img1.requires_grad = False
30 | img2.requires_grad = True
31 |
32 | loss_func = msssim if metric == 'MSSSIM' else ssim
33 |
34 | value = loss_func(img1, img2)
35 | print("Initial %s: %.5f" % (metric, value.item()))
36 |
37 | optimizer = optim.Adam([img2], lr=0.01)
38 |
39 | # MSSSIM yields higher values for worse results, because noise is removed in scales with lower resolutions
40 | threshold = 0.999 if metric == 'MSSSIM' else 0.9
41 |
42 | while value < threshold:
43 | optimizer.zero_grad()
44 | msssim_out = -loss_func(img1, img2)
45 | value = -msssim_out.item()
46 | print('Current MS-SSIM = %.5f' % value)
47 | msssim_out.backward()
48 | optimizer.step()
49 |
50 | if display:
51 | # Post processing
52 | img1np = post_process(img1)
53 | img2 = torch.nn.functional.sigmoid(img2)
54 | img2np = post_process(img2)
55 | import matplotlib.pyplot as plt
56 | cmap = 'gray' if len(img1np.shape) == 2 else None
57 | plt.subplot(1, 2, 1)
58 | plt.imshow(img1np, cmap=cmap)
59 | plt.title('Original')
60 | plt.subplot(1, 2, 2)
61 | plt.imshow(img2np, cmap=cmap)
62 | plt.title('Generated, {:s}: {:.3f}'.format(metric, value))
63 | plt.show()
64 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import timm
6 | import types
7 | import math
8 | import numpy as np
9 | from PIL import Image
10 | import PIL.Image as pil
11 | from torchvision import transforms, datasets
12 | from torch.utils.data import Dataset, DataLoader
13 | import torch.utils.data as data
14 | import PerceptualSimilarity.lpips.lpips as lpips
15 | import glob
16 | import gc
17 | import pytorch_msssim.pytorch_msssim as pytorch_msssim
18 | from torch.cuda.amp import autocast
19 | from model import BRViT
20 | import matplotlib.pyplot as plt
21 | from pytorch_ssim.pytorch_ssim import ssim
22 | from torchvision.utils import save_image
23 | from PerceptualSimilarity.util import util
24 |
25 |
26 | import pandas as pd
27 | from tqdm import tqdm
28 | import sys
29 | import cv2
30 |
31 | from skimage.metrics import peak_signal_noise_ratio as compare_psnr
32 | from skimage.metrics import structural_similarity as compare_ssim
33 |
34 | batch_size = 1
35 | feed_width = 1536
36 | feed_height = 1024
37 |
38 |
39 |
40 | def evaluate():
41 | tot_lpips_loss = 0.0
42 | total_psnr = 0.0
43 | total_ssim = 0.0
44 |
45 | device = torch.device("cuda")
46 | loss_fn = lpips.LPIPS(net='alex').to(device)
47 |
48 | for i in tqdm(range(294)):
49 | csv_file = "Bokeh_Data/test.csv"
50 | root_dir = "."
51 | dataa = pd.read_csv(csv_file)
52 | idx = i
53 |
54 | img0 = util.im2tensor(util.load_image(root_dir + dataa.iloc[idx, 0][1:])) # RGB image from [-1,1]
55 | img1 = util.im2tensor(util.load_image(f"Results/{4400+i}.png"))
56 | img0 = img0.to(device)
57 | img1 = img1.to(device)
58 |
59 | lpips_loss = loss_fn.forward(img0, img1)
60 | tot_lpips_loss += lpips_loss.item()
61 |
62 |
63 | total_psnr += compare_psnr(I0,I1)
64 | total_ssim += compare_ssim(I0, I1, multichannel=True)
65 |
66 |
67 | print("TOTAL LPIPS:",":", tot_lpips_loss / 294)
68 | print("TOTAL PSNR",":", total_psnr / 294)
69 | print("TOTAL SSIM",":", total_ssim / 294)
70 |
71 |
72 | if __name__ == "__main__":
73 | evaluate()
74 |
75 |
76 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/util/html.py:
--------------------------------------------------------------------------------
1 | import dominate
2 | from dominate.tags import *
3 | import os
4 |
5 |
6 | class HTML:
7 | def __init__(self, web_dir, title, image_subdir='', reflesh=0):
8 | self.title = title
9 | self.web_dir = web_dir
10 | # self.img_dir = os.path.join(self.web_dir, )
11 | self.img_subdir = image_subdir
12 | self.img_dir = os.path.join(self.web_dir, image_subdir)
13 | if not os.path.exists(self.web_dir):
14 | os.makedirs(self.web_dir)
15 | if not os.path.exists(self.img_dir):
16 | os.makedirs(self.img_dir)
17 | # print(self.img_dir)
18 |
19 | self.doc = dominate.document(title=title)
20 | if reflesh > 0:
21 | with self.doc.head:
22 | meta(http_equiv="reflesh", content=str(reflesh))
23 |
24 | def get_image_dir(self):
25 | return self.img_dir
26 |
27 | def add_header(self, str):
28 | with self.doc:
29 | h3(str)
30 |
31 | def add_table(self, border=1):
32 | self.t = table(border=border, style="table-layout: fixed;")
33 | self.doc.add(self.t)
34 |
35 | def add_images(self, ims, txts, links, width=400):
36 | self.add_table()
37 | with self.t:
38 | with tr():
39 | for im, txt, link in zip(ims, txts, links):
40 | with td(style="word-wrap: break-word;", halign="center", valign="top"):
41 | with p():
42 | with a(href=os.path.join(link)):
43 | img(style="width:%dpx" % width, src=os.path.join(im))
44 | br()
45 | p(txt)
46 |
47 | def save(self,file='index'):
48 | html_file = '%s/%s.html' % (self.web_dir,file)
49 | f = open(html_file, 'wt')
50 | f.write(self.doc.render())
51 | f.close()
52 |
53 |
54 | if __name__ == '__main__':
55 | html = HTML('web/', 'test_html')
56 | html.add_header('hello world')
57 |
58 | ims = []
59 | txts = []
60 | links = []
61 | for n in range(4):
62 | ims.append('image_%d.png' % n)
63 | txts.append('text_%d' % n)
64 | links.append('image_%d.png' % n)
65 | html.add_images(ims, txts, links)
66 | html.save()
67 |
--------------------------------------------------------------------------------
/store_results.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import timm
6 | import types
7 | import math
8 | import numpy as np
9 | from PIL import Image
10 | import PIL.Image as pil
11 | from torchvision import transforms, datasets
12 | from torch.utils.data import Dataset, DataLoader
13 | import torch.utils.data as data
14 | import PerceptualSimilarity.lpips.lpips as lpips
15 | import glob
16 | import gc
17 | import pytorch_msssim.pytorch_msssim as pytorch_msssim
18 | from torch.cuda.amp import autocast
19 | from model import BRViT
20 | import matplotlib.pyplot as plt
21 | from pytorch_ssim.pytorch_ssim import ssim
22 | from torchvision.utils import save_image
23 |
24 | import pandas as pd
25 | from tqdm import tqdm
26 | import sys
27 | import cv2
28 | import time
29 | import scipy.misc
30 |
31 | from skimage.metrics import peak_signal_noise_ratio as compare_psnr
32 | from skimage.metrics import structural_similarity as compare_ssim
33 |
34 |
35 | feed_width = 1536
36 | feed_height = 1024
37 |
38 |
39 | #to store results
40 | def store_results():
41 | device = torch.device("cuda")
42 |
43 | model = BRViT().to(device)
44 |
45 | PATH = "weights/BRViT_53_0.pt"
46 | model.load_state_dict(torch.load(PATH), strict=True)
47 |
48 | with torch.no_grad():
49 | for i in tqdm(range(4400,4694)):
50 | csv_file = "Bokeh_Data/test.csv"
51 | data = pd.read_csv(csv_file)
52 | root_dir = "."
53 | idx = i - 4400
54 | bok = pil.open(root_dir + data.iloc[idx, 0][1:]).convert('RGB')
55 | input_image = pil.open(root_dir + data.iloc[idx, 1][1:]).convert('RGB')
56 | original_width, original_height = input_image.size
57 |
58 | org_image = input_image
59 | org_image = transforms.ToTensor()(org_image).unsqueeze(0)
60 |
61 | input_image = input_image.resize((feed_width, feed_height), pil.LANCZOS)
62 | input_image = transforms.ToTensor()(input_image).unsqueeze(0)
63 |
64 | # prediction
65 | org_image = org_image.to(device)
66 | input_image = input_image.to(device)
67 |
68 | bok_pred = model(input_image)
69 |
70 | bok_pred = F.interpolate(bok_pred,(original_height,original_width),mode = 'bilinear')
71 |
72 | save_image(bok_pred,'Results/'+ str(i) +'.png')
73 |
74 |
75 |
76 | if __name__ == "__main__":
77 | store_results()
78 |
79 |
80 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/data/image_folder.py:
--------------------------------------------------------------------------------
1 | ################################################################################
2 | # Code from
3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
4 | # Modified the original code so that it also loads images from the current
5 | # directory as well as the subdirectories
6 | ################################################################################
7 |
8 | import torch.utils.data as data
9 |
10 | from PIL import Image
11 | import os
12 | import os.path
13 |
14 | IMG_EXTENSIONS = [
15 | '.jpg', '.JPG', '.jpeg', '.JPEG',
16 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
17 | ]
18 |
19 | NP_EXTENSIONS = ['.npy',]
20 |
21 | def is_image_file(filename, mode='img'):
22 | if(mode=='img'):
23 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
24 | elif(mode=='np'):
25 | return any(filename.endswith(extension) for extension in NP_EXTENSIONS)
26 |
27 | def make_dataset(dirs, mode='img'):
28 | if(not isinstance(dirs,list)):
29 | dirs = [dirs,]
30 |
31 | images = []
32 | for dir in dirs:
33 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
34 | for root, _, fnames in sorted(os.walk(dir)):
35 | for fname in fnames:
36 | if is_image_file(fname, mode=mode):
37 | path = os.path.join(root, fname)
38 | images.append(path)
39 |
40 | # print("Found %i images in %s"%(len(images),root))
41 | return images
42 |
43 | def default_loader(path):
44 | return Image.open(path).convert('RGB')
45 |
46 | class ImageFolder(data.Dataset):
47 | def __init__(self, root, transform=None, return_paths=False,
48 | loader=default_loader):
49 | imgs = make_dataset(root)
50 | if len(imgs) == 0:
51 | raise(RuntimeError("Found 0 images in: " + root + "\n"
52 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
53 |
54 | self.root = root
55 | self.imgs = imgs
56 | self.transform = transform
57 | self.return_paths = return_paths
58 | self.loader = loader
59 |
60 | def __getitem__(self, index):
61 | path = self.imgs[index]
62 | img = self.loader(path)
63 | if self.transform is not None:
64 | img = self.transform(img)
65 | if self.return_paths:
66 | return img, path
67 | else:
68 | return img
69 |
70 | def __len__(self):
71 | return len(self.imgs)
72 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/data/dataset/twoafc_dataset.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import torchvision.transforms as transforms
3 | from data.dataset.base_dataset import BaseDataset
4 | from data.image_folder import make_dataset
5 | from PIL import Image
6 | import numpy as np
7 | import torch
8 | # from IPython import embed
9 |
10 | class TwoAFCDataset(BaseDataset):
11 | def initialize(self, dataroots, load_size=64):
12 | if(not isinstance(dataroots,list)):
13 | dataroots = [dataroots,]
14 | self.roots = dataroots
15 | self.load_size = load_size
16 |
17 | # image directory
18 | self.dir_ref = [os.path.join(root, 'ref') for root in self.roots]
19 | self.ref_paths = make_dataset(self.dir_ref)
20 | self.ref_paths = sorted(self.ref_paths)
21 |
22 | self.dir_p0 = [os.path.join(root, 'p0') for root in self.roots]
23 | self.p0_paths = make_dataset(self.dir_p0)
24 | self.p0_paths = sorted(self.p0_paths)
25 |
26 | self.dir_p1 = [os.path.join(root, 'p1') for root in self.roots]
27 | self.p1_paths = make_dataset(self.dir_p1)
28 | self.p1_paths = sorted(self.p1_paths)
29 |
30 | transform_list = []
31 | transform_list.append(transforms.Scale(load_size))
32 | transform_list += [transforms.ToTensor(),
33 | transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))]
34 |
35 | self.transform = transforms.Compose(transform_list)
36 |
37 | # judgement directory
38 | self.dir_J = [os.path.join(root, 'judge') for root in self.roots]
39 | self.judge_paths = make_dataset(self.dir_J,mode='np')
40 | self.judge_paths = sorted(self.judge_paths)
41 |
42 | def __getitem__(self, index):
43 | p0_path = self.p0_paths[index]
44 | p0_img_ = Image.open(p0_path).convert('RGB')
45 | p0_img = self.transform(p0_img_)
46 |
47 | p1_path = self.p1_paths[index]
48 | p1_img_ = Image.open(p1_path).convert('RGB')
49 | p1_img = self.transform(p1_img_)
50 |
51 | ref_path = self.ref_paths[index]
52 | ref_img_ = Image.open(ref_path).convert('RGB')
53 | ref_img = self.transform(ref_img_)
54 |
55 | judge_path = self.judge_paths[index]
56 | # judge_img = (np.load(judge_path)*2.-1.).reshape((1,1,1,)) # [-1,1]
57 | judge_img = np.load(judge_path).reshape((1,1,1,)) # [0,1]
58 |
59 | judge_img = torch.FloatTensor(judge_img)
60 |
61 | return {'p0': p0_img, 'p1': p1_img, 'ref': ref_img, 'judge': judge_img,
62 | 'p0_path': p0_path, 'p1_path': p1_path, 'ref_path': ref_path, 'judge_path': judge_path}
63 |
64 | def __len__(self):
65 | return len(self.p0_paths)
66 |
--------------------------------------------------------------------------------
/pytorch_msssim/README.md:
--------------------------------------------------------------------------------
1 | # pytorch-msssim
2 |
3 | ### Differentiable Multi-Scale Structural Similarity (SSIM) index
4 |
5 | This small utiliy provides a differentiable MS-SSIM implementation for PyTorch based on Po Hsun Su's implementation of SSIM @ https://github.com/Po-Hsun-Su/pytorch-ssim.
6 | At the moment only the product method for MS-SSIM is supported.
7 |
8 | ## Installation
9 |
10 | Master branch now only supports PyTorch 0.4 or higher. All development occurs in the dev branch (`git checkout dev` after cloning the repository to get the latest development version).
11 |
12 | To install the current version of pytorch_mssim:
13 |
14 | 1. Clone this repo.
15 | 2. Go to the repo directory.
16 | 3. Run `python setup.py install`
17 |
18 | or
19 |
20 | 1. Clone this repo.
21 | 2. Copy "pytorch_msssim" folder in your project.
22 |
23 | To install a version of of pytorch_mssim that runs in PyTorch 0.3.1 or lower use the tag checkpoint-0.3. To do so, run the following commands after cloning the repository:
24 |
25 | ```
26 | git fetch --all --tags
27 | git checkout tags/checkpoint-0.3
28 | ```
29 |
30 | ## Example
31 |
32 | ### Basic usage
33 | ```python
34 | import pytorch_msssim
35 | import torch
36 |
37 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
38 | m = pytorch_msssim.MSSSIM()
39 |
40 | img1 = torch.rand(1, 1, 256, 256)
41 | img2 = torch.rand(1, 1, 256, 256)
42 |
43 | print(pytorch_msssim.msssim(img1, img2))
44 | print(m(img1, img2))
45 |
46 |
47 | ```
48 |
49 | ### Training
50 |
51 | For a detailed example on how to use msssim for optimization, take a look at the file max_ssim.py.
52 |
53 |
54 | ### Stability and normalization
55 |
56 | MS-SSIM is a particularly unstable metric when used for some architectures and may result in NaN values early on during the training. The msssim method provides a normalize attribute to help in these cases. There are three possible values. We recommend using the value normalized="relu" when training.
57 |
58 | - None : no normalization method is used and should be used for evaluation
59 | - "relu" : the `ssim`and `mc` values of each level during the calculation are rectified using a relu ensuring that negative values are zeroed
60 | - "simple" : the `ssim`result of each iteration is averaged with 1 for an expected lower bound of 0.5 - should ONLY be used for the initial iterations of your training or when averaging below 0.6 normalized score
61 |
62 | Currently and due to backward compability, a value of True will equal the "simple" normalization.
63 |
64 | ## Reference
65 | https://ece.uwaterloo.ca/~z70wang/research/ssim/
66 |
67 | https://github.com/Po-Hsun-Su/pytorch-ssim
68 |
69 | Thanks to z70wang for proposing MS-SSIM and providing the initial implementation, and Po-Hsun-Su for the initial differentiable SSIM implementation for Pytorch.
70 |
--------------------------------------------------------------------------------
/pytorch_ssim/pytorch_ssim/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.autograd import Variable
4 | import numpy as np
5 | from math import exp
6 |
7 | def gaussian(window_size, sigma):
8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
9 | return gauss/gauss.sum()
10 |
11 | def create_window(window_size, channel):
12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
15 | return window
16 |
17 | def _ssim(img1, img2, window, window_size, channel, size_average = True):
18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
20 |
21 | mu1_sq = mu1.pow(2)
22 | mu2_sq = mu2.pow(2)
23 | mu1_mu2 = mu1*mu2
24 |
25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
28 |
29 | C1 = 0.01**2
30 | C2 = 0.03**2
31 |
32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
33 |
34 | if size_average:
35 | return ssim_map.mean()
36 | else:
37 | return ssim_map.mean(1).mean(1).mean(1)
38 |
39 | class SSIM(torch.nn.Module):
40 | def __init__(self, window_size = 11, size_average = True):
41 | super(SSIM, self).__init__()
42 | self.window_size = window_size
43 | self.size_average = size_average
44 | self.channel = 1
45 | self.window = create_window(window_size, self.channel)
46 |
47 | def forward(self, img1, img2):
48 | (_, channel, _, _) = img1.size()
49 |
50 | if channel == self.channel and self.window.data.type() == img1.data.type():
51 | window = self.window
52 | else:
53 | window = create_window(self.window_size, channel)
54 |
55 | if img1.is_cuda:
56 | window = window.cuda(img1.get_device())
57 | window = window.type_as(img1)
58 |
59 | self.window = window
60 | self.channel = channel
61 |
62 |
63 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
64 |
65 | def ssim(img1, img2, window_size = 11, size_average = True):
66 | (_, channel, _, _) = img1.size()
67 | window = create_window(window_size, channel)
68 |
69 | if img1.is_cuda:
70 | window = window.cuda(img1.get_device())
71 | window = window.type_as(img1)
72 |
73 | return _ssim(img1, img2, window, window_size, channel, size_average)
74 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/test_dataset_model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import lpips
3 | from data import data_loader as dl
4 | import argparse
5 | from IPython import embed
6 |
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument('--dataset_mode', type=str, default='2afc', help='[2afc,jnd]')
9 | parser.add_argument('--datasets', type=str, nargs='+', default=['val/traditional','val/cnn','val/superres','val/deblur','val/color','val/frameinterp'], help='datasets to test - for jnd mode: [val/traditional],[val/cnn]; for 2afc mode: [train/traditional],[train/cnn],[train/mix],[val/traditional],[val/cnn],[val/color],[val/deblur],[val/frameinterp],[val/superres]')
10 | parser.add_argument('--model', type=str, default='lpips', help='distance model type [lpips] for linearly calibrated net, [baseline] for off-the-shelf network, [l2] for euclidean distance, [ssim] for Structured Similarity Image Metric')
11 | parser.add_argument('--net', type=str, default='alex', help='[squeeze], [alex], or [vgg] for network architectures')
12 | parser.add_argument('--colorspace', type=str, default='Lab', help='[Lab] or [RGB] for colorspace to use for l2, ssim model types')
13 | parser.add_argument('--batch_size', type=int, default=50, help='batch size to test image patches in')
14 | parser.add_argument('--use_gpu', action='store_true', help='turn on flag to use GPU')
15 | parser.add_argument('--gpu_ids', type=int, nargs='+', default=[0], help='gpus to use')
16 | parser.add_argument('--nThreads', type=int, default=4, help='number of threads to use in data loader')
17 |
18 | parser.add_argument('--model_path', type=str, default=None, help='location of model, will default to ./weights/v[version]/[net_name].pth')
19 |
20 | parser.add_argument('--from_scratch', action='store_true', help='model was initialized from scratch')
21 | parser.add_argument('--train_trunk', action='store_true', help='model trunk was trained/tuned')
22 | parser.add_argument('--version', type=str, default='0.1', help='v0.1 is latest, v0.0 was original release')
23 |
24 | opt = parser.parse_args()
25 | if(opt.model in ['l2','ssim']):
26 | opt.batch_size = 1
27 |
28 | # initialize model
29 | trainer = lpips.Trainer()
30 | # trainer.initialize(model=opt.model,net=opt.net,colorspace=opt.colorspace,model_path=opt.model_path,use_gpu=opt.use_gpu)
31 | trainer.initialize(model=opt.model, net=opt.net, colorspace=opt.colorspace,
32 | model_path=opt.model_path, use_gpu=opt.use_gpu, pnet_rand=opt.from_scratch, pnet_tune=opt.train_trunk,
33 | version=opt.version, gpu_ids=opt.gpu_ids)
34 |
35 | if(opt.model in ['net-lin','net']):
36 | print('Testing model [%s]-[%s]'%(opt.model,opt.net))
37 | elif(opt.model in ['l2','ssim']):
38 | print('Testing model [%s]-[%s]'%(opt.model,opt.colorspace))
39 |
40 | # initialize data loader
41 | for dataset in opt.datasets:
42 | data_loader = dl.CreateDataLoader(dataset,dataset_mode=opt.dataset_mode, batch_size=opt.batch_size, nThreads=opt.nThreads)
43 |
44 | # evaluate model on data
45 | if(opt.dataset_mode=='2afc'):
46 | (score, results_verbose) = lpips.score_2afc_dataset(data_loader, trainer.forward, name=dataset)
47 | elif(opt.dataset_mode=='jnd'):
48 | (score, results_verbose) = lpips.score_jnd_dataset(data_loader, trainer.forward, name=dataset)
49 |
50 | # print results
51 | print(' Dataset [%s]: %.2f'%(dataset,100.*score))
52 |
53 |
--------------------------------------------------------------------------------
/pytorch_msssim/pytorch_msssim/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from math import exp
4 | import numpy as np
5 |
6 |
7 | def gaussian(window_size, sigma):
8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
9 | return gauss/gauss.sum()
10 |
11 |
12 | def create_window(window_size, channel=1):
13 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
14 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
15 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
16 | return window
17 |
18 |
19 | def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
20 | # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
21 | if val_range is None:
22 | if torch.max(img1) > 128:
23 | max_val = 255
24 | else:
25 | max_val = 1
26 |
27 | if torch.min(img1) < -0.5:
28 | min_val = -1
29 | else:
30 | min_val = 0
31 | L = max_val - min_val
32 | else:
33 | L = val_range
34 |
35 | padd = 0
36 | (_, channel, height, width) = img1.size()
37 | if window is None:
38 | real_size = min(window_size, height, width)
39 | window = create_window(real_size, channel=channel).to(img1.device)
40 |
41 | mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
42 | mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
43 |
44 | mu1_sq = mu1.pow(2)
45 | mu2_sq = mu2.pow(2)
46 | mu1_mu2 = mu1 * mu2
47 |
48 | sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq
49 | sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq
50 | sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2
51 |
52 | C1 = (0.01 * L) ** 2
53 | C2 = (0.03 * L) ** 2
54 |
55 | v1 = 2.0 * sigma12 + C2
56 | v2 = sigma1_sq + sigma2_sq + C2
57 | cs = v1 / v2 # contrast sensitivity
58 |
59 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
60 |
61 | if size_average:
62 | cs = cs.mean()
63 | ret = ssim_map.mean()
64 | else:
65 | cs = cs.mean(1).mean(1).mean(1)
66 | ret = ssim_map.mean(1).mean(1).mean(1)
67 |
68 | if full:
69 | return ret, cs
70 | return ret
71 |
72 |
73 | def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=None):
74 | device = img1.device
75 | weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)
76 | levels = weights.size()[0]
77 | ssims = []
78 | mcs = []
79 | for _ in range(levels):
80 | sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
81 |
82 | # Relu normalize (not compliant with original definition)
83 | if normalize == "relu":
84 | ssims.append(torch.relu(sim))
85 | mcs.append(torch.relu(cs))
86 | else:
87 | ssims.append(sim)
88 | mcs.append(cs)
89 |
90 | img1 = F.avg_pool2d(img1, (2, 2))
91 | img2 = F.avg_pool2d(img2, (2, 2))
92 |
93 | ssims = torch.stack(ssims)
94 | mcs = torch.stack(mcs)
95 |
96 | # Simple normalize (not compliant with original definition)
97 | # TODO: remove support for normalize == True (kept for backward support)
98 | if normalize == "simple" or normalize == True:
99 | ssims = (ssims + 1) / 2
100 | mcs = (mcs + 1) / 2
101 |
102 | pow1 = mcs ** weights
103 | pow2 = ssims ** weights
104 |
105 | # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/
106 | output = torch.prod(pow1[:-1]) * pow2[-1]
107 | ## output = torch.prod(pow1[:-1] * pow2[-1])
108 | return output
109 |
110 |
111 | # Classes to re-use window
112 | class SSIM(torch.nn.Module):
113 | def __init__(self, window_size=11, size_average=True, val_range=None):
114 | super(SSIM, self).__init__()
115 | self.window_size = window_size
116 | self.size_average = size_average
117 | self.val_range = val_range
118 |
119 | # Assume 1 channel for SSIM
120 | self.channel = 1
121 | self.window = create_window(window_size)
122 |
123 | def forward(self, img1, img2):
124 | (_, channel, _, _) = img1.size()
125 |
126 | if channel == self.channel and self.window.dtype == img1.dtype:
127 | window = self.window
128 | else:
129 | window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
130 | self.window = window
131 | self.channel = channel
132 |
133 | return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
134 |
135 | class MSSSIM(torch.nn.Module):
136 | def __init__(self, window_size=11, size_average=True, channel=3):
137 | super(MSSSIM, self).__init__()
138 | self.window_size = window_size
139 | self.size_average = size_average
140 | self.channel = channel
141 |
142 | def forward(self, img1, img2):
143 | # TODO: store window between calls if possible
144 | return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average)
145 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 | import cv2
3 | import numbers
4 | import math
5 | import os
6 | import sys
7 | import glob
8 | import argparse
9 | import numpy as np
10 | import PIL.Image as pil
11 | import matplotlib as mpl
12 | import matplotlib.cm as cm
13 | import torch
14 | from torchvision import transforms, datasets
15 | from torch.utils.data import Dataset, DataLoader
16 | import pandas as pd
17 | import torch.nn as nn
18 | import torch.optim as optim
19 | import torch.nn.functional as F
20 | import math
21 | from pytorch_msssim.pytorch_msssim import msssim
22 | import torchvision
23 | from torch.autograd import Variable
24 | from pytorch_ssim.pytorch_ssim import ssim
25 | import PerceptualSimilarity.lpips.lpips as lpips
26 |
27 | from tqdm import tqdm
28 |
29 | device = torch.device("cuda:0")
30 |
31 | from model import BRViT
32 |
33 |
34 | feed_width = 768
35 | feed_height = 512
36 |
37 |
38 | bokehnet = BRViT().to(device)
39 | batch_size = 1
40 |
41 |
42 | PATH = "BRViT_53_0.pt"
43 | bokehnet.load_state_dict(torch.load(PATH), strict=False)
44 |
45 | print("weights loaded!!")
46 |
47 |
48 | class bokehDataset(Dataset):
49 |
50 | def __init__(self, csv_file,root_dir, transform=None):
51 |
52 | self.data = pd.read_csv(csv_file)
53 | self.transform = transform
54 | self.root_dir = root_dir
55 |
56 | def __len__(self):
57 | return len(self.data)
58 |
59 | def __getitem__(self, idx):
60 |
61 | bok = pil.open(self.root_dir + self.data.iloc[idx, 0][1:]).convert('RGB')
62 | org = pil.open(self.root_dir + self.data.iloc[idx, 1][1:]).convert('RGB')
63 |
64 | bok = bok.resize((feed_width, feed_height), pil.LANCZOS)
65 | org = org.resize((feed_width, feed_height), pil.LANCZOS)
66 | if self.transform :
67 | bok_dep = self.transform(bok)
68 | org_dep = self.transform(org)
69 | return (bok_dep, org_dep)
70 |
71 | transform1 = transforms.Compose(
72 | [
73 | transforms.ToTensor(),
74 | ])
75 |
76 |
77 | transform2 = transforms.Compose(
78 | [
79 | transforms.RandomHorizontalFlip(p=1),
80 | transforms.ToTensor(),
81 | ])
82 |
83 |
84 | transform3 = transforms.Compose(
85 | [
86 | transforms.RandomVerticalFlip(p=1),
87 | transforms.ToTensor(),
88 | ])
89 |
90 |
91 |
92 | trainset1 = bokehDataset(csv_file = './Bokeh_Data/train.csv', root_dir = '.',transform = transform1)
93 | trainset2 = bokehDataset(csv_file = './Bokeh_Data/train.csv', root_dir = '.',transform = transform2)
94 | trainset3 = bokehDataset(csv_file = './Bokeh_Data/train.csv', root_dir = '.',transform = transform3)
95 |
96 |
97 | trainloader = torch.utils.data.DataLoader(torch.utils.data.ConcatDataset([trainset1,trainset2,trainset3]), batch_size=batch_size,
98 | shuffle=True, num_workers=0)
99 |
100 | testset = bokehDataset(csv_file = './Bokeh_Data/test.csv', root_dir = '.', transform = transform1)
101 | testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
102 | shuffle=False, num_workers=0)
103 |
104 |
105 | learning_rate = 0.00001
106 |
107 | optimizer = optim.Adam(bokehnet.parameters(), lr=learning_rate, betas=(0.9, 0.999))
108 |
109 |
110 | sm = nn.Softmax(dim=1)
111 |
112 | MSE_LossFn = nn.MSELoss()
113 | L1_LossFn = nn.L1Loss()
114 |
115 |
116 | def train(dataloader):
117 | running_l1_loss = 0
118 | running_ms_loss = 0
119 | running_sal_loss = 0
120 | running_loss = 0
121 |
122 | for i,data in enumerate(dataloader,0) :
123 | bok , org = data
124 | bok , org = bok.to(device) , org.to(device)
125 |
126 | optimizer.zero_grad()
127 |
128 | bok_pred = bokehnet(org)
129 |
130 | loss = (1-msssim(bok_pred, bok))
131 |
132 | ## loss = L1_LossFn(bok_pred, bok)
133 |
134 | running_loss += loss.item()
135 |
136 | loss.backward()
137 | optimizer.step()
138 | if (i % 10 == 0):
139 | print ('Batch: ',i,'/',len(dataloader),' Loss:', loss.item())
140 |
141 |
142 | if ((i+1)%8000==0):
143 | torch.save(bokehnet.state_dict(), './weights/BRViT_' + str(epoch) + '_' + str(i) + '.pt')
144 | print(loss.item())
145 |
146 | print (running_loss/len(dataloader))
147 |
148 |
149 |
150 |
151 | def val(dataloader):
152 | running_l1_loss = 0
153 | running_ms_loss = 0
154 | running_lips_loss = 0
155 |
156 | with torch.no_grad():
157 | for i,data in enumerate(tqdm(dataloader),0) :
158 | bok , org = data
159 | bok , org = bok.to(device) , org.to(device)
160 |
161 | bok_pred = bokehnet(org)
162 |
163 | try:
164 | l1_loss = L1_LossFn(bok_pred, bok)
165 | except:
166 | l1_loss = L1_LossFn(bok_pred[0], bok)
167 |
168 | ms_loss = 1-ssim(bok_pred, bok)
169 |
170 |
171 | running_l1_loss += l1_loss.item()
172 | running_ms_loss += ms_loss.item()
173 |
174 |
175 | print ('Validation l1 Loss: ',running_l1_loss/len(dataloader))
176 | print ('Validation ms Loss: ',running_ms_loss/len(dataloader))
177 |
178 | ## torch.save(bokehnet.state_dict(), './weights/BRViT_'+str(epoch)+'.pt')
179 |
180 | try:
181 | with open("log.txt", 'a') as f:
182 | f.write(f"{running_ms_loss/len(dataloader)}\n")
183 | except:
184 | pass
185 |
186 |
187 |
188 | start_ep = 0
189 | for epoch in range(start_ep, 40):
190 | print (epoch)
191 |
192 | train(trainloader)
193 |
194 | with torch.no_grad():
195 | val(testloader)
196 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/train.py:
--------------------------------------------------------------------------------
1 | import torch.backends.cudnn as cudnn
2 | cudnn.benchmark=False
3 |
4 | import numpy as np
5 | import time
6 | import os
7 | import lpips
8 | from data import data_loader as dl
9 | import argparse
10 | from util.visualizer import Visualizer
11 | from IPython import embed
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('--datasets', type=str, nargs='+', default=['train/traditional','train/cnn','train/mix'], help='datasets to train on: [train/traditional],[train/cnn],[train/mix],[val/traditional],[val/cnn],[val/color],[val/deblur],[val/frameinterp],[val/superres]')
15 | parser.add_argument('--model', type=str, default='lpips', help='distance model type [lpips] for linearly calibrated net, [baseline] for off-the-shelf network, [l2] for euclidean distance, [ssim] for Structured Similarity Image Metric')
16 | parser.add_argument('--net', type=str, default='alex', help='[squeeze], [alex], or [vgg] for network architectures')
17 | parser.add_argument('--batch_size', type=int, default=50, help='batch size to test image patches in')
18 | parser.add_argument('--use_gpu', action='store_true', help='turn on flag to use GPU')
19 | parser.add_argument('--gpu_ids', type=int, nargs='+', default=[0], help='gpus to use')
20 |
21 | parser.add_argument('--nThreads', type=int, default=4, help='number of threads to use in data loader')
22 | parser.add_argument('--nepoch', type=int, default=5, help='# epochs at base learning rate')
23 | parser.add_argument('--nepoch_decay', type=int, default=5, help='# additional epochs at linearly learning rate')
24 | parser.add_argument('--display_freq', type=int, default=5000, help='frequency (in instances) of showing training results on screen')
25 | parser.add_argument('--print_freq', type=int, default=5000, help='frequency (in instances) of showing training results on console')
26 | parser.add_argument('--save_latest_freq', type=int, default=20000, help='frequency (in instances) of saving the latest results')
27 | parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs')
28 | parser.add_argument('--display_id', type=int, default=0, help='window id of the visdom display, [0] for no displaying')
29 | parser.add_argument('--display_winsize', type=int, default=256, help='display window size')
30 | parser.add_argument('--display_port', type=int, default=8001, help='visdom display port')
31 | parser.add_argument('--use_html', action='store_true', help='save off html pages')
32 | parser.add_argument('--checkpoints_dir', type=str, default='checkpoints', help='checkpoints directory')
33 | parser.add_argument('--name', type=str, default='tmp', help='directory name for training')
34 |
35 | parser.add_argument('--from_scratch', action='store_true', help='model was initialized from scratch')
36 | parser.add_argument('--train_trunk', action='store_true', help='model trunk was trained/tuned')
37 | parser.add_argument('--train_plot', action='store_true', help='plot saving')
38 |
39 | opt = parser.parse_args()
40 | opt.save_dir = os.path.join(opt.checkpoints_dir,opt.name)
41 | if(not os.path.exists(opt.save_dir)):
42 | os.mkdir(opt.save_dir)
43 |
44 | # initialize model
45 | trainer = lpips.Trainer()
46 | trainer.initialize(model=opt.model, net=opt.net, use_gpu=opt.use_gpu, is_train=True,
47 | pnet_rand=opt.from_scratch, pnet_tune=opt.train_trunk, gpu_ids=opt.gpu_ids)
48 |
49 | # load data from all training sets
50 | data_loader = dl.CreateDataLoader(opt.datasets,dataset_mode='2afc', batch_size=opt.batch_size, serial_batches=False, nThreads=opt.nThreads)
51 | dataset = data_loader.load_data()
52 | dataset_size = len(data_loader)
53 | D = len(dataset)
54 | print('Loading %i instances from'%dataset_size,opt.datasets)
55 | visualizer = Visualizer(opt)
56 |
57 | total_steps = 0
58 | fid = open(os.path.join(opt.checkpoints_dir,opt.name,'train_log.txt'),'w+')
59 | for epoch in range(1, opt.nepoch + opt.nepoch_decay + 1):
60 | epoch_start_time = time.time()
61 | for i, data in enumerate(dataset):
62 | iter_start_time = time.time()
63 | total_steps += opt.batch_size
64 | epoch_iter = total_steps - dataset_size * (epoch - 1)
65 |
66 | trainer.set_input(data)
67 | trainer.optimize_parameters()
68 |
69 | if total_steps % opt.display_freq == 0:
70 | visualizer.display_current_results(trainer.get_current_visuals(), epoch)
71 |
72 | if total_steps % opt.print_freq == 0:
73 | errors = trainer.get_current_errors()
74 | t = (time.time()-iter_start_time)/opt.batch_size
75 | t2o = (time.time()-epoch_start_time)/3600.
76 | t2 = t2o*D/(i+.0001)
77 | visualizer.print_current_errors(epoch, epoch_iter, errors, t, t2=t2, t2o=t2o, fid=fid)
78 |
79 | for key in errors.keys():
80 | visualizer.plot_current_errors_save(epoch, float(epoch_iter)/dataset_size, opt, errors, keys=[key,], name=key, to_plot=opt.train_plot)
81 |
82 | if opt.display_id > 0:
83 | visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors)
84 |
85 | if total_steps % opt.save_latest_freq == 0:
86 | print('saving the latest model (epoch %d, total_steps %d)' %
87 | (epoch, total_steps))
88 | trainer.save(opt.save_dir, 'latest')
89 |
90 | if epoch % opt.save_epoch_freq == 0:
91 | print('saving the model at the end of epoch %d, iters %d' %
92 | (epoch, total_steps))
93 | trainer.save(opt.save_dir, 'latest')
94 | trainer.save(opt.save_dir, epoch)
95 |
96 | print('End of epoch %d / %d \t Time Taken: %d sec' %
97 | (epoch, opt.nepoch + opt.nepoch_decay, time.time() - epoch_start_time))
98 |
99 | if epoch > opt.nepoch:
100 | trainer.update_learning_rate(opt.nepoch_decay)
101 |
102 | # trainer.save_done(True)
103 | fid.close()
104 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/lpips/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import numpy as np
7 | import torch
8 | # from torch.autograd import Variable
9 |
10 | from PerceptualSimilarity.lpips.trainer import *
11 | from PerceptualSimilarity.lpips.lpips import *
12 |
13 | # class PerceptualLoss(torch.nn.Module):
14 | # def __init__(self, model='lpips', net='alex', spatial=False, use_gpu=False, gpu_ids=[0], version='0.1'): # VGG using our perceptually-learned weights (LPIPS metric)
15 | # # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss
16 | # super(PerceptualLoss, self).__init__()
17 | # print('Setting up Perceptual loss...')
18 | # self.use_gpu = use_gpu
19 | # self.spatial = spatial
20 | # self.gpu_ids = gpu_ids
21 | # self.model = dist_model.DistModel()
22 | # self.model.initialize(model=model, net=net, use_gpu=use_gpu, spatial=self.spatial, gpu_ids=gpu_ids, version=version)
23 | # print('...[%s] initialized'%self.model.name())
24 | # print('...Done')
25 |
26 | # def forward(self, pred, target, normalize=False):
27 | # """
28 | # Pred and target are Variables.
29 | # If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1]
30 | # If normalize is False, assumes the images are already between [-1,+1]
31 |
32 | # Inputs pred and target are Nx3xHxW
33 | # Output pytorch Variable N long
34 | # """
35 |
36 | # if normalize:
37 | # target = 2 * target - 1
38 | # pred = 2 * pred - 1
39 |
40 | # return self.model.forward(target, pred)
41 |
42 | def normalize_tensor(in_feat,eps=1e-10):
43 | norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
44 | return in_feat/(norm_factor+eps)
45 |
46 | def l2(p0, p1, range=255.):
47 | return .5*np.mean((p0 / range - p1 / range)**2)
48 |
49 | def psnr(p0, p1, peak=255.):
50 | return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2))
51 |
52 | def dssim(p0, p1, range=255.):
53 | from skimage.measure import compare_ssim
54 | return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.
55 |
56 | def rgb2lab(in_img,mean_cent=False):
57 | from skimage import color
58 | img_lab = color.rgb2lab(in_img)
59 | if(mean_cent):
60 | img_lab[:,:,0] = img_lab[:,:,0]-50
61 | return img_lab
62 |
63 | def tensor2np(tensor_obj):
64 | # change dimension of a tensor object into a numpy array
65 | return tensor_obj[0].cpu().float().numpy().transpose((1,2,0))
66 |
67 | def np2tensor(np_obj):
68 | # change dimenion of np array into tensor array
69 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
70 |
71 | def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False):
72 | # image tensor to lab tensor
73 | from skimage import color
74 |
75 | img = tensor2im(image_tensor)
76 | img_lab = color.rgb2lab(img)
77 | if(mc_only):
78 | img_lab[:,:,0] = img_lab[:,:,0]-50
79 | if(to_norm and not mc_only):
80 | img_lab[:,:,0] = img_lab[:,:,0]-50
81 | img_lab = img_lab/100.
82 |
83 | return np2tensor(img_lab)
84 |
85 | def tensorlab2tensor(lab_tensor,return_inbnd=False):
86 | from skimage import color
87 | import warnings
88 | warnings.filterwarnings("ignore")
89 |
90 | lab = tensor2np(lab_tensor)*100.
91 | lab[:,:,0] = lab[:,:,0]+50
92 |
93 | rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1)
94 | if(return_inbnd):
95 | # convert back to lab, see if we match
96 | lab_back = color.rgb2lab(rgb_back.astype('uint8'))
97 | mask = 1.*np.isclose(lab_back,lab,atol=2.)
98 | mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis])
99 | return (im2tensor(rgb_back),mask)
100 | else:
101 | return im2tensor(rgb_back)
102 |
103 | def load_image(path):
104 | if(path[-3:] == 'dng'):
105 | import rawpy
106 | with rawpy.imread(path) as raw:
107 | img = raw.postprocess()
108 | elif(path[-3:]=='bmp' or path[-3:]=='jpg' or path[-3:]=='png' or path[-4:]=='jpeg'):
109 | import cv2
110 | return cv2.imread(path)[:,:,::-1]
111 | else:
112 | img = (255*plt.imread(path)[:,:,:3]).astype('uint8')
113 |
114 | return img
115 |
116 | def rgb2lab(input):
117 | from skimage import color
118 | return color.rgb2lab(input / 255.)
119 |
120 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
121 | image_numpy = image_tensor[0].cpu().float().numpy()
122 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
123 | return image_numpy.astype(imtype)
124 |
125 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
126 | return torch.Tensor((image / factor - cent)
127 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
128 |
129 | def tensor2vec(vector_tensor):
130 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0]
131 |
132 |
133 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
134 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
135 | image_numpy = image_tensor[0].cpu().float().numpy()
136 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
137 | return image_numpy.astype(imtype)
138 |
139 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
140 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
141 | return torch.Tensor((image / factor - cent)
142 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
143 |
144 |
145 |
146 | def voc_ap(rec, prec, use_07_metric=False):
147 | """ ap = voc_ap(rec, prec, [use_07_metric])
148 | Compute VOC AP given precision and recall.
149 | If use_07_metric is true, uses the
150 | VOC 07 11 point method (default:False).
151 | """
152 | if use_07_metric:
153 | # 11 point metric
154 | ap = 0.
155 | for t in np.arange(0., 1.1, 0.1):
156 | if np.sum(rec >= t) == 0:
157 | p = 0
158 | else:
159 | p = np.max(prec[rec >= t])
160 | ap = ap + p / 11.
161 | else:
162 | # correct AP calculation
163 | # first append sentinel values at the end
164 | mrec = np.concatenate(([0.], rec, [1.]))
165 | mpre = np.concatenate(([0.], prec, [0.]))
166 |
167 | # compute the precision envelope
168 | for i in range(mpre.size - 1, 0, -1):
169 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
170 |
171 | # to calculate area under PR curve, look for points
172 | # where X axis (recall) changes value
173 | i = np.where(mrec[1:] != mrec[:-1])[0]
174 |
175 | # and sum (\Delta recall) * prec
176 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
177 | return ap
178 |
179 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/lpips/pretrained_networks.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | import torch
3 | from torchvision import models as tv
4 |
5 | class squeezenet(torch.nn.Module):
6 | def __init__(self, requires_grad=False, pretrained=True):
7 | super(squeezenet, self).__init__()
8 | pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features
9 | self.slice1 = torch.nn.Sequential()
10 | self.slice2 = torch.nn.Sequential()
11 | self.slice3 = torch.nn.Sequential()
12 | self.slice4 = torch.nn.Sequential()
13 | self.slice5 = torch.nn.Sequential()
14 | self.slice6 = torch.nn.Sequential()
15 | self.slice7 = torch.nn.Sequential()
16 | self.N_slices = 7
17 | for x in range(2):
18 | self.slice1.add_module(str(x), pretrained_features[x])
19 | for x in range(2,5):
20 | self.slice2.add_module(str(x), pretrained_features[x])
21 | for x in range(5, 8):
22 | self.slice3.add_module(str(x), pretrained_features[x])
23 | for x in range(8, 10):
24 | self.slice4.add_module(str(x), pretrained_features[x])
25 | for x in range(10, 11):
26 | self.slice5.add_module(str(x), pretrained_features[x])
27 | for x in range(11, 12):
28 | self.slice6.add_module(str(x), pretrained_features[x])
29 | for x in range(12, 13):
30 | self.slice7.add_module(str(x), pretrained_features[x])
31 | if not requires_grad:
32 | for param in self.parameters():
33 | param.requires_grad = False
34 |
35 | def forward(self, X):
36 | h = self.slice1(X)
37 | h_relu1 = h
38 | h = self.slice2(h)
39 | h_relu2 = h
40 | h = self.slice3(h)
41 | h_relu3 = h
42 | h = self.slice4(h)
43 | h_relu4 = h
44 | h = self.slice5(h)
45 | h_relu5 = h
46 | h = self.slice6(h)
47 | h_relu6 = h
48 | h = self.slice7(h)
49 | h_relu7 = h
50 | vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7'])
51 | out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7)
52 |
53 | return out
54 |
55 |
56 | class alexnet(torch.nn.Module):
57 | def __init__(self, requires_grad=False, pretrained=True):
58 | super(alexnet, self).__init__()
59 | alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features
60 | self.slice1 = torch.nn.Sequential()
61 | self.slice2 = torch.nn.Sequential()
62 | self.slice3 = torch.nn.Sequential()
63 | self.slice4 = torch.nn.Sequential()
64 | self.slice5 = torch.nn.Sequential()
65 | self.N_slices = 5
66 | for x in range(2):
67 | self.slice1.add_module(str(x), alexnet_pretrained_features[x])
68 | for x in range(2, 5):
69 | self.slice2.add_module(str(x), alexnet_pretrained_features[x])
70 | for x in range(5, 8):
71 | self.slice3.add_module(str(x), alexnet_pretrained_features[x])
72 | for x in range(8, 10):
73 | self.slice4.add_module(str(x), alexnet_pretrained_features[x])
74 | for x in range(10, 12):
75 | self.slice5.add_module(str(x), alexnet_pretrained_features[x])
76 | if not requires_grad:
77 | for param in self.parameters():
78 | param.requires_grad = False
79 |
80 | def forward(self, X):
81 | h = self.slice1(X)
82 | h_relu1 = h
83 | h = self.slice2(h)
84 | h_relu2 = h
85 | h = self.slice3(h)
86 | h_relu3 = h
87 | h = self.slice4(h)
88 | h_relu4 = h
89 | h = self.slice5(h)
90 | h_relu5 = h
91 | alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
92 | out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
93 |
94 | return out
95 |
96 | class vgg16(torch.nn.Module):
97 | def __init__(self, requires_grad=False, pretrained=True):
98 | super(vgg16, self).__init__()
99 | vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
100 | self.slice1 = torch.nn.Sequential()
101 | self.slice2 = torch.nn.Sequential()
102 | self.slice3 = torch.nn.Sequential()
103 | self.slice4 = torch.nn.Sequential()
104 | self.slice5 = torch.nn.Sequential()
105 | self.N_slices = 5
106 | for x in range(4):
107 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
108 | for x in range(4, 9):
109 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
110 | for x in range(9, 16):
111 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
112 | for x in range(16, 23):
113 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
114 | for x in range(23, 30):
115 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
116 | if not requires_grad:
117 | for param in self.parameters():
118 | param.requires_grad = False
119 |
120 | def forward(self, X):
121 | h = self.slice1(X)
122 | h_relu1_2 = h
123 | h = self.slice2(h)
124 | h_relu2_2 = h
125 | h = self.slice3(h)
126 | h_relu3_3 = h
127 | h = self.slice4(h)
128 | h_relu4_3 = h
129 | h = self.slice5(h)
130 | h_relu5_3 = h
131 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
132 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
133 |
134 | return out
135 |
136 |
137 |
138 | class resnet(torch.nn.Module):
139 | def __init__(self, requires_grad=False, pretrained=True, num=18):
140 | super(resnet, self).__init__()
141 | if(num==18):
142 | self.net = tv.resnet18(pretrained=pretrained)
143 | elif(num==34):
144 | self.net = tv.resnet34(pretrained=pretrained)
145 | elif(num==50):
146 | self.net = tv.resnet50(pretrained=pretrained)
147 | elif(num==101):
148 | self.net = tv.resnet101(pretrained=pretrained)
149 | elif(num==152):
150 | self.net = tv.resnet152(pretrained=pretrained)
151 | self.N_slices = 5
152 |
153 | self.conv1 = self.net.conv1
154 | self.bn1 = self.net.bn1
155 | self.relu = self.net.relu
156 | self.maxpool = self.net.maxpool
157 | self.layer1 = self.net.layer1
158 | self.layer2 = self.net.layer2
159 | self.layer3 = self.net.layer3
160 | self.layer4 = self.net.layer4
161 |
162 | def forward(self, X):
163 | h = self.conv1(X)
164 | h = self.bn1(h)
165 | h = self.relu(h)
166 | h_relu1 = h
167 | h = self.maxpool(h)
168 | h = self.layer1(h)
169 | h_conv2 = h
170 | h = self.layer2(h)
171 | h_conv3 = h
172 | h = self.layer3(h)
173 | h_conv4 = h
174 | h = self.layer4(h)
175 | h_conv5 = h
176 |
177 | outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5'])
178 | out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
179 |
180 | return out
181 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/util/visualizer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import time
4 | from . import util
5 | from . import html
6 | import matplotlib.pyplot as plt
7 | import math
8 | # from IPython import embed
9 |
10 | def zoom_to_res(img,res=256,order=0,axis=0):
11 | # img 3xXxX
12 | from scipy.ndimage import zoom
13 | zoom_factor = res/img.shape[1]
14 | if(axis==0):
15 | return zoom(img,[1,zoom_factor,zoom_factor],order=order)
16 | elif(axis==2):
17 | return zoom(img,[zoom_factor,zoom_factor,1],order=order)
18 |
19 | class Visualizer():
20 | def __init__(self, opt):
21 | # self.opt = opt
22 | self.display_id = opt.display_id
23 | # self.use_html = opt.is_train and not opt.no_html
24 | self.win_size = opt.display_winsize
25 | self.name = opt.name
26 | self.display_cnt = 0 # display_current_results counter
27 | self.display_cnt_high = 0
28 | self.use_html = opt.use_html
29 |
30 | if self.display_id > 0:
31 | import visdom
32 | self.vis = visdom.Visdom(port = opt.display_port)
33 |
34 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
35 | util.mkdirs([self.web_dir,])
36 | if self.use_html:
37 | self.img_dir = os.path.join(self.web_dir, 'images')
38 | print('create web directory %s...' % self.web_dir)
39 | util.mkdirs([self.img_dir,])
40 |
41 | # |visuals|: dictionary of images to display or save
42 | def display_current_results(self, visuals, epoch, nrows=None, res=256):
43 | if self.display_id > 0: # show images in the browser
44 | title = self.name
45 | if(nrows is None):
46 | nrows = int(math.ceil(len(visuals.items()) / 2.0))
47 | images = []
48 | idx = 0
49 | for label, image_numpy in visuals.items():
50 | title += " | " if idx % nrows == 0 else ", "
51 | title += label
52 | img = image_numpy.transpose([2, 0, 1])
53 | img = zoom_to_res(img,res=res,order=0)
54 | images.append(img)
55 | idx += 1
56 | if len(visuals.items()) % 2 != 0:
57 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1]))*255
58 | white_image = zoom_to_res(white_image,res=res,order=0)
59 | images.append(white_image)
60 | self.vis.images(images, nrow=nrows, win=self.display_id + 1,
61 | opts=dict(title=title))
62 |
63 | if self.use_html: # save images to a html file
64 | for label, image_numpy in visuals.items():
65 | img_path = os.path.join(self.img_dir, 'epoch%.3d_cnt%.6d_%s.png' % (epoch, self.display_cnt, label))
66 | util.save_image(zoom_to_res(image_numpy, res=res, axis=2), img_path)
67 |
68 | self.display_cnt += 1
69 | self.display_cnt_high = np.maximum(self.display_cnt_high, self.display_cnt)
70 |
71 | # update website
72 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1)
73 | for n in range(epoch, 0, -1):
74 | webpage.add_header('epoch [%d]' % n)
75 | if(n==epoch):
76 | high = self.display_cnt
77 | else:
78 | high = self.display_cnt_high
79 | for c in range(high-1,-1,-1):
80 | ims = []
81 | txts = []
82 | links = []
83 |
84 | for label, image_numpy in visuals.items():
85 | img_path = 'epoch%.3d_cnt%.6d_%s.png' % (n, c, label)
86 | ims.append(os.path.join('images',img_path))
87 | txts.append(label)
88 | links.append(os.path.join('images',img_path))
89 | webpage.add_images(ims, txts, links, width=self.win_size)
90 | webpage.save()
91 |
92 | # save errors into a directory
93 | def plot_current_errors_save(self, epoch, counter_ratio, opt, errors,keys='+ALL',name='loss', to_plot=False):
94 | if not hasattr(self, 'plot_data'):
95 | self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())}
96 | self.plot_data['X'].append(epoch + counter_ratio)
97 | self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']])
98 |
99 | # embed()
100 | if(keys=='+ALL'):
101 | plot_keys = self.plot_data['legend']
102 | else:
103 | plot_keys = keys
104 |
105 | if(to_plot):
106 | (f,ax) = plt.subplots(1,1)
107 | for (k,kname) in enumerate(plot_keys):
108 | kk = np.where(np.array(self.plot_data['legend'])==kname)[0][0]
109 | x = self.plot_data['X']
110 | y = np.array(self.plot_data['Y'])[:,kk]
111 | if(to_plot):
112 | ax.plot(x, y, 'o-', label=kname)
113 | np.save(os.path.join(self.web_dir,'%s_x')%kname,x)
114 | np.save(os.path.join(self.web_dir,'%s_y')%kname,y)
115 |
116 | if(to_plot):
117 | plt.legend(loc=0,fontsize='small')
118 | plt.xlabel('epoch')
119 | plt.ylabel('Value')
120 | f.savefig(os.path.join(self.web_dir,'%s.png'%name))
121 | f.clf()
122 | plt.close()
123 |
124 | # errors: dictionary of error labels and values
125 | def plot_current_errors(self, epoch, counter_ratio, opt, errors):
126 | if not hasattr(self, 'plot_data'):
127 | self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())}
128 | self.plot_data['X'].append(epoch + counter_ratio)
129 | self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']])
130 | self.vis.line(
131 | X=np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1),
132 | Y=np.array(self.plot_data['Y']),
133 | opts={
134 | 'title': self.name + ' loss over time',
135 | 'legend': self.plot_data['legend'],
136 | 'xlabel': 'epoch',
137 | 'ylabel': 'loss'},
138 | win=self.display_id)
139 |
140 | # errors: same format as |errors| of plotCurrentErrors
141 | def print_current_errors(self, epoch, i, errors, t, t2=-1, t2o=-1, fid=None):
142 | message = '(ep: %d, it: %d, t: %.3f[s], ept: %.2f/%.2f[h]) ' % (epoch, i, t, t2o, t2)
143 | message += (', ').join(['%s: %.3f' % (k, v) for k, v in errors.items()])
144 |
145 | print(message)
146 | if(fid is not None):
147 | fid.write('%s\n'%message)
148 |
149 |
150 | # save image to the disk
151 | def save_images_simple(self, webpage, images, names, in_txts, prefix='', res=256):
152 | image_dir = webpage.get_image_dir()
153 | ims = []
154 | txts = []
155 | links = []
156 |
157 | for name, image_numpy, txt in zip(names, images, in_txts):
158 | image_name = '%s_%s.png' % (prefix, name)
159 | save_path = os.path.join(image_dir, image_name)
160 | if(res is not None):
161 | util.save_image(zoom_to_res(image_numpy,res=res,axis=2), save_path)
162 | else:
163 | util.save_image(image_numpy, save_path)
164 |
165 | ims.append(os.path.join(webpage.img_subdir,image_name))
166 | # txts.append(name)
167 | txts.append(txt)
168 | links.append(os.path.join(webpage.img_subdir,image_name))
169 | # embed()
170 | webpage.add_images(ims, txts, links, width=self.win_size)
171 |
172 | # save image to the disk
173 | def save_images(self, webpage, images, names, image_path, title=''):
174 | image_dir = webpage.get_image_dir()
175 | # short_path = ntpath.basename(image_path)
176 | # name = os.path.splitext(short_path)[0]
177 | # name = short_path
178 | # webpage.add_header('%s, %s' % (name, title))
179 | ims = []
180 | txts = []
181 | links = []
182 |
183 | for label, image_numpy in zip(names, images):
184 | image_name = '%s.jpg' % (label,)
185 | save_path = os.path.join(image_dir, image_name)
186 | util.save_image(image_numpy, save_path)
187 |
188 | ims.append(image_name)
189 | txts.append(label)
190 | links.append(image_name)
191 | webpage.add_images(ims, txts, links, width=self.win_size)
192 |
193 | # save image to the disk
194 | # def save_images(self, webpage, visuals, image_path, short=False):
195 | # image_dir = webpage.get_image_dir()
196 | # if short:
197 | # short_path = ntpath.basename(image_path)
198 | # name = os.path.splitext(short_path)[0]
199 | # else:
200 | # name = image_path
201 |
202 | # webpage.add_header(name)
203 | # ims = []
204 | # txts = []
205 | # links = []
206 |
207 | # for label, image_numpy in visuals.items():
208 | # image_name = '%s_%s.png' % (name, label)
209 | # save_path = os.path.join(image_dir, image_name)
210 | # util.save_image(image_numpy, save_path)
211 |
212 | # ims.append(image_name)
213 | # txts.append(label)
214 | # links.append(image_name)
215 | # webpage.add_images(ims, txts, links, width=self.win_size)
216 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/lpips/lpips.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.init as init
7 | from torch.autograd import Variable
8 | import numpy as np
9 | from . import pretrained_networks as pn
10 | import torch.nn
11 |
12 | import PerceptualSimilarity.lpips as lpips
13 |
14 | def spatial_average(in_tens, keepdim=True):
15 | return in_tens.mean([2,3],keepdim=keepdim)
16 |
17 | def upsample(in_tens, out_HW=(64,64)): # assumes scale factor is same for H and W
18 | in_H, in_W = in_tens.shape[2], in_tens.shape[3]
19 | return nn.Upsample(size=out_HW, mode='bilinear', align_corners=False)(in_tens)
20 |
21 | # Learned perceptual metric
22 | class LPIPS(nn.Module):
23 | def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spatial=False,
24 | pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True, verbose=True):
25 | # lpips - [True] means with linear calibration on top of base network
26 | # pretrained - [True] means load linear weights
27 |
28 | super(LPIPS, self).__init__()
29 | if(verbose):
30 | print('Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]'%
31 | ('LPIPS' if lpips else 'baseline', net, version, 'on' if spatial else 'off'))
32 |
33 | self.pnet_type = net
34 | self.pnet_tune = pnet_tune
35 | self.pnet_rand = pnet_rand
36 | self.spatial = spatial
37 | self.lpips = lpips # false means baseline of just averaging all layers
38 | self.version = version
39 | self.scaling_layer = ScalingLayer()
40 |
41 | if(self.pnet_type in ['vgg','vgg16']):
42 | net_type = pn.vgg16
43 | self.chns = [64,128,256,512,512]
44 | elif(self.pnet_type=='alex'):
45 | net_type = pn.alexnet
46 | self.chns = [64,192,384,256,256]
47 | elif(self.pnet_type=='squeeze'):
48 | net_type = pn.squeezenet
49 | self.chns = [64,128,256,384,384,512,512]
50 | self.L = len(self.chns)
51 |
52 | self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
53 |
54 | if(lpips):
55 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
56 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
57 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
58 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
59 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
60 | self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]
61 | if(self.pnet_type=='squeeze'): # 7 layers for squeezenet
62 | self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
63 | self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
64 | self.lins+=[self.lin5,self.lin6]
65 | self.lins = nn.ModuleList(self.lins)
66 |
67 | if(pretrained):
68 | if(model_path is None):
69 | import inspect
70 | import os
71 | model_path = os.path.abspath(os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth'%(version,net)))
72 |
73 | if(verbose):
74 | print('Loading model from: %s'%model_path)
75 | self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False)
76 |
77 | if(eval_mode):
78 | self.eval()
79 |
80 | def forward(self, in0, in1, retPerLayer=False, normalize=False):
81 | if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
82 | in0 = 2 * in0 - 1
83 | in1 = 2 * in1 - 1
84 |
85 | # v0.0 - original release had a bug, where input was not scaled
86 | in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
87 | outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
88 | feats0, feats1, diffs = {}, {}, {}
89 |
90 | for kk in range(self.L):
91 | feats0[kk], feats1[kk] = lpips.normalize_tensor(outs0[kk]), lpips.normalize_tensor(outs1[kk])
92 | diffs[kk] = (feats0[kk]-feats1[kk])**2
93 |
94 | if(self.lpips):
95 | if(self.spatial):
96 | res = [upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)]
97 | else:
98 | res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)]
99 | else:
100 | if(self.spatial):
101 | res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)]
102 | else:
103 | res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]
104 |
105 | val = res[0]
106 | for l in range(1,self.L):
107 | val += res[l]
108 |
109 | # a = spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
110 | # b = torch.max(self.lins[kk](feats0[kk]**2))
111 | # for kk in range(self.L):
112 | # a += spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
113 | # b = torch.max(b,torch.max(self.lins[kk](feats0[kk]**2)))
114 | # a = a/self.L
115 | # from IPython import embed
116 | # embed()
117 | # return 10*torch.log10(b/a)
118 |
119 | if(retPerLayer):
120 | return (val, res)
121 | else:
122 | return val
123 |
124 |
125 | class ScalingLayer(nn.Module):
126 | def __init__(self):
127 | super(ScalingLayer, self).__init__()
128 | self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None])
129 | self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None])
130 |
131 | def forward(self, inp):
132 | return (inp - self.shift) / self.scale
133 |
134 |
135 | class NetLinLayer(nn.Module):
136 | ''' A single linear layer which does a 1x1 conv '''
137 | def __init__(self, chn_in, chn_out=1, use_dropout=False):
138 | super(NetLinLayer, self).__init__()
139 |
140 | layers = [nn.Dropout(),] if(use_dropout) else []
141 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]
142 | self.model = nn.Sequential(*layers)
143 |
144 | def forward(self, x):
145 | return self.model(x)
146 |
147 | class Dist2LogitLayer(nn.Module):
148 | ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
149 | def __init__(self, chn_mid=32, use_sigmoid=True):
150 | super(Dist2LogitLayer, self).__init__()
151 |
152 | layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),]
153 | layers += [nn.LeakyReLU(0.2,True),]
154 | layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),]
155 | layers += [nn.LeakyReLU(0.2,True),]
156 | layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),]
157 | if(use_sigmoid):
158 | layers += [nn.Sigmoid(),]
159 | self.model = nn.Sequential(*layers)
160 |
161 | def forward(self,d0,d1,eps=0.1):
162 | return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1))
163 |
164 | class BCERankingLoss(nn.Module):
165 | def __init__(self, chn_mid=32):
166 | super(BCERankingLoss, self).__init__()
167 | self.net = Dist2LogitLayer(chn_mid=chn_mid)
168 | # self.parameters = list(self.net.parameters())
169 | self.loss = torch.nn.BCELoss()
170 |
171 | def forward(self, d0, d1, judge):
172 | per = (judge+1.)/2.
173 | self.logit = self.net.forward(d0,d1)
174 | return self.loss(self.logit, per)
175 |
176 | # L2, DSSIM metrics
177 | class FakeNet(nn.Module):
178 | def __init__(self, use_gpu=True, colorspace='Lab'):
179 | super(FakeNet, self).__init__()
180 | self.use_gpu = use_gpu
181 | self.colorspace = colorspace
182 |
183 | class L2(FakeNet):
184 | def forward(self, in0, in1, retPerLayer=None):
185 | assert(in0.size()[0]==1) # currently only supports batchSize 1
186 |
187 | if(self.colorspace=='RGB'):
188 | (N,C,X,Y) = in0.size()
189 | value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N)
190 | return value
191 | elif(self.colorspace=='Lab'):
192 | value = lpips.l2(lpips.tensor2np(lpips.tensor2tensorlab(in0.data,to_norm=False)),
193 | lpips.tensor2np(lpips.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
194 | ret_var = Variable( torch.Tensor((value,) ) )
195 | if(self.use_gpu):
196 | ret_var = ret_var.cuda()
197 | return ret_var
198 |
199 | class DSSIM(FakeNet):
200 |
201 | def forward(self, in0, in1, retPerLayer=None):
202 | assert(in0.size()[0]==1) # currently only supports batchSize 1
203 |
204 | if(self.colorspace=='RGB'):
205 | value = lpips.dssim(1.*lpips.tensor2im(in0.data), 1.*lpips.tensor2im(in1.data), range=255.).astype('float')
206 | elif(self.colorspace=='Lab'):
207 | value = lpips.dssim(lpips.tensor2np(lpips.tensor2tensorlab(in0.data,to_norm=False)),
208 | lpips.tensor2np(lpips.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
209 | ret_var = Variable( torch.Tensor((value,) ) )
210 | if(self.use_gpu):
211 | ret_var = ret_var.cuda()
212 | return ret_var
213 |
214 | def print_network(net):
215 | num_params = 0
216 | for param in net.parameters():
217 | num_params += param.numel()
218 | print('Network',net)
219 | print('Total number of parameters: %d' % num_params)
220 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/README.md:
--------------------------------------------------------------------------------
1 |
2 | ## Perceptual Similarity Metric and Dataset [[Project Page]](http://richzhang.github.io/PerceptualSimilarity/)
3 |
4 | **The Unreasonable Effectiveness of Deep Features as a Perceptual Metric**
5 | [Richard Zhang](https://richzhang.github.io/), [Phillip Isola](http://web.mit.edu/phillipi/), [Alexei A. Efros](http://www.eecs.berkeley.edu/~efros/), [Eli Shechtman](https://research.adobe.com/person/eli-shechtman/), [Oliver Wang](http://www.oliverwang.info/). In [CVPR](https://arxiv.org/abs/1801.03924), 2018.
6 |
7 |
8 |
9 | ### Quick start
10 |
11 | Run `pip install lpips`. The following Python code is all you need.
12 |
13 | ```python
14 | import lpips
15 | loss_fn_alex = lpips.LPIPS(net='alex') # best forward scores
16 | loss_fn_vgg = lpips.LPIPS(net='vgg') # closer to "traditional" perceptual loss, when used for optimization
17 |
18 | import torch
19 | img0 = torch.zeros(1,3,64,64) # image should be RGB, IMPORTANT: normalized to [-1,1]
20 | img1 = torch.zeros(1,3,64,64)
21 | d = loss_fn_alex(img0, img1)
22 | ```
23 |
24 | More thorough information about variants is below. This repository contains our **perceptual metric (LPIPS)** and **dataset (BAPPS)**. It can also be used as a "perceptual loss". This uses PyTorch; a Tensorflow alternative is [here](https://github.com/alexlee-gk/lpips-tensorflow).
25 |
26 |
27 | **Table of Contents**
28 | 1. [Learned Perceptual Image Patch Similarity (LPIPS) metric](#1-learned-perceptual-image-patch-similarity-lpips-metric)
29 | a. [Basic Usage](#a-basic-usage) If you just want to run the metric through command line, this is all you need.
30 | b. ["Perceptual Loss" usage](#b-backpropping-through-the-metric)
31 | c. [About the metric](#c-about-the-metric)
32 | 2. [Berkeley-Adobe Perceptual Patch Similarity (BAPPS) dataset](#2-berkeley-adobe-perceptual-patch-similarity-bapps-dataset)
33 | a. [Download](#a-downloading-the-dataset)
34 | b. [Evaluation](#b-evaluating-a-perceptual-similarity-metric-on-a-dataset)
35 | c. [About the dataset](#c-about-the-dataset)
36 | d. [Train the metric using the dataset](#d-using-the-dataset-to-train-the-metric)
37 |
38 | ## (0) Dependencies/Setup
39 |
40 | ### Installation
41 | - Install PyTorch 1.0+ and torchvision fom http://pytorch.org
42 |
43 | ```bash
44 | pip install -r requirements.txt
45 | ```
46 | - Clone this repo:
47 | ```bash
48 | git clone https://github.com/richzhang/PerceptualSimilarity
49 | cd PerceptualSimilarity
50 | ```
51 |
52 | ## (1) Learned Perceptual Image Patch Similarity (LPIPS) metric
53 |
54 | Evaluate the distance between image patches. **Higher means further/more different. Lower means more similar.**
55 |
56 | ### (A) Basic Usage
57 |
58 | #### (A.I) Line commands
59 |
60 | Example scripts to take the distance between 2 specific images, all corresponding pairs of images in 2 directories, or all pairs of images within a directory:
61 |
62 | ```
63 | python lpips_2imgs.py -p0 imgs/ex_ref.png -p1 imgs/ex_p0.png --use_gpu
64 | python lpips_2dirs.py -d0 imgs/ex_dir0 -d1 imgs/ex_dir1 -o imgs/example_dists.txt --use_gpu
65 | python lpips_1dir_allpairs.py -d imgs/ex_dir_pair -o imgs/example_dists_pair.txt --use_gpu
66 | ```
67 |
68 | #### (A.II) Python code
69 |
70 | File [test_network.py](test_network.py) shows example usage. This snippet is all you really need.
71 |
72 | ```python
73 | import lpips
74 | loss_fn = lpips.LPIPS(net='alex')
75 | d = loss_fn.forward(im0,im1)
76 | ```
77 |
78 | Variables ```im0, im1``` is a PyTorch Tensor/Variable with shape ```Nx3xHxW``` (```N``` patches of size ```HxW```, RGB images scaled in `[-1,+1]`). This returns `d`, a length `N` Tensor/Variable.
79 |
80 | Run `python test_network.py` to take the distance between example reference image [`ex_ref.png`](imgs/ex_ref.png) to distorted images [`ex_p0.png`](./imgs/ex_p0.png) and [`ex_p1.png`](imgs/ex_p1.png). Before running it - which do you think *should* be closer?
81 |
82 | **Some Options** By default in `model.initialize`:
83 | - By default, `net='alex'`. Network `alex` is fastest, performs the best (as a forward metric), and is the default. For backpropping, `net='vgg'` loss is closer to the traditional "perceptual loss".
84 | - By default, `lpips=True`. This adds a linear calibration on top of intermediate features in the net. Set this to `lpips=False` to equally weight all the features.
85 |
86 | ### (B) Backpropping through the metric
87 |
88 | File [`lpips_loss.py`](lpips_loss.py) shows how to iteratively optimize using the metric. Run `python lpips_loss.py` for a demo. The code can also be used to implement vanilla VGG loss, without our learned weights.
89 |
90 | ### (C) About the metric
91 |
92 | **Higher means further/more different. Lower means more similar.**
93 |
94 | We found that deep network activations work surprisingly well as a perceptual similarity metric. This was true across network architectures (SqueezeNet [2.8 MB], AlexNet [9.1 MB], and VGG [58.9 MB] provided similar scores) and supervisory signals (unsupervised, self-supervised, and supervised all perform strongly). We slightly improved scores by linearly "calibrating" networks - adding a linear layer on top of off-the-shelf classification networks. We provide 3 variants, using linear layers on top of the SqueezeNet, AlexNet (default), and VGG networks.
95 |
96 | If you use LPIPS in your publication, please specify which version you are using. The current version is 0.1. You can set `version='0.0'` for the initial release.
97 |
98 | ## (2) Berkeley Adobe Perceptual Patch Similarity (BAPPS) dataset
99 |
100 | ### (A) Downloading the dataset
101 |
102 | Run `bash ./scripts/download_dataset.sh` to download and unzip the dataset into directory `./dataset`. It takes [6.6 GB] total. Alternatively, run `bash ./scripts/download_dataset_valonly.sh` to only download the validation set [1.3 GB].
103 | - 2AFC train [5.3 GB]
104 | - 2AFC val [1.1 GB]
105 | - JND val [0.2 GB]
106 |
107 | ### (B) Evaluating a perceptual similarity metric on a dataset
108 |
109 | Script `test_dataset_model.py` evaluates a perceptual model on a subset of the dataset.
110 |
111 | **Dataset flags**
112 | - `--dataset_mode`: `2afc` or `jnd`, which type of perceptual judgment to evaluate
113 | - `--datasets`: list the datasets to evaluate
114 | - if `--dataset_mode 2afc`: choices are [`train/traditional`, `train/cnn`, `val/traditional`, `val/cnn`, `val/superres`, `val/deblur`, `val/color`, `val/frameinterp`]
115 | - if `--dataset_mode jnd`: choices are [`val/traditional`, `val/cnn`]
116 |
117 | **Perceptual similarity model flags**
118 | - `--model`: perceptual similarity model to use
119 | - `lpips` for our LPIPS learned similarity model (linear network on top of internal activations of pretrained network)
120 | - `baseline` for a classification network (uncalibrated with all layers averaged)
121 | - `l2` for Euclidean distance
122 | - `ssim` for Structured Similarity Image Metric
123 | - `--net`: [`squeeze`,`alex`,`vgg`] for the `net-lin` and `net` models; ignored for `l2` and `ssim` models
124 | - `--colorspace`: choices are [`Lab`,`RGB`], used for the `l2` and `ssim` models; ignored for `net-lin` and `net` models
125 |
126 | **Misc flags**
127 | - `--batch_size`: evaluation batch size (will default to 1)
128 | - `--use_gpu`: turn on this flag for GPU usage
129 |
130 | An example usage is as follows: `python ./test_dataset_model.py --dataset_mode 2afc --datasets val/traditional val/cnn --model lpips --net alex --use_gpu --batch_size 50`. This would evaluate our model on the "traditional" and "cnn" validation datasets.
131 |
132 | ### (C) About the dataset
133 |
134 | The dataset contains two types of perceptual judgements: **Two Alternative Forced Choice (2AFC)** and **Just Noticeable Differences (JND)**.
135 |
136 | **(1) 2AFC** Evaluators were given a patch triplet (1 reference + 2 distorted). They were asked to select which of the distorted was "closer" to the reference.
137 |
138 | Training sets contain 2 judgments/triplet.
139 | - `train/traditional` [56.6k triplets]
140 | - `train/cnn` [38.1k triplets]
141 | - `train/mix` [56.6k triplets]
142 |
143 | Validation sets contain 5 judgments/triplet.
144 | - `val/traditional` [4.7k triplets]
145 | - `val/cnn` [4.7k triplets]
146 | - `val/superres` [10.9k triplets]
147 | - `val/deblur` [9.4k triplets]
148 | - `val/color` [4.7k triplets]
149 | - `val/frameinterp` [1.9k triplets]
150 |
151 | Each 2AFC subdirectory contains the following folders:
152 | - `ref`: original reference patches
153 | - `p0,p1`: two distorted patches
154 | - `judge`: human judgments - 0 if all preferred p0, 1 if all humans preferred p1
155 |
156 | **(2) JND** Evaluators were presented with two patches - a reference and a distorted - for a limited time. They were asked if the patches were the same (identically) or different.
157 |
158 | Each set contains 3 human evaluations/example.
159 | - `val/traditional` [4.8k pairs]
160 | - `val/cnn` [4.8k pairs]
161 |
162 | Each JND subdirectory contains the following folders:
163 | - `p0,p1`: two patches
164 | - `same`: human judgments: 0 if all humans thought patches were different, 1 if all humans thought patches were same
165 |
166 | ### (D) Using the dataset to train the metric
167 |
168 | See script `train_test_metric.sh` for an example of training and testing the metric. The script will train a model on the full training set for 10 epochs, and then test the learned metric on all of the validation sets. The numbers should roughly match the **Alex - lin** row in Table 5 in the [paper](https://arxiv.org/abs/1801.03924). The code supports training a linear layer on top of an existing representation. Training will add a subdirectory in the `checkpoints` directory.
169 |
170 | You can also train "scratch" and "tune" versions by running `train_test_metric_scratch.sh` and `train_test_metric_tune.sh`, respectively.
171 |
172 | ## Citation
173 |
174 | If you find this repository useful for your research, please use the following.
175 |
176 | ```
177 | @inproceedings{zhang2018perceptual,
178 | title={The Unreasonable Effectiveness of Deep Features as a Perceptual Metric},
179 | author={Zhang, Richard and Isola, Phillip and Efros, Alexei A and Shechtman, Eli and Wang, Oliver},
180 | booktitle={CVPR},
181 | year={2018}
182 | }
183 | ```
184 |
185 | ## Acknowledgements
186 |
187 | This repository borrows partially from the [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) repository. The average precision (AP) code is borrowed from the [py-faster-rcnn](https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/datasets/voc_eval.py) repository. [Angjoo Kanazawa](https://github.com/akanazawa), [Connelly Barnes](http://www.connellybarnes.com/work/), [Gaurav Mittal](https://github.com/g1910), [wilhelmhb](https://github.com/wilhelmhb), [Filippo Mameli](https://github.com/mameli), [SuperShinyEyes](https://github.com/SuperShinyEyes), [Minyoung Huh](http://people.csail.mit.edu/minhuh/) helped to improve the codebase.
188 |
--------------------------------------------------------------------------------
/PerceptualSimilarity/lpips/trainer.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 |
4 | import numpy as np
5 | import torch
6 | from torch import nn
7 | from collections import OrderedDict
8 | from torch.autograd import Variable
9 | from scipy.ndimage import zoom
10 | from tqdm import tqdm
11 | import PerceptualSimilarity.lpips as lpips
12 | import os
13 |
14 |
15 | class Trainer():
16 | def name(self):
17 | return self.model_name
18 |
19 | def initialize(self, model='lpips', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None,
20 | use_gpu=True, printNet=False, spatial=False,
21 | is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]):
22 | '''
23 | INPUTS
24 | model - ['lpips'] for linearly calibrated network
25 | ['baseline'] for off-the-shelf network
26 | ['L2'] for L2 distance in Lab colorspace
27 | ['SSIM'] for ssim in RGB colorspace
28 | net - ['squeeze','alex','vgg']
29 | model_path - if None, will look in weights/[NET_NAME].pth
30 | colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM
31 | use_gpu - bool - whether or not to use a GPU
32 | printNet - bool - whether or not to print network architecture out
33 | spatial - bool - whether to output an array containing varying distances across spatial dimensions
34 | is_train - bool - [True] for training mode
35 | lr - float - initial learning rate
36 | beta1 - float - initial momentum term for adam
37 | version - 0.1 for latest, 0.0 was original (with a bug)
38 | gpu_ids - int array - [0] by default, gpus to use
39 | '''
40 | self.use_gpu = use_gpu
41 | self.gpu_ids = gpu_ids
42 | self.model = model
43 | self.net = net
44 | self.is_train = is_train
45 | self.spatial = spatial
46 | self.model_name = '%s [%s]'%(model,net)
47 |
48 | if(self.model == 'lpips'): # pretrained net + linear layer
49 | self.net = lpips.LPIPS(pretrained=not is_train, net=net, version=version, lpips=True, spatial=spatial,
50 | pnet_rand=pnet_rand, pnet_tune=pnet_tune,
51 | use_dropout=True, model_path=model_path, eval_mode=False)
52 | elif(self.model=='baseline'): # pretrained network
53 | self.net = lpips.LPIPS(pnet_rand=pnet_rand, net=net, lpips=False)
54 | elif(self.model in ['L2','l2']):
55 | self.net = lpips.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing
56 | self.model_name = 'L2'
57 | elif(self.model in ['DSSIM','dssim','SSIM','ssim']):
58 | self.net = lpips.DSSIM(use_gpu=use_gpu,colorspace=colorspace)
59 | self.model_name = 'SSIM'
60 | else:
61 | raise ValueError("Model [%s] not recognized." % self.model)
62 |
63 | self.parameters = list(self.net.parameters())
64 |
65 | if self.is_train: # training mode
66 | # extra network on top to go from distances (d0,d1) => predicted human judgment (h*)
67 | self.rankLoss = lpips.BCERankingLoss()
68 | self.parameters += list(self.rankLoss.net.parameters())
69 | self.lr = lr
70 | self.old_lr = lr
71 | self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999))
72 | else: # test mode
73 | self.net.eval()
74 |
75 | if(use_gpu):
76 | self.net.to(gpu_ids[0])
77 | self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids)
78 | if(self.is_train):
79 | self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0
80 |
81 | if(printNet):
82 | print('---------- Networks initialized -------------')
83 | networks.print_network(self.net)
84 | print('-----------------------------------------------')
85 |
86 | def forward(self, in0, in1, retPerLayer=False):
87 | ''' Function computes the distance between image patches in0 and in1
88 | INPUTS
89 | in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1]
90 | OUTPUT
91 | computed distances between in0 and in1
92 | '''
93 |
94 | return self.net.forward(in0, in1, retPerLayer=retPerLayer)
95 |
96 | # ***** TRAINING FUNCTIONS *****
97 | def optimize_parameters(self):
98 | self.forward_train()
99 | self.optimizer_net.zero_grad()
100 | self.backward_train()
101 | self.optimizer_net.step()
102 | self.clamp_weights()
103 |
104 | def clamp_weights(self):
105 | for module in self.net.modules():
106 | if(hasattr(module, 'weight') and module.kernel_size==(1,1)):
107 | module.weight.data = torch.clamp(module.weight.data,min=0)
108 |
109 | def set_input(self, data):
110 | self.input_ref = data['ref']
111 | self.input_p0 = data['p0']
112 | self.input_p1 = data['p1']
113 | self.input_judge = data['judge']
114 |
115 | if(self.use_gpu):
116 | self.input_ref = self.input_ref.to(device=self.gpu_ids[0])
117 | self.input_p0 = self.input_p0.to(device=self.gpu_ids[0])
118 | self.input_p1 = self.input_p1.to(device=self.gpu_ids[0])
119 | self.input_judge = self.input_judge.to(device=self.gpu_ids[0])
120 |
121 | self.var_ref = Variable(self.input_ref,requires_grad=True)
122 | self.var_p0 = Variable(self.input_p0,requires_grad=True)
123 | self.var_p1 = Variable(self.input_p1,requires_grad=True)
124 |
125 | def forward_train(self): # run forward pass
126 | self.d0 = self.forward(self.var_ref, self.var_p0)
127 | self.d1 = self.forward(self.var_ref, self.var_p1)
128 | self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge)
129 |
130 | self.var_judge = Variable(1.*self.input_judge).view(self.d0.size())
131 |
132 | self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.)
133 |
134 | return self.loss_total
135 |
136 | def backward_train(self):
137 | torch.mean(self.loss_total).backward()
138 |
139 | def compute_accuracy(self,d0,d1,judge):
140 | ''' d0, d1 are Variables, judge is a Tensor '''
141 | d1_lt_d0 = (d1 %f' % (type,self.old_lr, lr))
197 | self.old_lr = lr
198 |
199 |
200 | def get_image_paths(self):
201 | return self.image_paths
202 |
203 | def save_done(self, flag=False):
204 | np.save(os.path.join(self.save_dir, 'done_flag'),flag)
205 | np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i')
206 |
207 |
208 | def score_2afc_dataset(data_loader, func, name=''):
209 | ''' Function computes Two Alternative Forced Choice (2AFC) score using
210 | distance function 'func' in dataset 'data_loader'
211 | INPUTS
212 | data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside
213 | func - callable distance function - calling d=func(in0,in1) should take 2
214 | pytorch tensors with shape Nx3xXxY, and return numpy array of length N
215 | OUTPUTS
216 | [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators
217 | [1] - dictionary with following elements
218 | d0s,d1s - N arrays containing distances between reference patch to perturbed patches
219 | gts - N array in [0,1], preferred patch selected by human evaluators
220 | (closer to "0" for left patch p0, "1" for right patch p1,
221 | "0.6" means 60pct people preferred right patch, 40pct preferred left)
222 | scores - N array in [0,1], corresponding to what percentage function agreed with humans
223 | CONSTS
224 | N - number of test triplets in data_loader
225 | '''
226 |
227 | d0s = []
228 | d1s = []
229 | gts = []
230 |
231 | for data in tqdm(data_loader.load_data(), desc=name):
232 | d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist()
233 | d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist()
234 | gts+=data['judge'].cpu().numpy().flatten().tolist()
235 |
236 | d0s = np.array(d0s)
237 | d1s = np.array(d1s)
238 | gts = np.array(gts)
239 | scores = (d0s 1:
484 | out = self.conv_merge(out)
485 |
486 | return self.skip_add.add(out, x)
487 |
488 |
489 |
490 | class FeatureFusionBlock_custom(nn.Module):
491 | """Feature fusion block."""
492 |
493 | def __init__(
494 | self,
495 | features,
496 | activation,
497 | deconv=False,
498 | bn=False,
499 | expand=False,
500 | align_corners=True,
501 | ):
502 | """Init.
503 |
504 | Args:
505 | features (int): number of features
506 | """
507 | super(FeatureFusionBlock_custom, self).__init__()
508 |
509 | self.deconv = deconv
510 | self.align_corners = align_corners
511 |
512 | self.groups = 1
513 |
514 | self.expand = expand
515 | out_features = features
516 | if self.expand == True:
517 | out_features = features // 2
518 |
519 | self.out_conv = nn.Conv2d(
520 | features,
521 | out_features,
522 | kernel_size=1,
523 | stride=1,
524 | padding=0,
525 | bias=True,
526 | groups=1,
527 | )
528 |
529 | self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
530 | self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
531 |
532 | self.skip_add = nn.quantized.FloatFunctional()
533 |
534 | def forward(self, *xs):
535 | """Forward pass.
536 |
537 | Returns:
538 | tensor: output
539 | """
540 | output = xs[0]
541 |
542 | if len(xs) == 2:
543 | res = self.resConfUnit1(xs[1])
544 | output = self.skip_add.add(output, res)
545 |
546 | output = self.resConfUnit2(output)
547 |
548 | output = nn.functional.interpolate(
549 | output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
550 | )
551 |
552 | output = self.out_conv(output)
553 |
554 | return output
555 |
556 |
557 |
558 | def _make_fusion_block(features, use_bn):
559 | return FeatureFusionBlock_custom(
560 | features,
561 | nn.ReLU(False),
562 | deconv=False,
563 | bn=use_bn,
564 | expand=False,
565 | align_corners=True,
566 | )
567 |
568 |
569 | class BaseModel(torch.nn.Module):
570 | def load(self, path):
571 | """Load model from file.
572 |
573 | Args:
574 | path (str): file path
575 | """
576 | parameters = torch.load(path, map_location=torch.device("cpu"))
577 |
578 | if "optimizer" in parameters:
579 | parameters = parameters["model"]
580 |
581 | self.load_state_dict(parameters)
582 |
583 |
584 | class Interpolate(nn.Module):
585 | """Interpolation module."""
586 |
587 | def __init__(self, scale_factor, mode, align_corners=False):
588 | """Init.
589 |
590 | Args:
591 | scale_factor (float): scaling
592 | mode (str): interpolation mode
593 | """
594 | super(Interpolate, self).__init__()
595 |
596 | self.interp = nn.functional.interpolate
597 | self.scale_factor = scale_factor
598 | self.mode = mode
599 | self.align_corners = align_corners
600 |
601 | def forward(self, x):
602 | """Forward pass.
603 |
604 | Args:
605 | x (tensor): input
606 |
607 | Returns:
608 | tensor: interpolated data
609 | """
610 |
611 | x = self.interp(
612 | x,
613 | scale_factor=self.scale_factor,
614 | mode=self.mode,
615 | align_corners=self.align_corners,
616 | )
617 |
618 | return x
619 |
620 |
621 |
622 | class BRViT(BaseModel):
623 | def __init__(
624 | self,
625 | features=256,
626 | backbone="vitb_rn50_384",
627 | non_negative=False,
628 | readout="project",
629 | channels_last=False,
630 | use_bn=False,
631 | enable_attention_hooks=False,
632 | ):
633 |
634 | super(BRViT, self).__init__()
635 |
636 | self.channels_last = channels_last
637 |
638 | hooks = {"vitb_rn50_384": [0, 1, 8, 11]}
639 |
640 | # Instantiate backbone and reassemble blocks
641 | self.pretrained, self.scratch = _make_encoder(
642 | backbone,
643 | features,
644 | True,
645 | groups=1,
646 | expand=False,
647 | exportable=False,
648 | hooks=hooks[backbone],
649 | use_readout=readout,
650 | enable_attention_hooks=enable_attention_hooks,
651 | )
652 |
653 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
654 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
655 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
656 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
657 |
658 | head = nn.Sequential(
659 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
660 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
661 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
662 | nn.ReLU(True),
663 | nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
664 | nn.ReLU(True) if non_negative else nn.Identity(),
665 | nn.Identity(),
666 | )
667 |
668 | self.scratch.output_conv = head
669 |
670 | def forward(self, x):
671 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
672 |
673 | layer_1_rn = self.scratch.layer1_rn(layer_1)
674 | layer_2_rn = self.scratch.layer2_rn(layer_2)
675 | layer_3_rn = self.scratch.layer3_rn(layer_3)
676 | layer_4_rn = self.scratch.layer4_rn(layer_4)
677 |
678 | path_4 = self.scratch.refinenet4(layer_4_rn)
679 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
680 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
681 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
682 |
683 | out = self.scratch.output_conv(path_1)
684 |
685 | return out
686 |
--------------------------------------------------------------------------------