├── classifier └── __init__.py ├── metrics ├── LPIPS │ ├── data │ │ ├── __init__.py │ │ ├── dataset │ │ │ ├── __init__.py │ │ │ ├── base_dataset.py │ │ │ ├── jnd_dataset.py │ │ │ └── twoafc_dataset.py │ │ ├── base_data_loader.py │ │ ├── data_loader.py │ │ ├── custom_dataset_data_loader.py │ │ └── image_folder.py │ ├── util │ │ ├── __init__.py │ │ ├── util.py │ │ ├── html.py │ │ └── visualizer.py │ ├── .gitignore │ ├── imgs │ │ ├── ex_p0.png │ │ ├── ex_p1.png │ │ ├── fig1.png │ │ ├── ex_ref.png │ │ ├── ex_dir0 │ │ │ ├── 0.png │ │ │ └── 1.png │ │ ├── ex_dir1 │ │ │ ├── 0.png │ │ │ └── 1.png │ │ └── ex_dir_pair │ │ │ ├── ex_p0.png │ │ │ ├── ex_p1.png │ │ │ └── ex_ref.png │ ├── models │ │ ├── weights │ │ │ ├── v0.0 │ │ │ │ ├── vgg.pth │ │ │ │ ├── alex.pth │ │ │ │ └── squeeze.pth │ │ │ └── v0.1 │ │ │ │ ├── vgg.pth │ │ │ │ ├── alex.pth │ │ │ │ └── squeeze.pth │ │ ├── base_model.py │ │ ├── __init__.py │ │ ├── pretrained_networks.py │ │ ├── networks_basic.py │ │ └── dist_model.py │ ├── scripts │ │ ├── eval_valsets.sh │ │ ├── train_test_metric.sh │ │ ├── train_test_metric_tune.sh │ │ ├── train_test_metric_scratch.sh │ │ ├── download_dataset_valonly.sh │ │ └── download_dataset.sh │ ├── requirements.txt │ ├── compute_dists.py │ ├── compute_dists_dirs.py │ ├── LICENSE │ ├── perceptual_loss.py │ ├── Dockerfile │ ├── test_network.py │ ├── compute_dists_pair.py │ ├── test_dataset_model.py │ ├── train.py │ └── README.md ├── mIoU │ ├── __pycache__ │ │ ├── loss.cpython-310.pyc │ │ ├── main.cpython-310.pyc │ │ ├── unet.cpython-310.pyc │ │ └── dataset.cpython-310.pyc │ ├── loss.py │ ├── License.txt │ ├── unet.py │ ├── dataset.py │ └── main.py ├── FID │ ├── __pycache__ │ │ ├── fid_score.cpython-310.pyc │ │ ├── inception.cpython-310.pyc │ │ └── tests_with_FID.cpython-310.pyc │ ├── README.md │ ├── tests_with_FID.py │ ├── fid_score.py │ └── LICENSE └── __init__.py ├── .gitignore ├── requirements.txt ├── src ├── model_selection.py ├── test_options.py ├── base_model.py ├── utils_model.py ├── train_options.py ├── utils_datasets.py ├── utils.py └── base_options.py ├── ideal_fid_score.py ├── README.md ├── evaluate.py ├── train.py ├── test.py ├── data └── create_mvtec_dataset.py └── LICENSE /classifier/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /metrics/LPIPS/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /metrics/LPIPS/util/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /metrics/LPIPS/data/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /metrics/LPIPS/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | 3 | checkpoints/* 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | results/ 2 | outputs/ 3 | data/*/ 4 | weights/*/ 5 | classifier/data/*/ 6 | classifier/output/*/ 7 | -------------------------------------------------------------------------------- /metrics/LPIPS/imgs/ex_p0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pasqualecoscia/SyntheticDefectGeneration/HEAD/metrics/LPIPS/imgs/ex_p0.png -------------------------------------------------------------------------------- /metrics/LPIPS/imgs/ex_p1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pasqualecoscia/SyntheticDefectGeneration/HEAD/metrics/LPIPS/imgs/ex_p1.png -------------------------------------------------------------------------------- /metrics/LPIPS/imgs/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pasqualecoscia/SyntheticDefectGeneration/HEAD/metrics/LPIPS/imgs/fig1.png -------------------------------------------------------------------------------- /metrics/LPIPS/imgs/ex_ref.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pasqualecoscia/SyntheticDefectGeneration/HEAD/metrics/LPIPS/imgs/ex_ref.png -------------------------------------------------------------------------------- /metrics/LPIPS/imgs/ex_dir0/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pasqualecoscia/SyntheticDefectGeneration/HEAD/metrics/LPIPS/imgs/ex_dir0/0.png -------------------------------------------------------------------------------- /metrics/LPIPS/imgs/ex_dir0/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pasqualecoscia/SyntheticDefectGeneration/HEAD/metrics/LPIPS/imgs/ex_dir0/1.png -------------------------------------------------------------------------------- /metrics/LPIPS/imgs/ex_dir1/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pasqualecoscia/SyntheticDefectGeneration/HEAD/metrics/LPIPS/imgs/ex_dir1/0.png -------------------------------------------------------------------------------- /metrics/LPIPS/imgs/ex_dir1/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pasqualecoscia/SyntheticDefectGeneration/HEAD/metrics/LPIPS/imgs/ex_dir1/1.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pillow 2 | tqdm 3 | matplotlib 4 | opencv-python 5 | numpy 6 | torchmetrics 7 | torch-fidelity 8 | scikit-image 9 | bing-image-downloader -------------------------------------------------------------------------------- /metrics/LPIPS/imgs/ex_dir_pair/ex_p0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pasqualecoscia/SyntheticDefectGeneration/HEAD/metrics/LPIPS/imgs/ex_dir_pair/ex_p0.png -------------------------------------------------------------------------------- /metrics/LPIPS/imgs/ex_dir_pair/ex_p1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pasqualecoscia/SyntheticDefectGeneration/HEAD/metrics/LPIPS/imgs/ex_dir_pair/ex_p1.png -------------------------------------------------------------------------------- /metrics/LPIPS/imgs/ex_dir_pair/ex_ref.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pasqualecoscia/SyntheticDefectGeneration/HEAD/metrics/LPIPS/imgs/ex_dir_pair/ex_ref.png -------------------------------------------------------------------------------- /metrics/LPIPS/models/weights/v0.0/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pasqualecoscia/SyntheticDefectGeneration/HEAD/metrics/LPIPS/models/weights/v0.0/vgg.pth -------------------------------------------------------------------------------- /metrics/LPIPS/models/weights/v0.1/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pasqualecoscia/SyntheticDefectGeneration/HEAD/metrics/LPIPS/models/weights/v0.1/vgg.pth -------------------------------------------------------------------------------- /metrics/LPIPS/scripts/eval_valsets.sh: -------------------------------------------------------------------------------- 1 | 2 | python ./test_dataset_model.py --dataset_mode 2afc --model net-lin --net alex --use_gpu --batch_size 50 3 | 4 | -------------------------------------------------------------------------------- /metrics/LPIPS/models/weights/v0.0/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pasqualecoscia/SyntheticDefectGeneration/HEAD/metrics/LPIPS/models/weights/v0.0/alex.pth -------------------------------------------------------------------------------- /metrics/LPIPS/models/weights/v0.1/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pasqualecoscia/SyntheticDefectGeneration/HEAD/metrics/LPIPS/models/weights/v0.1/alex.pth -------------------------------------------------------------------------------- /metrics/LPIPS/models/weights/v0.0/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pasqualecoscia/SyntheticDefectGeneration/HEAD/metrics/LPIPS/models/weights/v0.0/squeeze.pth -------------------------------------------------------------------------------- /metrics/LPIPS/models/weights/v0.1/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pasqualecoscia/SyntheticDefectGeneration/HEAD/metrics/LPIPS/models/weights/v0.1/squeeze.pth -------------------------------------------------------------------------------- /metrics/mIoU/__pycache__/loss.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pasqualecoscia/SyntheticDefectGeneration/HEAD/metrics/mIoU/__pycache__/loss.cpython-310.pyc -------------------------------------------------------------------------------- /metrics/mIoU/__pycache__/main.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pasqualecoscia/SyntheticDefectGeneration/HEAD/metrics/mIoU/__pycache__/main.cpython-310.pyc -------------------------------------------------------------------------------- /metrics/mIoU/__pycache__/unet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pasqualecoscia/SyntheticDefectGeneration/HEAD/metrics/mIoU/__pycache__/unet.cpython-310.pyc -------------------------------------------------------------------------------- /metrics/mIoU/__pycache__/dataset.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pasqualecoscia/SyntheticDefectGeneration/HEAD/metrics/mIoU/__pycache__/dataset.cpython-310.pyc -------------------------------------------------------------------------------- /metrics/FID/__pycache__/fid_score.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pasqualecoscia/SyntheticDefectGeneration/HEAD/metrics/FID/__pycache__/fid_score.cpython-310.pyc -------------------------------------------------------------------------------- /metrics/FID/__pycache__/inception.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pasqualecoscia/SyntheticDefectGeneration/HEAD/metrics/FID/__pycache__/inception.cpython-310.pyc -------------------------------------------------------------------------------- /metrics/FID/__pycache__/tests_with_FID.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pasqualecoscia/SyntheticDefectGeneration/HEAD/metrics/FID/__pycache__/tests_with_FID.cpython-310.pyc -------------------------------------------------------------------------------- /metrics/LPIPS/requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=0.4.0 2 | torchvision>=0.2.1 3 | numpy>=1.14.3 4 | scipy>=1.0.1 5 | scikit-image>=0.13.0 6 | opencv>=2.4.11 7 | matplotlib>=1.5.1 8 | tqdm>=4.28.1 9 | -------------------------------------------------------------------------------- /metrics/LPIPS/data/base_data_loader.py: -------------------------------------------------------------------------------- 1 | 2 | class BaseDataLoader(): 3 | def __init__(self): 4 | pass 5 | 6 | def initialize(self): 7 | pass 8 | 9 | def load_data(): 10 | return None 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /metrics/LPIPS/scripts/train_test_metric.sh: -------------------------------------------------------------------------------- 1 | 2 | TRIAL=${1} 3 | NET=${2} 4 | mkdir checkpoints 5 | mkdir checkpoints/${NET}_${TRIAL} 6 | python ./train.py --use_gpu --net ${NET} --name ${NET}_${TRIAL} 7 | python ./test_dataset_model.py --use_gpu --net ${NET} --model_path ./checkpoints/${NET}_${TRIAL}/latest_net_.pth 8 | -------------------------------------------------------------------------------- /metrics/LPIPS/data/dataset/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | class BaseDataset(data.Dataset): 4 | def __init__(self): 5 | super(BaseDataset, self).__init__() 6 | 7 | def name(self): 8 | return 'BaseDataset' 9 | 10 | def initialize(self): 11 | pass 12 | 13 | -------------------------------------------------------------------------------- /metrics/LPIPS/scripts/train_test_metric_tune.sh: -------------------------------------------------------------------------------- 1 | 2 | TRIAL=${1} 3 | NET=${2} 4 | mkdir checkpoints 5 | mkdir checkpoints/${NET}_${TRIAL}_tune 6 | python ./train.py --train_trunk --use_gpu --net ${NET} --name ${NET}_${TRIAL}_tune 7 | python ./test_dataset_model.py --train_trunk --use_gpu --net ${NET} --model_path ./checkpoints/${NET}_${TRIAL}_tune/latest_net_.pth 8 | -------------------------------------------------------------------------------- /metrics/LPIPS/scripts/train_test_metric_scratch.sh: -------------------------------------------------------------------------------- 1 | 2 | TRIAL=${1} 3 | NET=${2} 4 | mkdir checkpoints 5 | mkdir checkpoints/${NET}_${TRIAL}_scratch 6 | python ./train.py --from_scratch --train_trunk --use_gpu --net ${NET} --name ${NET}_${TRIAL}_scratch 7 | python ./test_dataset_model.py --from_scratch --train_trunk --use_gpu --net ${NET} --model_path ./checkpoints/${NET}_${TRIAL}_scratch/latest_net_.pth 8 | 9 | -------------------------------------------------------------------------------- /src/model_selection.py: -------------------------------------------------------------------------------- 1 | from src.cycle_gan import CycleGAN 2 | from src.cycle_gan_mask import CycleGAN_Mask 3 | from src.gan_mask import GAN_Mask 4 | 5 | def select_model(args): 6 | """ Select model """ 7 | if args.model == 'cycle_gan': 8 | model = CycleGAN(args) 9 | elif args.model == 'cycle_gan_mask': 10 | model = CycleGAN_Mask(args) 11 | #TODO:else 12 | return model -------------------------------------------------------------------------------- /metrics/mIoU/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def dice_loss(pred, target, smooth = 1.): 5 | pred = pred.contiguous() 6 | target = target.contiguous() 7 | 8 | intersection = (pred * target).sum(dim=2).sum(dim=2) 9 | 10 | loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth))) 11 | 12 | return loss.mean() 13 | -------------------------------------------------------------------------------- /src/test_options.py: -------------------------------------------------------------------------------- 1 | from src.base_options import BaseOptions 2 | 3 | class TestOptions(BaseOptions): 4 | """This class includes training options.It also includes shared options defined in BaseOptions.""" 5 | def initialize(self, parser): 6 | # Initialize base options 7 | parser = BaseOptions.initialize(self, parser) 8 | parser.add_argument("--input_type", type=str, default="standard") 9 | return parser 10 | 11 | -------------------------------------------------------------------------------- /metrics/LPIPS/data/data_loader.py: -------------------------------------------------------------------------------- 1 | def CreateDataLoader(datafolder,dataroot='./dataset',dataset_mode='2afc',load_size=64,batch_size=1,serial_batches=True,nThreads=4): 2 | from data.custom_dataset_data_loader import CustomDatasetDataLoader 3 | data_loader = CustomDatasetDataLoader() 4 | # print(data_loader.name()) 5 | data_loader.initialize(datafolder,dataroot=dataroot+'/'+dataset_mode,dataset_mode=dataset_mode,load_size=load_size,batch_size=batch_size,serial_batches=serial_batches, nThreads=nThreads) 6 | return data_loader 7 | -------------------------------------------------------------------------------- /metrics/LPIPS/scripts/download_dataset_valonly.sh: -------------------------------------------------------------------------------- 1 | 2 | mkdir dataset 3 | 4 | # JND Dataset 5 | wget https://people.eecs.berkeley.edu/~rich.zhang/projects/2018_perceptual/dataset/jnd.tar.gz -O ./dataset/jnd.tar.gz 6 | 7 | mkdir dataset/jnd 8 | tar -xzf ./dataset/jnd.tar.gz -C ./dataset 9 | rm ./dataset/jnd.tar.gz 10 | 11 | # 2AFC Val set 12 | mkdir dataset/2afc/ 13 | wget https://people.eecs.berkeley.edu/~rich.zhang/projects/2018_perceptual/dataset/twoafc_val.tar.gz -O ./dataset/twoafc_val.tar.gz 14 | 15 | mkdir dataset/2afc/val 16 | tar -xzf ./dataset/twoafc_val.tar.gz -C ./dataset/2afc 17 | rm ./dataset/twoafc_val.tar.gz 18 | -------------------------------------------------------------------------------- /metrics/LPIPS/scripts/download_dataset.sh: -------------------------------------------------------------------------------- 1 | 2 | mkdir dataset 3 | 4 | # JND Dataset 5 | wget https://people.eecs.berkeley.edu/~rich.zhang/projects/2018_perceptual/dataset/jnd.tar.gz -O ./dataset/jnd.tar.gz 6 | 7 | mkdir dataset/jnd 8 | tar -xzf ./dataset/jnd.tar.gz -C ./dataset 9 | rm ./dataset/jnd.tar.gz 10 | 11 | # 2AFC Val set 12 | mkdir dataset/2afc/ 13 | wget https://people.eecs.berkeley.edu/~rich.zhang/projects/2018_perceptual/dataset/twoafc_val.tar.gz -O ./dataset/twoafc_val.tar.gz 14 | 15 | mkdir dataset/2afc/val 16 | tar -xzf ./dataset/twoafc_val.tar.gz -C ./dataset/2afc 17 | rm ./dataset/twoafc_val.tar.gz 18 | 19 | # 2AFC Train set 20 | mkdir dataset/2afc/ 21 | wget https://people.eecs.berkeley.edu/~rich.zhang/projects/2018_perceptual/dataset/twoafc_train.tar.gz -O ./dataset/twoafc_train.tar.gz 22 | 23 | mkdir dataset/2afc/train 24 | tar -xzf ./dataset/twoafc_train.tar.gz -C ./dataset/2afc 25 | rm ./dataset/twoafc_train.tar.gz 26 | -------------------------------------------------------------------------------- /metrics/LPIPS/compute_dists.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import models 3 | from util import util 4 | 5 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 6 | parser.add_argument('-p0','--path0', type=str, default='./imgs/ex_ref.png') 7 | parser.add_argument('-p1','--path1', type=str, default='./imgs/ex_p0.png') 8 | parser.add_argument('-v','--version', type=str, default='0.1') 9 | parser.add_argument('--use_gpu', action='store_true', help='turn on flag to use GPU') 10 | 11 | opt = parser.parse_args() 12 | 13 | ## Initializing the model 14 | model = models.PerceptualLoss(model='net-lin',net='alex',use_gpu=opt.use_gpu,version=opt.version) 15 | 16 | # Load images 17 | img0 = util.im2tensor(util.load_image(opt.path0)) # RGB image from [-1,1] 18 | img1 = util.im2tensor(util.load_image(opt.path1)) 19 | 20 | if(opt.use_gpu): 21 | img0 = img0.cuda() 22 | img1 = img1.cuda() 23 | 24 | 25 | # Compute distance 26 | dist01 = model.forward(img0,img1) 27 | print('Distance: %.3f'%dist01) 28 | -------------------------------------------------------------------------------- /metrics/mIoU/License.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Naoto Usuyama 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /metrics/LPIPS/compute_dists_dirs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import models 4 | from util import util 5 | 6 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 7 | parser.add_argument('-d0','--dir0', type=str, default='./imgs/ex_dir0') 8 | parser.add_argument('-d1','--dir1', type=str, default='./imgs/ex_dir1') 9 | parser.add_argument('-o','--out', type=str, default='./imgs/example_dists.txt') 10 | parser.add_argument('-v','--version', type=str, default='0.1') 11 | parser.add_argument('--use_gpu', action='store_true', help='turn on flag to use GPU') 12 | 13 | opt = parser.parse_args() 14 | 15 | ## Initializing the model 16 | model = models.PerceptualLoss(model='net-lin',net='alex',use_gpu=opt.use_gpu,version=opt.version) 17 | 18 | # crawl directories 19 | f = open(opt.out,'w') 20 | files = os.listdir(opt.dir0) 21 | 22 | for file in files: 23 | if(os.path.exists(os.path.join(opt.dir1,file))): 24 | # Load images 25 | img0 = util.im2tensor(util.load_image(os.path.join(opt.dir0,file))) # RGB image from [-1,1] 26 | img1 = util.im2tensor(util.load_image(os.path.join(opt.dir1,file))) 27 | 28 | if(opt.use_gpu): 29 | img0 = img0.cuda() 30 | img1 = img1.cuda() 31 | 32 | # Compute distance 33 | dist01 = model.forward(img0,img1) 34 | print('%s: %.3f'%(file,dist01)) 35 | f.writelines('%s: %.6f\n'%(file,dist01)) 36 | 37 | f.close() 38 | -------------------------------------------------------------------------------- /metrics/LPIPS/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | 25 | -------------------------------------------------------------------------------- /metrics/LPIPS/util/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import numpy as np 4 | from PIL import Image 5 | import numpy as np 6 | import os 7 | import matplotlib.pyplot as plt 8 | import torch 9 | 10 | def load_image(path): 11 | if(path[-3:] == 'dng'): 12 | import rawpy 13 | with rawpy.imread(path) as raw: 14 | img = raw.postprocess() 15 | elif(path[-3:]=='bmp' or path[-3:]=='jpg' or path[-3:]=='png'): 16 | import cv2 17 | return cv2.imread(path)[:,:,::-1] 18 | else: 19 | img = (255*plt.imread(path)[:,:,:3]).astype('uint8') 20 | 21 | return img 22 | 23 | def save_image(image_numpy, image_path, ): 24 | image_pil = Image.fromarray(image_numpy) 25 | image_pil.save(image_path) 26 | 27 | def mkdirs(paths): 28 | if isinstance(paths, list) and not isinstance(paths, str): 29 | for path in paths: 30 | mkdir(path) 31 | else: 32 | mkdir(paths) 33 | 34 | def mkdir(path): 35 | if not os.path.exists(path): 36 | os.makedirs(path) 37 | 38 | 39 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 40 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): 41 | image_numpy = image_tensor[0].cpu().float().numpy() 42 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 43 | return image_numpy.astype(imtype) 44 | 45 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 46 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): 47 | return torch.Tensor((image / factor - cent) 48 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 49 | -------------------------------------------------------------------------------- /metrics/LPIPS/data/custom_dataset_data_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from data.base_data_loader import BaseDataLoader 3 | import os 4 | 5 | def CreateDataset(dataroots,dataset_mode='2afc',load_size=64,): 6 | dataset = None 7 | if dataset_mode=='2afc': # human judgements 8 | from dataset.twoafc_dataset import TwoAFCDataset 9 | dataset = TwoAFCDataset() 10 | elif dataset_mode=='jnd': # human judgements 11 | from dataset.jnd_dataset import JNDDataset 12 | dataset = JNDDataset() 13 | else: 14 | raise ValueError("Dataset Mode [%s] not recognized."%self.dataset_mode) 15 | 16 | dataset.initialize(dataroots,load_size=load_size) 17 | return dataset 18 | 19 | class CustomDatasetDataLoader(BaseDataLoader): 20 | def name(self): 21 | return 'CustomDatasetDataLoader' 22 | 23 | def initialize(self, datafolders, dataroot='./dataset',dataset_mode='2afc',load_size=64,batch_size=1,serial_batches=True, nThreads=1): 24 | BaseDataLoader.initialize(self) 25 | if(not isinstance(datafolders,list)): 26 | datafolders = [datafolders,] 27 | data_root_folders = [os.path.join(dataroot,datafolder) for datafolder in datafolders] 28 | self.dataset = CreateDataset(data_root_folders,dataset_mode=dataset_mode,load_size=load_size) 29 | self.dataloader = torch.utils.data.DataLoader( 30 | self.dataset, 31 | batch_size=batch_size, 32 | shuffle=not serial_batches, 33 | num_workers=int(nThreads)) 34 | 35 | def load_data(self): 36 | return self.dataloader 37 | 38 | def __len__(self): 39 | return len(self.dataset) 40 | -------------------------------------------------------------------------------- /metrics/LPIPS/perceptual_loss.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | import sys 5 | import scipy 6 | import scipy.misc 7 | import numpy as np 8 | import torch 9 | from torch.autograd import Variable 10 | import models 11 | 12 | use_gpu = True 13 | 14 | ref_path = './imgs/ex_ref.png' 15 | pred_path = './imgs/ex_p1.png' 16 | 17 | ref_img = scipy.misc.imread(ref_path).transpose(2, 0, 1) / 255. 18 | pred_img = scipy.misc.imread(pred_path).transpose(2, 0, 1) / 255. 19 | 20 | # Torchify 21 | ref = Variable(torch.FloatTensor(ref_img)[None,:,:,:]) 22 | pred = Variable(torch.FloatTensor(pred_img)[None,:,:,:], requires_grad=True) 23 | 24 | loss_fn = models.PerceptualLoss(model='net-lin', net='vgg', use_gpu=use_gpu) 25 | optimizer = torch.optim.Adam([pred,], lr=1e-3, betas=(0.9, 0.999)) 26 | 27 | import matplotlib.pyplot as plt 28 | plt.ion() 29 | fig = plt.figure(1) 30 | ax = fig.add_subplot(131) 31 | ax.imshow(ref_img.transpose(1, 2, 0)) 32 | ax.set_title('target') 33 | ax = fig.add_subplot(133) 34 | ax.imshow(pred_img.transpose(1, 2, 0)) 35 | ax.set_title('initialization') 36 | 37 | for i in range(1000): 38 | dist = loss_fn.forward(pred, ref, normalize=True) 39 | optimizer.zero_grad() 40 | dist.backward() 41 | optimizer.step() 42 | pred.data = torch.clamp(pred.data, 0, 1) 43 | 44 | if i % 10 == 0: 45 | print('iter %d, dist %.3g' % (i, dist.view(-1).data.cpu().numpy()[0])) 46 | pred_img = pred[0].data.cpu().numpy().transpose(1, 2, 0) 47 | pred_img = np.clip(pred_img, 0, 1) 48 | ax = fig.add_subplot(132) 49 | ax.imshow(pred_img) 50 | ax.set_title('iter %d, dist %.3f' % (i, dist.view(-1).data.cpu().numpy()[0])) 51 | plt.pause(5e-2) 52 | # plt.imsave('imgs_saved/%04d.jpg'%i,pred_img) 53 | 54 | 55 | -------------------------------------------------------------------------------- /metrics/LPIPS/models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.autograd import Variable 4 | from pdb import set_trace as st 5 | from IPython import embed 6 | 7 | class BaseModel(): 8 | def __init__(self): 9 | pass; 10 | 11 | def name(self): 12 | return 'BaseModel' 13 | 14 | def initialize(self, use_gpu=True, gpu_ids=[0]): 15 | self.use_gpu = use_gpu 16 | self.gpu_ids = gpu_ids 17 | 18 | def forward(self): 19 | pass 20 | 21 | def get_image_paths(self): 22 | pass 23 | 24 | def optimize_parameters(self): 25 | pass 26 | 27 | def get_current_visuals(self): 28 | return self.input 29 | 30 | def get_current_errors(self): 31 | return {} 32 | 33 | def save(self, label): 34 | pass 35 | 36 | # helper saving function that can be used by subclasses 37 | def save_network(self, network, path, network_label, epoch_label): 38 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 39 | save_path = os.path.join(path, save_filename) 40 | torch.save(network.state_dict(), save_path) 41 | 42 | # helper loading function that can be used by subclasses 43 | def load_network(self, network, network_label, epoch_label): 44 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 45 | save_path = os.path.join(self.save_dir, save_filename) 46 | print('Loading network from %s'%save_path) 47 | network.load_state_dict(torch.load(save_path)) 48 | 49 | def update_learning_rate(): 50 | pass 51 | 52 | def get_image_paths(self): 53 | return self.image_paths 54 | 55 | def save_done(self, flag=False): 56 | np.save(os.path.join(self.save_dir, 'done_flag'),flag) 57 | np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i') 58 | 59 | -------------------------------------------------------------------------------- /ideal_fid_score.py: -------------------------------------------------------------------------------- 1 | from torchmetrics.image.fid import FrechetInceptionDistance 2 | import torchvision.transforms as transforms 3 | import os 4 | import torch 5 | import torchvision.io as io 6 | import argparse 7 | 8 | def read_images(path, transform): 9 | ''' 10 | Read images in path and applies transformation 'transform'. 11 | 12 | Input 13 | ------ 14 | path 15 | transform: torchvision.transforms.Resize 16 | 17 | Output 18 | ------ 19 | batch: torch.Tensor (N x 3 x H x W) 20 | ''' 21 | 22 | batch_size = len(os.listdir(path)) 23 | batch = torch.zeros(batch_size, 3, transform.size, transform.size, dtype=torch.uint8) 24 | for i, filename in enumerate(os.listdir(path)): 25 | batch[i] = transform(io.read_image(os.path.join(path, filename))) 26 | 27 | return batch 28 | 29 | parser = argparse.ArgumentParser( 30 | description="Adversarial Defect Synthesis - PyTorch") 31 | 32 | parser.add_argument("--dataroot", type=str, default="./data", help="path to datasets. (default:./data)") 33 | parser.add_argument("--dataset", type=str, default="mvtec_dataset", help="dataset name. (default:`horse2zebra`) Option: [apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, facades, selfie2anime, iphone2dslr_flower, ae_photos, ]") 34 | 35 | args = parser.parse_args() 36 | 37 | # Compute FID score 38 | fid = FrechetInceptionDistance(feature=2048) 39 | 40 | PATH_1= os.path.join(args.dataroot, args.dataset, "train/B") 41 | PATH_2 = os.path.join(args.dataroot, args.dataset, "test/B") 42 | 43 | # Resize the image with given size 44 | transform = transforms.Resize(256) 45 | 46 | images_1 = read_images(PATH_1, transform) 47 | images_2 = read_images(PATH_2, transform) 48 | 49 | fid.update(images_1, real=True) 50 | fid.update(images_2, real=False) 51 | print(f"FID score: {fid.compute():.2f}.") -------------------------------------------------------------------------------- /metrics/LPIPS/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:9.0-base-ubuntu16.04 2 | 3 | LABEL maintainer="Seyoung Park " 4 | 5 | # This Dockerfile is forked from Tensorflow Dockerfile 6 | 7 | # Pick up some PyTorch gpu dependencies 8 | RUN apt-get update && apt-get install -y --no-install-recommends \ 9 | build-essential \ 10 | cuda-command-line-tools-9-0 \ 11 | cuda-cublas-9-0 \ 12 | cuda-cufft-9-0 \ 13 | cuda-curand-9-0 \ 14 | cuda-cusolver-9-0 \ 15 | cuda-cusparse-9-0 \ 16 | curl \ 17 | libcudnn7=7.1.4.18-1+cuda9.0 \ 18 | libfreetype6-dev \ 19 | libhdf5-serial-dev \ 20 | libpng12-dev \ 21 | libzmq3-dev \ 22 | pkg-config \ 23 | python \ 24 | python-dev \ 25 | rsync \ 26 | software-properties-common \ 27 | unzip \ 28 | && \ 29 | apt-get clean && \ 30 | rm -rf /var/lib/apt/lists/* 31 | 32 | 33 | # Install miniconda 34 | RUN apt-get update && apt-get install -y --no-install-recommends \ 35 | wget && \ 36 | MINICONDA="Miniconda3-latest-Linux-x86_64.sh" && \ 37 | wget --quiet https://repo.continuum.io/miniconda/$MINICONDA && \ 38 | bash $MINICONDA -b -p /miniconda && \ 39 | rm -f $MINICONDA 40 | ENV PATH /miniconda/bin:$PATH 41 | 42 | # Install PyTorch 43 | RUN conda update -n base conda && \ 44 | conda install pytorch torchvision cuda90 -c pytorch 45 | 46 | # Install PerceptualSimilarity dependencies 47 | RUN conda install numpy scipy jupyter matplotlib && \ 48 | conda install -c conda-forge scikit-image && \ 49 | apt-get install -y python-qt4 && \ 50 | pip install opencv-python 51 | 52 | # For CUDA profiling, TensorFlow requires CUPTI. Maybe PyTorch needs this too. 53 | ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH 54 | 55 | # IPython 56 | EXPOSE 8888 57 | 58 | WORKDIR "/notebooks" 59 | 60 | -------------------------------------------------------------------------------- /metrics/LPIPS/test_network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from util import util 3 | import models 4 | from models import dist_model as dm 5 | from IPython import embed 6 | 7 | use_gpu = False # Whether to use GPU 8 | spatial = True # Return a spatial map of perceptual distance. 9 | 10 | # Linearly calibrated models (LPIPS) 11 | model = models.PerceptualLoss(model='net-lin', net='alex', use_gpu=use_gpu, spatial=spatial) 12 | # Can also set net = 'squeeze' or 'vgg' 13 | 14 | # Off-the-shelf uncalibrated networks 15 | # model = models.PerceptualLoss(model='net', net='alex', use_gpu=use_gpu, spatial=spatial) 16 | # Can also set net = 'squeeze' or 'vgg' 17 | 18 | # Low-level metrics 19 | # model = models.PerceptualLoss(model='L2', colorspace='Lab', use_gpu=use_gpu) 20 | # model = models.PerceptualLoss(model='ssim', colorspace='RGB', use_gpu=use_gpu) 21 | 22 | ## Example usage with dummy tensors 23 | dummy_im0 = torch.zeros(1,3,64,64) # image should be RGB, normalized to [-1,1] 24 | dummy_im1 = torch.zeros(1,3,64,64) 25 | if(use_gpu): 26 | dummy_im0 = dummy_im0.cuda() 27 | dummy_im1 = dummy_im1.cuda() 28 | dist = model.forward(dummy_im0,dummy_im1) 29 | 30 | ## Example usage with images 31 | ex_ref = util.im2tensor(util.load_image('./imgs/ex_ref.png')) 32 | ex_p0 = util.im2tensor(util.load_image('./imgs/ex_p0.png')) 33 | ex_p1 = util.im2tensor(util.load_image('./imgs/ex_p1.png')) 34 | if(use_gpu): 35 | ex_ref = ex_ref.cuda() 36 | ex_p0 = ex_p0.cuda() 37 | ex_p1 = ex_p1.cuda() 38 | 39 | ex_d0 = model.forward(ex_ref,ex_p0) 40 | ex_d1 = model.forward(ex_ref,ex_p1) 41 | 42 | if not spatial: 43 | print('Distances: (%.3f, %.3f)'%(ex_d0, ex_d1)) 44 | else: 45 | print('Distances: (%.3f, %.3f)'%(ex_d0.mean(), ex_d1.mean())) # The mean distance is approximately the same as the non-spatial distance 46 | 47 | # Visualize a spatially-varying distance map between ex_p0 and ex_ref 48 | import pylab 49 | pylab.imshow(ex_d0[0,0,...].data.cpu().numpy()) 50 | pylab.show() 51 | -------------------------------------------------------------------------------- /metrics/LPIPS/compute_dists_pair.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import models 4 | from util import util 5 | import numpy as np 6 | from IPython import embed 7 | 8 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 9 | parser.add_argument('-d','--dir', type=str, default='./imgs/ex_dir0') 10 | parser.add_argument('-o','--out', type=str, default='./imgs/example_dists.txt') 11 | parser.add_argument('-v','--version', type=str, default='0.1') 12 | parser.add_argument('--all-pairs', action='store_true', help='turn on to test all N(N-1)/2 pairs, leave off to just do consecutive pairs (N-1)') 13 | parser.add_argument('-N', type=int, default=None) 14 | parser.add_argument('--use_gpu', action='store_true', help='turn on flag to use GPU') 15 | 16 | opt = parser.parse_args() 17 | 18 | ## Initializing the model 19 | model = models.PerceptualLoss(model='net-lin',net='alex',use_gpu=opt.use_gpu,version=opt.version) 20 | 21 | # crawl directories 22 | f = open(opt.out,'w') 23 | files = os.listdir(opt.dir) 24 | if(opt.N is not None): 25 | files = files[:opt.N] 26 | F = len(files) 27 | 28 | dists = [] 29 | for (ff,file) in enumerate(files[:-1]): 30 | img0 = util.im2tensor(util.load_image(os.path.join(opt.dir,file))) # RGB image from [-1,1] 31 | if(opt.use_gpu): 32 | img0 = img0.cuda() 33 | 34 | if(opt.all_pairs): 35 | files1 = files[ff+1:] 36 | else: 37 | files1 = [files[ff+1],] 38 | 39 | for file1 in files1: 40 | img1 = util.im2tensor(util.load_image(os.path.join(opt.dir,file1))) 41 | 42 | if(opt.use_gpu): 43 | img1 = img1.cuda() 44 | 45 | # Compute distance 46 | dist01 = model.forward(img0,img1) 47 | print('(%s,%s): %.3f'%(file,file1,dist01)) 48 | f.writelines('(%s,%s): %.6f\n'%(file,file1,dist01)) 49 | 50 | dists.append(dist01.item()) 51 | 52 | avg_dist = np.mean(np.array(dists)) 53 | stderr_dist = np.std(np.array(dists))/np.sqrt(len(dists)) 54 | 55 | print('Avg: %.5f +/- %.5f'%(avg_dist,stderr_dist)) 56 | f.writelines('Avg: %.6f +/- %.6f'%(avg_dist,stderr_dist)) 57 | 58 | f.close() 59 | -------------------------------------------------------------------------------- /metrics/LPIPS/data/dataset/jnd_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import torchvision.transforms as transforms 3 | from data.dataset.base_dataset import BaseDataset 4 | from data.image_folder import make_dataset 5 | from PIL import Image 6 | import numpy as np 7 | import torch 8 | from IPython import embed 9 | 10 | class JNDDataset(BaseDataset): 11 | def initialize(self, dataroot, load_size=64): 12 | self.root = dataroot 13 | self.load_size = load_size 14 | 15 | self.dir_p0 = os.path.join(self.root, 'p0') 16 | self.p0_paths = make_dataset(self.dir_p0) 17 | self.p0_paths = sorted(self.p0_paths) 18 | 19 | self.dir_p1 = os.path.join(self.root, 'p1') 20 | self.p1_paths = make_dataset(self.dir_p1) 21 | self.p1_paths = sorted(self.p1_paths) 22 | 23 | transform_list = [] 24 | transform_list.append(transforms.Scale(load_size)) 25 | transform_list += [transforms.ToTensor(), 26 | transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))] 27 | 28 | self.transform = transforms.Compose(transform_list) 29 | 30 | # judgement directory 31 | self.dir_S = os.path.join(self.root, 'same') 32 | self.same_paths = make_dataset(self.dir_S,mode='np') 33 | self.same_paths = sorted(self.same_paths) 34 | 35 | def __getitem__(self, index): 36 | p0_path = self.p0_paths[index] 37 | p0_img_ = Image.open(p0_path).convert('RGB') 38 | p0_img = self.transform(p0_img_) 39 | 40 | p1_path = self.p1_paths[index] 41 | p1_img_ = Image.open(p1_path).convert('RGB') 42 | p1_img = self.transform(p1_img_) 43 | 44 | same_path = self.same_paths[index] 45 | same_img = np.load(same_path).reshape((1,1,1,)) # [0,1] 46 | 47 | same_img = torch.FloatTensor(same_img) 48 | 49 | return {'p0': p0_img, 'p1': p1_img, 'same': same_img, 50 | 'p0_path': p0_path, 'p1_path': p1_path, 'same_path': same_path} 51 | 52 | def __len__(self): 53 | return len(self.p0_paths) 54 | -------------------------------------------------------------------------------- /metrics/LPIPS/util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import * 3 | import os 4 | 5 | 6 | class HTML: 7 | def __init__(self, web_dir, title, image_subdir='', reflesh=0): 8 | self.title = title 9 | self.web_dir = web_dir 10 | # self.img_dir = os.path.join(self.web_dir, ) 11 | self.img_subdir = image_subdir 12 | self.img_dir = os.path.join(self.web_dir, image_subdir) 13 | if not os.path.exists(self.web_dir): 14 | os.makedirs(self.web_dir) 15 | if not os.path.exists(self.img_dir): 16 | os.makedirs(self.img_dir) 17 | # print(self.img_dir) 18 | 19 | self.doc = dominate.document(title=title) 20 | if reflesh > 0: 21 | with self.doc.head: 22 | meta(http_equiv="reflesh", content=str(reflesh)) 23 | 24 | def get_image_dir(self): 25 | return self.img_dir 26 | 27 | def add_header(self, str): 28 | with self.doc: 29 | h3(str) 30 | 31 | def add_table(self, border=1): 32 | self.t = table(border=border, style="table-layout: fixed;") 33 | self.doc.add(self.t) 34 | 35 | def add_images(self, ims, txts, links, width=400): 36 | self.add_table() 37 | with self.t: 38 | with tr(): 39 | for im, txt, link in zip(ims, txts, links): 40 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 41 | with p(): 42 | with a(href=os.path.join(link)): 43 | img(style="width:%dpx" % width, src=os.path.join(im)) 44 | br() 45 | p(txt) 46 | 47 | def save(self,file='index'): 48 | html_file = '%s/%s.html' % (self.web_dir,file) 49 | f = open(html_file, 'wt') 50 | f.write(self.doc.render()) 51 | f.close() 52 | 53 | 54 | if __name__ == '__main__': 55 | html = HTML('web/', 'test_html') 56 | html.add_header('hello world') 57 | 58 | ims = [] 59 | txts = [] 60 | links = [] 61 | for n in range(4): 62 | ims.append('image_%d.png' % n) 63 | txts.append('text_%d' % n) 64 | links.append('image_%d.png' % n) 65 | html.add_images(ims, txts, links) 66 | html.save() 67 | -------------------------------------------------------------------------------- /src/base_model.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import torch 3 | 4 | class BaseModel(ABC): 5 | """Create a model given the options""" 6 | 7 | def __init__(self, args): 8 | self.args = args 9 | self.device = torch.device("cuda:0" if args.cuda else "cpu") 10 | 11 | @abstractmethod 12 | def create_network(): 13 | """ Create network """ 14 | pass 15 | 16 | @abstractmethod 17 | def resume_training(): 18 | """ Resume training for each network component if path is provided """ 19 | pass 20 | 21 | @abstractmethod 22 | def define_losses(): 23 | """ Define loss functions """ 24 | pass 25 | 26 | @abstractmethod 27 | def set_optimizer(): 28 | """ Set optimizer """ 29 | pass 30 | 31 | @abstractmethod 32 | def set_decayLR(): 33 | """ Set decay LR""" 34 | pass 35 | 36 | @abstractmethod 37 | def prepare_inputs(): 38 | """ Prepare inputs for the model""" 39 | pass 40 | 41 | @abstractmethod 42 | def compute_loss_and_update(): 43 | """ Compute model losses and update it""" 44 | pass 45 | 46 | @abstractmethod 47 | def set_description(): 48 | """ Define progress bar string output """ 49 | pass 50 | 51 | @abstractmethod 52 | def update_learning_rates(): 53 | """ Update learning rates """ 54 | pass 55 | 56 | @abstractmethod 57 | def save_parameters(): 58 | """ Save network parameters """ 59 | pass 60 | 61 | @abstractmethod 62 | def load_weights(): 63 | """ Load weights for testing """ 64 | pass 65 | 66 | @abstractmethod 67 | def set_eval_mode(): 68 | """ Set network in evaluation mode for testing """ 69 | pass 70 | 71 | @abstractmethod 72 | def test(): 73 | """ Test model """ 74 | pass 75 | 76 | @abstractmethod 77 | def metrics_evaluation(): 78 | """ Evaluate metrics for the model """ 79 | pass 80 | 81 | @abstractmethod 82 | def metrics_initialization(): 83 | """ Metrics Initialization """ 84 | pass 85 | 86 | @abstractmethod 87 | def print_evaluation_metrics(self): 88 | """ Print evaluation metrics """ 89 | pass 90 | 91 | @abstractmethod 92 | def save_training_progress(self): 93 | """ Save training progress for first batch """ 94 | pass -------------------------------------------------------------------------------- /metrics/LPIPS/data/image_folder.py: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # Code from 3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py 4 | # Modified the original code so that it also loads images from the current 5 | # directory as well as the subdirectories 6 | ################################################################################ 7 | 8 | import torch.utils.data as data 9 | 10 | from PIL import Image 11 | import os 12 | import os.path 13 | 14 | IMG_EXTENSIONS = [ 15 | '.jpg', '.JPG', '.jpeg', '.JPEG', 16 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 17 | ] 18 | 19 | NP_EXTENSIONS = ['.npy',] 20 | 21 | def is_image_file(filename, mode='img'): 22 | if(mode=='img'): 23 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 24 | elif(mode=='np'): 25 | return any(filename.endswith(extension) for extension in NP_EXTENSIONS) 26 | 27 | def make_dataset(dirs, mode='img'): 28 | if(not isinstance(dirs,list)): 29 | dirs = [dirs,] 30 | 31 | images = [] 32 | for dir in dirs: 33 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 34 | for root, _, fnames in sorted(os.walk(dir)): 35 | for fname in fnames: 36 | if is_image_file(fname, mode=mode): 37 | path = os.path.join(root, fname) 38 | images.append(path) 39 | 40 | # print("Found %i images in %s"%(len(images),root)) 41 | return images 42 | 43 | def default_loader(path): 44 | return Image.open(path).convert('RGB') 45 | 46 | class ImageFolder(data.Dataset): 47 | def __init__(self, root, transform=None, return_paths=False, 48 | loader=default_loader): 49 | imgs = make_dataset(root) 50 | if len(imgs) == 0: 51 | raise(RuntimeError("Found 0 images in: " + root + "\n" 52 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 53 | 54 | self.root = root 55 | self.imgs = imgs 56 | self.transform = transform 57 | self.return_paths = return_paths 58 | self.loader = loader 59 | 60 | def __getitem__(self, index): 61 | path = self.imgs[index] 62 | img = self.loader(path) 63 | if self.transform is not None: 64 | img = self.transform(img) 65 | if self.return_paths: 66 | return img, path 67 | else: 68 | return img 69 | 70 | def __len__(self): 71 | return len(self.imgs) 72 | -------------------------------------------------------------------------------- /metrics/LPIPS/data/dataset/twoafc_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import torchvision.transforms as transforms 3 | from data.dataset.base_dataset import BaseDataset 4 | from data.image_folder import make_dataset 5 | from PIL import Image 6 | import numpy as np 7 | import torch 8 | # from IPython import embed 9 | 10 | class TwoAFCDataset(BaseDataset): 11 | def initialize(self, dataroots, load_size=64): 12 | if(not isinstance(dataroots,list)): 13 | dataroots = [dataroots,] 14 | self.roots = dataroots 15 | self.load_size = load_size 16 | 17 | # image directory 18 | self.dir_ref = [os.path.join(root, 'ref') for root in self.roots] 19 | self.ref_paths = make_dataset(self.dir_ref) 20 | self.ref_paths = sorted(self.ref_paths) 21 | 22 | self.dir_p0 = [os.path.join(root, 'p0') for root in self.roots] 23 | self.p0_paths = make_dataset(self.dir_p0) 24 | self.p0_paths = sorted(self.p0_paths) 25 | 26 | self.dir_p1 = [os.path.join(root, 'p1') for root in self.roots] 27 | self.p1_paths = make_dataset(self.dir_p1) 28 | self.p1_paths = sorted(self.p1_paths) 29 | 30 | transform_list = [] 31 | transform_list.append(transforms.Scale(load_size)) 32 | transform_list += [transforms.ToTensor(), 33 | transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))] 34 | 35 | self.transform = transforms.Compose(transform_list) 36 | 37 | # judgement directory 38 | self.dir_J = [os.path.join(root, 'judge') for root in self.roots] 39 | self.judge_paths = make_dataset(self.dir_J,mode='np') 40 | self.judge_paths = sorted(self.judge_paths) 41 | 42 | def __getitem__(self, index): 43 | p0_path = self.p0_paths[index] 44 | p0_img_ = Image.open(p0_path).convert('RGB') 45 | p0_img = self.transform(p0_img_) 46 | 47 | p1_path = self.p1_paths[index] 48 | p1_img_ = Image.open(p1_path).convert('RGB') 49 | p1_img = self.transform(p1_img_) 50 | 51 | ref_path = self.ref_paths[index] 52 | ref_img_ = Image.open(ref_path).convert('RGB') 53 | ref_img = self.transform(ref_img_) 54 | 55 | judge_path = self.judge_paths[index] 56 | # judge_img = (np.load(judge_path)*2.-1.).reshape((1,1,1,)) # [-1,1] 57 | judge_img = np.load(judge_path).reshape((1,1,1,)) # [0,1] 58 | 59 | judge_img = torch.FloatTensor(judge_img) 60 | 61 | return {'p0': p0_img, 'p1': p1_img, 'ref': ref_img, 'judge': judge_img, 62 | 'p0_path': p0_path, 'p1_path': p1_path, 'ref_path': ref_path, 'judge_path': judge_path} 63 | 64 | def __len__(self): 65 | return len(self.p0_paths) 66 | -------------------------------------------------------------------------------- /src/utils_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import random 3 | import torch 4 | import os 5 | 6 | class DecayLR: 7 | def __init__(self, epochs, offset, decay_epochs): 8 | epoch_flag = epochs - decay_epochs 9 | assert (epoch_flag > 0), "Decay must start before the training session ends!" 10 | self.epochs = epochs 11 | self.offset = offset 12 | self.decay_epochs = decay_epochs 13 | 14 | def step(self, epoch): 15 | return 1.0 - max(0, epoch + self.offset - self.decay_epochs) / ( 16 | self.epochs - self.decay_epochs) 17 | 18 | 19 | class AdversarialLoss(nn.Module): 20 | """Define MSE GAN loss. 21 | The GANLoss class abstracts away the need to create the target label tensor 22 | that has the same size as the input. 23 | """ 24 | 25 | def __init__(self, device, target_real_label=1.0, target_fake_label=0.0): 26 | """ Initialize the GANLoss class. 27 | Parameters: 28 | target_real_label (bool) - - label for a real image 29 | target_fake_label (bool) - - label of a fake image 30 | Note: Do not use sigmoid as the last layer of Discriminator. 31 | """ 32 | super(AdversarialLoss, self).__init__() 33 | self.register_buffer('real_label', torch.tensor(target_real_label).to(device)) 34 | self.register_buffer('fake_label', torch.tensor(target_fake_label).to(device)) 35 | 36 | self.loss = nn.MSELoss().to(device) 37 | 38 | 39 | def get_target_tensor(self, prediction, target_is_real): 40 | """Create label tensors with the same size as the input. 41 | Parameters: 42 | prediction (tensor) - - tpyically the prediction from a discriminator 43 | target_is_real (bool) - - if the ground truth label is for real images or fake images 44 | Returns: 45 | A label tensor filled with ground truth label, and with the size of the input 46 | """ 47 | 48 | if target_is_real: 49 | target_tensor = self.real_label 50 | else: 51 | target_tensor = self.fake_label 52 | return target_tensor.expand_as(prediction) 53 | 54 | def __call__(self, prediction, target_is_real): 55 | """Calculate loss given Discriminator's output and grount truth labels. 56 | Parameters: 57 | prediction (tensor) - - tpyically the prediction output from a discriminator 58 | target_is_real (bool) - - if the ground truth label is for real images or fake images 59 | Returns: 60 | the calculated loss. 61 | """ 62 | target_tensor = self.get_target_tensor(prediction, target_is_real) 63 | loss = self.loss(prediction, target_tensor) 64 | return loss -------------------------------------------------------------------------------- /src/train_options.py: -------------------------------------------------------------------------------- 1 | from src.base_options import BaseOptions 2 | 3 | class TrainOptions(BaseOptions): 4 | """This class includes training options.It also includes shared options defined in BaseOptions.""" 5 | def initialize(self, parser): 6 | # Initialize base options 7 | parser = BaseOptions.initialize(self, parser) 8 | # network saving and loading parameters 9 | parser.add_argument("--weightsf", default="./weights", help="folder saving weights. (default:'./weights').") 10 | parser.add_argument("--netG_A2B", default="", help="path to netG_A2B (to continue training)") 11 | parser.add_argument("--netG_B2A", default="", help="path to netG_B2A (to continue training)") 12 | parser.add_argument("--netD_A", default="", help="path to netD_A (to continue training)") 13 | parser.add_argument("--netD_B", default="", help="path to netD_B (to continue training)") 14 | parser.add_argument("--netD_fit", default="", help="path to netD_fit (to continue training)") 15 | parser.add_argument("--netD_mask", default="", help="path to netD_mask (to continue training)") 16 | # training parameters 17 | parser.add_argument("--decay_epochs", type=int, default=100, help="epoch to start linearly decaying the learning rate to 0. (default:100)") 18 | parser.add_argument("--batch_size", default=1, type=int, metavar="N", help="mini-batch size (default: 1), this is the total batch size of all GPUs on the current node when using Data Parallel or Distributed Data Parallel") 19 | parser.add_argument("--lr", type=float, default=0.0002, help="learning rate. (default:0.0002)") 20 | parser.add_argument("--beta_1", type=float, default=0.5, help="Beta 1 optimizer. (default:0.5)") 21 | parser.add_argument("--beta_2", type=float, default=0.999, help="Beta 2 optimizer. (default:0.999)") 22 | parser.add_argument("--lambda_identity_A", type=float, default=5.0, help="Weight identity A loss") 23 | parser.add_argument("--lambda_identity_B", type=float, default=5.0, help="Weight identity B loss") 24 | parser.add_argument("--lambda_GAN_A2B", type=float, default=1.0, help="Weight adversarial A2B loss") 25 | parser.add_argument("--lambda_GAN_B2A", type=float, default=1.0, help="Weight adversarial B2A loss") 26 | parser.add_argument("--lambda_cycle_ABA", type=float, default=10.0, help="Weight cycle B2A loss") 27 | parser.add_argument("--lambda_cycle_BAB", type=float, default=10.0, help="Weight cycle B2A loss") 28 | parser.add_argument("--lambda_GAN_fit", type=float, default=150.0, help="Weight adversarial fit loss") 29 | parser.add_argument("--lambda_background", type=float, default=0.1, help="Weight background loss") 30 | 31 | return parser 32 | -------------------------------------------------------------------------------- /metrics/mIoU/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models 4 | 5 | def convrelu(in_channels, out_channels, kernel, padding): 6 | return nn.Sequential( 7 | nn.Conv2d(in_channels, out_channels, kernel, padding=padding), 8 | nn.ReLU(inplace=True), 9 | ) 10 | 11 | class ResNetUNet(nn.Module): 12 | def __init__(self, n_class): 13 | super().__init__() 14 | 15 | self.base_model = models.resnet18(pretrained=True) 16 | #self.base_model.load_state_dict(torch.load("resnet18-f37072fd.pth")) 17 | 18 | self.base_layers = list(self.base_model.children()) 19 | 20 | self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2) 21 | self.layer0_1x1 = convrelu(64, 64, 1, 0) 22 | self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4) 23 | self.layer1_1x1 = convrelu(64, 64, 1, 0) 24 | self.layer2 = self.base_layers[5] # size=(N, 128, x.H/8, x.W/8) 25 | self.layer2_1x1 = convrelu(128, 128, 1, 0) 26 | self.layer3 = self.base_layers[6] # size=(N, 256, x.H/16, x.W/16) 27 | self.layer3_1x1 = convrelu(256, 256, 1, 0) 28 | self.layer4 = self.base_layers[7] # size=(N, 512, x.H/32, x.W/32) 29 | self.layer4_1x1 = convrelu(512, 512, 1, 0) 30 | 31 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 32 | 33 | self.conv_up3 = convrelu(256 + 512, 512, 3, 1) 34 | self.conv_up2 = convrelu(128 + 512, 256, 3, 1) 35 | self.conv_up1 = convrelu(64 + 256, 256, 3, 1) 36 | self.conv_up0 = convrelu(64 + 256, 128, 3, 1) 37 | 38 | self.conv_original_size0 = convrelu(3, 64, 3, 1) 39 | self.conv_original_size1 = convrelu(64, 64, 3, 1) 40 | self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1) 41 | 42 | self.conv_last = nn.Conv2d(64, n_class, 1) 43 | 44 | def forward(self, input): 45 | x_original = self.conv_original_size0(input) 46 | x_original = self.conv_original_size1(x_original) 47 | 48 | layer0 = self.layer0(input) 49 | layer1 = self.layer1(layer0) 50 | layer2 = self.layer2(layer1) 51 | layer3 = self.layer3(layer2) 52 | layer4 = self.layer4(layer3) 53 | 54 | layer4 = self.layer4_1x1(layer4) 55 | x = self.upsample(layer4) 56 | layer3 = self.layer3_1x1(layer3) 57 | x = torch.cat([x, layer3], dim=1) 58 | x = self.conv_up3(x) 59 | 60 | x = self.upsample(x) 61 | layer2 = self.layer2_1x1(layer2) 62 | x = torch.cat([x, layer2], dim=1) 63 | x = self.conv_up2(x) 64 | 65 | x = self.upsample(x) 66 | layer1 = self.layer1_1x1(layer1) 67 | x = torch.cat([x, layer1], dim=1) 68 | x = self.conv_up1(x) 69 | 70 | x = self.upsample(x) 71 | layer0 = self.layer0_1x1(layer0) 72 | x = torch.cat([x, layer0], dim=1) 73 | x = self.conv_up0(x) 74 | 75 | x = self.upsample(x) 76 | x = torch.cat([x, x_original], dim=1) 77 | x = self.conv_original_size2(x) 78 | 79 | out = self.conv_last(x) 80 | 81 | return out -------------------------------------------------------------------------------- /src/utils_datasets.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import random 4 | 5 | import cv2 6 | import numpy as np 7 | from PIL import Image 8 | from torch.utils.data import Dataset 9 | import torch 10 | 11 | class ImageDataset(Dataset): 12 | def __init__(self, root, transform=None, unaligned=False, mode="train", mask=False): 13 | 14 | self.transform = transform 15 | self.unaligned = unaligned 16 | self.mask = mask 17 | 18 | self.files_A = sorted(glob.glob(os.path.join(root, f"{mode}/A") + "/*.*")) 19 | self.files_B = sorted(glob.glob(os.path.join(root, f"{mode}/B") + "/*.*")) 20 | if mask: 21 | # Load mask folder 22 | self.files_B_mask = sorted(glob.glob(os.path.join(root, f"{mode}/mask") + "/*.*")) 23 | 24 | # Count number of files for domain A and B 25 | self.n_files_A = len(self.files_A) 26 | self.n_files_B = len(self.files_B) 27 | 28 | def __getitem__(self, index): 29 | if self.unaligned: 30 | b_index = random.randint(0, len(self.files_B) - 1) 31 | else: 32 | b_index = index % len(self.files_B) 33 | 34 | if self.mask: 35 | # Image and mask 36 | # Apply transformation 37 | items = self.transform( 38 | image=np.asarray(Image.open(self.files_A[index % len(self.files_A)])), 39 | imageB=np.asarray(Image.open(self.files_B[b_index])), 40 | maskB=np.asarray(Image.open(self.files_B_mask[b_index])) 41 | ) 42 | out = {"A": items["image"], "B": items["imageB"], "B_mask": items["maskB"]} 43 | else: 44 | # Image only 45 | # Apply transformation 46 | items = self.transform( 47 | image=np.asarray(Image.open(self.files_A[index % len(self.files_A)])), 48 | imageB=np.asarray(Image.open(self.files_B[b_index])) 49 | ) 50 | out = {"A": items["image"], "B": items["imageB"]} 51 | 52 | return out 53 | 54 | 55 | def __len__(self): 56 | return max(len(self.files_A), len(self.files_B)) 57 | 58 | def num_files_AB(self): 59 | # Return number of files per domain 60 | return (self.n_files_A, self.n_files_B) 61 | 62 | class ReplayBuffer: 63 | def __init__(self, max_size=50): 64 | assert (max_size > 0), "Empty buffer or trying to create a black hole. Be careful." 65 | self.max_size = max_size 66 | self.data = [] 67 | 68 | def push_and_pop(self, data): 69 | to_return = [] 70 | for element in data.data: 71 | element = torch.unsqueeze(element, 0) 72 | if len(self.data) < self.max_size: 73 | self.data.append(element) 74 | to_return.append(element) 75 | else: 76 | if random.uniform(0, 1) > 0.5: 77 | i = random.randint(0, self.max_size - 1) 78 | to_return.append(self.data[i].clone()) 79 | self.data[i] = element 80 | else: 81 | to_return.append(element) 82 | return torch.cat(to_return) 83 | -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.transforms import ToTensor 3 | from .FID.fid_score import calculate_fid_given_paths 4 | from .mIoU.main import compute_miou 5 | from .LPIPS.models import PerceptualLoss 6 | 7 | p_model = PerceptualLoss(model='net-lin', net='alex', use_gpu=True, gpu_ids=[0]) 8 | 9 | def FID(list_real_image, list_fake_image, dims=2048, device='cpu'): 10 | """ Compute FID score between two sets of images""" 11 | 12 | device = torch.device(device) 13 | 14 | fid_value = calculate_fid_given_paths(paths=[list_real_image, list_fake_image], 15 | batch_size=1, 16 | device=device, 17 | dims=dims, 18 | num_workers=1) 19 | 20 | return fid_value 21 | 22 | 23 | def LPIPS(list_fake_image): 24 | """ 25 | Compute average LPIPS between pairs of fake images 26 | """ 27 | dist_diversity = 0 28 | count = 0 29 | lst_im = list() 30 | # --- unpack images --- # 31 | for i in range(len(list_fake_image)): 32 | lst_im.append(ToTensor()(list_fake_image[i]).unsqueeze(0)) 33 | # --- compute LPIPS between pairs of images --- # 34 | for i in range(len(lst_im))[:100]: 35 | for j in range(i + 1, len(lst_im))[:100]: 36 | dist_diversity += p_model.forward(lst_im[i], lst_im[j]) 37 | count += 1 38 | return dist_diversity/count 39 | 40 | 41 | def LPIPS_to_train(list_real_image, list_fake_image, names_fake_image): 42 | """ 43 | For each fake image find the LPIPS to the closest training image 44 | """ 45 | dist_to_real_dict = dict() 46 | ans1 = 0 47 | count = 0 48 | lst_real, list_fake = list(), list() 49 | # --- unpack images --- # 50 | for i in range(len(list_fake_image)): 51 | list_fake.append(ToTensor()(list_fake_image[i]).unsqueeze(0)) 52 | for i in range(len(list_real_image)): 53 | lst_real.append(ToTensor()(list_real_image[i]).unsqueeze(0)) 54 | # --- compute average minimum LPIPS from a fake image to real images --- # 55 | for i in range(len(list_fake)): 56 | tens_im1 = list_fake[i] 57 | cur_ans = list() 58 | for j in range(len(lst_real)): 59 | tens_im2 = lst_real[j] 60 | dist_to_real = p_model.forward(tens_im1, tens_im2) 61 | cur_ans.append(dist_to_real) 62 | cur_min = torch.min(torch.Tensor(cur_ans)) 63 | dist_to_real_dict[names_fake_image[i]] = float(cur_min.detach().cpu().item()) 64 | ans1 += cur_min 65 | count += 1 66 | ans = ans1 / count 67 | return ans, dist_to_real_dict 68 | 69 | def mIoU(path_real_images, names_real_image, path_real_masks, names_real_masks, 70 | exp_folder, names_fake_image, names_fake_masks, im_res): 71 | """ 72 | Train a simple UNet on fake (real) images&masks, test on real (fake) images&masks. 73 | Report mIoU and segmentation accuracy for the whole sets (fake->real and real->fake) as well as 74 | individual scores for each fake image 75 | """ 76 | metrics_tensor, results, results_acc = compute_miou(path_real_images, names_real_image, path_real_masks, names_real_masks, 77 | exp_folder, names_fake_image, names_fake_masks, im_res) 78 | return metrics_tensor, results, results_acc 79 | -------------------------------------------------------------------------------- /metrics/LPIPS/test_dataset_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from models import dist_model as dm 3 | from data import data_loader as dl 4 | import argparse 5 | from IPython import embed 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--dataset_mode', type=str, default='2afc', help='[2afc,jnd]') 9 | parser.add_argument('--datasets', type=str, nargs='+', default=['val/traditional','val/cnn','val/superres','val/deblur','val/color','val/frameinterp'], help='datasets to test - for jnd mode: [val/traditional],[val/cnn]; for 2afc mode: [train/traditional],[train/cnn],[train/mix],[val/traditional],[val/cnn],[val/color],[val/deblur],[val/frameinterp],[val/superres]') 10 | parser.add_argument('--model', type=str, default='net-lin', help='distance model type [net-lin] for linearly calibrated net, [net] for off-the-shelf network, [l2] for euclidean distance, [ssim] for Structured Similarity Image Metric') 11 | parser.add_argument('--net', type=str, default='alex', help='[squeeze], [alex], or [vgg] for network architectures') 12 | parser.add_argument('--colorspace', type=str, default='Lab', help='[Lab] or [RGB] for colorspace to use for l2, ssim model types') 13 | parser.add_argument('--batch_size', type=int, default=50, help='batch size to test image patches in') 14 | parser.add_argument('--use_gpu', action='store_true', help='turn on flag to use GPU') 15 | parser.add_argument('--gpu_ids', type=int, nargs='+', default=[0], help='gpus to use') 16 | parser.add_argument('--nThreads', type=int, default=4, help='number of threads to use in data loader') 17 | 18 | parser.add_argument('--model_path', type=str, default=None, help='location of model, will default to ./weights/v[version]/[net_name].pth') 19 | 20 | parser.add_argument('--from_scratch', action='store_true', help='model was initialized from scratch') 21 | parser.add_argument('--train_trunk', action='store_true', help='model trunk was trained/tuned') 22 | parser.add_argument('--version', type=str, default='0.1', help='v0.1 is latest, v0.0 was original release') 23 | 24 | opt = parser.parse_args() 25 | if(opt.model in ['l2','ssim']): 26 | opt.batch_size = 1 27 | 28 | # initialize model 29 | model = dm.DistModel() 30 | # model.initialize(model=opt.model,net=opt.net,colorspace=opt.colorspace,model_path=opt.model_path,use_gpu=opt.use_gpu) 31 | model.initialize(model=opt.model, net=opt.net, colorspace=opt.colorspace, 32 | model_path=opt.model_path, use_gpu=opt.use_gpu, pnet_rand=opt.from_scratch, pnet_tune=opt.train_trunk, 33 | version=opt.version, gpu_ids=opt.gpu_ids) 34 | 35 | if(opt.model in ['net-lin','net']): 36 | print('Testing model [%s]-[%s]'%(opt.model,opt.net)) 37 | elif(opt.model in ['l2','ssim']): 38 | print('Testing model [%s]-[%s]'%(opt.model,opt.colorspace)) 39 | 40 | # initialize data loader 41 | for dataset in opt.datasets: 42 | data_loader = dl.CreateDataLoader(dataset,dataset_mode=opt.dataset_mode, batch_size=opt.batch_size, nThreads=opt.nThreads) 43 | 44 | # evaluate model on data 45 | if(opt.dataset_mode=='2afc'): 46 | (score, results_verbose) = dm.score_2afc_dataset(data_loader, model.forward, name=dataset) 47 | elif(opt.dataset_mode=='jnd'): 48 | (score, results_verbose) = dm.score_jnd_dataset(data_loader, model.forward, name=dataset) 49 | 50 | # print results 51 | print(' Dataset [%s]: %.2f'%(dataset,100.*score)) 52 | 53 | -------------------------------------------------------------------------------- /metrics/FID/README.md: -------------------------------------------------------------------------------- 1 | # Fréchet Inception Distance (FID score) in PyTorch 2 | 3 | This is a port of the official implementation of [Fréchet Inception Distance](https://arxiv.org/abs/1706.08500) to PyTorch. 4 | See [https://github.com/bioinf-jku/TTUR](https://github.com/bioinf-jku/TTUR) for the original implementation using Tensorflow. 5 | 6 | FID is a measure of similarity between two datasets of images. 7 | It was shown to correlate well with human judgement of visual quality and is most often used to evaluate the quality of samples of Generative Adversarial Networks. 8 | FID is calculated by computing the [Fréchet distance](https://en.wikipedia.org/wiki/Fr%C3%A9chet_distance) between two Gaussians fitted to feature representations of the Inception network. 9 | 10 | Further insights and an independent evaluation of the FID score can be found in [Are GANs Created Equal? A Large-Scale Study](https://arxiv.org/abs/1711.10337). 11 | 12 | **Note that the official implementation gives slightly different scores.** If you report FID scores in your paper, and you want them to be exactly comparable to FID scores reported in other papers, you should use [the official Tensorflow implementation](https://github.com/bioinf-jku/TTUR). 13 | You can still use this version if you want a quick FID estimate without installing Tensorflow. 14 | 15 | **Update:** The weights and the model are now exactly the same as in the official Tensorflow implementation, and I verified them to give the same results (around `1e-8` mean absolute error) on single inputs on my platform. However, due to differences in the image interpolation implementation and library backends, FID results might still differ slightly from the original implementation. A test I ran (details are to come) resulted in `.08` absolute error and `0.0009` relative error. 16 | 17 | ## Usage 18 | 19 | Requirements: 20 | - python3 21 | - pytorch 22 | - torchvision 23 | - numpy 24 | - scipy 25 | 26 | To compute the FID score between two datasets, where images of each dataset are contained in an individual folder: 27 | ``` 28 | ./fid_score.py path/to/dataset1 path/to/dataset2 29 | ``` 30 | 31 | To run the evaluation on GPU, use the flag `--gpu N`, where `N` is the index of the GPU to use. 32 | 33 | ### Using different layers for feature maps 34 | 35 | In difference to the official implementation, you can choose to use a different feature layer of the Inception network instead of the default `pool3` layer. 36 | As the lower layer features still have spatial extent, the features are first global average pooled to a vector before estimating mean and covariance. 37 | 38 | This might be useful if the datasets you want to compare have less than the otherwise required 2048 images. 39 | Note that this changes the magnitude of the FID score and you can not compare them against scores calculated on another dimensionality. 40 | The resulting scores might also no longer correlate with visual quality. 41 | 42 | You can select the dimensionality of features to use with the flag `--dims N`, where N is the dimensionality of features. 43 | The choices are: 44 | - 64: first max pooling features 45 | - 192: second max pooling featurs 46 | - 768: pre-aux classifier features 47 | - 2048: final average pooling features (this is the default) 48 | 49 | ## License 50 | 51 | This implementation is licensed under the Apache License 2.0. 52 | 53 | FID was introduced by Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler and Sepp Hochreiter in "GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium", see [https://arxiv.org/abs/1706.08500](https://arxiv.org/abs/1706.08500) 54 | 55 | The original implementation is by the Institute of Bioinformatics, JKU Linz, licensed under the Apache License 2.0. 56 | See [https://github.com/bioinf-jku/TTUR](https://github.com/bioinf-jku/TTUR). 57 | -------------------------------------------------------------------------------- /metrics/FID/tests_with_FID.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | from PIL import Image 5 | 6 | from pytorch_fid import fid_score, inception 7 | 8 | 9 | @pytest.fixture 10 | def device(): 11 | return torch.device('cpu') 12 | 13 | 14 | def test_calculate_fid_given_statistics(mocker, tmp_path, device): 15 | dim = 2048 16 | m1, m2 = np.zeros((dim,)), np.ones((dim,)) 17 | sigma = np.eye(dim) 18 | 19 | def dummy_statistics(path, model, batch_size, dims, device, num_workers): 20 | if path.endswith('1'): 21 | return m1, sigma 22 | elif path.endswith('2'): 23 | return m2, sigma 24 | else: 25 | raise ValueError 26 | 27 | mocker.patch('pytorch_fid.fid_score.compute_statistics_of_path', 28 | side_effect=dummy_statistics) 29 | 30 | dir_names = ['1', '2'] 31 | paths = [] 32 | for name in dir_names: 33 | path = tmp_path / name 34 | path.mkdir() 35 | paths.append(str(path)) 36 | 37 | fid_value = fid_score.calculate_fid_given_paths(paths, 38 | batch_size=dim, 39 | device=device, 40 | dims=dim, 41 | num_workers=0) 42 | 43 | # Given equal covariance, FID is just the squared norm of difference 44 | assert fid_value == np.sum((m1 - m2)**2) 45 | 46 | 47 | def test_compute_statistics_of_path(mocker, tmp_path, device): 48 | model = mocker.MagicMock(inception.InceptionV3)() 49 | model.side_effect = lambda inp: [inp.mean(dim=(2, 3), keepdim=True)] 50 | 51 | size = (4, 4, 3) 52 | arrays = [np.zeros(size), np.ones(size) * 0.5, np.ones(size)] 53 | images = [(arr * 255).astype(np.uint8) for arr in arrays] 54 | 55 | paths = [] 56 | for idx, image in enumerate(images): 57 | paths.append(str(tmp_path / '{}.png'.format(idx))) 58 | Image.fromarray(image, mode='RGB').save(paths[-1]) 59 | 60 | stats = fid_score.compute_statistics_of_path(str(tmp_path), model, 61 | batch_size=len(images), 62 | dims=3, 63 | device=device, 64 | num_workers=0) 65 | 66 | assert np.allclose(stats[0], np.ones((3,)) * 0.5, atol=1e-3) 67 | assert np.allclose(stats[1], np.ones((3, 3)) * 0.25) 68 | 69 | 70 | def test_compute_statistics_of_path_from_file(mocker, tmp_path, device): 71 | model = mocker.MagicMock(inception.InceptionV3)() 72 | 73 | mu = np.random.randn(5) 74 | sigma = np.random.randn(5, 5) 75 | 76 | path = tmp_path / 'stats.npz' 77 | with path.open('wb') as f: 78 | np.savez(f, mu=mu, sigma=sigma) 79 | 80 | stats = fid_score.compute_statistics_of_path(str(path), model, 81 | batch_size=1, 82 | dims=5, 83 | device=device, 84 | num_workers=0) 85 | 86 | assert np.allclose(stats[0], mu) 87 | assert np.allclose(stats[1], sigma) 88 | 89 | 90 | def test_image_types(tmp_path): 91 | in_arr = np.ones((24, 24, 3), dtype=np.uint8) * 255 92 | in_image = Image.fromarray(in_arr, mode='RGB') 93 | 94 | paths = [] 95 | for ext in fid_score.IMAGE_EXTENSIONS: 96 | paths.append(str(tmp_path / 'img.{}'.format(ext))) 97 | in_image.save(paths[-1]) 98 | 99 | dataset = fid_score.ImagePathDataset(paths) 100 | 101 | for img in dataset: 102 | assert np.allclose(np.array(img), in_arr) 103 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adversarial Defect Synthesis - PyTorch 2 | 3 | ### Overview 4 | This repository contains the PyTorch implementation of [Adversarial Defect Synthesis for Industrial Products in Low Data Regime](https://ieeexplore.ieee.org/document/10222874). 5 | 6 | ### Installation 7 | 8 | #### Clone and install requirements 9 | 10 | ```bash 11 | $ git clone https://github.com/pasqualecoscia/SyntheticDefectGeneration 12 | $ cd SyntheticDefectGeneration/ 13 | $ pip3 install -r requirements.txt 14 | ``` 15 | 16 | #### Download dataset 17 | 18 | Download the [MvTec AD dataset](https://www.mvtec.com/company/research/datasets/mvtec-ad) and extract the data into the data/mvtec folder. Then, select one product and defect and run: 19 | 20 | ```bash 21 | $ python3 data/create_mvtec_dataset.py --product product_name --defect defect_name 22 | ``` 23 | 24 | ### Train 25 | 26 | The following command can be used to train the model. 27 | 28 | ```bash 29 | $ python3 train.py --cuda 30 | ``` 31 | See src/train_options.py and src/base_options.py for more details. 32 | 33 | ### Test 34 | 35 | The following command can be used to test the model. 36 | 37 | ```bash 38 | $ python3 test.py --cuda 39 | ``` 40 | 41 | #### Resume training 42 | 43 | If you want to load pre-trained weights, run the following command. 44 | 45 | ```bash 46 | # Select the epoch to load 47 | $ python3 train.py --cuda\ 48 | --netG_A2B weights/mvtec_dataset/netG_A2B_epoch_100.pth \ 49 | --netG_B2A weights/mvtec_dataset/netG_B2A_epoch_100.pth \ 50 | --netD_A weights/mvtec_dataset/netD_A_epoch_100.pth \ 51 | --netD_B weights/mvtec_dataset/netD_B_epoch_100.pth \ 52 | --netD_fit weights/mvtec_dataset/netD_fit_epoch_100.pth \ 53 | --netD_mask weights/mvtec_dataset/netD_mask_epoch_100.pth 54 | ``` 55 | ### Merics Evaluation 56 | 57 | The following command can be used to evaluate the quality of the generated images. 58 | 59 | ```bash 60 | # Select the epoch to evaluate 61 | $ python3 evaluate.py --cuda --epoch 150 62 | ``` 63 | 64 | #### Classifier 65 | 66 | To run the classification experiment, run the following command (different models are supported). 67 | 68 | ```bash 69 | # Example: resnet18 model for 150 epochs 70 | $ python3 classifier.py --cuda --model resnet18 --batch_size 50 --epochs 150 71 | ``` 72 | 73 | #### Adversarial Defect Synthesis for Industrial Products in Low Data Regime 74 | _Pasquale Coscia, Angelo Genovese, Fabio Scotti, Vincenzo Piuri_
75 | 76 | **Abstract**
77 | Synthetic defect generation is an important aid for advanced manufacturing and production processes. Industrial scenarios rely on automated image-based quality control methods to avoid time-consuming manual inspections and promptly identify products not complying with specific quality standards. However, these methods show poor performance in the case of ill-posed low-data training regimes, and the lack of defective samples, due to operational costs or privacy policies, strongly limits their large-scale applicability.To overcome these limitations, we propose an innovative architecture based on an unpaired image-to-image (I2I) translation model to guide a transformation from a defect-free to a defective domain for common industrial products and propose simultaneously localizing their synthesized defects through a segmentation mask. As a performance evaluation, we measure image similarity and variability using standard metrics employed for generative models. Finally, we demonstrate that inspection networks, trained on synthesized samples, improve their accuracy in spotting real defective products. 78 | 79 | ``` 80 | @INPROCEEDINGS{defsynthesis, 81 | author={Coscia, Pasquale and Genovese, Angelo and Scotti, Fabio and Piuri, Vincenzo}, 82 | booktitle={2023 IEEE International Conference on Image Processing (ICIP)}, 83 | title={Adversarial Defect Synthesis for Industrial Products in Low Data Regime}, 84 | year={2023}, 85 | pages={1360-1364}, 86 | doi={10.1109/ICIP49359.2023.10222874}} 87 | ``` 88 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | 4 | import torch.backends.cudnn as cudnn 5 | import torch.utils.data 6 | import torch.utils.data.distributed 7 | import torchvision.io as io 8 | from tqdm import tqdm 9 | import albumentations as A 10 | from albumentations.pytorch import ToTensorV2 11 | from src.utils_datasets import ImageDataset 12 | from PIL.Image import Resampling 13 | # from src.models import define_generator 14 | from src.test_options import TestOptions 15 | from src.utils import create_test_folders, select_input_root, normalize_images_diff 16 | from src.model_selection import select_model 17 | 18 | import os 19 | import argparse 20 | import numpy as np 21 | import pickle 22 | import torch 23 | from PIL import Image 24 | from torchvision.transforms import ToTensor 25 | from metrics import FID, LPIPS, LPIPS_to_train 26 | 27 | if __name__ == '__main__': 28 | 29 | # Load test options 30 | args = TestOptions().parse() 31 | # Input parser for epoch selection 32 | parser = argparse.ArgumentParser(description='Evaluation metrics parser') 33 | parser.add_argument("--cuda", action="store_true", help="Enables cuda") 34 | parser.add_argument('--epoch', type=int, required=True, help='epoch to evaluate') 35 | parser.add_argument('--dims', type=int, default=64, help='FID dims features [64, 192, 768, 2048]') 36 | args_input = parser.parse_args() 37 | 38 | # Set cudnn 39 | cudnn.benchmark = True 40 | 41 | # GPU check 42 | if torch.cuda.is_available() and args_input.cuda == True: 43 | device = 'cuda' 44 | else: 45 | device = 'cpu' 46 | print("WARNING: you should probably run with --cuda if you have a GPU support.") 47 | 48 | print(device) 49 | DATA_PATH = os.path.join(args.outf, args.dataset, "test", args.input_type, "A2B/", "epoch_" + str(args_input.epoch)) 50 | DATA_PATH_REAL_B = os.path.join(args.dataroot, args.dataset, "train", "B/") 51 | # Read fake/real images 52 | list_fake_image_B, list_real_image_A, list_train_image_B = list(), list(), list() 53 | paths_fake_image_B, paths_real_image_A, paths_train_image_B = list(), list(), list() 54 | names_fake_image_B = sorted([f for f in os.listdir(DATA_PATH) if ("fake" in f.split('.')[0] and "mask" not in f.split('.')[0])]) 55 | names_real_image_A = sorted([f for f in os.listdir(DATA_PATH) if "real" in f.split('.')[0]]) 56 | names_train_image_B = sorted([f for f in os.listdir(DATA_PATH_REAL_B)]) 57 | 58 | for i in range(len(names_fake_image_B)): 59 | ############################## CHECK RESIZE 60 | p = os.path.join(DATA_PATH, names_fake_image_B[i]) 61 | im = (Image.open(p).convert("RGB")) 62 | list_fake_image_B += [im.resize((args.image_size, args.image_size), Resampling.BILINEAR)] 63 | paths_fake_image_B.append(p) 64 | #im_res = (ToTensor()(list_fake_image[0]).shape[2], ToTensor()(list_fake_image[0]).shape[1]) 65 | for i in range(len(names_real_image_A)): 66 | p = os.path.join(DATA_PATH, names_real_image_A[i]) 67 | im = (Image.open(p).convert("RGB")) 68 | list_real_image_A += [im.resize((args.image_size, args.image_size), Resampling.BILINEAR)] 69 | paths_real_image_A.append(p) 70 | for i in range(len(names_train_image_B)): 71 | p = os.path.join(DATA_PATH_REAL_B, names_train_image_B[i]) 72 | im = (Image.open(p).convert("RGB")) 73 | list_train_image_B += [im.resize((args.image_size, args.image_size), Resampling.BILINEAR)] 74 | paths_train_image_B.append(p) 75 | 76 | # --- Compute the metrics --- # 77 | with torch.no_grad(): 78 | fid_score = FID(paths_train_image_B, paths_fake_image_B, dims=args_input.dims, device=device) 79 | lpips = LPIPS(list_fake_image_B) 80 | dist_to_tr, dist_to_tr_byimage = LPIPS_to_train(list_train_image_B, list_fake_image_B, names_fake_image_B) 81 | 82 | print(f"FID: {fid_score:.2f}") 83 | print(f"LPIPS: {lpips.item():.2f}") 84 | print(f"LPIPS_to_train: {dist_to_tr:.2f}") 85 | 86 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import torch.backends.cudnn as cudnn 5 | import torch.utils.data 6 | import torchvision.transforms as transforms 7 | from PIL import Image 8 | from tqdm import tqdm 9 | import albumentations as A 10 | from albumentations.pytorch import ToTensorV2 11 | 12 | import cv2 13 | from src.utils_datasets import ImageDataset 14 | from src.utils import create_train_folders 15 | from src.model_selection import select_model 16 | from src.train_options import TrainOptions 17 | 18 | if __name__ == '__main__': 19 | 20 | # Load train options 21 | args = TrainOptions().parse() # get training options 22 | 23 | # Create training folders 24 | create_train_folders(args) 25 | 26 | # Set cudnn 27 | cudnn.benchmark = True 28 | 29 | # GPU check 30 | if torch.cuda.is_available() and not args.cuda: 31 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 32 | 33 | # Dataset 34 | # Normalize between [-1, 1] 35 | mean = (.5, .5, .5) 36 | std = (.5, .5, .5) 37 | 38 | DATASET_ROOT = os.path.join(args.dataroot, args.dataset) 39 | 40 | # Apply same transformation to image and mask (if provided) 41 | if args.mask: 42 | additional_targets = { 43 | 'imageB': 'image', 44 | 'maskB': 'mask', 45 | } 46 | else: 47 | additional_targets = { 48 | 'imageB': 'image', 49 | } 50 | 51 | transform = A.Compose([ 52 | A.Resize(int(args.image_size * 1.12), int(args.image_size * 1.12), interpolation=cv2.INTER_CUBIC), 53 | A.RandomCrop(args.image_size, args.image_size, p=1), 54 | A.HorizontalFlip(p=0.5), 55 | A.VerticalFlip(p=0.5), 56 | # A.ShiftScaleRotate (shift_limit=0.05, scale_limit=0.2, rotate_limit=10, interpolation=1, \ 57 | # border_mode=1, value=None, mask_value=None, shift_limit_x=None, shift_limit_y=None, rotate_method='largest_box', always_apply=False, p=0.5), 58 | A.ElasticTransform (alpha=1, sigma=2, alpha_affine=0.5, interpolation=cv2.INTER_CUBIC, \ 59 | border_mode=cv2.BORDER_REFLECT, value=None, mask_value=0, always_apply=False, approximate=False, \ 60 | same_dxdy=False, p=0.5), 61 | #A.RandomBrightnessContrast(p=0.2), 62 | #A.RandomContrast(p=0.2), 63 | A.Normalize(mean=mean, std=std), 64 | ToTensorV2() 65 | ], 66 | additional_targets=additional_targets 67 | ) 68 | 69 | 70 | dataset = ImageDataset(root=DATASET_ROOT, 71 | # transform_mask=transforms.Compose([ 72 | # transforms.Resize(int(args.image_size * 1.12), Image.BICUBIC), 73 | # transforms.RandomCrop(args.image_size), 74 | # transforms.RandomHorizontalFlip(), 75 | # transforms.ToTensor(), 76 | # ]), 77 | # transform_image = transforms.Normalize(mean=mean, std=std), 78 | transform = transform, 79 | unaligned=True, 80 | mask=args.mask) 81 | 82 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True) 83 | 84 | # Create model 85 | model = select_model(args) 86 | 87 | model.create_network() 88 | 89 | # Resume training if paths are provided 90 | model.resume_training() 91 | 92 | # Define losses 93 | model.define_losses() 94 | 95 | # Set optimizer 96 | model.set_optimizer() 97 | 98 | # Set decay LR 99 | model.set_decayLR() 100 | 101 | labels = {'real_label': 1, 'fake_label': 0} 102 | 103 | for epoch in range(1, args.epochs+1): 104 | progress_bar = tqdm(enumerate(dataloader), total=len(dataloader)) 105 | for i, data in progress_bar: 106 | 107 | # Prepare input data 108 | inputs = (data, labels) 109 | inputs = model.prepare_inputs(inputs) 110 | 111 | (output, losses) = model.compute_loss_and_update(inputs) 112 | 113 | if i == 0: 114 | model.save_training_progress(inputs, mean, std, epoch) 115 | 116 | # Print progress bar info 117 | progress_bar.set_description(model.set_description(losses, epoch, args.epochs, len(dataloader), i)) 118 | 119 | # Save images at specific frequencies 120 | if i % args.save_freq_images == 0: 121 | model.save_training_images(inputs, output, epoch, i) 122 | 123 | if epoch % (args.save_freq) == 0: 124 | # Save checkpoints 125 | model.save_parameters(epoch) 126 | 127 | # Update learning rates 128 | model.update_learning_rates() 129 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | 5 | def create_train_folders(args): 6 | """ Create train folders """ 7 | 8 | # Create folder for output images 9 | try: 10 | os.makedirs(args.outf) 11 | print('Output folder created!') 12 | except OSError: 13 | print('WARNING: Output folder already created or problem encountered!') 14 | 15 | # Create folder for saving model's weights 16 | try: 17 | os.makedirs(args.weightsf) 18 | print('Weights folder created!') 19 | except OSError: 20 | print('WARNING: Weights folder already created or problem encountered!') 21 | 22 | try: 23 | os.makedirs(os.path.join(args.weightsf, args.dataset)) 24 | print(f'Weights folder for dataset {args.dataset} created!') 25 | except OSError: 26 | print(f'WARNING: Weights folder for dataset {args.dataset} already created or problem encountered!') 27 | 28 | # Images folders 29 | try: 30 | os.makedirs(os.path.join(args.outf, args.dataset, "train", "A2B")) 31 | print('A2B folder created!') 32 | os.makedirs(os.path.join(args.outf, args.dataset, "train", "B2A")) 33 | print('B2A folder created!') 34 | except OSError: 35 | print('WARNING: Train A2B or Train B2A already created or problem encountered!') 36 | 37 | # Training progress folders 38 | try: 39 | os.makedirs(os.path.join(args.outf, args.dataset, "train_progress/A2B")) 40 | print('Training progress A2B folder created!') 41 | os.makedirs(os.path.join(args.outf, args.dataset, "train_progress/B2A")) 42 | print('Training progress B2A folder created!') 43 | except OSError: 44 | print('WARNING: Training progress folders already created or problem encountered!') 45 | 46 | def create_test_folders(args, epoch): 47 | """ Create test folders """ 48 | # Images folders 49 | try: 50 | os.makedirs(os.path.join(args.outf, args.dataset, "test", args.input_type, "A2B", "epoch_" + str(epoch))) 51 | print('A2B folder created!') 52 | os.makedirs(os.path.join(args.outf, args.dataset, "test", args.input_type, "B2A", "epoch_" + str(epoch))) 53 | print('B2A folder created!') 54 | except OSError: 55 | print('WARNING: Test A2B or Test B2A already created or problem encountered!') 56 | 57 | def select_input_root(args): 58 | """ Select input root """ 59 | if args.input_type == 'standard': # no-modified test set 60 | ROOT_PATH = os.path.join(args.dataroot, args.dataset) 61 | elif args.input_type == 'random': # random noise input 62 | ROOT_PATH = os.path.join(args.dataroot, 'random') 63 | elif args.input_type == 'other_products': # images of other products 64 | ROOT_PATH = os.path.join(args.dataroot, 'other_products') 65 | elif args.input_type == 'checkboard': # images of checkboard patterns 66 | ROOT_PATH = os.path.join(args.dataroot, 'checkboard') 67 | elif args.input_type == 'gradient': # images of checkboard patterns 68 | ROOT_PATH = os.path.join(args.dataroot, 'gradient') 69 | return ROOT_PATH 70 | 71 | def normalize_images_diff(output, mean, std): 72 | """ Normalize input images and computes also differences between real and fake images""" 73 | 74 | real_image_A = output['real_image_A'] 75 | real_image_B = output['real_image_B'] 76 | fake_image_A = output['fake_image_A'] 77 | fake_image_B = output['fake_image_B'] 78 | 79 | # Transform [-1, 1] -> [0, 1] 80 | real_image_A = real_image_A * std + mean 81 | real_image_B = real_image_B * std + mean 82 | fake_image_A = fake_image_A * std + mean 83 | fake_image_B = fake_image_B * std + mean 84 | 85 | # Transform [0, 1] -> [0, 255] 86 | real_image_A = np.moveaxis(real_image_A.squeeze(0).cpu().numpy()*255, 0, -1).astype(np.uint8) 87 | real_image_B = np.moveaxis(real_image_B.squeeze(0).cpu().numpy()*255, 0, -1).astype(np.uint8) 88 | fake_image_A = np.moveaxis(fake_image_A.squeeze(0).cpu().numpy()*255, 0, -1).astype(np.uint8) 89 | fake_image_B = np.moveaxis(fake_image_B.squeeze(0).cpu().numpy()*255, 0, -1).astype(np.uint8) 90 | 91 | # Absolute differences 92 | diff_A2B = cv2.absdiff(real_image_A, fake_image_B) # -> np.abs(img1 - img2), same as PIL.ImageChops.difference(im1, im2) 93 | diff_B2A = cv2.absdiff(real_image_B, fake_image_A) # -> np.abs(img1 - img2) 94 | 95 | # Convert to grayscale image 96 | diff_A2B = cv2.cvtColor(diff_A2B, cv2.COLOR_RGB2GRAY) 97 | diff_B2A = cv2.cvtColor(diff_B2A, cv2.COLOR_RGB2GRAY) 98 | 99 | out = { 100 | 'real_image_A': real_image_A, 101 | 'real_image_B': real_image_B, 102 | 'fake_image_A': fake_image_A, 103 | 'fake_image_B': fake_image_B, 104 | 'diff_A2B': diff_A2B, 105 | 'diff_B2A': diff_B2A 106 | } 107 | return out -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | 4 | import torch.backends.cudnn as cudnn 5 | import torch.utils.data 6 | import torch.utils.data.distributed 7 | import torchvision.transforms as transforms 8 | import torchvision.io as io 9 | from tqdm import tqdm 10 | import albumentations as A 11 | from albumentations.pytorch import ToTensorV2 12 | from src.utils_datasets import ImageDataset 13 | 14 | from torchmetrics.image.fid import FrechetInceptionDistance 15 | from torchmetrics.image.kid import KernelInceptionDistance 16 | # from src.models import define_generator 17 | from src.test_options import TestOptions 18 | from src.utils import create_test_folders, select_input_root, normalize_images_diff 19 | from src.model_selection import select_model 20 | 21 | def remove_duplicates(path, num_files): 22 | ''' 23 | Remove duplicates if images are not aligned and each domain has a different number of files 24 | ''' 25 | files_list = sorted(glob.glob(path + "*.*")) 26 | files_list = files_list[num_files:] 27 | for p in files_list: 28 | os.remove(p) 29 | 30 | def read_images(path, transform): 31 | ''' 32 | Read images in path and applies transformation 'transform'. 33 | 34 | Input 35 | ------ 36 | path 37 | transform: torchvision.transforms.Resize 38 | 39 | Output 40 | ------ 41 | batch: torch.Tensor (N x 3 x H x W) 42 | ''' 43 | 44 | batch_size = len(os.listdir(path)) 45 | batch = torch.zeros(batch_size, 3, transform.size, transform.size, dtype=torch.uint8) 46 | for i, filename in enumerate(os.listdir(path)): 47 | batch[i] = transform(io.read_image(os.path.join(path, filename))) 48 | 49 | return batch 50 | 51 | if __name__ == '__main__': 52 | 53 | # Load test options 54 | args = TestOptions().parse() 55 | 56 | # Set cudnn 57 | cudnn.benchmark = True 58 | 59 | # GPU check 60 | if torch.cuda.is_available() and not args.cuda: 61 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 62 | 63 | # Dataset 64 | mean_ = (.5, .5, .5) 65 | std_ = (.5, .5, .5) 66 | # Select input type 67 | ROOT_PATH = select_input_root(args) 68 | 69 | # Apply same transformation to image and mask (if provided) 70 | if args.mask: 71 | additional_targets = { 72 | 'imageB': 'image', 73 | 'maskB': 'mask', 74 | } 75 | else: 76 | additional_targets = { 77 | 'imageB': 'image', 78 | } 79 | 80 | transform = A.Compose([ 81 | A.Resize(args.image_size, args.image_size), 82 | A.Normalize(mean=mean_, std=std_), 83 | ToTensorV2() 84 | ], 85 | additional_targets=additional_targets 86 | ) 87 | 88 | dataset = ImageDataset(root=ROOT_PATH, 89 | # transform=transforms.Compose([ 90 | # transforms.Resize((args.image_size, args.image_size)), 91 | # transforms.ToTensor(), 92 | # transforms.Normalize(mean=mean, std=std) 93 | # ]), 94 | transform=transform, 95 | mode="test", 96 | mask=args.mask) 97 | 98 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, pin_memory=True) 99 | 100 | # Test every save_freq iteration 101 | epochs = list(range(args.save_freq, args.epochs + 1, args.save_freq)) 102 | for epoch in epochs: 103 | # Create test folders 104 | create_test_folders(args, epoch) 105 | 106 | progress_bar = tqdm(enumerate(dataloader), total=len(dataloader)) 107 | 108 | # Create models 109 | model = select_model(args) 110 | model.create_network(mode='test') 111 | 112 | # Load weights 113 | model.load_weights(epoch) 114 | 115 | # Set evaluation mode 116 | model.set_eval_mode() 117 | 118 | # Initialize metrics 119 | model.metrics_initialization() 120 | 121 | # Normalization values 122 | std = torch.tensor(std_, device=model.device).view(3, 1, 1) 123 | mean = torch.tensor(mean_, device=model.device).view(3, 1, 1) 124 | 125 | for i, data in progress_bar: 126 | labels = {'real_label': 1, 'fake_label': 0} 127 | # Prepare input data 128 | inputs = (data, labels) 129 | inputs = model.prepare_inputs(inputs) 130 | 131 | # Model testing 132 | output = model.test(inputs) 133 | 134 | # Normalize images and computer differences for saving images 135 | norm_output = normalize_images_diff(output, mean, std) 136 | model.save_test_images(output, norm_output, i, epoch) 137 | 138 | # Evaluate metrics 139 | model.metrics_evaluation(norm_output, output["fake_mask_B"]) 140 | 141 | progress_bar.set_description(f"Processing images {i + 1} of {len(dataloader)}") 142 | 143 | 144 | # Print metrics 145 | model.print_evaluation_metrics() 146 | -------------------------------------------------------------------------------- /src/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | class BaseOptions(): 4 | 5 | def __init__(self): 6 | """Reset the class; indicates the class hasn't been initailized""" 7 | self.initialized = False 8 | 9 | def initialize(self, parser): 10 | """Define the common options that are used for both training and test sets.""" 11 | # saving frequencies 12 | parser.add_argument("--save_freq_images", default=100, type=int, help="print frequency. (default:100)") 13 | parser.add_argument("--save_freq", default=25, type=int, help="saving frequency. (default:25). The net will be tested on each saved parameters.") 14 | # Number of training epochs (used also for loading parameters during testing) 15 | parser.add_argument("--epochs", default=300, type=int, metavar="N", help="number of total epochs to run") 16 | # basic parameters 17 | parser.add_argument("--dataroot", type=str, default="./data", help="path to datasets. (default:./data)") 18 | parser.add_argument("--dataset", type=str, default="mvtec_dataset", help="dataset name. (default:`horse2zebra`) Option: [apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, facades, selfie2anime, iphone2dslr_flower, ae_photos, ]") 19 | parser.add_argument("--mask", action="store_false", help="Load binary masks for defective images.") 20 | parser.add_argument("--outf", default="./results", help="folder for output images. (default:'./results').") 21 | # model parameters 22 | parser.add_argument('--model', type=str, default='cycle_gan_mask', help='specify model type [cycle_gan|cycle_gan_mask]') 23 | parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale') 24 | parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale') 25 | parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') 26 | parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]') 27 | # Network layers 28 | parser.add_argument('--netG', type=str, default='ResNet9', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]') 29 | parser.add_argument('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator') 30 | parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer') 31 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer') 32 | parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers') 33 | # dataset parameters 34 | parser.add_argument("--image_size", type=int, default=256, help="size of the data crop (squared assumed). (default:256)") 35 | # additional parameters 36 | parser.add_argument("--cuda", action="store_true", help="Enables cuda") 37 | # Mask threshold 38 | parser.add_argument("--thrs_mask", type=float, default=0.2, help="Threshold for obtaining defect segmentation mask from raw logits") 39 | self.initialized = True 40 | 41 | return parser 42 | 43 | def gather_options(self): 44 | """Initialize parser with basic options (only once).""" 45 | if not self.initialized: # check if it has been initialized 46 | parser = argparse.ArgumentParser(description="INFECT: Defects Blending for Industrial Products") 47 | parser = self.initialize(parser) 48 | 49 | # save and return the parser 50 | self.parser = parser 51 | return parser.parse_args() 52 | 53 | 54 | def print_options(self, opt): 55 | """Print and save options 56 | It will print both current options and default values(if different). 57 | It will save options into a text file / [checkpoints_dir] / opt.txt 58 | """ 59 | message = '' 60 | message += '----------------- Options ---------------\n' 61 | for k, v in sorted(vars(opt).items()): 62 | comment = '' 63 | default = self.parser.get_default(k) 64 | if v != default: 65 | comment = '\t[default: %s]' % str(default) 66 | message += '{:>20}: {:<30}{}\n'.format(str(k), str(v), comment) 67 | message += '----------------- End -------------------' 68 | print(message) 69 | 70 | # # save to the disk 71 | # expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 72 | # utils.mkdirs(expr_dir) 73 | # file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase)) 74 | # with open(file_name, 'wt') as opt_file: 75 | # opt_file.write(message) 76 | # opt_file.write('\n') 77 | 78 | def parse(self): 79 | """ Parse and print options """ 80 | opt = self.gather_options() 81 | 82 | self.print_options(opt) 83 | 84 | self.opt = opt 85 | return self.opt 86 | -------------------------------------------------------------------------------- /metrics/LPIPS/train.py: -------------------------------------------------------------------------------- 1 | import torch.backends.cudnn as cudnn 2 | cudnn.benchmark=False 3 | 4 | import numpy as np 5 | import time 6 | import os 7 | from models import dist_model as dm 8 | from data import data_loader as dl 9 | import argparse 10 | from util.visualizer import Visualizer 11 | from IPython import embed 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--datasets', type=str, nargs='+', default=['train/traditional','train/cnn','train/mix'], help='datasets to train on: [train/traditional],[train/cnn],[train/mix],[val/traditional],[val/cnn],[val/color],[val/deblur],[val/frameinterp],[val/superres]') 15 | parser.add_argument('--model', type=str, default='net-lin', help='distance model type [net-lin] for linearly calibrated net, [net] for off-the-shelf network, [l2] for euclidean distance, [ssim] for Structured Similarity Image Metric') 16 | parser.add_argument('--net', type=str, default='alex', help='[squeeze], [alex], or [vgg] for network architectures') 17 | parser.add_argument('--batch_size', type=int, default=50, help='batch size to test image patches in') 18 | parser.add_argument('--use_gpu', action='store_true', help='turn on flag to use GPU') 19 | parser.add_argument('--gpu_ids', type=int, nargs='+', default=[0], help='gpus to use') 20 | 21 | parser.add_argument('--nThreads', type=int, default=4, help='number of threads to use in data loader') 22 | parser.add_argument('--nepoch', type=int, default=5, help='# epochs at base learning rate') 23 | parser.add_argument('--nepoch_decay', type=int, default=5, help='# additional epochs at linearly learning rate') 24 | parser.add_argument('--display_freq', type=int, default=5000, help='frequency (in instances) of showing training results on screen') 25 | parser.add_argument('--print_freq', type=int, default=5000, help='frequency (in instances) of showing training results on console') 26 | parser.add_argument('--save_latest_freq', type=int, default=20000, help='frequency (in instances) of saving the latest results') 27 | parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs') 28 | parser.add_argument('--display_id', type=int, default=0, help='window id of the visdom display, [0] for no displaying') 29 | parser.add_argument('--display_winsize', type=int, default=256, help='display window size') 30 | parser.add_argument('--display_port', type=int, default=8001, help='visdom display port') 31 | parser.add_argument('--use_html', action='store_true', help='save off html pages') 32 | parser.add_argument('--checkpoints_dir', type=str, default='checkpoints', help='checkpoints directory') 33 | parser.add_argument('--name', type=str, default='tmp', help='directory name for training') 34 | 35 | parser.add_argument('--from_scratch', action='store_true', help='model was initialized from scratch') 36 | parser.add_argument('--train_trunk', action='store_true', help='model trunk was trained/tuned') 37 | parser.add_argument('--train_plot', action='store_true', help='plot saving') 38 | 39 | opt = parser.parse_args() 40 | opt.save_dir = os.path.join(opt.checkpoints_dir,opt.name) 41 | if(not os.path.exists(opt.save_dir)): 42 | os.mkdir(opt.save_dir) 43 | 44 | # initialize model 45 | model = dm.DistModel() 46 | model.initialize(model=opt.model, net=opt.net, use_gpu=opt.use_gpu, is_train=True, 47 | pnet_rand=opt.from_scratch, pnet_tune=opt.train_trunk, gpu_ids=opt.gpu_ids) 48 | 49 | # load data from all training sets 50 | data_loader = dl.CreateDataLoader(opt.datasets,dataset_mode='2afc', batch_size=opt.batch_size, serial_batches=False, nThreads=opt.nThreads) 51 | dataset = data_loader.load_data() 52 | dataset_size = len(data_loader) 53 | D = len(dataset) 54 | print('Loading %i instances from'%dataset_size,opt.datasets) 55 | visualizer = Visualizer(opt) 56 | 57 | total_steps = 0 58 | fid = open(os.path.join(opt.checkpoints_dir,opt.name,'train_log.txt'),'w+') 59 | for epoch in range(1, opt.nepoch + opt.nepoch_decay + 1): 60 | epoch_start_time = time.time() 61 | for i, data in enumerate(dataset): 62 | iter_start_time = time.time() 63 | total_steps += opt.batch_size 64 | epoch_iter = total_steps - dataset_size * (epoch - 1) 65 | 66 | model.set_input(data) 67 | model.optimize_parameters() 68 | 69 | if total_steps % opt.display_freq == 0: 70 | visualizer.display_current_results(model.get_current_visuals(), epoch) 71 | 72 | if total_steps % opt.print_freq == 0: 73 | errors = model.get_current_errors() 74 | t = (time.time()-iter_start_time)/opt.batch_size 75 | t2o = (time.time()-epoch_start_time)/3600. 76 | t2 = t2o*D/(i+.0001) 77 | visualizer.print_current_errors(epoch, epoch_iter, errors, t, t2=t2, t2o=t2o, fid=fid) 78 | 79 | for key in errors.keys(): 80 | visualizer.plot_current_errors_save(epoch, float(epoch_iter)/dataset_size, opt, errors, keys=[key,], name=key, to_plot=opt.train_plot) 81 | 82 | if opt.display_id > 0: 83 | visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors) 84 | 85 | if total_steps % opt.save_latest_freq == 0: 86 | print('saving the latest model (epoch %d, total_steps %d)' % 87 | (epoch, total_steps)) 88 | model.save(opt.save_dir, 'latest') 89 | 90 | if epoch % opt.save_epoch_freq == 0: 91 | print('saving the model at the end of epoch %d, iters %d' % 92 | (epoch, total_steps)) 93 | model.save(opt.save_dir, 'latest') 94 | model.save(opt.save_dir, epoch) 95 | 96 | print('End of epoch %d / %d \t Time Taken: %d sec' % 97 | (epoch, opt.nepoch + opt.nepoch_decay, time.time() - epoch_start_time)) 98 | 99 | if epoch > opt.nepoch: 100 | model.update_learning_rate(opt.nepoch_decay) 101 | 102 | # model.save_done(True) 103 | fid.close() 104 | -------------------------------------------------------------------------------- /metrics/mIoU/dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import numpy as np 4 | from torchvision import transforms as TR 5 | from PIL import Image, ImageOps 6 | import torch 7 | from torch.utils.data import Dataset, DataLoader 8 | 9 | 10 | class SimDataset(Dataset): 11 | def __init__(self, fld_im, names_real_image, fld_mask, names_real_masks, im_res, real=True, num_ch=None, no_transform=False): 12 | self.real = real 13 | 14 | self.frame_path = fld_im 15 | self.frames = names_real_image 16 | self.mask_path = fld_mask 17 | self.masks = names_real_masks 18 | self.im_res = (im_res[1], im_res[0]) 19 | self.no_transform = no_transform 20 | if real: 21 | self.num_mask_channels = self.get_num_mask_channels() 22 | else: 23 | self.num_mask_channels = num_ch 24 | 25 | self.transforms = get_transforms(im_res, no_transform) 26 | 27 | def __len__(self): 28 | return 10000000 29 | 30 | def __getitem__(self, indx): 31 | idx = indx % len(self.frames) 32 | 33 | seed = np.random.randint(2147483647) 34 | random.seed(seed) 35 | torch.manual_seed(seed) 36 | 37 | img_pil = Image.open("%s/%s" % (self.frame_path, self.frames[idx])).convert("RGB") 38 | target_size = self.im_res 39 | 40 | res = self.transforms(TR.functional.resize(img_pil, size=target_size)).to("cuda") 41 | ans = (res - 0.5) * 2 42 | 43 | random.seed(seed) # apply this seed to target tranfsorms 44 | torch.manual_seed(seed) # needed for torchvision 0.7 45 | 46 | mask_pil = Image.open("%s/%s" % (self.mask_path, self.masks[idx][:-4] + ".png")) 47 | mask = self.transforms(TR.functional.resize(mask_pil, size=target_size, interpolation=Image.NEAREST)).to("cuda") 48 | mask = self.create_mask_channels(mask) # mask should be N+1 channels 49 | return [ans, mask] 50 | 51 | def create_mask_channels(self, mask): 52 | if (mask.unique() * 256).max() > 20: # only object and background 53 | mask = (torch.sum(mask, dim=(0,), keepdim=True) > 0)*1.0 54 | mask = torch.cat((1 - mask, mask), dim=0) 55 | return mask 56 | else: # background and many objects 57 | integers = torch.round(mask * 256) 58 | mask = torch.nn.functional.one_hot(integers.long(), num_classes=self.num_mask_channels).float()[ 59 | 0].permute(2, 0, 1) 60 | return mask 61 | 62 | def get_num_mask_channels(self): 63 | masks = self.masks 64 | c = 0 65 | for item in range(len(masks)): 66 | im = TR.ToTensor()(Image.open(os.path.join(self.mask_path, masks[item]))) 67 | if (im.unique() * 256).max() > 20: 68 | c = 2 if 2 > c else c 69 | else: 70 | cur = torch.max(torch.round(im * 256)) 71 | c = cur + 1 if cur + 1 > c else c 72 | return int(c) 73 | 74 | 75 | def get_transforms(im_res, no_transform): 76 | prob_augm = 0.3 77 | tr_list = list() 78 | 79 | if not no_transform: 80 | TR.RandomApply( 81 | [TR.RandomResizedCrop(size=im_res, scale=(0.75, 1.0), ratio=(1, 1))], 82 | p=prob_augm), 83 | 84 | tr_list.append(TR.RandomApply([TR.RandomHorizontalFlip(p=1)], p=prob_augm / 2)), 85 | tr_list.append(TR.RandomApply([myVerticalTranslation(fraction=(0.05, 0.3))], p=prob_augm)), 86 | tr_list.append(TR.RandomApply([myHorizontalTranslation(fraction=(0.05, 0.3))], p=prob_augm)), 87 | tr_list.append(TR.ToTensor()) 88 | return TR.Compose(tr_list) 89 | 90 | 91 | class myRandomResizedCrop(TR.RandomResizedCrop): 92 | def __init__(self, size=256, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), ): 93 | super(myRandomResizedCrop, self).__init__(size, scale, ratio) 94 | 95 | def __call__(self, img): 96 | i, j, h, w = self.get_params(img, self.scale, self.ratio) 97 | return TR.functional.resized_crop(img, i, j, h, w, (img.size[1], img.size[0]), self.interpolation) 98 | 99 | 100 | class myVerticalTranslation(TR.RandomResizedCrop): 101 | def __init__(self, fraction=(0.05, 0.3)): 102 | self.fraction = fraction 103 | super(myVerticalTranslation, self).__init__(size=256) 104 | 105 | def __call__(self, img): 106 | margin = torch.rand(1) * (self.fraction[1] - self.fraction[0]) + self.fraction[0] 107 | direct_up = (torch.rand(1) < 0.5) # up or down 108 | width, height = img.size 109 | left, right = 0, width 110 | shift = -int(height * margin) if direct_up else int(height * margin) 111 | if direct_up: 112 | top, bottom = 0, int(height * margin), 113 | else: 114 | top, bottom = height - int(height * margin), height 115 | im_to_paste = ImageOps.flip(img.crop((left, top, right, bottom))) 116 | img = img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, shift)) 117 | if direct_up: 118 | img.paste(im_to_paste, (0, 0)) 119 | else: 120 | img.paste(im_to_paste, (0, height - shift)) 121 | return img 122 | 123 | 124 | class myHorizontalTranslation(TR.RandomResizedCrop): 125 | def __init__(self, fraction=(0.05, 0.3)): 126 | self.fraction = fraction 127 | super(myHorizontalTranslation, self).__init__(size=256) 128 | 129 | def __call__(self, img): 130 | margin = torch.rand(1) * (self.fraction[1] - self.fraction[0]) + self.fraction[0] 131 | direct_left = (torch.rand(1) < 0.5) # up or down 132 | width, height = img.size 133 | top, bottom = 0, height 134 | shift = -int(width * margin) if direct_left else int(width * margin) 135 | if direct_left: 136 | left, right = 0, int(width * margin) 137 | else: 138 | left, right = width - int(width * margin), width 139 | im_to_paste = ImageOps.mirror(img.crop((left, top, right, bottom))) 140 | img = img.transform(img.size, Image.AFFINE, (1, 0, shift, 0, 1, 0)) 141 | if direct_left: 142 | img.paste(im_to_paste, (0, 0)) 143 | else: 144 | img.paste(im_to_paste, (width - shift, 0)) 145 | return img -------------------------------------------------------------------------------- /metrics/LPIPS/models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import numpy as np 7 | from skimage.metrics import structural_similarity 8 | import torch 9 | from torch.autograd import Variable 10 | 11 | from . import dist_model 12 | 13 | class PerceptualLoss(torch.nn.Module): 14 | def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0], version='0.1'): # VGG using our perceptually-learned weights (LPIPS metric) 15 | # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss 16 | super(PerceptualLoss, self).__init__() 17 | print('Setting up Perceptual loss...') 18 | self.use_gpu = use_gpu 19 | self.spatial = spatial 20 | self.gpu_ids = gpu_ids 21 | self.model = dist_model.DistModel() 22 | self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids, version=version) 23 | print('...[%s] initialized'%self.model.name()) 24 | print('...Done') 25 | 26 | def forward(self, pred, target, normalize=False): 27 | """ 28 | Pred and target are Variables. 29 | If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] 30 | If normalize is False, assumes the images are already between [-1,+1] 31 | 32 | Inputs pred and target are Nx3xHxW 33 | Output pytorch Variable N long 34 | """ 35 | 36 | if normalize: 37 | target = 2 * target - 1 38 | pred = 2 * pred - 1 39 | 40 | return self.model.forward(target, pred) 41 | 42 | def normalize_tensor(in_feat,eps=1e-10): 43 | norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True)) 44 | return in_feat/(norm_factor+eps) 45 | 46 | def l2(p0, p1, range=255.): 47 | return .5*np.mean((p0 / range - p1 / range)**2) 48 | 49 | def psnr(p0, p1, peak=255.): 50 | return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2)) 51 | 52 | def dssim(p0, p1, range=255.): 53 | return (1 - structural_similarity(p0, p1, data_range=range, multichannel=True)) / 2. 54 | 55 | def rgb2lab(in_img,mean_cent=False): 56 | from skimage import color 57 | img_lab = color.rgb2lab(in_img) 58 | if(mean_cent): 59 | img_lab[:,:,0] = img_lab[:,:,0]-50 60 | return img_lab 61 | 62 | def tensor2np(tensor_obj): 63 | # change dimension of a tensor object into a numpy array 64 | return tensor_obj[0].cpu().float().numpy().transpose((1,2,0)) 65 | 66 | def np2tensor(np_obj): 67 | # change dimenion of np array into tensor array 68 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 69 | 70 | def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False): 71 | # image tensor to lab tensor 72 | from skimage import color 73 | 74 | img = tensor2im(image_tensor) 75 | img_lab = color.rgb2lab(img) 76 | if(mc_only): 77 | img_lab[:,:,0] = img_lab[:,:,0]-50 78 | if(to_norm and not mc_only): 79 | img_lab[:,:,0] = img_lab[:,:,0]-50 80 | img_lab = img_lab/100. 81 | 82 | return np2tensor(img_lab) 83 | 84 | def tensorlab2tensor(lab_tensor,return_inbnd=False): 85 | from skimage import color 86 | import warnings 87 | warnings.filterwarnings("ignore") 88 | 89 | lab = tensor2np(lab_tensor)*100. 90 | lab[:,:,0] = lab[:,:,0]+50 91 | 92 | rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1) 93 | if(return_inbnd): 94 | # convert back to lab, see if we match 95 | lab_back = color.rgb2lab(rgb_back.astype('uint8')) 96 | mask = 1.*np.isclose(lab_back,lab,atol=2.) 97 | mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis]) 98 | return (im2tensor(rgb_back),mask) 99 | else: 100 | return im2tensor(rgb_back) 101 | 102 | def rgb2lab(input): 103 | from skimage import color 104 | return color.rgb2lab(input / 255.) 105 | 106 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 107 | image_numpy = image_tensor[0].cpu().float().numpy() 108 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 109 | return image_numpy.astype(imtype) 110 | 111 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 112 | return torch.Tensor((image / factor - cent) 113 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 114 | 115 | def tensor2vec(vector_tensor): 116 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0] 117 | 118 | def voc_ap(rec, prec, use_07_metric=False): 119 | """ ap = voc_ap(rec, prec, [use_07_metric]) 120 | Compute VOC AP given precision and recall. 121 | If use_07_metric is true, uses the 122 | VOC 07 11 point method (default:False). 123 | """ 124 | if use_07_metric: 125 | # 11 point metric 126 | ap = 0. 127 | for t in np.arange(0., 1.1, 0.1): 128 | if np.sum(rec >= t) == 0: 129 | p = 0 130 | else: 131 | p = np.max(prec[rec >= t]) 132 | ap = ap + p / 11. 133 | else: 134 | # correct AP calculation 135 | # first append sentinel values at the end 136 | mrec = np.concatenate(([0.], rec, [1.])) 137 | mpre = np.concatenate(([0.], prec, [0.])) 138 | 139 | # compute the precision envelope 140 | for i in range(mpre.size - 1, 0, -1): 141 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 142 | 143 | # to calculate area under PR curve, look for points 144 | # where X axis (recall) changes value 145 | i = np.where(mrec[1:] != mrec[:-1])[0] 146 | 147 | # and sum (\Delta recall) * prec 148 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 149 | return ap 150 | 151 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 152 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): 153 | image_numpy = image_tensor[0].cpu().float().numpy() 154 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 155 | return image_numpy.astype(imtype) 156 | 157 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 158 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): 159 | return torch.Tensor((image / factor - cent) 160 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 161 | -------------------------------------------------------------------------------- /metrics/LPIPS/models/pretrained_networks.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torchvision import models as tv 4 | from IPython import embed 5 | 6 | class squeezenet(torch.nn.Module): 7 | def __init__(self, requires_grad=False, pretrained=True): 8 | super(squeezenet, self).__init__() 9 | pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features 10 | self.slice1 = torch.nn.Sequential() 11 | self.slice2 = torch.nn.Sequential() 12 | self.slice3 = torch.nn.Sequential() 13 | self.slice4 = torch.nn.Sequential() 14 | self.slice5 = torch.nn.Sequential() 15 | self.slice6 = torch.nn.Sequential() 16 | self.slice7 = torch.nn.Sequential() 17 | self.N_slices = 7 18 | for x in range(2): 19 | self.slice1.add_module(str(x), pretrained_features[x]) 20 | for x in range(2,5): 21 | self.slice2.add_module(str(x), pretrained_features[x]) 22 | for x in range(5, 8): 23 | self.slice3.add_module(str(x), pretrained_features[x]) 24 | for x in range(8, 10): 25 | self.slice4.add_module(str(x), pretrained_features[x]) 26 | for x in range(10, 11): 27 | self.slice5.add_module(str(x), pretrained_features[x]) 28 | for x in range(11, 12): 29 | self.slice6.add_module(str(x), pretrained_features[x]) 30 | for x in range(12, 13): 31 | self.slice7.add_module(str(x), pretrained_features[x]) 32 | if not requires_grad: 33 | for param in self.parameters(): 34 | param.requires_grad = False 35 | 36 | def forward(self, X): 37 | h = self.slice1(X) 38 | h_relu1 = h 39 | h = self.slice2(h) 40 | h_relu2 = h 41 | h = self.slice3(h) 42 | h_relu3 = h 43 | h = self.slice4(h) 44 | h_relu4 = h 45 | h = self.slice5(h) 46 | h_relu5 = h 47 | h = self.slice6(h) 48 | h_relu6 = h 49 | h = self.slice7(h) 50 | h_relu7 = h 51 | vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7']) 52 | out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7) 53 | 54 | return out 55 | 56 | 57 | class alexnet(torch.nn.Module): 58 | def __init__(self, requires_grad=False, pretrained=True): 59 | super(alexnet, self).__init__() 60 | alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features 61 | #torch.save(alexnet_pretrained_features, "alex.pth") 62 | #alexnet_pretrained_features = torch.load("SIFID/PerceptualSimilarity/models/alex.pth") 63 | 64 | self.slice1 = torch.nn.Sequential() 65 | self.slice2 = torch.nn.Sequential() 66 | self.slice3 = torch.nn.Sequential() 67 | self.slice4 = torch.nn.Sequential() 68 | self.slice5 = torch.nn.Sequential() 69 | self.N_slices = 5 70 | for x in range(2): 71 | self.slice1.add_module(str(x), alexnet_pretrained_features[x]) 72 | for x in range(2, 5): 73 | self.slice2.add_module(str(x), alexnet_pretrained_features[x]) 74 | for x in range(5, 8): 75 | self.slice3.add_module(str(x), alexnet_pretrained_features[x]) 76 | for x in range(8, 10): 77 | self.slice4.add_module(str(x), alexnet_pretrained_features[x]) 78 | for x in range(10, 12): 79 | self.slice5.add_module(str(x), alexnet_pretrained_features[x]) 80 | if not requires_grad: 81 | for param in self.parameters(): 82 | param.requires_grad = False 83 | 84 | def forward(self, X): 85 | h = self.slice1(X) 86 | h_relu1 = h 87 | h = self.slice2(h) 88 | h_relu2 = h 89 | h = self.slice3(h) 90 | h_relu3 = h 91 | h = self.slice4(h) 92 | h_relu4 = h 93 | h = self.slice5(h) 94 | h_relu5 = h 95 | alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']) 96 | out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) 97 | 98 | return out 99 | 100 | class vgg16(torch.nn.Module): 101 | def __init__(self, requires_grad=False, pretrained=True): 102 | super(vgg16, self).__init__() 103 | 104 | vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features 105 | #vgg_pretrained_features = torch.load("SIFID/PerceptualSimilarity/models/vgg16.pth") 106 | 107 | self.slice1 = torch.nn.Sequential() 108 | self.slice2 = torch.nn.Sequential() 109 | self.slice3 = torch.nn.Sequential() 110 | self.slice4 = torch.nn.Sequential() 111 | self.slice5 = torch.nn.Sequential() 112 | self.N_slices = 5 113 | for x in range(4): 114 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 115 | for x in range(4, 9): 116 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 117 | for x in range(9, 16): 118 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 119 | for x in range(16, 23): 120 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 121 | for x in range(23, 30): 122 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 123 | if not requires_grad: 124 | for param in self.parameters(): 125 | param.requires_grad = False 126 | 127 | def forward(self, X): 128 | h = self.slice1(X) 129 | h_relu1_2 = h 130 | h = self.slice2(h) 131 | h_relu2_2 = h 132 | h = self.slice3(h) 133 | h_relu3_3 = h 134 | h = self.slice4(h) 135 | h_relu4_3 = h 136 | h = self.slice5(h) 137 | h_relu5_3 = h 138 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 139 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 140 | 141 | return out 142 | 143 | 144 | 145 | class resnet(torch.nn.Module): 146 | def __init__(self, requires_grad=False, pretrained=True, num=18): 147 | super(resnet, self).__init__() 148 | if(num==18): 149 | self.net = tv.resnet18(pretrained=pretrained) 150 | elif(num==34): 151 | self.net = tv.resnet34(pretrained=pretrained) 152 | elif(num==50): 153 | self.net = tv.resnet50(pretrained=pretrained) 154 | elif(num==101): 155 | self.net = tv.resnet101(pretrained=pretrained) 156 | elif(num==152): 157 | self.net = tv.resnet152(pretrained=pretrained) 158 | self.N_slices = 5 159 | 160 | self.conv1 = self.net.conv1 161 | self.bn1 = self.net.bn1 162 | self.relu = self.net.relu 163 | self.maxpool = self.net.maxpool 164 | self.layer1 = self.net.layer1 165 | self.layer2 = self.net.layer2 166 | self.layer3 = self.net.layer3 167 | self.layer4 = self.net.layer4 168 | 169 | def forward(self, X): 170 | h = self.conv1(X) 171 | h = self.bn1(h) 172 | h = self.relu(h) 173 | h_relu1 = h 174 | h = self.maxpool(h) 175 | h = self.layer1(h) 176 | h_conv2 = h 177 | h = self.layer2(h) 178 | h_conv3 = h 179 | h = self.layer3(h) 180 | h_conv4 = h 181 | h = self.layer4(h) 182 | h_conv5 = h 183 | 184 | outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5']) 185 | out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) 186 | 187 | return out 188 | -------------------------------------------------------------------------------- /metrics/mIoU/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code implements a simple UNet to compute the mIoU metric evaluating the quality of segmentation masks. 3 | This code is adopted from usuyama/pytorch-unet. 4 | """ 5 | 6 | 7 | import torch 8 | import numpy as np 9 | from torch.utils.data import Dataset, DataLoader 10 | from collections import defaultdict 11 | import torch.nn.functional as F 12 | import torch.optim as optim 13 | from torch.optim import lr_scheduler 14 | from .loss import dice_loss 15 | from .dataset import SimDataset 16 | from .unet import ResNetUNet 17 | 18 | 19 | def calc_loss(pred, target, metrics, bce_weight=0.5): 20 | bce = F.binary_cross_entropy_with_logits(pred, target) 21 | 22 | pred = F.sigmoid(pred) 23 | dice = dice_loss(pred, target) 24 | 25 | loss = bce * bce_weight + dice * (1 - bce_weight) 26 | 27 | metrics['bce'] += bce.data.cpu().numpy() * target.size(0) 28 | metrics['dice'] += dice.data.cpu().numpy() * target.size(0) 29 | metrics['loss'] += loss.data.cpu().numpy() * target.size(0) 30 | return loss 31 | 32 | 33 | def print_metrics(metrics, epoch_samples, phase): 34 | outputs = [] 35 | for k in metrics.keys(): 36 | outputs.append("{}: {:4f}".format(k, metrics[k] / epoch_samples)) 37 | 38 | print("{}: {}".format(phase, ", ".join(outputs))) 39 | 40 | 41 | def iou_pytorch(outputs: torch.Tensor, labels: torch.Tensor): 42 | # You can comment out this line if you are passing tensors of equal shape 43 | # But if you are passing output from UNet or something it will most probably 44 | # be with the BATCH x 1 x H x W shape 45 | ###outputs = outputs.squeeze(1) # BATCH x 1 x H x W => BATCH x H x W 46 | SMOOTH = 1e-6 47 | outputs = torch.nn.functional.one_hot(outputs) 48 | labels = torch.nn.functional.one_hot(labels) 49 | intersection = (outputs & labels).float().sum((1, 2)) # Will be zero if Truth=0 or Prediction=0 50 | union = (outputs | labels).float().sum((1, 2)) # Will be zzero if both are 0 51 | iou = (intersection + SMOOTH) / (union + SMOOTH) # We smooth our devision to avoid 0/0 52 | thresholded = torch.clamp(20 * (iou - 0.5), 0, 10).ceil() / 10 # This is equal to comparing with thresolds 53 | #print(iou.shape, thresholded.shape) 54 | return thresholded 55 | 56 | 57 | def train_model(model, optimizer, scheduler, dataloader_train, num_epochs=500): 58 | model.train() 59 | 60 | for epoch, batch in enumerate(dataloader_train): 61 | if epoch >= num_epochs: 62 | break 63 | inputs, labels = batch 64 | scheduler.step() 65 | inputs = inputs.to("cuda") 66 | labels = labels.to("cuda") 67 | # zero the parameter gradients 68 | optimizer.zero_grad() 69 | metrics = defaultdict(float) 70 | # forward 71 | # track history if only in train 72 | outputs = model(inputs) 73 | loss = calc_loss(outputs, labels, metrics) 74 | 75 | # backward + optimize only if in training phase 76 | loss.backward() 77 | optimizer.step() 78 | return model 79 | 80 | # ------------------------------------------------------------------- 81 | 82 | 83 | def compute_miou(path_real_images, names_real_image, path_real_masks, names_real_masks, 84 | exp_folder, names_fake_image, names_fake_masks, im_res): 85 | train_set = SimDataset(path_real_images, names_real_image, path_real_masks, names_real_masks, im_res, real=True) 86 | num_ch = train_set.num_mask_channels 87 | val_set = SimDataset(exp_folder, names_fake_image, exp_folder, names_fake_masks, im_res, real=False, num_ch=num_ch) 88 | image_datasets = {'train': train_set, 'val': val_set} 89 | batch_size = 5 90 | dataloaders = { 91 | 'train': DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0), 92 | 'val': DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=0) 93 | } 94 | 95 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 96 | model = ResNetUNet(n_class=num_ch) 97 | model = model.to(device) 98 | optimizer_ft = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4) 99 | exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=1000, gamma=1.0) 100 | 101 | model = train_model(model, optimizer_ft, exp_lr_scheduler, dataloaders["val"], 500) 102 | 103 | model.eval() # Set model to the evaluation mode 104 | all_corr, sum_corr, cur_iou, countt = 0, 0, 0, 0 105 | for i, batch in enumerate(dataloaders["train"]): 106 | if i > 100: 107 | break 108 | inputs, labels = batch 109 | inputs = inputs.to(device) 110 | labels = labels.to(device) 111 | # Predict 112 | pred = model(inputs) 113 | # The loss functions include the sigmoid function. 114 | pred = F.sigmoid(pred) 115 | pred = pred.data 116 | pred1 = torch.argmax(pred, dim=1) 117 | pred2 = torch.argmax(labels, dim=1) 118 | correct = ((pred1 == pred2)*1).sum() 119 | sum_corr += correct 120 | all_corr += torch.numel(pred1) 121 | cur_iou += iou_pytorch(pred1, pred2).mean() 122 | countt += 1 123 | metrics_tensor = np.array([-1.0, -1.0, -1.0, -1.0]) 124 | metrics_tensor[0] = sum_corr / all_corr 125 | metrics_tensor[1] = cur_iou / countt 126 | 127 | # HERE TRAINING on real 128 | 129 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 130 | model1 = ResNetUNet(n_class=num_ch) 131 | model1 = model1.to(device) 132 | model1.train() 133 | 134 | optimizer_ft = optim.Adam(filter(lambda p: p.requires_grad, model1.parameters()), lr=1e-4) 135 | exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=1000, gamma=1.0) 136 | 137 | model1 = train_model(model1, optimizer_ft, exp_lr_scheduler, dataloaders["train"], 500) 138 | model1.eval() # Set model to the evaluation mode 139 | all_corr, sum_corr, cur_iou, countt = 0, 0, 0, 0 140 | for i, batch in enumerate(dataloaders["val"]): 141 | if i > 100: 142 | break 143 | inputs, labels = batch 144 | inputs = inputs.to(device) 145 | labels = labels.to(device) 146 | 147 | # Predict 148 | pred = model1(inputs) 149 | # The loss functions include the sigmoid function. 150 | pred = F.sigmoid(pred) 151 | pred = pred.data 152 | 153 | pred1 = torch.argmax(pred, dim=1) 154 | pred2 = torch.argmax(labels, dim=1) 155 | correct = ((pred1 == pred2)*1).sum() 156 | sum_corr += correct 157 | all_corr += torch.numel(pred1) 158 | 159 | cur_iou += iou_pytorch(pred1, pred2).mean() 160 | countt += 1 161 | 162 | metrics_tensor[2] = sum_corr / all_corr 163 | metrics_tensor[3] = cur_iou / countt 164 | 165 | #### metric per image below: 166 | 167 | val_per_fr_set = SimDataset(exp_folder, names_fake_image, exp_folder, names_fake_masks, im_res, real=False, num_ch=num_ch, no_transform=True) 168 | dataloader_per_frame = DataLoader(val_per_fr_set, batch_size=1, shuffle=False, num_workers=0) 169 | 170 | results = dict() 171 | results_acc = dict() 172 | 173 | for i, batch in enumerate(dataloader_per_frame): 174 | if i >= len(val_per_fr_set.masks): 175 | break 176 | inputs, labels = batch 177 | inputs = inputs.to(device) 178 | labels = labels.to(device) 179 | 180 | # Predict 181 | pred = model1(inputs) 182 | # The loss functions include the sigmoid function. 183 | pred = F.sigmoid(pred) 184 | pred = pred.data 185 | 186 | pred1 = torch.argmax(pred, dim=1) 187 | pred2 = torch.argmax(labels, dim=1) 188 | correct = ((pred1 == pred2)*1).sum() 189 | acc = correct / torch.numel(pred1) 190 | results_acc[val_per_fr_set.frames[i]] = acc.detach().cpu().numpy() 191 | cur_iou = iou_pytorch(pred1, pred2).mean() 192 | results[val_per_fr_set.frames[i]] = cur_iou.detach().cpu().numpy() 193 | 194 | return metrics_tensor, results, results_acc 195 | -------------------------------------------------------------------------------- /metrics/LPIPS/models/networks_basic.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | import sys 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.init as init 8 | from torch.autograd import Variable 9 | import numpy as np 10 | from pdb import set_trace as st 11 | from skimage import color 12 | from IPython import embed 13 | from . import pretrained_networks as pn 14 | 15 | import sys 16 | sys.path.append("metrics/LPIPS/") 17 | import models as util 18 | 19 | def spatial_average(in_tens, keepdim=True): 20 | return in_tens.mean([2,3],keepdim=keepdim) 21 | 22 | def upsample(in_tens, out_HW=(64,64)): # assumes scale factor is same for H and W 23 | in_H, in_W = in_tens.shape[2], in_tens.shape[3] 24 | scale_factor_H, scale_factor_W = 1.*out_HW[0]/in_H, 1.*out_HW[1]/in_W 25 | 26 | return nn.Upsample(scale_factor=(scale_factor_H, scale_factor_W), mode='bilinear', align_corners=False)(in_tens) 27 | 28 | # Learned perceptual metric 29 | class PNetLin(nn.Module): 30 | def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True): 31 | super(PNetLin, self).__init__() 32 | 33 | self.pnet_type = pnet_type 34 | self.pnet_tune = pnet_tune 35 | self.pnet_rand = pnet_rand 36 | self.spatial = spatial 37 | self.lpips = lpips 38 | self.version = version 39 | self.scaling_layer = ScalingLayer() 40 | 41 | if(self.pnet_type in ['vgg','vgg16']): 42 | net_type = pn.vgg16 43 | self.chns = [64,128,256,512,512] 44 | elif(self.pnet_type=='alex'): 45 | net_type = pn.alexnet 46 | self.chns = [64,192,384,256,256] 47 | elif(self.pnet_type=='squeeze'): 48 | net_type = pn.squeezenet 49 | self.chns = [64,128,256,384,384,512,512] 50 | self.L = len(self.chns) 51 | 52 | self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) 53 | 54 | if(lpips): 55 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 56 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 57 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 58 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 59 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 60 | self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4] 61 | if(self.pnet_type=='squeeze'): # 7 layers for squeezenet 62 | self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) 63 | self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) 64 | self.lins+=[self.lin5,self.lin6] 65 | 66 | def forward(self, in0, in1, retPerLayer=False): 67 | # v0.0 - original release had a bug, where input was not scaled 68 | in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1) 69 | outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) 70 | feats0, feats1, diffs = {}, {}, {} 71 | 72 | for kk in range(self.L): 73 | feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk]) 74 | diffs[kk] = (feats0[kk]-feats1[kk])**2 75 | 76 | if(self.lpips): 77 | if(self.spatial): 78 | res = [upsample(self.lins[kk].model(diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)] 79 | else: 80 | res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)] 81 | else: 82 | if(self.spatial): 83 | res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)] 84 | else: 85 | res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)] 86 | 87 | val = res[0] 88 | for l in range(1,self.L): 89 | val += res[l] 90 | 91 | if(retPerLayer): 92 | return (val, res) 93 | else: 94 | return val 95 | 96 | class ScalingLayer(nn.Module): 97 | def __init__(self): 98 | super(ScalingLayer, self).__init__() 99 | self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None]) 100 | self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None]) 101 | 102 | def forward(self, inp): 103 | return (inp - self.shift) / self.scale 104 | 105 | 106 | class NetLinLayer(nn.Module): 107 | ''' A single linear layer which does a 1x1 conv ''' 108 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 109 | super(NetLinLayer, self).__init__() 110 | 111 | layers = [nn.Dropout(),] if(use_dropout) else [] 112 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),] 113 | self.model = nn.Sequential(*layers) 114 | 115 | 116 | class Dist2LogitLayer(nn.Module): 117 | ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) ''' 118 | def __init__(self, chn_mid=32, use_sigmoid=True): 119 | super(Dist2LogitLayer, self).__init__() 120 | 121 | layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),] 122 | layers += [nn.LeakyReLU(0.2,True),] 123 | layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),] 124 | layers += [nn.LeakyReLU(0.2,True),] 125 | layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),] 126 | if(use_sigmoid): 127 | layers += [nn.Sigmoid(),] 128 | self.model = nn.Sequential(*layers) 129 | 130 | def forward(self,d0,d1,eps=0.1): 131 | return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1)) 132 | 133 | class BCERankingLoss(nn.Module): 134 | def __init__(self, chn_mid=32): 135 | super(BCERankingLoss, self).__init__() 136 | self.net = Dist2LogitLayer(chn_mid=chn_mid) 137 | # self.parameters = list(self.net.parameters()) 138 | self.loss = torch.nn.BCELoss() 139 | 140 | def forward(self, d0, d1, judge): 141 | per = (judge+1.)/2. 142 | self.logit = self.net.forward(d0,d1) 143 | return self.loss(self.logit, per) 144 | 145 | # L2, DSSIM metrics 146 | class FakeNet(nn.Module): 147 | def __init__(self, use_gpu=True, colorspace='Lab'): 148 | super(FakeNet, self).__init__() 149 | self.use_gpu = use_gpu 150 | self.colorspace=colorspace 151 | 152 | class L2(FakeNet): 153 | 154 | def forward(self, in0, in1, retPerLayer=None): 155 | assert(in0.size()[0]==1) # currently only supports batchSize 1 156 | 157 | if(self.colorspace=='RGB'): 158 | (N,C,X,Y) = in0.size() 159 | value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N) 160 | return value 161 | elif(self.colorspace=='Lab'): 162 | value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)), 163 | util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') 164 | ret_var = Variable( torch.Tensor((value,) ) ) 165 | if(self.use_gpu): 166 | ret_var = ret_var.cuda() 167 | return ret_var 168 | 169 | class DSSIM(FakeNet): 170 | 171 | def forward(self, in0, in1, retPerLayer=None): 172 | assert(in0.size()[0]==1) # currently only supports batchSize 1 173 | 174 | if(self.colorspace=='RGB'): 175 | value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float') 176 | elif(self.colorspace=='Lab'): 177 | value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)), 178 | util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') 179 | ret_var = Variable( torch.Tensor((value,) ) ) 180 | if(self.use_gpu): 181 | ret_var = ret_var.cuda() 182 | return ret_var 183 | 184 | def print_network(net): 185 | num_params = 0 186 | for param in net.parameters(): 187 | num_params += param.numel() 188 | print('Network',net) 189 | print('Total number of parameters: %d' % num_params) 190 | -------------------------------------------------------------------------------- /data/create_mvtec_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | import random 5 | 6 | def copy_files(src_folder, dest_folder, mask_folder=None, mask_dest_folder=None, file_list=None): 7 | ''' 8 | Copy all files from src_folder to dest_folder. 9 | If add_folder is provided, its content is copied into dest_folder 10 | ''' 11 | src_files = os.listdir(src_folder) 12 | for file_name in src_files: 13 | full_file_name = os.path.join(src_folder, file_name) 14 | if os.path.isfile(full_file_name) and (file_list is None or file_name in file_list): 15 | # Rename if file exists 16 | num_file = int(os.path.splitext(os.path.basename(full_file_name))[0]) 17 | ext_file = os.path.splitext(os.path.basename(full_file_name))[-1] 18 | dest_folder_file_renamed = os.path.join(dest_folder, f"{num_file:03}" + ext_file) 19 | while os.path.exists(dest_folder_file_renamed): 20 | num_file += 1 21 | dest_folder_file_renamed = os.path.join(dest_folder, f"{num_file:03}" + ext_file) 22 | shutil.copy(full_file_name, dest_folder_file_renamed) 23 | if mask_folder is not None: 24 | end_name = '_mask' + ext_file 25 | MASK_FILE_PATH = os.path.join(mask_folder, full_file_name.split('/')[-1].split('.')[0] + end_name) 26 | MASK_DEST_FOLDER = os.path.join(mask_dest_folder, MASK_FILE_PATH.split('/')[-1]) 27 | shutil.copy(MASK_FILE_PATH, MASK_DEST_FOLDER) 28 | 29 | class MVTECDataset: 30 | # Class to create the MVTEC dataset using the same format as cycle-gan dataset 31 | 32 | def __init__(self, PATH): 33 | # Extract products and defects 34 | 35 | # Dataset Path 36 | self.path = PATH 37 | 38 | # Dictionary containing for each product its corresponding defects 39 | self.products_defects_dict = self.extract_products_defects(self.path) 40 | 41 | self.selected_products = [] 42 | self.selected_defects = [] 43 | 44 | def extract_products_defects(self, mvtec_path): 45 | ''' 46 | Returns a dictionary where keys are the products found in mvtec_path and 47 | values are the corresponding defects 48 | ''' 49 | prods_defs_dict = dict() 50 | 51 | for product in os.listdir(mvtec_path): 52 | PRODUCT_PATH = os.path.join(mvtec_path, product) 53 | if os.path.isdir(PRODUCT_PATH): 54 | # Defects of product "product" 55 | defects_list = [defect for defect in os.listdir(os.path.join(PRODUCT_PATH, "test")) if defect != "good"] 56 | # Add product and defect to dict 57 | prods_defs_dict[product] = defects_list 58 | 59 | return prods_defs_dict 60 | 61 | def check_inputs(self, product, defect): 62 | ''' 63 | Check if input product name or defect name is correct 64 | Output: 65 | True if correct 66 | False if either product or defect name is not correct 67 | 68 | This function also creates selected_products and selected_defects lists. 69 | ''' 70 | # Boolean variables that check if product and defect names are correct 71 | _product = False 72 | _defect = False 73 | 74 | prod_def_dict = self.products_defects_dict 75 | 76 | #TODO: implement specific defect for each product 77 | if product in prod_def_dict.keys(): 78 | _product = True 79 | self.selected_products.append(product) 80 | if (defect in prod_def_dict[product]): 81 | _defect = True 82 | self.selected_defects.append(defect) 83 | elif defect == 'one': 84 | _defect = True 85 | self.selected_defects.append(random.choice(prod_def_dict[product])) 86 | elif defect == 'all': 87 | _defect = True 88 | self.selected_defects = prod_def_dict[product] 89 | # elif product == 'all': 90 | # _product = True 91 | # self.selected_products = prod_def_dict.keys() 92 | # if defect == 'one': 93 | # _defect = True 94 | # for p in self.selected_products: 95 | # self.selected_defects.append(random.choice(prod_def_dict[p])) 96 | # elif defect == 'all': 97 | # _defect = True 98 | 99 | return (_product and _defect) 100 | 101 | 102 | def save_dataset(self, product, defect): 103 | ''' 104 | Save dataset in the following format: 105 | 106 | DATA 107 | |------PRODUCT_NAME 108 | |-------------TRAIN 109 | | |-----A 110 | | |-----B 111 | | |-----mask 112 | 113 | |-------------TEST 114 | | |-----A 115 | | |-----B 116 | | |-----mask 117 | 118 | ''' 119 | 120 | # Check if product name and defect name are correct 121 | # Note: 'all' means all products (or all defects) 122 | if not self.check_inputs(product, defect): 123 | raise ValueError('Product or defect name not recognized. Please, check input values.') 124 | 125 | print(f"Selected product: {self.selected_products}") 126 | print(f"Selected defect: {self.selected_defects}") 127 | 128 | # Define dataset path './data/mvtec_dataset' 129 | DATASET_PATH = os.path.join(os.path.dirname(self.path), 'mvtec_dataset') 130 | 131 | # Create train and test folders 132 | TRAIN_A_PATH = os.path.join(DATASET_PATH, "train/A") 133 | TRAIN_B_PATH = os.path.join(DATASET_PATH, "train/B") 134 | TRAIN_MASK_PATH = os.path.join(DATASET_PATH, "train/mask") 135 | TEST_A_PATH = os.path.join(DATASET_PATH, "test/A") 136 | TEST_B_PATH = os.path.join(DATASET_PATH, "test/B") 137 | TEST_MASK_PATH = os.path.join(DATASET_PATH, "test/mask") 138 | 139 | try: 140 | os.makedirs(DATASET_PATH) 141 | os.makedirs(TRAIN_A_PATH) 142 | os.makedirs(TRAIN_B_PATH) 143 | os.makedirs(TRAIN_MASK_PATH) 144 | os.makedirs(TEST_A_PATH) 145 | os.makedirs(TEST_B_PATH) 146 | os.makedirs(TEST_MASK_PATH) 147 | except OSError: 148 | pass 149 | 150 | for product in self.selected_products: 151 | # DOMAIN A IMAGES 152 | TRAIN_A_IMAGES_PATH = os.path.join(self.path, product, "train/good") 153 | TEST_A_IMAGES_PATH = os.path.join(self.path, product, "test/good") 154 | copy_files(src_folder=TRAIN_A_IMAGES_PATH, dest_folder=TRAIN_A_PATH) 155 | copy_files(src_folder=TEST_A_IMAGES_PATH, dest_folder=TEST_A_PATH) 156 | 157 | for defect in self.selected_defects: 158 | 159 | # DOMAIN B IMAGES 160 | DOMAIN_B_PATH = os.path.join(self.path, product, "test", defect) 161 | DOMAIN_B_PATH_MASK = os.path.join(self.path, product, "ground_truth", defect) 162 | images_b_list = os.listdir(DOMAIN_B_PATH) 163 | images_b_list.sort() 164 | # Select 80/20 for train/test 165 | 166 | train_b_list = images_b_list[:round(len(images_b_list) * 0.8)] 167 | test_b_list = images_b_list[round(len(images_b_list) * 0.8):] 168 | 169 | # Copy files 170 | copy_files(src_folder=DOMAIN_B_PATH, dest_folder=TRAIN_B_PATH, mask_folder=DOMAIN_B_PATH_MASK, mask_dest_folder=TRAIN_MASK_PATH, file_list=train_b_list) 171 | copy_files(src_folder=DOMAIN_B_PATH, dest_folder=TEST_B_PATH, mask_folder=DOMAIN_B_PATH_MASK, mask_dest_folder=TEST_MASK_PATH, file_list=test_b_list) 172 | 173 | if __name__ == "__main__": 174 | ''' 175 | Create MVTEC dataset using the same format as cycle-gan code. 176 | Four different cases are considered (P stands for product while D for defect): 177 | 178 | 1. 1 P - 1 D -> one product and one defect 179 | 1. 1 P - N D -> one product and N defects (N -> all) 180 | --------------------------------------------- IMPLEMENTED THIS FOR NOW 181 | 182 | 1. N P - 1 D -> N products and one defect per product (the defect is randomly chosen) TODO: select specific defect 183 | 1. N P - N D -> N products and N defects 184 | 185 | ''' 186 | parser = argparse.ArgumentParser( 187 | description="MVTEC dataset creator. You can optionally specify product and defect.") 188 | parser.add_argument("--product", type=str, default="transistor", help="Product. (default:`transistor`). If 'all', all products are selected.") 189 | 190 | parser.add_argument("--defect", type=str, default="all", help="Defect. (default:`all`). Values [one|all|defect_name]. If 'all', all defects are used and combined. If 'one' a defect is randomly selected. If 'defect_name', the selected defect is used.") 191 | args = parser.parse_args() 192 | 193 | CURRENT_PATH = os.path.dirname(os.path.realpath(__file__)) 194 | MVTEC_PATH = os.path.join(CURRENT_PATH, 'mvtec') 195 | 196 | mvtect_dataset = MVTECDataset(MVTEC_PATH) 197 | mvtect_dataset.save_dataset(args.product, args.defect) 198 | -------------------------------------------------------------------------------- /metrics/LPIPS/util/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import time 4 | from . import util 5 | from . import html 6 | # from pdb import set_trace as st 7 | import matplotlib.pyplot as plt 8 | import math 9 | # from IPython import embed 10 | 11 | def zoom_to_res(img,res=256,order=0,axis=0): 12 | # img 3xXxX 13 | from scipy.ndimage import zoom 14 | zoom_factor = res/img.shape[1] 15 | if(axis==0): 16 | return zoom(img,[1,zoom_factor,zoom_factor],order=order) 17 | elif(axis==2): 18 | return zoom(img,[zoom_factor,zoom_factor,1],order=order) 19 | 20 | class Visualizer(): 21 | def __init__(self, opt): 22 | # self.opt = opt 23 | self.display_id = opt.display_id 24 | # self.use_html = opt.is_train and not opt.no_html 25 | self.win_size = opt.display_winsize 26 | self.name = opt.name 27 | self.display_cnt = 0 # display_current_results counter 28 | self.display_cnt_high = 0 29 | self.use_html = opt.use_html 30 | 31 | if self.display_id > 0: 32 | import visdom 33 | self.vis = visdom.Visdom(port = opt.display_port) 34 | 35 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 36 | util.mkdirs([self.web_dir,]) 37 | if self.use_html: 38 | self.img_dir = os.path.join(self.web_dir, 'images') 39 | print('create web directory %s...' % self.web_dir) 40 | util.mkdirs([self.img_dir,]) 41 | 42 | # |visuals|: dictionary of images to display or save 43 | def display_current_results(self, visuals, epoch, nrows=None, res=256): 44 | if self.display_id > 0: # show images in the browser 45 | title = self.name 46 | if(nrows is None): 47 | nrows = int(math.ceil(len(visuals.items()) / 2.0)) 48 | images = [] 49 | idx = 0 50 | for label, image_numpy in visuals.items(): 51 | title += " | " if idx % nrows == 0 else ", " 52 | title += label 53 | img = image_numpy.transpose([2, 0, 1]) 54 | img = zoom_to_res(img,res=res,order=0) 55 | images.append(img) 56 | idx += 1 57 | if len(visuals.items()) % 2 != 0: 58 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1]))*255 59 | white_image = zoom_to_res(white_image,res=res,order=0) 60 | images.append(white_image) 61 | self.vis.images(images, nrow=nrows, win=self.display_id + 1, 62 | opts=dict(title=title)) 63 | 64 | if self.use_html: # save images to a html file 65 | for label, image_numpy in visuals.items(): 66 | img_path = os.path.join(self.img_dir, 'epoch%.3d_cnt%.6d_%s.png' % (epoch, self.display_cnt, label)) 67 | util.save_image(zoom_to_res(image_numpy, res=res, axis=2), img_path) 68 | 69 | self.display_cnt += 1 70 | self.display_cnt_high = np.maximum(self.display_cnt_high, self.display_cnt) 71 | 72 | # update website 73 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1) 74 | for n in range(epoch, 0, -1): 75 | webpage.add_header('epoch [%d]' % n) 76 | if(n==epoch): 77 | high = self.display_cnt 78 | else: 79 | high = self.display_cnt_high 80 | for c in range(high-1,-1,-1): 81 | ims = [] 82 | txts = [] 83 | links = [] 84 | 85 | for label, image_numpy in visuals.items(): 86 | img_path = 'epoch%.3d_cnt%.6d_%s.png' % (n, c, label) 87 | ims.append(os.path.join('images',img_path)) 88 | txts.append(label) 89 | links.append(os.path.join('images',img_path)) 90 | webpage.add_images(ims, txts, links, width=self.win_size) 91 | webpage.save() 92 | 93 | # save errors into a directory 94 | def plot_current_errors_save(self, epoch, counter_ratio, opt, errors,keys='+ALL',name='loss', to_plot=False): 95 | if not hasattr(self, 'plot_data'): 96 | self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())} 97 | self.plot_data['X'].append(epoch + counter_ratio) 98 | self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']]) 99 | 100 | # embed() 101 | if(keys=='+ALL'): 102 | plot_keys = self.plot_data['legend'] 103 | else: 104 | plot_keys = keys 105 | 106 | if(to_plot): 107 | (f,ax) = plt.subplots(1,1) 108 | for (k,kname) in enumerate(plot_keys): 109 | kk = np.where(np.array(self.plot_data['legend'])==kname)[0][0] 110 | x = self.plot_data['X'] 111 | y = np.array(self.plot_data['Y'])[:,kk] 112 | if(to_plot): 113 | ax.plot(x, y, 'o-', label=kname) 114 | np.save(os.path.join(self.web_dir,'%s_x')%kname,x) 115 | np.save(os.path.join(self.web_dir,'%s_y')%kname,y) 116 | 117 | if(to_plot): 118 | plt.legend(loc=0,fontsize='small') 119 | plt.xlabel('epoch') 120 | plt.ylabel('Value') 121 | f.savefig(os.path.join(self.web_dir,'%s.png'%name)) 122 | f.clf() 123 | plt.close() 124 | 125 | # errors: dictionary of error labels and values 126 | def plot_current_errors(self, epoch, counter_ratio, opt, errors): 127 | if not hasattr(self, 'plot_data'): 128 | self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())} 129 | self.plot_data['X'].append(epoch + counter_ratio) 130 | self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']]) 131 | self.vis.line( 132 | X=np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1), 133 | Y=np.array(self.plot_data['Y']), 134 | opts={ 135 | 'title': self.name + ' loss over time', 136 | 'legend': self.plot_data['legend'], 137 | 'xlabel': 'epoch', 138 | 'ylabel': 'loss'}, 139 | win=self.display_id) 140 | 141 | # errors: same format as |errors| of plotCurrentErrors 142 | def print_current_errors(self, epoch, i, errors, t, t2=-1, t2o=-1, fid=None): 143 | message = '(ep: %d, it: %d, t: %.3f[s], ept: %.2f/%.2f[h]) ' % (epoch, i, t, t2o, t2) 144 | message += (', ').join(['%s: %.3f' % (k, v) for k, v in errors.items()]) 145 | 146 | print(message) 147 | if(fid is not None): 148 | fid.write('%s\n'%message) 149 | 150 | 151 | # save image to the disk 152 | def save_images_simple(self, webpage, images, names, in_txts, prefix='', res=256): 153 | image_dir = webpage.get_image_dir() 154 | ims = [] 155 | txts = [] 156 | links = [] 157 | 158 | for name, image_numpy, txt in zip(names, images, in_txts): 159 | image_name = '%s_%s.png' % (prefix, name) 160 | save_path = os.path.join(image_dir, image_name) 161 | if(res is not None): 162 | util.save_image(zoom_to_res(image_numpy,res=res,axis=2), save_path) 163 | else: 164 | util.save_image(image_numpy, save_path) 165 | 166 | ims.append(os.path.join(webpage.img_subdir,image_name)) 167 | # txts.append(name) 168 | txts.append(txt) 169 | links.append(os.path.join(webpage.img_subdir,image_name)) 170 | # embed() 171 | webpage.add_images(ims, txts, links, width=self.win_size) 172 | 173 | # save image to the disk 174 | def save_images(self, webpage, images, names, image_path, title=''): 175 | image_dir = webpage.get_image_dir() 176 | # short_path = ntpath.basename(image_path) 177 | # name = os.path.splitext(short_path)[0] 178 | # name = short_path 179 | # webpage.add_header('%s, %s' % (name, title)) 180 | ims = [] 181 | txts = [] 182 | links = [] 183 | 184 | for label, image_numpy in zip(names, images): 185 | image_name = '%s.jpg' % (label,) 186 | save_path = os.path.join(image_dir, image_name) 187 | util.save_image(image_numpy, save_path) 188 | 189 | ims.append(image_name) 190 | txts.append(label) 191 | links.append(image_name) 192 | webpage.add_images(ims, txts, links, width=self.win_size) 193 | 194 | # save image to the disk 195 | # def save_images(self, webpage, visuals, image_path, short=False): 196 | # image_dir = webpage.get_image_dir() 197 | # if short: 198 | # short_path = ntpath.basename(image_path) 199 | # name = os.path.splitext(short_path)[0] 200 | # else: 201 | # name = image_path 202 | 203 | # webpage.add_header(name) 204 | # ims = [] 205 | # txts = [] 206 | # links = [] 207 | 208 | # for label, image_numpy in visuals.items(): 209 | # image_name = '%s_%s.png' % (name, label) 210 | # save_path = os.path.join(image_dir, image_name) 211 | # util.save_image(image_numpy, save_path) 212 | 213 | # ims.append(image_name) 214 | # txts.append(label) 215 | # links.append(image_name) 216 | # webpage.add_images(ims, txts, links, width=self.win_size) 217 | -------------------------------------------------------------------------------- /metrics/FID/fid_score.py: -------------------------------------------------------------------------------- 1 | """Calculates the Frechet Inception Distance (FID) to evalulate GANs 2 | 3 | The FID metric calculates the distance between two distributions of images. 4 | Typically, we have summary statistics (mean & covariance matrix) of one 5 | of these distributions, while the 2nd distribution is given by a GAN. 6 | 7 | When run as a stand-alone program, it compares the distribution of 8 | images that are stored as PNG/JPEG at a specified location with a 9 | distribution given by summary statistics (in pickle format). 10 | 11 | The FID is calculated by assuming that X_1 and X_2 are the activations of 12 | the pool_3 layer of the inception net for generated samples and real world 13 | samples respectively. 14 | 15 | See --help to see further details. 16 | 17 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead 18 | of Tensorflow 19 | 20 | Copyright 2018 Institute of Bioinformatics, JKU Linz 21 | 22 | Licensed under the Apache License, Version 2.0 (the "License"); 23 | you may not use this file except in compliance with the License. 24 | You may obtain a copy of the License at 25 | 26 | http://www.apache.org/licenses/LICENSE-2.0 27 | 28 | Unless required by applicable law or agreed to in writing, software 29 | distributed under the License is distributed on an "AS IS" BASIS, 30 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 31 | See the License for the specific language governing permissions and 32 | limitations under the License. 33 | """ 34 | import os 35 | import pathlib 36 | from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser 37 | 38 | import numpy as np 39 | import torch 40 | import torchvision.transforms as TF 41 | from PIL import Image 42 | from scipy import linalg 43 | from torch.nn.functional import adaptive_avg_pool2d 44 | 45 | try: 46 | from tqdm import tqdm 47 | except ImportError: 48 | # If tqdm is not available, provide a mock version of it 49 | def tqdm(x): 50 | return x 51 | 52 | from .inception import InceptionV3 53 | 54 | IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm', 55 | 'tif', 'tiff', 'webp'} 56 | 57 | class ImagePathDataset(torch.utils.data.Dataset): 58 | def __init__(self, files, transforms=None): 59 | self.files = files 60 | self.transforms = transforms 61 | 62 | def __len__(self): 63 | return len(self.files) 64 | 65 | def __getitem__(self, i): 66 | path = self.files[i] 67 | img = Image.open(path).convert('RGB') 68 | if self.transforms is not None: 69 | img = self.transforms(img) 70 | return img 71 | 72 | 73 | def get_activations(files, model, batch_size=50, dims=2048, device='cpu', 74 | num_workers=1): 75 | """Calculates the activations of the pool_3 layer for all images. 76 | 77 | Params: 78 | -- files : List of image files paths 79 | -- model : Instance of inception model 80 | -- batch_size : Batch size of images for the model to process at once. 81 | Make sure that the number of samples is a multiple of 82 | the batch size, otherwise some samples are ignored. This 83 | behavior is retained to match the original FID score 84 | implementation. 85 | -- dims : Dimensionality of features returned by Inception 86 | -- device : Device to run calculations 87 | -- num_workers : Number of parallel dataloader workers 88 | 89 | Returns: 90 | -- A numpy array of dimension (num images, dims) that contains the 91 | activations of the given tensor when feeding inception with the 92 | query tensor. 93 | """ 94 | model.eval() 95 | 96 | if batch_size > len(files): 97 | print(('Warning: batch size is bigger than the data size. ' 98 | 'Setting batch size to data size')) 99 | batch_size = len(files) 100 | 101 | dataset = ImagePathDataset(files, transforms=TF.ToTensor()) 102 | dataloader = torch.utils.data.DataLoader(dataset, 103 | batch_size=batch_size, 104 | shuffle=False, 105 | drop_last=False, 106 | num_workers=num_workers) 107 | 108 | pred_arr = np.empty((len(files), dims)) 109 | 110 | start_idx = 0 111 | 112 | for batch in tqdm(dataloader): 113 | batch = batch.to(device) 114 | 115 | with torch.no_grad(): 116 | pred = model(batch)[0] 117 | 118 | # If model output is not scalar, apply global spatial average pooling. 119 | # This happens if you choose a dimensionality not equal 2048. 120 | if pred.size(2) != 1 or pred.size(3) != 1: 121 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 122 | 123 | pred = pred.squeeze(3).squeeze(2).cpu().numpy() 124 | 125 | pred_arr[start_idx:start_idx + pred.shape[0]] = pred 126 | 127 | start_idx = start_idx + pred.shape[0] 128 | 129 | return pred_arr 130 | 131 | 132 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 133 | """Numpy implementation of the Frechet Distance. 134 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 135 | and X_2 ~ N(mu_2, C_2) is 136 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 137 | 138 | Stable version by Dougal J. Sutherland. 139 | 140 | Params: 141 | -- mu1 : Numpy array containing the activations of a layer of the 142 | inception net (like returned by the function 'get_predictions') 143 | for generated samples. 144 | -- mu2 : The sample mean over activations, precalculated on an 145 | representative data set. 146 | -- sigma1: The covariance matrix over activations for generated samples. 147 | -- sigma2: The covariance matrix over activations, precalculated on an 148 | representative data set. 149 | 150 | Returns: 151 | -- : The Frechet Distance. 152 | """ 153 | 154 | mu1 = np.atleast_1d(mu1) 155 | mu2 = np.atleast_1d(mu2) 156 | 157 | sigma1 = np.atleast_2d(sigma1) 158 | sigma2 = np.atleast_2d(sigma2) 159 | 160 | assert mu1.shape == mu2.shape, \ 161 | 'Training and test mean vectors have different lengths' 162 | assert sigma1.shape == sigma2.shape, \ 163 | 'Training and test covariances have different dimensions' 164 | 165 | diff = mu1 - mu2 166 | 167 | # Product might be almost singular 168 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 169 | if not np.isfinite(covmean).all(): 170 | msg = ('fid calculation produces singular product; ' 171 | 'adding %s to diagonal of cov estimates') % eps 172 | print(msg) 173 | offset = np.eye(sigma1.shape[0]) * eps 174 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 175 | 176 | # Numerical error might give slight imaginary component 177 | if np.iscomplexobj(covmean): 178 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 179 | m = np.max(np.abs(covmean.imag)) 180 | raise ValueError('Imaginary component {}'.format(m)) 181 | covmean = covmean.real 182 | 183 | tr_covmean = np.trace(covmean) 184 | 185 | return (diff.dot(diff) + np.trace(sigma1) 186 | + np.trace(sigma2) - 2 * tr_covmean) 187 | 188 | 189 | def calculate_activation_statistics(files, model, batch_size=50, dims=2048, 190 | device='cpu', num_workers=1): 191 | """Calculation of the statistics used by the FID. 192 | Params: 193 | -- files : List of image files paths 194 | -- model : Instance of inception model 195 | -- batch_size : The images numpy array is split into batches with 196 | batch size batch_size. A reasonable batch size 197 | depends on the hardware. 198 | -- dims : Dimensionality of features returned by Inception 199 | -- device : Device to run calculations 200 | -- num_workers : Number of parallel dataloader workers 201 | 202 | Returns: 203 | -- mu : The mean over samples of the activations of the pool_3 layer of 204 | the inception model. 205 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 206 | the inception model. 207 | """ 208 | act = get_activations(files, model, batch_size, dims, device, num_workers) 209 | mu = np.mean(act, axis=0) 210 | sigma = np.cov(act, rowvar=False) 211 | return mu, sigma 212 | 213 | 214 | def compute_statistics_of_path(path, model, batch_size, dims, device, 215 | num_workers=1): 216 | # if path.endswith('.npz'): 217 | # with np.load(path) as f: 218 | # m, s = f['mu'][:], f['sigma'][:] 219 | # else: 220 | # path = pathlib.Path(path) 221 | # files = sorted([file for ext in IMAGE_EXTENSIONS 222 | # for file in path.glob('*.{}'.format(ext))]) 223 | m, s = calculate_activation_statistics(path, model, batch_size, 224 | dims, device, num_workers) 225 | 226 | return m, s 227 | 228 | 229 | def calculate_fid_given_paths(paths, batch_size, device, dims, num_workers=1): 230 | """Calculates the FID of two paths""" 231 | for p1 in paths: 232 | for p in p1: 233 | if not os.path.exists(p): 234 | raise RuntimeError('Invalid path: %s' % p) 235 | 236 | if dims not in list(InceptionV3.BLOCK_INDEX_BY_DIM): 237 | raise Exception("Dims not accepted. Please digit another value.") 238 | 239 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 240 | 241 | model = InceptionV3([block_idx]).to(device) 242 | 243 | m1, s1 = compute_statistics_of_path(paths[0], model, batch_size, 244 | dims, device, num_workers) 245 | m2, s2 = compute_statistics_of_path(paths[1], model, batch_size, 246 | dims, device, num_workers) 247 | fid_value = calculate_frechet_distance(m1, s1, m2, s2) 248 | 249 | return fid_value 250 | 251 | 252 | -------------------------------------------------------------------------------- /metrics/LPIPS/README.md: -------------------------------------------------------------------------------- 1 | 2 | ## Perceptual Similarity Metric and Dataset [[Project Page]](http://richzhang.github.io/PerceptualSimilarity/) 3 | 4 | **The Unreasonable Effectiveness of Deep Features as a Perceptual Metric** 5 | [Richard Zhang](https://richzhang.github.io/), [Phillip Isola](http://web.mit.edu/phillipi/), [Alexei A. Efros](http://www.eecs.berkeley.edu/~efros/), [Eli Shechtman](https://research.adobe.com/person/eli-shechtman/), [Oliver Wang](http://www.oliverwang.info/). 6 |
In [CVPR](https://arxiv.org/abs/1801.03924), 2018. 7 | 8 | 9 | 10 | This repository contains our **perceptual metric (LPIPS)** and **dataset (BAPPS)**. It can also be used as a "perceptual loss". This uses PyTorch; a Tensorflow alternative is [here](https://github.com/alexlee-gk/lpips-tensorflow). 11 | 12 | **Table of Contents**
13 | 1. [Learned Perceptual Image Patch Similarity (LPIPS) metric](#1-learned-perceptual-image-patch-similarity-lpips-metric)
14 | a. [Basic Usage](#a-basic-usage) If you just want to run the metric through command line, this is all you need.
15 | b. ["Perceptual Loss" usage](#b-backpropping-through-the-metric)
16 | c. [About the metric](#c-about-the-metric)
17 | 2. [Berkeley-Adobe Perceptual Patch Similarity (BAPPS) dataset](#2-berkeley-adobe-perceptual-patch-similarity-bapps-dataset)
18 | a. [Download](#a-downloading-the-dataset)
19 | b. [Evaluation](#b-evaluating-a-perceptual-similarity-metric-on-a-dataset)
20 | c. [About the dataset](#c-about-the-dataset)
21 | d. [Train the metric using the dataset](#d-using-the-dataset-to-train-the-metric)
22 | 23 | ## (0) Dependencies/Setup 24 | 25 | ### Installation 26 | - Install PyTorch 1.0+ and torchvision fom http://pytorch.org 27 | 28 | ```bash 29 | pip install -r requirements.txt 30 | ``` 31 | - Clone this repo: 32 | ```bash 33 | git clone https://github.com/richzhang/PerceptualSimilarity 34 | cd PerceptualSimilarity 35 | ``` 36 | 37 | ## (1) Learned Perceptual Image Patch Similarity (LPIPS) metric 38 | 39 | Evaluate the distance between image patches. **Higher means further/more different. Lower means more similar.** 40 | 41 | ### (A) Basic Usage 42 | 43 | #### (A.I) Line commands 44 | 45 | Example scripts to take the distance between 2 specific images, all corresponding pairs of images in 2 directories, or all pairs of images within a directory: 46 | 47 | ``` 48 | python compute_dists.py -p0 imgs/ex_ref.png -p1 imgs/ex_p0.png --use_gpu 49 | python compute_dists_dirs.py -d0 imgs/ex_dir0 -d1 imgs/ex_dir1 -o imgs/example_dists.txt --use_gpu 50 | python compute_dists_pair.py -d imgs/ex_dir_pair -o imgs/example_dists_pair.txt --use_gpu 51 | ``` 52 | 53 | #### (A.II) Python code 54 | 55 | File [test_network.py](test_network.py) shows example usage. This snippet is all you really need. 56 | 57 | ```python 58 | import models 59 | model = models.PerceptualLoss(model='net-lin', net='alex', use_gpu=use_gpu, gpu_ids=[0]) 60 | d = model.forward(im0,im1) 61 | ``` 62 | 63 | Variables ```im0, im1``` is a PyTorch Tensor/Variable with shape ```Nx3xHxW``` (```N``` patches of size ```HxW```, RGB images scaled in `[-1,+1]`). This returns `d`, a length `N` Tensor/Variable. 64 | 65 | Run `python test_network.py` to take the distance between example reference image [`ex_ref.png`](imgs/ex_ref.png) to distorted images [`ex_p0.png`](./imgs/ex_p0.png) and [`ex_p1.png`](imgs/ex_p1.png). Before running it - which do you think *should* be closer? 66 | 67 | **Some Options** By default in `model.initialize`: 68 | - `net='alex'`: Network `alex` is fastest, performs the best, and is the default. You can instead use `squeeze` or `vgg`. 69 | - `model='net-lin'`: This adds a linear calibration on top of intermediate features in the net. Set this to `model=net` to equally weight all the features. 70 | 71 | ### (B) Backpropping through the metric 72 | 73 | File [`perceptual_loss.py`](perceptual_loss.py) shows how to iteratively optimize using the metric. Run `python perceptual_loss.py` for a demo. The code can also be used to implement vanilla VGG loss, without our learned weights. 74 | 75 | ### (C) About the metric 76 | 77 | **Higher means further/more different. Lower means more similar.** 78 | 79 | We found that deep network activations work surprisingly well as a perceptual similarity metric. This was true across network architectures (SqueezeNet [2.8 MB], AlexNet [9.1 MB], and VGG [58.9 MB] provided similar scores) and supervisory signals (unsupervised, self-supervised, and supervised all perform strongly). We slightly improved scores by linearly "calibrating" networks - adding a linear layer on top of off-the-shelf classification networks. We provide 3 variants, using linear layers on top of the SqueezeNet, AlexNet (default), and VGG networks. 80 | 81 | If you use LPIPS in your publication, please specify which version you are using. The current version is 0.1. You can set `version='0.0'` for the initial release. 82 | 83 | ## (2) Berkeley Adobe Perceptual Patch Similarity (BAPPS) dataset 84 | 85 | ### (A) Downloading the dataset 86 | 87 | Run `bash ./scripts/download_dataset.sh` to download and unzip the dataset into directory `./dataset`. It takes [6.6 GB] total. Alternatively, run `bash ./scripts/get_dataset_valonly.sh` to only download the validation set [1.3 GB]. 88 | - 2AFC train [5.3 GB] 89 | - 2AFC val [1.1 GB] 90 | - JND val [0.2 GB] 91 | 92 | ### (B) Evaluating a perceptual similarity metric on a dataset 93 | 94 | Script `test_dataset_model.py` evaluates a perceptual model on a subset of the dataset. 95 | 96 | **Dataset flags** 97 | - `--dataset_mode`: `2afc` or `jnd`, which type of perceptual judgment to evaluate 98 | - `--datasets`: list the datasets to evaluate 99 | - if `--dataset_mode 2afc`: choices are [`train/traditional`, `train/cnn`, `val/traditional`, `val/cnn`, `val/superres`, `val/deblur`, `val/color`, `val/frameinterp`] 100 | - if `--dataset_mode jnd`: choices are [`val/traditional`, `val/cnn`] 101 | 102 | **Perceptual similarity model flags** 103 | - `--model`: perceptual similarity model to use 104 | - `net-lin` for our LPIPS learned similarity model (linear network on top of internal activations of pretrained network) 105 | - `net` for a classification network (uncalibrated with all layers averaged) 106 | - `l2` for Euclidean distance 107 | - `ssim` for Structured Similarity Image Metric 108 | - `--net`: [`squeeze`,`alex`,`vgg`] for the `net-lin` and `net` models; ignored for `l2` and `ssim` models 109 | - `--colorspace`: choices are [`Lab`,`RGB`], used for the `l2` and `ssim` models; ignored for `net-lin` and `net` models 110 | 111 | **Misc flags** 112 | - `--batch_size`: evaluation batch size (will default to 1) 113 | - `--use_gpu`: turn on this flag for GPU usage 114 | 115 | An example usage is as follows: `python ./test_dataset_model.py --dataset_mode 2afc --datasets val/traditional val/cnn --model net-lin --net alex --use_gpu --batch_size 50`. This would evaluate our model on the "traditional" and "cnn" validation datasets. 116 | 117 | ### (C) About the dataset 118 | 119 | The dataset contains two types of perceptual judgements: **Two Alternative Forced Choice (2AFC)** and **Just Noticeable Differences (JND)**. 120 | 121 | **(1) 2AFC** Evaluators were given a patch triplet (1 reference + 2 distorted). They were asked to select which of the distorted was "closer" to the reference. 122 | 123 | Training sets contain 2 judgments/triplet. 124 | - `train/traditional` [56.6k triplets] 125 | - `train/cnn` [38.1k triplets] 126 | - `train/mix` [56.6k triplets] 127 | 128 | Validation sets contain 5 judgments/triplet. 129 | - `val/traditional` [4.7k triplets] 130 | - `val/cnn` [4.7k triplets] 131 | - `val/superres` [10.9k triplets] 132 | - `val/deblur` [9.4k triplets] 133 | - `val/color` [4.7k triplets] 134 | - `val/frameinterp` [1.9k triplets] 135 | 136 | Each 2AFC subdirectory contains the following folders: 137 | - `ref`: original reference patches 138 | - `p0,p1`: two distorted patches 139 | - `judge`: human judgments - 0 if all preferred p0, 1 if all humans preferred p1 140 | 141 | **(2) JND** Evaluators were presented with two patches - a reference and a distorted - for a limited time. They were asked if the patches were the same (identically) or different. 142 | 143 | Each set contains 3 human evaluations/example. 144 | - `val/traditional` [4.8k pairs] 145 | - `val/cnn` [4.8k pairs] 146 | 147 | Each JND subdirectory contains the following folders: 148 | - `p0,p1`: two patches 149 | - `same`: human judgments: 0 if all humans thought patches were different, 1 if all humans thought patches were same 150 | 151 | ### (D) Using the dataset to train the metric 152 | 153 | See script `train_test_metric.sh` for an example of training and testing the metric. The script will train a model on the full training set for 10 epochs, and then test the learned metric on all of the validation sets. The numbers should roughly match the **Alex - lin** row in Table 5 in the [paper](https://arxiv.org/abs/1801.03924). The code supports training a linear layer on top of an existing representation. Training will add a subdirectory in the `checkpoints` directory. 154 | 155 | You can also train "scratch" and "tune" versions by running `train_test_metric_scratch.sh` and `train_test_metric_tune.sh`, respectively. 156 | 157 | ### Docker Environment 158 | 159 | [Docker](https://hub.docker.com/r/shinyeyes/perceptualsimilarity/) set up by [SuperShinyEyes](https://github.com/SuperShinyEyes). 160 | 161 | ## Citation 162 | 163 | If you find this repository useful for your research, please use the following. 164 | 165 | ``` 166 | @inproceedings{zhang2018perceptual, 167 | title={The Unreasonable Effectiveness of Deep Features as a Perceptual Metric}, 168 | author={Zhang, Richard and Isola, Phillip and Efros, Alexei A and Shechtman, Eli and Wang, Oliver}, 169 | booktitle={CVPR}, 170 | year={2018} 171 | } 172 | ``` 173 | 174 | ## Acknowledgements 175 | 176 | This repository borrows partially from the [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) repository. The average precision (AP) code is borrowed from the [py-faster-rcnn](https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/datasets/voc_eval.py) repository. Backpropping through the metric was implemented by [Angjoo Kanazawa](https://github.com/akanazawa). 177 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "{}" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright {yyyy} {name of copyright owner} 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /metrics/FID/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /metrics/LPIPS/models/dist_model.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | import sys 5 | import numpy as np 6 | import torch 7 | from torch import nn 8 | import os 9 | from collections import OrderedDict 10 | from torch.autograd import Variable 11 | import itertools 12 | from .base_model import BaseModel 13 | from scipy.ndimage import zoom 14 | import fractions 15 | import functools 16 | import skimage.transform 17 | from tqdm import tqdm 18 | 19 | from IPython import embed 20 | 21 | from . import networks_basic as networks 22 | import models as util 23 | 24 | class DistModel(BaseModel): 25 | def name(self): 26 | return self.model_name 27 | 28 | def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None, 29 | use_gpu=True, printNet=False, spatial=False, 30 | is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]): 31 | ''' 32 | INPUTS 33 | model - ['net-lin'] for linearly calibrated network 34 | ['net'] for off-the-shelf network 35 | ['L2'] for L2 distance in Lab colorspace 36 | ['SSIM'] for ssim in RGB colorspace 37 | net - ['squeeze','alex','vgg'] 38 | model_path - if None, will look in weights/[NET_NAME].pth 39 | colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM 40 | use_gpu - bool - whether or not to use a GPU 41 | printNet - bool - whether or not to print network architecture out 42 | spatial - bool - whether to output an array containing varying distances across spatial dimensions 43 | is_train - bool - [True] for training mode 44 | lr - float - initial learning rate 45 | beta1 - float - initial momentum term for adam 46 | version - 0.1 for latest, 0.0 was original (with a bug) 47 | gpu_ids - int array - [0] by default, gpus to use 48 | ''' 49 | BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids) 50 | 51 | self.model = model 52 | self.net = net 53 | self.is_train = is_train 54 | self.spatial = spatial 55 | self.gpu_ids = gpu_ids 56 | self.model_name = '%s [%s]'%(model,net) 57 | 58 | if(self.model == 'net-lin'): # pretrained net + linear layer 59 | self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net, 60 | use_dropout=True, spatial=spatial, version=version, lpips=True) 61 | kw = {} 62 | if not use_gpu: 63 | kw['map_location'] = 'cpu' 64 | if(model_path is None): 65 | import inspect 66 | model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', 'weights/v%s/%s.pth'%(version,net))) 67 | 68 | if(not is_train): 69 | print('Loading model from: %s'%model_path) 70 | self.net.load_state_dict(torch.load(model_path, **kw), strict=False) 71 | 72 | elif(self.model=='net'): # pretrained network 73 | self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False) 74 | elif(self.model in ['L2','l2']): 75 | self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing 76 | self.model_name = 'L2' 77 | elif(self.model in ['DSSIM','dssim','SSIM','ssim']): 78 | self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace) 79 | self.model_name = 'SSIM' 80 | else: 81 | raise ValueError("Model [%s] not recognized." % self.model) 82 | 83 | self.parameters = list(self.net.parameters()) 84 | 85 | if self.is_train: # training mode 86 | # extra network on top to go from distances (d0,d1) => predicted human judgment (h*) 87 | self.rankLoss = networks.BCERankingLoss() 88 | self.parameters += list(self.rankLoss.net.parameters()) 89 | self.lr = lr 90 | self.old_lr = lr 91 | self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999)) 92 | else: # test mode 93 | self.net.eval() 94 | 95 | if(use_gpu): 96 | self.net.to(gpu_ids[0]) 97 | self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids) 98 | if(self.is_train): 99 | self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0 100 | 101 | if(printNet): 102 | print('---------- Networks initialized -------------') 103 | networks.print_network(self.net) 104 | print('-----------------------------------------------') 105 | 106 | def forward(self, in0, in1, retPerLayer=False): 107 | ''' Function computes the distance between image patches in0 and in1 108 | INPUTS 109 | in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1] 110 | OUTPUT 111 | computed distances between in0 and in1 112 | ''' 113 | 114 | return self.net.forward(in0, in1, retPerLayer=retPerLayer) 115 | 116 | # ***** TRAINING FUNCTIONS ***** 117 | def optimize_parameters(self): 118 | self.forward_train() 119 | self.optimizer_net.zero_grad() 120 | self.backward_train() 121 | self.optimizer_net.step() 122 | self.clamp_weights() 123 | 124 | def clamp_weights(self): 125 | for module in self.net.modules(): 126 | if(hasattr(module, 'weight') and module.kernel_size==(1,1)): 127 | module.weight.data = torch.clamp(module.weight.data,min=0) 128 | 129 | def set_input(self, data): 130 | self.input_ref = data['ref'] 131 | self.input_p0 = data['p0'] 132 | self.input_p1 = data['p1'] 133 | self.input_judge = data['judge'] 134 | 135 | if(self.use_gpu): 136 | self.input_ref = self.input_ref.to(device=self.gpu_ids[0]) 137 | self.input_p0 = self.input_p0.to(device=self.gpu_ids[0]) 138 | self.input_p1 = self.input_p1.to(device=self.gpu_ids[0]) 139 | self.input_judge = self.input_judge.to(device=self.gpu_ids[0]) 140 | 141 | self.var_ref = Variable(self.input_ref,requires_grad=True) 142 | self.var_p0 = Variable(self.input_p0,requires_grad=True) 143 | self.var_p1 = Variable(self.input_p1,requires_grad=True) 144 | 145 | def forward_train(self): # run forward pass 146 | # print(self.net.module.scaling_layer.shift) 147 | # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item()) 148 | 149 | self.d0 = self.forward(self.var_ref, self.var_p0) 150 | self.d1 = self.forward(self.var_ref, self.var_p1) 151 | self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge) 152 | 153 | self.var_judge = Variable(1.*self.input_judge).view(self.d0.size()) 154 | 155 | self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.) 156 | 157 | return self.loss_total 158 | 159 | def backward_train(self): 160 | torch.mean(self.loss_total).backward() 161 | 162 | def compute_accuracy(self,d0,d1,judge): 163 | ''' d0, d1 are Variables, judge is a Tensor ''' 164 | d1_lt_d0 = (d1 %f' % (type,self.old_lr, lr)) 207 | self.old_lr = lr 208 | 209 | def score_2afc_dataset(data_loader, func, name=''): 210 | ''' Function computes Two Alternative Forced Choice (2AFC) score using 211 | distance function 'func' in dataset 'data_loader' 212 | INPUTS 213 | data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside 214 | func - callable distance function - calling d=func(in0,in1) should take 2 215 | pytorch tensors with shape Nx3xXxY, and return numpy array of length N 216 | OUTPUTS 217 | [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators 218 | [1] - dictionary with following elements 219 | d0s,d1s - N arrays containing distances between reference patch to perturbed patches 220 | gts - N array in [0,1], preferred patch selected by human evaluators 221 | (closer to "0" for left patch p0, "1" for right patch p1, 222 | "0.6" means 60pct people preferred right patch, 40pct preferred left) 223 | scores - N array in [0,1], corresponding to what percentage function agreed with humans 224 | CONSTS 225 | N - number of test triplets in data_loader 226 | ''' 227 | 228 | d0s = [] 229 | d1s = [] 230 | gts = [] 231 | 232 | for data in tqdm(data_loader.load_data(), desc=name): 233 | d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist() 234 | d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist() 235 | gts+=data['judge'].cpu().numpy().flatten().tolist() 236 | 237 | d0s = np.array(d0s) 238 | d1s = np.array(d1s) 239 | gts = np.array(gts) 240 | scores = (d0s