├── models ├── __init__.py ├── components.py ├── networks.py └── vgg19.py ├── utils ├── __init__.py ├── logger.py ├── utils.py ├── metric.py ├── img_process.py ├── face_sketch_data.py ├── search_dataset.py ├── loss.py └── FeatureSIM.m ├── example_img.png ├── download_feature.sh ├── .gitignore ├── download_data_models.sh ├── LICENSE ├── test.py ├── train.py ├── README.md ├── data_process └── face_rectify.py └── face2sketch_wild.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /example_img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaofengc/Face-Sketch-Wild/HEAD/example_img.png -------------------------------------------------------------------------------- /download_feature.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | PROJECT_DIR=$PWD 4 | echo 'Downloading Precalculated Features' 5 | wget http://www.visionlab.cs.hku.hk/data/Face-Sketch-Wild/features.tgz -P $PROJECT_DIR/data 6 | 7 | 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | others/ 2 | data/ 3 | data_process/shape_predictor_68_face_landmarks.dat 4 | result/ 5 | result_all/ 6 | result_ours/ 7 | test/ 8 | weight/ 9 | eval_code/ 10 | pretrain_model/ 11 | clean_git_cache.py 12 | 13 | # Byte-compiled / optimized / DLL files 14 | __pycache__/ 15 | *.py[cod] 16 | *$py.class 17 | 18 | # C extensions 19 | *.so 20 | 21 | -------------------------------------------------------------------------------- /download_data_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | PROJECT_DIR=$PWD 4 | mkdir $PROJECT_DIR/data 5 | mkdir $PROJECT_DIR/pretrain_model 6 | echo 'Downloading Datasets' 7 | wget http://www.visionlab.cs.hku.hk/data/Face-Sketch-Wild/datasets.tgz -P $PROJECT_DIR/data 8 | echo 'Downloading Pretrain Models' 9 | wget http://www.visionlab.cs.hku.hk/data/Face-Sketch-Wild/models.tgz -P $PROJECT_DIR/pretrain_model 10 | 11 | 12 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | from collections import OrderedDict 4 | import numpy as np 5 | 6 | import matplotlib 7 | matplotlib.use('Agg') 8 | import matplotlib.pyplot as plt 9 | 10 | class Logger(): 11 | def __init__(self, save_weight_path): 12 | self.log_dir = save_weight_path 13 | self.iter_log = [] 14 | 15 | def iterLogUpdate(self, loss): 16 | """ 17 | iteration log: [iter][loss] 18 | """ 19 | self.iter_log.append(loss) 20 | 21 | def draw_loss_curve(self): 22 | fig = plt.figure() 23 | ax = fig.add_subplot(111) 24 | ax.plot(range(len(self.iter_log)), self.iter_log) 25 | ax.set_title('Loss Curve') 26 | plt.tight_layout() 27 | fig.savefig(os.path.join(self.log_dir, 'epoch_summary.pdf')) 28 | plt.close(fig) 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Chen Chaofeng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | gpus = '2' 5 | if sys.argv[1] == '1': 6 | test_dir = './data/CUFS/test_photos' 7 | test_gt_dir = './data/CUFS/test_sketches' 8 | result_dir = './result/CUFS' 9 | test_weight_path = './pretrain_model/cufs-epochs-026-meanshift30-G.pth' 10 | elif sys.argv[1] == '2': 11 | test_dir = './data/CUFSF/test_photos' 12 | test_gt_dir = './data/CUFSF/test_sketches' 13 | result_dir = './result/CUFSF' 14 | test_weight_path = './pretrain_model/cufsf-epochs-019-meanshift30-G.pth' 15 | elif sys.argv[1] == '3': 16 | test_dir = './data/CUHK_student/test_photos' 17 | test_gt_dir = './data/CUHK_student/test_sketches' 18 | result_dir = './result/CUHK_student' 19 | test_weight_path = './pretrain_model/cufs-epochs-026-meanshift30-G.pth' 20 | elif sys.argv[1] == '4': 21 | test_dir = './data/vgg_test/' 22 | test_gt_dir = 'none' 23 | result_dir = './result/VGG' 24 | test_weight_path = './pretrain_model/vgg-epochs-003-G.pth' 25 | 26 | param = [ 27 | '--gpus {}'.format(gpus), 28 | '--test-dir {}'.format(test_dir), 29 | '--test-gt-dir {}'.format(test_gt_dir), 30 | '--result-dir {}'.format(result_dir), 31 | '--test-weight-path {}'.format(test_weight_path), 32 | ] 33 | 34 | os.system('python face2sketch_wild.py eval {}'.format(" ".join(param))) 35 | 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import os 4 | 5 | def mkdirs(dirs): 6 | if isinstance(dirs, list): 7 | for i in dirs: 8 | if not os.path.exists(i): 9 | os.makedirs(i) 10 | elif isinstance(dirs, str): 11 | if not os.path.exists(dirs): 12 | os.makedirs(dirs) 13 | else: 14 | raise Exception('dirs should be list or string.') 15 | 16 | 17 | def to_device(tensor): 18 | """ 19 | Move tensor to device. If GPU is is_available, move to GPU. 20 | """ 21 | if torch.cuda.is_available(): 22 | return tensor.cuda() 23 | else: 24 | return tensor 25 | 26 | 27 | def tensorToVar(tensor): 28 | """ 29 | Convert a tensor to Variable 30 | If cuda is avaible, move to GPU 31 | """ 32 | if torch.cuda.is_available(): 33 | return Variable(tensor.cuda()) 34 | else: 35 | return Variable(tensor) 36 | 37 | 38 | def extract_patches(img, patch_size=(3, 3), stride=(1, 1)): 39 | """ 40 | Divide img into overlapping patches with stride = 1 41 | img: (b, c, h, w) 42 | output patches: (b, nH*nW, c, patch_size) 43 | """ 44 | assert type(patch_size) in [int, tuple], 'patch size should be int or tuple int' 45 | assert type(stride) in [int, tuple], 'stride size should be int or tuple int' 46 | if type(stride) is int: 47 | stride = (stride, stride) 48 | if type(patch_size) is int: 49 | patch_size = (patch_size, patch_size) 50 | patches = img.unfold(2, patch_size[0], stride[0]).unfold(3, patch_size[1], stride[1]) 51 | patches = patches.contiguous().view(img.shape[0], img.shape[1], -1, patch_size[0], patch_size[1]) 52 | patches = patches.transpose(1, 2) 53 | return patches 54 | 55 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | # gpus = '1,2' 5 | gpus = '1' 6 | seed = 12345 7 | batch_size = 6 8 | learning_rate = 1e-3 9 | epochs = 40 10 | vgg_weight = './pretrain_model/vgg_conv.pth' 11 | weight_root = './weight' 12 | Gnorm = 'IN' 13 | Dnorm = 'None' 14 | feature_layers = [0, 0, 1, 1, 1] 15 | resume = 0 16 | topk = 5 17 | vgg_select_num = 0 18 | meanshift = 30 19 | weight = [1e0, 1e3, 1e-5] 20 | train_style = 'cufs' # style loss, adv loss, tv loss 21 | other = 'vgg{:02d}-meanshift{}-{}'.format(vgg_select_num, meanshift, seed) 22 | train_data = [ 23 | './data/AR/train_photos', 24 | './data/CUHK_student/train_photos', 25 | './data/XM2VTS/train_photos', 26 | './data/CUFSF/train_photos', 27 | ] 28 | if vgg_select_num: 29 | train_data.append('./data/vggface_{:02d}/'.format(vgg_select_num)) 30 | param = [ 31 | '--gpus {}'.format(gpus), 32 | '--train-data {}'.format(" ".join(train_data)), 33 | '--train-style {}'.format(train_style), 34 | '--batch-size {}'.format(batch_size), 35 | '--epochs {}'.format(epochs), 36 | '--vgg19-weight {}'.format(vgg_weight), 37 | '--weight-root {}'.format(weight_root), 38 | '--Gnorm {}'.format(Gnorm), 39 | '--Dnorm {}'.format(Dnorm), 40 | '--weight {} {} {}'.format(*weight), 41 | '--flayers {} {} {} {} {}'.format(*feature_layers), 42 | '--topk {}'.format(topk), 43 | '--other {}'.format(other), 44 | '--resume {}'.format(resume), 45 | '--seed {}'.format(seed), 46 | ] 47 | 48 | os.system('python face2sketch_wild.py train {}'.format(" ".join(param))) 49 | print(train_data, '\tdone, ') 50 | 51 | -------------------------------------------------------------------------------- /utils/metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage.measure import compare_ssim 3 | import matlab_wrapper 4 | import os 5 | import cv2 as cv 6 | from PIL import Image 7 | 8 | def FSIM(matlab, gt_img, test_img): 9 | """Calculate FSIM score. 10 | ------------------------- 11 | Use matlab wrapper to calculate fsim score. 12 | Codes come from: https://github.com/gregfreeman/image_quality_toolbox 13 | """ 14 | test_img = np.array(test_img) 15 | gt_img = np.array(gt_img) 16 | matlab.eval("addpath('./utils')") 17 | matlab.put('imageRef', gt_img) 18 | matlab.put('imageDis', test_img) 19 | matlab.eval('[score, fsimc] = FeatureSIM(imageRef, imageDis)') 20 | tmp_score = matlab.get('score') 21 | return tmp_score 22 | 23 | 24 | def SSIM(gt_img, test_img): 25 | """Calculate ssim score using skimage toolkit. 26 | """ 27 | test_img = np.array(test_img).astype(np.uint8) 28 | gt_img = np.array(gt_img).astype(np.uint8) 29 | tmp_score = compare_ssim(gt_img, test_img, gaussian_weights=True, sigma=1.5, use_sample_covariance=False) 30 | return tmp_score 31 | 32 | 33 | def avg_score(test_dir, gt_dir, metric_name='ssim', smooth=False, sigma=75, verbose=False): 34 | """ 35 | Read images from two folders and calculate the average score. 36 | """ 37 | metric_name = metric_name.lower() 38 | all_score = [] 39 | if metric_name == 'fsim': 40 | matlab = matlab_wrapper.MatlabSession() 41 | for name in sorted(sorted(os.listdir(gt_dir))): 42 | test_img = Image.open(os.path.join(test_dir, name)).convert('L') 43 | gt_img = Image.open(os.path.join(gt_dir, name)).convert('L') 44 | if smooth: 45 | test_img = cv.bilateralFilter(np.array(test_img),7,sigma,sigma) 46 | 47 | if metric_name == 'ssim': 48 | tmp_score = SSIM(gt_img, test_img) 49 | elif metric_name == 'fsim': 50 | tmp_score = FSIM(matlab, gt_img, test_img) 51 | if verbose: 52 | print('Image: {}, Metric: {}, Smooth: {}, Score: {}'.format(name, metric_name, smooth, tmp_score)) 53 | all_score.append(tmp_score) 54 | return np.mean(np.array(all_score)) 55 | 56 | -------------------------------------------------------------------------------- /utils/img_process.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms 3 | import numpy as np 4 | from PIL import Image, ImageDraw, ImageFilter 5 | import dlib 6 | 7 | from .utils import tensorToVar 8 | 9 | def read_img_var(img_path, color=1, size=None): 10 | """ 11 | Read image and convert it to Variable in 0~255. 12 | Args: 13 | img_path: str, test image path 14 | size: tuple, output size (1, C, W, H) 15 | """ 16 | if color: 17 | img = Image.open(img_path).convert('RGB') 18 | else: 19 | img = Image.open(img_path).convert('L') 20 | if size is not None: 21 | img = transforms.functional.resize(img, size) 22 | return tensorToVar(transforms.functional.to_tensor(img)).unsqueeze(0) * 255 23 | 24 | def read_sketch_var(img_path, color=1, size=None, addxy=1, DoG=1): 25 | """ 26 | Read image and convert it to Variable. 27 | Args: 28 | img_path: str, test image path 29 | size: tuple, output size (W, H) 30 | """ 31 | img = Image.open(img_path).convert('L') 32 | face_img = transforms.functional.resize(img, size) 33 | return tensorToVar(transforms.functional.to_tensor(face_img)) * 255 34 | 35 | 36 | def save_var_img(var, save_path=None, size=None): 37 | """ 38 | Post processing output Variable. 39 | Args: 40 | var: Variable, (1, C, H, W) 41 | """ 42 | out = var.squeeze().data.cpu().numpy() 43 | out[out>255] = 255 44 | out[out<0] = 0 45 | if len(out.shape) > 2: 46 | out = out.transpose(1, 2, 0) 47 | out = Image.fromarray(out.astype(np.uint8)).convert('RGB') 48 | if size: 49 | out = transforms.functional.resize(out, size) 50 | if save_path: 51 | out.save(save_path) 52 | return out 53 | 54 | def subtract_mean_batch(batch, img_type='face', sketch_mean_shift=0): 55 | """ 56 | Convert image batch to BGR and subtract imagenet mean 57 | Batch Size: (B, C, H, W), RGB 58 | Convert BGR to gray by: [0.114, 0.587, 0.299] 59 | """ 60 | vgg_mean_bgr = np.array([103.939, 116.779, 123.680]) 61 | sketch_mean = np.array([np.dot(vgg_mean_bgr, np.array([0.114, 0.587, 0.299]))]*3) 62 | if img_type == 'face': 63 | mean_bgr = vgg_mean_bgr 64 | elif img_type == 'sketch': 65 | mean_bgr = sketch_mean + sketch_mean_shift 66 | 67 | batch = batch[:, [2, 1, 0], :, :] 68 | batch = batch - tensorToVar(torch.Tensor(mean_bgr)).view(1, 3, 1, 1) 69 | return batch 70 | 71 | -------------------------------------------------------------------------------- /utils/face_sketch_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | from torchvision import transforms 4 | 5 | from PIL import Image, ImageEnhance 6 | import numpy as np 7 | import random 8 | import os 9 | 10 | 11 | class FaceDataset(Dataset): 12 | """ 13 | Face dataset. 14 | Args: 15 | img_dirs: dir list to read photo from. 16 | """ 17 | def __init__(self, img_dirs, shuffle=False, transform=None): 18 | self.shuffle = shuffle 19 | self.img_dirs = img_dirs 20 | self.img_names = self.__get_imgnames__() 21 | self.transform = transform 22 | 23 | def __get_imgnames__(self): 24 | tmp = [] 25 | for i in self.img_dirs: 26 | for name in os.listdir(i): 27 | tmp.append(os.path.join(i, name)) 28 | if self.shuffle: 29 | random.shuffle(tmp) 30 | return tmp 31 | 32 | def __len__(self): 33 | return len(self.img_names) 34 | 35 | def __getitem__(self, idx): 36 | face_path = self.img_names[idx] 37 | face = Image.open(face_path).convert('RGB') 38 | face_origin = Image.open(face_path).convert('RGB') 39 | sample = [face, face_origin] 40 | 41 | if self.transform: 42 | sample = self.transform(sample) 43 | return sample 44 | 45 | 46 | class Rescale(object): 47 | """ 48 | Rescale the image in a sample to a given size. 49 | 50 | Args: 51 | output_size: tuple, output image size (H, W) 52 | """ 53 | def __init__(self, output_size): 54 | assert isinstance(output_size, tuple) 55 | self.output_size = output_size 56 | 57 | def __call__(self, sample): 58 | for idx, i in enumerate(sample): 59 | sample[idx] = transforms.functional.resize(i, self.output_size) 60 | return sample 61 | 62 | 63 | class ToTensor(object): 64 | """Convert image to tensor, and normalize the value to [0, 255] 65 | """ 66 | def __call__(self, sample): 67 | for idx, i in enumerate(sample): 68 | sample[idx] = transforms.functional.to_tensor(i) * 255. 69 | return sample 70 | 71 | 72 | class ColorJitter(transforms.ColorJitter): 73 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, sharp=0.0): 74 | super(ColorJitter, self).__init__(brightness, contrast, saturation, hue) 75 | self.sharp = sharp 76 | 77 | def __call__(self, sample): 78 | img = sample[0] 79 | sharp_factor = np.random.uniform(max(0, 1 - self.sharp), 1 + self.sharp) 80 | enhancer = ImageEnhance.Sharpness(img) 81 | img = enhancer.enhance(sharp_factor) 82 | 83 | transform = self.get_params(self.brightness, self.contrast, 84 | self.saturation, self.hue) 85 | img = transform(img) 86 | sample[0] = img 87 | 88 | return sample 89 | 90 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Face Sketch Synthesis in the Wild 2 | 3 | PyTorch implementation for face sketch synthesis in the wild through semi-supervised learning. Here is an example: 4 | 5 | ![](example_img.png) 6 | 7 | [**Semi-Supervised Learning for Face Sketch Synthesis in the Wild.**](https://arxiv.org/abs/1812.04929) 8 | [Chaofeng Chen](https://chaofengc.github.io), [Wei Liu](http://www.visionlab.cs.hku.hk/people.html), [Xiao Tan](http://www.xtan.org/), [Kwan-Yee K. Wong](http://i.cs.hku.hk/~kykwong/). 9 | 10 | # Getting Started 11 | 12 | ## Prerequisite 13 | - Pytorch 0.3 14 | - torchvision 0.2 15 | - opencv-python 16 | - matlab_wrapper 17 | - Matlab (For FSIM evaluation) 18 | 19 | ## Datasets 20 | - We use [CUFS](http://mmlab.ie.cuhk.edu.hk/archive/facesketch.html) and [CUFSF](http://mmlab.ie.cuhk.edu.hk/archive/cufsf/) dataset provided by Chinese University of Hong Kong (CUHK) to train our networks. You can also download them from [HERE](http://www.ihitworld.com/RSLCR.html). 21 | - For the training of in the wild images, we use a subset of VGG-Face. 22 | 23 | ## Usage 24 | 25 | ### Download. 26 | Download the datasets and pretrained models using the following scripts. 27 | ``` 28 | bash download_data_models.sh 29 | ``` 30 | Download the precalculated features for fast patch matching. 31 | ``` 32 | bash download_feature.sh 33 | ``` 34 | If the server is not available, you can also download the resources from [BaiduYun](https://pan.baidu.com/s/1pKpSVj7trJhxXVp7MoECaA) or [GoogleDrive](https://drive.google.com/drive/folders/1CxURCNxV1MbfYRNLq3PFcQkBouP0MSUX?usp=sharing), and then extract the files to the corresponding directory according to the download scripts above. 35 | 36 | ### Quick Test 37 | After download the datasets and pretrain models, use the provided script to test the model 38 | ``` 39 | python test.py 1 # Test on CUFS test set 40 | python test.py 2 # Test on CUFSF test set 41 | python test.py 3 # Test on CUHK_Student test set 42 | python test.py 4 # Test on VGG test set 43 | ``` 44 | You can also test on your own test dataset. Simply change the `--test_dir` and `--test_weight_path`. If you have ground truth images, you can also specify `--test_gt_dir`. 45 | 46 | ### Train 47 | 1. Configure training process. 48 | - `vgg_select_num [0 or 10]`. `0`: no extra images in training. `10`: extra VGG-Face in training. **Only the largest vgg10 is provided here**. 49 | - `train_style [cufs, cufsf]`. use `cufs` or `cufsf` as the reference style. 50 | 51 | **The models in the paper are trained under 3 configurations**: 52 | - `--vgg_select_num 0 --train_style cufs`. Model evaluated on CUFS. 53 | - `--vgg_select_num 0 --train_style cufsf`. Model evaluated on CUFSF. 54 | - `--vgg_select_num 10 --train_style cufs`. Model evaluated on VGG-Face. 55 | 56 | 2. Train the model. 57 | ``` 58 | python train.py 59 | ``` 60 | 61 | # Citation 62 | 63 | If you find this code or the provided data useful in your research, please consider cite: 64 | ``` 65 | @inproceedings{chen2018face-sketch-wild, 66 | title={Semi-Supervised Learning for Face Sketch Synthesis in the Wild}, 67 | author={Chen, Chaofeng and Liu, Wei and Tan, Xiao and Wong, Kwan-Yee~K.}, 68 | booktitle={Asian Conference on Computer Vision (ACCV)}, 69 | year={2018}, 70 | } 71 | ``` 72 | 73 | # Resources 74 | 75 | [1] [Random Sampling and Locality Constraint for Face Sketch Synthesis](http://www.ihitworld.com/RSLCR.html) 76 | [2] [Real-Time Exemplar-Based Face Sketch Synthesis](https://ybsong00.github.io/eccv14/index.html) 77 | 78 | -------------------------------------------------------------------------------- /models/components.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class ConvLayer(nn.Module): 5 | """Convolution layer with reflection padding. 6 | """ 7 | def __init__(self, in_channels, out_channels, kernel_size, stride, bias): 8 | super(ConvLayer, self).__init__() 9 | reflection_padding = kernel_size // 2 10 | self.reflection_pad = nn.ReflectionPad2d(reflection_padding) 11 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias) 12 | 13 | def forward(self, x): 14 | out = self.reflection_pad(x) 15 | out = self.conv2d(out) 16 | return out 17 | 18 | 19 | class NormLayer(nn.Module): 20 | """Normalization layers 21 | ------------------- 22 | # Args 23 | - channels: input channels 24 | - norm_type: normalization types. in: instance normalization; bn: batch normalization. 25 | """ 26 | def __init__(self, channels, norm_type): 27 | super(NormLayer, self).__init__() 28 | norm_type = norm_type.lower() 29 | if norm_type == 'in': 30 | self.norm_func = nn.InstanceNorm2d(channels, affine=True) 31 | elif norm_type == 'bn': 32 | self.norm_func == nn.BatchNorm2d(channels, affine=True) 33 | elif norm_type == 'none': 34 | self.norm_func = lambda x: x 35 | else: 36 | assert 1==0, 'Norm type {} not supported yet'.format(norm_type) 37 | 38 | def forward(self, x): 39 | return self.norm_func(x) 40 | 41 | 42 | class ResidualBlock(torch.nn.Module): 43 | """ResidualBlock 44 | --------------------- 45 | introduced in: https://arxiv.org/abs/1512.03385 46 | recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html 47 | """ 48 | def __init__(self, channels, norm_type='IN'): 49 | super(ResidualBlock, self).__init__() 50 | self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1, bias=False) 51 | self.norm1 = NormLayer(channels, norm_type) 52 | self.norm2 = NormLayer(channels, norm_type) 53 | self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1, bias=False) 54 | self.relu = nn.ReLU() 55 | 56 | def forward(self, x): 57 | residual = x 58 | out = self.relu(self.norm1(self.conv1(x))) 59 | out = self.norm2(self.conv2(out)) 60 | out = out + residual 61 | return out 62 | 63 | 64 | class UpsampleConvLayer(torch.nn.Module): 65 | """UpsampleConvLayer 66 | -------------------- 67 | Upsamples the input and then does a convolution. 68 | This method produces less checkerboard effect compared to ConvTranspose2d, according to 69 | http://distill.pub/2016/deconv-checkerboard/ 70 | """ 71 | 72 | def __init__(self, in_channels, out_channels, kernel_size, stride, bias, upsample=None): 73 | super(UpsampleConvLayer, self).__init__() 74 | self.upsample = upsample 75 | if upsample: 76 | self.upsample_layer = torch.nn.Upsample(scale_factor=upsample) 77 | reflection_padding = kernel_size // 2 78 | self.reflection_pad = nn.ReflectionPad2d(reflection_padding) 79 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias) 80 | 81 | def forward(self, x): 82 | x_in = x 83 | if self.upsample: 84 | x_in = self.upsample_layer(x_in) 85 | out = self.reflection_pad(x_in) 86 | out = self.conv2d(out) 87 | return out 88 | 89 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | from components import * 2 | 3 | class SketchNet(nn.Module): 4 | """SketchNet, transform input RGB photo to gray sketch. 5 | --------------------- 6 | A U-Net architecture similar to: https://arxiv.org/pdf/1603.08155.pdf 7 | Codes borrowed from: https://github.com/pytorch/examples/tree/master/fast_neural_style 8 | """ 9 | def __init__(self, in_channels=3, out_channels=1, norm_type='IN'): 10 | super(SketchNet, self).__init__() 11 | # Downsample convolution layers 12 | self.conv1 = ConvLayer(in_channels, 32, kernel_size=3, stride=1, bias=False) 13 | self.norm1 = NormLayer(32, norm_type) 14 | self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2, bias=False) 15 | self.norm2 = NormLayer(64, norm_type) 16 | self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2, bias=False) 17 | self.norm3 = NormLayer(128, norm_type) 18 | 19 | # Residual layers 20 | self.res1 = ResidualBlock(128, norm_type) 21 | self.res2 = ResidualBlock(128, norm_type) 22 | self.res3 = ResidualBlock(128, norm_type) 23 | self.res4 = ResidualBlock(128, norm_type) 24 | self.res5 = ResidualBlock(128, norm_type) 25 | 26 | # Upsampling Layers 27 | self.deconv1 = UpsampleConvLayer(256, 64, kernel_size=3, stride=1, bias=False, upsample=2) 28 | self.norm4 = NormLayer(64, norm_type) 29 | self.deconv2 = UpsampleConvLayer(128, 32, kernel_size=3, stride=1, bias=False, upsample=2) 30 | self.norm5 = NormLayer(32, norm_type) 31 | self.deconv3 = ConvLayer(64, out_channels, kernel_size=3, stride=1, bias=True) 32 | 33 | # Non-linear layer 34 | self.relu = nn.ReLU(True) 35 | 36 | def forward(self, X): 37 | y_conv1 = self.relu(self.norm1(self.conv1(X))) 38 | y_conv2 = self.relu(self.norm2(self.conv2(y_conv1))) 39 | y_conv3 = self.relu(self.norm3(self.conv3(y_conv2))) 40 | y = self.res1(y_conv3) 41 | y = self.res2(y) 42 | y = self.res3(y) 43 | y = self.res4(y) 44 | y_deconv0 = self.res5(y) 45 | y_deconv0 = torch.cat((y_deconv0, y_conv3), 1) 46 | y_deconv1 = self.relu(self.norm4(self.deconv1(y_deconv0))) 47 | y_deconv1 = torch.cat((y_deconv1, y_conv2), 1) 48 | y_deconv2 = self.relu(self.norm5(self.deconv2(y_deconv1))) 49 | y_deconv2 = torch.cat((y_deconv2, y_conv1), 1) 50 | y = self.deconv3(y_deconv2) 51 | return y 52 | 53 | 54 | class DNet(nn.Module): 55 | """Discriminator network. 56 | """ 57 | def __init__(self, in_channels=1, norm_type='IN'): 58 | super(DNet, self).__init__() 59 | b = True if norm_type == 'none' else False 60 | self.net = nn.Sequential( 61 | ConvLayer(in_channels, 32, kernel_size=3, stride=2, bias=True), 62 | nn.ReLU(inplace=True), 63 | ConvLayer(32, 64, kernel_size=3, stride=2, bias=b), 64 | NormLayer(64, norm_type), 65 | nn.ReLU(inplace=True), 66 | ConvLayer(64, 128, kernel_size=3, stride=2, bias=b), 67 | NormLayer(128, norm_type), 68 | nn.ReLU(inplace=True), 69 | ConvLayer(128, 256, kernel_size=3, stride=2, bias=b), 70 | NormLayer(256, norm_type), 71 | nn.ReLU(inplace=True), 72 | ConvLayer(256, 1, kernel_size=3, stride=1, bias=True), 73 | ) 74 | 75 | def forward(self, x): 76 | return self.net(x) 77 | 78 | 79 | 80 | -------------------------------------------------------------------------------- /utils/search_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from . import img_process 5 | 6 | def get_real_sketch_batch(batch_size, img_name_list, dataset_filter): 7 | img_name_list_all = np.array([x.strip() for x in open(img_name_list).readlines()]) 8 | img_name_list = [] 9 | for idx, i in enumerate(img_name_list_all): 10 | for j in dataset_filter: 11 | if j in i: 12 | img_name_list.append(i) 13 | break 14 | sketch_name_list = [x.replace('train_photos', 'train_sketches') for x in img_name_list] 15 | sketch_name_list = np.array(sketch_name_list) 16 | img_batch = np.random.choice(sketch_name_list, batch_size, replace=False) 17 | img_batch = [img_process.read_img_var(x, 0, size=(224, 224)) for x in img_batch] 18 | return torch.stack(img_batch).squeeze(1) 19 | 20 | 21 | def find_photo_sketch_batch(photo_batch, dataset_path, img_name_list, vgg_model, 22 | topk=1, dataset_filter=['CUHK_student', 'AR'], compare_layer=['r51']): 23 | """ 24 | Search the dataset to find the topk matching image. 25 | """ 26 | dataset_all = torch.load(dataset_path) 27 | dataset_all = torch.autograd.Variable(dataset_all.type_as(photo_batch.data)) 28 | img_name_list_all = np.array([x.strip() for x in open(img_name_list).readlines()]) 29 | img_name_list = [] 30 | dataset_idx = [] 31 | for idx, i in enumerate(img_name_list_all): 32 | for j in dataset_filter: 33 | if j in i: 34 | img_name_list.append(i) 35 | dataset_idx.append(idx) 36 | break 37 | dataset = dataset_all[dataset_idx] 38 | img_name_list = np.array(img_name_list) 39 | 40 | photo_feat = vgg_model(img_process.subtract_mean_batch(photo_batch), compare_layer)[0] 41 | photo_feat = torch.nn.functional.normalize(photo_feat, p=2, dim=1).view(photo_feat.size(0), photo_feat.size(1), -1) 42 | dataset = torch.nn.functional.normalize(dataset, p=2, dim=1).view(dataset.size(0), dataset.size(1), -1) 43 | img_idx = [] 44 | for i in range(photo_feat.size(0)): 45 | dist = photo_feat[i].unsqueeze(0) * dataset 46 | dist = torch.sum(dist, -1) 47 | dist = torch.sum(dist, -1) 48 | _, best_idx = torch.topk(dist, topk, 0) 49 | img_idx += best_idx.data.cpu().tolist() 50 | 51 | match_img_list = img_name_list[img_idx] 52 | match_sketch_list = [x.replace('train_photos', 'train_sketches') for x in match_img_list] 53 | 54 | match_img_batch = [img_process.read_img_var(x, size=(224, 224)) for x in match_img_list] 55 | match_sketch_batch = [img_process.read_img_var(x, size=(224, 224)) for x in match_sketch_list] 56 | match_sketch_batch, match_img_batch = torch.stack(match_sketch_batch).squeeze(), torch.stack(match_img_batch).squeeze() 57 | 58 | return match_sketch_batch, match_img_batch 59 | 60 | def select_random_batch(ref_img_list, batch_size, dataset_filter=['CUHK_student', 'AR']): 61 | ref_img_list_all = np.array([x.strip() for x in open(ref_img_list).readlines()]) 62 | ref_img_list = [] 63 | for idx, i in enumerate(ref_img_list_all): 64 | for j in dataset_filter: 65 | if j in i: 66 | ref_img_list.append(i) 67 | break 68 | ref_img_list = np.array(ref_img_list) 69 | 70 | selected_ref_img = np.random.choice(ref_img_list, batch_size, replace=False) 71 | selected_ref_sketch = [x.replace('train_photos', 'train_sketches') for x in selected_ref_img] 72 | 73 | selected_ref_batch = [img_process.read_img_var(x, size=(224, 224)) for x in selected_ref_img] 74 | selected_sketch_batch = [img_process.read_img_var(x, size=(224, 224)) for x in selected_ref_sketch] 75 | selected_sketch_batch, selected_ref_batch = torch.stack(selected_sketch_batch).squeeze(1), torch.stack(selected_ref_batch).squeeze(1) 76 | return selected_ref_batch, selected_sketch_batch 77 | 78 | -------------------------------------------------------------------------------- /models/vgg19.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class VGG(nn.Module): 8 | """VGG19 model. 9 | --------------------- 10 | Codes borrowed from: https://github.com/leongatys/PytorchNeuralStyleTransfer 11 | """ 12 | def __init__(self, pool='max', pool_ks=2, pool_st=2): 13 | super(VGG, self).__init__() 14 | # vgg modules 15 | self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) 16 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 17 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 18 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 19 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1) 20 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 21 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 22 | self.conv3_4 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 23 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1) 24 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 25 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 26 | self.conv4_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 27 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 28 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 29 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 30 | self.conv5_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 31 | if pool == 'max': 32 | self.pool1 = nn.MaxPool2d(kernel_size=pool_ks, stride=pool_st) 33 | self.pool2 = nn.MaxPool2d(kernel_size=pool_ks, stride=pool_st) 34 | self.pool3 = nn.MaxPool2d(kernel_size=pool_ks, stride=pool_st) 35 | self.pool4 = nn.MaxPool2d(kernel_size=pool_ks, stride=pool_st) 36 | self.pool5 = nn.MaxPool2d(kernel_size=pool_ks, stride=pool_st) 37 | elif pool == 'avg': 38 | self.pool1 = nn.AvgPool2d(kernel_size=pool_ks, stride=pool_st) 39 | self.pool2 = nn.AvgPool2d(kernel_size=pool_ks, stride=pool_st) 40 | self.pool3 = nn.AvgPool2d(kernel_size=pool_ks, stride=pool_st) 41 | self.pool4 = nn.AvgPool2d(kernel_size=pool_ks, stride=pool_st) 42 | self.pool5 = nn.AvgPool2d(kernel_size=pool_ks, stride=pool_st) 43 | 44 | def forward(self, x, out_keys): 45 | if len(x.size()) == 3: 46 | x = x.unsqueeze(1).repeat(1, 3, 1, 1) 47 | elif x.size(1) == 1: 48 | x = x.repeat(1, 3, 1, 1) 49 | out = {} 50 | out['r11'] = F.relu(self.conv1_1(x)) 51 | out['r12'] = F.relu(self.conv1_2(out['r11'])) 52 | out['p1'] = self.pool1(out['r12']) 53 | out['r21'] = F.relu(self.conv2_1(out['p1'])) 54 | out['r22'] = F.relu(self.conv2_2(out['r21'])) 55 | out['p2'] = self.pool2(out['r22']) 56 | out['r31'] = F.relu(self.conv3_1(out['p2'])) 57 | out['r32'] = F.relu(self.conv3_2(out['r31'])) 58 | out['r33'] = F.relu(self.conv3_3(out['r32'])) 59 | out['r34'] = F.relu(self.conv3_4(out['r33'])) 60 | out['p3'] = self.pool3(out['r34']) 61 | out['r41'] = F.relu(self.conv4_1(out['p3'])) 62 | out['r42'] = F.relu(self.conv4_2(out['r41'])) 63 | out['r43'] = F.relu(self.conv4_3(out['r42'])) 64 | out['r44'] = F.relu(self.conv4_4(out['r43'])) 65 | out['p4'] = self.pool4(out['r44']) 66 | out['r51'] = F.relu(self.conv5_1(out['p4'])) 67 | out['r52'] = F.relu(self.conv5_2(out['r51'])) 68 | out['r53'] = F.relu(self.conv5_3(out['r52'])) 69 | out['r54'] = F.relu(self.conv5_4(out['r53'])) 70 | out['p5'] = self.pool5(out['r54']) 71 | return [out[key] for key in out_keys] 72 | 73 | 74 | def vgg19(model_path): 75 | vgg = VGG(pool_ks=2, pool_st=2) 76 | vgg.load_state_dict(torch.load(model_path)) 77 | for param in vgg.parameters(): 78 | param.requires_grad = False 79 | if torch.cuda.is_available(): 80 | vgg.cuda() 81 | return vgg 82 | -------------------------------------------------------------------------------- /data_process/face_rectify.py: -------------------------------------------------------------------------------- 1 | """ 2 | Rectify the face photo according to the paper: Real-Time Exemplar-Based Face Sketch Synthesis. 3 | shape: h=250, w=200 4 | position: left eye (x=75,y=125), right eye (x=125, y=125) 5 | 6 | This module use similarity transformation to roughly align the two eyes. 7 | Specifically, the transformation matrix can be written as: 8 | S = |s_x cos(\theta), sin(\theta) , t_x | 9 | |-sin(\theta) , s_y cos(\theta), t_y | 10 | There are 5 degrees in the above function, needs at least 3 points(x, y) to solve it. 11 | we can simply hallucinate a third point such that it forms an equilateral triangle with the two known points. 12 | 13 | Reference: 14 | http://www.learnopencv.com/average-face-opencv-c-python-tutorial/ 15 | http://blog.csdn.net/GraceDD/article/details/51382952 16 | """ 17 | import math 18 | import numpy as np 19 | import os 20 | 21 | import dlib 22 | import cv2 as cv 23 | from PIL import Image 24 | import matplotlib.pyplot as plt 25 | from natsort import natsorted 26 | 27 | def detect_fiducial_points(img, predictor_path): 28 | """ 29 | Detect face landmarks and return the mean points of left and right eyes. 30 | If there are multiple faces in one image, only select the first one. 31 | """ 32 | detector = dlib.get_frontal_face_detector() 33 | predictor = dlib.shape_predictor(predictor_path) 34 | dets = detector(img, 1) 35 | if len(dets) < 1: 36 | return [] 37 | for k, d in enumerate(dets): 38 | shape = predictor(img, d) 39 | break 40 | landmarks = [] 41 | for i in range(68): 42 | landmarks.append([shape.part(i).x, shape.part(i).y]) 43 | landmarks = np.array(landmarks) 44 | left_eye = landmarks[36:42] 45 | right_eye = landmarks[42:48] 46 | mouth = landmarks[48:68] 47 | return np.array([np.mean(left_eye, 0), np.mean(right_eye, 0)]).astype('int') 48 | 49 | 50 | def similarityTransform(inPoints, outPoints) : 51 | """ 52 | Calculate similarity transform: 53 | Input: 54 | (left eye, right eye) in (x, y) 55 | inPoints: (2, 2), numpy array. 56 | outPoints: (2, 2), numpy array 57 | Return: 58 | A partial affine transform. 59 | """ 60 | s60 = math.sin(60*math.pi/180) 61 | c60 = math.cos(60*math.pi/180) 62 | 63 | inPts = np.copy(inPoints).tolist() 64 | outPts = np.copy(outPoints).tolist() 65 | xin = c60*(inPts[0][0] - inPts[1][0]) - s60*(inPts[0][1] - inPts[1][1]) + inPts[1][0] 66 | yin = s60*(inPts[0][0] - inPts[1][0]) + c60*(inPts[0][1] - inPts[1][1]) + inPts[1][1] 67 | inPts.append([np.int(xin), np.int(yin)]) 68 | 69 | xout = c60*(outPts[0][0] - outPts[1][0]) - s60*(outPts[0][1] - outPts[1][1]) + outPts[1][0] 70 | yout = s60*(outPts[0][0] - outPts[1][0]) + c60*(outPts[0][1] - outPts[1][1]) + outPts[1][1] 71 | outPts.append([np.int(xout), np.int(yout)]) 72 | tform = cv.estimateRigidTransform(np.array([inPts]), np.array([outPts]), False) 73 | 74 | return tform 75 | 76 | def rectify_img(img_path, predictor_path): 77 | template_eye_pos = np.array([[75, 125], [125, 125]]) 78 | template_size = (200, 250) 79 | img = cv.imread(img_path) 80 | detected_eyes = detect_fiducial_points(np.array(img), predictor_path) 81 | if not len(detected_eyes): 82 | return None 83 | trans = similarityTransform(detected_eyes, template_eye_pos) 84 | rect_img = cv.warpAffine(img, trans, template_size) 85 | return rect_img 86 | 87 | def align_img(ref_path, src_path, predictor_path): 88 | ref_img = cv.imread(ref_path) 89 | src_img = cv.imread(src_path) 90 | 91 | ref_eyes = detect_fiducial_points(np.array(ref_img), predictor_path) 92 | src_eyes = detect_fiducial_points(np.array(src_img), predictor_path) 93 | trans = similarityTransform(src_eyes, ref_eyes) 94 | rect_img = cv.warpAffine(src_img, trans, (200, 250)) 95 | return rect_img 96 | 97 | 98 | if __name__ == '__main__': 99 | src_dir = '../result_ours/CUFSF_intersect/ours_result' 100 | ref_dir = '../result_ours/CUFSF_intersect/gt_sketch' 101 | 102 | save_dir = '../result_ours/CUFSF_intersect/ours_warp' 103 | if not os.path.exists(save_dir): os.mkdir(save_dir) 104 | ref_img_list = natsorted(os.listdir(ref_dir)) 105 | src_img_list = natsorted(os.listdir(src_dir)) 106 | 107 | for i in range(len(ref_img_list)): 108 | ref_path = os.path.join(ref_dir, ref_img_list[i]) 109 | src_path = os.path.join(src_dir, src_img_list[i]) 110 | save_path = os.path.join(save_dir, ref_img_list[i]) 111 | warp_src = align_img(ref_path, src_path, './shape_predictor_68_face_landmarks.dat') 112 | cv.imwrite(save_path, warp_src) 113 | # template_eye_pos = np.array([[75, 125], [125, 125]]) 114 | # template_size = (200, 250) 115 | # img_path = '/disk1/cfchen/data/FERET/original_photo/00001.jpg' 116 | # img = cv.imread(img_path) 117 | # detected_eyes = detect_fiducial_points(np.array(img), './shape_predictor_68_face_landmarks.dat') 118 | # trans = similarityTransform(detected_eyes, template_eye_pos) 119 | # rect_img = cv.warpAffine(img, trans, template_size) 120 | # cv.imshow('test', rect_img) 121 | # cv.waitKey() 122 | 123 | 124 | 125 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import torch.nn.parameter as Param 5 | import numpy as np 6 | import torch.nn.functional as F 7 | from torchvision.transforms import functional as tf 8 | 9 | from .utils import tensorToVar, extract_patches 10 | from time import sleep 11 | 12 | def total_variation(x): 13 | """ 14 | Total Variation Loss. 15 | """ 16 | return torch.sum(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]) 17 | ) + torch.sum(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :])) 18 | 19 | 20 | def feature_mse_loss_func(x, y, vgg_model, mask=None, layer=['r51']): 21 | """ 22 | Feature loss define by vgg layer. 23 | """ 24 | if mask is not None: 25 | x = MaskGrad(mask)(x) 26 | x_feat = vgg_model(x, layer) 27 | y_feat = vgg_model(y, layer) 28 | loss = sum([nn.MSELoss()(a, b) for a, b in zip(x_feat, y_feat)]) 29 | return loss 30 | 31 | 32 | def feature_mrf_loss_func(x, y, vgg_model=None, layer=[], match_img_vgg=[], topk=1): 33 | assert isinstance(match_img_vgg, list), 'Parameter match_img_vgg should be a list' 34 | mrf_crit = MRFLoss(topk=topk) 35 | loss = 0. 36 | if len(layer) == 0 or layer[0] == 'r11' or layer[0] == 'r12': 37 | mrf_crit.patch_size = (5, 5) 38 | mrf_crit.filter_patch_stride = 4 39 | if len(layer) == 0: 40 | return mrf_crit(x, y) 41 | x_feat = vgg_model(x, layer) 42 | y_feat = vgg_model(y, layer) 43 | match_img_feat = [vgg_model(m, layer) for m in match_img_vgg] 44 | if len(match_img_vgg) == 0: 45 | for pred, gt in zip(x_feat, y_feat): 46 | loss += mrf_crit(pred, gt) 47 | elif len(match_img_vgg) == 1: 48 | for pred, gt, match0 in zip(x_feat, y_feat, match_img_feat[0]): 49 | loss += mrf_crit(pred, gt, [match0]) 50 | elif len(match_img_vgg) == 2: 51 | for pred, gt, match0, match1 in zip(x_feat, y_feat, match_img_feat[0], match_img_feat[1]): 52 | loss += mrf_crit(pred, gt, [match0, match1]) 53 | return loss 54 | 55 | 56 | class MRFLoss(nn.Module): 57 | """ 58 | Feature level patch matching loss. 59 | """ 60 | def __init__(self, patch_size=(3, 3), filter_patch_stride=1, compare_stride=1, topk=1): 61 | super(MRFLoss, self).__init__() 62 | self.crit = nn.MSELoss() 63 | self.patch_size = patch_size 64 | self.compare_stride = compare_stride 65 | self.filter_patch_stride = filter_patch_stride 66 | self.topk = topk 67 | 68 | def best_topk_match(self, x1, x2): 69 | """ 70 | Best topk match. 71 | x1: reference feature, (B, C, H, W) 72 | x2: topk candidate feature patches, (B*topk, nH*nW, c, patch_size, patch_size) 73 | """ 74 | x1 = F.normalize(x1, p=2, dim=1) 75 | x2 = F.normalize(x2, p=2, dim=2) 76 | k_match, spatial_match = [], [] 77 | dist_func = nn.Conv2d(x1.size(1), x2.size(1), (x2.size(3), x2.size(4)), stride=self.compare_stride, bias=False) 78 | if torch.cuda.is_available(): 79 | dist_func.cuda() 80 | dist_func.eval() 81 | for i in range(x1.size(0)): 82 | tmp_value, tmp_idx = [], [] 83 | for j in range(self.topk): 84 | dist_func.weight.data = x2[i*self.topk + j].squeeze().data 85 | cosine_dist = dist_func(x1[i].unsqueeze(0)) 86 | max_value, max_idx = torch.max(cosine_dist, dim=1, keepdim=False) 87 | tmp_value.append(max_value) 88 | tmp_idx.append(max_idx) 89 | topk_value = torch.stack(tmp_value) 90 | _, k_idx = torch.max(topk_value, dim=0, keepdim=False) 91 | spatial_idx = torch.stack(tmp_idx) 92 | k_match.append(k_idx.squeeze().view(-1).data) 93 | spatial_match.append(spatial_idx.squeeze(1).view(spatial_idx.shape[0], -1).data) 94 | return torch.stack(k_match), torch.stack(spatial_match) 95 | 96 | def get_new_style_map(self): 97 | # Visulize new_target_style_patches 98 | B, nHnW, c, _, _ = self.new_style_feature.size() 99 | feature_map = torch.mean(self.new_style_feature.view(B, nHnW, c, -1), -1) 100 | feature_map = feature_map.view(B, np.sqrt(nHnW).astype(int), np.sqrt(nHnW).astype(int), c) 101 | feature_map = feature_map.permute(0, 3, 1, 2) 102 | return feature_map 103 | 104 | def get_pixel_match(self, topk_ref_style): 105 | topk_style_patches = extract_patches(topk_ref_style, (12, 12), 4) 106 | pred_shape = list(topk_style_patches.size()) 107 | pred_shape[0] = 1 108 | new_topk_target_style_patches = tensorToVar(torch.zeros(pred_shape[0]*self.topk, 109 | pred_shape[1], pred_shape[2], pred_shape[3], pred_shape[4])) 110 | self.spatial_best_match = self.spatial_best_match.view(pred_shape[0]*self.topk, -1) 111 | for i in range(pred_shape[0]*self.topk): 112 | new_topk_target_style_patches[i] = topk_style_patches[[i], self.spatial_best_match[i]] 113 | new_topk_target_style_patches = new_topk_target_style_patches.view(pred_shape[0], self.topk, 114 | pred_shape[1], pred_shape[2], pred_shape[3], pred_shape[4]) 115 | new_target_style_patches = tensorToVar(torch.zeros(pred_shape)) 116 | for i in range(self.k_best_match.shape[0]): 117 | for j in range(self.k_best_match.shape[1]): 118 | new_target_style_patches[i, j] = new_topk_target_style_patches[i, self.k_best_match[i, j], j] 119 | B, nHnW, c, _, _ = new_target_style_patches.shape 120 | nH = int(np.sqrt(nHnW)) 121 | pix_vis = new_target_style_patches[:, :, :, 4:8, 4:8].squeeze() 122 | 123 | pix_vis = pix_vis.permute(1, 0, 2, 3).contiguous() 124 | pix_vis = pix_vis.view(3, nH, nH, 4, 4) 125 | pix_vis = pix_vis.permute(0, 1, 3, 2, 4).contiguous() 126 | pix_vis = pix_vis.view(3, nH*4, nH*4) 127 | return pix_vis.unsqueeze(0) 128 | 129 | def forward(self, pred_style, target_style, match=[]): 130 | """ 131 | pred_style: feature of predicted image 132 | target_style: target style feature 133 | match: images used to match pred_style with target style 134 | 135 | switch(len(match)): 136 | case 0: matching is done between pred_style and target_style 137 | case 1: matching is done between match[0] and target style 138 | case 2: matching is done between match[0] and match[1] 139 | """ 140 | assert isinstance(match, list), 'Parameter match should be a list' 141 | target_style_patches = extract_patches(target_style, self.patch_size, self.filter_patch_stride) 142 | pred_style_patches = extract_patches(pred_style, self.patch_size, self.compare_stride) 143 | 144 | bk, nhnw, c, psz, psz = target_style_patches.shape 145 | 146 | if len(match) == 0: 147 | k_best_match, spatial_best_match = self.best_topk_match(pred_style, target_style_patches) 148 | elif len(match) == 1: 149 | k_best_match, spatial_best_match = self.best_topk_match(match[0], target_style_patches) 150 | elif len(match) == 2: 151 | match_patches = extract_patches(match[1], self.patch_size, self.filter_patch_stride) 152 | k_best_match, spatial_best_match = self.best_topk_match(match[0], match_patches) 153 | 154 | self.k_best_match = k_best_match 155 | self.spatial_best_match = spatial_best_match 156 | 157 | pred_shape = pred_style_patches.size() 158 | new_topk_target_style_patches = tensorToVar(torch.zeros(pred_shape[0]*self.topk, 159 | pred_shape[1], pred_shape[2], pred_shape[3], pred_shape[4])) 160 | spatial_best_match = spatial_best_match.view(pred_shape[0]*self.topk, -1) 161 | for i in range(pred_shape[0]*self.topk): 162 | new_topk_target_style_patches[i] = target_style_patches[[i], spatial_best_match[i]] 163 | new_topk_target_style_patches = new_topk_target_style_patches.view(pred_shape[0], self.topk, 164 | pred_shape[1], pred_shape[2], pred_shape[3], pred_shape[4]) 165 | new_target_style_patches = tensorToVar(torch.zeros(pred_shape)) 166 | for i in range(k_best_match.shape[0]): 167 | for j in range(k_best_match.shape[1]): 168 | new_target_style_patches[i, j] = new_topk_target_style_patches[i, k_best_match[i, j], j] 169 | self.new_style_feature = new_target_style_patches 170 | loss = self.crit(pred_style_patches, new_target_style_patches) 171 | return loss 172 | 173 | 174 | 175 | 176 | 177 | -------------------------------------------------------------------------------- /face2sketch_wild.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.nn as nn 4 | from torch.optim import Adam 5 | from torch.optim.lr_scheduler import MultiStepLR 6 | from torchvision import transforms 7 | 8 | import argparse 9 | import os 10 | import numpy as np 11 | from time import time 12 | from datetime import datetime 13 | import itertools 14 | import copy 15 | from glob import glob 16 | 17 | from utils.face_sketch_data import * 18 | from models.networks import SketchNet, DNet 19 | from models.vgg19 import vgg19 20 | from utils import loss 21 | from utils import img_process 22 | from utils import search_dataset 23 | from utils import logger 24 | from utils import utils 25 | from utils.metric import avg_score 26 | 27 | 28 | def cmd_option(): 29 | arg_parser = argparse.ArgumentParser(description='CMD arguments for the face sketch network') 30 | arg_parser.add_argument('train_eval', type=str, default='train', help='Train or eval') 31 | arg_parser.add_argument('--gpus', type=str, default='0', help='Which gpus to train the model') 32 | arg_parser.add_argument('--train-data', type=str, nargs='*', 33 | default=["./data/AR/train_photos", "./data/CUHK_student/train_photos", "./data/XM2VTS/train_photos", "./data/CUFSF/train_photos"], help="Train data dir root") 34 | arg_parser.add_argument('--resume', type=int, default=0, help='Resume training or not') 35 | arg_parser.add_argument('--train-style', type=str, default='cufs', help='Styles used to train') 36 | arg_parser.add_argument('--seed', type=int, default=1234, help='Random seed for training') 37 | arg_parser.add_argument('--batch-size', type=int, default=6, help='Train batch size') 38 | arg_parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate for training') 39 | arg_parser.add_argument('--epochs', type=int, default=40, help='Training epochs to generate') 40 | arg_parser.add_argument('--weight-root', type=str, default='./weight', help='Weight saving path') 41 | arg_parser.add_argument('--vgg19-weight', type=str, default='/home/cfchen/pytorch_models/vgg_conv.pth', 42 | help='Pretrained vgg19 weight path') 43 | arg_parser.add_argument('--Gnorm', type=str, default='IN', help="Instance(IN) normalization or batch(BN) normalization") 44 | arg_parser.add_argument('--Dnorm', type=str, default='None', help="Instance(IN) normalization or batch(BN) normalization") 45 | arg_parser.add_argument('--flayers', type=int, nargs=5, default=[0, 0, 1, 1, 1], help="Layers used to calculate feature loss") 46 | arg_parser.add_argument('--weight', type=float, nargs=3, default=[1e0, 1e3, 1e-5], help="MSE loss weight, Feature loss weight, and total variation weight") 47 | arg_parser.add_argument('--topk', type=int, default=1, help="Topk image choose to match input photo") 48 | arg_parser.add_argument('--meanshift', type=int, default=20, help="mean shift of the predicted sketch.") 49 | arg_parser.add_argument('--other', type=str, default='', help="Other information") 50 | 51 | arg_parser.add_argument('--test-dir', type=str, default='', help='Test image directory') 52 | arg_parser.add_argument('--test-gt-dir', type=str, default='', help='Test ground truth image directory') 53 | arg_parser.add_argument('--result-dir', type=str, default='./result', help='Result saving directory') 54 | arg_parser.add_argument('--test-weight-path', type=str, default='', help='Test model path') 55 | return arg_parser.parse_args() 56 | 57 | def train(args): 58 | torch.backends.cudnn.benchmark=True 59 | torch.backends.cudnn.deterministic = True 60 | np.random.seed(args.seed) 61 | random.seed(args.seed) 62 | torch.manual_seed(args.seed) 63 | torch.cuda.manual_seed_all(args.seed) 64 | 65 | # -------------------- Load data ---------------------------------- 66 | transform = transforms.Compose([ 67 | Rescale((224, 224)), 68 | ColorJitter(0.5, 0.5, 0.5, 0.3, 0.5), 69 | ToTensor(), 70 | ]) 71 | dataset = FaceDataset(args.train_data, True, transform=transform) 72 | data_loader = DataLoader(dataset, shuffle=True, batch_size=args.batch_size, drop_last=True, num_workers=4) 73 | 74 | # ----------------- Define networks --------------------------------- 75 | Gnet= SketchNet(in_channels=3, out_channels=1, norm_type=args.Gnorm) 76 | Dnet = DNet(norm_type=args.Dnorm) 77 | vgg19_model = vgg19(args.vgg19_weight) 78 | 79 | gpu_ids = [int(x) for x in args.gpus.split(',')] 80 | if len(gpu_ids) > 0: 81 | Gnet.cuda() 82 | Dnet.cuda() 83 | Gnet = nn.DataParallel(Gnet, device_ids=gpu_ids) 84 | Dnet = nn.DataParallel(Dnet, device_ids=gpu_ids) 85 | vgg19_model = nn.DataParallel(vgg19_model, device_ids=gpu_ids) 86 | 87 | Gnet.train() 88 | Dnet.train() 89 | 90 | if args.resume: 91 | weights = glob(os.path.join(args.save_weight_path, '*-*.pth')) 92 | weight_path = sorted(weights)[-1][:-5] 93 | Gnet.load_state_dict(torch.load(weight_path + 'G.pth')) 94 | Dnet.load_state_dict(torch.load(weight_path + 'D.pth')) 95 | 96 | # ---------------- set optimizer and learning rate --------------------- 97 | args.epochs = np.ceil(args.epochs * 1000 / len(dataset)) 98 | args.epochs = max(int(args.epochs), 4) 99 | ms = [int(1./4 * args.epochs), int(2./4 * args.epochs)] 100 | 101 | optim_G = Adam(Gnet.parameters(), args.lr) 102 | optim_D = Adam(Dnet.parameters(), args.lr) 103 | scheduler_G = MultiStepLR(optim_G, milestones=ms, gamma=0.1) 104 | scheduler_D = MultiStepLR(optim_D, milestones=ms, gamma=0.1) 105 | mse_crit = nn.MSELoss() 106 | 107 | # ---------------------- Define reference styles and feature loss layers ---------- 108 | if args.train_style == 'cufs': 109 | ref_style_dataset = ['CUHK_student', 'AR', 'XM2VTS'] 110 | ref_feature = './data/cufs_feature_dataset.pth' 111 | ref_img_list = './data/cufs_reference_img_list.txt' 112 | elif args.train_style == 'cufsf': 113 | ref_style_dataset = ['CUFSF'] 114 | ref_feature = './data/cufsf_feature_dataset.pth' 115 | ref_img_list = './data/cufsf_reference_img_list.txt' 116 | else: 117 | assert 1==0, 'Train style {} not supported.'.format(args.train_style) 118 | 119 | vgg_feature_layers = ['r11', 'r21', 'r31', 'r41', 'r51'] 120 | feature_loss_layers = list(itertools.compress(vgg_feature_layers, args.flayers)) 121 | 122 | log = logger.Logger(args.save_weight_path) 123 | 124 | for e in range(args.epochs): 125 | scheduler_G.step() 126 | scheduler_D.step() 127 | sample_count = 0 128 | for batch_idx, batch_data in enumerate(data_loader): 129 | # ---------------- Load data ------------------- 130 | start = time() 131 | train_img, train_img_org = [utils.tensorToVar(x) for x in batch_data] 132 | topk_sketch_img, topk_photo_img = search_dataset.find_photo_sketch_batch( 133 | train_img_org, ref_feature, ref_img_list, 134 | vgg19_model, dataset_filter=ref_style_dataset, topk=args.topk) 135 | random_real_sketch = search_dataset.get_real_sketch_batch(train_img.size(0), ref_img_list, dataset_filter=ref_style_dataset) 136 | end = time() 137 | data_time = end - start 138 | sample_count += train_img.size(0) 139 | 140 | # ---------------- Model forward ------------------- 141 | start = time() 142 | fake_sketch = Gnet(train_img) 143 | fake_score = Dnet(fake_sketch) 144 | real_score = Dnet(random_real_sketch) 145 | 146 | real_label = torch.ones_like(fake_score) 147 | fake_label = torch.zeros_like(fake_score) 148 | 149 | # ----------------- Calculate loss and backward ------------------- 150 | train_img_org_vgg = img_process.subtract_mean_batch(train_img_org, 'face') 151 | topk_sketch_img_vgg = img_process.subtract_mean_batch(topk_sketch_img, 'sketch') 152 | topk_photo_img_vgg = img_process.subtract_mean_batch(topk_photo_img, 'face') 153 | fake_sketch_vgg = img_process.subtract_mean_batch(fake_sketch.expand_as(train_img_org), 'sketch', args.meanshift) 154 | 155 | style_loss = loss.feature_mrf_loss_func( 156 | fake_sketch_vgg, topk_sketch_img_vgg, vgg19_model, 157 | feature_loss_layers, [train_img_org_vgg, topk_photo_img_vgg], topk=args.topk) 158 | 159 | tv_loss = loss.total_variation(fake_sketch) 160 | 161 | # GAN Loss 162 | adv_loss = mse_crit(fake_score, real_label) * args.weight[1] 163 | tv_loss = tv_loss * args.weight[2] 164 | loss_G = style_loss * args.weight[0] + adv_loss + tv_loss 165 | loss_D = 0.5 * mse_crit(fake_score, fake_label) + 0.5 * mse_crit(real_score, real_label) 166 | 167 | # Update parameters 168 | optim_D.zero_grad() 169 | loss_D.backward(retain_graph=True) 170 | optim_D.step() 171 | 172 | optim_G.zero_grad() 173 | loss_G.backward() 174 | optim_G.step() 175 | 176 | end = time() 177 | train_time = end - start 178 | 179 | # ----------------- Print result and log the output ------------------- 180 | log.iterLogUpdate(loss_G.data[0]) 181 | if batch_idx % 100 == 0: 182 | log.draw_loss_curve() 183 | 184 | msg = "{:%Y-%m-%d %H:%M:%S}\tEpoch [{:03d}/{:03d}]\tBatch [{:03d}/{:03d}]\tData: {:.2f} Train: {:.2f}\tLoss: G-{:.4f}, Adv-{:.4f}, tv-{:.4f}, D-{:.4f}".format( 185 | datetime.now(), 186 | e, args.epochs, sample_count, len(dataset), 187 | data_time, train_time, *[x.data[0] for x in [loss_G, adv_loss, tv_loss, loss_D]]) 188 | print(msg) 189 | log_file = open(os.path.join(args.save_weight_path, 'log.txt'), 'a+') 190 | log_file.write(msg + '\n') 191 | log_file.close() 192 | 193 | save_weight_name = "epochs-{:03d}-".format(e) 194 | G_cpu_model = copy.deepcopy(Gnet).cpu() 195 | D_cpu_model = copy.deepcopy(Dnet).cpu() 196 | torch.save(G_cpu_model.state_dict(), os.path.join(args.save_weight_path, save_weight_name+'G.pth')) 197 | torch.save(D_cpu_model.state_dict(), os.path.join(args.save_weight_path, save_weight_name+'D.pth')) 198 | 199 | 200 | def test(args): 201 | """ 202 | Test image of a given directory. Calculate the quantitative result if ground truth dir is provided. 203 | """ 204 | Gnet= SketchNet(in_channels=3, out_channels=1, norm_type=args.Gnorm) 205 | gpu_ids = [int(x) for x in args.gpus.split(',')] 206 | if len(gpu_ids) > 0: 207 | Gnet.cuda() 208 | Gnet = nn.DataParallel(Gnet, device_ids=gpu_ids) 209 | Gnet.eval() 210 | Gnet.load_state_dict(torch.load(args.test_weight_path)) 211 | 212 | utils.mkdirs(args.result_dir) 213 | for img_name in os.listdir(args.test_dir): 214 | test_img_path = os.path.join(args.test_dir, img_name) 215 | test_img = img_process.read_img_var(test_img_path, size=(256, 256)) 216 | face_pred = Gnet(test_img) 217 | 218 | sketch_save_path = os.path.join(args.result_dir, img_name) 219 | img_process.save_var_img(face_pred, sketch_save_path, (250, 200)) 220 | print('Save sketch in', sketch_save_path) 221 | 222 | if args.test_gt_dir != 'none': 223 | print('------------ Calculating average SSIM (This may take for a while)-----------') 224 | avg_ssim = avg_score(args.result_dir, args.test_gt_dir, metric_name='ssim', smooth=False, verbose=True) 225 | print('------------ Calculating smoothed average SSIM (This may take for a while)-----------') 226 | avg_ssim_smoothed = avg_score(args.result_dir, args.test_gt_dir, metric_name='ssim', smooth=True, verbose=True) 227 | print('------------ Calculating average FSIM (This may take for a while)-----------') 228 | avg_fsim = avg_score(args.result_dir, args.test_gt_dir, metric_name='fsim', smooth=False, verbose=True) 229 | print('------------ Calculating smoothed average FSIM (This may take for a while)-----------') 230 | avg_fsim_smoothed = avg_score(args.result_dir, args.test_gt_dir, metric_name='fsim', smooth=True, verbose=True) 231 | print('Average SSIM: {}'.format(avg_ssim)) 232 | print('Average SSIM (Smoothed): {}'.format(avg_ssim_smoothed)) 233 | print('Average FSIM: {}'.format(avg_fsim)) 234 | print('Average FSIM (Smoothed): {}'.format(avg_fsim_smoothed)) 235 | 236 | if __name__ == '__main__': 237 | args = cmd_option() 238 | gpu_ids = [int(x) for x in args.gpus.split(',')] 239 | torch.cuda.set_device(gpu_ids[0]) 240 | 241 | args.save_weight_dir = 'face2sketch-norm_G{}_D{}-top{}-style_{}-flayers{}-weight-{:.1e}-{:.1e}-{:.1e}-epoch{:02d}-{}'.format( 242 | args.Gnorm, args.Dnorm, args.topk, args.train_style, "".join(map(str, args.flayers)), 243 | args.weight[0], args.weight[1], args.weight[2], 244 | args.epochs, args.other) 245 | args.save_weight_path = os.path.join(args.weight_root, args.save_weight_dir) 246 | 247 | if args.train_eval == 'train': 248 | print('Saving weight path', args.save_weight_path) 249 | utils.mkdirs(args.save_weight_path) 250 | train(args) 251 | elif args.train_eval == 'eval': 252 | test(args) 253 | 254 | 255 | -------------------------------------------------------------------------------- /utils/FeatureSIM.m: -------------------------------------------------------------------------------- 1 | function [FSIM, FSIMc] = FeatureSIM(imageRef, imageDis) 2 | % ======================================================================== 3 | % FSIM Index with automatic downsampling, Version 1.0 4 | % Copyright(c) 2010 Lin ZHANG, Lei Zhang, Xuanqin Mou and David Zhang 5 | % All Rights Reserved. 6 | % 7 | % ---------------------------------------------------------------------- 8 | % Permission to use, copy, or modify this software and its documentation 9 | % for educational and research purposes only and without fee is here 10 | % granted, provided that this copyright notice and the original authors' 11 | % names appear on all copies and supporting documentation. This program 12 | % shall not be used, rewritten, or adapted as the basis of a commercial 13 | % software or hardware product without first obtaining permission of the 14 | % authors. The authors make no representations about the suitability of 15 | % this software for any purpose. It is provided "as is" without express 16 | % or implied warranty. 17 | %---------------------------------------------------------------------- 18 | % 19 | % This is an implementation of the algorithm for calculating the 20 | % Feature SIMilarity (FSIM) index between two images. 21 | % 22 | % Please refer to the following paper 23 | % 24 | % Lin Zhang, Lei Zhang, Xuanqin Mou, and David Zhang,"FSIM: a feature similarity 25 | % index for image qualtiy assessment", IEEE Transactions on Image Processing, vol. 20, no. 8, pp. 2378-2386, 2011. 26 | % 27 | %---------------------------------------------------------------------- 28 | % 29 | %Input : (1) imageRef: the first image being compared 30 | % (2) imageDis: the second image being compared 31 | % 32 | %Output: (1) FSIM: is the similarty score calculated using FSIM algorithm. FSIM 33 | % only considers the luminance component of images. For colorful images, 34 | % they will be converted to the grayscale at first. 35 | % (2) FSIMc: is the similarity score calculated using FSIMc algorithm. FSIMc 36 | % considers both the grayscale and the color information. 37 | %Note: For grayscale images, the returned FSIM and FSIMc are the same. 38 | % 39 | %----------------------------------------------------------------------- 40 | % 41 | %Usage: 42 | %Given 2 test images img1 and img2. For gray-scale images, their dynamic range should be 0-255. 43 | %For colorful images, the dynamic range of each color channel should be 0-255. 44 | % 45 | %[FSIM, FSIMc] = FeatureSIM(img1, img2); 46 | %----------------------------------------------------------------------- 47 | 48 | [rows, cols] = size(imageRef(:,:,1)); 49 | I1 = ones(rows, cols); 50 | I2 = ones(rows, cols); 51 | Q1 = ones(rows, cols); 52 | Q2 = ones(rows, cols); 53 | 54 | if ndims(imageRef) == 3 %images are colorful 55 | Y1 = 0.299 * double(imageRef(:,:,1)) + 0.587 * double(imageRef(:,:,2)) + 0.114 * double(imageRef(:,:,3)); 56 | Y2 = 0.299 * double(imageDis(:,:,1)) + 0.587 * double(imageDis(:,:,2)) + 0.114 * double(imageDis(:,:,3)); 57 | I1 = 0.596 * double(imageRef(:,:,1)) - 0.274 * double(imageRef(:,:,2)) - 0.322 * double(imageRef(:,:,3)); 58 | I2 = 0.596 * double(imageDis(:,:,1)) - 0.274 * double(imageDis(:,:,2)) - 0.322 * double(imageDis(:,:,3)); 59 | Q1 = 0.211 * double(imageRef(:,:,1)) - 0.523 * double(imageRef(:,:,2)) + 0.312 * double(imageRef(:,:,3)); 60 | Q2 = 0.211 * double(imageDis(:,:,1)) - 0.523 * double(imageDis(:,:,2)) + 0.312 * double(imageDis(:,:,3)); 61 | else %images are grayscale 62 | Y1 = imageRef; 63 | Y2 = imageDis; 64 | end 65 | 66 | Y1 = double(Y1); 67 | Y2 = double(Y2); 68 | %%%%%%%%%%%%%%%%%%%%%%%%% 69 | % Downsample the image 70 | %%%%%%%%%%%%%%%%%%%%%%%%% 71 | minDimension = min(rows,cols); 72 | F = max(1,round(minDimension / 256)); 73 | aveKernel = fspecial('average',F); 74 | 75 | aveI1 = conv2(I1, aveKernel,'same'); 76 | aveI2 = conv2(I2, aveKernel,'same'); 77 | I1 = aveI1(1:F:rows,1:F:cols); 78 | I2 = aveI2(1:F:rows,1:F:cols); 79 | 80 | aveQ1 = conv2(Q1, aveKernel,'same'); 81 | aveQ2 = conv2(Q2, aveKernel,'same'); 82 | Q1 = aveQ1(1:F:rows,1:F:cols); 83 | Q2 = aveQ2(1:F:rows,1:F:cols); 84 | 85 | aveY1 = conv2(Y1, aveKernel,'same'); 86 | aveY2 = conv2(Y2, aveKernel,'same'); 87 | Y1 = aveY1(1:F:rows,1:F:cols); 88 | Y2 = aveY2(1:F:rows,1:F:cols); 89 | 90 | %%%%%%%%%%%%%%%%%%%%%%%%% 91 | % Calculate the phase congruency maps 92 | %%%%%%%%%%%%%%%%%%%%%%%%% 93 | PC1 = phasecong2(Y1); 94 | PC2 = phasecong2(Y2); 95 | 96 | %%%%%%%%%%%%%%%%%%%%%%%%% 97 | % Calculate the gradient map 98 | %%%%%%%%%%%%%%%%%%%%%%%%% 99 | dx = [3 0 -3; 10 0 -10; 3 0 -3]/16; 100 | dy = [3 10 3; 0 0 0; -3 -10 -3]/16; 101 | IxY1 = conv2(Y1, dx, 'same'); 102 | IyY1 = conv2(Y1, dy, 'same'); 103 | gradientMap1 = sqrt(IxY1.^2 + IyY1.^2); 104 | 105 | IxY2 = conv2(Y2, dx, 'same'); 106 | IyY2 = conv2(Y2, dy, 'same'); 107 | gradientMap2 = sqrt(IxY2.^2 + IyY2.^2); 108 | 109 | %%%%%%%%%%%%%%%%%%%%%%%%% 110 | % Calculate the FSIM 111 | %%%%%%%%%%%%%%%%%%%%%%%%% 112 | T1 = 0.85; %fixed 113 | T2 = 160; %fixed 114 | PCSimMatrix = (2 * PC1 .* PC2 + T1) ./ (PC1.^2 + PC2.^2 + T1); 115 | gradientSimMatrix = (2*gradientMap1.*gradientMap2 + T2) ./(gradientMap1.^2 + gradientMap2.^2 + T2); 116 | PCm = max(PC1, PC2); 117 | SimMatrix = gradientSimMatrix .* PCSimMatrix .* PCm; 118 | FSIM = sum(sum(SimMatrix)) / sum(sum(PCm)); 119 | 120 | %%%%%%%%%%%%%%%%%%%%%%%%% 121 | % Calculate the FSIMc 122 | %%%%%%%%%%%%%%%%%%%%%%%%% 123 | T3 = 200; 124 | T4 = 200; 125 | ISimMatrix = (2 * I1 .* I2 + T3) ./ (I1.^2 + I2.^2 + T3); 126 | QSimMatrix = (2 * Q1 .* Q2 + T4) ./ (Q1.^2 + Q2.^2 + T4); 127 | 128 | lambda = 0.03; 129 | 130 | SimMatrixC = gradientSimMatrix .* PCSimMatrix .* real((ISimMatrix .* QSimMatrix) .^ lambda) .* PCm; 131 | FSIMc = sum(sum(SimMatrixC)) / sum(sum(PCm)); 132 | 133 | return; 134 | 135 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 136 | 137 | function [ResultPC]=phasecong2(im) 138 | % ======================================================================== 139 | % Copyright (c) 1996-2009 Peter Kovesi 140 | % School of Computer Science & Software Engineering 141 | % The University of Western Australia 142 | % http://www.csse.uwa.edu.au/ 143 | % 144 | % Permission is hereby granted, free of charge, to any person obtaining a copy 145 | % of this software and associated documentation files (the "Software"), to deal 146 | % in the Software without restriction, subject to the following conditions: 147 | % 148 | % The above copyright notice and this permission notice shall be included in all 149 | % copies or substantial portions of the Software. 150 | % 151 | % The software is provided "as is", without warranty of any kind. 152 | % References: 153 | % 154 | % Peter Kovesi, "Image Features From Phase Congruency". Videre: A 155 | % Journal of Computer Vision Research. MIT Press. Volume 1, Number 3, 156 | % Summer 1999 http://mitpress.mit.edu/e-journals/Videre/001/v13.html 157 | 158 | nscale = 4; % Number of wavelet scales. 159 | norient = 4; % Number of filter orientations. 160 | minWaveLength = 6; % Wavelength of smallest scale filter. 161 | mult = 2; % Scaling factor between successive filters. 162 | sigmaOnf = 0.55; % Ratio of the standard deviation of the 163 | % Gaussian describing the log Gabor filter's 164 | % transfer function in the frequency domain 165 | % to the filter center frequency. 166 | dThetaOnSigma = 1.2; % Ratio of angular interval between filter orientations 167 | % and the standard deviation of the angular Gaussian 168 | % function used to construct filters in the 169 | % freq. plane. 170 | k = 2.0; % No of standard deviations of the noise 171 | % energy beyond the mean at which we set the 172 | % noise threshold point. 173 | % below which phase congruency values get 174 | % penalized. 175 | epsilon = .0001; % Used to prevent division by zero. 176 | 177 | thetaSigma = pi/norient/dThetaOnSigma; % Calculate the standard deviation of the 178 | % angular Gaussian function used to 179 | % construct filters in the freq. plane. 180 | 181 | [rows,cols] = size(im); 182 | imagefft = fft2(im); % Fourier transform of image 183 | 184 | zero = zeros(rows,cols); 185 | EO = cell(nscale, norient); % Array of convolution results. 186 | 187 | estMeanE2n = []; 188 | ifftFilterArray = cell(1,nscale); % Array of inverse FFTs of filters 189 | 190 | % Pre-compute some stuff to speed up filter construction 191 | 192 | % Set up X and Y matrices with ranges normalised to +/- 0.5 193 | % The following code adjusts things appropriately for odd and even values 194 | % of rows and columns. 195 | if mod(cols,2) 196 | xrange = [-(cols-1)/2:(cols-1)/2]/(cols-1); 197 | else 198 | xrange = [-cols/2:(cols/2-1)]/cols; 199 | end 200 | 201 | if mod(rows,2) 202 | yrange = [-(rows-1)/2:(rows-1)/2]/(rows-1); 203 | else 204 | yrange = [-rows/2:(rows/2-1)]/rows; 205 | end 206 | 207 | [x,y] = meshgrid(xrange, yrange); 208 | 209 | radius = sqrt(x.^2 + y.^2); % Matrix values contain *normalised* radius from centre. 210 | theta = atan2(-y,x); % Matrix values contain polar angle. 211 | % (note -ve y is used to give +ve 212 | % anti-clockwise angles) 213 | 214 | radius = ifftshift(radius); % Quadrant shift radius and theta so that filters 215 | theta = ifftshift(theta); % are constructed with 0 frequency at the corners. 216 | radius(1,1) = 1; % Get rid of the 0 radius value at the 0 217 | % frequency point (now at top-left corner) 218 | % so that taking the log of the radius will 219 | % not cause trouble. 220 | 221 | sintheta = sin(theta); 222 | costheta = cos(theta); 223 | clear x; clear y; clear theta; % save a little memory 224 | 225 | % Filters are constructed in terms of two components. 226 | % 1) The radial component, which controls the frequency band that the filter 227 | % responds to 228 | % 2) The angular component, which controls the orientation that the filter 229 | % responds to. 230 | % The two components are multiplied together to construct the overall filter. 231 | 232 | % Construct the radial filter components... 233 | 234 | % First construct a low-pass filter that is as large as possible, yet falls 235 | % away to zero at the boundaries. All log Gabor filters are multiplied by 236 | % this to ensure no extra frequencies at the 'corners' of the FFT are 237 | % incorporated as this seems to upset the normalisation process when 238 | % calculating phase congrunecy. 239 | lp = lowpassfilter([rows,cols],.45,15); % Radius .45, 'sharpness' 15 240 | 241 | logGabor = cell(1,nscale); 242 | 243 | for s = 1:nscale 244 | wavelength = minWaveLength*mult^(s-1); 245 | fo = 1.0/wavelength; % Centre frequency of filter. 246 | logGabor{s} = exp((-(log(radius/fo)).^2) / (2 * log(sigmaOnf)^2)); 247 | logGabor{s} = logGabor{s}.*lp; % Apply low-pass filter 248 | logGabor{s}(1,1) = 0; % Set the value at the 0 frequency point of the filter 249 | % back to zero (undo the radius fudge). 250 | end 251 | 252 | % Then construct the angular filter components... 253 | 254 | spread = cell(1,norient); 255 | 256 | for o = 1:norient 257 | angl = (o-1)*pi/norient; % Filter angle. 258 | 259 | % For each point in the filter matrix calculate the angular distance from 260 | % the specified filter orientation. To overcome the angular wrap-around 261 | % problem sine difference and cosine difference values are first computed 262 | % and then the atan2 function is used to determine angular distance. 263 | 264 | ds = sintheta * cos(angl) - costheta * sin(angl); % Difference in sine. 265 | dc = costheta * cos(angl) + sintheta * sin(angl); % Difference in cosine. 266 | dtheta = abs(atan2(ds,dc)); % Absolute angular distance. 267 | spread{o} = exp((-dtheta.^2) / (2 * thetaSigma^2)); % Calculate the 268 | % angular filter component. 269 | end 270 | 271 | % The main loop... 272 | EnergyAll(rows,cols) = 0; 273 | AnAll(rows,cols) = 0; 274 | 275 | for o = 1:norient % For each orientation. 276 | sumE_ThisOrient = zero; % Initialize accumulator matrices. 277 | sumO_ThisOrient = zero; 278 | sumAn_ThisOrient = zero; 279 | Energy = zero; 280 | for s = 1:nscale, % For each scale. 281 | filter = logGabor{s} .* spread{o}; % Multiply radial and angular 282 | % components to get the filter. 283 | ifftFilt = real(ifft2(filter))*sqrt(rows*cols); % Note rescaling to match power 284 | ifftFilterArray{s} = ifftFilt; % record ifft2 of filter 285 | % Convolve image with even and odd filters returning the result in EO 286 | EO{s,o} = ifft2(imagefft .* filter); 287 | 288 | An = abs(EO{s,o}); % Amplitude of even & odd filter response. 289 | sumAn_ThisOrient = sumAn_ThisOrient + An; % Sum of amplitude responses. 290 | sumE_ThisOrient = sumE_ThisOrient + real(EO{s,o}); % Sum of even filter convolution results. 291 | sumO_ThisOrient = sumO_ThisOrient + imag(EO{s,o}); % Sum of odd filter convolution results. 292 | if s==1 % Record mean squared filter value at smallest 293 | EM_n = sum(sum(filter.^2)); % scale. This is used for noise estimation. 294 | maxAn = An; % Record the maximum An over all scales. 295 | else 296 | maxAn = max(maxAn, An); 297 | end 298 | end % ... and process the next scale 299 | 300 | % Get weighted mean filter response vector, this gives the weighted mean 301 | % phase angle. 302 | 303 | XEnergy = sqrt(sumE_ThisOrient.^2 + sumO_ThisOrient.^2) + epsilon; 304 | MeanE = sumE_ThisOrient ./ XEnergy; 305 | MeanO = sumO_ThisOrient ./ XEnergy; 306 | 307 | % Now calculate An(cos(phase_deviation) - | sin(phase_deviation)) | by 308 | % using dot and cross products between the weighted mean filter response 309 | % vector and the individual filter response vectors at each scale. This 310 | % quantity is phase congruency multiplied by An, which we call energy. 311 | 312 | for s = 1:nscale, 313 | E = real(EO{s,o}); O = imag(EO{s,o}); % Extract even and odd 314 | % convolution results. 315 | Energy = Energy + E.*MeanE + O.*MeanO - abs(E.*MeanO - O.*MeanE); 316 | end 317 | 318 | % Compensate for noise 319 | % We estimate the noise power from the energy squared response at the 320 | % smallest scale. If the noise is Gaussian the energy squared will have a 321 | % Chi-squared 2DOF pdf. We calculate the median energy squared response 322 | % as this is a robust statistic. From this we estimate the mean. 323 | % The estimate of noise power is obtained by dividing the mean squared 324 | % energy value by the mean squared filter value 325 | 326 | medianE2n = median(reshape(abs(EO{1,o}).^2,1,rows*cols)); 327 | meanE2n = -medianE2n/log(0.5); 328 | estMeanE2n(o) = meanE2n; 329 | 330 | noisePower = meanE2n/EM_n; % Estimate of noise power. 331 | 332 | % Now estimate the total energy^2 due to noise 333 | % Estimate for sum(An^2) + sum(Ai.*Aj.*(cphi.*cphj + sphi.*sphj)) 334 | 335 | EstSumAn2 = zero; 336 | for s = 1:nscale 337 | EstSumAn2 = EstSumAn2 + ifftFilterArray{s}.^2; 338 | end 339 | 340 | EstSumAiAj = zero; 341 | for si = 1:(nscale-1) 342 | for sj = (si+1):nscale 343 | EstSumAiAj = EstSumAiAj + ifftFilterArray{si}.*ifftFilterArray{sj}; 344 | end 345 | end 346 | sumEstSumAn2 = sum(sum(EstSumAn2)); 347 | sumEstSumAiAj = sum(sum(EstSumAiAj)); 348 | 349 | EstNoiseEnergy2 = 2*noisePower*sumEstSumAn2 + 4*noisePower*sumEstSumAiAj; 350 | 351 | tau = sqrt(EstNoiseEnergy2/2); % Rayleigh parameter 352 | EstNoiseEnergy = tau*sqrt(pi/2); % Expected value of noise energy 353 | EstNoiseEnergySigma = sqrt( (2-pi/2)*tau^2 ); 354 | 355 | T = EstNoiseEnergy + k*EstNoiseEnergySigma; % Noise threshold 356 | 357 | % The estimated noise effect calculated above is only valid for the PC_1 measure. 358 | % The PC_2 measure does not lend itself readily to the same analysis. However 359 | % empirically it seems that the noise effect is overestimated roughly by a factor 360 | % of 1.7 for the filter parameters used here. 361 | 362 | T = T/1.7; % Empirical rescaling of the estimated noise effect to 363 | % suit the PC_2 phase congruency measure 364 | Energy = max(Energy - T, zero); % Apply noise threshold 365 | 366 | EnergyAll = EnergyAll + Energy; 367 | AnAll = AnAll + sumAn_ThisOrient; 368 | end % For each orientation 369 | ResultPC = EnergyAll ./ AnAll; 370 | return; 371 | 372 | 373 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 374 | % LOWPASSFILTER - Constructs a low-pass butterworth filter. 375 | % 376 | % usage: f = lowpassfilter(sze, cutoff, n) 377 | % 378 | % where: sze is a two element vector specifying the size of filter 379 | % to construct [rows cols]. 380 | % cutoff is the cutoff frequency of the filter 0 - 0.5 381 | % n is the order of the filter, the higher n is the sharper 382 | % the transition is. (n must be an integer >= 1). 383 | % Note that n is doubled so that it is always an even integer. 384 | % 385 | % 1 386 | % f = -------------------- 387 | % 2n 388 | % 1.0 + (w/cutoff) 389 | % 390 | % The frequency origin of the returned filter is at the corners. 391 | % 392 | % See also: HIGHPASSFILTER, HIGHBOOSTFILTER, BANDPASSFILTER 393 | % 394 | 395 | % Copyright (c) 1999 Peter Kovesi 396 | % School of Computer Science & Software Engineering 397 | % The University of Western Australia 398 | % http://www.csse.uwa.edu.au/ 399 | % 400 | % Permission is hereby granted, free of charge, to any person obtaining a copy 401 | % of this software and associated documentation files (the "Software"), to deal 402 | % in the Software without restriction, subject to the following conditions: 403 | % 404 | % The above copyright notice and this permission notice shall be included in 405 | % all copies or substantial portions of the Software. 406 | % 407 | % The Software is provided "as is", without warranty of any kind. 408 | 409 | % October 1999 410 | % August 2005 - Fixed up frequency ranges for odd and even sized filters 411 | % (previous code was a bit approximate) 412 | 413 | function f = lowpassfilter(sze, cutoff, n) 414 | 415 | if cutoff < 0 || cutoff > 0.5 416 | error('cutoff frequency must be between 0 and 0.5'); 417 | end 418 | 419 | if rem(n,1) ~= 0 || n < 1 420 | error('n must be an integer >= 1'); 421 | end 422 | 423 | if length(sze) == 1 424 | rows = sze; cols = sze; 425 | else 426 | rows = sze(1); cols = sze(2); 427 | end 428 | 429 | % Set up X and Y matrices with ranges normalised to +/- 0.5 430 | % The following code adjusts things appropriately for odd and even values 431 | % of rows and columns. 432 | if mod(cols,2) 433 | xrange = [-(cols-1)/2:(cols-1)/2]/(cols-1); 434 | else 435 | xrange = [-cols/2:(cols/2-1)]/cols; 436 | end 437 | 438 | if mod(rows,2) 439 | yrange = [-(rows-1)/2:(rows-1)/2]/(rows-1); 440 | else 441 | yrange = [-rows/2:(rows/2-1)]/rows; 442 | end 443 | 444 | [x,y] = meshgrid(xrange, yrange); 445 | radius = sqrt(x.^2 + y.^2); % A matrix with every pixel = radius relative to centre. 446 | f = ifftshift( 1 ./ (1.0 + (radius ./ cutoff).^(2*n)) ); % The filter 447 | return; 448 | --------------------------------------------------------------------------------