├── CX ├── __init__.py ├── enums.py ├── CX_helper.py ├── CSFlow.py └── CX_distance.py ├── utils ├── __init__.py ├── FetchManager.py └── helper.py ├── images ├── teaser_im.png └── trump_cartoon.jpg ├── model.py ├── config.py ├── vgg_model.py ├── README.md └── single_image_animation.py /CX/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /images/teaser_im.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/roimehrez/contextualLoss/HEAD/images/teaser_im.png -------------------------------------------------------------------------------- /images/trump_cartoon.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/roimehrez/contextualLoss/HEAD/images/trump_cartoon.jpg -------------------------------------------------------------------------------- /CX/enums.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class Distance(Enum): 5 | L2 = 0 6 | DotProduct = 1 7 | 8 | 9 | class TensorAxis: 10 | N = 0 11 | H = 1 12 | W = 2 13 | C = 3 -------------------------------------------------------------------------------- /utils/FetchManager.py: -------------------------------------------------------------------------------- 1 | class FetchManager: 2 | def __init__(self, sess, fetches): 3 | self.fetches = fetches 4 | self.sess = sess 5 | 6 | def fetch(self, feed_dictionary, additional_fetches=[]): 7 | fetches = self.fetches + additional_fetches 8 | evaluation = self.sess.run(fetches, feed_dictionary) 9 | return {k:v for k,v in zip(fetches, evaluation)} -------------------------------------------------------------------------------- /utils/helper.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------- 2 | # code credits: https://github.com/CQFIO/PhotographicImageSynthesis 3 | # --------------------------------------------------- 4 | import numpy as np 5 | import scipy 6 | from config import * 7 | import tensorflow as tf 8 | 9 | 10 | def read_image(file_name, resize=True, fliplr=False): 11 | image = scipy.misc.imread(file_name) 12 | if resize: 13 | image = scipy.misc.imresize(image, size=config.TRAIN.resize, interp='bilinear', mode=None) 14 | if fliplr: 15 | image = np.fliplr(image) 16 | image = np.float32(image) 17 | return np.expand_dims(image, axis=0) 18 | 19 | 20 | def save_image(output, file_name): 21 | output = np.minimum(np.maximum(output, 0.0), 255.0) 22 | scipy.misc.toimage(output.squeeze(axis=0), cmin=0, cmax=255).save(file_name) 23 | return 24 | 25 | 26 | def write_loss_in_txt(loss_list, epoch): 27 | target = open(config.TRAIN.out_dir + "/%04d/score.txt" % epoch, 'w') 28 | target.write("%f" % np.mean(loss_list[np.where(loss_list)])) 29 | target.close() 30 | 31 | 32 | def random_crop_together(im1, im2, size): 33 | images = tf.concat([im1, im2], axis=0) 34 | images_croped = tf.random_crop(images, size=size) 35 | im1, im2 = tf.split(images_croped, 2, axis=0) 36 | return im1, im2 37 | 38 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------- 2 | # code credits: https://github.com/CQFIO/PhotographicImageSynthesis 3 | # --------------------------------------------------- 4 | 5 | import tensorflow.contrib.slim as slim 6 | from vgg_model import * 7 | from config import * 8 | 9 | 10 | # this function have been modify such that the images are portrait and not landscape 11 | def recursive_generator(input_image, width): 12 | ar = config.TRAIN.aspect_ratio 13 | if width >= 128: 14 | dim = 512 // config.TRAIN.reduce_dim 15 | else: 16 | dim = 1024 // config.TRAIN.reduce_dim 17 | 18 | if width == 4: 19 | input = input_image 20 | else: 21 | downsampled_width = width // 2 22 | downsampled_input = tf.image.resize_area(input_image, (downsampled_width, downsampled_width // ar), align_corners=False) 23 | recursive_call = recursive_generator(downsampled_input, downsampled_width) 24 | predicted_on_downsampled = tf.image.resize_bilinear(recursive_call, (width, width // ar), align_corners=True) 25 | input = tf.concat([predicted_on_downsampled, input_image], 3) 26 | 27 | net = slim.conv2d(input, dim, [3, 3], rate=1, normalizer_fn=slim.layer_norm, activation_fn=lrelu, scope='g_' + str(width) + '_conv1') 28 | net = slim.conv2d(net, dim, [3, 3], rate=1, normalizer_fn=slim.layer_norm, activation_fn=lrelu, scope='g_' + str(width) + '_conv2') 29 | 30 | if width == config.TRAIN.sp*config.TRAIN.aspect_ratio: 31 | net = slim.conv2d(net, 3, [1, 1], rate=1, activation_fn=None, scope='g_' + str(width) + '_conv100') 32 | net = (net + 1.0) / 2.0 * 255.0 33 | return net 34 | 35 | 36 | -------------------------------------------------------------------------------- /CX/CX_helper.py: -------------------------------------------------------------------------------- 1 | from CX import CSFlow 2 | import tensorflow as tf 3 | 4 | 5 | def random_sampling(tensor_NHWC, n, indices=None): 6 | N, H, W, C = tf.convert_to_tensor(tensor_NHWC).shape.as_list() 7 | S = H * W 8 | tensor_NSC = tf.reshape(tensor_NHWC, [N, S, C]) 9 | all_indices = list(range(S)) 10 | shuffled_indices = tf.random_shuffle(all_indices) 11 | indices = tf.gather(shuffled_indices, list(range(n)), axis=0) if indices is None else indices 12 | indices_old = tf.random_uniform([n], 0, S, tf.int32) if indices is None else indices 13 | res = tf.gather(tensor_NSC, indices, axis=1) 14 | return res, indices 15 | 16 | 17 | def random_pooling(feats, output_1d_size=100): 18 | is_input_tensor = type(feats) is tf.Tensor 19 | 20 | if is_input_tensor: 21 | feats = [feats] 22 | 23 | # convert all inputs to tensors 24 | feats = [tf.convert_to_tensor(feats_i) for feats_i in feats] 25 | 26 | N, H, W, C = feats[0].shape.as_list() 27 | feats_sampled_0, indices = random_sampling(feats[0], output_1d_size ** 2) 28 | res = [feats_sampled_0] 29 | for i in range(1, len(feats)): 30 | feats_sampled_i, _ = random_sampling(feats[i], -1, indices) 31 | res.append(feats_sampled_i) 32 | 33 | res = [tf.reshape(feats_sampled_i, [N, output_1d_size, output_1d_size, C]) for feats_sampled_i in res] 34 | if is_input_tensor: 35 | return res[0] 36 | return res 37 | 38 | 39 | def crop_quarters(feature_tensor): 40 | N, fH, fW, fC = feature_tensor.shape.as_list() 41 | quarters_list = [] 42 | quarter_size = [N, round(fH / 2), round(fW / 2), fC] 43 | quarters_list.append(tf.slice(feature_tensor, [0, 0, 0, 0], quarter_size)) 44 | quarters_list.append(tf.slice(feature_tensor, [0, round(fH / 2), 0, 0], quarter_size)) 45 | quarters_list.append(tf.slice(feature_tensor, [0, 0, round(fW / 2), 0], quarter_size)) 46 | quarters_list.append(tf.slice(feature_tensor, [0, round(fH / 2), round(fW / 2), 0], quarter_size)) 47 | feature_tensor = tf.concat(quarters_list, axis=0) 48 | return feature_tensor 49 | 50 | 51 | def CX_loss_helper(vgg_A, vgg_B, CX_config): 52 | if CX_config.crop_quarters is True: 53 | vgg_A = crop_quarters(vgg_A) 54 | vgg_B = crop_quarters(vgg_B) 55 | 56 | N, fH, fW, fC = vgg_A.shape.as_list() 57 | if fH * fW <= CX_config.max_sampling_1d_size ** 2: 58 | print(' #### Skipping pooling for CX....') 59 | else: 60 | print(' #### pooling for CX %d**2 out of %dx%d' % (CX_config.max_sampling_1d_size, fH, fW)) 61 | vgg_A, vgg_B = random_pooling([vgg_A, vgg_B], output_1d_size=CX_config.max_sampling_1d_size) 62 | 63 | CX_loss = CSFlow.CX_loss(vgg_A, vgg_B, distance=CX_config.Dist, nnsigma=CX_config.nn_stretch_sigma) 64 | return CX_loss 65 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from easydict import EasyDict as edict 3 | import json 4 | from CX.enums import Distance 5 | import tensorflow as tf 6 | import numpy as np 7 | import re 8 | 9 | celebA = False 10 | single_image = False 11 | zero_tensor = tf.constant(0.0, dtype=tf.float32) 12 | config = edict() 13 | 14 | #--------------------------------------------- 15 | # update the right paths 16 | config.base_dir = 'C:/DATA/person2person/single/' 17 | config.single_image_B_file_name = 'images/trump_cartoon.jpg' 18 | config.vgg_model_path = 'C:/DATA/VGG_Model/imagenet-vgg-verydeep-19.mat' 19 | #--------------------------------------------- 20 | 21 | 22 | 23 | config.W = edict() 24 | # weights 25 | config.W.CX = 1.0 26 | config.W.CX_content = 1.0 27 | 28 | # train parameters 29 | config.TRAIN = edict() 30 | config.TRAIN.is_train = True #change to True of you want to train 31 | config.TRAIN.sp = 256 32 | config.TRAIN.aspect_ratio = 1 # 1 33 | config.TRAIN.resize = [config.TRAIN.sp * config.TRAIN.aspect_ratio, config.TRAIN.sp] 34 | config.TRAIN.crop_size = [config.TRAIN.sp * config.TRAIN.aspect_ratio, config.TRAIN.sp] 35 | config.TRAIN.A_data_dir = 'train' 36 | config.TRAIN.out_dir = "result/" 37 | config.TRAIN.num_epochs = 10 38 | config.TRAIN.reduce_dim = 2 #use of smaller CRN model 39 | config.TRAIN.every_nth_frame = 1 #train using all frames 40 | 41 | config.VAL = edict() 42 | config.VAL.A_data_dir = 'test' 43 | config.VAL.every_nth_frame = 1 44 | 45 | config.TEST = edict() 46 | config.TEST.is_test = not config.TRAIN.is_train 47 | config.TEST.A_data_dir = config.VAL.A_data_dir 48 | # config.TEST.every_nth_frame = 5 49 | config.TEST.out_dir_postfix = "/test" 50 | config.TEST.random_crop = False # if False, take the top left corner 51 | 52 | config.CX = edict() 53 | config.CX.crop_quarters = False 54 | config.CX.max_sampling_1d_size = 65 55 | # config.dis.feat_layers = {'conv1_1': 1.0,'conv2_1': 1.0, 'conv3_1': 1.0, 'conv4_1': 1.0,'conv5_1': 1.0} 56 | config.CX.feat_layers = {'conv3_2': 1.0, 'conv4_2': 1.0} 57 | config.CX.feat_content_layers = {'conv4_2': 1.0} # for single image 58 | config.CX.Dist = Distance.DotProduct 59 | config.CX.nn_stretch_sigma = 0.5#0.1 60 | config.CX.patch_size = 5 61 | config.CX.patch_stride = 2 62 | 63 | 64 | def last_two_nums(str): 65 | if str.endswith('vgg_input_im') or str is 'RGB': 66 | return 'rgb' 67 | all_nums = re.findall(r'\d+', str) 68 | return all_nums[-2] + all_nums[-1] 69 | 70 | 71 | 72 | 73 | 74 | config.expirament_postfix = 'single_im' 75 | if config.W.CX > 0: 76 | config.expirament_postfix += "_CXt" #CX_target 77 | config.expirament_postfix += '_'.join([last_two_nums(layer) for layer in sorted(config.CX.feat_layers.keys())]) 78 | config.expirament_postfix += '_{}'.format(config.W.CX) 79 | if config.W.CX_content: 80 | config.expirament_postfix += "_CXs" #CX_source 81 | config.expirament_postfix += '_'.join([last_two_nums(layer) for layer in sorted(config.CX.feat_content_layers.keys())]) 82 | config.expirament_postfix += '_{}'.format(config.W.CX_content) 83 | 84 | 85 | # uncomment and update for test 86 | # config.expirament_postfix = 'm2f_D32_42_1.0(s0.5)_DC42_1.0' 87 | 88 | config.TRAIN.out_dir += config.expirament_postfix 89 | config.TEST.out_dir = config.TRAIN.out_dir 90 | if not os.path.exists(config.TRAIN.out_dir): 91 | os.makedirs(config.TRAIN.out_dir) 92 | 93 | 94 | -------------------------------------------------------------------------------- /vgg_model.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------- 2 | # code credits: https://github.com/CQFIO/PhotographicImageSynthesis 3 | # --------------------------------------------------- 4 | 5 | import tensorflow as tf 6 | import tensorflow.contrib.slim as slim 7 | import numpy as np 8 | import scipy.io 9 | from config import * 10 | 11 | 12 | def lrelu(x): 13 | return tf.maximum(0.2 * x, x) 14 | 15 | 16 | def build_net(ntype, nin, nwb=None, name=None): 17 | if ntype == 'conv': 18 | return tf.nn.relu(tf.nn.conv2d(nin, nwb[0], strides=[1, 1, 1, 1], padding='SAME', name=name) + nwb[1]) 19 | elif ntype == 'pool': 20 | return tf.nn.avg_pool(nin, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 21 | 22 | 23 | def get_weight_bias(vgg_layers, i): 24 | weights = vgg_layers[i][0][0][2][0][0] 25 | weights = tf.constant(weights) 26 | bias = vgg_layers[i][0][0][2][0][1] 27 | bias = tf.constant(np.reshape(bias, (bias.size))) 28 | return weights, bias 29 | 30 | 31 | def build_vgg19(input, reuse=False): 32 | if reuse: 33 | tf.get_variable_scope().reuse_variables() 34 | net = {} 35 | vgg_rawnet = scipy.io.loadmat(config.vgg_model_path) 36 | vgg_layers = vgg_rawnet['layers'][0] 37 | net['input'] = input - np.array([123.6800, 116.7790, 103.9390]).reshape((1, 1, 1, 3)) 38 | net['conv1_1'] = build_net('conv', net['input'], get_weight_bias(vgg_layers, 0), name='vgg_conv1_1') 39 | net['conv1_2'] = build_net('conv', net['conv1_1'], get_weight_bias(vgg_layers, 2), name='vgg_conv1_2') 40 | net['pool1'] = build_net('pool', net['conv1_2']) 41 | net['conv2_1'] = build_net('conv', net['pool1'], get_weight_bias(vgg_layers, 5), name='vgg_conv2_1') 42 | net['conv2_2'] = build_net('conv', net['conv2_1'], get_weight_bias(vgg_layers, 7), name='vgg_conv2_2') 43 | net['pool2'] = build_net('pool', net['conv2_2']) 44 | net['conv3_1'] = build_net('conv', net['pool2'], get_weight_bias(vgg_layers, 10), name='vgg_conv3_1') 45 | net['conv3_2'] = build_net('conv', net['conv3_1'], get_weight_bias(vgg_layers, 12), name='vgg_conv3_2') 46 | net['conv3_3'] = build_net('conv', net['conv3_2'], get_weight_bias(vgg_layers, 14), name='vgg_conv3_3') 47 | net['conv3_4'] = build_net('conv', net['conv3_3'], get_weight_bias(vgg_layers, 16), name='vgg_conv3_4') 48 | net['pool3'] = build_net('pool', net['conv3_4']) 49 | net['conv4_1'] = build_net('conv', net['pool3'], get_weight_bias(vgg_layers, 19), name='vgg_conv4_1') 50 | net['conv4_2'] = build_net('conv', net['conv4_1'], get_weight_bias(vgg_layers, 21), name='vgg_conv4_2') 51 | net['conv4_3'] = build_net('conv', net['conv4_2'], get_weight_bias(vgg_layers, 23), name='vgg_conv4_3') 52 | net['conv4_4'] = build_net('conv', net['conv4_3'], get_weight_bias(vgg_layers, 25), name='vgg_conv4_4') 53 | net['pool4'] = build_net('pool', net['conv4_4']) 54 | net['conv5_1'] = build_net('conv', net['pool4'], get_weight_bias(vgg_layers, 28), name='vgg_conv5_1') 55 | net['conv5_2'] = build_net('conv', net['conv5_1'], get_weight_bias(vgg_layers, 30), name='vgg_conv5_2') 56 | net['conv5_3'] = build_net('conv', net['conv5_2'], get_weight_bias(vgg_layers, 32), name='vgg_conv5_3') 57 | net['conv5_4'] = build_net('conv', net['conv5_3'], get_weight_bias(vgg_layers, 34), name='vgg_conv5_4') 58 | net['pool5'] = build_net('pool', net['conv5_4']) 59 | return net 60 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # The Contextual Loss [[project page]](https://cgm.technion.ac.il/Computer-Graphics-Multimedia/Software/Contextual/) 2 | 3 | This is a Tensorflow implementation of the Contextual loss function as reported in the following papers: 4 | (PyTorch implementation is also available - see bellow) 5 | 6 | ### The Contextual Loss for Image Transformation with Non-Aligned Data, [arXiv](https://arxiv.org/abs/1803.02077) 7 | ### Learning to Maintain Natural Image Statistics, [arXiv](https://arxiv.org/abs/1803.04626) 8 | 9 | [Roey Mechrez*](https://roimehrez.github.io/), Itamar Talmi*, Firas Shama, [Lihi Zelnik-Manor](http://lihi.eew.technion.ac.il/). [The Technion](http://cgm.technion.ac.il/) 10 | 11 | Copyright 2018 Itamar Talmi and Roey Mechrez Licensed for noncommercial research use only. 12 | 13 |
14 | 15 |
16 | 17 | ## Setup 18 | 19 | ### Background 20 | This code is mainly the contextual loss function. The two papers have many applications, here we provide only one applications: animation from single image. 21 | 22 | An example pre-trained model can be download from this [link](https://www.dropbox.com/s/37nz4hy7ai4pqxc/single_im_D32_42_1.0_DC42_1.0.zip?dl=0) 23 | 24 | The data for this example can be download from this [link](https://www.dropbox.com/s/ggb6v6rv1a0212y/single.zip?dl=0) 25 | 26 | ### Requirement 27 | Required python libraries: Tensorflow (>=1.0, <1.9, tested on 1.4) + Scipy + Numpy + easydict 28 | 29 | Tested in Windows + Intel i7 CPU + Nvidia Titan Xp (and 1080ti) with Cuda (>=8.0) and CuDNN (>=5.0). CPU mode should also work with minor changes. 30 | 31 | 32 | ### Quick Start (Testing) 33 | 1. Clone this repository. 34 | 2. Download the pretrained model from this [link](https://www.dropbox.com/s/q3wjtaxr76cdx3t/imagenet-vgg-verydeep-19.mat?dl=0) 35 | 3. Extract the zip file under ```result``` folder. The models should be in ```based_dir/result/single_im_D32_42_1.0_DC42_1.0/``` 36 | 3. Update the ```config.base_dir``` and ```config.vgg_model_path``` in ```config.py``` and run: ``` single_image_animation.py``` 37 | 38 | ### Training 39 | 1. Change ```config.TRAIN.to_train``` to ```True``` 40 | 2. Arrange the paths to the data, should have ```train``` and ```test``` folders 41 | 2. run ``` single_image_animation.py ``` for 10 epochs. 42 | 43 | ### Pytorch implemntation 44 | We have also released a PyTorch implementation of the loss function. See ```CX/CX_distance.py```. Note that we havn't test this implemntation to reproduce the results in the paper. 45 | 46 | 47 | ## License 48 | 49 | This software is provided under the provisions of the Lesser GNU Public License (LGPL). 50 | see: http://www.gnu.org/copyleft/lesser.html. 51 | 52 | This software can be used only for research purposes, you should cite 53 | the aforementioned papers in any resulting publication. 54 | 55 | The Software is provided "as is", without warranty of any kind. 56 | 57 | 58 | ## Citation 59 | If you use our code for research, please cite our paper: 60 | ``` 61 | @article{mechrez2018contextual, 62 | title={The Contextual Loss for Image Transformation with Non-Aligned Data}, 63 | author={Mechrez, Roey and Talmi, Itamar and Zelnik-Manor, Lihi}, 64 | journal={arXiv preprint arXiv:1803.02077}, 65 | year={2018} 66 | } 67 | @article{mechrez2018Learning, 68 | title={Learning to Maintain Natural Image Statistics, [arXiv](https://arxiv.org/abs/1803.04626)}, 69 | author={Mechrez, Roey and Talmi, Itamar and Shama, Firas and Zelnik-Manor, Lihi}, 70 | journal={arXiv preprint arXiv:1803.04626}, 71 | year={2018} 72 | } 73 | ``` 74 | 75 | 76 | ## Code References 77 | 78 | [1] Template Matching with Deformable Diversity Similarity, https://github.com/roimehrez/DDIS 79 | 80 | [2] Photographic Image Synthesis with Cascaded Refinement Networks https://cqf.io/ImageSynthesis/ 81 | 82 | -------------------------------------------------------------------------------- /single_image_animation.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------- 2 | # code credits: https://github.com/CQFIO/PhotographicImageSynthesis 3 | # --------------------------------------------------- 4 | from __future__ import division 5 | 6 | import time 7 | 8 | import utils.helper as helper 9 | from CX.CX_helper import * 10 | from model import * 11 | from utils.FetchManager import * 12 | 13 | sess = tf.Session() 14 | 15 | # --------------------------------------------------- 16 | # graph 17 | # --------------------------------------------------- 18 | with tf.variable_scope(tf.get_variable_scope()): 19 | input_A = tf.placeholder(tf.float32, [None, None, None, 3]) 20 | input_B = tf.placeholder(tf.float32, [None, None, None, 3]) 21 | input_A_test = tf.placeholder(tf.float32, [None, None, None, 3]) 22 | input_image_A, real_image_B = helper.random_crop_together(input_A, input_B, [2, config.TRAIN.resize[0], config.TRAIN.resize[1], 3]) 23 | with tf.variable_scope("g") as scope: 24 | generator = recursive_generator(input_image_A, config.TRAIN.sp) 25 | scope.reuse_variables() 26 | generator_test = recursive_generator(input_A_test, config.TRAIN.sp) 27 | weight = tf.placeholder(tf.float32) 28 | vgg_real = build_vgg19(real_image_B) 29 | vgg_fake = build_vgg19(generator, reuse=True) 30 | vgg_input = build_vgg19(input_image_A, reuse=True) 31 | 32 | 33 | ## --- contextual style/target--- 34 | if config.W.CX > 0: 35 | CX_loss_list = [w * CX_loss_helper(vgg_real[layer], vgg_fake[layer], config.CX) 36 | for layer, w in config.CX.feat_layers.items()] 37 | CX_style_loss = tf.reduce_sum(CX_loss_list) 38 | CX_style_loss *= config.W.CX 39 | else: 40 | CX_style_loss = zero_tensor 41 | 42 | ## --- contextual content/source--- 43 | if config.W.CX_content > 0: 44 | CX_loss_content_list = [w * CX_loss_helper(vgg_input[layer], vgg_fake[layer], config.CX) 45 | for layer, w in config.CX.feat_content_layers.items()] 46 | CX_content_loss = tf.reduce_sum(CX_loss_content_list) 47 | CX_content_loss *= config.W.CX_content 48 | else: 49 | CX_content_loss = zero_tensor 50 | 51 | ## --- total loss --- 52 | G_loss = CX_style_loss + CX_content_loss 53 | 54 | 55 | # create the optimization 56 | lr = tf.placeholder(tf.float32) 57 | var_list = [var for var in tf.trainable_variables() if var.name.startswith('g/g_')] 58 | G_opt = tf.train.AdamOptimizer(learning_rate=lr).minimize(G_loss, var_list=var_list) 59 | saver = tf.train.Saver(max_to_keep=1000) 60 | sess.run(tf.global_variables_initializer()) 61 | 62 | 63 | # load from checkpoint if exist 64 | def load(dir): 65 | ckpt = tf.train.get_checkpoint_state(dir) 66 | if ckpt: 67 | print('loaded ' + ckpt.model_checkpoint_path) 68 | saver.restore(sess, ckpt.model_checkpoint_path) 69 | return ckpt 70 | 71 | 72 | ckpt = load(config.TRAIN.out_dir) 73 | 74 | 75 | # --------------------------------------------------- 76 | # train 77 | # --------------------------------------------------- 78 | if config.TRAIN.is_train: 79 | file_list = os.listdir(config.base_dir + config.TRAIN.A_data_dir) 80 | val_file_list = os.listdir(config.base_dir + config.VAL.A_data_dir) 81 | file_list = np.random.permutation(file_list) 82 | assert len(file_list) > 0 83 | train_file_list = file_list[0::config.TRAIN.every_nth_frame] 84 | val_file_list = val_file_list[0::config.VAL.every_nth_frame] 85 | g_loss = np.zeros(len(train_file_list), dtype=float) 86 | fetcher = FetchManager(sess, [G_opt, G_loss]) 87 | B_file_name = config.single_image_B_file_name 88 | B_image = helper.read_image(B_file_name) # training image B 89 | 90 | ## ------------ epoch loop ------------------------- 91 | for epoch in range(1, config.TRAIN.num_epochs + 1): 92 | epoch_dir = config.TRAIN.out_dir + "/%04d" % epoch 93 | if os.path.isdir(epoch_dir): 94 | continue 95 | cnt = 0 96 | 97 | ## ------------ batch loop ------------------------- 98 | for ind in np.random.permutation(len(train_file_list)):# 99 | st = time.time() 100 | cnt += 1 101 | 102 | A_file_name = config.base_dir + config.TRAIN.A_data_dir + '/' + train_file_list[ind] 103 | if not os.path.isfile(A_file_name) or not os.path.isfile(A_file_name): 104 | continue 105 | A_image = helper.read_image(A_file_name) # training image A 106 | 107 | feed_dict = {input_A: A_image, input_B: B_image, lr: 1e-4} 108 | 109 | #session run 110 | eval = fetcher.fetch(feed_dict, [CX_style_loss, CX_content_loss]) 111 | 112 | g_loss[ind] = eval[G_loss] 113 | log = "epoch:%d | cnt:%d | time:%.2f | loss:%.2f || dis_style:%.2f | dis_content:%.2f " % \ 114 | (epoch, cnt, time.time() - st, np.mean(g_loss[np.where(g_loss)]), eval[CX_style_loss], eval[CX_content_loss]) 115 | print(log) 116 | ##------------ end batch loop ------------------- 117 | 118 | # -------------- save the model ------------------ 119 | # we use loop with try and catch to verify that the save was done. when saving on Dropbox it sometimes cause an error. 120 | for i in range(5): 121 | try: 122 | if not os.path.exists(epoch_dir): 123 | os.makedirs(epoch_dir) 124 | helper.write_loss_in_txt(g_loss, epoch) 125 | saver.save(sess, config.TRAIN.out_dir + "/model.ckpt") 126 | except: 127 | time.sleep(1) 128 | 129 | ## ------------ validation loop ------------------------- 130 | for ind in range(len(val_file_list)): 131 | A_file_name_val = config.base_dir + config.VAL.A_data_dir + '/' + val_file_list[ind] 132 | if not os.path.isfile(A_file_name_val): # test label 133 | continue 134 | A_image_val = helper.read_image(A_file_name_val) # training image A 135 | # B_image_val = helper.read_image(B_file_name_val) # training image A 136 | output = sess.run(generator_test, feed_dict={input_A_test: A_image_val}) 137 | output = np.concatenate([A_image_val, output, B_image], axis=2) 138 | helper.save_image(output, config.TRAIN.out_dir + "/%04d/" % epoch + val_file_list[ind].replace('.jpg', '_out.jpg')) 139 | 140 | 141 | 142 | # --------------------------------------------------- 143 | # test 144 | # --------------------------------------------------- 145 | if config.TEST.is_test: 146 | test_file_list = os.listdir(config.base_dir + config.TEST.A_data_dir) 147 | if not os.path.isdir(config.TEST.out_dir + config.TEST.out_dir_postfix): 148 | os.makedirs(config.TEST.out_dir + config.TEST.out_dir_postfix) 149 | time_list = np.zeros(len(test_file_list), dtype=float) 150 | for ind in range(len(test_file_list)): 151 | A_file_name_val = config.base_dir + config.TEST.A_data_dir + '/' + test_file_list[ind] 152 | if not os.path.isfile(A_file_name_val): 153 | continue 154 | A_image_val = helper.read_image(A_file_name_val, fliplr=False) # training image A 155 | st = time.time() 156 | output = sess.run(generator_test, feed_dict={input_A_test: A_image_val}) 157 | et = time.time() 158 | output = np.concatenate([A_image_val, output], axis=2)#B_image_val 159 | helper.save_image(output, config.TEST.out_dir + config.TEST.out_dir_postfix + "/" + test_file_list[ind].replace('.jpg', '_out.jpg')) 160 | time_list[ind] = et - st 161 | print("test for image #: %d, time: %1.4f" % (ind, et - st)) 162 | print('average time per image: %f' % time_list.mean()) 163 | -------------------------------------------------------------------------------- /CX/CSFlow.py: -------------------------------------------------------------------------------- 1 | from logging import exception 2 | import CX.enums as enums 3 | import tensorflow as tf 4 | from CX.enums import TensorAxis, Distance 5 | 6 | 7 | class CSFlow: 8 | def __init__(self, sigma = float(0.1), b = float(1.0)): 9 | self.b = b 10 | self.sigma = sigma 11 | 12 | def __calculate_CS(self, scaled_distances, axis_for_normalization = TensorAxis.C): 13 | self.scaled_distances = scaled_distances 14 | self.cs_weights_before_normalization = tf.exp((self.b - scaled_distances) / self.sigma, name='weights_before_normalization') 15 | self.cs_NHWC = CSFlow.sum_normalize(self.cs_weights_before_normalization, axis_for_normalization) 16 | 17 | def reversed_direction_CS(self): 18 | cs_flow_opposite = CSFlow(self.sigma, self.b) 19 | cs_flow_opposite.raw_distances = self.raw_distances 20 | work_axis = [TensorAxis.H, TensorAxis.W] 21 | relative_dist = cs_flow_opposite.calc_relative_distances(axis=work_axis) 22 | cs_flow_opposite.__calculate_CS(relative_dist, work_axis) 23 | return cs_flow_opposite 24 | 25 | # -- 26 | @staticmethod 27 | def create_using_L2(I_features, T_features, sigma = float(0.1), b = float(1.0)): 28 | cs_flow = CSFlow(sigma, b) 29 | # for debug: 30 | # I_features = tf.concat([I_features, I_features], axis=1) 31 | with tf.name_scope('CS'): 32 | # assert I_features.shape[TensorAxis.C] == T_features.shape[TensorAxis.C] 33 | c = T_features.shape[TensorAxis.C].value 34 | sT = T_features.shape.as_list() 35 | sI = I_features.shape.as_list() 36 | 37 | Ivecs = tf.reshape(I_features, (sI[TensorAxis.N], -1, sI[TensorAxis.C])) 38 | Tvecs = tf.reshape(T_features, (sI[TensorAxis.N], -1, sT[TensorAxis.C])) 39 | r_Ts = tf.reduce_sum(Tvecs * Tvecs, 2) 40 | r_Is = tf.reduce_sum(Ivecs * Ivecs, 2) 41 | raw_distances_list = [] 42 | for i in range(sT[TensorAxis.N]): 43 | Ivec, Tvec, r_T, r_I = Ivecs[i], Tvecs[i], r_Ts[i], r_Is[i] 44 | A = Tvec @ tf.transpose(Ivec) 45 | cs_flow.A = A 46 | # A = tf.matmul(Tvec, tf.transpose(Ivec)) 47 | r_T = tf.reshape(r_T, [-1, 1]) # turn to column vector 48 | dist = r_T - 2 * A + r_I 49 | cs_shape = sI[:3] + [dist.shape[0].value] 50 | cs_shape[0] = 1 51 | dist = tf.reshape(tf.transpose(dist), cs_shape) 52 | # protecting against numerical problems, dist should be positive 53 | dist = tf.maximum(float(0.0), dist) 54 | # dist = tf.sqrt(dist) 55 | raw_distances_list += [dist] 56 | 57 | cs_flow.raw_distances = tf.convert_to_tensor([tf.squeeze(raw_dist, axis=0) for raw_dist in raw_distances_list]) 58 | 59 | relative_dist = cs_flow.calc_relative_distances() 60 | cs_flow.__calculate_CS(relative_dist) 61 | return cs_flow 62 | 63 | #-- 64 | @staticmethod 65 | def create_using_dotP(I_features, T_features, sigma = float(1.0), b = float(1.0)): 66 | cs_flow = CSFlow(sigma, b) 67 | with tf.name_scope('CS'): 68 | # prepare feature before calculating cosine distance 69 | T_features, I_features = cs_flow.center_by_T(T_features, I_features) 70 | with tf.name_scope('TFeatures'): 71 | T_features = CSFlow.l2_normalize_channelwise(T_features) 72 | with tf.name_scope('IFeatures'): 73 | I_features = CSFlow.l2_normalize_channelwise(I_features) 74 | 75 | # work seperatly for each example in dim 1 76 | cosine_dist_l = [] 77 | N, _, __, ___ = T_features.shape.as_list() 78 | for i in range(N): 79 | T_features_i = tf.expand_dims(T_features[i, :, :, :], 0) 80 | I_features_i = tf.expand_dims(I_features[i, :, :, :], 0) 81 | patches_HWCN_i = cs_flow.patch_decomposition(T_features_i) 82 | cosine_dist_i = tf.nn.conv2d(I_features_i, patches_HWCN_i, strides=[1, 1, 1, 1], 83 | padding='VALID', use_cudnn_on_gpu=True, name='cosine_dist') 84 | cosine_dist_l.append(cosine_dist_i) 85 | 86 | cs_flow.cosine_dist = tf.concat(cosine_dist_l, axis = 0) 87 | 88 | cosine_dist_zero_to_one = -(cs_flow.cosine_dist - 1) / 2 89 | cs_flow.raw_distances = cosine_dist_zero_to_one 90 | 91 | relative_dist = cs_flow.calc_relative_distances() 92 | cs_flow.__calculate_CS(relative_dist) 93 | return cs_flow 94 | 95 | def calc_relative_distances(self, axis=TensorAxis.C): 96 | epsilon = 1e-5 97 | div = tf.reduce_min(self.raw_distances, axis=axis, keep_dims=True) 98 | # div = tf.reduce_mean(self.raw_distances, axis=axis, keep_dims=True) 99 | relative_dist = self.raw_distances / (div + epsilon) 100 | return relative_dist 101 | 102 | def weighted_average_dist(self, axis = TensorAxis.C): 103 | if not hasattr(self, 'raw_distances'): 104 | raise exception('raw_distances property does not exists. cant calculate weighted average l2') 105 | 106 | multiply = self.raw_distances * self.cs_NHWC 107 | return tf.reduce_sum(multiply, axis=axis, name='weightedDistPerPatch') 108 | 109 | # -- 110 | @staticmethod 111 | def create(I_features, T_features, distance : enums.Distance, nnsigma=float(1.0), b=float(1.0)): 112 | if distance.value == enums.Distance.DotProduct.value: 113 | cs_flow = CSFlow.create_using_dotP(I_features, T_features, nnsigma, b) 114 | elif distance.value == enums.Distance.L2.value: 115 | cs_flow = CSFlow.create_using_L2(I_features, T_features, nnsigma, b) 116 | else: 117 | raise "not supported distance " + distance.__str__() 118 | return cs_flow 119 | 120 | @staticmethod 121 | def sum_normalize(cs, axis=TensorAxis.C): 122 | reduce_sum = tf.reduce_sum(cs, axis, keep_dims=True, name='sum') 123 | return tf.divide(cs, reduce_sum, name='sumNormalized') 124 | 125 | def center_by_T(self, T_features, I_features): 126 | # assuming both input are of the same size 127 | 128 | # calculate stas over [batch, height, width], expecting 1x1xDepth tensor 129 | axes = [0, 1, 2] 130 | self.meanT, self.varT = tf.nn.moments( 131 | T_features, axes, name='TFeatures/moments') 132 | 133 | # we do not divide by std since its causing the histogram 134 | # for the final cs to be very thin, so the NN weights 135 | # are not distinctive, giving similar values for all patches. 136 | # stdT = tf.sqrt(varT, "stdT") 137 | # correct places with std zero 138 | # stdT[tf.less(stdT, tf.constant(0.001))] = tf.constant(1) 139 | 140 | # TODO check broadcasting here 141 | with tf.name_scope('TFeatures/centering'): 142 | self.T_features_centered = T_features - self.meanT 143 | with tf.name_scope('IFeatures/centering'): 144 | self.I_features_centered = I_features - self.meanT 145 | 146 | return self.T_features_centered, self.I_features_centered 147 | @staticmethod 148 | 149 | def l2_normalize_channelwise(features): 150 | norms = tf.norm(features, ord='euclidean', axis=TensorAxis.C, name='norm') 151 | # expanding the norms tensor to support broadcast division 152 | norms_expanded = tf.expand_dims(norms, TensorAxis.C) 153 | features = tf.divide(features, norms_expanded, name='normalized') 154 | return features 155 | 156 | def patch_decomposition(self, T_features): 157 | # patch decomposition 158 | # see https://stackoverflow.com/questions/40731433/understanding-tf-extract-image-patches-for-extracting-patches-from-an-image 159 | patch_size = 1 160 | patches_as_depth_vectors = tf.extract_image_patches( 161 | images=T_features, ksizes=[1, patch_size, patch_size, 1], 162 | strides=[1, 1, 1, 1], rates=[1, 1, 1, 1], padding='VALID', 163 | name='patches_as_depth_vectors') 164 | 165 | self.patches_NHWC = tf.reshape( 166 | patches_as_depth_vectors, 167 | shape=[-1, patch_size, patch_size, patches_as_depth_vectors.shape[3].value], 168 | name='patches_PHWC') 169 | 170 | self.patches_HWCN = tf.transpose( 171 | self.patches_NHWC, 172 | perm=[1, 2, 3, 0], 173 | name='patches_HWCP') # tf.conv2 ready format 174 | 175 | return self.patches_HWCN 176 | 177 | 178 | #-------------------------------------------------- 179 | # CX loss 180 | #-------------------------------------------------- 181 | 182 | 183 | def CX_loss(T_features, I_features, distance=Distance.L2, nnsigma=float(1.0)): 184 | T_features = tf.convert_to_tensor(T_features, dtype=tf.float32) 185 | I_features = tf.convert_to_tensor(I_features, dtype=tf.float32) 186 | 187 | with tf.name_scope('CX'): 188 | cs_flow = CSFlow.create(I_features, T_features, distance, nnsigma) 189 | # sum_normalize: 190 | height_width_axis = [TensorAxis.H, TensorAxis.W] 191 | # To: 192 | cs = cs_flow.cs_NHWC 193 | k_max_NC = tf.reduce_max(cs, axis=height_width_axis) 194 | CS = tf.reduce_mean(k_max_NC, axis=[1]) 195 | CX_as_loss = 1 - CS 196 | CX_loss = -tf.log(1 - CX_as_loss) 197 | CX_loss = tf.reduce_mean(CX_loss) 198 | return CX_loss -------------------------------------------------------------------------------- /CX/CX_distance.py: -------------------------------------------------------------------------------- 1 | # import tensorflow as tf 2 | import torch 3 | import numpy as np 4 | import sklearn.manifold.t_sne 5 | 6 | class TensorAxis: 7 | N = 0 8 | H = 1 9 | W = 2 10 | C = 3 11 | 12 | 13 | class CSFlow: 14 | def __init__(self, sigma=float(0.1), b=float(1.0)): 15 | self.b = b 16 | self.sigma = sigma 17 | 18 | def __calculate_CS(self, scaled_distances, axis_for_normalization=TensorAxis.C): 19 | self.scaled_distances = scaled_distances 20 | self.cs_weights_before_normalization = torch.exp((self.b - scaled_distances) / self.sigma) 21 | # self.cs_weights_before_normalization = 1 / (1 + scaled_distances) 22 | # self.cs_NHWC = CSFlow.sum_normalize(self.cs_weights_before_normalization, axis_for_normalization) 23 | self.cs_NHWC = self.cs_weights_before_normalization 24 | 25 | # def reversed_direction_CS(self): 26 | # cs_flow_opposite = CSFlow(self.sigma, self.b) 27 | # cs_flow_opposite.raw_distances = self.raw_distances 28 | # work_axis = [TensorAxis.H, TensorAxis.W] 29 | # relative_dist = cs_flow_opposite.calc_relative_distances(axis=work_axis) 30 | # cs_flow_opposite.__calculate_CS(relative_dist, work_axis) 31 | # return cs_flow_opposite 32 | 33 | # -- 34 | @staticmethod 35 | def create_using_L2(I_features, T_features, sigma=float(0.5), b=float(1.0)): 36 | cs_flow = CSFlow(sigma, b) 37 | sT = T_features.shape 38 | sI = I_features.shape 39 | 40 | Ivecs = torch.reshape(I_features, (sI[0], -1, sI[3])) 41 | Tvecs = torch.reshape(T_features, (sI[0], -1, sT[3])) 42 | r_Ts = torch.sum(Tvecs * Tvecs, 2) 43 | r_Is = torch.sum(Ivecs * Ivecs, 2) 44 | raw_distances_list = [] 45 | for i in range(sT[0]): 46 | Ivec, Tvec, r_T, r_I = Ivecs[i], Tvecs[i], r_Ts[i], r_Is[i] 47 | A = Tvec @ torch.transpose(Ivec, 0, 1) # (matrix multiplication) 48 | cs_flow.A = A 49 | # A = tf.matmul(Tvec, tf.transpose(Ivec)) 50 | r_T = torch.reshape(r_T, [-1, 1]) # turn to column vector 51 | dist = r_T - 2 * A + r_I 52 | dist = torch.reshape(torch.transpose(dist, 0, 1), shape=(1, sI[1], sI[2], dist.shape[0])) 53 | # protecting against numerical problems, dist should be positive 54 | dist = torch.clamp(dist, min=float(0.0)) 55 | # dist = tf.sqrt(dist) 56 | raw_distances_list += [dist] 57 | 58 | cs_flow.raw_distances = torch.cat(raw_distances_list) 59 | 60 | relative_dist = cs_flow.calc_relative_distances() 61 | cs_flow.__calculate_CS(relative_dist) 62 | return cs_flow 63 | 64 | # -- 65 | @staticmethod 66 | def create_using_L1(I_features, T_features, sigma=float(0.5), b=float(1.0)): 67 | cs_flow = CSFlow(sigma, b) 68 | sT = T_features.shape 69 | sI = I_features.shape 70 | 71 | Ivecs = torch.reshape(I_features, (sI[0], -1, sI[3])) 72 | Tvecs = torch.reshape(T_features, (sI[0], -1, sT[3])) 73 | raw_distances_list = [] 74 | for i in range(sT[0]): 75 | Ivec, Tvec = Ivecs[i], Tvecs[i] 76 | dist = torch.abs(torch.sum(Ivec.unsqueeze(1) - Tvec.unsqueeze(0), dim=2)) 77 | dist = torch.reshape(torch.transpose(dist, 0, 1), shape=(1, sI[1], sI[2], dist.shape[0])) 78 | # protecting against numerical problems, dist should be positive 79 | dist = torch.clamp(dist, min=float(0.0)) 80 | # dist = tf.sqrt(dist) 81 | raw_distances_list += [dist] 82 | 83 | cs_flow.raw_distances = torch.cat(raw_distances_list) 84 | 85 | relative_dist = cs_flow.calc_relative_distances() 86 | cs_flow.__calculate_CS(relative_dist) 87 | return cs_flow 88 | 89 | # -- 90 | @staticmethod 91 | def create_using_dotP(I_features, T_features, sigma=float(0.5), b=float(1.0)): 92 | cs_flow = CSFlow(sigma, b) 93 | # prepare feature before calculating cosine distance 94 | T_features, I_features = cs_flow.center_by_T(T_features, I_features) 95 | T_features = CSFlow.l2_normalize_channelwise(T_features) 96 | I_features = CSFlow.l2_normalize_channelwise(I_features) 97 | 98 | # work seperatly for each example in dim 1 99 | cosine_dist_l = [] 100 | N = T_features.size()[0] 101 | for i in range(N): 102 | T_features_i = T_features[i, :, :, :].unsqueeze_(0) # 1HWC --> 1CHW 103 | I_features_i = I_features[i, :, :, :].unsqueeze_(0).permute((0, 3, 1, 2)) 104 | patches_PC11_i = cs_flow.patch_decomposition(T_features_i) # 1HWC --> PC11, with P=H*W 105 | cosine_dist_i = torch.nn.functional.conv2d(I_features_i, patches_PC11_i) 106 | cosine_dist_1HWC = cosine_dist_i.permute((0, 2, 3, 1)) 107 | cosine_dist_l.append(cosine_dist_i.permute((0, 2, 3, 1))) # back to 1HWC 108 | 109 | cs_flow.cosine_dist = torch.cat(cosine_dist_l, dim=0) 110 | 111 | cs_flow.raw_distances = - (cs_flow.cosine_dist - 1) / 2 ### why - 112 | 113 | relative_dist = cs_flow.calc_relative_distances() 114 | cs_flow.__calculate_CS(relative_dist) 115 | return cs_flow 116 | 117 | def calc_relative_distances(self, axis=TensorAxis.C): 118 | epsilon = 1e-5 119 | div = torch.min(self.raw_distances, dim=axis, keepdim=True)[0] 120 | relative_dist = self.raw_distances / (div + epsilon) 121 | return relative_dist 122 | 123 | @staticmethod 124 | def sum_normalize(cs, axis=TensorAxis.C): 125 | reduce_sum = torch.sum(cs, dim=axis, keepdim=True) 126 | cs_normalize = torch.div(cs, reduce_sum) 127 | return cs_normalize 128 | 129 | def center_by_T(self, T_features, I_features): 130 | # assuming both input are of the same size 131 | # calculate stas over [batch, height, width], expecting 1x1xDepth tensor 132 | axes = [0, 1, 2] 133 | self.meanT = T_features.mean(0, keepdim=True).mean(1, keepdim=True).mean(2, keepdim=True) 134 | self.varT = T_features.var(0, keepdim=True).var(1, keepdim=True).var(2, keepdim=True) 135 | self.T_features_centered = T_features - self.meanT 136 | self.I_features_centered = I_features - self.meanT 137 | 138 | return self.T_features_centered, self.I_features_centered 139 | 140 | @staticmethod 141 | def l2_normalize_channelwise(features): 142 | norms = features.norm(p=2, dim=TensorAxis.C, keepdim=True) 143 | features = features.div(norms) 144 | return features 145 | 146 | def patch_decomposition(self, T_features): 147 | # 1HWC --> 11PC --> PC11, with P=H*W 148 | (N, H, W, C) = T_features.shape 149 | P = H * W 150 | patches_PC11 = T_features.reshape(shape=(1, 1, P, C)).permute(dims=(2, 3, 0, 1)) 151 | return patches_PC11 152 | 153 | @staticmethod 154 | def pdist2(x, keepdim=False): 155 | sx = x.shape 156 | x = x.reshape(shape=(sx[0], sx[1] * sx[2], sx[3])) 157 | differences = x.unsqueeze(2) - x.unsqueeze(1) 158 | distances = torch.sum(differences**2, -1) 159 | if keepdim: 160 | distances = distances.reshape(shape=(sx[0], sx[1], sx[2], sx[3])) 161 | return distances 162 | 163 | @staticmethod 164 | def calcR_static(sT, order='C', deformation_sigma=0.05): 165 | # oreder can be C or F (matlab order) 166 | pixel_count = sT[0] * sT[1] 167 | 168 | rangeRows = range(0, sT[1]) 169 | rangeCols = range(0, sT[0]) 170 | Js, Is = np.meshgrid(rangeRows, rangeCols) 171 | row_diff_from_first_row = Is 172 | col_diff_from_first_col = Js 173 | 174 | row_diff_from_first_row_3d_repeat = np.repeat(row_diff_from_first_row[:, :, np.newaxis], pixel_count, axis=2) 175 | col_diff_from_first_col_3d_repeat = np.repeat(col_diff_from_first_col[:, :, np.newaxis], pixel_count, axis=2) 176 | 177 | rowDiffs = -row_diff_from_first_row_3d_repeat + row_diff_from_first_row.flatten(order).reshape(1, 1, -1) 178 | colDiffs = -col_diff_from_first_col_3d_repeat + col_diff_from_first_col.flatten(order).reshape(1, 1, -1) 179 | R = rowDiffs ** 2 + colDiffs ** 2 180 | R = R.astype(np.float32) 181 | R = np.exp(-(R) / (2 * deformation_sigma ** 2)) 182 | return R 183 | 184 | 185 | 186 | 187 | 188 | 189 | # -------------------------------------------------- 190 | # CX loss 191 | # -------------------------------------------------- 192 | 193 | 194 | 195 | def CX_loss(T_features, I_features, deformation=False, dis=False): 196 | # T_features = tf.convert_to_tensor(T_features, dtype=tf.float32) 197 | # I_features = tf.convert_to_tensor(I_features, dtype=tf.float32) 198 | # since this is a convertion of tensorflow to pytorch we permute the tensor from 199 | # T_features = normalize_tensor(T_features) 200 | # I_features = normalize_tensor(I_features) 201 | 202 | # since this originally Tensorflow implemntation 203 | # we modify all tensors to be as TF convention and not as the convention of pytorch. 204 | def from_pt2tf(Tpt): 205 | Ttf = Tpt.permute(0, 2, 3, 1) 206 | return Ttf 207 | # N x C x H x W --> N x H x W x C 208 | T_features_tf = from_pt2tf(T_features) 209 | I_features_tf = from_pt2tf(I_features) 210 | 211 | # cs_flow = CSFlow.create_using_dotP(I_features_tf, T_features_tf, sigma=1.0) 212 | cs_flow = CSFlow.create_using_L2(I_features_tf, T_features_tf, sigma=1.0) 213 | # sum_normalize: 214 | # To: 215 | cs = cs_flow.cs_NHWC 216 | 217 | if deformation: 218 | deforma_sigma = 0.001 219 | sT = T_features_tf.shape[1:2 + 1] 220 | R = CSFlow.calcR_static(sT, deformation_sigma=deforma_sigma) 221 | cs *= torch.Tensor(R).unsqueeze(dim=0).cuda() 222 | 223 | if dis: 224 | CS = [] 225 | k_max_NC = torch.max(torch.max(cs, dim=1)[1], dim=1)[1] 226 | indices = k_max_NC.cpu() 227 | N, C = indices.shape 228 | for i in range(N): 229 | CS.append((C - len(torch.unique(indices[i, :]))) / C) 230 | score = torch.FloatTensor(CS) 231 | else: 232 | # reduce_max X and Y dims 233 | # cs = CSFlow.pdist2(cs,keepdim=True) 234 | k_max_NC = torch.max(torch.max(cs, dim=1)[0], dim=1)[0] 235 | # reduce mean over C dim 236 | CS = torch.mean(k_max_NC, dim=1) 237 | # score = 1/CS 238 | # score = torch.exp(-CS*10) 239 | score = -torch.log(CS) 240 | # reduce mean over N dim 241 | # CX_loss = torch.mean(CX_loss) 242 | return score 243 | 244 | 245 | def symetric_CX_loss(T_features, I_features): 246 | score = (CX_loss(T_features, I_features) + CX_loss(I_features, T_features)) / 2 247 | return score 248 | --------------------------------------------------------------------------------