├── CX
├── __init__.py
├── enums.py
├── CX_helper.py
├── CSFlow.py
└── CX_distance.py
├── utils
├── __init__.py
├── FetchManager.py
└── helper.py
├── images
├── teaser_im.png
└── trump_cartoon.jpg
├── model.py
├── config.py
├── vgg_model.py
├── README.md
└── single_image_animation.py
/CX/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/images/teaser_im.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/roimehrez/contextualLoss/HEAD/images/teaser_im.png
--------------------------------------------------------------------------------
/images/trump_cartoon.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/roimehrez/contextualLoss/HEAD/images/trump_cartoon.jpg
--------------------------------------------------------------------------------
/CX/enums.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 |
4 | class Distance(Enum):
5 | L2 = 0
6 | DotProduct = 1
7 |
8 |
9 | class TensorAxis:
10 | N = 0
11 | H = 1
12 | W = 2
13 | C = 3
--------------------------------------------------------------------------------
/utils/FetchManager.py:
--------------------------------------------------------------------------------
1 | class FetchManager:
2 | def __init__(self, sess, fetches):
3 | self.fetches = fetches
4 | self.sess = sess
5 |
6 | def fetch(self, feed_dictionary, additional_fetches=[]):
7 | fetches = self.fetches + additional_fetches
8 | evaluation = self.sess.run(fetches, feed_dictionary)
9 | return {k:v for k,v in zip(fetches, evaluation)}
--------------------------------------------------------------------------------
/utils/helper.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------
2 | # code credits: https://github.com/CQFIO/PhotographicImageSynthesis
3 | # ---------------------------------------------------
4 | import numpy as np
5 | import scipy
6 | from config import *
7 | import tensorflow as tf
8 |
9 |
10 | def read_image(file_name, resize=True, fliplr=False):
11 | image = scipy.misc.imread(file_name)
12 | if resize:
13 | image = scipy.misc.imresize(image, size=config.TRAIN.resize, interp='bilinear', mode=None)
14 | if fliplr:
15 | image = np.fliplr(image)
16 | image = np.float32(image)
17 | return np.expand_dims(image, axis=0)
18 |
19 |
20 | def save_image(output, file_name):
21 | output = np.minimum(np.maximum(output, 0.0), 255.0)
22 | scipy.misc.toimage(output.squeeze(axis=0), cmin=0, cmax=255).save(file_name)
23 | return
24 |
25 |
26 | def write_loss_in_txt(loss_list, epoch):
27 | target = open(config.TRAIN.out_dir + "/%04d/score.txt" % epoch, 'w')
28 | target.write("%f" % np.mean(loss_list[np.where(loss_list)]))
29 | target.close()
30 |
31 |
32 | def random_crop_together(im1, im2, size):
33 | images = tf.concat([im1, im2], axis=0)
34 | images_croped = tf.random_crop(images, size=size)
35 | im1, im2 = tf.split(images_croped, 2, axis=0)
36 | return im1, im2
37 |
38 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------
2 | # code credits: https://github.com/CQFIO/PhotographicImageSynthesis
3 | # ---------------------------------------------------
4 |
5 | import tensorflow.contrib.slim as slim
6 | from vgg_model import *
7 | from config import *
8 |
9 |
10 | # this function have been modify such that the images are portrait and not landscape
11 | def recursive_generator(input_image, width):
12 | ar = config.TRAIN.aspect_ratio
13 | if width >= 128:
14 | dim = 512 // config.TRAIN.reduce_dim
15 | else:
16 | dim = 1024 // config.TRAIN.reduce_dim
17 |
18 | if width == 4:
19 | input = input_image
20 | else:
21 | downsampled_width = width // 2
22 | downsampled_input = tf.image.resize_area(input_image, (downsampled_width, downsampled_width // ar), align_corners=False)
23 | recursive_call = recursive_generator(downsampled_input, downsampled_width)
24 | predicted_on_downsampled = tf.image.resize_bilinear(recursive_call, (width, width // ar), align_corners=True)
25 | input = tf.concat([predicted_on_downsampled, input_image], 3)
26 |
27 | net = slim.conv2d(input, dim, [3, 3], rate=1, normalizer_fn=slim.layer_norm, activation_fn=lrelu, scope='g_' + str(width) + '_conv1')
28 | net = slim.conv2d(net, dim, [3, 3], rate=1, normalizer_fn=slim.layer_norm, activation_fn=lrelu, scope='g_' + str(width) + '_conv2')
29 |
30 | if width == config.TRAIN.sp*config.TRAIN.aspect_ratio:
31 | net = slim.conv2d(net, 3, [1, 1], rate=1, activation_fn=None, scope='g_' + str(width) + '_conv100')
32 | net = (net + 1.0) / 2.0 * 255.0
33 | return net
34 |
35 |
36 |
--------------------------------------------------------------------------------
/CX/CX_helper.py:
--------------------------------------------------------------------------------
1 | from CX import CSFlow
2 | import tensorflow as tf
3 |
4 |
5 | def random_sampling(tensor_NHWC, n, indices=None):
6 | N, H, W, C = tf.convert_to_tensor(tensor_NHWC).shape.as_list()
7 | S = H * W
8 | tensor_NSC = tf.reshape(tensor_NHWC, [N, S, C])
9 | all_indices = list(range(S))
10 | shuffled_indices = tf.random_shuffle(all_indices)
11 | indices = tf.gather(shuffled_indices, list(range(n)), axis=0) if indices is None else indices
12 | indices_old = tf.random_uniform([n], 0, S, tf.int32) if indices is None else indices
13 | res = tf.gather(tensor_NSC, indices, axis=1)
14 | return res, indices
15 |
16 |
17 | def random_pooling(feats, output_1d_size=100):
18 | is_input_tensor = type(feats) is tf.Tensor
19 |
20 | if is_input_tensor:
21 | feats = [feats]
22 |
23 | # convert all inputs to tensors
24 | feats = [tf.convert_to_tensor(feats_i) for feats_i in feats]
25 |
26 | N, H, W, C = feats[0].shape.as_list()
27 | feats_sampled_0, indices = random_sampling(feats[0], output_1d_size ** 2)
28 | res = [feats_sampled_0]
29 | for i in range(1, len(feats)):
30 | feats_sampled_i, _ = random_sampling(feats[i], -1, indices)
31 | res.append(feats_sampled_i)
32 |
33 | res = [tf.reshape(feats_sampled_i, [N, output_1d_size, output_1d_size, C]) for feats_sampled_i in res]
34 | if is_input_tensor:
35 | return res[0]
36 | return res
37 |
38 |
39 | def crop_quarters(feature_tensor):
40 | N, fH, fW, fC = feature_tensor.shape.as_list()
41 | quarters_list = []
42 | quarter_size = [N, round(fH / 2), round(fW / 2), fC]
43 | quarters_list.append(tf.slice(feature_tensor, [0, 0, 0, 0], quarter_size))
44 | quarters_list.append(tf.slice(feature_tensor, [0, round(fH / 2), 0, 0], quarter_size))
45 | quarters_list.append(tf.slice(feature_tensor, [0, 0, round(fW / 2), 0], quarter_size))
46 | quarters_list.append(tf.slice(feature_tensor, [0, round(fH / 2), round(fW / 2), 0], quarter_size))
47 | feature_tensor = tf.concat(quarters_list, axis=0)
48 | return feature_tensor
49 |
50 |
51 | def CX_loss_helper(vgg_A, vgg_B, CX_config):
52 | if CX_config.crop_quarters is True:
53 | vgg_A = crop_quarters(vgg_A)
54 | vgg_B = crop_quarters(vgg_B)
55 |
56 | N, fH, fW, fC = vgg_A.shape.as_list()
57 | if fH * fW <= CX_config.max_sampling_1d_size ** 2:
58 | print(' #### Skipping pooling for CX....')
59 | else:
60 | print(' #### pooling for CX %d**2 out of %dx%d' % (CX_config.max_sampling_1d_size, fH, fW))
61 | vgg_A, vgg_B = random_pooling([vgg_A, vgg_B], output_1d_size=CX_config.max_sampling_1d_size)
62 |
63 | CX_loss = CSFlow.CX_loss(vgg_A, vgg_B, distance=CX_config.Dist, nnsigma=CX_config.nn_stretch_sigma)
64 | return CX_loss
65 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | from easydict import EasyDict as edict
3 | import json
4 | from CX.enums import Distance
5 | import tensorflow as tf
6 | import numpy as np
7 | import re
8 |
9 | celebA = False
10 | single_image = False
11 | zero_tensor = tf.constant(0.0, dtype=tf.float32)
12 | config = edict()
13 |
14 | #---------------------------------------------
15 | # update the right paths
16 | config.base_dir = 'C:/DATA/person2person/single/'
17 | config.single_image_B_file_name = 'images/trump_cartoon.jpg'
18 | config.vgg_model_path = 'C:/DATA/VGG_Model/imagenet-vgg-verydeep-19.mat'
19 | #---------------------------------------------
20 |
21 |
22 |
23 | config.W = edict()
24 | # weights
25 | config.W.CX = 1.0
26 | config.W.CX_content = 1.0
27 |
28 | # train parameters
29 | config.TRAIN = edict()
30 | config.TRAIN.is_train = True #change to True of you want to train
31 | config.TRAIN.sp = 256
32 | config.TRAIN.aspect_ratio = 1 # 1
33 | config.TRAIN.resize = [config.TRAIN.sp * config.TRAIN.aspect_ratio, config.TRAIN.sp]
34 | config.TRAIN.crop_size = [config.TRAIN.sp * config.TRAIN.aspect_ratio, config.TRAIN.sp]
35 | config.TRAIN.A_data_dir = 'train'
36 | config.TRAIN.out_dir = "result/"
37 | config.TRAIN.num_epochs = 10
38 | config.TRAIN.reduce_dim = 2 #use of smaller CRN model
39 | config.TRAIN.every_nth_frame = 1 #train using all frames
40 |
41 | config.VAL = edict()
42 | config.VAL.A_data_dir = 'test'
43 | config.VAL.every_nth_frame = 1
44 |
45 | config.TEST = edict()
46 | config.TEST.is_test = not config.TRAIN.is_train
47 | config.TEST.A_data_dir = config.VAL.A_data_dir
48 | # config.TEST.every_nth_frame = 5
49 | config.TEST.out_dir_postfix = "/test"
50 | config.TEST.random_crop = False # if False, take the top left corner
51 |
52 | config.CX = edict()
53 | config.CX.crop_quarters = False
54 | config.CX.max_sampling_1d_size = 65
55 | # config.dis.feat_layers = {'conv1_1': 1.0,'conv2_1': 1.0, 'conv3_1': 1.0, 'conv4_1': 1.0,'conv5_1': 1.0}
56 | config.CX.feat_layers = {'conv3_2': 1.0, 'conv4_2': 1.0}
57 | config.CX.feat_content_layers = {'conv4_2': 1.0} # for single image
58 | config.CX.Dist = Distance.DotProduct
59 | config.CX.nn_stretch_sigma = 0.5#0.1
60 | config.CX.patch_size = 5
61 | config.CX.patch_stride = 2
62 |
63 |
64 | def last_two_nums(str):
65 | if str.endswith('vgg_input_im') or str is 'RGB':
66 | return 'rgb'
67 | all_nums = re.findall(r'\d+', str)
68 | return all_nums[-2] + all_nums[-1]
69 |
70 |
71 |
72 |
73 |
74 | config.expirament_postfix = 'single_im'
75 | if config.W.CX > 0:
76 | config.expirament_postfix += "_CXt" #CX_target
77 | config.expirament_postfix += '_'.join([last_two_nums(layer) for layer in sorted(config.CX.feat_layers.keys())])
78 | config.expirament_postfix += '_{}'.format(config.W.CX)
79 | if config.W.CX_content:
80 | config.expirament_postfix += "_CXs" #CX_source
81 | config.expirament_postfix += '_'.join([last_two_nums(layer) for layer in sorted(config.CX.feat_content_layers.keys())])
82 | config.expirament_postfix += '_{}'.format(config.W.CX_content)
83 |
84 |
85 | # uncomment and update for test
86 | # config.expirament_postfix = 'm2f_D32_42_1.0(s0.5)_DC42_1.0'
87 |
88 | config.TRAIN.out_dir += config.expirament_postfix
89 | config.TEST.out_dir = config.TRAIN.out_dir
90 | if not os.path.exists(config.TRAIN.out_dir):
91 | os.makedirs(config.TRAIN.out_dir)
92 |
93 |
94 |
--------------------------------------------------------------------------------
/vgg_model.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------
2 | # code credits: https://github.com/CQFIO/PhotographicImageSynthesis
3 | # ---------------------------------------------------
4 |
5 | import tensorflow as tf
6 | import tensorflow.contrib.slim as slim
7 | import numpy as np
8 | import scipy.io
9 | from config import *
10 |
11 |
12 | def lrelu(x):
13 | return tf.maximum(0.2 * x, x)
14 |
15 |
16 | def build_net(ntype, nin, nwb=None, name=None):
17 | if ntype == 'conv':
18 | return tf.nn.relu(tf.nn.conv2d(nin, nwb[0], strides=[1, 1, 1, 1], padding='SAME', name=name) + nwb[1])
19 | elif ntype == 'pool':
20 | return tf.nn.avg_pool(nin, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
21 |
22 |
23 | def get_weight_bias(vgg_layers, i):
24 | weights = vgg_layers[i][0][0][2][0][0]
25 | weights = tf.constant(weights)
26 | bias = vgg_layers[i][0][0][2][0][1]
27 | bias = tf.constant(np.reshape(bias, (bias.size)))
28 | return weights, bias
29 |
30 |
31 | def build_vgg19(input, reuse=False):
32 | if reuse:
33 | tf.get_variable_scope().reuse_variables()
34 | net = {}
35 | vgg_rawnet = scipy.io.loadmat(config.vgg_model_path)
36 | vgg_layers = vgg_rawnet['layers'][0]
37 | net['input'] = input - np.array([123.6800, 116.7790, 103.9390]).reshape((1, 1, 1, 3))
38 | net['conv1_1'] = build_net('conv', net['input'], get_weight_bias(vgg_layers, 0), name='vgg_conv1_1')
39 | net['conv1_2'] = build_net('conv', net['conv1_1'], get_weight_bias(vgg_layers, 2), name='vgg_conv1_2')
40 | net['pool1'] = build_net('pool', net['conv1_2'])
41 | net['conv2_1'] = build_net('conv', net['pool1'], get_weight_bias(vgg_layers, 5), name='vgg_conv2_1')
42 | net['conv2_2'] = build_net('conv', net['conv2_1'], get_weight_bias(vgg_layers, 7), name='vgg_conv2_2')
43 | net['pool2'] = build_net('pool', net['conv2_2'])
44 | net['conv3_1'] = build_net('conv', net['pool2'], get_weight_bias(vgg_layers, 10), name='vgg_conv3_1')
45 | net['conv3_2'] = build_net('conv', net['conv3_1'], get_weight_bias(vgg_layers, 12), name='vgg_conv3_2')
46 | net['conv3_3'] = build_net('conv', net['conv3_2'], get_weight_bias(vgg_layers, 14), name='vgg_conv3_3')
47 | net['conv3_4'] = build_net('conv', net['conv3_3'], get_weight_bias(vgg_layers, 16), name='vgg_conv3_4')
48 | net['pool3'] = build_net('pool', net['conv3_4'])
49 | net['conv4_1'] = build_net('conv', net['pool3'], get_weight_bias(vgg_layers, 19), name='vgg_conv4_1')
50 | net['conv4_2'] = build_net('conv', net['conv4_1'], get_weight_bias(vgg_layers, 21), name='vgg_conv4_2')
51 | net['conv4_3'] = build_net('conv', net['conv4_2'], get_weight_bias(vgg_layers, 23), name='vgg_conv4_3')
52 | net['conv4_4'] = build_net('conv', net['conv4_3'], get_weight_bias(vgg_layers, 25), name='vgg_conv4_4')
53 | net['pool4'] = build_net('pool', net['conv4_4'])
54 | net['conv5_1'] = build_net('conv', net['pool4'], get_weight_bias(vgg_layers, 28), name='vgg_conv5_1')
55 | net['conv5_2'] = build_net('conv', net['conv5_1'], get_weight_bias(vgg_layers, 30), name='vgg_conv5_2')
56 | net['conv5_3'] = build_net('conv', net['conv5_2'], get_weight_bias(vgg_layers, 32), name='vgg_conv5_3')
57 | net['conv5_4'] = build_net('conv', net['conv5_3'], get_weight_bias(vgg_layers, 34), name='vgg_conv5_4')
58 | net['pool5'] = build_net('pool', net['conv5_4'])
59 | return net
60 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # The Contextual Loss [[project page]](https://cgm.technion.ac.il/Computer-Graphics-Multimedia/Software/Contextual/)
2 |
3 | This is a Tensorflow implementation of the Contextual loss function as reported in the following papers:
4 | (PyTorch implementation is also available - see bellow)
5 |
6 | ### The Contextual Loss for Image Transformation with Non-Aligned Data, [arXiv](https://arxiv.org/abs/1803.02077)
7 | ### Learning to Maintain Natural Image Statistics, [arXiv](https://arxiv.org/abs/1803.04626)
8 |
9 | [Roey Mechrez*](https://roimehrez.github.io/), Itamar Talmi*, Firas Shama, [Lihi Zelnik-Manor](http://lihi.eew.technion.ac.il/). [The Technion](http://cgm.technion.ac.il/)
10 |
11 | Copyright 2018 Itamar Talmi and Roey Mechrez Licensed for noncommercial research use only.
12 |
13 |
14 |

15 |
16 |
17 | ## Setup
18 |
19 | ### Background
20 | This code is mainly the contextual loss function. The two papers have many applications, here we provide only one applications: animation from single image.
21 |
22 | An example pre-trained model can be download from this [link](https://www.dropbox.com/s/37nz4hy7ai4pqxc/single_im_D32_42_1.0_DC42_1.0.zip?dl=0)
23 |
24 | The data for this example can be download from this [link](https://www.dropbox.com/s/ggb6v6rv1a0212y/single.zip?dl=0)
25 |
26 | ### Requirement
27 | Required python libraries: Tensorflow (>=1.0, <1.9, tested on 1.4) + Scipy + Numpy + easydict
28 |
29 | Tested in Windows + Intel i7 CPU + Nvidia Titan Xp (and 1080ti) with Cuda (>=8.0) and CuDNN (>=5.0). CPU mode should also work with minor changes.
30 |
31 |
32 | ### Quick Start (Testing)
33 | 1. Clone this repository.
34 | 2. Download the pretrained model from this [link](https://www.dropbox.com/s/q3wjtaxr76cdx3t/imagenet-vgg-verydeep-19.mat?dl=0)
35 | 3. Extract the zip file under ```result``` folder. The models should be in ```based_dir/result/single_im_D32_42_1.0_DC42_1.0/```
36 | 3. Update the ```config.base_dir``` and ```config.vgg_model_path``` in ```config.py``` and run: ``` single_image_animation.py```
37 |
38 | ### Training
39 | 1. Change ```config.TRAIN.to_train``` to ```True```
40 | 2. Arrange the paths to the data, should have ```train``` and ```test``` folders
41 | 2. run ``` single_image_animation.py ``` for 10 epochs.
42 |
43 | ### Pytorch implemntation
44 | We have also released a PyTorch implementation of the loss function. See ```CX/CX_distance.py```. Note that we havn't test this implemntation to reproduce the results in the paper.
45 |
46 |
47 | ## License
48 |
49 | This software is provided under the provisions of the Lesser GNU Public License (LGPL).
50 | see: http://www.gnu.org/copyleft/lesser.html.
51 |
52 | This software can be used only for research purposes, you should cite
53 | the aforementioned papers in any resulting publication.
54 |
55 | The Software is provided "as is", without warranty of any kind.
56 |
57 |
58 | ## Citation
59 | If you use our code for research, please cite our paper:
60 | ```
61 | @article{mechrez2018contextual,
62 | title={The Contextual Loss for Image Transformation with Non-Aligned Data},
63 | author={Mechrez, Roey and Talmi, Itamar and Zelnik-Manor, Lihi},
64 | journal={arXiv preprint arXiv:1803.02077},
65 | year={2018}
66 | }
67 | @article{mechrez2018Learning,
68 | title={Learning to Maintain Natural Image Statistics, [arXiv](https://arxiv.org/abs/1803.04626)},
69 | author={Mechrez, Roey and Talmi, Itamar and Shama, Firas and Zelnik-Manor, Lihi},
70 | journal={arXiv preprint arXiv:1803.04626},
71 | year={2018}
72 | }
73 | ```
74 |
75 |
76 | ## Code References
77 |
78 | [1] Template Matching with Deformable Diversity Similarity, https://github.com/roimehrez/DDIS
79 |
80 | [2] Photographic Image Synthesis with Cascaded Refinement Networks https://cqf.io/ImageSynthesis/
81 |
82 |
--------------------------------------------------------------------------------
/single_image_animation.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------
2 | # code credits: https://github.com/CQFIO/PhotographicImageSynthesis
3 | # ---------------------------------------------------
4 | from __future__ import division
5 |
6 | import time
7 |
8 | import utils.helper as helper
9 | from CX.CX_helper import *
10 | from model import *
11 | from utils.FetchManager import *
12 |
13 | sess = tf.Session()
14 |
15 | # ---------------------------------------------------
16 | # graph
17 | # ---------------------------------------------------
18 | with tf.variable_scope(tf.get_variable_scope()):
19 | input_A = tf.placeholder(tf.float32, [None, None, None, 3])
20 | input_B = tf.placeholder(tf.float32, [None, None, None, 3])
21 | input_A_test = tf.placeholder(tf.float32, [None, None, None, 3])
22 | input_image_A, real_image_B = helper.random_crop_together(input_A, input_B, [2, config.TRAIN.resize[0], config.TRAIN.resize[1], 3])
23 | with tf.variable_scope("g") as scope:
24 | generator = recursive_generator(input_image_A, config.TRAIN.sp)
25 | scope.reuse_variables()
26 | generator_test = recursive_generator(input_A_test, config.TRAIN.sp)
27 | weight = tf.placeholder(tf.float32)
28 | vgg_real = build_vgg19(real_image_B)
29 | vgg_fake = build_vgg19(generator, reuse=True)
30 | vgg_input = build_vgg19(input_image_A, reuse=True)
31 |
32 |
33 | ## --- contextual style/target---
34 | if config.W.CX > 0:
35 | CX_loss_list = [w * CX_loss_helper(vgg_real[layer], vgg_fake[layer], config.CX)
36 | for layer, w in config.CX.feat_layers.items()]
37 | CX_style_loss = tf.reduce_sum(CX_loss_list)
38 | CX_style_loss *= config.W.CX
39 | else:
40 | CX_style_loss = zero_tensor
41 |
42 | ## --- contextual content/source---
43 | if config.W.CX_content > 0:
44 | CX_loss_content_list = [w * CX_loss_helper(vgg_input[layer], vgg_fake[layer], config.CX)
45 | for layer, w in config.CX.feat_content_layers.items()]
46 | CX_content_loss = tf.reduce_sum(CX_loss_content_list)
47 | CX_content_loss *= config.W.CX_content
48 | else:
49 | CX_content_loss = zero_tensor
50 |
51 | ## --- total loss ---
52 | G_loss = CX_style_loss + CX_content_loss
53 |
54 |
55 | # create the optimization
56 | lr = tf.placeholder(tf.float32)
57 | var_list = [var for var in tf.trainable_variables() if var.name.startswith('g/g_')]
58 | G_opt = tf.train.AdamOptimizer(learning_rate=lr).minimize(G_loss, var_list=var_list)
59 | saver = tf.train.Saver(max_to_keep=1000)
60 | sess.run(tf.global_variables_initializer())
61 |
62 |
63 | # load from checkpoint if exist
64 | def load(dir):
65 | ckpt = tf.train.get_checkpoint_state(dir)
66 | if ckpt:
67 | print('loaded ' + ckpt.model_checkpoint_path)
68 | saver.restore(sess, ckpt.model_checkpoint_path)
69 | return ckpt
70 |
71 |
72 | ckpt = load(config.TRAIN.out_dir)
73 |
74 |
75 | # ---------------------------------------------------
76 | # train
77 | # ---------------------------------------------------
78 | if config.TRAIN.is_train:
79 | file_list = os.listdir(config.base_dir + config.TRAIN.A_data_dir)
80 | val_file_list = os.listdir(config.base_dir + config.VAL.A_data_dir)
81 | file_list = np.random.permutation(file_list)
82 | assert len(file_list) > 0
83 | train_file_list = file_list[0::config.TRAIN.every_nth_frame]
84 | val_file_list = val_file_list[0::config.VAL.every_nth_frame]
85 | g_loss = np.zeros(len(train_file_list), dtype=float)
86 | fetcher = FetchManager(sess, [G_opt, G_loss])
87 | B_file_name = config.single_image_B_file_name
88 | B_image = helper.read_image(B_file_name) # training image B
89 |
90 | ## ------------ epoch loop -------------------------
91 | for epoch in range(1, config.TRAIN.num_epochs + 1):
92 | epoch_dir = config.TRAIN.out_dir + "/%04d" % epoch
93 | if os.path.isdir(epoch_dir):
94 | continue
95 | cnt = 0
96 |
97 | ## ------------ batch loop -------------------------
98 | for ind in np.random.permutation(len(train_file_list)):#
99 | st = time.time()
100 | cnt += 1
101 |
102 | A_file_name = config.base_dir + config.TRAIN.A_data_dir + '/' + train_file_list[ind]
103 | if not os.path.isfile(A_file_name) or not os.path.isfile(A_file_name):
104 | continue
105 | A_image = helper.read_image(A_file_name) # training image A
106 |
107 | feed_dict = {input_A: A_image, input_B: B_image, lr: 1e-4}
108 |
109 | #session run
110 | eval = fetcher.fetch(feed_dict, [CX_style_loss, CX_content_loss])
111 |
112 | g_loss[ind] = eval[G_loss]
113 | log = "epoch:%d | cnt:%d | time:%.2f | loss:%.2f || dis_style:%.2f | dis_content:%.2f " % \
114 | (epoch, cnt, time.time() - st, np.mean(g_loss[np.where(g_loss)]), eval[CX_style_loss], eval[CX_content_loss])
115 | print(log)
116 | ##------------ end batch loop -------------------
117 |
118 | # -------------- save the model ------------------
119 | # we use loop with try and catch to verify that the save was done. when saving on Dropbox it sometimes cause an error.
120 | for i in range(5):
121 | try:
122 | if not os.path.exists(epoch_dir):
123 | os.makedirs(epoch_dir)
124 | helper.write_loss_in_txt(g_loss, epoch)
125 | saver.save(sess, config.TRAIN.out_dir + "/model.ckpt")
126 | except:
127 | time.sleep(1)
128 |
129 | ## ------------ validation loop -------------------------
130 | for ind in range(len(val_file_list)):
131 | A_file_name_val = config.base_dir + config.VAL.A_data_dir + '/' + val_file_list[ind]
132 | if not os.path.isfile(A_file_name_val): # test label
133 | continue
134 | A_image_val = helper.read_image(A_file_name_val) # training image A
135 | # B_image_val = helper.read_image(B_file_name_val) # training image A
136 | output = sess.run(generator_test, feed_dict={input_A_test: A_image_val})
137 | output = np.concatenate([A_image_val, output, B_image], axis=2)
138 | helper.save_image(output, config.TRAIN.out_dir + "/%04d/" % epoch + val_file_list[ind].replace('.jpg', '_out.jpg'))
139 |
140 |
141 |
142 | # ---------------------------------------------------
143 | # test
144 | # ---------------------------------------------------
145 | if config.TEST.is_test:
146 | test_file_list = os.listdir(config.base_dir + config.TEST.A_data_dir)
147 | if not os.path.isdir(config.TEST.out_dir + config.TEST.out_dir_postfix):
148 | os.makedirs(config.TEST.out_dir + config.TEST.out_dir_postfix)
149 | time_list = np.zeros(len(test_file_list), dtype=float)
150 | for ind in range(len(test_file_list)):
151 | A_file_name_val = config.base_dir + config.TEST.A_data_dir + '/' + test_file_list[ind]
152 | if not os.path.isfile(A_file_name_val):
153 | continue
154 | A_image_val = helper.read_image(A_file_name_val, fliplr=False) # training image A
155 | st = time.time()
156 | output = sess.run(generator_test, feed_dict={input_A_test: A_image_val})
157 | et = time.time()
158 | output = np.concatenate([A_image_val, output], axis=2)#B_image_val
159 | helper.save_image(output, config.TEST.out_dir + config.TEST.out_dir_postfix + "/" + test_file_list[ind].replace('.jpg', '_out.jpg'))
160 | time_list[ind] = et - st
161 | print("test for image #: %d, time: %1.4f" % (ind, et - st))
162 | print('average time per image: %f' % time_list.mean())
163 |
--------------------------------------------------------------------------------
/CX/CSFlow.py:
--------------------------------------------------------------------------------
1 | from logging import exception
2 | import CX.enums as enums
3 | import tensorflow as tf
4 | from CX.enums import TensorAxis, Distance
5 |
6 |
7 | class CSFlow:
8 | def __init__(self, sigma = float(0.1), b = float(1.0)):
9 | self.b = b
10 | self.sigma = sigma
11 |
12 | def __calculate_CS(self, scaled_distances, axis_for_normalization = TensorAxis.C):
13 | self.scaled_distances = scaled_distances
14 | self.cs_weights_before_normalization = tf.exp((self.b - scaled_distances) / self.sigma, name='weights_before_normalization')
15 | self.cs_NHWC = CSFlow.sum_normalize(self.cs_weights_before_normalization, axis_for_normalization)
16 |
17 | def reversed_direction_CS(self):
18 | cs_flow_opposite = CSFlow(self.sigma, self.b)
19 | cs_flow_opposite.raw_distances = self.raw_distances
20 | work_axis = [TensorAxis.H, TensorAxis.W]
21 | relative_dist = cs_flow_opposite.calc_relative_distances(axis=work_axis)
22 | cs_flow_opposite.__calculate_CS(relative_dist, work_axis)
23 | return cs_flow_opposite
24 |
25 | # --
26 | @staticmethod
27 | def create_using_L2(I_features, T_features, sigma = float(0.1), b = float(1.0)):
28 | cs_flow = CSFlow(sigma, b)
29 | # for debug:
30 | # I_features = tf.concat([I_features, I_features], axis=1)
31 | with tf.name_scope('CS'):
32 | # assert I_features.shape[TensorAxis.C] == T_features.shape[TensorAxis.C]
33 | c = T_features.shape[TensorAxis.C].value
34 | sT = T_features.shape.as_list()
35 | sI = I_features.shape.as_list()
36 |
37 | Ivecs = tf.reshape(I_features, (sI[TensorAxis.N], -1, sI[TensorAxis.C]))
38 | Tvecs = tf.reshape(T_features, (sI[TensorAxis.N], -1, sT[TensorAxis.C]))
39 | r_Ts = tf.reduce_sum(Tvecs * Tvecs, 2)
40 | r_Is = tf.reduce_sum(Ivecs * Ivecs, 2)
41 | raw_distances_list = []
42 | for i in range(sT[TensorAxis.N]):
43 | Ivec, Tvec, r_T, r_I = Ivecs[i], Tvecs[i], r_Ts[i], r_Is[i]
44 | A = Tvec @ tf.transpose(Ivec)
45 | cs_flow.A = A
46 | # A = tf.matmul(Tvec, tf.transpose(Ivec))
47 | r_T = tf.reshape(r_T, [-1, 1]) # turn to column vector
48 | dist = r_T - 2 * A + r_I
49 | cs_shape = sI[:3] + [dist.shape[0].value]
50 | cs_shape[0] = 1
51 | dist = tf.reshape(tf.transpose(dist), cs_shape)
52 | # protecting against numerical problems, dist should be positive
53 | dist = tf.maximum(float(0.0), dist)
54 | # dist = tf.sqrt(dist)
55 | raw_distances_list += [dist]
56 |
57 | cs_flow.raw_distances = tf.convert_to_tensor([tf.squeeze(raw_dist, axis=0) for raw_dist in raw_distances_list])
58 |
59 | relative_dist = cs_flow.calc_relative_distances()
60 | cs_flow.__calculate_CS(relative_dist)
61 | return cs_flow
62 |
63 | #--
64 | @staticmethod
65 | def create_using_dotP(I_features, T_features, sigma = float(1.0), b = float(1.0)):
66 | cs_flow = CSFlow(sigma, b)
67 | with tf.name_scope('CS'):
68 | # prepare feature before calculating cosine distance
69 | T_features, I_features = cs_flow.center_by_T(T_features, I_features)
70 | with tf.name_scope('TFeatures'):
71 | T_features = CSFlow.l2_normalize_channelwise(T_features)
72 | with tf.name_scope('IFeatures'):
73 | I_features = CSFlow.l2_normalize_channelwise(I_features)
74 |
75 | # work seperatly for each example in dim 1
76 | cosine_dist_l = []
77 | N, _, __, ___ = T_features.shape.as_list()
78 | for i in range(N):
79 | T_features_i = tf.expand_dims(T_features[i, :, :, :], 0)
80 | I_features_i = tf.expand_dims(I_features[i, :, :, :], 0)
81 | patches_HWCN_i = cs_flow.patch_decomposition(T_features_i)
82 | cosine_dist_i = tf.nn.conv2d(I_features_i, patches_HWCN_i, strides=[1, 1, 1, 1],
83 | padding='VALID', use_cudnn_on_gpu=True, name='cosine_dist')
84 | cosine_dist_l.append(cosine_dist_i)
85 |
86 | cs_flow.cosine_dist = tf.concat(cosine_dist_l, axis = 0)
87 |
88 | cosine_dist_zero_to_one = -(cs_flow.cosine_dist - 1) / 2
89 | cs_flow.raw_distances = cosine_dist_zero_to_one
90 |
91 | relative_dist = cs_flow.calc_relative_distances()
92 | cs_flow.__calculate_CS(relative_dist)
93 | return cs_flow
94 |
95 | def calc_relative_distances(self, axis=TensorAxis.C):
96 | epsilon = 1e-5
97 | div = tf.reduce_min(self.raw_distances, axis=axis, keep_dims=True)
98 | # div = tf.reduce_mean(self.raw_distances, axis=axis, keep_dims=True)
99 | relative_dist = self.raw_distances / (div + epsilon)
100 | return relative_dist
101 |
102 | def weighted_average_dist(self, axis = TensorAxis.C):
103 | if not hasattr(self, 'raw_distances'):
104 | raise exception('raw_distances property does not exists. cant calculate weighted average l2')
105 |
106 | multiply = self.raw_distances * self.cs_NHWC
107 | return tf.reduce_sum(multiply, axis=axis, name='weightedDistPerPatch')
108 |
109 | # --
110 | @staticmethod
111 | def create(I_features, T_features, distance : enums.Distance, nnsigma=float(1.0), b=float(1.0)):
112 | if distance.value == enums.Distance.DotProduct.value:
113 | cs_flow = CSFlow.create_using_dotP(I_features, T_features, nnsigma, b)
114 | elif distance.value == enums.Distance.L2.value:
115 | cs_flow = CSFlow.create_using_L2(I_features, T_features, nnsigma, b)
116 | else:
117 | raise "not supported distance " + distance.__str__()
118 | return cs_flow
119 |
120 | @staticmethod
121 | def sum_normalize(cs, axis=TensorAxis.C):
122 | reduce_sum = tf.reduce_sum(cs, axis, keep_dims=True, name='sum')
123 | return tf.divide(cs, reduce_sum, name='sumNormalized')
124 |
125 | def center_by_T(self, T_features, I_features):
126 | # assuming both input are of the same size
127 |
128 | # calculate stas over [batch, height, width], expecting 1x1xDepth tensor
129 | axes = [0, 1, 2]
130 | self.meanT, self.varT = tf.nn.moments(
131 | T_features, axes, name='TFeatures/moments')
132 |
133 | # we do not divide by std since its causing the histogram
134 | # for the final cs to be very thin, so the NN weights
135 | # are not distinctive, giving similar values for all patches.
136 | # stdT = tf.sqrt(varT, "stdT")
137 | # correct places with std zero
138 | # stdT[tf.less(stdT, tf.constant(0.001))] = tf.constant(1)
139 |
140 | # TODO check broadcasting here
141 | with tf.name_scope('TFeatures/centering'):
142 | self.T_features_centered = T_features - self.meanT
143 | with tf.name_scope('IFeatures/centering'):
144 | self.I_features_centered = I_features - self.meanT
145 |
146 | return self.T_features_centered, self.I_features_centered
147 | @staticmethod
148 |
149 | def l2_normalize_channelwise(features):
150 | norms = tf.norm(features, ord='euclidean', axis=TensorAxis.C, name='norm')
151 | # expanding the norms tensor to support broadcast division
152 | norms_expanded = tf.expand_dims(norms, TensorAxis.C)
153 | features = tf.divide(features, norms_expanded, name='normalized')
154 | return features
155 |
156 | def patch_decomposition(self, T_features):
157 | # patch decomposition
158 | # see https://stackoverflow.com/questions/40731433/understanding-tf-extract-image-patches-for-extracting-patches-from-an-image
159 | patch_size = 1
160 | patches_as_depth_vectors = tf.extract_image_patches(
161 | images=T_features, ksizes=[1, patch_size, patch_size, 1],
162 | strides=[1, 1, 1, 1], rates=[1, 1, 1, 1], padding='VALID',
163 | name='patches_as_depth_vectors')
164 |
165 | self.patches_NHWC = tf.reshape(
166 | patches_as_depth_vectors,
167 | shape=[-1, patch_size, patch_size, patches_as_depth_vectors.shape[3].value],
168 | name='patches_PHWC')
169 |
170 | self.patches_HWCN = tf.transpose(
171 | self.patches_NHWC,
172 | perm=[1, 2, 3, 0],
173 | name='patches_HWCP') # tf.conv2 ready format
174 |
175 | return self.patches_HWCN
176 |
177 |
178 | #--------------------------------------------------
179 | # CX loss
180 | #--------------------------------------------------
181 |
182 |
183 | def CX_loss(T_features, I_features, distance=Distance.L2, nnsigma=float(1.0)):
184 | T_features = tf.convert_to_tensor(T_features, dtype=tf.float32)
185 | I_features = tf.convert_to_tensor(I_features, dtype=tf.float32)
186 |
187 | with tf.name_scope('CX'):
188 | cs_flow = CSFlow.create(I_features, T_features, distance, nnsigma)
189 | # sum_normalize:
190 | height_width_axis = [TensorAxis.H, TensorAxis.W]
191 | # To:
192 | cs = cs_flow.cs_NHWC
193 | k_max_NC = tf.reduce_max(cs, axis=height_width_axis)
194 | CS = tf.reduce_mean(k_max_NC, axis=[1])
195 | CX_as_loss = 1 - CS
196 | CX_loss = -tf.log(1 - CX_as_loss)
197 | CX_loss = tf.reduce_mean(CX_loss)
198 | return CX_loss
--------------------------------------------------------------------------------
/CX/CX_distance.py:
--------------------------------------------------------------------------------
1 | # import tensorflow as tf
2 | import torch
3 | import numpy as np
4 | import sklearn.manifold.t_sne
5 |
6 | class TensorAxis:
7 | N = 0
8 | H = 1
9 | W = 2
10 | C = 3
11 |
12 |
13 | class CSFlow:
14 | def __init__(self, sigma=float(0.1), b=float(1.0)):
15 | self.b = b
16 | self.sigma = sigma
17 |
18 | def __calculate_CS(self, scaled_distances, axis_for_normalization=TensorAxis.C):
19 | self.scaled_distances = scaled_distances
20 | self.cs_weights_before_normalization = torch.exp((self.b - scaled_distances) / self.sigma)
21 | # self.cs_weights_before_normalization = 1 / (1 + scaled_distances)
22 | # self.cs_NHWC = CSFlow.sum_normalize(self.cs_weights_before_normalization, axis_for_normalization)
23 | self.cs_NHWC = self.cs_weights_before_normalization
24 |
25 | # def reversed_direction_CS(self):
26 | # cs_flow_opposite = CSFlow(self.sigma, self.b)
27 | # cs_flow_opposite.raw_distances = self.raw_distances
28 | # work_axis = [TensorAxis.H, TensorAxis.W]
29 | # relative_dist = cs_flow_opposite.calc_relative_distances(axis=work_axis)
30 | # cs_flow_opposite.__calculate_CS(relative_dist, work_axis)
31 | # return cs_flow_opposite
32 |
33 | # --
34 | @staticmethod
35 | def create_using_L2(I_features, T_features, sigma=float(0.5), b=float(1.0)):
36 | cs_flow = CSFlow(sigma, b)
37 | sT = T_features.shape
38 | sI = I_features.shape
39 |
40 | Ivecs = torch.reshape(I_features, (sI[0], -1, sI[3]))
41 | Tvecs = torch.reshape(T_features, (sI[0], -1, sT[3]))
42 | r_Ts = torch.sum(Tvecs * Tvecs, 2)
43 | r_Is = torch.sum(Ivecs * Ivecs, 2)
44 | raw_distances_list = []
45 | for i in range(sT[0]):
46 | Ivec, Tvec, r_T, r_I = Ivecs[i], Tvecs[i], r_Ts[i], r_Is[i]
47 | A = Tvec @ torch.transpose(Ivec, 0, 1) # (matrix multiplication)
48 | cs_flow.A = A
49 | # A = tf.matmul(Tvec, tf.transpose(Ivec))
50 | r_T = torch.reshape(r_T, [-1, 1]) # turn to column vector
51 | dist = r_T - 2 * A + r_I
52 | dist = torch.reshape(torch.transpose(dist, 0, 1), shape=(1, sI[1], sI[2], dist.shape[0]))
53 | # protecting against numerical problems, dist should be positive
54 | dist = torch.clamp(dist, min=float(0.0))
55 | # dist = tf.sqrt(dist)
56 | raw_distances_list += [dist]
57 |
58 | cs_flow.raw_distances = torch.cat(raw_distances_list)
59 |
60 | relative_dist = cs_flow.calc_relative_distances()
61 | cs_flow.__calculate_CS(relative_dist)
62 | return cs_flow
63 |
64 | # --
65 | @staticmethod
66 | def create_using_L1(I_features, T_features, sigma=float(0.5), b=float(1.0)):
67 | cs_flow = CSFlow(sigma, b)
68 | sT = T_features.shape
69 | sI = I_features.shape
70 |
71 | Ivecs = torch.reshape(I_features, (sI[0], -1, sI[3]))
72 | Tvecs = torch.reshape(T_features, (sI[0], -1, sT[3]))
73 | raw_distances_list = []
74 | for i in range(sT[0]):
75 | Ivec, Tvec = Ivecs[i], Tvecs[i]
76 | dist = torch.abs(torch.sum(Ivec.unsqueeze(1) - Tvec.unsqueeze(0), dim=2))
77 | dist = torch.reshape(torch.transpose(dist, 0, 1), shape=(1, sI[1], sI[2], dist.shape[0]))
78 | # protecting against numerical problems, dist should be positive
79 | dist = torch.clamp(dist, min=float(0.0))
80 | # dist = tf.sqrt(dist)
81 | raw_distances_list += [dist]
82 |
83 | cs_flow.raw_distances = torch.cat(raw_distances_list)
84 |
85 | relative_dist = cs_flow.calc_relative_distances()
86 | cs_flow.__calculate_CS(relative_dist)
87 | return cs_flow
88 |
89 | # --
90 | @staticmethod
91 | def create_using_dotP(I_features, T_features, sigma=float(0.5), b=float(1.0)):
92 | cs_flow = CSFlow(sigma, b)
93 | # prepare feature before calculating cosine distance
94 | T_features, I_features = cs_flow.center_by_T(T_features, I_features)
95 | T_features = CSFlow.l2_normalize_channelwise(T_features)
96 | I_features = CSFlow.l2_normalize_channelwise(I_features)
97 |
98 | # work seperatly for each example in dim 1
99 | cosine_dist_l = []
100 | N = T_features.size()[0]
101 | for i in range(N):
102 | T_features_i = T_features[i, :, :, :].unsqueeze_(0) # 1HWC --> 1CHW
103 | I_features_i = I_features[i, :, :, :].unsqueeze_(0).permute((0, 3, 1, 2))
104 | patches_PC11_i = cs_flow.patch_decomposition(T_features_i) # 1HWC --> PC11, with P=H*W
105 | cosine_dist_i = torch.nn.functional.conv2d(I_features_i, patches_PC11_i)
106 | cosine_dist_1HWC = cosine_dist_i.permute((0, 2, 3, 1))
107 | cosine_dist_l.append(cosine_dist_i.permute((0, 2, 3, 1))) # back to 1HWC
108 |
109 | cs_flow.cosine_dist = torch.cat(cosine_dist_l, dim=0)
110 |
111 | cs_flow.raw_distances = - (cs_flow.cosine_dist - 1) / 2 ### why -
112 |
113 | relative_dist = cs_flow.calc_relative_distances()
114 | cs_flow.__calculate_CS(relative_dist)
115 | return cs_flow
116 |
117 | def calc_relative_distances(self, axis=TensorAxis.C):
118 | epsilon = 1e-5
119 | div = torch.min(self.raw_distances, dim=axis, keepdim=True)[0]
120 | relative_dist = self.raw_distances / (div + epsilon)
121 | return relative_dist
122 |
123 | @staticmethod
124 | def sum_normalize(cs, axis=TensorAxis.C):
125 | reduce_sum = torch.sum(cs, dim=axis, keepdim=True)
126 | cs_normalize = torch.div(cs, reduce_sum)
127 | return cs_normalize
128 |
129 | def center_by_T(self, T_features, I_features):
130 | # assuming both input are of the same size
131 | # calculate stas over [batch, height, width], expecting 1x1xDepth tensor
132 | axes = [0, 1, 2]
133 | self.meanT = T_features.mean(0, keepdim=True).mean(1, keepdim=True).mean(2, keepdim=True)
134 | self.varT = T_features.var(0, keepdim=True).var(1, keepdim=True).var(2, keepdim=True)
135 | self.T_features_centered = T_features - self.meanT
136 | self.I_features_centered = I_features - self.meanT
137 |
138 | return self.T_features_centered, self.I_features_centered
139 |
140 | @staticmethod
141 | def l2_normalize_channelwise(features):
142 | norms = features.norm(p=2, dim=TensorAxis.C, keepdim=True)
143 | features = features.div(norms)
144 | return features
145 |
146 | def patch_decomposition(self, T_features):
147 | # 1HWC --> 11PC --> PC11, with P=H*W
148 | (N, H, W, C) = T_features.shape
149 | P = H * W
150 | patches_PC11 = T_features.reshape(shape=(1, 1, P, C)).permute(dims=(2, 3, 0, 1))
151 | return patches_PC11
152 |
153 | @staticmethod
154 | def pdist2(x, keepdim=False):
155 | sx = x.shape
156 | x = x.reshape(shape=(sx[0], sx[1] * sx[2], sx[3]))
157 | differences = x.unsqueeze(2) - x.unsqueeze(1)
158 | distances = torch.sum(differences**2, -1)
159 | if keepdim:
160 | distances = distances.reshape(shape=(sx[0], sx[1], sx[2], sx[3]))
161 | return distances
162 |
163 | @staticmethod
164 | def calcR_static(sT, order='C', deformation_sigma=0.05):
165 | # oreder can be C or F (matlab order)
166 | pixel_count = sT[0] * sT[1]
167 |
168 | rangeRows = range(0, sT[1])
169 | rangeCols = range(0, sT[0])
170 | Js, Is = np.meshgrid(rangeRows, rangeCols)
171 | row_diff_from_first_row = Is
172 | col_diff_from_first_col = Js
173 |
174 | row_diff_from_first_row_3d_repeat = np.repeat(row_diff_from_first_row[:, :, np.newaxis], pixel_count, axis=2)
175 | col_diff_from_first_col_3d_repeat = np.repeat(col_diff_from_first_col[:, :, np.newaxis], pixel_count, axis=2)
176 |
177 | rowDiffs = -row_diff_from_first_row_3d_repeat + row_diff_from_first_row.flatten(order).reshape(1, 1, -1)
178 | colDiffs = -col_diff_from_first_col_3d_repeat + col_diff_from_first_col.flatten(order).reshape(1, 1, -1)
179 | R = rowDiffs ** 2 + colDiffs ** 2
180 | R = R.astype(np.float32)
181 | R = np.exp(-(R) / (2 * deformation_sigma ** 2))
182 | return R
183 |
184 |
185 |
186 |
187 |
188 |
189 | # --------------------------------------------------
190 | # CX loss
191 | # --------------------------------------------------
192 |
193 |
194 |
195 | def CX_loss(T_features, I_features, deformation=False, dis=False):
196 | # T_features = tf.convert_to_tensor(T_features, dtype=tf.float32)
197 | # I_features = tf.convert_to_tensor(I_features, dtype=tf.float32)
198 | # since this is a convertion of tensorflow to pytorch we permute the tensor from
199 | # T_features = normalize_tensor(T_features)
200 | # I_features = normalize_tensor(I_features)
201 |
202 | # since this originally Tensorflow implemntation
203 | # we modify all tensors to be as TF convention and not as the convention of pytorch.
204 | def from_pt2tf(Tpt):
205 | Ttf = Tpt.permute(0, 2, 3, 1)
206 | return Ttf
207 | # N x C x H x W --> N x H x W x C
208 | T_features_tf = from_pt2tf(T_features)
209 | I_features_tf = from_pt2tf(I_features)
210 |
211 | # cs_flow = CSFlow.create_using_dotP(I_features_tf, T_features_tf, sigma=1.0)
212 | cs_flow = CSFlow.create_using_L2(I_features_tf, T_features_tf, sigma=1.0)
213 | # sum_normalize:
214 | # To:
215 | cs = cs_flow.cs_NHWC
216 |
217 | if deformation:
218 | deforma_sigma = 0.001
219 | sT = T_features_tf.shape[1:2 + 1]
220 | R = CSFlow.calcR_static(sT, deformation_sigma=deforma_sigma)
221 | cs *= torch.Tensor(R).unsqueeze(dim=0).cuda()
222 |
223 | if dis:
224 | CS = []
225 | k_max_NC = torch.max(torch.max(cs, dim=1)[1], dim=1)[1]
226 | indices = k_max_NC.cpu()
227 | N, C = indices.shape
228 | for i in range(N):
229 | CS.append((C - len(torch.unique(indices[i, :]))) / C)
230 | score = torch.FloatTensor(CS)
231 | else:
232 | # reduce_max X and Y dims
233 | # cs = CSFlow.pdist2(cs,keepdim=True)
234 | k_max_NC = torch.max(torch.max(cs, dim=1)[0], dim=1)[0]
235 | # reduce mean over C dim
236 | CS = torch.mean(k_max_NC, dim=1)
237 | # score = 1/CS
238 | # score = torch.exp(-CS*10)
239 | score = -torch.log(CS)
240 | # reduce mean over N dim
241 | # CX_loss = torch.mean(CX_loss)
242 | return score
243 |
244 |
245 | def symetric_CX_loss(T_features, I_features):
246 | score = (CX_loss(T_features, I_features) + CX_loss(I_features, T_features)) / 2
247 | return score
248 |
--------------------------------------------------------------------------------