├── README.md ├── architecture_ops.py ├── dataloading.py ├── main.py ├── config.py ├── architectures.py ├── utils.py ├── transformations.py ├── model.py └── ops.py /README.md: -------------------------------------------------------------------------------- 1 | git README.md 2 | 3 | Work in progress 4 | 5 | Instructions: 6 | Install tensorflow 1.14 7 | 8 | Download 9 | a) video dataset or, 10 | b) image dataset 11 | 12 | Change dataloader 13 | a) use the structure of load_train_human3m in dataloading.py 14 | b) use the structure of load_train_generic in dataloading.py 15 | 16 | Start training (current hyperparameters work for roughly cropped human3m without background) 17 | a) python main.py test_a --gpu 0 --dataset human3m --mode train --bn 16 18 | b) python main.py test_b --gpu 0 --dataset generic --mode train --bn 16 --static 19 | 20 | Eval (use same flags as in training except for mode) 21 | a) python main.py test_a --gpu 0 --dataset human3m --mode predict --bn 16 22 | b )python main.py test_b --gpu 0 --dataset generic --mode train --bn 16 --static 23 | 24 | 25 | Todo: 26 | -code documentation 27 | -add preprocessing 28 | -add hyperparameters for datasets 29 | -add eval functions 30 | 31 | [Project page with videos](https://compvis.github.io/unsupervised-disentangling/) 32 | -------------------------------------------------------------------------------- /architecture_ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from utils import wrappy 4 | 5 | 6 | def _conv(inputs, filters, kernel_size=1, strides=1, pad='VALID', name='conv'): 7 | with tf.variable_scope(name): 8 | # Kernel for convolution, Xavier Initialisation 9 | kernel = tf.get_variable(shape=[kernel_size, kernel_size, inputs.get_shape().as_list()[3], filters], 10 | initializer=tf.contrib.layers.xavier_initializer(uniform=False), name='weights') 11 | conv = tf.nn.conv2d(inputs, kernel, [1, strides, strides, 1], padding=pad, data_format='NHWC') 12 | return conv 13 | 14 | 15 | def _conv_bn_relu(inputs, filters, train, kernel_size=1, strides=1, pad='VALID', name='conv_bn_relu'): 16 | with tf.variable_scope(name): 17 | kernel = tf.get_variable(shape=[kernel_size, kernel_size, inputs.get_shape().as_list()[3], filters], 18 | initializer=tf.contrib.layers.xavier_initializer(uniform=False), name='weights') 19 | conv = tf.nn.conv2d(inputs, kernel, [1, strides, strides, 1], padding=pad, data_format='NHWC') 20 | norm = tf.nn.relu(tf.layers.batch_normalization(conv, momentum=0.9, epsilon=1e-5, training=train, name='bn')) 21 | return norm 22 | 23 | 24 | def _conv_block(inputs, numOut, train, name='conv_block'): 25 | with tf.variable_scope(name): 26 | with tf.variable_scope('norm_1'): 27 | norm_1 = tf.nn.relu(tf.layers.batch_normalization(inputs, momentum=0.9, epsilon=1e-5, training=train, name='bn')) 28 | 29 | conv_1 = _conv(norm_1, int(numOut / 2), kernel_size=1, strides=1, pad='VALID', name='conv') 30 | with tf.variable_scope('norm_2'): 31 | norm_2 = tf.nn.relu(tf.layers.batch_normalization(conv_1, momentum=0.9, epsilon=1e-5, training=train, name='bn')) 32 | 33 | pad = tf.pad(norm_2, np.array([[0, 0], [1, 1], [1, 1], [0, 0]]), name='pad') 34 | conv_2 = _conv(pad, int(numOut / 2), kernel_size=3, strides=1, pad='VALID', name='conv') 35 | with tf.variable_scope('norm_3'): 36 | norm_3 = tf.nn.relu(tf.layers.batch_normalization(conv_2, momentum=0.9, epsilon=1e-5, training=train, name='bn')) 37 | 38 | conv_3 = _conv(norm_3, int(numOut), kernel_size=1, strides=1, pad='VALID', name='conv') 39 | return conv_3 40 | 41 | 42 | def _skip_layer(inputs, num_out, name='skip_layer'): 43 | with tf.variable_scope(name): 44 | if inputs.get_shape().as_list()[3] == num_out: 45 | return inputs 46 | else: 47 | conv = _conv(inputs, num_out, kernel_size=1, strides=1, name='conv') 48 | return conv 49 | 50 | 51 | def _residual(inputs, num_out, train, name='residual_block'): 52 | with tf.variable_scope(name): 53 | convb = _conv_block(inputs, num_out, train=train) 54 | skipl = _skip_layer(inputs, num_out) 55 | return tf.add_n([convb, skipl], name='res_block') 56 | 57 | 58 | @wrappy 59 | def nccuc(input_A, input_B, n_filters, padding, training, name): 60 | with tf.variable_scope("layer{}".format(name)): 61 | for i, F in enumerate(n_filters): 62 | if i < 1: 63 | x0 = input_A 64 | x1 = tf.layers.conv2d(x0, F, (4, 4), strides=(1, 1), activation=None, padding=padding, 65 | kernel_regularizer=tf.contrib.layers.l2_regularizer(0.1), name="conv_{}".format(i + 1)) 66 | x1 = tf.layers.batch_normalization( 67 | x1, training=training, name="bn_{}".format(i + 1)) 68 | x1 = tf.nn.relu(x1, name="relu{}_{}".format(name, i + 1)) 69 | 70 | elif i == 1: 71 | up_conv = tf.layers.conv2d_transpose(x1, filters=F, kernel_size=4, strides=2, padding=padding, 72 | kernel_regularizer=tf.contrib.layers.l2_regularizer(0.1), name="upsample_{}".format(name)) 73 | 74 | up_conv = tf.nn.relu(up_conv, name="relu{}_{}".format(name, i + 1)) 75 | return tf.concat([up_conv, input_B], axis=-1, name="concat_{}".format(name)) 76 | 77 | else: 78 | return x1 -------------------------------------------------------------------------------- /dataloading.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import glob 3 | import matplotlib.pyplot as plt 4 | from PIL import Image 5 | import os 6 | import random 7 | import numpy as np 8 | 9 | 10 | 11 | 12 | def preprocess_image(image): 13 | image = tf.image.decode_jpeg(image, channels=3) 14 | image = tf.cast(image, dtype=tf.float32) 15 | image = image / 255. 16 | return image 17 | 18 | 19 | def load_and_preprocess_image(path): 20 | image = tf.io.read_file(path) 21 | return preprocess_image(image) 22 | 23 | 24 | def chunks(l, n): 25 | """Yield successive n-sized chunks from l.""" 26 | for i in range(0, len(l), n): 27 | elem = l[i:i + n] 28 | random.shuffle(elem) 29 | yield elem 30 | 31 | 32 | def load_train_human3m(arg, path='../datasets/human3/train/'): 33 | """ 34 | creates tf.dataset from the human3m video dataset with the following folder structure on disk: 35 | human3 36 | train 37 | S1 38 | directions <- first video 39 | 0.jpg 40 | 1.jpg 41 | 2.jpg 42 | ... 43 | discussion <- second video 44 | 0.jpg 45 | 1.jpg 46 | 2.jpg 47 | ... 48 | ... 49 | S5 50 | directions 51 | 0.jpg 52 | 1.jpg 53 | 2.jpg 54 | ... 55 | discussion 56 | 0.jpg 57 | 1.jpg 58 | 2.jpg 59 | ... 60 | ... 61 | ... 62 | :param path: 63 | :return: 64 | """ 65 | vids = [f for f in glob.glob(path + "*/*", recursive=True)] 66 | frames = [] 67 | for vid in vids: 68 | for chunk in chunks(sorted(glob.glob(vid + "/*.jpg", recursive=True), 69 | key=lambda x: int(x.split('/')[-1].split('.jpg')[0])), arg.chunk_size): 70 | if len(chunk) == arg.chunk_size: 71 | random.shuffle(chunk) 72 | frames.append(chunk) 73 | random.shuffle(frames) 74 | frames = np.asarray(frames) 75 | raw_dataset = tf.data.Dataset.from_tensor_slices(frames)\ 76 | .interleave(lambda x: tf.data.Dataset.from_tensor_slices(x) 77 | .shuffle(arg.n_shuffle, reshuffle_each_iteration=True), cycle_length=arg.chunk_size, block_length=2) 78 | return raw_dataset 79 | 80 | 81 | def load_test_human3m(arg, path='../datasets/human3/test/'): 82 | chunk_size = 2 83 | vids = [f for f in glob.glob(path + "*/*", recursive=True)] 84 | frames = [] 85 | for vid in vids: 86 | for chunk in chunks( 87 | sorted(glob.glob(vid + "/*.jpg", recursive=True), key=lambda x: int(x.split('/')[-1].split('.jpg')[0])), 88 | chunk_size): 89 | if len(chunk) == 2: 90 | frames.append(chunk) 91 | frames = np.asarray(frames) 92 | raw_dataset = tf.data.Dataset.from_tensor_slices(frames).flat_map(lambda x: tf.data.Dataset.from_tensor_slices(x)) 93 | return raw_dataset 94 | 95 | 96 | def load_train_generic(arg, '../datasets/generic/train_images/'): 97 | frames = glob.glob(path + "*.jpg", recursive=True) 98 | frames = np.asarray(frames).reshape(-1, 1) 99 | raw_dataset = tf.data.Dataset.from_tensor_slices(frames)\ 100 | .flat_map(lambda x: tf.data.Dataset.from_tensor_slices(x)).shuffle(arg.n_shuffle, reshuffle_each_iteration=True) 101 | return raw_dataset 102 | 103 | 104 | def load_test_generic(arg, path='../datasets/generic/test_images/'): 105 | frames = glob.glob(path + "*.jpg", recursive=True) 106 | frames = np.asarray(frames).reshape(-1, 1) 107 | raw_dataset = tf.data.Dataset.from_tensor_slices(frames)\ 108 | .flat_map(lambda x: tf.data.Dataset.from_tensor_slices(x)) 109 | return raw_dataset 110 | 111 | 112 | dataset_map_train = {'generic': load_train_generic, 'human3m': load_train_human3m} 113 | dataset_map_test = {'generic': load_test_generic, 'human3m': load_test_human3m} 114 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataloading import load_and_preprocess_image, dataset_map_train, dataset_map_test 3 | from transformations import tps_parameters 4 | from dotmap import DotMap 5 | import numpy as np 6 | from config import parse_args, write_hyperparameters 7 | from model import Model 8 | from utils import save_python_files, transformation_parameters, find_ckpt, batch_colour_map, save, initialize_uninitialized 9 | import tensorflow as tf 10 | 11 | 12 | def main(arg): 13 | os.environ["CUDA_VISIBLE_DEVICES"] = str(arg.gpu) 14 | model_save_dir = "../experiments/" + arg.name + "/" 15 | 16 | with tf.variable_scope("Data_prep"): 17 | if arg.mode == 'train': 18 | raw_dataset = dataset_map_train[arg.dataset](arg) 19 | 20 | elif arg.mode == 'predict': 21 | raw_dataset = dataset_map_test[arg.dataset](arg) 22 | 23 | dataset = raw_dataset.map(load_and_preprocess_image, num_parallel_calls=arg.data_parallel_calls) 24 | dataset = dataset.batch(arg['bn'], drop_remainder=True).repeat(arg.epochs) 25 | iterator = dataset.make_one_shot_iterator() 26 | next_element = iterator.get_next() 27 | b_images = next_element 28 | 29 | orig_images = tf.tile(b_images, [2, 1, 1, 1]) 30 | 31 | scal = tf.placeholder(dtype=tf.float32, shape=(), name='scal_placeholder') 32 | tps_scal = tf.placeholder(dtype=tf.float32, shape=(), name='tps_placeholder') 33 | rot_scal = tf.placeholder(dtype=tf.float32, shape=(), name='rot_scal_placeholder') 34 | off_scal = tf.placeholder(dtype=tf.float32, shape=(), name='off_scal_placeholder') 35 | scal_var = tf.placeholder(dtype=tf.float32, shape=(), name='scal_var_placeholder') 36 | augm_scal = tf.placeholder(dtype=tf.float32, shape=(), name='augm_scal_placeholder') 37 | 38 | tps_param_dic = tps_parameters(2 * arg.bn, scal, tps_scal, rot_scal, off_scal, scal_var) 39 | tps_param_dic.augm_scal = augm_scal 40 | 41 | ctr = 0 42 | config = tf.ConfigProto() 43 | config.gpu_options.allow_growth = True 44 | config.gpu_options.per_process_gpu_memory_fraction = 0.95 45 | with tf.Session(config=config) as sess: 46 | 47 | model = Model(orig_images, arg, tps_param_dic) 48 | tvar = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) 49 | saver = tf.train.Saver(var_list=tvar) 50 | merged = tf.summary.merge_all() 51 | 52 | if arg.mode == 'train': 53 | if arg.load: 54 | ckpt, ctr = find_ckpt(model_save_dir + 'saved_model/') 55 | saver.restore(sess, ckpt) 56 | else: 57 | save_python_files(save_dir=model_save_dir + 'bin/') 58 | write_hyperparameters(arg.toDict(), model_save_dir) 59 | sess.run(tf.global_variables_initializer()) 60 | 61 | writer = tf.summary.FileWriter("../summaries/" + arg.name, graph=sess.graph) 62 | 63 | elif arg.mode == 'predict': 64 | ckpt, ctr = find_ckpt(model_save_dir + 'saved_model/') 65 | saver.restore(sess, ckpt) 66 | 67 | initialize_uninitialized(sess) 68 | while True: 69 | try: 70 | feed = transformation_parameters(arg, ctr, no_transform=(arg.mode == 'predict')) # no transform if arg.visualize 71 | trf = {scal: feed.scal, tps_scal: feed.tps_scal, 72 | scal_var: feed.scal_var, rot_scal: feed.rot_scal, off_scal: feed.off_scal, augm_scal: feed.augm_scal} 73 | ctr += 1 74 | if arg.mode == 'train': 75 | if np.mod(ctr, arg.summary_interval) == 0: 76 | merged_summary = sess.run(merged, feed_dict=trf) 77 | writer.add_summary(merged_summary, global_step=ctr) 78 | 79 | _, loss = sess.run([model.optimize, model.loss], feed_dict=trf) 80 | if np.mod(ctr, arg.save_interval) == 0: 81 | saver.save(sess, model_save_dir + '/saved_model/' + 'save_net.ckpt', global_step=ctr) 82 | 83 | elif arg.mode == 'predict': 84 | img, img_rec, mu, heat_raw = sess.run([model.image_in, model.reconstruct_same_id, model.mu, 85 | batch_colour_map(model.part_maps)], feed_dict=trf) 86 | 87 | save(img, mu, ctr) 88 | 89 | except tf.errors.OutOfRangeError: 90 | print("End of training.") 91 | break 92 | 93 | 94 | if __name__ == '__main__': 95 | arg = DotMap(vars(parse_args())) 96 | if arg.decoder == 'standard': 97 | if arg.reconstr_dim == 256: 98 | arg.rec_stages = [[256, 256], [128, 128], [64, 64], [32, 32], [16, 16], [8, 8], [4, 4]] 99 | arg.feat_slices = [[0, 0], [0, 0], [0, 0], [0, 0], [4, arg.n_parts], [2, 4], [0, 2]] 100 | arg.part_depths = [arg.n_parts, arg.n_parts, arg.n_parts, arg.n_parts, arg.n_parts, 4, 2] 101 | 102 | if arg.reconstr_dim == 128: 103 | arg.rec_stages = [[128, 128], [64, 64], [32, 32], [16, 16], [8, 8], [4, 4]] 104 | arg.feat_slices = [[0, 0], [0, 0], [0, 0], [4, arg.n_parts], [2, 4], [0, 2]] 105 | arg.part_depths = [arg.n_parts, arg.n_parts, arg.n_parts, arg.n_parts, 4, 2] 106 | main(arg) 107 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from dataloading import dataset_map_train 4 | from architectures import decoder_map, encoder_map 5 | 6 | import tensorflow as tf 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('name', type=str, help="name of the experiment") 11 | 12 | # run setting 13 | parser.add_argument('--mode', default='train', choices=['train', 'predict'], required=True) 14 | parser.add_argument('--gpu', type=int, required=True) 15 | parser.add_argument('--load', action='store_true') 16 | 17 | # dataset folder 18 | parser.add_argument('--dataset', choices=dataset_map_train.keys(), required=True) 19 | 20 | # options 21 | parser.add_argument('--covariance', action='store_true') 22 | parser.add_argument('--feat_shape', default=True, type=bool) 23 | parser.add_argument('--L1', action='store_true') 24 | 25 | parser.add_argument('--heat_feat_normalize', default=True, type=bool) 26 | parser.add_argument('--epochs', default=100000, type=int, help="number of epochs") 27 | 28 | # architectures 29 | parser.add_argument('--decoder', default='standard', choices=decoder_map.keys()) 30 | parser.add_argument('--encoder', default='seperate', choices=encoder_map.keys()) 31 | parser.add_argument('--in_dim', default=256, type=int, choices=[128, 256], help="dim of input img 256 or 128") 32 | parser.add_argument('--reconstr_dim', default=256, type=int, choices=[128, 256], help="dim of reconstructed img 256 or 128") 33 | parser.add_argument('--heat_dim', default=64, type=int, choices=[64], help="dim of part_map (fixed)") 34 | 35 | # modes 36 | parser.add_argument('--l_2_scal', default=0.1, type=float, help="scale around part means that is considered for l2") 37 | parser.add_argument('--l_2_threshold', default=0.2, type=float, help="") 38 | parser.add_argument('--L_inv_scal', default=0.8, type=float, help="") 39 | 40 | parser.add_argument('--bn', default=32, type=int, help="batchsize if not slim and 2 * batchsize if slim") 41 | parser.add_argument('--n_parts', default=16, type=int, help="number of parts") 42 | parser.add_argument('--n_features', default=64, type=int, help="neurons of feature map layer") 43 | parser.add_argument('--n_c', default=3, type=int) 44 | parser.add_argument('--nFeat_1', default=256, type=int, help="neurons in residual module of part hourglass") 45 | parser.add_argument('--nFeat_2', default=256, type=int, help="neurons in residual module of feature hourglass") 46 | 47 | # loss multiplication constants 48 | parser.add_argument('--lr', default=0.001, type=float, help="learning rate of network") 49 | parser.add_argument('--lr_d', default=0.001, type=float, help="adversarial setting: learning rate discriminator network") 50 | 51 | parser.add_argument('--c_l2', default=1., type=float, help="") 52 | parser.add_argument('--c_trans', default=5., type=float, help="") 53 | parser.add_argument('--c_precision_trans', default=0.1, type=float, help="") 54 | parser.add_argument('--c_t', default=1., type=float, help="") 55 | 56 | # tps parameters 57 | parser.add_argument('--schedule_scale', default=100000, type=int, help="") 58 | parser.add_argument('--scal', default=[0.8], type=float, nargs='+', help="default 0.6 sensible shedule [0.6, 0.6]") 59 | parser.add_argument('--tps_scal', default=[0.05], type=float, nargs='+', help="sensible shedule [0.01, 0.08]") 60 | parser.add_argument('--rot_scal', default=[0.1], type=float, nargs='+', help="sensible shedule [0.05, 0.6]") 61 | parser.add_argument('--off_scal', default=[0.15], type=float, nargs='+', help="sensible shedule [0.05, 0.15]") 62 | parser.add_argument('--scal_var', default=[0.05], type=float, nargs='+', help="sensible shedule [0.05, 0.2]") 63 | parser.add_argument('--augm_scal', default=[1.], type=float, nargs='+', help="sensible shedule [0.0, 1.]") 64 | 65 | #appearance parameters 66 | parser.add_argument('--contrast_var', default=0.5, type=float, help="contrast variation") 67 | parser.add_argument('--brightness_var', default=0.3, type=float, help="contrast variation") 68 | parser.add_argument('--saturation_var', default=0.1, type=float, help="contrast variation") 69 | parser.add_argument('--hue_var', default=0.3, type=float, help="contrast variation") 70 | parser.add_argument('--p_flip', default=0., type=float, help="contrast variation") 71 | 72 | # adverserial 73 | parser.add_argument('--adverserial', action='store_true') 74 | parser.add_argument('--c_g', default=0.0002, type=float, help="factor weighting adversarial loss generator") 75 | parser.add_argument('--patch_size', default=[49, 49], type=int, nargs=2, help="dim of patch_size") 76 | 77 | parser.add_argument('--print_vars', action='store_true') 78 | parser.add_argument('--save_interval', default=20000, type=int, help="saves model every n gradient steps") 79 | parser.add_argument('--summary_interval', default=500, type=int, help="writes summary every n gradient steps") 80 | 81 | parser.add_argument('--static', action='store_true') # for e.g.birds (inter-species reconstruction too difficult) 82 | parser.add_argument('--chunk_size', default=16, type=int, help="group of consecutive frames from video which are used for shape transformations") 83 | parser.add_argument('--n_shuffle', default=64, type=int, help="n shuffle data") 84 | parser.add_argument('--data_parallel_calls', default=4, type=int, help="number of parallel calls for tf map for preprocessing data") 85 | 86 | arg = parser.parse_args() 87 | return arg 88 | 89 | 90 | def write_hyperparameters(r, save_dir): 91 | filename = save_dir + "config.txt" 92 | with open(filename, "a") as input_file: 93 | for k, v in r.items(): 94 | line = '{}, {}'.format(k, v) 95 | print(line) 96 | print(line, file=input_file) 97 | 98 | 99 | -------------------------------------------------------------------------------- /architectures.py: -------------------------------------------------------------------------------- 1 | from utils import wrappy 2 | import tensorflow as tf 3 | from architecture_ops import _residual, _conv_bn_relu, _conv, nccuc 4 | from ops import softmax, get_features 5 | 6 | 7 | 8 | @wrappy 9 | def discriminator_patch(image, train): 10 | padding = 'VALID' 11 | x0 = image 12 | x1 = tf.layers.conv2d(x0, 32, 4, strides=1, padding=padding, activation=tf.nn.leaky_relu, name="conv_0") # 46 13 | x1 = tf.layers.batch_normalization(x1, training=train, name='bn_0') 14 | x1 = tf.layers.conv2d(x1, 64, 4, strides=2, padding=padding, activation=tf.nn.leaky_relu, name="conv_1") # 44 15 | x1 = tf.layers.batch_normalization(x1, training=train, name='bn_1') 16 | x2 = tf.layers.conv2d(x1, 128, 4, strides=2, padding=padding, activation=tf.nn.leaky_relu, name="conv_2") # 10 17 | x2 = tf.layers.batch_normalization(x2, training=train, name='bn_2') 18 | x3 = tf.layers.conv2d(x2, 256, 4, strides=2, padding=padding, activation=tf.nn.leaky_relu, name="conv_3") # 4 19 | x3 = tf.layers.batch_normalization(x3, training=train, name='bn_3') 20 | x4 = tf.reshape(x3, shape=[-1, 4 * 4 * 256]) 21 | x4 = tf.layers.dense(x4, 1, name="last_fc") 22 | return tf.nn.sigmoid(x4), x4 23 | 24 | 25 | @wrappy 26 | def decoder(encoding_list, train, reconstr_dim, n_c): 27 | padding = 'SAME' 28 | 29 | input = encoding_list[-1] 30 | conv1 = nccuc(input, encoding_list[-2], [512, 512], padding, train, name=1) # 8 31 | conv2 = nccuc(conv1, encoding_list[-3], [512, 256], padding, train, name=2) # 16 32 | conv3 = nccuc(conv2, encoding_list[-4], [256, 256], padding, train, name=3) # 32 33 | 34 | if reconstr_dim == 128: 35 | conv4 = nccuc(conv3, encoding_list[-5], [256, 128], padding, train, name=4) # 64 36 | conv5 = nccuc(conv4, encoding_list[-6], [128, 64], padding, train, name=5) # 128 37 | conv6 = tf.layers.conv2d(conv5, n_c, 6, strides=1, padding='SAME', activation=tf.nn.sigmoid, name="conv_6") 38 | reconstruction = conv6 # 128 39 | 40 | if reconstr_dim == 256: 41 | conv4 = nccuc(conv3, encoding_list[-5], [256, 128], padding, train, name=4) # 64 42 | conv5 = nccuc(conv4, encoding_list[-6], [128, 128], padding, train, name=5) # 128 43 | conv6 = nccuc(conv5, encoding_list[-7], [128, 64], padding, train, name=6) # 256 44 | conv7 = tf.layers.conv2d(conv6, n_c, 6, strides=1, padding='SAME', activation=tf.nn.sigmoid, name="conv_7") 45 | reconstruction = conv7 # 256 46 | return reconstruction 47 | 48 | 49 | def _hourglass(inputs, n, numOut, train, name='hourglass'): 50 | """ Hourglass Module 51 | Args: 52 | inputs : Input Tensor 53 | n : Number of downsampling step 54 | numOut : Number of Output Features (channels) 55 | name : Name of the block 56 | """ 57 | with tf.variable_scope(name): 58 | # Upper Branch 59 | up_1 = _residual(inputs, numOut, train=train, name='up_1') 60 | # Lower Branch 61 | low_ = tf.contrib.layers.max_pool2d(inputs, [2, 2], [2, 2], padding='VALID') 62 | low_1 = _residual(low_, numOut, train=train, name='low_1') 63 | 64 | if n > 0: 65 | low_2 = _hourglass(low_1, n - 1, numOut, train=train, name='low_2') 66 | else: 67 | low_2 = _residual(low_1, numOut, train=train, name='low_2') 68 | 69 | low_3 = _residual(low_2, numOut, train=train, name='low_3') 70 | up_2 = tf.image.resize_nearest_neighbor(low_3, tf.shape(low_3)[1:3] * 2, name='upsampling') 71 | return tf.add_n([up_2, up_1], name='out_hg') 72 | 73 | 74 | @wrappy 75 | def seperate_hourglass(inputs, train, n_landmark, n_features, nFeat_1, nFeat_2): 76 | _, h, w, c = inputs.get_shape().as_list() 77 | nLow = 4 # hourglass preprocessing reduces by factor two hourglass by factor 16 (2⁴) e.g. 128 -> 4 78 | n_Low_feat = 1 79 | dropout_rate = 0.2 80 | 81 | # Storage Table 82 | hg = [None] * 2 83 | ll = [None] * 2 84 | ll_ = [None] * 2 85 | drop = [None] * 2 86 | out = [None] * 2 87 | out_ = [None] * 2 88 | sum_ = [None] * 2 89 | 90 | nFeat_1 = nFeat_1 91 | nFeat_2 = nFeat_2 92 | 93 | train = train 94 | 95 | with tf.variable_scope('model'): 96 | with tf.variable_scope('preprocessing'): 97 | if h == 256: 98 | pad1 = tf.pad(inputs, [[0, 0], [2, 2], [2, 2], [0, 0]], name='pad_1') 99 | conv1 = _conv_bn_relu(pad1, filters=64, train=train, kernel_size=6, strides=2, name='conv_256_to_128') 100 | r1 = _residual(conv1, num_out=128, train=train, name='r1') 101 | pool1 = tf.contrib.layers.max_pool2d(r1, [2, 2], [2, 2], padding='VALID') 102 | r2 = _residual(pool1, num_out=int(nFeat_1 / 2), train=train, name='r2') 103 | r3 = _residual(r2, num_out=nFeat_1, train=train, name='r3') 104 | 105 | elif h == 128: 106 | pad1 = tf.pad(inputs, [[0, 0], [2, 2], [2, 2], [0, 0]], name='pad_1') 107 | conv1 = _conv_bn_relu(pad1, filters=64, train=train, kernel_size=6, strides=2, name='conv_64_to_32') 108 | r3 = _residual(conv1, num_out=nFeat_1, train=train, name='r3') 109 | 110 | elif h == 64: 111 | pad1 = tf.pad(inputs, [[0, 0], [3, 2], [3, 2], [0, 0]], name='pad_1') 112 | conv1 = _conv_bn_relu(pad1, filters=64, train=train, kernel_size=6, strides=1, name='conv_64_to_32') 113 | r3 = _residual(conv1, num_out=nFeat_1, train=train, name='r3') 114 | 115 | else: 116 | raise ValueError 117 | 118 | with tf.variable_scope('stage_0'): 119 | hg[0] = _hourglass(r3, nLow, nFeat_1, train=train, name='hourglass') 120 | drop[0] = tf.layers.dropout(hg[0], rate=dropout_rate, training=train, name='dropout') 121 | ll[0] = _conv_bn_relu(drop[0], nFeat_1, train=train, kernel_size=1, strides=1, pad='VALID', name='conv') 122 | ll_[0] = _conv(ll[0], nFeat_1, 1, 1, 'VALID', 'll') 123 | out[0] = _conv(ll[0], n_landmark, 1, 1, 'VALID', 'out') 124 | out_[0] = _conv(softmax(out[0]), nFeat_1, 1, 1, 'VALID', 'out_') 125 | sum_[0] = tf.add_n([out_[0], r3], name='merge') 126 | 127 | with tf.variable_scope('stage_1'): 128 | hg[1] = _hourglass(sum_[0], n_Low_feat, nFeat_2, train=train, name='hourglass') 129 | drop[1] = tf.layers.dropout(hg[1], rate=dropout_rate, 130 | training=train, name='dropout') 131 | ll[1] = _conv_bn_relu(drop[1], nFeat_2, train=train, kernel_size=1, strides=1, 132 | pad='VALID', name='conv') 133 | 134 | out[1] = _conv(ll[1], n_features, 1, 1, 'VALID', 'out') 135 | 136 | features = out[1] 137 | return softmax(out[0]), features 138 | 139 | 140 | decoder_map = {'standard': decoder} 141 | encoder_map = {'seperate': seperate_hourglass} 142 | 143 | 144 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import tensorflow as tf 3 | import numpy as np 4 | import os 5 | import matplotlib 6 | 7 | 8 | matplotlib.use('agg') 9 | import matplotlib.pyplot as plt 10 | from PIL import Image 11 | from matplotlib import cm 12 | import glob 13 | from shutil import copyfile 14 | from dotmap import DotMap 15 | 16 | 17 | 18 | 19 | def wrappy(func): 20 | def wrapped(*args, **kwargs): 21 | with tf.variable_scope(func.__name__): 22 | return func(*args, **kwargs) 23 | 24 | return wrapped 25 | 26 | 27 | def doublewrap(function): 28 | """ 29 | A decorator decorator, allowing to use the decorator to be used without 30 | parentheses if not arguments are provided. All arguments must be optional. 31 | """ 32 | 33 | @functools.wraps(function) 34 | def decorator(*args, **kwargs): 35 | if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): 36 | return function(args[0]) 37 | else: 38 | return lambda wrapee: function(wrapee, *args, **kwargs) 39 | 40 | return decorator 41 | 42 | 43 | @doublewrap 44 | def define_scope(function, *args, **kwargs): 45 | """ 46 | A decorator for functions that define TensorFlow operations. The wrapped 47 | function will only be executed once. Subsequent calls to it will directly 48 | return the result so that operations are added to the graph only once. 49 | 50 | The operations added by the function live within a tf.variable_scope(). If 51 | this decorator is used with arguments, they will be forwarded to the 52 | variable scope. The scope name defaults to the name of the wrapped 53 | function. 54 | """ 55 | attribute = '_cache_' + function.__name__ 56 | name = function.__name__ 57 | 58 | @property 59 | @functools.wraps(function) 60 | def decorator(self): 61 | if not hasattr(self, attribute): 62 | with tf.variable_scope(name): 63 | setattr(self, attribute, function(self, *args, **kwargs)) 64 | return getattr(self, attribute) 65 | 66 | return decorator 67 | 68 | 69 | def initialize_uninitialized(sess): 70 | global_vars = tf.global_variables() 71 | is_not_initialized = sess.run([tf.is_variable_initialized(var) for var in global_vars]) 72 | not_initialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f] 73 | 74 | print([str(i.name) for i in not_initialized_vars]) # only for testing 75 | if len(not_initialized_vars): 76 | sess.run(tf.variables_initializer(not_initialized_vars)) 77 | 78 | 79 | def probabilistic_switch(a, b, counter, scale=10000): 80 | """ 81 | :param a: 82 | :param b: 83 | :param counter: 84 | :param scale: corresponds to decay rate 85 | :return: at counter 0 a is returned with p=1. the probability decays 86 | asymptotically to p 0.5 with increasing counter values 87 | """ 88 | p = counter / (2 * counter + scale) 89 | r = np.random.choice([a, b], p=[1 - p, p]) 90 | return r 91 | 92 | 93 | def evolve_a_to_b(min_max, time): 94 | if len(min_max) == 1: 95 | evolve = min_max[0] 96 | elif len(min_max) == 2: 97 | evolve = time * min_max[1] + (1 - time) * min_max[0] 98 | 99 | return evolve 100 | 101 | 102 | def transformation_parameters(arg=None, ctr=None, no_transform=False): 103 | """ 104 | if no transform: arg.scal is still used 105 | default for penn {'scal': 0.5, 'tps_scal': 0.05, 'rot_scal': 0.3, 'off_scal': 0.15, 'scal_var': 0.1} 106 | :param t: 107 | :param range: 108 | :return: 109 | """ 110 | trf_arg = {} 111 | 112 | if no_transform: 113 | trf_arg['scal'] = arg.scal[0] 114 | trf_arg['tps_scal'] = 0. 115 | trf_arg['rot_scal'] = 0. 116 | trf_arg['off_scal'] = 0. 117 | trf_arg['scal_var'] = 0. 118 | trf_arg['augm_scal'] = 0. 119 | 120 | else: 121 | time = min(ctr / arg.schedule_scale, 1.) 122 | trf_arg['scal'] = evolve_a_to_b(arg.scal, time) 123 | trf_arg['tps_scal'] = evolve_a_to_b(arg.tps_scal, time) 124 | trf_arg['rot_scal'] = evolve_a_to_b(arg.rot_scal, time) 125 | trf_arg['off_scal'] = evolve_a_to_b(arg.off_scal, time) 126 | trf_arg['scal_var'] = evolve_a_to_b(arg.scal_var, time) 127 | trf_arg['augm_scal'] = evolve_a_to_b(arg.augm_scal, time) 128 | 129 | dotty = DotMap(trf_arg) 130 | return dotty 131 | 132 | 133 | def batch_colour_map(heat_map): 134 | c = heat_map.get_shape().as_list()[-1] 135 | colour = [] 136 | for i in range(c): 137 | colour.append(cm.hsv(float(i / c))[:3]) 138 | colour = tf.constant(colour) 139 | colour_map = tf.einsum('bijk,kl->bijl', heat_map, colour) 140 | return colour_map 141 | 142 | 143 | def np_batch_colour_map(heat_map): 144 | c = heat_map.shape[-1] 145 | colour = [] 146 | for i in range(c): 147 | colour.append(cm.hsv(float(i / c))[:3]) 148 | np_colour = np.array(colour) 149 | colour_map = np.einsum('bijk,kl->bijl', heat_map, np_colour) 150 | return colour_map 151 | 152 | 153 | def identify_parts(image, raw, n_parts, version): 154 | image_base = np.array(Image.fromarray(image[0]).resize((64, 64))) / 255. 155 | base = image_base[:, :, 0] + image_base[ :, :, 1] + image_base[:, :, 2] 156 | directory = os.path.join('../images/' + str(version) + "/identify/") 157 | if not os.path.exists(directory): 158 | os.makedirs(directory) 159 | for i in range(n_parts): 160 | prlonint("hep") 161 | plt.imshow(raw[0, :, :, i] + 0.02 * base, cmap='gray') 162 | fname = directory + str(i) + '.png' 163 | plt.savefig(fname, bbox_inches='tight') 164 | 165 | 166 | def save(img, mu, counter): 167 | batch_size, out_shape = img.shape[0], img.shape[1:3] 168 | marker_list = ["o", "v", "s", "|", "_"] 169 | directory = os.path.join('../images/landmarks/') 170 | if not os.path.exists(directory): 171 | os.makedirs(directory) 172 | s = out_shape[0] // 8 173 | n_parts = mu.shape[-2] 174 | mu_img = (mu + 1.) / 2. * np.array(out_shape)[0] 175 | steps = batch_size 176 | step_size = 1 177 | 178 | for i in range(0, steps, step_size): 179 | plt.imshow(img[i]) 180 | for j in range(n_parts): 181 | plt.scatter(mu_img[i, j, 1], mu_img[i, j, 0], s=s, marker=marker_list[np.mod(j, len(marker_list))], color=cm.hsv(float(j / n_parts))) 182 | 183 | plt.axis('off') 184 | fname = directory + str(counter) + '_' + str(i) + '.png' 185 | plt.savefig(fname, bbox_inches='tight') 186 | plt.close() 187 | 188 | @wrappy 189 | def tf_summary_feat_and_parts(encoding_list, part_depths, visualize_features=False, square=True): 190 | for n, enc in enumerate(encoding_list): 191 | part_maps, feat_maps = enc[:, :, :, :part_depths[n]], enc[:, :, :, part_depths[n]:] 192 | if square: 193 | part_maps = part_maps ** 2 194 | color_part_map = batch_colour_map(part_maps) 195 | with tf.variable_scope("parts"): 196 | tf.summary.image(name="parts" + str(n), tensor=color_part_map, max_outputs=4) 197 | 198 | if visualize_features: 199 | if feat_maps.get_shape().as_list()[-1] > 0: 200 | with tf.variable_scope("feature_maps"): 201 | if square: 202 | feat_maps = feat_maps ** 2 203 | color_feat_map = batch_colour_map( 204 | feat_maps / tf.reduce_sum(feat_maps, axis=[1, 2], keepdims=True)) 205 | tf.summary.image(name="feat_maps" + str(n), tensor=color_feat_map ** 2, max_outputs=4) 206 | 207 | 208 | @wrappy 209 | def part_to_color_map(encoding_list, part_depths, size, square=True, ): 210 | part_maps = encoding_list[0][:, :, :, :part_depths[0]] 211 | if square: 212 | part_maps = part_maps ** 4 213 | color_part_map = batch_colour_map(part_maps) 214 | color_part_map = tf.image.resize_images(color_part_map, size=(size, size)) 215 | 216 | return color_part_map 217 | 218 | 219 | 220 | def save_python_files(save_dir): 221 | assert (not os.path.exists(save_dir)) 222 | os.makedirs(save_dir) 223 | for file in glob.glob("*.py"): 224 | copyfile(src=file, dst=save_dir + file) 225 | 226 | 227 | def find_ckpt(dir): 228 | filename = dir + 'checkpoint' 229 | if os.path.exists(filename): 230 | with open(filename) as f: 231 | content = f.readline() 232 | ckpt = content.split('"')[1] 233 | print("found checkpoint :" + ckpt) 234 | print("counter set to", ckpt.split("-")[-1]) 235 | return dir + ckpt, int(ckpt.split("-")[-1]) 236 | else: 237 | raise FileNotFoundError 238 | 239 | 240 | def convert_image_np(inp): 241 | """Convert a Tensor to numpy image.""" 242 | inp = inp.numpy().transpose((1, 2, 0)) 243 | inp = np.clip(inp, 0, 1) 244 | return inp 245 | -------------------------------------------------------------------------------- /transformations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from dotmap import DotMap 4 | 5 | def tf_rotation_mat(rotation): 6 | """ 7 | :param rotation: tf tensor of shape [1] 8 | :return: rotation matrix as tf tensor with shape [2, 2] 9 | """ 10 | a = tf.expand_dims(tf.cos(rotation), axis=0) 11 | b = tf.expand_dims(tf.sin(rotation), axis=0) 12 | row_1 = tf.concat([a, -b], axis=1) 13 | row_2 = tf.concat([b, a], axis=1) 14 | mat = tf.concat([row_1, row_2], axis=0) 15 | return mat 16 | 17 | def tps_parameters(batch_size, scal, tps_scal, rot_scal, off_scal, scal_var, rescal=1): 18 | coord = tf.constant([[[-0.5, -0.5], [0.5, -0.5], [-0.5, 0.5], [0.5, 0.5], 19 | [0.2, -0.2], [-0.2, 0.2], [0.2, 0.2], [-0.2, - 0.2]]] 20 | , dtype=tf.float32) 21 | 22 | coord = tf.tile(coord, [batch_size, 1, 1]) 23 | shape = coord.get_shape() 24 | coord = coord + tf.random_uniform(shape=shape, minval=-0.2, maxval=0.2) 25 | vector = tf.random_uniform(shape=shape, minval=-tps_scal, maxval=tps_scal, dtype=tf.float32) 26 | 27 | offset = tf.random_uniform(shape=[batch_size, 1, 2], minval=-off_scal, maxval=off_scal, dtype=tf.float32) 28 | offset_2 = tf.random_uniform(shape=[batch_size, 1, 2], minval=-off_scal, maxval=off_scal, dtype=tf.float32) 29 | t_scal = tf.random_uniform(shape=[batch_size, 2], minval=scal * (1. - scal_var), maxval=scal * (1. + scal_var), 30 | dtype=tf.float32) 31 | t_scal = t_scal * rescal 32 | 33 | rot_param = tf.random_uniform(shape=[batch_size, 1], minval=-rot_scal, maxval=rot_scal, dtype=tf.float32) 34 | rot_mat = tf.map_fn(tf_rotation_mat, rot_param) 35 | 36 | parameter_dict = {'coord': coord, 'vector': vector, 'offset': offset, 'offset_2': offset_2, 37 | 't_scal': t_scal, 'rot_mat': rot_mat} 38 | parameter_dict = DotMap(parameter_dict) 39 | return parameter_dict 40 | 41 | 42 | def static_param_2d(param): 43 | bn, d_1 = param.get_shape().as_list() 44 | param = param[::2] 45 | param = tf.tile(param, [1, 2]) 46 | param = tf.reshape(param, [bn, d_1]) 47 | 48 | return param 49 | 50 | 51 | def static_param_3d(param): 52 | bn, d_1, d_2 = param.get_shape().as_list() 53 | param = param[::2] 54 | param = tf.tile(param, [1, 2, 1]) 55 | param = tf.reshape(param, [bn, d_1, d_2]) 56 | return param 57 | 58 | 59 | def make_input_tps_param(tps_param, move_point=None, scal_point=None): 60 | coord = tps_param.coord 61 | vector = tps_param.vector 62 | offset = tps_param.offset 63 | offset_2 = tps_param.offset_2 64 | rot_mat = tps_param.rot_mat 65 | t_scal = tps_param.t_scal 66 | 67 | scaled_coord = tf.einsum('bk,bck->bck', t_scal, coord + vector - offset) + offset 68 | t_vector = tf.einsum('blk,bck->bcl', rot_mat, scaled_coord - offset_2) + offset_2 - coord 69 | 70 | if move_point is not None and scal_point is not None: 71 | coord = tf.einsum('bk,bck->bck', scal_point, coord + move_point) 72 | t_vector = tf.einsum('bk,bck->bck', scal_point, t_vector) 73 | 74 | else: 75 | assert(move_point is None and scal_point is None) 76 | 77 | return coord, t_vector 78 | 79 | 80 | def adapt_tps_for_crop(tps_param, move_point, scal_point): 81 | """ 82 | :param center_point: b, 1, 2 83 | :param tps_param: 84 | :return: 85 | """ 86 | move_point = - move_point 87 | scal_point = 1. / scal_point 88 | crop_coord, t_vector_coord = make_input_tps_param(tps_param, move_point, scal_point) 89 | return crop_coord, t_vector_coord 90 | 91 | #code adapted from https://github.com/iwyoo/tf_ThinPlateSpline 92 | 93 | def ThinPlateSpline(U, coord, vector, out_size, n_c, move=None, scal=None): 94 | 95 | coord = coord[:, :, ::-1] 96 | vector = vector[:, :, ::-1] 97 | 98 | num_batch = tf.shape(U)[0] 99 | height = tf.shape(U)[1] 100 | width = tf.shape(U)[2] 101 | channels = n_c 102 | out_height = out_size 103 | out_width = out_size 104 | height_f = tf.cast(height, 'float32') 105 | width_f = tf.cast(width, 'float32') 106 | num_point = tf.shape(coord)[1] 107 | 108 | def _repeat(x, n_repeats): 109 | rep = tf.transpose(tf.expand_dims(tf.ones(shape=tf.stack([n_repeats, ])), 1), [1, 0]) 110 | rep = tf.cast(rep, 'int32') 111 | x = tf.matmul(tf.reshape(x, (-1, 1)), rep) 112 | return tf.reshape(x, [-1]) 113 | 114 | def _interpolate(im, y, x): 115 | # constants 116 | y = tf.cast(y, 'float32') 117 | x = tf.cast(x, 'float32') 118 | 119 | zero = tf.zeros([], dtype='int32') 120 | max_y = tf.cast(height - 1, 'int32') 121 | max_x = tf.cast(width - 1, 'int32') 122 | 123 | # scale indices from aprox [-1, 1] to [0, width/height] 124 | 125 | y = (y + 1) * height_f / 2.0 126 | x = (x + 1) * width_f / 2.0 127 | 128 | y = tf.reshape(y, [-1]) 129 | x = tf.reshape(x, [-1]) 130 | 131 | # do sampling 132 | y0 = tf.cast(tf.floor(y), 'int32') 133 | y1 = y0 + 1 134 | x0 = tf.cast(tf.floor(x), 'int32') 135 | x1 = x0 + 1 136 | 137 | y0 = tf.clip_by_value(y0, zero, max_y) 138 | y1 = tf.clip_by_value(y1, zero, max_y) 139 | x0 = tf.clip_by_value(x0, zero, max_x) 140 | x1 = tf.clip_by_value(x1, zero, max_x) 141 | 142 | base = _repeat(tf.range(num_batch) * width * height, out_height * out_width) 143 | base_y0 = base + y0 * width 144 | base_y1 = base + y1 * width 145 | idx_a = base_y0 + x0 146 | idx_b = base_y1 + x0 147 | idx_c = base_y0 + x1 148 | idx_d = base_y1 + x1 149 | 150 | # use indices to lookup pixels in the flat image and restore 151 | # channels dim 152 | im_flat = tf.reshape(im, [-1, channels]) 153 | im_flat = tf.cast(im_flat, 'float32') 154 | Ia = tf.gather(im_flat, idx_a) 155 | Ib = tf.gather(im_flat, idx_b) 156 | Ic = tf.gather(im_flat, idx_c) 157 | Id = tf.gather(im_flat, idx_d) 158 | 159 | # and finally calculate interpolated values 160 | x0_f = tf.cast(x0, 'float32') 161 | x1_f = tf.cast(x1, 'float32') 162 | y0_f = tf.cast(y0, 'float32') 163 | y1_f = tf.cast(y1, 'float32') 164 | wa = tf.expand_dims(((x1_f - x) * (y1_f - y)), 1) 165 | wb = tf.expand_dims(((x1_f - x) * (y - y0_f)), 1) 166 | wc = tf.expand_dims(((x - x0_f) * (y1_f - y)), 1) 167 | wd = tf.expand_dims(((x - x0_f) * (y - y0_f)), 1) 168 | output = tf.add_n([wa * Ia, wb * Ib, wc * Ic, wd * Id]) 169 | return output 170 | 171 | def _meshgrid(height, width, coord): 172 | 173 | x_t = tf.tile(tf.reshape(tf.linspace(- 1., 1., width), [1, width]), [height, 1]) 174 | y_t = tf.tile(tf.reshape(tf.linspace(- 1., 1., height), [height, 1]), [1, width]) 175 | 176 | x_t_flat = tf.reshape(x_t, (1, 1, -1)) 177 | y_t_flat = tf.reshape(y_t, (1, 1, -1)) 178 | 179 | px = tf.expand_dims(coord[:, :, 0], 2) # [bn, pn, 1] 180 | py = tf.expand_dims(coord[:, :, 1], 2) # [bn, pn, 1] 181 | d2 = tf.square(x_t_flat - px) + tf.square(y_t_flat - py) 182 | r = d2 * tf.log(d2 + 1e-6) # [bn, pn, h*w] 183 | x_t_flat_g = tf.tile(x_t_flat, [num_batch, 1, 1]) # [bn, 1, h*w] 184 | y_t_flat_g = tf.tile(y_t_flat, [num_batch, 1, 1]) # [bn, 1, h*w] 185 | ones = tf.ones_like(x_t_flat_g) # [bn, 1, h*w] 186 | 187 | grid = tf.concat([ones, x_t_flat_g, y_t_flat_g, r], 1) # [bn, 3+pn, h*w] 188 | return grid 189 | 190 | def _transform(T, coord, move, scal): 191 | # grid of (x_t, y_t, 1), eq (1) in ref [1] 192 | grid = _meshgrid(out_height, out_width, coord) # [bn, 3+pn, h*w] 193 | 194 | # transform A x (1, x_t, y_t, r1, r2, ..., rn) -> (x_s, y_s) 195 | # [bn, 2, pn+3] x [bn, pn+3, h*w] -> [bn, 2, h*w] 196 | T_g = tf.matmul(T, grid) # 197 | x_s = tf.slice(T_g, [0, 0, 0], [-1, 1, -1]) 198 | y_s = tf.slice(T_g, [0, 1, 0], [-1, 1, -1]) 199 | 200 | if move is not None and scal is not None: 201 | off_y = tf.expand_dims(move[:, :, 0], axis=-1) 202 | off_x = tf.expand_dims(move[:, :, 1], axis=-1) 203 | scal_y = tf.expand_dims(tf.expand_dims(scal[:, 0], axis=-1), axis=-1) 204 | scal_x = tf.expand_dims(tf.expand_dims(scal[:, 1], axis=-1), axis=-1) 205 | y = (y_s * scal_y + off_y) 206 | x = (x_s * scal_x + off_x) 207 | 208 | else: 209 | assert (move is None and scal is None) 210 | y = y_s 211 | x = x_s 212 | 213 | return y, x 214 | 215 | def _solve_system(coord, vector): 216 | ones = tf.ones([num_batch, num_point, 1], dtype="float32") 217 | p = tf.concat([ones, coord], 2) # [bn, pn, 3] 218 | 219 | p_1 = tf.reshape(p, [num_batch, -1, 1, 3]) # [bn, pn, 1, 3] 220 | p_2 = tf.reshape(p, [num_batch, 1, -1, 3]) # [bn, 1, pn, 3] 221 | d2 = tf.reduce_sum(tf.square(p_1 - p_2), 3) # [bn, pn, pn] 222 | r = d2 * tf.log(d2 + 1e-6) # Kernel [bn, pn, pn] 223 | 224 | zeros = tf.zeros([num_batch, 3, 3], dtype="float32") 225 | W_0 = tf.concat([p, r], 2) # [bn, pn, 3+pn] 226 | W_1 = tf.concat([zeros, tf.transpose(p, [0, 2, 1])], 2) # [bn, 3, pn+3] 227 | W = tf.concat([W_0, W_1], 1) # [bn, pn+3, pn+3] 228 | W_inv = tf.matrix_inverse(W) 229 | 230 | tp = tf.pad(coord + vector, 231 | [[0, 0], [0, 3], [0, 0]], "CONSTANT") # [bn, pn+3, 2] 232 | T = tf.matmul(W_inv, tp) # [bn, pn+3, 2] 233 | T = tf.transpose(T, [0, 2, 1]) # [bn, 2, pn+3] 234 | 235 | return T 236 | 237 | T = _solve_system(coord, vector) 238 | y, x = _transform(T, coord, move, scal) 239 | input_transformed = _interpolate(U, y, x) 240 | output = tf.reshape(input_transformed, [num_batch, out_height, out_width, channels]) 241 | y = tf.reshape(y, [num_batch, out_height, out_width, 1]) 242 | x = tf.reshape(x, [num_batch, out_height, out_width, 1]) 243 | t_arr = tf.concat([y, x], axis=-1) 244 | return output, t_arr 245 | 246 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from ops import get_features, part_map_to_mu_L_inv, feat_mu_to_enc, \ 4 | fold_img_with_mu, fold_img_with_L_inv, prepare_pairs, AbsDetJacobian, augm_mu, get_img_slice_around_mu 5 | from utils import define_scope, batch_colour_map, tf_summary_feat_and_parts, part_to_color_map 6 | from transformations import ThinPlateSpline, make_input_tps_param 7 | from architectures import decoder_map, encoder_map, discriminator_patch 8 | 9 | 10 | class Model: 11 | def __init__(self, orig_img, arg, tps_param_dic): 12 | 13 | self.arg = arg 14 | 15 | self.train = (self.arg.mode == 'train') 16 | self.tps_par = tps_param_dic 17 | self.image_orig = orig_img 18 | self.encoder = encoder_map[arg.encoder] 19 | self.img_decoder = decoder_map[arg.decoder] 20 | self.discriminator = discriminator_patch 21 | 22 | self.image_in, self.image_rec, self.transform_mesh = None, None,None 23 | 24 | self.mu, self.mu_t, self.stddev_t = None, None, None 25 | 26 | self.volume_mesh, self.features, self.L_inv, self.part_maps = None, None, None, None 27 | self.encoding_same_id, self.reconstruct_same_id = None, None 28 | 29 | self.heat_mask_l2, self.fold_img_squared = None, None 30 | 31 | # adverserial 32 | self.adverserial = self.arg.adverserial 33 | self.t_D, self.t_D_logits = None, None 34 | self.patches = None 35 | 36 | self.update_ops = None 37 | 38 | self.graph() 39 | self.optimize 40 | self.visualize() 41 | 42 | def graph(self): 43 | with tf.variable_scope("tps"): 44 | coord, vector = make_input_tps_param(self.tps_par) 45 | t_images, t_mesh = ThinPlateSpline(self.image_orig, coord, vector, self.arg.in_dim, self.arg.n_c) 46 | self.image_in, self.image_rec = prepare_pairs(t_images, self.arg.reconstr_dim, self.arg) 47 | self.transform_mesh = tf.image.resize_images(t_mesh, size=(self.arg.heat_dim, self.arg.heat_dim)) 48 | self.volume_mesh = AbsDetJacobian(self.transform_mesh) 49 | 50 | with tf.variable_scope("encoding"): 51 | self.part_maps, raw_features = self.encoder(self.image_in, self.train, 52 | self.arg.n_parts, self.arg.n_features, 53 | self.arg.nFeat_1, 54 | self.arg.nFeat_2) 55 | 56 | self.mu, self.L_inv = part_map_to_mu_L_inv(part_maps=self.part_maps, scal=self.arg.L_inv_scal) 57 | self.features = get_features(raw_features, self.part_maps, slim=True) 58 | 59 | with tf.variable_scope("transform"): 60 | integrant = tf.squeeze(tf.expand_dims(self.part_maps, axis=-1) * tf.expand_dims(self.volume_mesh, axis=-1)) 61 | self.integrant = integrant / tf.reduce_sum(integrant, axis=[1, 2], 62 | keepdims=True) 63 | 64 | self.mu_t = tf.einsum('aijk,aijl->akl', self.integrant, self.transform_mesh) 65 | transform_mesh_out_prod = tf.einsum('aijm,aijn->aijmn', self.transform_mesh, self.transform_mesh) 66 | mu_out_prod = tf.einsum('akm,akn->akmn', self.mu_t, self.mu_t) 67 | self.stddev_t = tf.einsum('aijk,aijmn->akmn', self.integrant, transform_mesh_out_prod) - mu_out_prod 68 | 69 | with tf.variable_scope("generation"): 70 | with tf.variable_scope("encoding"): 71 | self.encoding_same_id = feat_mu_to_enc(self.features, self.mu, self.L_inv, 72 | self.arg.rec_stages, self.arg.part_depths, 73 | self.arg.feat_slices, n_reverse=2, 74 | covariance=self.arg.covariance, 75 | feat_shape=self.arg.average_features_mode, 76 | heat_feat_normalize=self.arg.heat_feat_normalize, 77 | static=self.arg.static) 78 | 79 | self.reconstruct_same_id = self.img_decoder(self.encoding_same_id, self.train, self.arg.reconstr_dim, self.arg.n_c) 80 | 81 | if self.adverserial: 82 | with tf.variable_scope("adverserial_on_patches"): 83 | flatten_dim = 2 * self.arg.bn * self.arg.n_parts 84 | part_map_last_layer = self.encoding_same_id[0][:, :, :, :self.arg.part_depths[0]] 85 | real_patches = get_img_slice_around_mu(tf.concat([self.image_rec, part_map_last_layer], axis=-1), 86 | self.mu, self.arg.patch_size) 87 | real_patches = tf.reshape(real_patches, 88 | shape=[flatten_dim, self.arg.patch_size[0], self.arg.patch_size[1], -1]) 89 | fake_patches_same_id = get_img_slice_around_mu(tf.concat( 90 | [self.reconstruct_same_id, part_map_last_layer], axis=-1), self.mu, self.arg.patch_size) 91 | fake_patches_same_id = tf.reshape(fake_patches_same_id, shape=[flatten_dim, self.arg.patch_size[0], 92 | self.arg.patch_size[1], -1]) 93 | self.patches = tf.concat([real_patches, fake_patches_same_id], axis=0) 94 | self.t_D, self.t_D_logits = self.discriminator(self.patches, train=self.train) 95 | 96 | self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 97 | 98 | @define_scope 99 | def loss(self): 100 | mu_t_1, mu_t_2 = self.mu_t[:self.arg.bn], self.mu_t[self.arg.bn:] 101 | stddev_t_1, stddev_t_2 = self.stddev_t[:self.arg.bn], self.stddev_t[self.arg.bn:] 102 | transform_loss = tf.reduce_mean((mu_t_1 - mu_t_2) ** 2) 103 | 104 | precision_sq = (stddev_t_1 - stddev_t_2) ** 2 105 | 106 | eps = 1e-6 107 | precision_loss = tf.reduce_mean(tf.sqrt(tf.reduce_sum(precision_sq, axis=[2, 3]) + eps)) 108 | 109 | img_difference = self.reconstruct_same_id - self.image_rec 110 | 111 | if self.arg.L1: 112 | distance_metric = tf.abs(img_difference) 113 | 114 | else: 115 | distance_metric = img_difference ** 2 116 | 117 | if self.arg.fold_with_shape: 118 | fold_img_squared = fold_img_with_L_inv( 119 | distance_metric, tf.stop_gradient(self.mu), tf.stop_gradient(self.L_inv), self.arg.l_2_scal, 120 | visualize=False, threshold=self.arg.l_2_threshold, normalize=True) 121 | else: 122 | fold_img_squared, self.heat_mask_l2 = fold_img_with_mu(distance_metric, self.mu, self.arg.l_2_scal, 123 | visualize=False,threshold=self.arg.l_2_threshold, 124 | normalize=True) 125 | 126 | self.fold_img_squared = fold_img_squared 127 | tf.summary.image(name="l2_loss", tensor=fold_img_squared, max_outputs=4, family="reconstr") 128 | l2_loss = tf.reduce_mean(tf.reduce_sum(fold_img_squared, axis=[1, 2])) 129 | 130 | if self.adverserial: 131 | flatten_dim = 2 * self.arg.bn * self.arg.n_parts 132 | D, D_ = self.t_D[:flatten_dim], self.t_D[flatten_dim:] 133 | D_logits, D_logits_ = self.t_D_logits[:flatten_dim], self.t_D_logits[flatten_dim:] 134 | 135 | d_loss_real = tf.reduce_mean( 136 | tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logits, labels=tf.ones_like(D))) 137 | d_loss_fake = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logits_, labels=tf.zeros_like(D_))) 138 | d_loss = d_loss_real + d_loss_fake 139 | g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logits_, labels=tf.ones_like(D_))) 140 | 141 | else: 142 | d_loss, g_loss = tf.constant(0.), tf.constant(0.) 143 | 144 | return transform_loss, precision_loss, l2_loss, d_loss, g_loss 145 | 146 | @define_scope 147 | def optimize(self): 148 | transform_loss, precision_loss, l2_loss, d_loss, g_loss = self.loss 149 | 150 | total_loss = self.arg.c_l2 * l2_loss + self.arg.c_trans * transform_loss\ 151 | + self.arg.c_precision_trans * precision_loss + self.arg.c_g * g_loss 152 | 153 | tf.summary.scalar(name="total_loss", tensor=total_loss) 154 | tf.summary.scalar(name="l2", tensor=l2_loss) 155 | tf.summary.scalar(name="transform_loss", tensor=transform_loss) 156 | tf.summary.scalar(name="precision_loss", tensor=precision_loss) 157 | if self.adverserial: 158 | tf.summary.scalar(name="g_loss", tensor=g_loss) 159 | tf.summary.scalar(name="d_loss", tensor=d_loss) 160 | 161 | tvar = tf.trainable_variables() 162 | adverserial_vars = [var for var in tvar if 'discriminator' in var.name] 163 | rest_vars = [var for var in tvar if 'discriminator' not in var.name] 164 | 165 | if self.arg.print_vars: 166 | if self.adverserial: 167 | print("adverserial_vars") 168 | for var in adverserial_vars: 169 | print(var) 170 | 171 | print("normal_vars") 172 | for var in rest_vars: 173 | print(var) 174 | 175 | with tf.control_dependencies(self.update_ops): 176 | optimizer = tf.train.AdamOptimizer(learning_rate=self.arg.lr) 177 | 178 | optimizer_d = tf.train.AdamOptimizer(learning_rate=self.arg.lr_d) 179 | 180 | if self.adverserial: 181 | return optimizer.minimize(total_loss, var_list=rest_vars), optimizer_d.minimize(d_loss, var_list=adverserial_vars) 182 | else: 183 | return optimizer.minimize(total_loss, var_list=rest_vars) 184 | 185 | def visualize(self): 186 | tf.summary.image(name="g_reconstr", tensor=self.image_rec, max_outputs=4, family="reconstr") 187 | 188 | normal = part_to_color_map(self.encoding_same_id, self.arg.part_depths, size=self.arg.in_dim) 189 | normal = normal / (1 + tf.reduce_sum(normal, axis=-1, keepdims=True)) 190 | vis_normal = tf.where(tf.tile(tf.reduce_sum(normal, axis=-1, keepdims=True), [1, 1, 1, 3]) > 0.3, normal, 191 | self.image_in) 192 | heat_mask_l2 = tf.image.resize_images(tf.tile(self.heat_mask_l2, [1, 1, 1, 3]), size=(self.arg.in_dim, self.arg.in_dim)) 193 | vis_normal = tf.where(heat_mask_l2 > self.arg.l_2_threshold, vis_normal, 0.3 * vis_normal) 194 | tf.summary.image(name="gt_t_1", tensor=vis_normal[:self.arg.bn], max_outputs=4, family="t_1") 195 | tf.summary.image(name="gt_t_2", tensor=vis_normal[self.arg.bn:], max_outputs=4, family="t_2") 196 | tf.summary.image(name="part_maps", tensor=batch_colour_map(self.part_maps[:self.arg.bn]), max_outputs=4, 197 | family="t_1") 198 | tf.summary.image(name="part_maps", tensor=batch_colour_map(self.part_maps[self.arg.bn:]), max_outputs=4, 199 | family="t_2") 200 | # tf.summary.image(name="VolumeElement", tensor=self.volume_mesh, max_outputs=4, family="Volume") 201 | 202 | if self.adverserial: 203 | f_dim = 2 * self.arg.bn * self.arg.n_parts 204 | with tf.variable_scope("patch_real"): 205 | tf.summary.image(name="patch_real", 206 | tensor=self.patches[:f_dim, :, :, :self.arg.n_c], max_outputs=4) 207 | with tf.variable_scope("patch_fake"): 208 | tf.summary.image(name="fake_same", 209 | tensor=self.patches[f_dim: f_dim + f_dim // 2, :, :, :self.arg.n_c], 210 | max_outputs=4) 211 | 212 | 213 | with tf.variable_scope("reconstr_same_id"): 214 | tf.summary.image(name="same_id_reconstruction", tensor=self.reconstruct_same_id, max_outputs=4, 215 | family="reconstr") 216 | 217 | with tf.variable_scope("normal"): 218 | tf_summary_feat_and_parts(self.encoding_same_id, self.arg.part_depths, visualize_features=False) 219 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from utils import wrappy 4 | from tensorflow.python.ops import math_ops 5 | from tensorflow.python.ops import random_ops 6 | 7 | 8 | def AbsDetJacobian(batch_meshgrid): 9 | """ 10 | :param batch_meshgrid: takes meshgrid tensor of dim [bn, h, w, 2] (conceptually meshgrid represents a two dimensional function f = [fx, fy] on [bn, h, w] ) 11 | :return: returns Abs det of Jacobian of f of dim [bn, h, w, 1 ] 12 | """ 13 | y_c = tf.expand_dims(batch_meshgrid[:, :, :, 0], -1) 14 | x_c = tf.expand_dims(batch_meshgrid[:, :, :, 1], -1) 15 | sobel_x = 1 / 4 * tf.constant([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], tf.float32) 16 | sobel_x_filter = tf.reshape(sobel_x, [3, 3, 1, 1]) 17 | sobel_y_filter = tf.transpose(sobel_x_filter, [1, 0, 2, 3]) 18 | 19 | filtered_y_y = tf.nn.conv2d(y_c, sobel_y_filter, strides=[1, 1, 1, 1], padding='VALID') 20 | filtered_y_x = tf.nn.conv2d(y_c, sobel_x_filter, strides=[1, 1, 1, 1], padding='VALID') 21 | filtered_x_y = tf.nn.conv2d(x_c, sobel_y_filter, strides=[1, 1, 1, 1], padding='VALID') 22 | filtered_x_x = tf.nn.conv2d(x_c, sobel_x_filter, strides=[1, 1, 1, 1], padding='VALID') 23 | 24 | Det = tf.abs(filtered_y_y * filtered_x_x - filtered_y_x * filtered_x_y) 25 | pad = tf.constant([[0, 0], [1, 1], [1, 1], [0, 0]]) 26 | Det = tf.pad(Det, pad, mode='SYMMETRIC') 27 | 28 | return Det 29 | 30 | 31 | @wrappy 32 | def augm(t, arg): 33 | t = tf.image.random_contrast(t, lower=1 - arg.contrast_var, upper=1 + arg.contrast_var) 34 | t = tf.image.random_brightness(t, arg.brightness_var) 35 | t = tf.image.random_saturation(t, 1 - arg.saturation_var, 1 + arg.saturation_var) 36 | t = tf.image.random_hue(t, max_delta=arg.hue_var) 37 | 38 | random_tensor = 1. - arg.p_flip + random_ops.random_uniform(shape=[1], dtype=t.dtype) 39 | 40 | binary_tensor = math_ops.floor(random_tensor) 41 | augmented = binary_tensor * t + (1 - binary_tensor) * (1 - t) 42 | return augmented 43 | 44 | @wrappy 45 | def Parity(t_images, t_mesh, on=False): 46 | if on: 47 | bn = t_images.get_shape().as_list()[0] 48 | P = tf.random_uniform(shape=[bn, 1, 1, 1], dtype=tf.float32) - 0.5 # bn, h ,w ,c 49 | P = tf.cast(P > 0., dtype=tf.float32) 50 | Pt_images = P * t_images[:, :, ::-1] + (1 - P) * t_images 51 | Pt_mesh = P * t_mesh[:, :, ::-1] + (1 - P) * t_mesh 52 | 53 | else: 54 | Pt_images = t_images 55 | Pt_mesh = t_mesh 56 | 57 | return Pt_images, Pt_mesh 58 | 59 | 60 | 61 | @wrappy 62 | def prepare_pairs(t_images, reconstr_dim, arg): 63 | if arg.mode == 'train': 64 | bn, h, w, n_c = t_images.get_shape().as_list() 65 | if arg.static: 66 | t_images = tf.concat([tf.expand_dims(t_images[:bn//2], axis=1), tf.expand_dims(t_images[bn//2:], axis=1)], axis=1) 67 | else: 68 | t_images = tf.reshape(t_images, shape=[bn // 2, 2, h, w, n_c]) 69 | t_c_1_images = tf.map_fn(lambda x: augm(x, arg), t_images) 70 | t_c_2_images = tf.map_fn(lambda x: augm(x, arg), t_images) 71 | a, b = tf.expand_dims(t_c_1_images[:, 0], axis=1), tf.expand_dims(t_c_1_images[:, 1], axis=1) 72 | c, d = tf.expand_dims(t_c_2_images[:, 0], axis=1), tf.expand_dims(t_c_2_images[:, 1], axis=1) 73 | if arg.static: 74 | t_input_images = tf.reshape(tf.concat([a, d], axis=0), shape=[bn, h, w, n_c]) 75 | t_reconstr_images = tf.reshape(tf.concat([c, b], axis=0), shape=[bn, h, w, n_c]) 76 | else: 77 | t_input_images = tf.reshape(tf.concat([a, d], axis=1), shape=[bn, h, w, n_c]) 78 | t_reconstr_images = tf.reshape(tf.concat([c, b], axis=1), shape=[bn, h, w, n_c]) 79 | 80 | t_input_images = tf.clip_by_value(t_input_images, 0., 1.) 81 | t_reconstr_images = tf.image.resize_images(tf.clip_by_value(t_reconstr_images, 0., 1.), size=(reconstr_dim, reconstr_dim)) 82 | 83 | else: 84 | t_input_images = tf.clip_by_value(t_images, 0., 1.) 85 | t_reconstr_images = tf.image.resize_images(tf.clip_by_value(t_images, 0., 1.), size=(reconstr_dim, reconstr_dim)) 86 | 87 | return t_input_images, t_reconstr_images 88 | 89 | 90 | @wrappy 91 | def reverse_batch(tensor, n_reverse): 92 | """ 93 | reverses order of elements the first axis of tensor 94 | example: reverse_batch(tensor=tf([[1],[2],[3],[4],[5],[6]), n_reverse=3) returns tf([[3],[2],[1],[6],[5],[4]]) for n reverse 3 95 | :param tensor: 96 | :param n_reverse: 97 | :return: 98 | """ 99 | bn, *rest = tensor.get_shape().as_list() 100 | assert ((bn / n_reverse).is_integer()) 101 | tensor = tf.reshape(tensor, shape=[bn // n_reverse, n_reverse, *rest]) 102 | tensor_rev = tensor[:, ::-1] 103 | tensor_rev = tf.reshape(tensor_rev, shape=[bn, *rest]) 104 | return tensor_rev 105 | 106 | 107 | @wrappy 108 | def softmax_norm(logit_map): 109 | eps = 1e-12 110 | exp = tf.exp(logit_map - tf.reduce_max(logit_map, axis=[1, 2], keepdims=True)) 111 | norm = tf.reduce_sum(exp, axis=[1, 2], keepdims=True) + eps 112 | softmax = exp / norm 113 | return softmax, norm 114 | 115 | 116 | @wrappy 117 | def softmax(logit_map): 118 | eps = 1e-12 119 | exp = tf.exp(logit_map - tf.reduce_max(logit_map, axis=[1, 2], keepdims=True)) 120 | norm = tf.reduce_sum(exp, axis=[1, 2], keepdims=True) + eps 121 | softmax = exp / norm 122 | return softmax 123 | 124 | 125 | 126 | def random_scal(bn, min_scal, max_scal): 127 | rand_scal = tf.random_uniform(shape=[bn // 2, 2], 128 | minval=min_scal, maxval=max_scal, dtype=tf.float32) 129 | rand_scal = tf.tile(rand_scal, [2, 2]) 130 | rand_scal = tf.reshape(rand_scal, shape=[2 * bn, 2]) 131 | return rand_scal 132 | 133 | 134 | 135 | @wrappy 136 | def part_map_to_mu_L_inv(part_maps, scal): 137 | """ 138 | Calculate mean for each channel of part_maps 139 | :param part_maps: tensor of part map activations [bn, h, w, n_part] 140 | :return: mean calculated on a grid of scale [-1, 1] 141 | """ 142 | bn, h, w, nk = part_maps.get_shape().as_list() 143 | y_t = tf.tile(tf.reshape(tf.linspace(-1., 1., h), [h, 1]), [1, w]) 144 | x_t = tf.tile(tf.reshape(tf.linspace(-1., 1., w), [1, w]), [h, 1]) 145 | y_t = tf.expand_dims(y_t, axis=-1) 146 | x_t = tf.expand_dims(x_t, axis=-1) 147 | meshgrid = tf.concat([y_t, x_t], axis=-1) 148 | 149 | mu = tf.einsum('ijl,aijk->akl', meshgrid, part_maps) 150 | mu_out_prod = tf.einsum('akm,akn->akmn', mu, mu) 151 | 152 | mesh_out_prod = tf.einsum('ijm,ijn->ijmn', meshgrid, meshgrid) 153 | stddev = tf.einsum('ijmn,aijk->akmn', mesh_out_prod, part_maps) - mu_out_prod 154 | 155 | a_sq = stddev[:, :, 0, 0] 156 | a_b = stddev[:, :, 0, 1] 157 | b_sq_add_c_sq = stddev[:, :, 1, 1] 158 | eps = 1e-12 159 | 160 | a = tf.sqrt(a_sq + eps) # Σ = L L^T Prec = Σ^-1 = L^T^-1 * L^-1 ->looking for L^-1 but first L = [[a, 0], [b, c] 161 | b = a_b / (a + eps) 162 | c = tf.sqrt(b_sq_add_c_sq - b ** 2 + eps) 163 | z = tf.zeros_like(a) 164 | 165 | tf.summary.scalar(name="L_0_0", tensor=a[0, 0]) 166 | tf.summary.scalar(name="L_1_0", tensor=b[0, 0]) 167 | tf.summary.scalar(name="L_1_1", tensor=c[0, 0]) 168 | 169 | det = tf.expand_dims(tf.expand_dims(a * c, axis=-1), axis=-1) 170 | row_1 = tf.expand_dims(tf.concat([tf.expand_dims(c, axis=-1), tf.expand_dims(z, axis=-1)], axis=-1), axis=-2) 171 | row_2 = tf.expand_dims(tf.concat([tf.expand_dims(-b, axis=-1), tf.expand_dims(a, axis=-1)], axis=-1), axis=-2) 172 | 173 | L_inv = scal / (det + eps) * tf.concat([row_1, row_2], axis=-2) # L^⁻1 = 1/(ac)* [[c, 0], [-b, a] 174 | tf.summary.scalar(name="L_inv_0_0", tensor=L_inv[0, 0, 0, 0]) 175 | tf.summary.scalar(name="L_inv_1_0", tensor=L_inv[0, 0, 1, 0]) 176 | tf.summary.scalar(name="L_inv_1_1", tensor=L_inv[0, 0, 1, 1]) 177 | 178 | return mu, L_inv 179 | 180 | 181 | @wrappy 182 | def get_features(features, part_map, slim): 183 | """ 184 | :param features: features of shape [bn, h, w, n_features] (slim), or [bn, h, w, n_part, n_features] 185 | :param part_map: part_map of shape [bn, h, w, n_part] 186 | :param slim: 187 | :return: features of shape [bn, nk, n_features] 188 | """ 189 | if slim: 190 | features = tf.einsum('bijf,bijk->bkf', features, part_map) 191 | else: 192 | features = tf.einsum('bijkf,bijk->bkf', features, part_map) 193 | return features 194 | 195 | 196 | @wrappy 197 | def augm_mu(image_in, image_rec, mu, features, batch_size, n_part, move_list): 198 | image_in = tf.tile(tf.expand_dims(image_in[0], axis=0), [batch_size, 1, 1, 1]) 199 | image_rec = tf.tile(tf.expand_dims(image_rec[0], axis=0), [batch_size, 1, 1, 1]) 200 | mu = tf.tile(tf.expand_dims(mu[0], axis=0), [batch_size, 1, 1]) 201 | features = tf.tile(tf.expand_dims(features[0], axis=0), [batch_size, 1, 1]) 202 | batch_size = batch_size // 2 203 | ran = (tf.reshape(tf.range(batch_size), [batch_size, 1]))/batch_size - 0.5 204 | array = tf.concat([tf.concat([ran, tf.zeros_like(ran)], axis=-1), tf.concat([tf.zeros_like(ran), ran], axis=-1)], axis=0) 205 | array = tf.expand_dims(tf.cast(array, dtype=tf.float32), axis=1) 206 | for elem in move_list: 207 | part = tf.constant(elem, dtype=tf.int32, shape=([1, 1])) 208 | pad_part = tf.constant([[1, 1], [0, 0]]) 209 | part_arr = tf.pad(tf.concat([part, -part], axis=-1), pad_part) 210 | pad = tf.constant([[0, 0], [0, n_part - 1], [0, 0]]) + part_arr 211 | addy = tf.pad(array, pad) 212 | mu = mu + addy 213 | return image_in, image_rec, mu, features 214 | 215 | 216 | @wrappy 217 | def precision_dist_op(precision, dist, part_depth, nk, h, w): 218 | proj_precision = tf.einsum('bnik,bnkf->bnif', precision, dist) ** 2 # tf.matmul(precision, dist)**2 219 | proj_precision = tf.reduce_sum(proj_precision, axis=-2) # sum x and y axis 220 | 221 | heat = 1 / (1 + proj_precision) 222 | heat = tf.reshape(heat, shape=[-1, nk, h, w]) # bn width height number parts 223 | part_heat = heat[:, :part_depth] 224 | part_heat = tf.transpose(part_heat, [0, 2, 3, 1]) 225 | return heat, part_heat 226 | 227 | 228 | @wrappy 229 | def feat_mu_to_enc(features, mu, L_inv, reconstruct_stages, part_depths, feat_map_depths, static, n_reverse, covariance=None, feat_shape=None, heat_feat_normalize=True, range=10, ): 230 | """ 231 | :param features: tensor shape bn, nk, nf 232 | :param mu: tensor shape [bn, nk, 2] in range[-1,1] 233 | :param L_inv: tensor shape [bn, nk, 2, 2] 234 | :param reconstruct_stages: 235 | :param part_depths: 236 | :param feat_map_depths: 237 | :param n_reverse: 238 | :param average: 239 | :return: 240 | """ 241 | bn, nk, nf = features.get_shape().as_list() 242 | 243 | if static: 244 | reverse_features = tf.concat([features[bn//2:], features[:bn//2]], axis=0) 245 | 246 | else: 247 | reverse_features = reverse_batch(features, n_reverse) 248 | 249 | encoding_list = [] 250 | circular_precision = tf.tile(tf.reshape(tf.constant([[range, 0.], [0, range]], dtype=tf.float32), shape=[1, 1, 2, 2]), multiples=[bn, nk, 1, 1]) 251 | 252 | for dims, part_depth, feat_slice in zip(reconstruct_stages, part_depths, feat_map_depths): 253 | h, w = dims[0], dims[1] 254 | 255 | y_t = tf.expand_dims(tf.tile(tf.reshape(tf.linspace(-1., 1., h), [h, 1]), [1, w]), axis=-1) 256 | x_t = tf.expand_dims(tf.tile(tf.reshape(tf.linspace(- 1., 1., w), [1, w]), [h, 1]), axis=-1) 257 | 258 | y_t_flat = tf.reshape(y_t, (1, 1, 1, -1)) 259 | x_t_flat = tf.reshape(x_t, (1, 1, 1, -1)) 260 | 261 | mesh = tf.concat([y_t_flat, x_t_flat], axis=-2) 262 | dist = mesh - tf.expand_dims(mu, axis=-1) 263 | 264 | if not covariance or not feat_shape: 265 | heat_circ, part_heat_circ = precision_dist_op(circular_precision, dist, part_depth, nk, h, w) 266 | 267 | if covariance or feat_shape: 268 | heat_shape, part_heat_shape = precision_dist_op(L_inv, dist, part_depth, nk, h, w) 269 | 270 | nkf = feat_slice[1] - feat_slice[0] 271 | 272 | if nkf != 0: 273 | feature_slice_rev = reverse_features[:, feat_slice[0]: feat_slice[1]] 274 | 275 | if feat_shape: 276 | heat_scal = heat_shape[:, feat_slice[0]: feat_slice[1]] 277 | 278 | else: 279 | heat_scal = heat_circ[:, feat_slice[0]: feat_slice[1]] 280 | 281 | if heat_feat_normalize: 282 | heat_scal_norm = tf.reduce_sum(heat_scal, axis=1, keepdims=True) + 1 283 | heat_scal = heat_scal / heat_scal_norm 284 | 285 | heat_feat_map = tf.einsum('bkij,bkn->bijn', heat_scal, feature_slice_rev) 286 | 287 | if covariance: 288 | encoding_list.append(tf.concat([part_heat_shape, heat_feat_map], axis=-1)) 289 | 290 | else: 291 | encoding_list.append(tf.concat([part_heat_circ, heat_feat_map], axis=-1)) 292 | 293 | else: 294 | if covariance: 295 | encoding_list.append(part_heat_shape) 296 | 297 | else: 298 | encoding_list.append(part_heat_circ) 299 | 300 | return encoding_list 301 | 302 | 303 | @wrappy 304 | def heat_map_function(y_dist, x_dist, y_scale, x_scale): 305 | x = 1 / (1 + (tf.square(y_dist / (1e-6 + y_scale)) + tf.square( 306 | x_dist / (1e-6 + x_scale)))) 307 | return x 308 | 309 | 310 | @wrappy 311 | def unary_mat(vector): 312 | b_1 = tf.expand_dims(vector, axis=-2) # (y, x) #b1 and b2 are eigenvectors 313 | b_2 = tf.expand_dims(tf.einsum('bkc,c->bkc', vector[:, :, ::-1], tf.constant([-1., 1], dtype=tf.float32)), 314 | axis=-2) # (y, x) -> (-x, y) get orthogonal eigvec 315 | U_mat = tf.concat([b_1, b_2], axis=-2) # U = (b_1^T, b_2^t)^T U contains transposed eigenvecs 316 | return U_mat 317 | 318 | 319 | @wrappy 320 | def get_img_slice_around_mu(img, mu, slice_size): 321 | """ 322 | 323 | :param img: 324 | :param mu: in range [-1, 1] 325 | :param slice_size: 326 | :return: bn, n_part, slice_size[0] , slice_size[1], channel colour + n_part 327 | """ 328 | 329 | h, w = slice_size 330 | bn, img_h, img_w, c = img.get_shape().as_list() # bn this actually 2bn now 331 | bn_2, nk, _ = mu.get_shape().as_list() 332 | assert (int(h / 2)) 333 | assert (int(w / 2)) 334 | assert (bn_2 == bn) 335 | 336 | scal = tf.constant([img_h, img_w], dtype=tf.float32) 337 | mu = tf.stop_gradient(mu) 338 | mu_no_grad = tf.einsum('bkj,j->bkj', (mu + 1) / 2., scal) 339 | mu_no_grad = tf.cast(mu_no_grad, dtype=tf.int32) 340 | 341 | mu_no_grad = tf.reshape(mu_no_grad, shape=[bn, nk, 1, 1, 2]) 342 | y = tf.tile(tf.reshape(tf.range(- h // 2, h // 2), [1, 1, h, 1, 1]), [bn, nk, 1, w, 1]) 343 | x = tf.tile(tf.reshape(tf.range(- w // 2, w // 2), [1, 1, 1, w, 1]), [bn, nk, h, 1, 1]) 344 | 345 | field = tf.concat([y, x], axis=-1) + mu_no_grad 346 | 347 | h1 = tf.tile(tf.reshape(tf.range(bn), [bn, 1, 1, 1, 1]), [1, nk, h, w, 1]) 348 | 349 | idx = tf.concat([h1, field], axis=-1) 350 | 351 | image_slices = tf.gather_nd(img, idx) 352 | return image_slices 353 | 354 | 355 | 356 | 357 | @wrappy 358 | def fold_img_with_mu(img, mu, scale, visualize, threshold, normalize=True): 359 | """ 360 | folds the pixel values of img with potentials centered around the part means (mu) 361 | :param img: batch of images 362 | :param mu: batch of part means in range [-1, 1] 363 | :param scale: scale that governs the range of the potential 364 | :param visualize: 365 | :param normalize: whether to normalize the potentials 366 | :return: folded image 367 | """ 368 | bn, h, w, nc = img.get_shape().as_list() 369 | bn, nk, _ = mu.get_shape().as_list() 370 | 371 | py = tf.expand_dims(mu[:, :, 0], 2) 372 | px = tf.expand_dims(mu[:, :, 1], 2) 373 | 374 | py = tf.stop_gradient(py) 375 | px = tf.stop_gradient(px) 376 | 377 | y_t = tf.tile(tf.reshape(tf.linspace(-1., 1., h), [h, 1]), [1, w]) 378 | x_t = tf.tile(tf.reshape(tf.linspace(- 1., 1., w), [1, w]), [h, 1]) 379 | x_t_flat = tf.reshape(x_t, (1, 1, -1)) 380 | y_t_flat = tf.reshape(y_t, (1, 1, -1)) 381 | 382 | y_dist = py - y_t_flat 383 | x_dist = px - x_t_flat 384 | 385 | heat_scal = heat_map_function(y_dist=y_dist, x_dist=x_dist, x_scale=scale, y_scale=scale) 386 | heat_scal = tf.reshape(heat_scal, shape=[bn, nk, h, w]) # bn width height number parts 387 | heat_scal = tf.einsum('bkij->bij', heat_scal) 388 | heat_scal = tf.clip_by_value(t=heat_scal, clip_value_min=0., clip_value_max=1.) 389 | heat_scal = tf.where(heat_scal > threshold, heat_scal, tf.zeros_like(heat_scal)) 390 | 391 | norm = tf.reduce_sum(heat_scal, axis=[1, 2], keepdims=True) 392 | tf.summary.scalar("norm", tensor=tf.reduce_mean(norm)) 393 | if normalize: 394 | heat_scal_norm = heat_scal / norm 395 | folded_img = tf.einsum('bijc,bij->bijc', img, heat_scal_norm) 396 | if not normalize: 397 | folded_img = tf.einsum('bijc,bij->bijc', img, heat_scal) 398 | if visualize: 399 | tf.summary.image(name="foldy_map", tensor=tf.expand_dims(heat_scal, axis=-1), max_outputs=4) 400 | 401 | return folded_img, tf.expand_dims(heat_scal, axis=-1) 402 | 403 | 404 | 405 | 406 | @wrappy 407 | def mu_img_gate(mu, resolution, scale): 408 | """ 409 | folds the pixel values of img with potentials centered around the part means (mu) 410 | :param img: batch of images 411 | :param mu: batch of part means in range [-1, 1] 412 | :param scale: scale that governs the range of the potential 413 | :param visualize: 414 | :param normalize: whether to normalize the potentials 415 | :return: folded image 416 | """ 417 | bn, nk, _ = mu.get_shape().as_list() 418 | 419 | py = tf.expand_dims(mu[:, :, 0], 2) 420 | px = tf.expand_dims(mu[:, :, 1], 2) 421 | 422 | py = tf.stop_gradient(py) 423 | px = tf.stop_gradient(px) 424 | h, w = resolution 425 | 426 | y_t = tf.tile(tf.reshape(tf.linspace(-1., 1., h), [h, 1]), [1, w]) 427 | x_t = tf.tile(tf.reshape(tf.linspace(- 1., 1., w), [1, w]), [h, 1]) 428 | x_t_flat = tf.reshape(x_t, (1, 1, -1)) 429 | y_t_flat = tf.reshape(y_t, (1, 1, -1)) 430 | 431 | y_dist = py - y_t_flat 432 | x_dist = px - x_t_flat 433 | 434 | heat_scal = heat_map_function(y_dist=y_dist, x_dist=x_dist, x_scale=scale, y_scale=scale) 435 | heat_scal = tf.reshape(heat_scal, shape=[bn, nk, h, w]) # bn width height number parts 436 | heat_scal = tf.einsum('bkij->bij', heat_scal) 437 | return heat_scal 438 | 439 | @wrappy 440 | def binary_activation(x): 441 | cond = tf.less(x, tf.zeros(tf.shape(x))) 442 | out = tf.where(cond, tf.zeros(tf.shape(x)), tf.ones(tf.shape(x))) 443 | 444 | return out 445 | 446 | 447 | @wrappy 448 | def fold_img_with_L_inv(img, mu, L_inv, scale, visualize, threshold, normalize=True): 449 | """ 450 | folds the pixel values of img with potentials centered around the part means (mu) 451 | :param img: batch of images 452 | :param mu: batch of part means in range [-1, 1] 453 | :param scale: scale that governs the range of the potential 454 | :param visualize: 455 | :param normalize: whether to normalize the potentials 456 | :return: folded image 457 | """ 458 | bn, h, w, nc = img.get_shape().as_list() 459 | bn, nk, _ = mu.get_shape().as_list() 460 | 461 | mu_stop = tf.stop_gradient(mu) 462 | 463 | y_t = tf.tile(tf.reshape(tf.linspace(-1., 1., h), [h, 1]), [1, w]) 464 | x_t = tf.tile(tf.reshape(tf.linspace(- 1., 1., w), [1, w]), [h, 1]) 465 | x_t_flat = tf.reshape(x_t, (1, 1, -1)) 466 | y_t_flat = tf.reshape(y_t, (1, 1, -1)) 467 | 468 | mesh = tf.concat([y_t_flat, x_t_flat], axis=-2) 469 | dist = mesh - tf.expand_dims(mu_stop, axis=-1) 470 | 471 | proj_precision = tf.einsum('bnik,bnkf->bnif', scale * L_inv, dist) ** 2 # tf.matmul(precision, dist)**2 472 | proj_precision = tf.reduce_sum(proj_precision, axis=-2) # sum x and y axis 473 | 474 | heat = 1 / (1 + proj_precision) 475 | 476 | heat = tf.reshape(heat, shape=[bn, nk, h, w]) # bn width height number parts 477 | heat = tf.einsum('bkij->bij', heat) 478 | heat_scal = tf.clip_by_value(t=heat, clip_value_min=0., clip_value_max=1.) 479 | heat_scal = tf.where(heat_scal > threshold, heat_scal, tf.zeros_like(heat_scal)) 480 | norm = tf.reduce_sum(heat_scal, axis=[1, 2], keepdims=True) 481 | tf.summary.scalar("norm", tensor=tf.reduce_mean(norm)) 482 | if normalize: 483 | heat_scal = heat_scal / norm 484 | folded_img = tf.einsum('bijc,bij->bijc', img, heat_scal) 485 | if visualize: 486 | tf.summary.image(name="foldy_map", tensor=tf.expand_dims(heat_scal, axis=-1), max_outputs=4) 487 | 488 | return folded_img 489 | 490 | 491 | @wrappy 492 | def probabilistic_switch(handles, handle_probs, counter, scale=10000.): 493 | t = counter / scale 494 | scheduled_probs = [] 495 | 496 | for p_1, p_2 in zip(handle_probs[::2], handle_probs[1::2]): 497 | scheduled_prob = t * p_1 + (1 - t) * p_2 498 | scheduled_probs.append(scheduled_prob) 499 | 500 | handle = np.random.choice(handles, p=scheduled_probs) 501 | return handle 502 | 503 | 504 | def initialize_uninitialized(sess): 505 | global_vars = tf.global_variables() 506 | is_not_initialized = sess.run([tf.is_variable_initialized(var) for var in global_vars]) 507 | not_initialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f] 508 | 509 | print([str(i.name) for i in not_initialized_vars]) # only for testing 510 | if len(not_initialized_vars): 511 | sess.run(tf.variables_initializer(not_initialized_vars)) 512 | 513 | 514 | 515 | --------------------------------------------------------------------------------