├── vis ├── linspecer.m ├── vppAutoKeypointShowSingle.m ├── figure_show_callback.m ├── vppAutoKeypointImageRecon.m └── vppDetailsResultsTemplate.m ├── demo ├── input │ ├── 047269.jpg │ ├── 075525.jpg │ ├── 114868.jpg │ └── 156207.jpg ├── vppAutoKeypointShow.m └── vppAutoKeypointShowSingle.m ├── requirements.txt ├── zutils ├── pt_utils.py ├── np_utils.py └── option_struct.py ├── bin ├── stitch_samples.sh ├── batch_test.sh ├── batch_test_large_testset.sh ├── batch_test_large_trainset.sh ├── run_tensorboard.sh ├── link_to_latest_samples.sh └── wipe_test_cache.sh ├── one_step_test_celeba_demo.sh ├── net_modules ├── distribution_utils.py ├── common.py ├── auto_struct │ ├── utils.py │ ├── keypoint_decoder.py │ └── generic_encoder.py ├── ndeconv.py ├── nearest_upsampling.py ├── gen.py ├── nconv.py ├── pixel_bias.py ├── spatial_transformer_pt.py ├── pt_group_connected.py ├── pt_patch_batch_normalization.py └── spatial_transformer.py ├── nets ├── distribution │ ├── bernoullix.py │ ├── spike_in_01.py │ ├── generic.py │ ├── gaussian_in_01.py │ ├── gaussian_fixedvar_in_01.py │ ├── spike_at_zero.py │ ├── spike.py │ ├── bernoulli.py │ ├── gaussian_fixedvar.py │ └── gaussian.py ├── recon │ └── generic_single_dist.py ├── decoder │ ├── car_64x64.py │ ├── general_80x80.py │ ├── general_64x64.py │ └── general_128x128_landmark.py ├── encoder │ ├── general_64x64.py │ ├── car_64x64.py │ ├── general_128x128_landmark.py │ └── general_80x80.py └── data │ ├── cat_80x80.py │ └── aflw_80x80.py ├── evaluation ├── merge_large_dataset.m ├── linear_regressor.m ├── mean_error_IOD.m ├── face_evaluation.m └── cat_evaluation.m ├── one_step_test_cat.sh ├── one_step_test_celeba.sh ├── one_step_test_aflw.sh ├── .gitignore ├── runner ├── resumable_data_module_wrapper.py ├── data_op.py ├── train_pipeline.py ├── one_epoch_runner.py └── preprocessing_data_module_wrapper.py ├── exp-ae-aflw-10.py ├── exp-ae-aflw-30.py ├── model ├── pipeline.py └── options.py ├── tools └── run_test_in_folder.py ├── exp-ae-cat-10.py ├── exp-ae-cat-20.py ├── exp-ae-celeba-mafl-30.py ├── download_cat.sh ├── exp-ae-celeba-mafl-10.py ├── download_aflw.sh ├── download_celeba.sh └── README.md /vis/linspecer.m: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YutingZhang/lmdis-rep/HEAD/vis/linspecer.m -------------------------------------------------------------------------------- /demo/input/047269.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YutingZhang/lmdis-rep/HEAD/demo/input/047269.jpg -------------------------------------------------------------------------------- /demo/input/075525.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YutingZhang/lmdis-rep/HEAD/demo/input/075525.jpg -------------------------------------------------------------------------------- /demo/input/114868.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YutingZhang/lmdis-rep/HEAD/demo/input/114868.jpg -------------------------------------------------------------------------------- /demo/input/156207.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YutingZhang/lmdis-rep/HEAD/demo/input/156207.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy >= 1.12.0 2 | scipy >= 0.18.1 3 | tensorflow[, -gpu] > 1.0.0, <=1.7.0 4 | prettytensor 5 | opencv-python 6 | -------------------------------------------------------------------------------- /zutils/pt_utils.py: -------------------------------------------------------------------------------- 1 | import prettytensor as pt 2 | 3 | 4 | def default_phase(): 5 | defaults = pt.pretty_tensor_class._defaults 6 | if 'phase' in defaults: 7 | return defaults['phase'] 8 | else: 9 | return pt.Phase.test 10 | -------------------------------------------------------------------------------- /bin/stitch_samples.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ls -d var/results/*/test.final/prior_samples var/results/*/test.snapshot/*/prior_samples | \ 4 | while read line; do 5 | if [ -e $line".png" ]; then continue; fi 6 | echo "$line" 7 | montage -mode concatenate `ls $line/*.png` $line".png" 8 | done 9 | 10 | -------------------------------------------------------------------------------- /one_step_test_celeba_demo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SNAPSHOT_ITER="" 4 | SPECIFIC_MODEL_DIR=$(cd `dirname $0`; pwd)'/pretrained_results/celeba_10' 5 | 6 | python3 "./tools/run_test_in_folder.py" "$SPECIFIC_MODEL_DIR" "'test_subset':'demo', 'test_limit':None" "test.demo" "$SNAPSHOT_ITER" "False" "True" 7 | 8 | TEST_PRED_FILE=$SPECIFIC_MODEL_DIR'/test.demo/posterior_param.mat' 9 | matlab -nosplash -r "tmp=load('$TEST_PRED_FILE'); cd demo; vppAutoKeypointShow(tmp.data, tmp.encoded.structure_param,'output');exit()" 10 | -------------------------------------------------------------------------------- /zutils/np_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def onthot_to_int(onehot_array, axis=1, dtype=np.int64, keepdims=False): 5 | 6 | s = onehot_array.shape 7 | num = s[axis] 8 | nonzero_indexes = np.nonzero(onehot_array) 9 | index_arr = np.array(np.arange(num)) 10 | all_indexes = index_arr[nonzero_indexes[axis]] 11 | if keepdims: 12 | s[axis] = 1 13 | else: 14 | s = s[0:axis] + s[axis+1:] 15 | all_indexes = np.reshape(all_indexes, s) 16 | 17 | return all_indexes 18 | -------------------------------------------------------------------------------- /net_modules/distribution_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import zutils.tf_math_funcs as tmf 3 | 4 | 5 | # handle extended structure_param, pathc_feature_x, overall_features 6 | def reshape_extended_features(a, param_factor): 7 | if a is None: 8 | return None, None 9 | kept_shape = tmf.get_shape(a)[:-1] + [tmf.get_shape(a)[-1] // param_factor] 10 | a = tf.reshape(a, kept_shape + [param_factor]) 11 | main_chooser = (slice(None, None),) * len(kept_shape) + (0,) 12 | b = a[main_chooser] 13 | return a, b 14 | -------------------------------------------------------------------------------- /bin/batch_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ "$1" == "" ]; then 4 | echo 'Usage: '$0' [OVERRIDE_PARAM [TEST_PATH [SNAPSHOT_ITER]]]' 5 | exit -1 6 | fi 7 | 8 | SCRIPT_DIR=`dirname "$0"` 9 | TOOLBOX_DIR=`readlink -f "$SCRIPT_DIR"/..` 10 | 11 | OVERRIDE_PARAM="$1" 12 | TEST_PATH="$2" 13 | SNAPSHOT_ITER="$3" 14 | 15 | if type "sponge" > /dev/null; then 16 | EDITOR=sponge 17 | else 18 | EDITOR=cat 19 | fi 20 | 21 | $EDITOR | python3 "$TOOLBOX_DIR/tools/run_test_in_folder.py" - "$OVERRIDE_PARAM" "$TEST_PATH" "$SNAPSHOT_ITER" 22 | 23 | -------------------------------------------------------------------------------- /nets/distribution/bernoullix.py: -------------------------------------------------------------------------------- 1 | import nets.distribution.bernoulli 2 | 3 | BaseFactory = nets.distribution.bernoulli.Factory 4 | 5 | 6 | class Factory(BaseFactory): 7 | 8 | def __init__(self, *args, **kwargs): 9 | super().__init__(*args, **kwargs) 10 | 11 | # use cross-entropy for nll 12 | def nll(self, dist_param, samples): 13 | gt_dist_param = self.parametrize(samples, None) # no need for latent dim for bernoulli 14 | xnll = self.cross_entropy(gt_dist_param, dist_param) 15 | return xnll, True 16 | 17 | -------------------------------------------------------------------------------- /bin/batch_test_large_testset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ "$1" == "" ]; then 4 | echo 'Usage: run_test_in_folder.py EXP_PATH [OVERRIDE_PARAM [TEST_PATH [SNAPSHOT_ITER]]]' 5 | exit -1 6 | fi 7 | 8 | SCRIPT_DIR=`dirname "$0"` 9 | TOOLBOX_DIR=`readlink -f "$SCRIPT_DIR"/..` 10 | 11 | OVERRIDE_PARAM="$1" 12 | TEST_PATH="$2" 13 | SNAPSHOT_ITER="$3" 14 | 15 | if type "sponge" > /dev/null; then 16 | EDITOR=sponge 17 | else 18 | EDITOR=cat 19 | fi 20 | 21 | $EDITOR | python3 "$TOOLBOX_DIR/tools/run_test_large_testset_in_folder.py" - "$OVERRIDE_PARAM" "$TEST_PATH" "$SNAPSHOT_ITER" 22 | 23 | -------------------------------------------------------------------------------- /bin/batch_test_large_trainset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ "$1" == "" ]; then 4 | echo 'Usage: run_test_in_folder.py EXP_PATH [OVERRIDE_PARAM [TEST_PATH [SNAPSHOT_ITER]]]' 5 | exit -1 6 | fi 7 | 8 | SCRIPT_DIR=`dirname "$0"` 9 | TOOLBOX_DIR=`readlink -f "$SCRIPT_DIR"/..` 10 | 11 | OVERRIDE_PARAM="$1" 12 | TEST_PATH="$2" 13 | SNAPSHOT_ITER="$3" 14 | 15 | if type "sponge" > /dev/null; then 16 | EDITOR=sponge 17 | else 18 | EDITOR=cat 19 | fi 20 | 21 | $EDITOR | python3 "$TOOLBOX_DIR/tools/run_test_large_trainset_in_folder.py" - "$OVERRIDE_PARAM" "$TEST_PATH" "$SNAPSHOT_ITER" 22 | 23 | -------------------------------------------------------------------------------- /evaluation/merge_large_dataset.m: -------------------------------------------------------------------------------- 1 | function merge_large_trainset(data_dir) 2 | files = dir(sprintf('%s/posterior_param_*.mat',data_dir)); 3 | file_names = sort({files.name}); 4 | encoded = struct; 5 | encoded.structure_param = []; 6 | for i=1:numel(file_names) 7 | file = file_names{i}; 8 | fprintf(file); 9 | tmp = load(sprintf('%s/%s',data_dir,file)); 10 | encoded.structure_param = [encoded.structure_param; tmp.encoded.structure_param]; 11 | end 12 | save(sprintf('%s/posterior_param.mat',data_dir), 'encoded'); 13 | end 14 | -------------------------------------------------------------------------------- /bin/run_tensorboard.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | PORT=$1 4 | if [ -z "$PORT" ]; then 5 | PORT=6006 6 | fi 7 | 8 | SCRIPT_DIR=`dirname "$0"` 9 | TOOLBOX_DIR=`readlink -f "$SCRIPT_DIR"/..` 10 | RESULT_DIR="$TOOLBOX_DIR"/var/results 11 | 12 | if type "sponge" > /dev/null; then 13 | EDITOR=sponge 14 | else 15 | EDITOR=cat 16 | fi 17 | 18 | echo "Please Input the List of Experiments:" 19 | 20 | $EDITOR | while read line; do 21 | echo $PORT ":" $line 22 | if [[ "$str" != /* ]]; then 23 | line=$RESULT_DIR/$line 24 | fi 25 | ( cd "$line"/logs; tensorboard --logdir=. --port=$PORT > /dev/null 2>&1 ) & 26 | PORT=$((PORT+1)) 27 | done 28 | 29 | while true; do sleep 1; done 30 | 31 | -------------------------------------------------------------------------------- /demo/vppAutoKeypointShow.m: -------------------------------------------------------------------------------- 1 | function vppAutoKeypointShow(im, keypoints, output_dir) 2 | 3 | output_to_dir = exist('output_dir', 'var') && ~isempty(output_dir); 4 | if output_to_dir 5 | if ~exist(output_dir, 'dir') 6 | mkdir(output_dir) 7 | end 8 | end 9 | 10 | batch_mode_started = false; 11 | for k=1:size(im, 1) 12 | I = squeeze(im(k, :, :, :)); 13 | clf 14 | the_figure = gcf; 15 | P=double(squeeze(keypoints(k,:,:))); 16 | vppAutoKeypointShowSingle(I, P) 17 | set(gca,'visible','off'); 18 | set(gcf,'color','white'); 19 | if output_to_dir 20 | fprintf('%d / %d\n', k, size(im, 1)) 21 | saveas(the_figure, fullfile(output_dir, sprintf('%d.eps', k)), 'epsc') 22 | else 23 | while waitforbuttonpress; end 24 | end 25 | end 26 | -------------------------------------------------------------------------------- /nets/distribution/spike_in_01.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import zutils.tf_math_funcs as tmf 3 | import nets.distribution.spike 4 | 5 | BaseFactory = nets.distribution.spike.Factory 6 | 7 | epsilon = tmf.epsilon 8 | 9 | 10 | class Factory(BaseFactory): 11 | 12 | def __init__(self, *args, **kwargs): 13 | super().__init__(*args, **kwargs) 14 | 15 | @classmethod 16 | def transform2param(cls, input_tensor, latent_dim): 17 | """ Create network for converting input_tensor to distribution parameters 18 | 19 | :param input_tensor: (posterior phase) input tensor for the posterior 20 | :param latent_dim: dimension of the latent_variables 21 | :return: param_tensor - distribution parameters 22 | """ 23 | param_tensor = tf.sigmoid(input_tensor) 24 | return param_tensor 25 | -------------------------------------------------------------------------------- /evaluation/linear_regressor.m: -------------------------------------------------------------------------------- 1 | function [W] = linear_regressor(pred_kp, gt_kp) 2 | fprintf('clean the gt keypoints mat and prediction keypoints mat \n'); 3 | nan_rows = []; 4 | for k=1:size(gt_kp, 1) 5 | if (max(gt_kp(k,:))==0) 6 | nan_rows = [nan_rows, k]; 7 | end 8 | end 9 | if length(nan_rows) > 0 10 | fprintf('nan_rows/n'); 11 | nan_rows 12 | end 13 | for i = size(nan_rows, 2):-1:1 14 | row_idx = nan_rows(1,i); 15 | gt_kp = gt_kp([1:(row_idx-1),(row_idx+1):size(gt_kp,1)], :, :); 16 | pred_kp = pred_kp([1:(row_idx-1),(row_idx+1):size(pred_kp,1)], :, :); 17 | end 18 | clean_gt_kp = gt_kp; 19 | X = reshape(pred_kp,[size(pred_kp,1), size(pred_kp,2)*size(pred_kp,3)] ); 20 | Y = reshape(clean_gt_kp,[size(clean_gt_kp,1), size(clean_gt_kp,2)*size(clean_gt_kp,3)] ); 21 | %The most simple linear regression WX = Y, W = X\Y 22 | W = X\Y; 23 | end 24 | -------------------------------------------------------------------------------- /net_modules/common.py: -------------------------------------------------------------------------------- 1 | from zutils.py_utils import value_class_for_with 2 | import zutils.tf_math_funcs as tmf 3 | 4 | 5 | default_activation = value_class_for_with(tmf.leaky_relu) 6 | 7 | import net_modules.nearest_upsampling 8 | _prettytensor2 = net_modules.nearest_upsampling.prettytensor 9 | 10 | import net_modules.deconv 11 | _prettytensor3 = net_modules.deconv.prettytensor 12 | 13 | import net_modules.ndeconv 14 | _prettytensor4 = net_modules.ndeconv.prettytensor 15 | 16 | import net_modules.pixel_bias 17 | _prettytensor5 = net_modules.pixel_bias.prettytensor 18 | 19 | import net_modules.spatial_transformer_pt 20 | _prettytensor6 = net_modules.spatial_transformer_pt.prettytensor 21 | 22 | import net_modules.pt_patch_batch_normalization 23 | _prettytensor7 = net_modules.pt_patch_batch_normalization.prettytensor 24 | 25 | import net_modules.pt_group_connected 26 | _prettytensor9 = net_modules.pt_group_connected.prettytensor 27 | -------------------------------------------------------------------------------- /one_step_test_cat.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SNAPSHOT_ITER="" 4 | SPECIFIC_MODEL_DIR=$(cd `dirname $0`; pwd)'/pretrained_results/cat_10' 5 | 6 | TRAIN_GT_FILE=$(cd `dirname $0`; pwd)'/data/cat_data/cat_train_keypoints.mat' 7 | TEST_GT_FILE=$(cd `dirname $0`; pwd)'/data/cat_data/cat_test_keypoints.mat' 8 | 9 | python3 "./tools/run_test_in_folder.py" "$SPECIFIC_MODEL_DIR" "'test_subset':'test', 'test_limit':None" "test.test" "$SNAPSHOT_ITER" "False" "True" 10 | 11 | python3 "./tools/run_test_in_folder.py" "$SPECIFIC_MODEL_DIR" "'test_subset':'train', 'test_limit':None" "test.train" "$SNAPSHOT_ITER" "True" "False" 12 | 13 | DATA_DIR=$SPECIFIC_MODEL_DIR'/test.train' 14 | matlab -nodesktop -nosplash -r "cd('./evaluation');merge_large_dataset('$DATA_DIR');exit;" 15 | 16 | TRAIN_PRED_FILE=$SPECIFIC_MODEL_DIR'/test.train/posterior_param.mat' 17 | TEST_PRED_FILE=$SPECIFIC_MODEL_DIR'/test.test/posterior_param.mat' 18 | matlab -nosplash -r "cd('./evaluation');cat_evaluation('$TRAIN_PRED_FILE','$TRAIN_GT_FILE','$TEST_PRED_FILE','$TEST_GT_FILE');exit()" 19 | -------------------------------------------------------------------------------- /one_step_test_celeba.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SNAPSHOT_ITER="" 4 | SPECIFIC_MODEL_DIR=$(cd `dirname $0`; pwd)'/pretrained_results/celeba_10' 5 | 6 | TRAIN_GT_FILE=$(cd `dirname $0`; pwd)'/data/celeba_data/celeba_mafl_train.mat' 7 | TEST_GT_FILE=$(cd `dirname $0`; pwd)'/data/celeba_data/celeba_mafl_test.mat' 8 | 9 | python3 "./tools/run_test_in_folder.py" "$SPECIFIC_MODEL_DIR" "'test_subset':'test', 'test_limit':None" "test.test" "$SNAPSHOT_ITER" "False" "True" 10 | 11 | python3 "./tools/run_test_in_folder.py" "$SPECIFIC_MODEL_DIR" "'test_subset':'train', 'test_limit':None" "test.train" "$SNAPSHOT_ITER" "True" "False" 12 | 13 | DATA_DIR=$SPECIFIC_MODEL_DIR'/test.train' 14 | matlab -nodesktop -nosplash -r "cd('./evaluation');merge_large_dataset('$DATA_DIR');exit;" 15 | 16 | TRAIN_PRED_FILE=$SPECIFIC_MODEL_DIR'/test.train/posterior_param.mat' 17 | TEST_PRED_FILE=$SPECIFIC_MODEL_DIR'/test.test/posterior_param.mat' 18 | matlab -nosplash -r "cd('./evaluation');face_evaluation('$TRAIN_PRED_FILE','$TRAIN_GT_FILE','$TEST_PRED_FILE','$TEST_GT_FILE');exit()" 19 | -------------------------------------------------------------------------------- /net_modules/auto_struct/utils.py: -------------------------------------------------------------------------------- 1 | from zutils.py_utils import recursive_merge_dicts 2 | 3 | 4 | class ModuleOutputStrip: 5 | 6 | def __init__(self): 7 | self.extra_outputs = dict() 8 | self.extra_outputs["save"] = dict() 9 | self.extra_outputs["extra_recon"] = dict() 10 | self.extra_outputs["for_decoder"] = dict() 11 | self.extra_outputs["for_discriminator"] = dict() 12 | 13 | def __call__(self, module_output): 14 | if isinstance(module_output, tuple): 15 | if not module_output: 16 | return None 17 | if len(module_output) == 1: 18 | return module_output[0] 19 | 20 | if module_output[-1] is not None: 21 | self.extra_outputs = recursive_merge_dicts( 22 | self.extra_outputs, module_output[-1] 23 | ) 24 | if len(module_output) == 2: 25 | return module_output[0] 26 | else: 27 | return module_output[:-1] 28 | else: 29 | return module_output 30 | 31 | -------------------------------------------------------------------------------- /demo/vppAutoKeypointShowSingle.m: -------------------------------------------------------------------------------- 1 | function vppAutoKeypointShowSingle(I, P) 2 | 3 | colormap(gca, 'gray') 4 | imagesc(I) 5 | pbaspect(gca, [size(I,2), size(I,1), 1]) 6 | hold on 7 | if isempty(P) 8 | hold off; 9 | return 10 | end 11 | H=size(I,1); 12 | W=size(I,2); 13 | ASPECT_RATIO=W/H; 14 | asrq=sqrt(ASPECT_RATIO); 15 | 16 | P(:,[1,2]) = bsxfun(@times, P(:,[1,2]), [size(I,1)*asrq, size(I,2)/asrq])+1; 17 | P = double(P); 18 | if 0 19 | for j = 1:size(P,1) 20 | the_color = [1 0 0]; 21 | plot(P(j,2),P(j,1), ... 22 | 'o', 'Color', the_color, 'MarkerFaceColor', the_color); 23 | text(P(j,2),P(j,1), int2str(j), ... 24 | 'HorizontalAlignment', 'left', 'VerticalAlignment', 'bottom', ... 25 | 'FontSize', 12, ... 26 | 'Color', the_color ... 27 | ); 28 | end 29 | else 30 | %the_color_list = linspecer(size(P,1)); 31 | the_color_list = jet(size(P,1)); 32 | for j = 1:size(P,1) 33 | the_color = the_color_list(j,:); 34 | plot(P(j,2),P(j,1), ... 35 | '+', 'Color', the_color, 'LineWidth', 5, 'MarkerSize', 8); 36 | end 37 | end 38 | 39 | hold off 40 | -------------------------------------------------------------------------------- /one_step_test_aflw.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SNAPSHOT_ITER="" 4 | SPECIFIC_MODEL_DIR=$(cd `dirname $0`; pwd)'/pretrained_results/aflw_10' 5 | 6 | TRAIN_GT_FILE=$(cd `dirname $0`; pwd)'/data/aflw_data/aflw_train_keypoints.mat' 7 | TEST_GT_FILE=$(cd `dirname $0`; pwd)'/data/aflw_data/aflw_test_keypoints.mat' 8 | 9 | python3 "./tools/run_test_in_folder.py" "$SPECIFIC_MODEL_DIR" "'test_subset':'test', 'test_limit':None" "test.test" "$SNAPSHOT_ITER" "True" "False" 10 | 11 | python3 "./tools/run_test_in_folder.py" "$SPECIFIC_MODEL_DIR" "'test_subset':'train', 'test_limit':None" "test.train" "$SNAPSHOT_ITER" "True" "False" 12 | 13 | DATA_DIR=$SPECIFIC_MODEL_DIR'/test.train' 14 | matlab -nodesktop -nosplash -r "cd('./evaluation');merge_large_dataset('$DATA_DIR');exit;" 15 | DATA_DIR=$SPECIFIC_MODEL_DIR'/test.test' 16 | matlab -nodesktop -nosplash -r "cd('./evaluation');merge_large_dataset('$DATA_DIR');exit;" 17 | 18 | TRAIN_PRED_FILE=$SPECIFIC_MODEL_DIR'/test.train/posterior_param.mat' 19 | TEST_PRED_FILE=$SPECIFIC_MODEL_DIR'/test.test/posterior_param.mat' 20 | matlab -nosplash -r "cd('./evaluation');face_evaluation('$TRAIN_PRED_FILE','$TRAIN_GT_FILE','$TEST_PRED_FILE','$TEST_GT_FILE');exit()" 21 | -------------------------------------------------------------------------------- /nets/distribution/generic.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | import tensorflow as tf 3 | from collections import OrderedDict 4 | import math 5 | import zutils.tf_math_funcs as tmf 6 | 7 | epsilon = tmf.epsilon 8 | 9 | 10 | class Factory: 11 | 12 | def sampling(self, dist_param, batch_size, latent_dim): 13 | """ Create network for VAE latent variables (sampling only) 14 | 15 | :param dist_param: input for the posterior 16 | :param batch_size: batch size 17 | :param latent_dim: dimension of the latent_variables 18 | :return: samples - random samples from either posterior or prior distribution 19 | """ 20 | 21 | # generate random samples 22 | rho = tf.random_uniform([batch_size, latent_dim]) 23 | return self.inv_cdf(dist_param, rho) 24 | 25 | @abstractmethod 26 | def nll(self, dist_param, samples): 27 | return None, None 28 | 29 | def nll_formatted(self, dist_param, samples): 30 | x, a = self.nll(dist_param, samples) 31 | if isinstance(a, bool): 32 | s = x.get_shape() 33 | a = tf.tile(tf.reshape(tf.constant(a), [1]*len(s)), s) 34 | return x, a 35 | -------------------------------------------------------------------------------- /bin/link_to_latest_samples.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ -z "$1" ]; then 4 | BASE_FOLDER="var/results" 5 | else 6 | BASE_FOLDER=$1 7 | fi 8 | 9 | cd $BASE_FOLDER 10 | 11 | rm -f _figure_path.txt 12 | mkfifo _figure_path.txt 13 | rm -f _available_folders.txt 14 | mkfifo _available_folders.txt 15 | 16 | find . -maxdepth 1 -type d > _available_folders.txt & 17 | 18 | cat _available_folders.txt | while read line; do 19 | LATEST_FIGURE= 20 | if [ -e $line/test.final/prior_samples.png ]; then 21 | LATEST_FIGURE=$line/test.final/prior_samples.png 22 | elif [ -d $line/test.snapshot ]; then 23 | (cd $line/test.snapshot; ls -d step_*) | sed -e "s/^step_//" | sort -r -n | while read step_idx; do 24 | THE_FN=$line/test.snapshot/step_"$step_idx"/prior_samples.png 25 | if [ -e $THE_FN ]; then 26 | echo $THE_FN 27 | break 28 | fi 29 | done > _figure_path.txt & 30 | LATEST_FIGURE=`cat _figure_path.txt` 31 | fi 32 | if [ -z "$LATEST_FIGURE" ]; then 33 | continue 34 | fi 35 | echo "$LATEST_FIGURE" "->" "$line.png" 36 | rm -f $line.png 37 | ln -s $LATEST_FIGURE $line.png 38 | done 39 | 40 | rm -f _figure_path.txt _available_folders.txt 41 | 42 | 43 | -------------------------------------------------------------------------------- /nets/distribution/gaussian_in_01.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import zutils.tf_math_funcs as tmf 3 | import nets.distribution.gaussian 4 | 5 | BaseFactory = nets.distribution.gaussian.Factory 6 | 7 | epsilon = tmf.epsilon 8 | 9 | 10 | class Factory(BaseFactory): 11 | 12 | default_stddev_lower_bound = 0.05 13 | 14 | def __init__(self, *args, **kwargs): 15 | super().__init__(*args, **kwargs) 16 | 17 | @classmethod 18 | def transform2param(cls, input_tensor, latent_dim): 19 | """ Create network for converting input_tensor to distribution parameters 20 | 21 | :param input_tensor: (posterior phase) input tensor for the posterior 22 | :param latent_dim: dimension of the latent_variables 23 | :return: param_tensor - distribution parameters 24 | """ 25 | param_tensor = tf.concat( 26 | [tf.sigmoid(input_tensor[:, :latent_dim]), 27 | tmf.atanh_sigmoid(input_tensor[:, latent_dim:])+epsilon], 28 | axis=1 29 | ) 30 | return param_tensor 31 | 32 | @staticmethod 33 | def sample_to_real(samples): 34 | return tmf.inv_sigmoid(samples) 35 | 36 | @staticmethod 37 | def real_to_samples(samples_in_real): 38 | return tf.sigmoid(samples_in_real) 39 | -------------------------------------------------------------------------------- /vis/vppAutoKeypointShowSingle.m: -------------------------------------------------------------------------------- 1 | function vppAutoKeypointShowSingle(I, P) 2 | 3 | colormap(gca, 'gray') 4 | imagesc(I) 5 | pbaspect(gca, [size(I,2), size(I,1), 1]) 6 | hold on 7 | if isempty(P) 8 | hold off; 9 | return 10 | end 11 | H=size(I,1); 12 | W=size(I,2); 13 | ASPECT_RATIO=W/H; 14 | asrq=sqrt(ASPECT_RATIO); 15 | 16 | % P(:,[1,2]) = bsxfun(@times, P(:,[1,2]), [size(I,1), size(I,2)])+1; 17 | P(:,[1,2]) = bsxfun(@times, P(:,[1,2]), [size(I,1)*asrq, size(I,2)/asrq])+1; 18 | P = double(P); 19 | if 0 20 | for j = 1:size(P,1) 21 | the_color = [1 0 0]; 22 | plot(P(j,2),P(j,1), ... 23 | 'o', 'Color', the_color, 'MarkerFaceColor', the_color); 24 | text(P(j,2),P(j,1), int2str(j), ... 25 | 'HorizontalAlignment', 'left', 'VerticalAlignment', 'bottom', ... 26 | 'FontSize', 12, ... 27 | 'Color', the_color ... 28 | ); 29 | end 30 | else 31 | the_color_list = linspecer(size(P,1)); 32 | for j = 1:size(P,1) 33 | the_color = the_color_list(j,:); 34 | plot(P(j,2),P(j,1), ... 35 | '+', 'Color', the_color, 'LineWidth', 2.5, 'MarkerSize', 4); 36 | %plot(P(j,2),P(j,1), ... 37 | % '+', 'Color', the_color, 'LineWidth', 5, 'MarkerSize', 8); 38 | end 39 | end 40 | 41 | hold off 42 | -------------------------------------------------------------------------------- /bin/wipe_test_cache.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | SCRIPT_PATH=`readlink -f "$0"` 4 | SCRIPT_FOLDER=`dirname "$SCRIPT_PATH"` 5 | TOOLBOX_FOLDER=`dirname "$SCRIPT_FOLDER"` 6 | CURRENT_FOLDER=`readlink -f .` 7 | 8 | RESULT_FOLDER=$1 9 | if [ -z "$RESULT_FOLDER" ]; then 10 | if [[ "${CURRENT_FOLDER}/" == "$TOOLBOX_FOLDER"/* ]]; then 11 | RESULT_FOLDER="$TOOLBOX_FOLDER"/var/results 12 | else 13 | RESULT_FOLDER="${CURRENT_FOLDER}" 14 | fi 15 | fi 16 | 17 | TMP_FOLDER=/tmp/clean_zyt_tf_cache.$$ 18 | mkdir -p "$TMP_FOLDER" 19 | mkfifo "$TMP_FOLDER"/model_iters 20 | mkfifo "$TMP_FOLDER"/test_iters 21 | mkfifo "$TMP_FOLDER"/rm_iters 22 | 23 | echo cd "$RESULT_FOLDER" 24 | cd "$RESULT_FOLDER" 25 | 26 | find . -type d -name "test.snapshot" | while read SNAPSHOT_ROOT; do 27 | echo "$SNAPSHOT_ROOT" 28 | EXP_ROOT=`dirname "$SNAPSHOT_ROOT"` 29 | MODEL_ROOT="$EXP_ROOT"/model 30 | ( cd "$MODEL_ROOT" && ls -d snapshot_step_*.index ) 2>/dev/null | sed -e 's/^snapshot_step_\(.*\)\.index$/\1/' | sort -u > "$TMP_FOLDER"/model_iters & 31 | ( cd "$SNAPSHOT_ROOT" && ls -d step_* ) 2>/dev/null | sed -e 's/^step_\(.*\)$/\1/' | sort -nu | head -n -1 | sort > "$TMP_FOLDER"/test_iters & 32 | comm -13 "$TMP_FOLDER"/model_iters "$TMP_FOLDER"/test_iters > "$TMP_FOLDER"/rm_iters & 33 | cat "$TMP_FOLDER"/rm_iters | while read RM_ITER; do 34 | echo "$SNAPSHOT_ROOT"/step_"$RM_ITER" 35 | done | xargs rm -rf 36 | done 37 | 38 | rm -r "$TMP_FOLDER" 39 | -------------------------------------------------------------------------------- /nets/distribution/gaussian_fixedvar_in_01.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import zutils.tf_math_funcs as tmf 3 | import nets.distribution.gaussian_fixedvar 4 | 5 | BaseFactory = nets.distribution.gaussian_fixedvar.Factory 6 | 7 | epsilon = tmf.epsilon 8 | 9 | 10 | class Factory(BaseFactory): 11 | 12 | def __init__(self, *args, **kwargs): 13 | super().__init__(*args, **kwargs) 14 | if "options" in kwargs and "pre_sigmoid" in kwargs["options"]: 15 | self.pre_sigmoid = kwargs["options"]["pre_sigmoid"] 16 | else: 17 | self.pre_sigmoid = False 18 | 19 | def transform2param(self, input_tensor, latent_dim): 20 | """ Create network for converting input_tensor to distribution parameters 21 | 22 | :param input_tensor: (posterior phase) input tensor for the posterior 23 | :param latent_dim: dimension of the latent_variables 24 | :return: param_tensor - distribution parameters 25 | """ 26 | param_tensor = tf.concat( 27 | [input_tensor[:, :latent_dim] if self.pre_sigmoid else tf.sigmoid(input_tensor[:, :latent_dim]), 28 | tmf.atanh_sigmoid(input_tensor[:, latent_dim:])+epsilon], 29 | axis=1 30 | ) 31 | return param_tensor 32 | 33 | @staticmethod 34 | def sample_to_real(samples): 35 | return tmf.inv_sigmoid(samples) 36 | 37 | @staticmethod 38 | def real_to_samples(samples_in_real): 39 | return tf.sigmoid(samples_in_real) 40 | -------------------------------------------------------------------------------- /evaluation/mean_error_IOD.m: -------------------------------------------------------------------------------- 1 | function mean_error = mean_error_IOD(fit_kp, gt_kp) 2 | %assume the order of keypoints are right eye, left_eye, nose, mouth_corner_r mouth_corner_l 3 | fprintf('Calculate the error for each face \n'); 4 | size(gt_kp) 5 | error_list = []; 6 | for i=1:size(gt_kp, 1) 7 | fit_keypoints = squeeze(fit_kp(i, :, :)); 8 | gt_keypoints = squeeze(gt_kp(i, :, :)); 9 | face_error = 0; 10 | for k = 1:size(gt_kp, 2) 11 | face_error = face_error + norm(fit_keypoints(k,:)-gt_keypoints(k,:)); 12 | end 13 | face_error = face_error/size(gt_kp, 2); 14 | if size(gt_keypoints,1)==5 15 | right_pupil = gt_keypoints(1, :); 16 | left_pupil = gt_keypoints(2, :); 17 | elseif size(gt_keypoints,1)==68 18 | left_pupil = gt_keypoints(8, :); 19 | right_pupil = gt_keypoints(11, :); 20 | elseif (size(gt_keypoints,1)==9 || size(gt_keypoints,1)==7) 21 | left_pupil = gt_keypoints(1, :); 22 | right_pupil = gt_keypoints(2, :); 23 | elseif (size(gt_keypoints,1)==6) 24 | left_pupil = gt_keypoints(1, :); 25 | right_pupil = gt_keypoints(2, :); 26 | elseif (size(gt_keypoints, 1)==32) 27 | left_pupil = [0,0]; 28 | right_pupil = [0,1]; 29 | end 30 | IOD = norm(right_pupil-left_pupil); 31 | if(IOD~=0) 32 | face_error_normalized = face_error/IOD; 33 | error_list = [error_list, face_error_normalized]; 34 | end 35 | end 36 | mean_error = mean(error_list); 37 | end 38 | -------------------------------------------------------------------------------- /evaluation/face_evaluation.m: -------------------------------------------------------------------------------- 1 | function kp_evaluation(train_pred_file, train_gt_file, test_pred_file, test_gt_file) 2 | train_pred = load(train_pred_file); 3 | train_pred_kp = train_pred.encoded.structure_param; %scale (0,1) 4 | train_gt = load(train_gt_file); 5 | train_gt_hw = train_gt.hw; 6 | train_gt_kp = train_gt.gt;%scale (0,h-1), scale (0, w-1) 7 | 8 | train_gt_kp(:,:,1) = train_gt_kp(:,:,1)./double(repmat(squeeze(train_gt_hw(:,1)),[1, size(train_gt_kp,2)])); 9 | train_gt_kp(:,:,2) = train_gt_kp(:,:,2)./double(repmat(squeeze(train_gt_hw(:,2)),[1, size(train_gt_kp,2)])); 10 | 11 | test_pred = load(test_pred_file); 12 | test_pred_kp = test_pred.encoded.structure_param; %scale (0,1) scale (0,1) 13 | test_gt = load(test_gt_file); 14 | test_gt_kp = test_gt.gt; 15 | test_gt_hw = test_gt.hw;%scale (0, h-1), scale (0, w-1) 16 | 17 | test_gt_kp(:,:,1) = test_gt_kp(:,:,1)./double(repmat(squeeze(test_gt_hw(:,1)),[1, size(test_gt_kp,2)])); 18 | test_gt_kp(:,:,2) = test_gt_kp(:,:,2)./double(repmat(squeeze(test_gt_hw(:,2)),[1, size(test_gt_kp,2)])); 19 | 20 | train_gt_kp = train_gt_kp - 0.5; 21 | test_gt_kp = test_gt_kp - 0.5; 22 | train_pred_kp = train_pred_kp - 0.5; 23 | test_pred_kp = test_pred_kp - 0.5; 24 | 25 | W = linear_regressor(train_pred_kp, train_gt_kp); 26 | test_pred_kp = reshape(test_pred_kp,[size(test_pred_kp,1), size(test_pred_kp,2)*size(test_pred_kp,3)] ); 27 | test_fit_kp = test_pred_kp*W; 28 | test_fit_kp = reshape(test_fit_kp, [size(test_pred_kp,1), size(test_fit_kp,2)/2, 2]); 29 | mean_error = mean_error_IOD(test_fit_kp, test_gt_kp); 30 | mean_error 31 | end 32 | -------------------------------------------------------------------------------- /evaluation/cat_evaluation.m: -------------------------------------------------------------------------------- 1 | function kp_evaluation(train_pred_file, train_gt_file, test_pred_file, test_gt_file) 2 | fprintf('Load the predicted keypoints and ground truth keypoints for the training dataset\n'); 3 | train_pred = load(train_pred_file); 4 | train_pred_kp = train_pred.encoded.structure_param; 5 | train_gt = load(train_gt_file); 6 | train_gt_hw = train_gt.hw*80/100; 7 | train_gt_kp = train_gt.gt(:,[1:4,6:7,9],:)-10; 8 | train_gt_kp(:,:,1) = train_gt_kp(:,:,1)./double(repmat(squeeze(train_gt_hw(:,1)),[1, size(train_gt_kp,2)])); 9 | train_gt_kp(:,:,2) = train_gt_kp(:,:,2)./double(repmat(squeeze(train_gt_hw(:,2)),[1, size(train_gt_kp,2)])); 10 | test_pred = load(test_pred_file); 11 | test_pred_kp = test_pred.encoded.structure_param; 12 | test_gt = load(test_gt_file); 13 | test_gt_kp = test_gt.gt(:,[1:4,6:7,9],:)-10; 14 | test_gt_hw = test_gt.hw*80/100; 15 | test_gt_kp(:,:,1) = test_gt_kp(:,:,1)./double(repmat(squeeze(test_gt_hw(:,1)),[1, size(test_gt_kp,2)])); 16 | test_gt_kp(:,:,2) = test_gt_kp(:,:,2)./double(repmat(squeeze(test_gt_hw(:,2)),[1, size(test_gt_kp,2)])); 17 | 18 | train_gt_kp = train_gt_kp - 0.5; 19 | test_gt_kp = test_gt_kp - 0.5; 20 | train_pred_kp = train_pred_kp - 0.5; 21 | test_pred_kp = test_pred_kp - 0.5; 22 | 23 | W = linear_regressor(train_pred_kp, train_gt_kp); 24 | test_pred_kp = reshape(test_pred_kp,[size(test_pred_kp,1), size(test_pred_kp,2)*size(test_pred_kp,3)] ); 25 | test_fit_kp = test_pred_kp*W; 26 | test_fit_kp = reshape(test_fit_kp, [size(test_pred_kp,1), size(test_fit_kp,2)/2, 2]); 27 | mean_error = mean_error_IOD(test_fit_kp, test_gt_kp); 28 | mean_error 29 | end 30 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /var 2 | /exp- 3 | 4 | *.bak 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # Environments 90 | .env 91 | .venv 92 | env/ 93 | venv/ 94 | ENV/ 95 | env.bak/ 96 | venv.bak/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | 111 | -------------------------------------------------------------------------------- /runner/resumable_data_module_wrapper.py: -------------------------------------------------------------------------------- 1 | from zutils.py_utils import link_with_instance 2 | 3 | 4 | class Net: 5 | 6 | def __init__(self, data_module, _num_sample_factors=None): 7 | 8 | if _num_sample_factors is None: 9 | self._num_sample_factors = 1 10 | else: 11 | self._num_sample_factors = _num_sample_factors 12 | self._num_samples = data_module.num_samples() 13 | self.data_module = data_module 14 | self._total_pos = \ 15 | data_module.num_samples_finished() + \ 16 | self._num_samples * data_module.epoch() 17 | self._total_iter = data_module.iter() 18 | 19 | link_with_instance(self, self.data_module) 20 | 21 | def __call__(self, *args, **kwargs): 22 | return self.next_batch(*args, **kwargs) 23 | 24 | def next_batch(self, batch_size): 25 | self._total_iter += 1 26 | self._total_pos += batch_size 27 | data = self.data_module(batch_size) 28 | return data 29 | 30 | def num_samples_finished(self): 31 | return (self._total_pos % self._num_samples) * self._num_sample_factors 32 | 33 | def epoch(self): 34 | return (self._total_pos // self._num_samples) * self._num_sample_factors 35 | 36 | def iter(self): 37 | return self._total_iter * self._num_sample_factors 38 | 39 | def pos(self): 40 | return self._total_pos * self._num_sample_factors 41 | 42 | def set_iter(self, new_iter): 43 | self._total_iter = new_iter * self._num_sample_factors 44 | 45 | def set_pos(self, new_pos): 46 | self._total_pos = new_pos * self._num_sample_factors 47 | 48 | def reset(self): 49 | self._total_pos = 0 50 | self._total_iter = 0 51 | 52 | def num_samples(self): 53 | return self._num_samples * self._num_sample_factors 54 | 55 | -------------------------------------------------------------------------------- /net_modules/ndeconv.py: -------------------------------------------------------------------------------- 1 | from net_modules.deconv import * 2 | from net_modules.deconv import _deconv2d 3 | from net_modules.deconv import _kernel 4 | from net_modules.deconv import _stride 5 | 6 | import net_modules.ndeconv 7 | 8 | 9 | @prettytensor.Register( 10 | assign_defaults=('activation_fn', 'l2loss', 'stddev', 'batch_normalize')) 11 | class ndeconv2d: 12 | def __init__(self): 13 | self._deconv_internal = None 14 | 15 | def __call__( 16 | self, 17 | input_layer, 18 | kernel, 19 | depth, 20 | name=PROVIDED, 21 | stride=None, 22 | activation_fn=None, 23 | l2loss=None, 24 | init=None, 25 | stddev=None, 26 | bias=True, 27 | edges=PAD_SAME, 28 | batch_normalize=False 29 | ): 30 | 31 | input_shape = input_layer.shape 32 | self._deconv_internal = _deconv2d() 33 | output_layer = self._deconv_internal( 34 | input_layer, 35 | kernel, depth, name, stride, activation_fn, l2loss, 36 | init, stddev, bias, edges, batch_normalize) 37 | output_shape = output_layer.shape 38 | 39 | output_mask = _deconv_mask(input_shape, output_shape, kernel, stride, edges, output_layer.dtype) 40 | output_layer.with_tensor(tf.multiply(output_layer.tensor, output_mask)) 41 | 42 | return output_layer 43 | 44 | 45 | def _deconv_mask(input_shape, output_shape, kernel, stride, padding, dtype): 46 | with tf.variable_scope("fida_factor"): 47 | with tf.device("/cpu:0"): 48 | filter_mask = tf.ones(shape=_kernel(kernel) + [1, 1], dtype=dtype) 49 | input_mask = tf.ones(shape=[1] + input_shape[1:3] + [1], dtype=dtype) 50 | output_mask = tf.nn.conv2d_transpose( 51 | value=input_mask, filter=filter_mask, output_shape=[1] + output_shape[1:3] + [1], 52 | strides=_stride(stride), padding=padding 53 | ) 54 | output_mask = tf.reduce_mean(output_mask) / output_mask 55 | return output_mask 56 | -------------------------------------------------------------------------------- /net_modules/nearest_upsampling.py: -------------------------------------------------------------------------------- 1 | from net_modules.deconv import * 2 | from net_modules.deconv import _deconv2d 3 | from net_modules.deconv import _kernel 4 | from net_modules.deconv import _stride 5 | from net_modules.deconv import get2d_deconv_output_size 6 | from net_modules.ndeconv import _deconv_mask 7 | from prettytensor import pretty_tensor_class as prettytensor 8 | 9 | import zutils.tf_math_funcs as tmf 10 | 11 | 12 | @prettytensor.Register() 13 | def nearest_upsampling( 14 | input_layer, kernel, stride, edges=PAD_SAME, name=PROVIDED 15 | ): 16 | 17 | assert len(input_layer.shape) == 4, "input rank must be 4" 18 | 19 | kernel = _kernel(kernel) 20 | stride = _stride(stride) 21 | 22 | input_height = input_layer.shape[1] 23 | input_width = input_layer.shape[2] 24 | depth = input_layer.shape[3] 25 | 26 | filter_height = kernel[0] 27 | filter_width = kernel[1] 28 | 29 | row_stride = stride[1] 30 | col_stride = stride[2] 31 | 32 | out_rows, out_cols = get2d_deconv_output_size( 33 | input_height, input_width, filter_height, filter_width, row_stride, col_stride, edges) 34 | 35 | output_shape_3d = [input_layer.shape[0], out_rows, out_cols, depth, 1] 36 | 37 | kernel_3d = kernel + [1] 38 | stride_3d = stride + [1] 39 | 40 | filter_mask = tf.ones(shape=kernel_3d + [1, 1], dtype=input_layer.dtype) 41 | output_tensor = tf.nn.conv3d_transpose( 42 | value=tf.expand_dims(input_layer, axis=-1), 43 | filter=filter_mask, output_shape=output_shape_3d, 44 | strides=stride_3d, padding=edges, name=name 45 | ) 46 | output_tensor = tf.squeeze(output_tensor, axis=4) 47 | if filter_height != row_stride or filter_width != col_stride: 48 | output_mask = _deconv_mask( 49 | tmf.get_shape(input_layer), output_shape_3d[:-1], 50 | kernel, stride, edges, input_layer.dtype) 51 | output_tensor = tf.multiply(output_tensor, output_mask) 52 | 53 | return input_layer.with_tensor(output_tensor) 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /vis/figure_show_callback.m: -------------------------------------------------------------------------------- 1 | classdef figure_show_callback < handle 2 | 3 | properties (GetAccess=public, SetAccess=protected) 4 | output_path_pattern = [] 5 | auto_continue = false 6 | end 7 | 8 | methods (Access=public) 9 | 10 | function obj = figure_show_callback(output_path_pattern, auto_continue) 11 | if nargin < 1 12 | output_path_pattern = []; 13 | end 14 | if nargin < 2 15 | auto_continue = false; 16 | end 17 | obj.output_path_pattern = output_path_pattern; 18 | if isempty(output_path_pattern) 19 | assert(~auto_continue, 'Cannot use auto continue for display mode'); 20 | end 21 | obj.auto_continue = auto_continue; 22 | end 23 | 24 | function callback(obj, varargin) 25 | obj.callback_dummy() 26 | obj.callback_no_user_input(varargin{:}) 27 | end 28 | 29 | function callback_dummy(obj) 30 | if ~obj.auto_continue 31 | c = waitforbuttonpress; 32 | if ~isempty(obj.output_path_pattern) && c 33 | obj.auto_continue = true; 34 | fprintf('Batch mode started\n') 35 | end 36 | end 37 | end 38 | 39 | function callback_no_user_input(obj, varargin) 40 | if ~isempty(obj.output_path_pattern) 41 | the_figure = gcf; 42 | the_axis = gca; 43 | file_path = sprintf(obj.output_path_pattern, varargin{:}); 44 | file_path = [file_path '.eps']; 45 | [parent_folder, ~, ~] = fileparts(file_path); 46 | if ~exist(parent_folder, 'file') 47 | mkdir(parent_folder) 48 | end 49 | fprintf('%s : ', file_path); tic 50 | saveas(the_figure, file_path, 'epsc'); 51 | toc 52 | figure(the_figure) 53 | axes(the_axis) 54 | pause(0.00001) 55 | end 56 | end 57 | 58 | end 59 | 60 | end 61 | -------------------------------------------------------------------------------- /runner/data_op.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from operator import itemgetter 3 | 4 | 5 | class RecordedDataFetcher: 6 | def __init__(self, data_mod, batch_size, debug_text=None): 7 | self.iter = 0 8 | self.data_mod = data_mod 9 | self.batch_size = batch_size 10 | self.debug_text = debug_text 11 | 12 | def next_batch(self): 13 | self.iter += 1 14 | if self.debug_text is not None: 15 | print("%s [iter %d]" % (str(self.debug_text), self.iter)) 16 | return self.data_mod(self.batch_size) 17 | 18 | 19 | def tf_variable_from_data_module(*args, **kwargs): 20 | with tf.device("/cpu:0"): 21 | return _tf_variable_from_data_module(*args, **kwargs) 22 | 23 | 24 | def _tf_variable_from_data_module(data_mod, batch_size, output_index=0, debug_text=None): 25 | out_type = data_mod.output_types() 26 | out_shape = data_mod.output_shapes() 27 | for i in range(len(out_type)): 28 | if isinstance(out_type[i], str): 29 | out_type[i] = getattr(tf, out_type[i]) 30 | 31 | rdf = RecordedDataFetcher(data_mod, batch_size, debug_text) 32 | 33 | data = tf.py_func(rdf.next_batch, [], out_type) 34 | if not (isinstance(data, list) or isinstance(data, tuple)): 35 | out_shape = list(out_shape) 36 | out_shape[0] = batch_size 37 | data = tf.reshape(data, out_shape) 38 | else: 39 | data = list(data) 40 | out_shape = list(out_shape) 41 | for i in range(len(out_shape)): 42 | out_shape[i] = list(out_shape[i]) 43 | out_shape[i][0] = batch_size 44 | data[i] = tf.reshape(data[i], out_shape[i]) 45 | 46 | assert output_index is not None, "output_index should be a scalar or a list" 47 | if not (isinstance(data, list) or isinstance(data, tuple)): 48 | data = [data] 49 | 50 | for i in range(len(data)): 51 | data[i] = tf.stop_gradient(data[i]) 52 | 53 | out_keys = data_mod.output_keys() 54 | data_dict = dict() 55 | for i in range(len(data)): 56 | data_dict[i] = data[i] 57 | data_dict[out_keys[i]] = data[i] 58 | 59 | if isinstance(output_index, (list, tuple)): 60 | return itemgetter(*output_index)(data_dict) 61 | else: 62 | return itemgetter(output_index)(data_dict) 63 | -------------------------------------------------------------------------------- /net_modules/gen.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from zutils.py_utils import call_func_with_ignored_args 3 | 4 | 5 | def get_class_instance(net_type, net_name, cls_name, *args, **kwargs): 6 | """ Get an instance of the network class 7 | 8 | :param net_type: 'encoder', 'latent', 'decoder', 'recon' 9 | :param net_name: a string indicating which net func to use 10 | :param cls_name: a string indicating which class to use 11 | :return: 12 | """ 13 | mod_name = "nets."+net_type+"."+net_name 14 | mod_spec = importlib.util.find_spec(mod_name) 15 | if mod_spec is None: 16 | return None 17 | cls_mod = importlib.import_module(mod_name) 18 | if hasattr(cls_mod, cls_name): 19 | cls_type = getattr(cls_mod, cls_name) 20 | cls_instance = call_func_with_ignored_args(cls_type, *args, **kwargs) 21 | return cls_instance 22 | else: 23 | return None 24 | 25 | 26 | def try_get_net_factory(net_type, net_name, *args, **kwargs): 27 | """ Get an instance of the network factory class 28 | 29 | :param net_type: 'encoder', 'latent', 'decoder', 'recon' 30 | :param net_name: a string indicating which net func to use 31 | :return: 32 | """ 33 | return get_class_instance(net_type, net_name, "Factory", *args, **kwargs) 34 | 35 | 36 | def get_net_factory(net_type, net_name, *args, **kwargs): 37 | """ Get an instance of the network factory class 38 | 39 | :param net_type: 'encoder', 'latent', 'decoder', 'recon' 40 | :param net_name: a string indicating which net func to use 41 | :return: 42 | """ 43 | a = try_get_net_factory(net_type, net_name, *args, **kwargs) 44 | assert a is not None, "Cannot find such a factory" 45 | return a 46 | 47 | 48 | def try_get_net_instance(net_type, net_name, *args, **kwargs): 49 | """ Get an instance of the network class 50 | 51 | :param net_type: 'data' 52 | :param net_name: a string indicating which net func to use 53 | :return: 54 | """ 55 | return get_class_instance(net_type, net_name, "Net", *args, **kwargs) 56 | 57 | 58 | def get_net_instance(net_type, net_name, *args, **kwargs): 59 | """ Get an instance of the network class 60 | 61 | :param net_type: 'data' 62 | :param net_name: a string indicating which net func to use 63 | :return: 64 | """ 65 | a = try_get_net_instance(net_type, net_name, *args, **kwargs) 66 | assert a is not None, "Cannot find such a net" 67 | return a 68 | 69 | -------------------------------------------------------------------------------- /exp-ae-aflw-10.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | import sys 4 | from copy import copy 5 | from model.pipeline import Pipeline 6 | 7 | from tensorflow.python import debug as tf_debug 8 | 9 | 10 | if __name__ == "__main__": 11 | 12 | num_keypoints = 10 13 | patch_feature_dim = 8 14 | decoding_levels = 5 15 | kp_transform_loss = 1e4 16 | 17 | recon_weight = 0.1 18 | 19 | learning_rate=0.001 20 | 21 | keypoint_separation_bandwidth=0.06 22 | keypoint_separation_loss_weight = 16.0 23 | 24 | opt = { 25 | "optimizer": "Adam", 26 | "data_name": "aflw_80x80", 27 | "recon_name": "gaussian_fixedvar_in_01", 28 | "encoder_name": "general_80x80", 29 | "decoder_name": "general_80x80", 30 | "latent_dim": num_keypoints*2+(num_keypoints+1)*patch_feature_dim, 31 | "train_color_jittering": True, 32 | "train_random_mirroring": False, 33 | "train_batch_size": 16, 34 | "train_shuffle_capacity": 1000, 35 | "learning_rate": learning_rate, 36 | "max_epochs": 2000, 37 | "weight_decay": 1e-6, 38 | "test_steps": 5000, 39 | "test_limit": 200, 40 | "recon_weight": recon_weight, 41 | } 42 | 43 | opt["encoder_options"] = { 44 | "keypoint_num": num_keypoints, 45 | "patch_feature_dim": patch_feature_dim, 46 | "ae_recon_type": opt["recon_name"], 47 | "keypoint_concentration_loss_weight": 100., 48 | "keypoint_axis_balancing_loss_weight": 200., 49 | "keypoint_separation_loss_weight": keypoint_separation_loss_weight, 50 | "keypoint_separation_bandwidth": keypoint_separation_bandwidth, 51 | "keypoint_transform_loss_weight": kp_transform_loss, 52 | "keypoint_decoding_heatmap_levels": decoding_levels, 53 | "keypoint_decoding_heatmap_level_base": 0.5**(1/2), 54 | "image_channels": 3, 55 | } 56 | opt["decoder_options"] = copy(opt["encoder_options"]) 57 | 58 | # ------------------------------------- 59 | model_dir = os.path.join("results/aflw_10") 60 | checkpoint_dir = 'pretrained_results' 61 | checkpoint_filename = 'celeba_10/model/snapshot_step_224549' 62 | vp = Pipeline(None, opt, model_dir=model_dir) 63 | print(vp.opt) 64 | with vp.graph.as_default(): 65 | sess = vp.create_session() 66 | vp.run_full_train_from_checkpoint(sess, checkpoint_dir = checkpoint_dir, checkpoint_filename=checkpoint_filename) 67 | vp.run_full_test(sess) 68 | -------------------------------------------------------------------------------- /exp-ae-aflw-30.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | import sys 4 | from copy import copy 5 | from model.pipeline import Pipeline 6 | 7 | from tensorflow.python import debug as tf_debug 8 | 9 | 10 | if __name__ == "__main__": 11 | 12 | num_keypoints = 30 13 | patch_feature_dim = 8 14 | decoding_levels = 5 15 | kp_transform_loss = 1e4 16 | 17 | recon_weight = 0.001 18 | 19 | learning_rate=0.01 20 | 21 | keypoint_separation_bandwidth=0.04 22 | keypoint_separation_loss_weight = 10.0 23 | 24 | opt = { 25 | "optimizer": "Adam", 26 | "data_name": "aflw_80x80", 27 | "recon_name": "gaussian_fixedvar_in_01", 28 | "encoder_name": "general_80x80", 29 | "decoder_name": "general_80x80", 30 | "latent_dim": num_keypoints*2+(num_keypoints+1)*patch_feature_dim, 31 | "train_color_jittering": True, 32 | "train_random_mirroring": False, 33 | "train_batch_size": 8, 34 | "train_shuffle_capacity": 1000, 35 | "learning_rate": learning_rate, 36 | "max_epochs": 2000, 37 | "weight_decay": 1e-6, 38 | "test_steps": 5000, 39 | "test_limit": 200, 40 | "recon_weight": recon_weight, 41 | } 42 | 43 | opt["encoder_options"] = { 44 | "keypoint_num": num_keypoints, 45 | "patch_feature_dim": patch_feature_dim, 46 | "ae_recon_type": opt["recon_name"], 47 | "keypoint_concentration_loss_weight": 100., 48 | "keypoint_axis_balancing_loss_weight": 200., 49 | "keypoint_separation_loss_weight": keypoint_separation_loss_weight, 50 | "keypoint_separation_bandwidth": keypoint_separation_bandwidth, 51 | "keypoint_transform_loss_weight": kp_transform_loss, 52 | "keypoint_decoding_heatmap_levels": decoding_levels, 53 | "keypoint_decoding_heatmap_level_base": 0.5**(1/2), 54 | "image_channels": 3, 55 | } 56 | opt["decoder_options"] = copy(opt["encoder_options"]) 57 | 58 | # ------------------------------------- 59 | model_dir = os.path.join("results/aflw_30") 60 | checkpoint_dir = 'pretrained_results' 61 | checkpoint_filename = 'celeba_30/model/snapshot_step_205317' 62 | vp = Pipeline(None, opt, model_dir=model_dir) 63 | print(vp.opt) 64 | with vp.graph.as_default(): 65 | sess = vp.create_session() 66 | vp.run_full_train_from_checkpoint(sess, checkpoint_dir = checkpoint_dir, checkpoint_filename=checkpoint_filename) 67 | vp.run_full_test(sess) 68 | -------------------------------------------------------------------------------- /model/pipeline.py: -------------------------------------------------------------------------------- 1 | from scipy.io import savemat 2 | 3 | import net_modules.gen 4 | from model.pipeline_netdef import PipelineNetDef 5 | from runner.one_epoch_runner import OneEpochRunner 6 | from zutils.py_utils import * 7 | 8 | net_factory = net_modules.gen.get_net_factory 9 | net_instance = net_modules.gen.get_net_instance 10 | 11 | 12 | class Pipeline(PipelineNetDef): 13 | 14 | def __init__(self, *args, **kwargs): 15 | super().__init__(*args, **kwargs) 16 | 17 | # training ------------------------------------------ 18 | 19 | def resume(self, sess): 20 | self.train.groups["ae"].trainer.run(sess) 21 | 22 | # posterior parameters ----------------------------------------------------------------------- 23 | 24 | def posterior_param(self, sess, output_fn, is_large=False, save_img_data=True): 25 | if save_img_data: 26 | output_list = self.posterior.outputs 27 | else: 28 | output_list = self.posterior.outputs 29 | output_list.pop('data') 30 | print(output_list) 31 | r = OneEpochRunner( 32 | self.posterior.data_module, 33 | output_fn = output_fn, 34 | num_samples=self.opt.test_limit, 35 | output_list=output_list, 36 | disp_time_interval=self.opt.disp_time_interval, 37 | is_large=is_large) 38 | return r.run(sess) 39 | 40 | """ 41 | def dump_posterior_param(self, pp, output_fn): 42 | dir_path = os.path.dirname(output_fn) 43 | if not os.path.exists(dir_path): 44 | os.makedirs(dir_path) 45 | # pickle.dump(pp, open(output_fn + ".p", "wb")) 46 | if "vis" in pp["decoded"]: 47 | pp["decoded"]["vis"] = self.output_scaled_color_image(pp["decoded"]["vis"]) 48 | if "data" in pp: 49 | pp["data"] = self.output_scaled_color_image(pp["data"]) 50 | savemat(output_fn + ".mat", pp) 51 | """ 52 | 53 | # all test ----------------------------------------------------------------------------- 54 | def test(self, sess, output_dir, is_snapshot=False, is_large=False, save_img_data=True): 55 | 56 | def nprint(*args, **kwargs): 57 | if not is_snapshot: 58 | print(*args, **kwargs) 59 | 60 | # not necessary for snapshot 61 | nprint('========== Posterior parameters') 62 | self.posterior_param(sess, os.path.join(output_dir, "posterior_param"), is_large, save_img_data) 63 | nprint('-- Done') 64 | 65 | 66 | -------------------------------------------------------------------------------- /tools/run_test_in_folder.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import pickle 4 | from zutils.py_utils import * 5 | from model.pipeline import Pipeline 6 | 7 | 8 | def run_test(folder_path, override_dict, test_path, snapshot_iter, is_large, save_img_data): 9 | 10 | print("Folder path: %s" % folder_path) 11 | 12 | with open(os.path.join(folder_path, "PARAM.p"), 'rb') as f: 13 | opt0 = pickle.load(f) 14 | 15 | # opt = {**opt0, **override_dict} 16 | opt = recursive_merge_dicts(opt0, override_dict) 17 | 18 | vp = Pipeline( 19 | None, opt, model_dir=folder_path, 20 | auto_save_hyperparameters=False, use_logging=False 21 | ) 22 | 23 | print(vp.opt) 24 | with vp.graph.as_default(): 25 | sess = vp.create_session() 26 | vp.run_full_test_from_checkpoint(sess, test_path=test_path, snapshot_iter=snapshot_iter, is_large=is_large, save_img_data=save_img_data) 27 | 28 | 29 | def main(): 30 | if not sys.argv: 31 | print("Usage: run_test_in_folder.py EXP_PATH [OVERRIDE_PARAM [TEST_PATH [SNAPSHOT_ITER]]]") 32 | exit(-1) 33 | folder_path = sys.argv[1] 34 | 35 | opt_command = sys.argv[2] if len(sys.argv) > 2 else "" 36 | override_dict = eval("{%s}" % opt_command) 37 | 38 | test_path = sys.argv[3] if len(sys.argv) > 3 else "" 39 | if not test_path: 40 | test_path = None 41 | 42 | snapshot_iter = sys.argv[4] if len(sys.argv) > 4 else "" 43 | if not snapshot_iter: 44 | snapshot_iter = None 45 | else: 46 | snapshot_iter = int(snapshot_iter) 47 | 48 | is_large = sys.argv[5] if len(sys.argv)>5 else "" 49 | if is_large == 'False': 50 | is_large = False 51 | elif is_large == 'True': 52 | is_large = True 53 | else: 54 | is_large = False 55 | 56 | save_img_data = sys.argv[6] if len(sys.argv)>6 else "" 57 | if save_img_data == 'False': 58 | save_img_data = False 59 | elif save_img_data == 'True': 60 | save_img_data = True 61 | else: 62 | save_img_data = True 63 | 64 | if folder_path != "-": 65 | run_test(folder_path, override_dict, test_path, snapshot_iter, is_large, save_img_data) 66 | else: 67 | print("Please Input a List of Experiment Folders:") 68 | for line in sys.stdin: 69 | line = line[:-1] 70 | if not line: 71 | continue 72 | run_test(line, override_dict, test_path, snapshot_iter, is_large, save_img_data) 73 | 74 | 75 | if __name__ == "__main__": 76 | main() 77 | -------------------------------------------------------------------------------- /nets/recon/generic_single_dist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | import zutils.tf_math_funcs as tmf 5 | from net_modules.gen import get_net_factory 6 | 7 | epsilon = tmf.epsilon 8 | 9 | 10 | class Factory: 11 | 12 | def __init__(self, latent_dist_name, *args, **kwargs): 13 | self.output_dist = get_net_factory("distribution", latent_dist_name, *args, **kwargs) 14 | assert self.output_dist is not None, "Cannot get the distribution" 15 | 16 | def __call__(self, input_tensor, gt_tensor): 17 | 18 | input_tensor, input_shape = self.flatten_dist_tensor(input_tensor) 19 | 20 | gt_s = tmf.get_shape(gt_tensor) 21 | gt_tensor = tf.reshape(gt_tensor, [gt_s[0], np.prod(gt_s[1:])]) 22 | 23 | dist_param, _ = self.visible_dist(input_tensor) 24 | nll, _ = self.output_dist.nll(dist_param, gt_tensor) 25 | 26 | nll = tf.reshape(nll, input_shape) 27 | return nll 28 | 29 | def flatten_dist_tensor(self, dist_tensor): 30 | s = tmf.get_shape(dist_tensor) 31 | total_hidden = np.prod(s[1:]) 32 | input_tensor = tf.reshape(dist_tensor, [s[0], total_hidden//self.param_num(), self.param_num()]) 33 | input_tensor = tf.transpose(input_tensor, [0, 2, 1]) # move the channel in front of geometric axes 34 | input_tensor = tf.reshape(input_tensor, [s[0], total_hidden]) 35 | s[-1] //= self.param_num() 36 | return input_tensor, s 37 | 38 | def self_entropy(self, input_tensor): 39 | 40 | input_tensor, input_shape = self.flatten_dist_tensor(input_tensor) 41 | 42 | dist_param, _ = self.visible_dist(input_tensor) 43 | se = self.output_dist.self_entropy(dist_param) 44 | 45 | se = tf.reshape(se, input_shape) 46 | return se 47 | 48 | def mean(self, input_tensor): 49 | 50 | input_tensor, input_shape = self.flatten_dist_tensor(input_tensor) 51 | 52 | dist_param, param_tensor = self.visible_dist(input_tensor) 53 | vis = self.output_dist.mean(dist_param) 54 | 55 | param_tensor = tf.reshape( 56 | param_tensor, [input_shape[0], self.param_num()]+input_shape[1:]) 57 | vis = tf.reshape(vis, input_shape) 58 | return vis, param_tensor 59 | 60 | visualize = mean 61 | 62 | def visible_dist(self, input_tensor): 63 | s = tmf.get_shape(input_tensor) 64 | latent_dim = np.prod(s[1:])//self.param_num() 65 | param_tensor = self.output_dist.transform2param(input_tensor, latent_dim) 66 | dist_param = self.output_dist.parametrize(param_tensor, latent_dim) 67 | return dist_param, param_tensor 68 | 69 | def param_num(self): 70 | return self.output_dist.param_num() 71 | -------------------------------------------------------------------------------- /net_modules/nconv.py: -------------------------------------------------------------------------------- 1 | from net_modules.ndeconv import * 2 | from net_modules.ndeconv import _deconv_mask 3 | from prettytensor import parameters 4 | from prettytensor.pretty_tensor_image_methods import conv2d as pt_conv2d 5 | 6 | import zutils.tf_math_funcs as tmf 7 | from net_modules.deconv import _kernel 8 | from net_modules.deconv import _stride 9 | 10 | 11 | class _conv2d(prettytensor.VarStoreMethod): 12 | __call__ = pt_conv2d.__call__ 13 | 14 | 15 | @prettytensor.Register( 16 | assign_defaults=('activation_fn', 'l2loss', 'batch_normalize', 17 | 'parameter_modifier', 'phase')) 18 | class nconv2d: 19 | 20 | tmp_graph = tf.Graph() 21 | 22 | def __init__(self): 23 | self._internal_conv2d = _conv2d() 24 | 25 | def __call__( 26 | self, 27 | input_layer, 28 | kernel, 29 | depth, 30 | activation_fn=None, 31 | stride=(1, 1), 32 | l2loss=None, 33 | weights=None, 34 | bias=tf.zeros_initializer(), 35 | edges=PAD_SAME, 36 | batch_normalize=False, 37 | phase=prettytensor.Phase.train, 38 | parameter_modifier=parameters.identity, 39 | name=PROVIDED 40 | ): 41 | 42 | # compute output size 43 | input_shape = input_layer.shape 44 | input_mask_shape = [1]+input_shape[1:3]+[1] 45 | with self.tmp_graph.as_default(), tf.device("/cpu:0"): 46 | fake_output_mask = tf.nn.conv2d( 47 | input=tf.zeros(shape=input_mask_shape, dtype=tf.float32), 48 | filter=tf.zeros(shape=_kernel(kernel)+[1, 1], dtype=tf.float32), 49 | strides=_stride(stride), 50 | padding=edges 51 | ) 52 | output_mask_shape = tmf.get_shape(fake_output_mask) 53 | 54 | # generate input mask 55 | input_mask = _deconv_mask( 56 | output_mask_shape, input_mask_shape, kernel, stride, edges, input_layer.dtype 57 | ) 58 | 59 | input_layer.with_tensor( 60 | _gradient_scale(input_layer.tensor, input_mask) 61 | ) 62 | 63 | output_layer = self._internal_conv2d( 64 | input_layer, 65 | kernel, 66 | depth, 67 | activation_fn, 68 | stride, 69 | l2loss, 70 | weights, 71 | bias, 72 | edges, 73 | batch_normalize, 74 | phase, 75 | parameter_modifier, 76 | name 77 | ) 78 | 79 | return output_layer 80 | 81 | 82 | def _gradient_scale(data, mask): 83 | 84 | data_grad_mask = (data - tf.stop_gradient(data) + 1.) * mask 85 | unscaled_data = tf.stop_gradient(data / mask) * data_grad_mask 86 | return unscaled_data 87 | -------------------------------------------------------------------------------- /exp-ae-cat-10.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | import sys 4 | from copy import copy 5 | from model.pipeline import Pipeline 6 | 7 | from tensorflow.python import debug as tf_debug 8 | 9 | 10 | if __name__ == "__main__": 11 | 12 | num_keypoints = 10 13 | patch_feature_dim = 8 14 | decoding_levels = 5 15 | kp_transform_loss = 5000 16 | 17 | base_recon_weight = 0.0001 18 | recon_weight = Pipeline.ValueScheduler( 19 | "piecewise_constant", 20 | [100000, 200000], 21 | [base_recon_weight, base_recon_weight*10, base_recon_weight*100] 22 | ) 23 | 24 | 25 | base_learning_rate=0.001 26 | learning_rate = Pipeline.ValueScheduler( 27 | "piecewise_constant", 28 | [100000, 200000], 29 | [base_learning_rate, base_learning_rate*0.1, base_learning_rate*0.01] 30 | ) 31 | 32 | keypoint_separation_bandwidth=0.08 33 | keypoint_separation_loss_weight = 20.0 34 | 35 | opt = { 36 | "optimizer": "Adam", 37 | "data_name": "cat_80x80", 38 | "recon_name": "gaussian_fixedvar_in_01", 39 | "encoder_name": "general_80x80", 40 | "decoder_name": "general_80x80", 41 | "latent_dim": num_keypoints*2+(num_keypoints+1)*patch_feature_dim, 42 | "train_color_jittering": True, 43 | "train_random_mirroring": False, 44 | "train_batch_size": 16, 45 | "train_shuffle_capacity": 1000, 46 | "learning_rate": learning_rate, 47 | "max_epochs": 2000, 48 | "weight_decay": 1e-6, 49 | "test_steps": 5000, 50 | "test_limit": 200, 51 | "recon_weight": recon_weight, 52 | } 53 | opt["encoder_options"] = { 54 | "keypoint_num": num_keypoints, 55 | "patch_feature_dim": patch_feature_dim, 56 | "ae_recon_type": opt["recon_name"], 57 | "keypoint_concentration_loss_weight": 100., 58 | "keypoint_axis_balancing_loss_weight": 200., 59 | "keypoint_separation_loss_weight": keypoint_separation_loss_weight, 60 | "keypoint_separation_bandwidth": keypoint_separation_bandwidth, 61 | "keypoint_transform_loss_weight": kp_transform_loss, 62 | "keypoint_decoding_heatmap_levels": decoding_levels, 63 | "keypoint_decoding_heatmap_level_base": 0.5**(1/2), 64 | "image_channels": 3, 65 | } 66 | opt["decoder_options"] = copy(opt["encoder_options"]) 67 | 68 | 69 | # ------------------------------------- 70 | model_dir = os.path.join("results/cat_10") 71 | vp = Pipeline(None, opt, model_dir=model_dir) 72 | print(vp.opt) 73 | with vp.graph.as_default(): 74 | sess = vp.create_session() 75 | vp.run_full_train(sess, restore=True) 76 | vp.run_full_test(sess) 77 | 78 | -------------------------------------------------------------------------------- /exp-ae-cat-20.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | import sys 4 | from copy import copy 5 | from model.pipeline import Pipeline 6 | 7 | from tensorflow.python import debug as tf_debug 8 | 9 | 10 | if __name__ == "__main__": 11 | 12 | num_keypoints = 20 13 | patch_feature_dim = 8 14 | decoding_levels = 5 15 | kp_transform_loss = 5000 16 | 17 | base_recon_weight = 0.0001 18 | recon_weight = Pipeline.ValueScheduler( 19 | "piecewise_constant", 20 | [100000, 200000], 21 | [base_recon_weight, base_recon_weight*10, base_recon_weight*100] 22 | ) 23 | 24 | 25 | base_learning_rate=0.001 26 | learning_rate = Pipeline.ValueScheduler( 27 | "piecewise_constant", 28 | [100000, 200000], 29 | [base_learning_rate, base_learning_rate*0.1, base_learning_rate*0.01] 30 | ) 31 | 32 | keypoint_separation_bandwidth=0.05 33 | keypoint_separation_loss_weight = 10.0 34 | 35 | opt = { 36 | "optimizer": "Adam", 37 | "data_name": "cat_80x80", 38 | "recon_name": "gaussian_fixedvar_in_01", 39 | "encoder_name": "general_80x80", 40 | "decoder_name": "general_80x80", 41 | "latent_dim": num_keypoints*2+(num_keypoints+1)*patch_feature_dim, 42 | "train_color_jittering": True, 43 | "train_random_mirroring": False, 44 | "train_batch_size": 16, 45 | "train_shuffle_capacity": 1000, 46 | "learning_rate": learning_rate, 47 | "max_epochs": 2000, 48 | "weight_decay": 1e-6, 49 | "test_steps": 5000, 50 | "test_limit": 200, 51 | "recon_weight": recon_weight, 52 | } 53 | opt["encoder_options"] = { 54 | "keypoint_num": num_keypoints, 55 | "patch_feature_dim": patch_feature_dim, 56 | "ae_recon_type": opt["recon_name"], 57 | "keypoint_concentration_loss_weight": 100., 58 | "keypoint_axis_balancing_loss_weight": 200., 59 | "keypoint_separation_loss_weight": keypoint_separation_loss_weight, 60 | "keypoint_separation_bandwidth": keypoint_separation_bandwidth, 61 | "keypoint_transform_loss_weight": kp_transform_loss, 62 | "keypoint_decoding_heatmap_levels": decoding_levels, 63 | "keypoint_decoding_heatmap_level_base": 0.5**(1/2), 64 | "image_channels": 3, 65 | } 66 | opt["decoder_options"] = copy(opt["encoder_options"]) 67 | 68 | 69 | # ------------------------------------- 70 | model_dir = os.path.join("results/cat_20") 71 | vp = Pipeline(None, opt, model_dir=model_dir) 72 | print(vp.opt) 73 | with vp.graph.as_default(): 74 | sess = vp.create_session() 75 | vp.run_full_train(sess, restore=True) 76 | vp.run_full_test(sess) 77 | 78 | -------------------------------------------------------------------------------- /exp-ae-celeba-mafl-30.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | import sys 4 | from copy import copy 5 | from model.pipeline import Pipeline 6 | 7 | from tensorflow.python import debug as tf_debug 8 | 9 | 10 | if __name__ == "__main__": 11 | 12 | num_keypoints = 30 13 | patch_feature_dim = 8 14 | decoding_levels = 5 15 | kp_transform_loss = 1e4 16 | 17 | base_recon_weight = 0.1 18 | recon_weight = Pipeline.ValueScheduler( 19 | "piecewise_constant", 20 | [100000, 200000], 21 | [base_recon_weight, base_recon_weight*100, base_recon_weight*1000] 22 | ) 23 | 24 | base_learning_rate=0.01 25 | learning_rate = Pipeline.ValueScheduler( 26 | "piecewise_constant", 27 | [100000, 200000], 28 | [base_learning_rate, base_learning_rate*0.1, base_learning_rate*0.01] 29 | ) 30 | 31 | keypoint_separation_bandwidth=0.04 32 | keypoint_separation_loss_weight = 10 33 | 34 | opt = { 35 | "optimizer": "Adam", 36 | "data_name": "celeba_mafl_100x100_80x80", 37 | "recon_name": "gaussian_fixedvar_in_01", 38 | "encoder_name": "general_80x80", 39 | "decoder_name": "general_80x80", 40 | "latent_dim": num_keypoints*2+(num_keypoints+1)*patch_feature_dim, 41 | "train_color_jittering": True, 42 | "train_random_mirroring": False, 43 | "train_batch_size": 8, 44 | "train_shuffle_capacity": 1000, 45 | "learning_rate": learning_rate, 46 | "max_epochs": 2000, 47 | "weight_decay": 1e-6, 48 | "test_steps": 5000, 49 | "test_limit": 200, 50 | "recon_weight": recon_weight, 51 | } 52 | opt["encoder_options"] = { 53 | "keypoint_num": num_keypoints, 54 | "patch_feature_dim": patch_feature_dim, 55 | "ae_recon_type": opt["recon_name"], 56 | "keypoint_concentration_loss_weight": 100., 57 | "keypoint_axis_balancing_loss_weight": 200., 58 | "keypoint_separation_loss_weight": keypoint_separation_loss_weight, 59 | "keypoint_separation_bandwidth": keypoint_separation_bandwidth, 60 | "keypoint_transform_loss_weight": kp_transform_loss, 61 | "keypoint_decoding_heatmap_levels": decoding_levels, 62 | "keypoint_decoding_heatmap_level_base": 0.5**(1/2), 63 | "image_channels": 3, 64 | 65 | } 66 | opt["decoder_options"] = copy(opt["encoder_options"]) 67 | 68 | 69 | # ------------------------------------- 70 | model_dir = os.path.join("results/celeba_30") 71 | vp = Pipeline(None, opt, model_dir=model_dir) 72 | print(vp.opt) 73 | with vp.graph.as_default(): 74 | sess = vp.create_session() 75 | vp.run_full_train(sess, restore=True) 76 | vp.run_full_test(sess) 77 | 78 | -------------------------------------------------------------------------------- /download_cat.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # =========================== 3 | # Usage: ./setup.sh (model|data)? 4 | 5 | if wget --help | grep -q 'show-progress'; then 6 | WGET_FLAG="-q --show-progress" 7 | else 8 | WGET_FLAG="" 9 | fi 10 | 11 | # create a tmp directory for the downloading data 12 | TMP_DIR="./tmp_download" 13 | mkdir -p "${TMP_DIR}" 14 | 15 | #create the directory for the pre-trained model 16 | MODEL_DIR="./pretrained_results" 17 | mkdir -p "${MODEL_DIR}" 18 | 19 | #create the directory for the dataset 20 | DATA_DIR="./data" 21 | mkdir -p "${DATA_DIR}" 22 | 23 | # downloading model 24 | download_model() 25 | { 26 | # directory for aflw model 27 | TMP_MODEL_TAR_BALL="${TMP_DIR}/cat_pretrained_results.tar.gz" 28 | 29 | MODEL_URL="http://files.ytzhang.net/lmdis-rep/release-v1/cat/cat_pretrained_results.tar.gz" 30 | echo "Downloading pre-trained models ..." 31 | wget ${WGET_FLAG} "${MODEL_URL}" -O "${TMP_MODEL_TAR_BALL}" 32 | echo "Uncompressing pre-trained models ..." 33 | tar -xzf "${TMP_MODEL_TAR_BALL}" -C "${TMP_DIR}" 34 | 35 | # move model to default directories 36 | echo "Move pre-trained image network model to ${MODEL_DIR} ..." 37 | mv "${TMP_DIR}/cat_pretrained_results/cat_10" "${MODEL_DIR}/cat_10" 38 | mv "${TMP_DIR}/cat_pretrained_results/cat_20" "${MODEL_DIR}/cat_20" 39 | 40 | } 41 | 42 | # downloading data 43 | download_data() 44 | { 45 | # directory for cat data 46 | TMP_DATA_TAR_BALL="${DATA_DIR}/cat_data.tar.gz" 47 | DATA_URL="http://files.ytzhang.net/lmdis-rep/release-v1/cat/cat_data.tar.gz" 48 | echo "Downloading data ..." 49 | wget ${WGET_FLAG} "${DATA_URL}" -O "${TMP_DATA_TAR_BALL}" 50 | echo "Uncompressing data ..." 51 | tar -xzf "${TMP_DATA_TAR_BALL}" -C "${DATA_DIR}" 52 | rm -rf "${TMP_DATA_TAR_BALL}" 53 | 54 | TMP_IMAGE_TAR_BALL="${DATA_DIR}/cat_images.tar.gz" 55 | IMAGE_URL="http://files.ytzhang.net/lmdis-rep/release-v1/cat/cat_images.tar.gz" 56 | echo "Downloading images ..." 57 | wget ${WGET_FLAG} "${IMAGE_URL}" -O "${TMP_IMAGE_TAR_BALL}" 58 | echo "Uncompressing images ..." 59 | tar -xzf "${TMP_IMAGE_TAR_BALL}" -C "${DATA_DIR}" 60 | rm -rf "${TMP_IMAGE_TAR_BALL}" 61 | } 62 | 63 | # default to download all 64 | if [ $# -eq 0 ]; then 65 | download_model 66 | download_data 67 | else 68 | case $1 in 69 | "model") download_model 70 | ;; 71 | "data") download_data 72 | ;; 73 | *) echo "Usage: ./setup.sh [OPTION]" 74 | echo "" 75 | echo "No option will download both model and data." 76 | echo "" 77 | echo "OPTION:\n\tmodel: only download the pre-trained models (.npy)" 78 | echo "\tdata: only download the data(.json)" 79 | ;; 80 | esac 81 | fi 82 | 83 | # clear the tmp files 84 | rm -rf "${TMP_DIR}" 85 | -------------------------------------------------------------------------------- /exp-ae-celeba-mafl-10.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | import sys 4 | from copy import copy 5 | from model.pipeline import Pipeline 6 | 7 | from tensorflow.python import debug as tf_debug 8 | 9 | 10 | if __name__ == "__main__": 11 | 12 | num_keypoints = 10 13 | patch_feature_dim = 8 14 | decoding_levels = 5 15 | kp_transform_loss = 1e4 16 | 17 | base_recon_weight = 0.1 18 | recon_weight = Pipeline.ValueScheduler( 19 | "piecewise_constant", 20 | [100000, 200000], 21 | [base_recon_weight, base_recon_weight*10, base_recon_weight*100] 22 | ) 23 | 24 | base_learning_rate=0.001 25 | learning_rate = Pipeline.ValueScheduler( 26 | "piecewise_constant", 27 | [100000, 200000], 28 | [base_learning_rate, base_learning_rate*0.1, base_learning_rate*0.01] 29 | ) 30 | 31 | keypoint_separation_bandwidth=0.08 32 | keypoint_separation_loss_weight = 20 33 | 34 | opt = { 35 | "optimizer": "Adam", 36 | "data_name": "celeba_mafl_100x100_80x80", 37 | "recon_name": "gaussian_fixedvar_in_01", 38 | "encoder_name": "general_80x80", 39 | "decoder_name": "general_80x80", 40 | "latent_dim": num_keypoints*2+(num_keypoints+1)*patch_feature_dim, 41 | "train_color_jittering": True, 42 | "train_random_mirroring": False, 43 | "train_batch_size": 16, 44 | "train_shuffle_capacity": 1000, 45 | "learning_rate": learning_rate, 46 | "max_epochs": 2000, 47 | "weight_decay": 1e-6, 48 | "test_steps": 5000, 49 | "test_limit": 200, 50 | "recon_weight": recon_weight, 51 | #"keep_checkpoint_every_n_hours": 0.1 52 | } 53 | opt["encoder_options"] = { 54 | "keypoint_num": num_keypoints, 55 | "patch_feature_dim": patch_feature_dim, 56 | "ae_recon_type": opt["recon_name"], 57 | "keypoint_concentration_loss_weight": 100., 58 | "keypoint_axis_balancing_loss_weight": 200. , 59 | "keypoint_separation_loss_weight": keypoint_separation_loss_weight, 60 | "keypoint_separation_bandwidth": keypoint_separation_bandwidth, 61 | "keypoint_transform_loss_weight": kp_transform_loss, 62 | "keypoint_decoding_heatmap_levels": decoding_levels, 63 | "keypoint_decoding_heatmap_level_base": 0.5**(1/2), 64 | "image_channels": 3, 65 | 66 | } 67 | opt["decoder_options"] = copy(opt["encoder_options"]) 68 | 69 | 70 | # ------------------------------------- 71 | model_dir = os.path.join("results/celeba_10") 72 | vp = Pipeline(None, opt, model_dir=model_dir) 73 | print(vp.opt) 74 | with vp.graph.as_default(): 75 | sess = vp.create_session() 76 | vp.run_full_train(sess, restore=True) 77 | vp.run_full_test(sess) 78 | 79 | -------------------------------------------------------------------------------- /model/options.py: -------------------------------------------------------------------------------- 1 | from runner.train_pipeline_traindef import TrainDefOptions 2 | 3 | 4 | # default options of the model 5 | class ModelOptionDefinition: 6 | 7 | @staticmethod 8 | def main_model(p): 9 | 10 | # non-linearity 11 | p["non_linearity"] = "leaky_relu" 12 | 13 | # dataset name 14 | p["data_name"] = "mnist_binary_rbm" 15 | p["data_options"] = dict() 16 | 17 | p["test_data_name"] = p["data_name"] 18 | p["test_data_options"] = p["data_options"] 19 | 20 | # network names 21 | if p["data_name"] in ("mnist", ): 22 | p["latent_dim"] = 10 23 | p["encoder_name"] = "PLACE_HODLER" 24 | p["recon_name"] = "bernoullix" 25 | else: 26 | p["latent_dim"] = 128 27 | p["encoder_name"] = "PLACE_HODLER" 28 | p["recon_name"] = "guassian_in_01" 29 | 30 | p["decoder_name"] = p["encoder_name"] # by default, use the paired encoder and decoder 31 | 32 | # latent name 33 | p["decoder_options"] = dict() 34 | p["recon_options"] = dict() 35 | 36 | if p["recon_name"] in {"gaussian_in_01", "gaussian_fixedvar_in_01"}: 37 | p.set( 38 | "recon_options", 39 | { 40 | **{"stddev_lower_bound": 0.05}, # 0.005 41 | **p["recon_options"] 42 | } 43 | ) 44 | 45 | p["recon_weight"] = 1.0 46 | p["encoder_options"] = dict() 47 | 48 | p["condition_list"] = [] 49 | 50 | @staticmethod 51 | def model(p): 52 | p.require("main_model") 53 | 54 | 55 | # default options for the pipeline 56 | class PipelineOptionDefinition(ModelOptionDefinition, TrainDefOptions): 57 | 58 | @staticmethod 59 | def train(p): 60 | 61 | p.require("model") 62 | p.include("trainer") 63 | 64 | # big datasets 65 | if p["data_name"] in ( 66 | "CelebA", "celeba_80x80_landmark", 67 | "human_128x128", "human_128x128_landmark" 68 | ): 69 | assert p["train_shuffle_capacity"] != "full", "should not use full train_shuffle_capacity for large dataset" 70 | 71 | 72 | p["train_subset"] = "train" 73 | 74 | p["data_class_list"] = None 75 | 76 | # preprocessing 77 | if p["data_name"] in ("mnist",): 78 | p["train_color_jittering"] = False 79 | 80 | @staticmethod 81 | def test(p): 82 | 83 | p["rotate_batch_samples"] = False 84 | 85 | p["test_subset"] = "test" 86 | p["test_limit"] = None 87 | p.require("train") 88 | p["test_batch_size"] = p["train_batch_size"] 89 | 90 | @staticmethod 91 | def all(p): 92 | 93 | p.require("model") 94 | p.require("train") 95 | p.require("test") 96 | 97 | p.finalize() 98 | 99 | -------------------------------------------------------------------------------- /net_modules/pixel_bias.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | import collections 4 | 5 | from prettytensor import layers 6 | from prettytensor import parameters 7 | from prettytensor import pretty_tensor_class as prettytensor 8 | from prettytensor.pretty_tensor_class import PROVIDED 9 | 10 | 11 | @prettytensor.Register(assign_defaults=('activation_fn', 'parameter_modifier', 'phase')) 12 | class pixel_bias(prettytensor.VarStoreMethod): 13 | 14 | def __call__( 15 | self, input_layer, activation_fn=None, bias=tf.zeros_initializer(), phase=prettytensor.Phase.train, 16 | parameter_modifier=parameters.identity, name=PROVIDED 17 | ): 18 | """ 19 | Adds the parameters for a fully connected layer and returns a tensor. 20 | The current PrettyTensor must have rank 2. 21 | Args: 22 | input_layer: The Pretty Tensor object, supplied. 23 | size: The number of neurons 24 | bias: An initializer for the bias or a Tensor. No bias if set to None. 25 | phase: The phase of graph construction. See `pt.Phase`. 26 | parameter_modifier: A function to modify parameters that is applied after 27 | creation and before use. 28 | name: The name for this operation is also used to create/find the 29 | parameter variables. 30 | Returns: 31 | A Pretty Tensor handle to the layer. 32 | Raises: 33 | ValueError: if the Pretty Tensor is not rank 2 or the number of input 34 | nodes (second dim) is not known. 35 | """ 36 | 37 | if input_layer.get_shape().ndims != 4: 38 | raise ValueError( 39 | 'pixel_bias requires a rank 4 Tensor with known second ' 40 | 'dimension: %s' % input_layer.get_shape()) 41 | if input_layer.shape[1] is None or input_layer.shape[2] is None or input_layer.shape[3] is None: 42 | raise ValueError('input size must be known.') 43 | 44 | x = input_layer.tensor 45 | dtype = input_layer.dtype 46 | books = input_layer.bookkeeper 47 | b = parameter_modifier( 48 | 'bias', 49 | self.variable('bias', input_layer.shape[2:], bias, dt=dtype), 50 | phase) 51 | y = x + tf.expand_dims(b, axis=0) 52 | 53 | if activation_fn is not None: 54 | if not isinstance(activation_fn, collections.Sequence): 55 | activation_fn = (activation_fn,) 56 | y = layers.apply_activation(books, 57 | y, 58 | activation_fn[0], 59 | activation_args=activation_fn[1:]) 60 | books.add_histogram_summary(y, '%s/activations' % y.op.name) 61 | 62 | return input_layer.with_tensor(y, parameters=self.vars) 63 | 64 | # pylint: enable=invalid-name 65 | -------------------------------------------------------------------------------- /download_aflw.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # =========================== 3 | # Usage: ./setup.sh (model|data)? 4 | 5 | if wget --help | grep -q 'show-progress'; then 6 | WGET_FLAG="-q --show-progress" 7 | else 8 | WGET_FLAG="" 9 | fi 10 | 11 | # create a tmp directory for the downloading data 12 | TMP_DIR="./tmp_download" 13 | mkdir -p "${TMP_DIR}" 14 | 15 | #create the directory for the pre-trained model 16 | MODEL_DIR="./pretrained_results" 17 | mkdir -p "${MODEL_DIR}" 18 | 19 | #create the directory for the dataset 20 | DATA_DIR="./data" 21 | mkdir -p "${DATA_DIR}" 22 | 23 | # downloading model 24 | download_model() 25 | { 26 | # directory for aflw model 27 | TMP_MODEL_TAR_BALL="${TMP_DIR}/aflw_pretrained_results.tar.gz" 28 | 29 | MODEL_URL="http://files.ytzhang.net/lmdis-rep/release-v1/aflw/aflw_pretrained_results.tar.gz" 30 | echo "Downloading pre-trained models ..." 31 | wget ${WGET_FLAG} "${MODEL_URL}" -O "${TMP_MODEL_TAR_BALL}" 32 | echo "Uncompressing pre-trained models ..." 33 | tar -xzf "${TMP_MODEL_TAR_BALL}" -C "${TMP_DIR}" 34 | 35 | # move model to default directories 36 | echo "Move pre-trained image network model to ${MODEL_DIR} ..." 37 | mv "${TMP_DIR}/aflw_pretrained_results/aflw_10" "${MODEL_DIR}/aflw_10" 38 | mv "${TMP_DIR}/aflw_pretrained_results/aflw_30" "${MODEL_DIR}/aflw_30" 39 | 40 | } 41 | 42 | # downloading data 43 | download_data() 44 | { 45 | # directory for celeba data 46 | TMP_DATA_TAR_BALL="${DATA_DIR}/aflw_data.tar.gz" 47 | DATA_URL="http://files.ytzhang.net/lmdis-rep/release-v1/aflw/aflw_data.tar.gz" 48 | echo "Downloading data ..." 49 | wget ${WGET_FLAG} "${DATA_URL}" -O "${TMP_DATA_TAR_BALL}" 50 | echo "Uncompressing data ..." 51 | tar -xzf "${TMP_DATA_TAR_BALL}" -C "${DATA_DIR}" 52 | rm -rf "${TMP_DATA_TAR_BALL}" 53 | 54 | TMP_IMAGE_TAR_BALL="${DATA_DIR}/aflw_images.tar.gz" 55 | IMAGE_URL="http://files.ytzhang.net/lmdis-rep/release-v1/aflw/aflw_images.tar.gz" 56 | echo "Downloading images ..." 57 | wget ${WGET_FLAG} "${IMAGE_URL}" -O "${TMP_IMAGE_TAR_BALL}" 58 | echo "Uncompressing images ..." 59 | tar -xzf "${TMP_IMAGE_TAR_BALL}" -C "${DATA_DIR}" 60 | mv "${DATA_DIR}/output" "${DATA_DIR}/aflw_images" 61 | rm -rf "${TMP_IMAGE_TAR_BALL}" 62 | } 63 | 64 | # default to download all 65 | if [ $# -eq 0 ]; then 66 | download_model 67 | download_data 68 | else 69 | case $1 in 70 | "model") download_model 71 | ;; 72 | "data") download_data 73 | ;; 74 | *) echo "Usage: ./setup.sh [OPTION]" 75 | echo "" 76 | echo "No option will download both model and data." 77 | echo "" 78 | echo "OPTION:\n\tmodel: only download the pre-trained models (.npy)" 79 | echo "\tdata: only download the data(.json)" 80 | ;; 81 | esac 82 | fi 83 | 84 | # clear the tmp files 85 | rm -rf "${TMP_DIR}" 86 | -------------------------------------------------------------------------------- /runner/train_pipeline.py: -------------------------------------------------------------------------------- 1 | import zutils.tf_graph_utils as tgu 2 | from zutils.py_utils import * 3 | 4 | from runner.base_pipeline import Pipeline as BasePipeline 5 | 6 | from runner.train_pipeline_traindef import TrainDef as PipelineTrainDef 7 | from runner.train_pipeline_traindef import TrainDefOptions as PipelineTrainDefOptions 8 | 9 | from runner.data_op import tf_variable_from_data_module 10 | 11 | import runner.preprocessing_data_module_wrapper 12 | 13 | TrainDefOptions = PipelineTrainDefOptions 14 | 15 | 16 | class Pipeline(BasePipeline): 17 | 18 | TrainDef = PipelineTrainDef 19 | TrainDefOptions = PipelineTrainDefOptions 20 | 21 | def __init__(self, model_dir=None, **kwarg): 22 | 23 | super().__init__(model_dir, **kwarg) 24 | 25 | # gpu names 26 | self.ps_device_name, self.gpu_names = self._init_gpu_names() 27 | 28 | @staticmethod 29 | def _init_gpu_names(num_gpus=None, ps_cpu=True): 30 | # check devices 31 | ps_device = None 32 | gpu_names = None 33 | all_gpu_names = tgu.get_available_gpus() 34 | assert len(all_gpu_names)<=1, \ 35 | "ERROR: this code does not support multiple devices. Use CUDA_VISIBLE_DEVICES=... to specify the GPU." 36 | 37 | return ps_device, gpu_names 38 | 39 | def init_logger_saver(self): 40 | 41 | self.init_logger() 42 | 43 | with self.graph.as_default(): 44 | 45 | print("* Link logger with trainer") 46 | for train_obj in self.train.groups.values(): 47 | train_obj.trainer.set_logger(self.logger) 48 | 49 | # saver 50 | print("* Define net+trainer saver") 51 | saver_kwargs = dict() 52 | if self.opt.keep_checkpoint_every_n_hours is not None: 53 | saver_kwargs["keep_checkpoint_every_n_hours"] = self.opt.keep_checkpoint_every_n_hours 54 | if self.opt.max_checkpoint_to_keep is not None: 55 | saver_kwargs["max_to_keep"] = self.opt.max_checkpoint_to_keep 56 | self.saver = tgu.MultiDeviceSaver(var_list=self.train.saveable_variables, **saver_kwargs) 57 | print("* Define net saver") 58 | self.net_saver = tgu.MultiDeviceSaver(var_list=self.train.net_variables, **saver_kwargs) 59 | 60 | def data_module_preprocess(self, raw_data_module, mode=None): 61 | return runner.preprocessing_data_module_wrapper.Net(raw_data_module, self.opt_dict, mode) 62 | 63 | def data_module_to_tensor(self, raw_data_module, data_fields, batch_size): 64 | preprocessed_data_module = self.data_module_preprocess(raw_data_module, mode="deterministic") 65 | return tf_variable_from_data_module(preprocessed_data_module, batch_size, data_fields) 66 | 67 | def output_scaled_color_image(self, im): 68 | if self.opt.image_color_scaling: 69 | im = (im - (1-self.opt.image_color_scaling)*0.5) / self.opt.image_color_scaling 70 | return im 71 | -------------------------------------------------------------------------------- /nets/decoder/car_64x64.py: -------------------------------------------------------------------------------- 1 | import prettytensor as pt 2 | import tensorflow as tf 3 | from net_modules.common import * 4 | 5 | from net_modules.hourglass import hourglass 6 | 7 | import zutils.tf_math_funcs as tmf 8 | 9 | import zutils.pt_utils as ptu 10 | 11 | from net_modules.auto_struct.generic_decoder import Factory as GenericDecoderFactory 12 | from net_modules.auto_struct.keypoint_decoder import Factory as BaseFactory 13 | 14 | class Factory(BaseFactory): 15 | 16 | @staticmethod 17 | def pt_defaults_scope_value(): 18 | return { 19 | # non-linearity 20 | 'activation_fn': default_activation.current_value, 21 | # for batch normalization 22 | 'batch_normalize': True, 23 | 'learned_moments_update_rate': 0.0003, 24 | # 'learned_moments_update_rate': 1., 25 | 'variance_epsilon': 0.001, 26 | 'scale_after_normalization': True 27 | } 28 | 29 | default_patch_feature_dim = 8 30 | 31 | def __init__(self, recon_dist_param_num=1, options=None): 32 | super().__init__(recon_dist_param_num, options) 33 | if "image_channels" in options: 34 | self.image_channels = options["image_channels"] 35 | else: 36 | self.image_channels = 3 37 | 38 | def image_size(self): 39 | return 64, 64 40 | 41 | def input_feature_dim(self): 42 | return 64 43 | 44 | def feature2image(self, feature_tensor): 45 | output_channels = 3*self.recon_dist_param_num 46 | 47 | hgd = [ 48 | {"type": "conv2d", "depth": 64, "decoder_depth": output_channels, "decoder_activation_fn": None}, 49 | {"type": "conv2d", "depth": 64, "decoder_depth": 32}, 50 | {"type": "skip", "layer_num": 2}, 51 | {"type": "pool", "pool": "max", "kernel": 2, "stride": 2}, # 40 x 40 52 | {"type": "conv2d", "depth": 128, "decoder_depth": 64}, 53 | {"type": "skip", "layer_num": 2}, 54 | {"type": "pool", "pool": "max", "kernel": 2, "stride": 2}, # 20x20 55 | {"type": "conv2d", "depth": 256}, 56 | {"type": "skip", "layer_num": 2}, 57 | {"type": "pool", "pool": "max", "kernel": 2, "stride": 2}, # 10x10 58 | {"type": "conv2d", "depth": 512}, 59 | {"type": "skip", "layer_num": 2}, 60 | {"type": "pool", "pool": "max", "kernel": 2, "stride": 2}, # 5x5 61 | {"type": "conv2d", "depth": 512}, 62 | ] 63 | 64 | 65 | with pt.defaults_scope(**self.pt_defaults_scope_value()): 66 | output_tensor = hourglass( 67 | feature_tensor, hgd, 68 | net_type=self.options["hourglass_type"] if "hourglass_type" in self.options else None, 69 | extra_highlevel_feature=None 70 | ) 71 | return output_tensor 72 | 73 | rotate_dominating_features_if_necessary = GenericDecoderFactory.rotate_dominating_features_if_necessary 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /nets/decoder/general_80x80.py: -------------------------------------------------------------------------------- 1 | import prettytensor as pt 2 | import tensorflow as tf 3 | from net_modules.common import * 4 | 5 | from net_modules.hourglass import hourglass 6 | 7 | 8 | import zutils.tf_math_funcs as tmf 9 | 10 | import zutils.pt_utils as ptu 11 | 12 | from net_modules.auto_struct.generic_decoder import Factory as GenericDecoderFactory 13 | from net_modules.auto_struct.keypoint_decoder import Factory as BaseFactory 14 | 15 | class Factory(BaseFactory): 16 | 17 | @staticmethod 18 | def pt_defaults_scope_value(): 19 | return { 20 | # non-linearity 21 | 'activation_fn': default_activation.current_value, 22 | # for batch normalization 23 | 'batch_normalize': True, 24 | 'learned_moments_update_rate': 0.0003, 25 | # 'learned_moments_update_rate': 1., 26 | 'variance_epsilon': 0.001, 27 | 'scale_after_normalization': True 28 | } 29 | 30 | default_patch_feature_dim = 8 31 | 32 | def __init__(self, recon_dist_param_num=1, options=None): 33 | super().__init__(recon_dist_param_num, options) 34 | if "image_channels" in options: 35 | self.image_channels = options["image_channels"] 36 | else: 37 | self.image_channels = 3 38 | 39 | def image_size(self): 40 | return 80, 80 41 | 42 | 43 | def input_feature_dim(self): 44 | return 64 45 | 46 | def feature2image(self, feature_tensor): 47 | 48 | output_channels = 3*self.recon_dist_param_num 49 | hgd = [ 50 | {"type": "conv2d", "depth": 64, "decoder_depth": output_channels, "decoder_activation_fn": None}, 51 | {"type": "conv2d", "depth": 64, "decoder_depth": 32}, 52 | {"type": "skip", "layer_num": 2}, 53 | {"type": "pool", "pool": "max", "kernel": 2, "stride": 2}, # 40 x 40 54 | {"type": "conv2d", "depth": 128, "decoder_depth": 64}, 55 | {"type": "skip", "layer_num": 2}, 56 | {"type": "pool", "pool": "max", "kernel": 2, "stride": 2}, # 20x20 57 | {"type": "conv2d", "depth": 256}, 58 | {"type": "skip", "layer_num": 2}, 59 | {"type": "pool", "pool": "max", "kernel": 2, "stride": 2}, # 10x10 60 | {"type": "conv2d", "depth": 512}, 61 | {"type": "skip", "layer_num": 2}, 62 | {"type": "pool", "pool": "max", "kernel": 2, "stride": 2}, # 5x5 63 | {"type": "conv2d", "depth": 512}, 64 | ] 65 | 66 | with pt.defaults_scope(**self.pt_defaults_scope_value()): 67 | output_tensor = hourglass( 68 | feature_tensor, hgd, 69 | net_type=self.options["hourglass_type"] if "hourglass_type" in self.options else None, 70 | extra_highlevel_feature=None 71 | ) 72 | 73 | return output_tensor 74 | 75 | rotate_dominating_features_if_necessary = GenericDecoderFactory.rotate_dominating_features_if_necessary 76 | 77 | -------------------------------------------------------------------------------- /nets/decoder/general_64x64.py: -------------------------------------------------------------------------------- 1 | import prettytensor as pt 2 | import tensorflow as tf 3 | from net_modules.common import * 4 | 5 | from net_modules.hourglass import hourglass 6 | 7 | import zutils.tf_math_funcs as tmf 8 | 9 | import zutils.pt_utils as ptu 10 | 11 | from net_modules.auto_struct.generic_decoder import Factory as GenericDecoderFactory 12 | from net_modules.auto_struct.keypoint_decoder import Factory as BaseFactory 13 | 14 | 15 | class Factory(BaseFactory): 16 | 17 | @staticmethod 18 | def pt_defaults_scope_value(): 19 | return { 20 | # non-linearity 21 | 'activation_fn': default_activation.current_value, 22 | # for batch normalization 23 | 'batch_normalize': True, 24 | 'learned_moments_update_rate': 0.0003, 25 | # 'learned_moments_update_rate': 1., 26 | 'variance_epsilon': 0.001, 27 | 'scale_after_normalization': True 28 | } 29 | 30 | default_patch_feature_dim = 8 31 | 32 | def __init__(self, recon_dist_param_num=1, options=None): 33 | super().__init__(recon_dist_param_num, options) 34 | if "image_channels" in options: 35 | self.image_channels = options["image_channels"] 36 | else: 37 | self.image_channels = 3 38 | 39 | def image_size(self): 40 | return 64, 64 41 | 42 | def input_feature_dim(self): 43 | return 64 44 | 45 | def feature2image(self, feature_tensor): 46 | output_channels = 3*self.recon_dist_param_num 47 | 48 | hgd = [ 49 | {"type": "conv2d", "depth": 64, "decoder_depth": output_channels, "decoder_activation_fn": None}, 50 | {"type": "conv2d", "depth": 64, "decoder_depth": 32}, 51 | {"type": "skip", "layer_num": 2}, 52 | {"type": "pool", "pool": "max", "kernel": 2, "stride": 2}, # 40 x 40 53 | {"type": "conv2d", "depth": 128, "decoder_depth": 64}, 54 | {"type": "skip", "layer_num": 2}, 55 | {"type": "pool", "pool": "max", "kernel": 2, "stride": 2}, # 20x20 56 | {"type": "conv2d", "depth": 256}, 57 | {"type": "skip", "layer_num": 2}, 58 | {"type": "pool", "pool": "max", "kernel": 2, "stride": 2}, # 10x10 59 | {"type": "conv2d", "depth": 512}, 60 | {"type": "skip", "layer_num": 2}, 61 | {"type": "pool", "pool": "max", "kernel": 2, "stride": 2}, # 5x5 62 | {"type": "conv2d", "depth": 512}, 63 | ] 64 | 65 | with pt.defaults_scope(**self.pt_defaults_scope_value()): 66 | output_tensor = hourglass( 67 | feature_tensor, hgd, 68 | net_type=self.options["hourglass_type"] if "hourglass_type" in self.options else None, 69 | extra_highlevel_feature=None 70 | ) 71 | 72 | return output_tensor 73 | 74 | rotate_dominating_features_if_necessary = GenericDecoderFactory.rotate_dominating_features_if_necessary 75 | 76 | 77 | 78 | -------------------------------------------------------------------------------- /net_modules/spatial_transformer_pt.py: -------------------------------------------------------------------------------- 1 | from net_modules.deconv import * 2 | from net_modules.spatial_transformer import transformer 3 | from net_modules.tps_stn import TPS_STN, TPS_TRANSFORM 4 | import numpy as np 5 | 6 | import zutils.tf_math_funcs as tmf 7 | 8 | 9 | @prettytensor.Register() 10 | def spatial_transformer( 11 | input_layer, theta, out_size, name=PROVIDED 12 | ): 13 | 14 | # init 15 | input_shape = tmf.get_shape(input_layer.tensor) 16 | assert len(input_shape) == 4, "input tensor must be rank 4" 17 | if theta is np.ndarray: 18 | theta = tf.constant(theta) 19 | elif not tmf.is_tf_data(theta): 20 | theta = theta.tensor 21 | 22 | # apply transformer 23 | output = transformer(input_layer.tensor, theta, out_size=out_size, name=name) 24 | 25 | # make output shape explicit 26 | output = tf.reshape(output, [input_shape[0]]+out_size+[input_shape[3]]) 27 | return output 28 | 29 | 30 | @prettytensor.Register() 31 | def coordinate_inv_transformer( 32 | input_layer, theta, name=PROVIDED 33 | ): 34 | 35 | # init 36 | input_tensor = input_layer.tensor 37 | input_shape = tmf.get_shape(input_tensor) 38 | assert len(input_shape) == 3, "input tensor must be rank 3" 39 | if theta is np.ndarray: 40 | theta = tf.constant(theta) 41 | elif not tmf.is_tf_data(theta): 42 | theta = theta.tensor 43 | 44 | keypoint_num = tmf.get_shape(input_tensor)[1] 45 | 46 | with tf.variable_scope(name): 47 | kp2_e = tf.concat([input_tensor, tf.ones_like(input_tensor[:, :, :1])], axis=2) 48 | kp2_e = tf.expand_dims(kp2_e, axis=-1) 49 | transform_e = tf.tile(tf.expand_dims(theta, axis=1), [1, keypoint_num, 1, 1]) 50 | kp1from2_e = tf.matmul(transform_e, kp2_e) 51 | kp1from2 = tf.squeeze(kp1from2_e, axis=-1) 52 | 53 | return kp1from2 54 | 55 | 56 | @prettytensor.Register() 57 | def spatial_transformer_tps( 58 | input_layer, nx, ny, cp, out_size, fp_more=None, name=PROVIDED 59 | ): 60 | 61 | # init 62 | input_shape = tmf.get_shape(input_layer.tensor) 63 | assert len(input_shape) == 4, "input tensor must be rank 4" 64 | 65 | def convert_to_tensor(a): 66 | if a is np.ndarray: 67 | a = tf.constant(a) 68 | elif not tmf.is_tf_data(a): 69 | a = a.tensor 70 | return a 71 | 72 | cp = convert_to_tensor(cp) 73 | 74 | batch_size = tmf.get_shape(input_layer)[0] 75 | 76 | # apply transformer 77 | with tf.variable_scope(name): 78 | cp = tf.reshape(cp, [batch_size, -1, 2]) 79 | output = TPS_STN(input_layer.tensor, nx, ny, cp, out_size=out_size, fp_more=fp_more) 80 | 81 | # make output shape explicit 82 | output = tf.reshape(output, [input_shape[0]]+out_size+[input_shape[3]]) 83 | return output 84 | 85 | 86 | @prettytensor.Register() 87 | def coordinate_inv_transformer_tps( 88 | input_layer, nx, ny, cp, fp_more=None, name=PROVIDED 89 | ): 90 | input_shape = tmf.get_shape(input_layer.tensor) 91 | assert len(input_shape) == 3, "input tensor must be rank 3" 92 | p = input_layer.tensor 93 | with tf.variable_scope(name): 94 | output = TPS_TRANSFORM(nx, ny, cp, p, fp_more=fp_more) 95 | return output 96 | -------------------------------------------------------------------------------- /download_celeba.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # =========================== 3 | # Usage: ./setup.sh (model|data)? 4 | 5 | if wget --help | grep -q 'show-progress'; then 6 | WGET_FLAG="-q --show-progress" 7 | else 8 | WGET_FLAG="" 9 | fi 10 | 11 | # create a tmp directory for the downloading data 12 | TMP_DIR="./tmp_download" 13 | mkdir -p "${TMP_DIR}" 14 | 15 | #create the directory for the pre-trained model 16 | MODEL_DIR="./pretrained_results" 17 | mkdir -p "${MODEL_DIR}" 18 | 19 | #create the directory for the dataset 20 | DATA_DIR="./data" 21 | mkdir -p "${DATA_DIR}" 22 | 23 | # downloading model 24 | download_model() 25 | { 26 | # directory for celeba model 27 | TMP_MODEL_TAR_BALL="${TMP_DIR}/celeba_pretrained_results.tar.gz" 28 | 29 | MODEL_URL="http://files.ytzhang.net/lmdis-rep/release-v1/celeba/celeba_pretrained_results.tar.gz" 30 | echo "Downloading pre-trained models ..." 31 | wget ${WGET_FLAG} "${MODEL_URL}" -O "${TMP_MODEL_TAR_BALL}" 32 | echo "Uncompressing pre-trained models ..." 33 | tar -xzf "${TMP_MODEL_TAR_BALL}" -C "${TMP_DIR}" 34 | 35 | # move model to default directories 36 | echo "Move pre-trained image network model to ${MODEL_DIR} ..." 37 | mv "${TMP_DIR}/celeba_pretrained_results/celeba_10" "${MODEL_DIR}/celeba_10" 38 | mv "${TMP_DIR}/celeba_pretrained_results/celeba_30" "${MODEL_DIR}/celeba_30" 39 | 40 | } 41 | 42 | # downloading data 43 | download_data() 44 | { 45 | # directory for celeba data 46 | TMP_DATA_TAR_BALL="${DATA_DIR}/celeba_data.tar.gz" 47 | DATA_URL="http://files.ytzhang.net/lmdis-rep/release-v1/celeba/celeba_data.tar.gz" 48 | echo "Downloading data ..." 49 | wget ${WGET_FLAG} "${DATA_URL}" -O "${TMP_DATA_TAR_BALL}" 50 | echo "Uncompressing data ..." 51 | tar -xzf "${TMP_DATA_TAR_BALL}" -C "${DATA_DIR}" 52 | rm -rf "${TMP_DATA_TAR_BALL}" 53 | 54 | if [ ! -d "${DATA_DIR}/celeba_images" ];then 55 | echo "Warning! Please download CelebA data from http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html and have the folder in data/celeba_images ..." 56 | elif [ ! -d "${DATA_DIR}/celeba_images/Eval" ];then 57 | echo "Warning! Please download CelebA data from http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html and have the Eval folder in data/celeba_images/Eval ..." 58 | elif [ ! -d "${DATA_DIR}/celeba_images/Img" ];then 59 | echo "Warning! Please download CelebA data from http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html and have the Img folder in data/celeba_images/Img ..." 60 | fi 61 | 62 | if [ ! -d "${DATA_DIR}/celeba_images/Img/img_align_celeba_png" ];then 63 | echo "The CelebA official website provide images in jpg format in Img/img_align_celeba, if you want to do experiment on the CelebA png images, you can download them from http://files.ytzhang.net/lmdis-rep/release-v1/celeba/img_align_celeba_png.tar.gz and save the images in data/celeba_images/Img/img_align_celeba_png" 64 | fi 65 | } 66 | 67 | # default to download all 68 | if [ $# -eq 0 ]; then 69 | download_model 70 | download_data 71 | else 72 | case $1 in 73 | "model") download_model 74 | ;; 75 | "data") download_data 76 | ;; 77 | *) echo "Usage: ./setup.sh [OPTION]" 78 | echo "" 79 | echo "No option will download both model and data." 80 | echo "" 81 | echo "OPTION:\n\tmodel: only download the pre-trained models (.npy)" 82 | echo "\tdata: only download the data(.json)" 83 | ;; 84 | esac 85 | fi 86 | 87 | # clear the tmp files 88 | rm -rf "${TMP_DIR}" 89 | -------------------------------------------------------------------------------- /vis/vppAutoKeypointImageRecon.m: -------------------------------------------------------------------------------- 1 | function vppAutoKeypointImageRecon(result_path, step, samples_ids, save_to_file, type_ids) 2 | 3 | if ~exist('type_ids', 'var') 4 | type_ids = {}; 5 | end 6 | 7 | if ischar(type_ids) 8 | type_ids = {type_ids}; 9 | end 10 | 11 | vppDetailsResultsTemplate( ... 12 | @(varargin) vppAutoKeypointReconPrior_Internal(type_ids, varargin{:}), ... 13 | result_path, step, samples_ids, save_to_file, '%s/%s_%s'); 14 | 15 | function vppAutoKeypointReconPrior_Internal( ... 16 | type_ids, result_path, the_title, samples_ids, callback) 17 | 18 | A = load(fullfile(result_path, 'posterior_param.mat')); 19 | 20 | if ~isfield(A.encoded, 'structure_param') 21 | A.encoded.structure_param = zeros(numel(samples_ids),0); 22 | end 23 | 24 | if ~isfield(A.decoded, 'structure_param') 25 | A.decoded.structure_param = zeros(numel(samples_ids),0); 26 | end 27 | 28 | 29 | h = size(A.data, 2); 30 | w = size(A.data, 3); 31 | sample_num = size(A.data, 1); 32 | 33 | if isempty(samples_ids) 34 | samples_ids = 1:sample_num; 35 | end 36 | samples_ids = reshape(samples_ids, 1, numel(samples_ids)); 37 | 38 | N = numel(samples_ids); 39 | fw = ceil(sqrt(N*(4/3)*(h/w))); 40 | if (N/fw)==ceil(N/fw) 41 | ; 42 | elseif (N/(fw-1))==ceil((N-1)/(fw-1)) 43 | fw = fw-1; 44 | elseif (N/(fw+1))==ceil((N-1)/(fw+1)) 45 | fw = fw+1; 46 | end 47 | fh = ceil(N/fw); 48 | 49 | id_str = sprintf('%d-%d', min(samples_ids), max(samples_ids)); 50 | 51 | type_id_list = {'data-encoded', 'data-decoded', 'recon-encoded', 'recon-decoded'}; 52 | 53 | if ~isempty(type_ids) 54 | assert(all(ismember(type_ids, type_id_list)), 'unrecognized ids'); 55 | type_id_list = type_ids; 56 | end 57 | 58 | for type_id = 1:length(type_id_list) 59 | 60 | type_str = type_id_list{type_id}; 61 | 62 | D = struct(); 63 | switch type_str 64 | case 'data-encoded' 65 | D.vis = A.data; 66 | D.structure_param = A.encoded.structure_param; 67 | case 'data-decoded' 68 | D.vis = A.data; 69 | D.structure_param = A.decoded.structure_param; 70 | case 'recon-encoded' 71 | D.vis = A.decoded.vis; 72 | D.structure_param = A.encoded.structure_param; 73 | case 'recon-decoded' 74 | D.vis = A.decoded.vis; 75 | D.structure_param = A.decoded.structure_param; 76 | otherwise 77 | error('Internal error: Unrecognized type') 78 | end 79 | 80 | figure(1) 81 | set(gcf, 'color', 'white'); 82 | clf 83 | for j = 1:N 84 | sidx = samples_ids(j); 85 | subplot(fh, fw, j) 86 | vppAutoKeypointShowSingle(squeeze(D.vis(sidx, :, :, :)), squeeze(D.structure_param(sidx, :, :))) 87 | title(int2str(sidx)) 88 | 89 | ha = axes('Position',[0 0 1 1],'Xlim',[0 1],'Ylim',[0 1],'Box','off','Visible','off','Units','normalized', 'clipping' , 'off'); 90 | text(0.5, 1,sprintf('%s %s (samples %s) %s', '\bf', the_title, id_str, type_str), ... 91 | 'HorizontalAlignment' ,'center','VerticalAlignment', 'top'); 92 | 93 | end 94 | 95 | if type_id ~= length(type_id_list) 96 | callback.callback( ... 97 | [result_path '_recon_keypoints_batch'], ... 98 | id_str, type_str ... 99 | ); 100 | else 101 | callback.callback_no_user_input( ... 102 | [result_path '_recon_keypoints_batch'], ... 103 | id_str, type_str ... 104 | ); 105 | end 106 | end 107 | 108 | -------------------------------------------------------------------------------- /nets/decoder/general_128x128_landmark.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import zutils.tf_math_funcs as tmf 3 | import numpy as np 4 | import zutils.pt_utils as ptu 5 | import prettytensor as pt 6 | from net_modules.common import * 7 | 8 | from net_modules.hourglass import hourglass 9 | 10 | from net_modules.auto_struct.keypoint_decoder import Factory as BaseFactory 11 | 12 | 13 | class Factory(BaseFactory): 14 | 15 | @staticmethod 16 | def pt_defaults_scope_value(): 17 | return { 18 | # non-linearity 19 | 'activation_fn': default_activation.current_value, 20 | # for batch normalization 21 | 'batch_normalize': True, 22 | 'learned_moments_update_rate': 0.0003, 23 | # 'learned_moments_update_rate': 1., 24 | 'variance_epsilon': 0.001, 25 | 'scale_after_normalization': True 26 | } 27 | 28 | default_patch_feature_dim = 8 29 | 30 | def __init__(self, recon_dist_param_num=1, options=None): 31 | super().__init__(recon_dist_param_num, options) 32 | if "image_channels" in options: 33 | self.image_channels = options["image_channels"] 34 | else: 35 | self.image_channels = 3 36 | 37 | def image_size(self): 38 | return 64, 64 39 | 40 | def feature2image(self, feature_map): 41 | output_channels = 3*self.recon_dist_param_num 42 | hgd = [ 43 | {"type": "conv2d", "depth": 64}, 44 | {"type": "skip", "layer_num": 2}, 45 | {"type": "pool", "pool": "max", "kernel": 2, "stride": 2}, # 32x32 46 | {"type": "conv2d", "depth": 128}, 47 | {"type": "skip", "layer_num": 2}, 48 | {"type": "pool", "pool": "max", "kernel": 2, "stride": 2}, # 16x16 49 | {"type": "conv2d", "depth": 256}, 50 | {"type": "skip", "layer_num": 2}, 51 | {"type": "pool", "pool": "max", "kernel": 2, "stride": 2}, # 8x8 52 | {"type": "conv2d", "depth": 512}, 53 | {"type": "skip", "layer_num": 2}, 54 | {"type": "pool", "pool": "max", "kernel": 2, "stride": 2}, # 4x4 55 | {"type": "conv2d", "depth": 512}, 56 | ] 57 | with pt.defaults_scope(**self.pt_defaults_scope_value()): 58 | output_tensor = hourglass( 59 | feature_map, hgd, 60 | net_type=self.options["hourglass_type"] if "hourglass_type" in self.options else None, 61 | extra_highlevel_feature=None 62 | ) 63 | output_tensor = ( 64 | pt.wrap(output_tensor). 65 | deconv2d(3, 32, stride=2). 66 | conv2d(3, output_channels, activation_fn=None) 67 | ).tensor 68 | return output_tensor 69 | 70 | def input_feature_dim(self): 71 | return 64 72 | 73 | def bg_feature2image(self, bg_feature): 74 | batch_size = tmf.get_shape(bg_feature)[0] 75 | with pt.defaults_scope(**self.pt_defaults_scope_value()): 76 | return ( 77 | pt.wrap(bg_feature). 78 | conv2d(3, 512). # 2 79 | deconv2d(3, 512, stride=2). # 4 80 | deconv2d(3, 256, stride=2). # 8 81 | deconv2d(3, 256, stride=2). # 16 82 | deconv2d(3, 128, stride=2). # 32 83 | deconv2d(3, 64, stride=2). # 64 84 | deconv2d(3, 32, stride=2). # 128 85 | conv2d(3, 3*self.recon_dist_param_num, activation_fn=None).tensor 86 | ) 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /nets/distribution/spike_at_zero.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from collections import OrderedDict 3 | import math 4 | import zutils.tf_math_funcs as tmf 5 | import nets.distribution.generic 6 | 7 | GenericFactory = nets.distribution.generic.Factory 8 | 9 | epsilon = tmf.epsilon 10 | 11 | 12 | class Factory(GenericFactory): 13 | 14 | def __init__(self): 15 | pass 16 | 17 | @staticmethod 18 | def is_atomic(): # True for discrete distribution 19 | return True 20 | 21 | @staticmethod 22 | def param_num(): 23 | return 0 24 | 25 | @staticmethod 26 | def param_dict(): 27 | return OrderedDict() 28 | 29 | @staticmethod 30 | def transform2param(input_tensor, latent_dim): 31 | """ Create network for converting input_tensor to distribution parameters 32 | 33 | :param input_tensor: (posterior phase) input tensor for the posterior 34 | :param latent_dim: dimension of the latent_variables 35 | :return: param_tensor - distribution parameters 36 | """ 37 | return input_tensor 38 | 39 | @classmethod 40 | def parametrize(cls, param_tensor, latent_dim): 41 | """ Create network for converting parameter_tensor to parameter dictionary 42 | 43 | :param param_tensor: (posterior phase) input tensor for the posterior 44 | :param latent_dim: dimension of the latent_variables 45 | :return: dist_param - distribution parameters 46 | """ 47 | dist_param = cls.param_dict() 48 | return dist_param 49 | 50 | @classmethod 51 | def deparametrize(cls, dist_param): 52 | param_tensor = None 53 | return param_tensor 54 | 55 | @staticmethod 56 | def nll(dist_param, samples): 57 | """ Compute negative log likelihood on given sample and distribution parameter 58 | 59 | :param samples: samples for evaluating PDF 60 | :param dist_param: input for the posterior 61 | :return: likelihood - likelihood to draw such samples from given distribution 62 | :return: is_atomic - is atomic, scalar or the same size as likelihood 63 | """ 64 | spike_nll = tf.where( 65 | samples == tf.zeros_like(samples), 66 | tf.ones_like(samples), -tf.ones_like(samples) * math.inf) 67 | return spike_nll, True 68 | 69 | @staticmethod 70 | def sampling(dist_param, batch_size, latent_dim): 71 | """ Create network for VAE latent variables (sampling only) 72 | 73 | :param dist_param: input for the posterior 74 | :param batch_size: batch size 75 | :param latent_dim: dimension of the latent_variables 76 | :return: samples - random samples from either posterior or prior distribution 77 | """ 78 | 79 | # generate random samples 80 | return tf.zeros([batch_size, latent_dim]) 81 | 82 | @staticmethod 83 | def self_entropy(dist_param): 84 | return 0.0 85 | 86 | @classmethod 87 | def kl_divergence(cls, dist_param, ref_dist_param, ref_dist_type=None): 88 | if not isinstance(ref_dist_type, cls) and ref_dist_type is not None: # handle hybrid distribution 89 | return None 90 | return 0.0 91 | 92 | @classmethod 93 | def cross_entropy(cls, dist_param, ref_dist_param, ref_dist_type=None): 94 | if not isinstance(ref_dist_type, cls) and ref_dist_type is not None: # handle hybrid distribution 95 | return None 96 | return 0.0 97 | 98 | @staticmethod 99 | def mean(dist_param): 100 | return 0.0 101 | -------------------------------------------------------------------------------- /nets/encoder/general_64x64.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import prettytensor as pt 4 | from net_modules.common import * 5 | from net_modules import keypoints_2d 6 | import math 7 | 8 | from net_modules.hourglass import hourglass 9 | 10 | import zutils.tf_math_funcs as tmf 11 | from net_modules.auto_struct.keypoint_encoder import Factory as BaseFactory 12 | 13 | class Factory(BaseFactory): 14 | @staticmethod 15 | def pt_defaults_scope_value(): 16 | return { 17 | # non-linearity 18 | 'activation_fn': default_activation.current_value, 19 | # for batch normalization 20 | 'batch_normalize': True, 21 | 'learned_moments_update_rate': 0.0003, 22 | 'variance_epsilon': 0.001, 23 | 'scale_after_normalization': True 24 | } 25 | 26 | default_patch_feature_dim = 8 27 | 28 | def __init__(self, output_channels, options): 29 | """ 30 | :param output_channels: output_channels for the encoding net 31 | """ 32 | super().__init__(output_channels, options) 33 | 34 | self.target_input_size = [80, 80] 35 | self.input_size = [64, 64] 36 | 37 | def image2heatmap(self, image_tensor): 38 | hgd = [ 39 | {"type": "conv2d", "depth": 32, "decoder_depth": self.options["keypoint_num"] + 1, 40 | "decoder_activation_fn": None}, 41 | # plus one for bg 42 | {"type": "conv2d", "depth": 32}, 43 | {"type": "skip", "layer_num": 3, }, 44 | {"type": "pool", "pool": "max"}, 45 | {"type": "conv2d", "depth": 64}, 46 | {"type": "conv2d", "depth": 64}, 47 | {"type": "skip", "layer_num": 3, }, 48 | {"type": "pool", "pool": "max"}, 49 | {"type": "conv2d", "depth": 64}, 50 | {"type": "conv2d", "depth": 64}, 51 | {"type": "skip", "layer_num": 3, }, 52 | {"type": "pool", "pool": "max"}, 53 | {"type": "conv2d", "depth": 64}, 54 | {"type": "conv2d", "depth": 64}, 55 | ] 56 | 57 | with pt.defaults_scope(**self.pt_defaults_scope_value()): 58 | raw_heatmap = hourglass( 59 | image_tensor, hgd, 60 | net_type=self.options["hourglass_type"] if "hourglass_type" in self.options else None 61 | ) 62 | # raw_heatmap = pt.wrap(raw_heatmap).pixel_bias(activation_fn=None).tensor 63 | 64 | return raw_heatmap 65 | 66 | 67 | 68 | def image2feature(self, image_tensor): 69 | 70 | if self.patch_feature_dim == 0: 71 | return None 72 | 73 | hgd = [ 74 | {"type": "conv2d", "depth": 32, "decoder_depth": 64}, 75 | {"type": "conv2d", "depth": 64, "decoder_depth": 64}, 76 | {"type": "skip", "layer_num": 2}, 77 | {"type": "pool", "pool": "max", "kernel": 2, "stride": 2}, # 40 x 40 78 | {"type": "conv2d", "depth": 128}, 79 | {"type": "skip", "layer_num": 2}, 80 | {"type": "pool", "pool": "max", "kernel": 2, "stride": 2}, # 20x20 81 | {"type": "conv2d", "depth": 256}, 82 | {"type": "skip", "layer_num": 2}, 83 | {"type": "pool", "pool": "max", "kernel": 2, "stride": 2}, # 10x10 84 | {"type": "conv2d", "depth": 512}, 85 | ] 86 | 87 | with pt.defaults_scope(**self.pt_defaults_scope_value()): 88 | feature_map = hourglass( 89 | image_tensor, hgd, 90 | net_type=self.options["hourglass_type"] if "hourglass_type" in self.options else None 91 | ) 92 | return feature_map 93 | 94 | -------------------------------------------------------------------------------- /nets/encoder/car_64x64.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import prettytensor as pt 4 | from net_modules.common import * 5 | from net_modules import keypoints_2d 6 | import math 7 | 8 | from net_modules.hourglass import hourglass 9 | 10 | import zutils.tf_math_funcs as tmf 11 | from net_modules.auto_struct.keypoint_encoder import Factory as BaseFactory 12 | 13 | class Factory(BaseFactory): 14 | @staticmethod 15 | def pt_defaults_scope_value(): 16 | return { 17 | # non-linearity 18 | 'activation_fn': default_activation.current_value, 19 | # for batch normalization 20 | 'batch_normalize': True, 21 | 'learned_moments_update_rate': 0.0003, 22 | 'variance_epsilon': 0.001, 23 | 'scale_after_normalization': True 24 | } 25 | 26 | default_patch_feature_dim = 8 27 | 28 | def __init__(self, output_channels, options): 29 | """ 30 | :param output_channels: output_channels for the encoding net 31 | """ 32 | super().__init__(output_channels, options) 33 | 34 | self.target_input_size = [96, 96] 35 | self.input_size = [64, 64] 36 | 37 | 38 | def image2heatmap(self, image_tensor): 39 | hgd = [ 40 | {"type": "conv2d", "depth": 32, "decoder_depth": self.options["keypoint_num"] + 1, 41 | "decoder_activation_fn": None}, 42 | # plus one for bg 43 | {"type": "conv2d", "depth": 32}, 44 | {"type": "skip", "layer_num": 3, }, 45 | {"type": "pool", "pool": "max"}, 46 | {"type": "conv2d", "depth": 64}, 47 | {"type": "conv2d", "depth": 64}, 48 | {"type": "skip", "layer_num": 3, }, 49 | {"type": "pool", "pool": "max"}, 50 | {"type": "conv2d", "depth": 64}, 51 | {"type": "conv2d", "depth": 64}, 52 | {"type": "skip", "layer_num": 3, }, 53 | {"type": "pool", "pool": "max"}, 54 | {"type": "conv2d", "depth": 64}, 55 | {"type": "conv2d", "depth": 64}, 56 | ] 57 | 58 | with pt.defaults_scope(**self.pt_defaults_scope_value()): 59 | raw_heatmap = hourglass( 60 | image_tensor, hgd, 61 | net_type=self.options["hourglass_type"] if "hourglass_type" in self.options else None 62 | ) 63 | # raw_heatmap = pt.wrap(raw_heatmap).pixel_bias(activation_fn=None).tensor 64 | 65 | return raw_heatmap 66 | 67 | def image2feature(self, image_tensor): 68 | 69 | if self.patch_feature_dim == 0: 70 | return None 71 | 72 | hgd = [ 73 | {"type": "conv2d", "depth": 32, "decoder_depth": 64}, 74 | {"type": "conv2d", "depth": 64, "decoder_depth": 64}, 75 | {"type": "skip", "layer_num": 2}, 76 | {"type": "pool", "pool": "max", "kernel": 2, "stride": 2}, # 40 x 40 77 | {"type": "conv2d", "depth": 128}, 78 | {"type": "skip", "layer_num": 2}, 79 | {"type": "pool", "pool": "max", "kernel": 2, "stride": 2}, # 20x20 80 | {"type": "conv2d", "depth": 256}, 81 | {"type": "skip", "layer_num": 2}, 82 | {"type": "pool", "pool": "max", "kernel": 2, "stride": 2}, # 10x10 83 | {"type": "conv2d", "depth": 512}, 84 | ] 85 | 86 | 87 | with pt.defaults_scope(**self.pt_defaults_scope_value()): 88 | feature_map = hourglass( 89 | image_tensor, hgd, 90 | net_type=self.options["hourglass_type"] if "hourglass_type" in self.options else None 91 | ) 92 | return feature_map 93 | 94 | 95 | -------------------------------------------------------------------------------- /vis/vppDetailsResultsTemplate.m: -------------------------------------------------------------------------------- 1 | function vppDetailsResultsTemplate( ... 2 | show_func, result_path, step, samples_ids, save_to_file, output_pattern) 3 | 4 | if ~exist('save_to_file', 'var') || isempty(save_to_file) 5 | save_to_file = false; 6 | end 7 | 8 | if ~iscell(result_path) 9 | result_path = {result_path}; 10 | end 11 | 12 | if ~exist('output_pattern', 'var') || isempty(output_pattern) 13 | output_pattern = '%s/%d'; 14 | end 15 | 16 | auto_continue = false; 17 | for k = 1:numel(result_path) 18 | is_last_single = (k==numel(result_path)); 19 | auto_continue = vppDetailsResultsTemplate_Single( ... 20 | show_func, result_path{k}, step, samples_ids, save_to_file, ... 21 | auto_continue, is_last_single, output_pattern); 22 | end 23 | 24 | 25 | function batch_mode = vppDetailsResultsTemplate_Single( ... 26 | show_func, result_path, step, samples_ids, save_to_file, ... 27 | auto_continue, is_last_single, output_pattern) 28 | 29 | batch_mode = auto_continue; 30 | result_path0 = result_path; 31 | 32 | if isempty(step) 33 | step = 'latest'; 34 | end 35 | step0 = step; 36 | if ischar(step0) && (strcmp(step0, 'all') || strcmp(step0, 'latest')) 37 | all_available_steps = dir(fullfile(result_path, 'test.snapshot/step_*')); 38 | all_available_steps = {all_available_steps.name}; 39 | all_available_steps = ... 40 | cellfun(@(a) str2double(a(6:end)), all_available_steps); 41 | all_available_steps(isnan(all_available_steps)) = []; 42 | all_available_steps = sort(all_available_steps); 43 | step = all_available_steps; 44 | if exist(fullfile(result_path, 'test.final'), 'file') 45 | step(end+1) = -1; 46 | end 47 | if isempty(step) 48 | fprintf(2, 'No test snapshot found: %s\n', result_path) 49 | return 50 | end 51 | if strcmp(step0, 'latest') 52 | step = step(end); 53 | end 54 | end 55 | 56 | if save_to_file 57 | callback = figure_show_callback(output_pattern, auto_continue); 58 | else 59 | callback = figure_show_callback(); 60 | end 61 | 62 | if ischar(step) 63 | step = {step}; 64 | end 65 | 66 | for k = 1:numel(step) 67 | if iscell(step(k)) 68 | step_k = step{k}; 69 | else 70 | step_k = step(k); 71 | end 72 | if step_k<0 73 | step_k = 'final'; 74 | end 75 | if ischar(step_k) 76 | step_result_path = fullfile(result_path, sprintf('test.%s', step_k)); 77 | step_str = ['step ' step_k]; 78 | else 79 | step_result_path = fullfile(result_path, sprintf('test.snapshot/step_%d', step_k)); 80 | step_str = sprintf('step %d', step_k); 81 | end 82 | if ~exist(step_result_path, 'dir') 83 | fprintf(2, 'Test snapshot folder not available: %s\n', step_result_path) 84 | continue; 85 | end 86 | fprintf('Display: %s\n', step_result_path) 87 | % step_title = sprintf('step %d', step(k)); 88 | step_title = sprintf('%s\n%s ', ... 89 | strrep(wrap_str_at_comma(result_path0, 100), '_', '\_'), step_str); 90 | show_func( ... 91 | step_result_path, step_title, samples_ids, callback); 92 | if kline_length 104 | bpb = s==','; 105 | bploc = find(bpb(1:min(end,line_length)),1,'last'); 106 | if isempty(bploc) 107 | bploc = find(bpb,1); 108 | end 109 | if ~isempty(bploc) 110 | r = [r, newline, s(1:bploc)]; 111 | s = s(bploc+1:end); 112 | end 113 | end 114 | r = [r, newline, s]; 115 | r = strtrim(r); 116 | 117 | 118 | -------------------------------------------------------------------------------- /nets/distribution/spike.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from collections import OrderedDict 3 | import math 4 | import zutils.tf_math_funcs as tmf 5 | import nets.distribution.generic 6 | 7 | GenericFactory = nets.distribution.generic.Factory 8 | 9 | epsilon = tmf.epsilon 10 | 11 | 12 | class Factory(GenericFactory): 13 | 14 | def __init__(self, **kwargs): 15 | pass 16 | 17 | @staticmethod 18 | def is_atomic(): # True for discrete distribution 19 | return True 20 | 21 | @staticmethod 22 | def param_num(): 23 | return 1 24 | 25 | @staticmethod 26 | def param_dict(a=0.0): 27 | return OrderedDict(a=a) 28 | 29 | def transform2param(self, input_tensor, latent_dim): 30 | """ Create network for converting input_tensor to distribution parameters 31 | 32 | :param input_tensor: (posterior phase) input tensor for the posterior 33 | :param latent_dim: dimension of the latent_variables 34 | :return: param_tensor - distribution parameters 35 | """ 36 | param_tensor = input_tensor 37 | return param_tensor 38 | 39 | @classmethod 40 | def parametrize(cls, param_tensor, latent_dim): 41 | """ Create network for converting parameter_tensor to parameter dictionary 42 | 43 | :param param_tensor: (posterior phase) input tensor for the posterior 44 | :param latent_dim: dimension of the latent_variables 45 | :return: dist_param - distribution parameters 46 | """ 47 | dist_param = cls.param_dict( 48 | a=param_tensor 49 | ) 50 | return dist_param 51 | 52 | @classmethod 53 | def deparametrize(cls, dist_param): 54 | param_tensor = dist_param["a"] 55 | return param_tensor 56 | 57 | @staticmethod 58 | def nll(dist_param, samples): 59 | """ Compute negative log likelihood on given sample and distribution parameter 60 | 61 | :param samples: samples for evaluating PDF 62 | :param dist_param: input for the posterior 63 | :return: likelihood - likelihood to draw such samples from given distribution 64 | :return: is_atomic - is atomic, scalar or the same size as likelihood 65 | """ 66 | spike_nll = tf.where( 67 | samples == dist_param["a"], 68 | tf.ones_like(samples), -tf.ones_like(samples)*math.inf) 69 | return spike_nll, True 70 | 71 | @staticmethod 72 | def sampling(dist_param, batch_size, latent_dim): 73 | """ Create network for VAE latent variables (sampling only) 74 | 75 | :param dist_param: input for the posterior 76 | :param batch_size: batch size 77 | :param latent_dim: dimension of the latent_variables 78 | :return: samples - random samples from either posterior or prior distribution 79 | """ 80 | 81 | # generate random samples 82 | return tf.ones([batch_size, latent_dim]) * dist_param["a"] 83 | 84 | @staticmethod 85 | def self_entropy(dist_param): 86 | return tf.zeros_like(dist_param["a"]) 87 | 88 | @classmethod 89 | def kl_divergence(cls, dist_param, ref_dist_param, ref_dist_type=None): 90 | if not isinstance(ref_dist_type, cls) and ref_dist_type is not None: # handle hybrid distribution 91 | return None 92 | return tf.zeros_like(dist_param["a"]) 93 | # return tf.where( 94 | # dist_param["a"] == ref_dist_param["a"], 95 | # tf.zeros_like(dist_param["a"]), tf.ones_like(dist_param["a"])*math.inf) 96 | 97 | @classmethod 98 | def cross_entropy(cls, dist_param, ref_dist_param, ref_dist_type=None): 99 | if not isinstance(ref_dist_type, cls) and ref_dist_type is not None: # handle hybrid distribution 100 | return None 101 | return tf.zeros_like(dist_param["a"]) 102 | # return tf.where( 103 | # dist_param["a"] == ref_dist_param["a"], 104 | # tf.zeros_like(dist_param["a"]), tf.ones_like(dist_param["a"])*math.inf) 105 | 106 | @staticmethod 107 | def mean(dist_param): 108 | return dist_param["a"] 109 | 110 | @staticmethod 111 | def sample_to_real(samples): 112 | return samples 113 | 114 | @staticmethod 115 | def real_to_samples(samples_in_real): 116 | return samples_in_real 117 | -------------------------------------------------------------------------------- /zutils/option_struct.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from copy import copy 3 | from easydict import EasyDict as edict 4 | 5 | 6 | class OptionStruct_UnsetCacheNone: 7 | pass 8 | 9 | 10 | class OptionStruct: 11 | 12 | def __init__(self, option_dict_or_struct): 13 | self.user_dict = {} 14 | self.enabled_dict = {} 15 | self.unset_set = set() 16 | self.option_def = None 17 | self.option_name = None 18 | if option_dict_or_struct is not None: 19 | self.add_user_dict(option_dict_or_struct) 20 | 21 | def add_user_dict(self, option_dict_or_struct): 22 | if isinstance(option_dict_or_struct, dict): 23 | app_user_dict = option_dict_or_struct 24 | elif isinstance(option_dict_or_struct, OptionStruct): 25 | app_user_dict = option_dict_or_struct.enabled_dict 26 | else: 27 | raise ValueError("Invalid option dict") 28 | self.user_dict = {**self.user_dict, **app_user_dict} 29 | 30 | def set(self, key, value): 31 | self.enabled_dict[key] = value 32 | self.user_dict[key] = value 33 | 34 | def unset(self, key): 35 | self.user_dict.pop(key, None) 36 | self.enabled_dict.pop(key, OptionStruct_UnsetCacheNone()) 37 | self.unset_set.add(key) 38 | 39 | def set_default(self, key, default_value): 40 | if key in self.user_dict: 41 | self.enabled_dict[key] = self.user_dict[key] 42 | else: 43 | self.enabled_dict[key] = default_value 44 | 45 | def get_enabled(self, key): 46 | return self.enabled_dict[key] 47 | 48 | def __getitem__(self, item): 49 | return self.get_enabled(item) 50 | 51 | def __setitem__(self, key, value): 52 | self.set_default(key, value) 53 | 54 | def get_namedtuple(self, tuple_type_name=None): 55 | if tuple_type_name is None: 56 | assert self.option_name is not None, "tuple_type_name must be specified" 57 | tuple_type_name = self.option_name 58 | return namedtuple(tuple_type_name, self.enabled_dict.keys())(**self.enabled_dict) 59 | 60 | def get_dict(self): 61 | return self.enabled_dict 62 | 63 | def get_edict(self): 64 | return edict(self.enabled_dict) 65 | 66 | def _require(self, option_name, is_include): 67 | assert isinstance(self.option_def, OptionDef), "invalid option_def" 68 | p = self.option_def[option_name] 69 | self.enabled_dict = {**self.enabled_dict, **p.enabled_dict} 70 | if is_include: 71 | self.user_dict = {**self.user_dict, **p.user_dict} 72 | else: 73 | self.user_dict = {**self.user_dict, **p.enabled_dict} 74 | self.unset_set = self.unset_set.union(p.unset_set) 75 | 76 | def include(self, option_name): 77 | self._require(option_name, is_include=True) 78 | 79 | def require(self, option_name): 80 | self._require(option_name, is_include=False) 81 | 82 | def finalize(self, error_uneaten=True): 83 | uneaten_keys = set(self.user_dict.keys()) - set(self.enabled_dict.keys()) - self.unset_set 84 | if len(uneaten_keys) > 0: 85 | print("WARNING: uneaten options") 86 | for k in uneaten_keys: 87 | print(" %s: " % k, end="") 88 | print(self.user_dict[k]) 89 | if error_uneaten: 90 | raise ValueError("uneaten options: " + k) 91 | 92 | 93 | class OptionDef: 94 | 95 | def __init__(self, user_dict={}, def_cls_or_obj=None): 96 | self._user_dict = user_dict 97 | self._opts = {} 98 | if def_cls_or_obj is None: 99 | self._def_obj = self 100 | elif def_cls_or_obj is type: 101 | self._def_obj = def_cls_or_obj() 102 | else: 103 | self._def_obj = def_cls_or_obj 104 | 105 | def __getitem__(self, item): 106 | if item in self._opts: 107 | return self._opts[item] 108 | else: 109 | assert hasattr(self._def_obj, item), "no such method for option definition" 110 | p = OptionStruct(self._user_dict) 111 | p.option_def = self 112 | p.option_name = item + "_options" 113 | # opt_def_func = getattr(self._def_obj, item) 114 | # opt_def_func(p) 115 | eval("self._def_obj.%s(p)" % item) 116 | self._opts[item] = p 117 | return p 118 | -------------------------------------------------------------------------------- /nets/encoder/general_128x128_landmark.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import zutils.tf_math_funcs as tmf 3 | import numpy as np 4 | import zutils.pt_utils as ptu 5 | import prettytensor as pt 6 | from net_modules.common import * 7 | 8 | from net_modules.hourglass import hourglass 9 | 10 | from net_modules.auto_struct.keypoint_encoder import Factory as BaseFactory 11 | 12 | 13 | class Factory(BaseFactory): 14 | 15 | @staticmethod 16 | def pt_defaults_scope_value(): 17 | return { 18 | # non-linearity 19 | 'activation_fn': default_activation.current_value, 20 | # for batch normalization 21 | 'batch_normalize': True, 22 | 'learned_moments_update_rate': 0.0003, 23 | 'variance_epsilon': 0.001, 24 | 'scale_after_normalization': True 25 | } 26 | 27 | default_patch_feature_dim = 8 28 | 29 | def __init__(self, output_channels, options): 30 | """ 31 | :param output_channels: output_channels for the encoding net 32 | """ 33 | super().__init__(output_channels, options) 34 | 35 | self.target_input_size = [192, 192] 36 | self.input_size = [128, 128] 37 | 38 | def image2heatmap(self, image_tensor): 39 | mid_tensor = ( 40 | pt.wrap(image_tensor). 41 | conv2d(3, 32). 42 | max_pool(2, 2) 43 | ).tensor 44 | 45 | hgd = [ 46 | {"type": "conv2d", "depth": 64}, 47 | {"type": "skip", "layer_num": 2}, 48 | {"type": "pool", "pool": "max", "kernel": 2, "stride": 2}, # 32 x 32 49 | {"type": "conv2d", "depth": 128}, 50 | {"type": "skip", "layer_num": 2}, 51 | {"type": "pool", "pool": "max", "kernel": 2, "stride": 2}, # 16 x 16 52 | {"type": "conv2d", "depth": 256}, 53 | {"type": "skip", "layer_num": 2}, 54 | {"type": "pool", "pool": "max", "kernel": 2, "stride": 2}, # 8 x 8 55 | {"type": "conv2d", "depth": 512}, 56 | {"type": "skip", "layer_num": 2}, 57 | {"type": "pool", "pool": "max", "kernel": 2, "stride": 2}, # 4 x 4 58 | {"type": "conv2d", "depth": 512}, 59 | ] 60 | 61 | with pt.defaults_scope(**self.pt_defaults_scope_value()): 62 | raw_heatmap_feat = hourglass( 63 | mid_tensor, hgd, 64 | net_type = self.options["hourglass_type"] if "hourglass_type" in self.options else None 65 | ) 66 | 67 | return raw_heatmap_feat 68 | 69 | def image2feature(self, image_tensor): 70 | 71 | if self.patch_feature_dim == 0: 72 | return None 73 | 74 | mid_tensor = ( 75 | pt.wrap(image_tensor). 76 | conv2d(3, 32). 77 | max_pool(2, 2) 78 | ).tensor # 64x64 79 | 80 | hgd = [ 81 | {"type": "conv2d", "depth": 64}, 82 | {"type": "skip", "layer_num": 2}, 83 | {"type": "pool", "pool": "max", "kernel": 2, "stride": 2}, # 32 x 32 84 | {"type": "conv2d", "depth": 128}, 85 | {"type": "skip", "layer_num": 2}, 86 | {"type": "pool", "pool": "max", "kernel": 2, "stride": 2}, # 16 x 16 87 | {"type": "conv2d", "depth": 256}, 88 | {"type": "skip", "layer_num": 2}, 89 | {"type": "pool", "pool": "max", "kernel": 2, "stride": 2}, # 8 x 8 90 | {"type": "conv2d", "depth": 512}, 91 | {"type": "skip", "layer_num": 2}, 92 | {"type": "pool", "pool": "max", "kernel": 2, "stride": 2}, # 4 x 4 93 | {"type": "conv2d", "depth": 512}, 94 | ] 95 | 96 | with pt.defaults_scope(**self.pt_defaults_scope_value()): 97 | feature_map = hourglass( 98 | mid_tensor, hgd, 99 | net_type=self.options["hourglass_type"] if "hourglass_type" in self.options else None 100 | ) 101 | 102 | return feature_map 103 | 104 | def bg_feature(self, image_tensor): 105 | with pt.defaults_scope(**self.pt_defaults_scope_value()): 106 | return ( 107 | pt.wrap(image_tensor). 108 | conv2d(3, 32).max_pool(2, 2). # 64 109 | conv2d(3, 64).max_pool(2, 2). # 32 110 | conv2d(3, 128).max_pool(2, 2). # 16 111 | conv2d(3, 256).max_pool(2, 2). # 8 112 | conv2d(3, 256).max_pool(2, 2). # 4 113 | conv2d(3, 512).max_pool(2, 2). # 2 114 | conv2d(3, 512) 115 | ) 116 | 117 | 118 | -------------------------------------------------------------------------------- /nets/encoder/general_80x80.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import prettytensor as pt 4 | from net_modules.common import * 5 | from net_modules import keypoints_2d 6 | import math 7 | 8 | from net_modules.hourglass import hourglass 9 | 10 | import zutils.tf_math_funcs as tmf 11 | from net_modules.auto_struct.keypoint_encoder import Factory as BaseFactory 12 | 13 | def encoder_map(input_tensor, hourglass_type=None): 14 | 15 | hgd = [ 16 | {"type": "conv2d", "depth": 32, "decoder_depth": 64}, 17 | {"type": "conv2d", "depth": 64, "decoder_depth": 64}, 18 | {"type": "skip", "layer_num": 2}, 19 | {"type": "pool", "pool": "max", "kernel": 2, "stride": 2}, # 40 x 40 20 | {"type": "conv2d", "depth": 128}, 21 | {"type": "skip", "layer_num": 2}, 22 | {"type": "pool", "pool": "max", "kernel": 2, "stride": 2}, # 20x20 23 | {"type": "conv2d", "depth": 256}, 24 | {"type": "skip", "layer_num": 2}, 25 | {"type": "pool", "pool": "max", "kernel": 2, "stride": 2}, # 10x10 26 | {"type": "conv2d", "depth": 512}, 27 | ] 28 | 29 | output_tensor = hourglass( 30 | input_tensor, hgd, 31 | net_type=hourglass_type 32 | ) 33 | return output_tensor 34 | 35 | 36 | 37 | 38 | class Factory(BaseFactory): 39 | @staticmethod 40 | def pt_defaults_scope_value(): 41 | return { 42 | # non-linearity 43 | 'activation_fn': default_activation.current_value, 44 | # for batch normalization 45 | 'batch_normalize': True, 46 | 'learned_moments_update_rate': 0.0003, 47 | 'variance_epsilon': 0.001, 48 | 'scale_after_normalization': True 49 | } 50 | 51 | default_patch_feature_dim = 8 52 | 53 | def __init__(self, output_channels, options): 54 | """ 55 | :param output_channels: output_channels for the encoding net 56 | """ 57 | super().__init__(output_channels, options) 58 | 59 | self.target_input_size = [96, 96] 60 | self.input_size = [80, 80] 61 | 62 | def image2heatmap(self, image_tensor): 63 | hgd = [ 64 | {"type": "conv2d", "depth": 32, "decoder_depth": self.options["keypoint_num"] + 1, 65 | "decoder_activation_fn": None}, 66 | # plus one for bg 67 | {"type": "conv2d", "depth": 32}, 68 | {"type": "skip", "layer_num": 3, }, 69 | {"type": "pool", "pool": "max"}, 70 | {"type": "conv2d", "depth": 64}, 71 | {"type": "conv2d", "depth": 64}, 72 | {"type": "skip", "layer_num": 3, }, 73 | {"type": "pool", "pool": "max"}, 74 | {"type": "conv2d", "depth": 64}, 75 | {"type": "conv2d", "depth": 64}, 76 | {"type": "skip", "layer_num": 3, }, 77 | {"type": "pool", "pool": "max"}, 78 | {"type": "conv2d", "depth": 64}, 79 | {"type": "conv2d", "depth": 64}, 80 | ] 81 | 82 | with pt.defaults_scope(**self.pt_defaults_scope_value()): 83 | raw_heatmap = hourglass( 84 | image_tensor, hgd, 85 | net_type=self.options["hourglass_type"] if "hourglass_type" in self.options else None 86 | ) 87 | # raw_heatmap = pt.wrap(raw_heatmap).pixel_bias(activation_fn=None).tensor 88 | 89 | return raw_heatmap 90 | 91 | 92 | 93 | def image2feature(self, image_tensor): 94 | 95 | if self.patch_feature_dim == 0: 96 | return None 97 | 98 | hgd = [ 99 | {"type": "conv2d", "depth": 32, "decoder_depth": 64}, 100 | {"type": "conv2d", "depth": 64, "decoder_depth": 64}, 101 | {"type": "skip", "layer_num": 2}, 102 | {"type": "pool", "pool": "max", "kernel": 2, "stride": 2}, # 40 x 40 103 | {"type": "conv2d", "depth": 128}, 104 | {"type": "skip", "layer_num": 2}, 105 | {"type": "pool", "pool": "max", "kernel": 2, "stride": 2}, # 20x20 106 | {"type": "conv2d", "depth": 256}, 107 | {"type": "skip", "layer_num": 2}, 108 | {"type": "pool", "pool": "max", "kernel": 2, "stride": 2}, # 10x10 109 | {"type": "conv2d", "depth": 512}, 110 | ] 111 | 112 | with pt.defaults_scope(**self.pt_defaults_scope_value()): 113 | feature_map = hourglass( 114 | image_tensor, hgd, 115 | net_type=self.options["hourglass_type"] if "hourglass_type" in self.options else None 116 | ) 117 | return feature_map 118 | 119 | -------------------------------------------------------------------------------- /net_modules/pt_group_connected.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | import tensorflow as tf 4 | 5 | from prettytensor import layers 6 | from prettytensor import parameters 7 | from prettytensor import pretty_tensor_class as prettytensor 8 | from prettytensor.pretty_tensor_class import PROVIDED 9 | 10 | # pylint: disable=invalid-name 11 | @prettytensor.Register(assign_defaults=('activation_fn', 'l2loss', 12 | 'parameter_modifier', 'phase')) 13 | class group_connected(prettytensor.VarStoreMethod): 14 | 15 | def __call__( 16 | self, 17 | input_layer, 18 | size, 19 | activation_fn=None, 20 | l2loss=None, 21 | weights=None, 22 | bias=tf.zeros_initializer(), 23 | transpose_weights=False, 24 | phase=prettytensor.Phase.train, 25 | parameter_modifier=parameters.identity, 26 | tie_groups=False, 27 | name=PROVIDED 28 | ): 29 | """Adds the parameters for a fully connected layer and returns a tensor. 30 | The current PrettyTensor must have rank 2. 31 | Args: 32 | input_layer: The Pretty Tensor object, supplied. 33 | size: The number of neurons 34 | activation_fn: A tuple of (activation_function, extra_parameters). Any 35 | function that takes a tensor as its first argument can be used. More 36 | common functions will have summaries added (e.g. relu). 37 | l2loss: Set to a value greater than 0 to use L2 regularization to decay 38 | the weights. 39 | weights: An initializer for weights or a Tensor. If not specified, 40 | uses He's initialization. 41 | bias: An initializer for the bias or a Tensor. No bias if set to None. 42 | transpose_weights: Flag indicating if weights should be transposed; 43 | this is useful for loading models with a different shape. 44 | phase: The phase of graph construction. See `pt.Phase`. 45 | parameter_modifier: A function to modify parameters that is applied after 46 | creation and before use. 47 | name: The name for this operation is also used to create/find the 48 | parameter variables. 49 | Returns: 50 | A Pretty Tensor handle to the layer. 51 | Raises: 52 | ValueError: if the Pretty Tensor is not rank 2 or the number of input 53 | nodes (second dim) is not known. 54 | """ 55 | if input_layer.get_shape().ndims != 3: 56 | raise ValueError( 57 | 'group_connected requires a rank 3 Tensor with known 2nd and 3rd ' 58 | 'dimension: %s' % input_layer.get_shape()) 59 | group_num = input_layer.shape[1] 60 | in_size = input_layer.shape[2] 61 | if group_num is None: 62 | raise ValueError('Number of groups must be known.') 63 | if in_size is None: 64 | raise ValueError('Number of input nodes must be known.') 65 | books = input_layer.bookkeeper 66 | if weights is None: 67 | weights = layers.he_init(in_size, size, activation_fn) 68 | 69 | dtype = input_layer.tensor.dtype 70 | weight_shape = [group_num, size, in_size] if transpose_weights else [group_num, in_size, size] 71 | 72 | params_var = parameter_modifier( 73 | 'weights', 74 | self.variable('weights', weight_shape, 75 | weights, dt=dtype), 76 | phase) 77 | 78 | if tie_groups and phase == prettytensor.Phase.train: 79 | with tf.variable_scope("weight_tying"): 80 | params = tf.tile(tf.reduce_mean(params_var, axis=0, keep_dims=True), [group_num, 1, 1]) 81 | with tf.control_dependencies([tf.assign(params_var, params)]): 82 | params = tf.identity(params) 83 | else: 84 | params = params_var 85 | 86 | input_tensor = tf.expand_dims(input_layer, axis=-2) 87 | params_tensor = tf.tile(tf.expand_dims(params, axis=0), [tf.shape(input_tensor)[0], 1, 1, 1]) 88 | y = tf.matmul(input_tensor, params_tensor, transpose_b=transpose_weights, name=name) 89 | y = tf.squeeze(y, axis=2) 90 | layers.add_l2loss(books, params, l2loss) 91 | if bias is not None: 92 | y += parameter_modifier( 93 | 'bias', 94 | self.variable('bias', [size], bias, dt=dtype), 95 | phase) 96 | 97 | if activation_fn is not None: 98 | if not isinstance(activation_fn, collections.Sequence): 99 | activation_fn = (activation_fn,) 100 | y = layers.apply_activation(books, 101 | y, 102 | activation_fn[0], 103 | activation_args=activation_fn[1:]) 104 | books.add_histogram_summary(y, '%s/activations' % y.op.name) 105 | return input_layer.with_tensor(y, parameters=self.vars) 106 | # pylint: enable=invalid-name 107 | 108 | 109 | def main(): 110 | import prettytensor as pt 111 | input_tensor = tf.zeros([64, 5, 7]) 112 | output_tensor = pt.wrap(input_tensor).group_connected(15).tensor 113 | print(output_tensor) 114 | 115 | 116 | if __name__ == "main": 117 | main() 118 | -------------------------------------------------------------------------------- /nets/distribution/bernoulli.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from collections import OrderedDict 4 | import math 5 | import zutils.tf_math_funcs as tmf 6 | import nets.distribution.generic 7 | import nets.distribution.category 8 | 9 | CategoryFactory = nets.distribution.category.Factory 10 | 11 | GenericFactory = nets.distribution.generic.Factory 12 | 13 | epsilon = tmf.epsilon 14 | 15 | 16 | class Factory(GenericFactory): 17 | 18 | def __init__(self, tau=0., **kwargs): 19 | self.categ_dist = CategoryFactory(pi=[0.5, 0.5], tau=tau) 20 | 21 | @staticmethod 22 | def param_num(): 23 | return 1 24 | 25 | @staticmethod 26 | def param_dict(p=0.5): 27 | return OrderedDict(p=p) 28 | 29 | @classmethod 30 | def transform2param(cls, input_tensor, latent_dim): 31 | """ Create network for converting input_tensor to distribution parameters 32 | 33 | :param input_tensor: (posterior phase) input tensor for the posterior 34 | :param latent_dim: dimension of the latent_variables 35 | :return: param_tensor - distribution parameters 36 | """ 37 | param_tensor = tf.sigmoid(input_tensor) 38 | return param_tensor 39 | 40 | @classmethod 41 | def parametrize(cls, param_tensor, latent_dim): 42 | """ Create network for converting parameter_tensor to parameter dictionary 43 | 44 | :param param_tensor: (posterior phase) input tensor for the posterior 45 | :param latent_dim: dimension of the latent_variables 46 | :return: dist_param - distribution parameters 47 | """ 48 | dist_param = cls.param_dict( 49 | p=param_tensor, 50 | ) 51 | return dist_param 52 | 53 | @classmethod 54 | def deparametrize(cls, dist_param): 55 | param_tensor = tmf.expand_minimum_ndim(dist_param["p"], axis=0) 56 | return param_tensor 57 | 58 | def nll(self, dist_param, samples): 59 | """ Compute negative log likelihood on given sample and distribution parameter 60 | 61 | :param samples: samples for evaluating PDF 62 | :param dist_param: input for the posterior 63 | :return: likelihood - likelihood to draw such samples from given distribution 64 | :return: is_atomic - is atomic, scalar or the same size as likelihood 65 | """ 66 | likelihood = tf.where(samples > 0.5, dist_param["p"], 1.0-dist_param["p"]) 67 | bernoulli_nll = -tf.log(likelihood+epsilon) 68 | return bernoulli_nll, True 69 | 70 | def sampling(self, dist_param, batch_size, latent_dim): 71 | """ Create network for VAE latent variables (sampling only) 72 | 73 | :param dist_param: input for the posterior 74 | :param batch_size: batch size 75 | :param latent_dim: dimension of the latent_variables 76 | :return: samples - random samples from either posterior or prior distribution 77 | """ 78 | 79 | # generate random samples 80 | if self.categ_dist.tau>0.0: 81 | # soft sampling 82 | p = self.deparametrize(dist_param) 83 | categ_dist_param = OrderedDict() 84 | categ_dist_param["K"] = 2 85 | t_p = tf.reshape(p, [batch_size, 1, latent_dim]) 86 | categ_dist_param["pi"] = tf.concat([1.0-t_p, t_p], axis=1) 87 | categ_samples = self.categ_dist.sampling(dist_param, batch_size, latent_dim) 88 | return categ_samples[:,1] 89 | else: 90 | # hard sampling 91 | rho = tf.random_uniform([batch_size, latent_dim]) 92 | return self.inv_cdf(dist_param, rho) 93 | 94 | def inv_cdf(self, dist_param, rho): 95 | p = self.deparametrize(dist_param) 96 | return tf.to_float(rho > 1.0-p) 97 | 98 | @staticmethod 99 | def self_entropy(dist_param): 100 | p = dist_param["p"] 101 | se = -p*tf.log(p+epsilon) - (1.0-p)*tf.log(1.0-p+epsilon) 102 | return se 103 | 104 | @classmethod 105 | def kl_divergence(cls, dist_param, ref_dist_param, ref_dist_type=None): 106 | if not isinstance(ref_dist_type, cls) and ref_dist_type is not None: # handle hybrid distribution 107 | return None 108 | 109 | p = dist_param["p"] 110 | p0 = ref_dist_param["p"] 111 | homo_kl = p*tf.log(p/(p0+epsilon)+epsilon) + (1.0-p)*tf.log((1.0-p)/(1.0-p0+epsilon)+epsilon) 112 | return homo_kl 113 | 114 | @classmethod 115 | def cross_entropy(cls, dist_param, ref_dist_param, ref_dist_type=None): 116 | if not isinstance(ref_dist_type, cls) and ref_dist_type is not None: # handle hybrid distribution 117 | return None 118 | 119 | p = dist_param["p"] 120 | p0 = ref_dist_param["p"] 121 | homo_ce = -p*tf.log(tf.clip_by_value(p0+epsilon, clip_value_min=epsilon, clip_value_max=1.)) - \ 122 | (1.0-p)*tf.log(tf.clip_by_value(1.0-p0+epsilon, clip_value_min=epsilon, clip_value_max=1.)) 123 | return homo_ce 124 | 125 | @staticmethod 126 | def mean(dist_param): 127 | p = dist_param["p"] 128 | return p 129 | -------------------------------------------------------------------------------- /nets/distribution/gaussian_fixedvar.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from collections import OrderedDict 3 | import math 4 | import zutils.tf_math_funcs as tmf 5 | import nets.distribution.generic 6 | 7 | GenericFactory = nets.distribution.generic.Factory 8 | 9 | epsilon = tmf.epsilon 10 | 11 | 12 | class Factory(GenericFactory): 13 | 14 | def __init__(self, **kwargs): 15 | if "options" in kwargs and "stddev" in kwargs["options"]: 16 | self.stddev = kwargs["options"]["stddev"] 17 | else: 18 | self.stddev = 1. 19 | 20 | @staticmethod 21 | def param_num(): 22 | return 1 23 | 24 | @staticmethod 25 | def param_dict(mean=0.0): 26 | return OrderedDict(mean=mean) 27 | 28 | @classmethod 29 | def transform2param(cls, input_tensor, latent_dim): 30 | """ Create network for converting input_tensor to distribution parameters 31 | 32 | :param input_tensor: (posterior phase) input tensor for the posterior 33 | :param latent_dim: dimension of the latent_variables 34 | :return: param_tensor - distribution parameters 35 | """ 36 | assert tmf.get_shape(input_tensor)[1] == latent_dim, "wrong dim" 37 | param_tensor = input_tensor 38 | return param_tensor 39 | 40 | @classmethod 41 | def parametrize(cls, param_tensor, latent_dim): 42 | """ Create network for converting parameter_tensor to parameter dictionary 43 | 44 | :param param_tensor: (posterior phase) input tensor for the posterior 45 | :param latent_dim: dimension of the latent_variables 46 | :return: dist_param - distribution parameters 47 | """ 48 | assert tmf.get_shape(param_tensor)[1] == latent_dim, "wrong dim" 49 | dist_param = cls.param_dict( 50 | mean=param_tensor, 51 | ) 52 | return dist_param 53 | 54 | @classmethod 55 | def deparametrize(cls, dist_param): 56 | param_tensor = tmf.expand_minimum_ndim(dist_param["mean"], 2) 57 | return param_tensor 58 | 59 | def nll(self, dist_param, samples): 60 | """ Compute negative log likelihood on given sample and distribution parameter 61 | 62 | :param samples: samples for evaluating PDF 63 | :param dist_param: input for the posterior 64 | :return: likelihood - likelihood to draw such samples from given distribution 65 | :return: is_atomic - is atomic, scalar or the same size as likelihood 66 | """ 67 | u = dist_param["mean"] 68 | s = self.stddev 69 | x = samples 70 | 71 | gaussian_nll = 0.5*(tf.square((x-u)/s) + 72 | 2.0*tf.log(s) + math.log(2.0*math.pi)) 73 | 74 | x = tf.check_numerics(x, "gaussian nll inf or nan", "gaussian_nll_check") 75 | 76 | return gaussian_nll, False 77 | 78 | def sampling(self, dist_param, batch_size, latent_dim): 79 | """ Create network for VAE latent variables (sampling only) 80 | 81 | :param dist_param: input for the posterior 82 | :param batch_size: batch size 83 | :param latent_dim: dimension of the latent_variables 84 | :return: samples - random samples from either posterior or prior distribution 85 | """ 86 | 87 | # generate random samples 88 | varepsilon = tf.random_normal([batch_size, latent_dim]) 89 | samples = dist_param["mean"] + varepsilon * self.stddev 90 | return samples 91 | 92 | """ 93 | @staticmethod 94 | def inv_cdf(dist_param, rho): 95 | pass 96 | """ 97 | 98 | def self_entropy(self, dist_param): 99 | s = self.stddev 100 | gaussian_entropy = 0.5*(math.log(2*math.pi)+1.0) + tf.log(s + epsilon) 101 | 102 | gaussian_entropy = tf.ones_like(dist_param["mean"]) * gaussian_entropy 103 | 104 | return gaussian_entropy 105 | 106 | def kl_divergence(self, dist_param, ref_dist_param, ref_dist_type=None): 107 | if not isinstance(ref_dist_type, type(self)) and ref_dist_type is not None: # handle hybrid distribution 108 | return None 109 | 110 | u0 = ref_dist_param["mean"] 111 | s0 = self.stddev 112 | u = dist_param["mean"] 113 | s = self.stddev 114 | 115 | homo_kl = (tf.square(s) + tf.square(u-u0) - 1.0) / (2*s0) 116 | 117 | return homo_kl 118 | 119 | def cross_entropy(self, dist_param, ref_dist_param, ref_dist_type=None): 120 | if not isinstance(ref_dist_type, type(self)) and ref_dist_type is not None: # handle hybrid distribution 121 | return None 122 | 123 | u0 = ref_dist_param["mean"] 124 | s0 = self.stddev 125 | u = dist_param["mean"] 126 | s = self.stddev 127 | 128 | homo_ce = 0.5*(math.log(2*math.pi)+1.0) + tf.log(s0 + epsilon) + (tf.square(s) + tf.square(u-u0) - 1.0) / (2*s0) 129 | 130 | return homo_ce 131 | 132 | @staticmethod 133 | def mean(dist_param): 134 | mean = dist_param["mean"] 135 | return mean 136 | 137 | @staticmethod 138 | def sample_to_real(samples): 139 | return samples 140 | 141 | @staticmethod 142 | def real_to_samples(samples_in_real): 143 | return samples_in_real 144 | -------------------------------------------------------------------------------- /runner/one_epoch_runner.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import datetime 3 | import numpy as np 4 | import zutils.tf_math_funcs as tmf 5 | 6 | from zutils.py_utils import * 7 | from scipy.io import savemat 8 | 9 | class OneEpochRunner: 10 | 11 | def __init__( 12 | self, data_module, output_list=None, 13 | net_func=None, batch_axis=0, num_samples=None, disp_time_interval=2, 14 | output_fn=None, is_large=False): 15 | 16 | self.data_module = data_module 17 | self.num_samples = self.data_module.num_samples() 18 | self.batch_axis = batch_axis 19 | self.disp_time_interval = disp_time_interval 20 | self.output_fn = output_fn 21 | self.is_large = is_large 22 | 23 | if num_samples is not None: 24 | if self.num_samples < num_samples: 25 | print("specified number_samples is larger than one epoch") 26 | else: 27 | self.num_samples = num_samples 28 | 29 | self.use_net_func = output_list is None # otherwise use net_func 30 | if self.use_net_func: 31 | assert net_func is not None, \ 32 | "output_list and net_func should not be both specified" 33 | self.net_func = net_func 34 | # remark: net_func(sess) 35 | else: 36 | assert net_func is None, \ 37 | "one of output_list and net_func must be specified" 38 | self.output_list = output_list 39 | [self.flatten_output_list, self.output_wrap_func] = \ 40 | recursive_flatten_with_wrap_func( 41 | lambda x: tmf.is_tf_data(x), self.output_list) 42 | 43 | self.data_module.reset() 44 | self.cur_sample_end = 0 45 | 46 | def run_single_batch(self, sess): 47 | 48 | if self.cur_sample_end >= self.num_samples: 49 | return None 50 | 51 | if self.use_net_func: 52 | output_val = self.net_func(sess) 53 | else: 54 | output_val = sess.run(self.flatten_output_list, {}) 55 | output_val = self.output_wrap_func(output_val) 56 | 57 | batch_size = first_element_apply( 58 | lambda x: isinstance(x, np.ndarray), 59 | lambda x: x.shape[self.batch_axis], output_val) 60 | self.batch_size = batch_size 61 | 62 | new_end = self.cur_sample_end + batch_size 63 | if new_end > self.num_samples: 64 | effective_batch_size = \ 65 | batch_size - (new_end-self.num_samples) 66 | slice_indexes = (slice(None),)*self.batch_axis + (slice(effective_batch_size),) 67 | output_val = recursive_apply( 68 | lambda x: isinstance(x, np.ndarray), 69 | lambda x: x[slice_indexes], output_val) 70 | self.cur_sample_end = new_end 71 | return output_val 72 | 73 | def run(self, sess): 74 | disp_countdown = IfTimeout(self.disp_time_interval) 75 | num_samples_total = self.num_samples 76 | 77 | output_val_single = self.run_single_batch(sess) 78 | output_val = [] 79 | 80 | while output_val_single is not None: 81 | output_val += [output_val_single] 82 | 83 | iter = self.data_module.iter() 84 | if self.data_module.epoch() == 0: 85 | num_samples_finished = self.data_module.num_samples_finished() 86 | else: 87 | num_samples_finished = self.num_samples 88 | 89 | if disp_countdown.is_timeout(): 90 | epoch_percentage = num_samples_finished / num_samples_total * 100 91 | print("%s] Iter %d (%4.1f%% = %d / %d)" % 92 | (datetime.datetime.now().strftime('%Y-%m/%d-%H:%M:%S.%f'), 93 | iter, epoch_percentage, num_samples_finished, num_samples_total)) 94 | disp_countdown = IfTimeout(self.disp_time_interval) 95 | 96 | 97 | if self.is_large and (num_samples_finished % (100*self.batch_size) == 0 or num_samples_finished == self.num_samples): 98 | output_val = recursive_apply( 99 | lambda *args: isinstance(args[0], np.ndarray), 100 | lambda *args: np.concatenate(args, axis=self.batch_axis), 101 | *output_val) 102 | self.dir_path = os.path.dirname(self.output_fn+'_'+'%06d'%num_samples_finished) 103 | if not os.path.exists(self.dir_path): 104 | os.makedirs(self.dir_path) 105 | savemat(self.output_fn+'_'+'%06d'%num_samples_finished+'.mat',output_val) 106 | print('Saving part of output to '+ self.output_fn+'_'+'%06d'%num_samples_finished+'.mat') 107 | output_val = [] 108 | output_val_single = self.run_single_batch(sess) 109 | 110 | if not self.is_large: 111 | output_val = recursive_apply( 112 | lambda *args: isinstance(args[0], np.ndarray), 113 | lambda *args: np.concatenate(args, axis=self.batch_axis), 114 | *output_val) 115 | savemat(self.output_fn + ".mat", output_val) 116 | print('Saving output to ' + self.output_fn + ".mat") 117 | 118 | 119 | -------------------------------------------------------------------------------- /net_modules/pt_patch_batch_normalization.py: -------------------------------------------------------------------------------- 1 | from prettytensor.bookkeeper import * 2 | from prettytensor.bookkeeper import _bare_var_name 3 | from zutils.pt_utils import default_phase, pt 4 | from prettytensor import pretty_tensor_class as prettytensor 5 | 6 | 7 | # this is a monkey patch for exponential_moving_average in order to improve batch normalization 8 | 9 | 10 | class PatchedBookkeeper(Bookkeeper): 11 | 12 | old_exponential_moving_average = Bookkeeper.exponential_moving_average 13 | 14 | def exponential_moving_average( 15 | self, 16 | var, 17 | avg_var=None, 18 | decay=0.999, 19 | ignore_nan=False 20 | ): 21 | """Calculates the exponential moving average. 22 | TODO(): check if this implementation of moving average can now 23 | be replaced by tensorflows implementation. 24 | Adds a variable to keep track of the exponential moving average and adds an 25 | update operation to the bookkeeper. The name of the variable is 26 | '%s_average' % name prefixed with the current variable scope. 27 | Args: 28 | var: The variable for which a moving average should be computed. 29 | avg_var: The variable to set the average into, if None create a zero 30 | initialized one. 31 | decay: How much history to use in the moving average. 32 | Higher, means more history values [0, 1) accepted. 33 | ignore_nan: If the value is NaN or Inf, skip it. 34 | Returns: 35 | The averaged variable. 36 | Raises: 37 | ValueError: if decay is not in [0, 1). 38 | """ 39 | 40 | with self._g.as_default(): 41 | if decay < 0 or decay >= 1.0: 42 | raise ValueError('Decay is %5.2f, but has to be in [0, 1).' % decay) 43 | if avg_var is None: 44 | avg_name = '%s_average' % _bare_var_name(var) 45 | with tf.control_dependencies(None): 46 | with tf.name_scope(avg_name + '/Initializer/'): 47 | if isinstance(var, tf.Variable): 48 | init_val = var.initialized_value() 49 | elif var.get_shape().is_fully_defined(): 50 | init_val = tf.constant(0, 51 | shape=var.get_shape(), 52 | dtype=var.dtype.base_dtype) 53 | else: 54 | init_val = tf.constant(0, dtype=var.dtype.base_dtype) 55 | avg_var = tf.Variable(init_val, name=avg_name, trainable=False) 56 | 57 | avg_name = _bare_var_name(avg_var) 58 | num_updates = tf.get_variable( 59 | name='%s_numupdates' % avg_name, shape=(), 60 | dtype=tf.int64, initializer=tf.constant_initializer(0, dtype=tf.int64), 61 | trainable=False 62 | ) 63 | is_running = tf.get_variable( 64 | name='%s_isrunning' % avg_name, shape=(), 65 | dtype=tf.bool, initializer=tf.constant_initializer(True, dtype=tf.bool), 66 | trainable=False 67 | ) 68 | 69 | exact_avg_mode = tf.group( 70 | tf.assign(is_running, False), tf.assign(num_updates, 0), tf.variables_initializer([avg_var]) 71 | ) 72 | running_avg_mode = tf.assign(is_running, True) 73 | 74 | # op to switch between running and non-running mode 75 | avg_var.is_switchable_avg = True 76 | avg_var.exact_avg_mode = exact_avg_mode 77 | avg_var.running_avg_mode = running_avg_mode 78 | avg_var.num_updates = num_updates 79 | 80 | # compute decay 81 | 82 | num_updates, is_running, _ = tf.tuple([num_updates, is_running, tf.assign_add(num_updates, 1)]) 83 | 84 | # -------------- the following control does not work------ Looks like a tensorflow bug when using Variables 85 | # with tf.control_dependencies([tf.assign_add(num_updates, 1)]): 86 | # num_updates = tf.identity(num_updates) 87 | # -------------------------------------------------------- 88 | 89 | num_updates_f = tf.cast(num_updates, tf.float32) 90 | running_decay = tf.minimum( 91 | decay, 92 | tf.maximum(0.9, (1.0 + num_updates_f) / (10.0 + num_updates_f)) 93 | ) 94 | exact_decay = (num_updates_f - 1.)/num_updates_f 95 | decay = tf.where(is_running, running_decay, exact_decay) 96 | # decay = tf.check_numerics(decay, "avg_decay_failed", "%s_avg_decay" % avg_name) 97 | 98 | # apply average 99 | with tf.device(avg_var.device): 100 | if ignore_nan: 101 | var = tf.where(tf.is_finite(var), var, avg_var) 102 | if var.get_shape().is_fully_defined(): 103 | avg_update = tf.assign_sub(avg_var, (1 - decay) * (avg_var - var)) 104 | else: 105 | avg_update = tf.assign(avg_var, 106 | avg_var - (1 - decay) * (avg_var - var), 107 | validate_shape=False) 108 | self._g.add_to_collection(GraphKeys.UPDATE_OPS, avg_update) 109 | 110 | return avg_update 111 | 112 | Bookkeeper.exponential_moving_average = PatchedBookkeeper.exponential_moving_average 113 | -------------------------------------------------------------------------------- /nets/distribution/gaussian.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from collections import OrderedDict 3 | import math 4 | import zutils.tf_math_funcs as tmf 5 | import nets.distribution.generic 6 | 7 | GenericFactory = nets.distribution.generic.Factory 8 | 9 | epsilon = tmf.epsilon 10 | 11 | 12 | class Factory(GenericFactory): 13 | 14 | default_stddev_lower_bound = epsilon 15 | 16 | def __init__(self, **kwargs): 17 | if "options" in kwargs: 18 | self.options = kwargs["options"] 19 | else: 20 | self.options = dict() 21 | if "stddev_lower_bound" not in self.options: 22 | self.options["stddev_lower_bound"] = epsilon 23 | 24 | @staticmethod 25 | def param_num(): 26 | return 2 27 | 28 | @staticmethod 29 | def param_dict(mean=0.0, stddev=1.0): 30 | return OrderedDict(mean=mean, stddev=stddev) 31 | 32 | def transform2param(self, input_tensor, latent_dim): 33 | """ Create network for converting input_tensor to distribution parameters 34 | 35 | :param input_tensor: (posterior phase) input tensor for the posterior 36 | :param latent_dim: dimension of the latent_variables 37 | :return: param_tensor - distribution parameters 38 | """ 39 | param_tensor = tf.concat( 40 | [input_tensor[:, :latent_dim], 41 | tf.maximum(tmf.atanh_sigmoid(input_tensor[:, latent_dim:]), self.options["stddev_lower_bound"])], 42 | axis=1 43 | ) 44 | return param_tensor 45 | 46 | @classmethod 47 | def parametrize(cls, param_tensor, latent_dim): 48 | """ Create network for converting parameter_tensor to parameter dictionary 49 | 50 | :param param_tensor: (posterior phase) input tensor for the posterior 51 | :param latent_dim: dimension of the latent_variables 52 | :return: dist_param - distribution parameters 53 | """ 54 | dist_param = cls.param_dict( 55 | mean=param_tensor[:, :latent_dim], 56 | stddev=param_tensor[:, latent_dim:] 57 | ) 58 | return dist_param 59 | 60 | @classmethod 61 | def deparametrize(cls, dist_param): 62 | param_tensor = tf.concat( 63 | [tmf.expand_minimum_ndim(dist_param["mean"], 2), 64 | tmf.expand_minimum_ndim(dist_param["stddev"], 2)], axis=1) 65 | return param_tensor 66 | 67 | @staticmethod 68 | def nll(dist_param, samples): 69 | """ Compute negative log likelihood on given sample and distribution parameter 70 | 71 | :param samples: samples for evaluating PDF 72 | :param dist_param: input for the posterior 73 | :return: likelihood - likelihood to draw such samples from given distribution 74 | :return: is_atomic - is atomic, scalar or the same size as likelihood 75 | """ 76 | u = dist_param["mean"] 77 | s = dist_param["stddev"] 78 | x = samples 79 | 80 | gaussian_nll = 0.5*(tf.square((x-u)/s) + 81 | 2.0*tf.log(s) + math.log(2.0*math.pi)) 82 | 83 | return gaussian_nll, False 84 | 85 | @staticmethod 86 | def sampling(dist_param, batch_size, latent_dim): 87 | """ Create network for VAE latent variables (sampling only) 88 | 89 | :param dist_param: input for the posterior 90 | :param batch_size: batch size 91 | :param latent_dim: dimension of the latent_variables 92 | :return: samples - random samples from either posterior or prior distribution 93 | """ 94 | 95 | # generate random samples 96 | varepsilon = tf.random_normal([batch_size, latent_dim]) 97 | samples = dist_param["mean"] + varepsilon * dist_param["stddev"] 98 | return samples 99 | 100 | """ 101 | @staticmethod 102 | def inv_cdf(dist_param, rho): 103 | pass 104 | """ 105 | 106 | @staticmethod 107 | def self_entropy(dist_param): 108 | s = dist_param["stddev"] 109 | gaussian_entropy = 0.5*(math.log(2*math.pi)+1.0) + tf.log(s + epsilon) 110 | return gaussian_entropy 111 | 112 | @classmethod 113 | def kl_divergence(cls, dist_param, ref_dist_param, ref_dist_type=None): 114 | if not isinstance(ref_dist_type, cls) and ref_dist_type is not None: # handle hybrid distribution 115 | return None 116 | 117 | u0 = ref_dist_param["mean"] 118 | s0 = ref_dist_param["stddev"] 119 | u = dist_param["mean"] 120 | s = dist_param["stddev"] 121 | 122 | homo_kl = tf.log(s0 + epsilon) - tf.log(s + epsilon) + \ 123 | 0.5*(tf.square(s) + tf.square(u-u0)) / tf.square(s0) - 0.5 124 | 125 | return homo_kl 126 | 127 | @classmethod 128 | def cross_entropy(cls, dist_param, ref_dist_param, ref_dist_type=None): 129 | if not isinstance(ref_dist_type, cls) and ref_dist_type is not None: # handle hybrid distribution 130 | return None 131 | 132 | u0 = ref_dist_param["mean"] 133 | s0 = ref_dist_param["stddev"] 134 | u = dist_param["mean"] 135 | s = dist_param["stddev"] 136 | 137 | homo_ce = 0.5*(math.log(2*math.pi)+1.0) + tf.log(s0 + epsilon) + \ 138 | 0.5 * (tf.square(s) + tf.square(u - u0)) / tf.square(s0) - 0.5 139 | 140 | return homo_ce 141 | 142 | @staticmethod 143 | def mean(dist_param): 144 | mean = dist_param["mean"] 145 | return mean 146 | 147 | @staticmethod 148 | def sample_to_real(samples): 149 | return samples 150 | 151 | @staticmethod 152 | def real_to_samples(samples_in_real): 153 | return samples_in_real 154 | -------------------------------------------------------------------------------- /net_modules/auto_struct/keypoint_decoder.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from net_modules import keypoints_2d 3 | import tensorflow as tf 4 | import zutils.tf_math_funcs as tmf 5 | import numpy as np 6 | import zutils.pt_utils as ptu 7 | import prettytensor as pt 8 | import math 9 | 10 | import warnings 11 | 12 | from net_modules.auto_struct.generic_decoder import Factory as BaseFactory 13 | 14 | from net_modules.auto_struct.keypoint_encoder import Factory as EncoderFactory 15 | 16 | 17 | class BasicFactory(BaseFactory): 18 | 19 | structure_param_num = EncoderFactory.structure_param_num 20 | 21 | def __init__(self, recon_dist_param_num=1, options=None): 22 | super().__init__(recon_dist_param_num, options) 23 | self.keypoint_init() 24 | 25 | def keypoint_init(self): 26 | 27 | self.base_gaussian_stddev = keypoints_2d.gaussian_2d_base_stddev 28 | if "base_gaussian_stddev" in self.options and self.options["base_gaussian_stddev"] is not None: 29 | self.base_gaussian_stddev = self.options["base_gaussian_stddev"] 30 | 31 | self.use_background_feature = "background_as_patch" not in self.options or self.options["background_as_patch"] 32 | 33 | @abstractmethod 34 | def image_size(self): 35 | raise ValueError("Must specify the image size.") 36 | return None, None 37 | 38 | def structure2heatmap(self, structure_param, extra_inputs=None): 39 | 40 | if extra_inputs is None: 41 | extra_inputs = dict() 42 | 43 | h, w = self.image_size() 44 | 45 | keypoint_param_raw = structure_param 46 | 47 | if "heatmap_stddev_for_patch_features" in extra_inputs and \ 48 | extra_inputs["heatmap_stddev_for_patch_features"] is not None: 49 | the_base_gaussian_stddev = extra_inputs["heatmap_stddev_for_patch_features"] 50 | else: 51 | the_base_gaussian_stddev = tf.ones_like(keypoint_param_raw)*self.base_gaussian_stddev 52 | 53 | def param2heatmap(std_scale=1): 54 | if std_scale == 1: 55 | keypoint_param = keypoint_param_raw 56 | else: 57 | param_dim = tmf.get_shape(keypoint_param_raw)[2] 58 | if param_dim == 2: 59 | keypoint_param = tf.concat([ 60 | keypoint_param_raw, the_base_gaussian_stddev*std_scale 61 | ], axis=2) 62 | elif param_dim == 3: 63 | keypoint_param = tf.concat([ 64 | keypoint_param_raw[:, :, :2], 65 | keypoint_param_raw[:, :, 2:3]*std_scale 66 | ], axis=2) 67 | else: 68 | keypoint_param = tf.concat([ 69 | keypoint_param_raw[:, :, :2], 70 | keypoint_param_raw[:, :, 2:4]*std_scale 71 | ], axis=2) 72 | if param_dim==5: 73 | keypoint_param = tf.concat([ 74 | keypoint_param, 75 | keypoint_param_raw[:, :, 4:5] 76 | ], axis=2) 77 | keypoint_map = keypoints_2d.gaussian_coordinate_to_keypoint_map( 78 | keypoint_param, h, w 79 | ) 80 | keypoint_map_with_bg = tf.concat( 81 | [keypoint_map, tf.ones_like(keypoint_map[:, :, :, 0:1]) * (1. / (h * w))], axis=3 82 | ) 83 | keypoint_map_with_bg /= tf.reduce_sum(keypoint_map_with_bg, axis=3, keep_dims=True) 84 | return keypoint_map_with_bg 85 | 86 | if "keypoint_decoding_heatmap_levels" not in self.options or \ 87 | self.options["keypoint_decoding_heatmap_levels"] == 1: 88 | return param2heatmap() 89 | else: 90 | assert "keypoint_decoding_heatmap_level_base" in self.options, \ 91 | "keypoint_decoding_heatmap_level_base must be specified" 92 | b = self.options["keypoint_decoding_heatmap_level_base"] 93 | s = 1. 94 | keypoint_map_list = list() 95 | for i in range(self.options["keypoint_decoding_heatmap_levels"]): 96 | keypoint_map_list.append(param2heatmap(std_scale=s)) 97 | s /= b 98 | return keypoint_map_list 99 | 100 | def structure_param2euclidean(self, structure_param): 101 | return keypoints_2d.gaussian2dparam_to_recon_code(structure_param) 102 | 103 | def post_image_reconstruction(self, im, extra_inputs=None): 104 | extra_outputs = dict() 105 | extra_outputs["save"] = dict() 106 | return im 107 | 108 | 109 | class Factory(BasicFactory): 110 | 111 | def __init__(self, *args, **kwargs): 112 | super().__init__(*args, **kwargs) 113 | 114 | def latent2structure_patch_overall_generic(self, latent_tensor): 115 | 116 | keypoint_num = self.options["keypoint_num"] 117 | 118 | batch_size = tmf.get_shape(latent_tensor)[0] 119 | total_dim = tmf.get_shape(latent_tensor)[1] 120 | 121 | cur_dim = 0 122 | keypoint_param_dim = self.structure_param_num * keypoint_num 123 | keypoint_tensor = latent_tensor[:, :keypoint_param_dim] 124 | keypoint_param = tf.reshape(keypoint_tensor, [batch_size, keypoint_num, -1]) 125 | cur_dim += keypoint_param_dim 126 | 127 | if self.patch_feature_dim is not None and self.patch_feature_dim > 0: 128 | all_patch_feat_dims = (keypoint_num+1)*self.patch_feature_dim 129 | patch_tensor = latent_tensor[:, keypoint_param_dim:keypoint_param_dim+all_patch_feat_dims] 130 | patch_features = tf.reshape(patch_tensor, [ 131 | batch_size, keypoint_num + (1 if self.use_background_feature else 0), 132 | self.patch_feature_dim 133 | ]) 134 | cur_dim += all_patch_feat_dims 135 | else: 136 | patch_features = None 137 | 138 | if total_dim > cur_dim: 139 | overall_features = latent_tensor[:, cur_dim:] 140 | if self.overall_feature_dim is None or self.overall_feature_dim+cur_dimplease copy the link and paste it in a new window to download. (Direct clicking the link may not work due to GitHub's way to handle links with redirection)*** 27 | 28 | - The [CelebA](https://drive.google.com/drive/folders/0B7EVK8r0v71pWEZsZE9oNnFzTm8) dataset, saved in the `data/celeba_images` folder. 29 | - [ground truth landmark annotation](http://files.ytzhang.net/lmdis-rep/release-v1/celeba/celeba_data.tar.gz) for pre-processed CelebA images, saved in the `data/celeba_data`. 30 | - The [AFLW](http://files.ytzhang.net/lmdis-rep/release-v1/aflw/aflw_images.tar.gz) dataset of preprocessed images, so that the face is on the similar position as face in CelebA images, saved in `data/aflw_images` folder. 31 | - [ground truth landmark annotation](http://files.ytzhang.net/lmdis-rep/release-v1/aflw/aflw_data.tar.gz) for pre-processed AFLW images, saved in `data/aflw_data` folder. 32 | - The [CAT](http://files.ytzhang.net/lmdis-rep/release-v1/cat/cat_images.tar.gz) dataset of preprocessed images, so that the face is on the similar position as face in CAT images, saved in `data/cat_images` folder. 33 | - [ground truth landmark annotation](http://files.ytzhang.net/lmdis-rep/release-v1/cat/cat_data.tar.gz) for pre-processed CAT images, saved in `data/cat_data` folder. 34 | 35 | ## Pretrained Models to download 36 | 37 | ***For all the links, please copy the link and paste it in a new window to download. (Direct clicking the link may not work due to GitHub's way to handle links with redirection)*** 38 | 39 | The pretrained models for CelebA dataset can be obtained via [this link](http://files.ytzhang.net/lmdis-rep/release-v1/celeba/celeba_pretrained_results.tar.gz), which detect 10 or 30 landmarks on the image. 40 | 41 | The pretrained models for AFLW dataset can be obtained via [this link](http://files.ytzhang.net/lmdis-rep/release-v1/aflw/aflw_pretrained_results.tar.gz), which detect 10 or 30 landmarks on the image. 42 | 43 | The pretrained models for AFLW dataset can be obtained via [this link](http://files.ytzhang.net/lmdis-rep/release-v1/cat/cat_pretrained_results.tar.gz), which detect 10 or 20 landmarks on the image. 44 | 45 | Running `./download_celeba.sh`, `./download_aflw.sh` and `./download_cat.sh` will automatically download pretrained models and data for experiment on each dataset. The pretrained model will be saved in `pretrained_results/celeba_10`, `pretrained_results/celeba_30`, `pretrained_results/aflw_10`, `pretrained_results/aflw_30`, `pretrained_results/cat_10`, `pretrained_results/cat_20` . And the data will be saved in `data/celeba_data`, `data/aflw_data`, `data/aflw_images`, `data/cat_data`, `data/cat_images`. Note that you should download the CelebA data by yourself into `data/celeba_images` 46 | 47 | ## Data and Model for Human3.6M 48 | 49 | Google Drive: [link](https://drive.google.com/drive/folders/1dFVEhg0UokpVK1ya5OJ7f7-Lxvrkc1Jz?usp=sharing) 50 | 51 | ## Demo on CelebA image samples using pre-trained model (quick demo) 52 | 53 | You can run a quick demo on CelebA images. 54 | 55 | - Download our pretrained model on CelebA. 56 | 57 | - Put the pretrained model in the same directory as defined in `one_step_test_celeba_demo.sh`, the directory is `pretrained_results/celeba_10` or `pretrained_results/celeba_30` 58 | 59 | - After that, run 60 | 61 | ./one_step_test_celeba_demo.sh 62 | 63 | - You should be able to view the visualization of the landmark discovery results in the `demo/output` folder created under the root of the project. If `SPECIFIC_MODEL_DIR` is `pretrained_results/celeba_10`, there are 10 detected landmarked on each image. If `SPECIFIC_MODEL_DIR` is `pretrained_results/celeba_30`, there are 30 detected landmarked on each image. 64 | 65 | To perform detection on other human face images, you can just put the images you are interested in into the `demo/input` folder, and rerun `./one_step_test_celeba_demo.sh` to see the detected landmarks on these images in `demo/output` folder. 66 | 67 | ## Training 68 | 69 | - Train the model on CelebA dataset for 10 landmarks: `python exp-ae-celeba-mafl-10.py` 70 | - Train the model on CelebA dataset for 30 landmarks: `python exp-ae-celeba-mafl-30.py` 71 | - Train the model on AFLW dataset for 10 landmarks: `python exp-ae-aflw-10.py`, the AFLW is finetuned based on pretrained model for CelebA dataset, so we must have `pretrained_results/celeba_10` downloaded. 72 | - Train the model on AFLW dataset for 30 landmarks: `python exp-ae-aflw-30.py`, the AFLW is finetuned based on pretrained model for CelebA dataset, so we must have `pretrained_results/celeba_30` downloaded. 73 | - Train the model on CAT dataset for 10 landmarks: `python exp-ae-cat-10.py` 74 | - Train the model on CAT dataset for 30 landmarks: `python exp-ae-cat-20.py` 75 | 76 | ## Evaluation 77 | 78 | - Test the model on CelebA dataset `./one_step_test_celeba.sh` 79 | - Test the model on AFLW dataset `./one_step_test_celeba.sh` 80 | - Test the model on CAT dataset `./one_step_test_cat.sh` 81 | 82 | In `one_step_test_celeba.sh`, `one_step_test_aflw.sh` and `one_step_test_cat.sh`, you can specify `SPECIFIC_MODEL_DIR` as the path of folder saving the trained checkpoint, and `SNAPSHOT_ITER` as the number of snapshot step you would like to test. If the snaphot step is not specified, the script automatically test on the lastest checkpoint. 83 | 84 | ## Visualization 85 | 86 | In `vis` folder, call the matlab function `vppAutoKeypointImageRecon(result_path, step, sample_ids, save_to_file, type_ids)` 87 | 88 | `result_path` is the folder saving training results, `step` is the training step you want to test on (the step must be saved in corresponding `test.snapshot` folder), `samples_ids` is the id of test samples you are interested in, `save_to_file` is whether you would like to save the visualization figures, `type_ids` is 'data-encoded' or 'recon-encoded'. 89 | 90 | For example, `vppAutoKeypointImageRecon('../results/celeba_10/', 220000, 1:20, false, 'data-encoded')` will visualize the discovered keypoints on test images 1~20 in test dataset. 91 | 92 | ## Remarks 93 | 94 | - The landmarks outputs by our model are in the ordering of `yx` (not the usual `xy`). 95 | - Models on other datasets and more code updates are coming soon. 96 | -------------------------------------------------------------------------------- /nets/data/cat_80x80.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import json 5 | import time 6 | import threading 7 | import random 8 | from multiprocessing.dummy import Pool 9 | from multiprocessing import cpu_count 10 | import scipy.io as sio 11 | 12 | class Net: 13 | def __init__(self, subset_name='train', options = None): 14 | 15 | self._debug = False 16 | self._shuffle = False 17 | self._cache_size = 3000 18 | self._mean_reduce = False 19 | self._mean = [5.0, 10.0, 15.0] 20 | if options != None and options != {}: 21 | if 'cache_size' in options: 22 | self._cache_size = options['cache_size'] 23 | if 'mean_reduce' in options: 24 | self._mean_reduce = options['mean_reduce'] 25 | if 'shuffle' in options: 26 | self._shuffle = options['shuffle'] 27 | if 'debug' in options: 28 | self._debug = options['debug'] 29 | 30 | current_path = os.path.dirname(os.path.abspath(__file__)) 31 | root_path = current_path[:-9] 32 | self._cat_train = root_path+'data/cat_data/cat_train_images.txt' 33 | self._cat_test = root_path+'data/cat_data/cat_test_images.txt' 34 | self._impath = root_path+'data/cat_images/' 35 | 36 | with open(self._cat_train, 'r') as f: 37 | self._train_imlist = f.read().splitlines() 38 | with open(self._cat_test, 'r') as f: 39 | self._test_imlist = f.read().splitlines() 40 | if subset_name == 'train': 41 | self._imlist = self._train_imlist 42 | if subset_name == 'test': 43 | self._imlist = self._test_imlist 44 | self._num_samples = len(self._imlist) 45 | self._waitlist = list(range(len(self._imlist))) 46 | if self._shuffle: 47 | random.shuffle(self._waitlist) 48 | self._dataset = None 49 | self._cur_pos = 0 # num of sample done in this epoch 50 | self._cur_epoch = 0 # current num of epoch 51 | self._cur_iter = 0 # num of batches returned 52 | self._num_fields = 1 53 | self._out_h = 80 54 | self._out_w = 80 55 | 56 | self._image_cache = [] 57 | 58 | self._lock = threading.Lock() 59 | 60 | # self.set_dataset() 61 | 62 | self._pool_size = cpu_count() 63 | 64 | self._pool = Pool(self._pool_size) 65 | self._cache_thread = threading.Thread(target=self.preload_dataset) 66 | self._cache_thread.start() 67 | 68 | def read_image(self, i): 69 | image_name = self._impath + self._imlist[i] 70 | # The channel for cv2.imread is B, G, R 71 | if not os.path.exists(image_name): 72 | print(image_name) 73 | image_arr = cv2.imread(image_name) 74 | #image_arr = cv2.resize(image_arr, (100,100)) 75 | h,w,_ = image_arr.shape 76 | margin_h = (h-self._out_h)//2 77 | margin_w = (w-self._out_w)//2 78 | image_arr = image_arr[margin_h:margin_h+self._out_h, margin_w:margin_w+self._out_w] 79 | result = image_arr.astype(np.float32) / np.array(255., dtype=np.float32) 80 | result[:, :, [0, 1, 2]] = result[:, :, [2, 1, 0]] 81 | 82 | return result 83 | 84 | def __call__(self, *args, **kwargs): 85 | return self.next_batch(*args, **kwargs) 86 | 87 | def num_samples(self): 88 | return self._num_samples 89 | 90 | def epoch(self): 91 | return self._cur_epoch 92 | 93 | def iter(self): 94 | return self._cur_iter 95 | 96 | def num_fields(self): 97 | return self._num_fields 98 | 99 | def num_samples_finished(self): 100 | return self._cur_pos 101 | 102 | def reset(self): 103 | """ Reset the state of the data loader 104 | E.g., the reader points at the beginning of the dataset again 105 | :return: None 106 | """ 107 | self._cur_pos = 0 108 | self._cur_epoch = 0 109 | self._cur_iter = 0 110 | self._waitlist = list(range(len(self._imlist))) 111 | if self._shuffle: 112 | random.shuffle(self._waitlist) 113 | tmp = 0 114 | while self._cache_thread.isAlive(): 115 | tmp+=1 116 | self._cache_thread = threading.Thread(target=self.preload_dataset) 117 | self._lock.acquire() 118 | self._image_cache = [] 119 | self._lock.release() 120 | self._cache_thread.start() 121 | 122 | def preload_dataset(self): 123 | if self._debug: 124 | print("preload") 125 | if len(self._image_cache) > self._cache_size: 126 | return 127 | else: 128 | while len(self._image_cache) < 1000: 129 | if len(self._waitlist) < 1000: 130 | self._waitlist += list(range(len(self._imlist))) 131 | if self._shuffle: 132 | random.shuffle(self._waitlist) 133 | 134 | results = self._pool.map(self.read_image, self._waitlist[:1000]) 135 | del self._waitlist[:1000] 136 | self._lock.acquire() 137 | self._image_cache = self._image_cache + list(results) 138 | self._lock.release() 139 | if self._debug: 140 | print(len(self._image_cache)) 141 | 142 | def next_batch(self, batch_size): 143 | """ fetch the next batch 144 | :param batch_size: next batch_size 145 | :return: a tuple includes all data 146 | """ 147 | if batch_size < 0: 148 | batch_size = 0 149 | if self._cache_size < 3 * batch_size: 150 | self._cache_size = 3 * batch_size 151 | 152 | this_batch = [None] * self._num_fields 153 | 154 | if len(self._image_cache) < batch_size: 155 | if self._debug: 156 | print("Blocking!!, Should only appear once with proper setting") 157 | 158 | if not self._cache_thread.isAlive(): 159 | self._cache_thread = threading.Thread(target=self.preload_dataset) 160 | self._cache_thread.start() 161 | self._cache_thread.join() 162 | 163 | self._lock.acquire() 164 | this_batch[0] = self._image_cache[0:batch_size] 165 | del self._image_cache[0:batch_size] 166 | self._lock.release() 167 | else: 168 | self._lock.acquire() 169 | this_batch[0] = self._image_cache[0:batch_size] 170 | del self._image_cache[0:batch_size] 171 | self._lock.release() 172 | if not self._cache_thread.isAlive(): 173 | self._cache_thread = threading.Thread(target=self.preload_dataset) 174 | self._cache_thread.start() 175 | 176 | self._cur_iter += 1 177 | self._cur_pos = self._cur_pos + batch_size 178 | if self._cur_pos >= self._num_samples: 179 | self._cur_epoch += 1 180 | self._cur_pos = self._cur_pos % self._num_samples 181 | 182 | return this_batch 183 | 184 | @staticmethod 185 | def output_types(): # only used for net instance 186 | t = ["float32"] 187 | return t 188 | 189 | @staticmethod 190 | def output_shapes(): 191 | t = [(None, 80, 80, 3)] # None for batch size 192 | return t 193 | 194 | @staticmethod 195 | def output_ranges(): 196 | return [1.] 197 | 198 | @staticmethod 199 | def output_keys(): 200 | return ["data"] 201 | 202 | 203 | if __name__ == '__main__': 204 | main() 205 | 206 | -------------------------------------------------------------------------------- /nets/data/aflw_80x80.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import threading 5 | import random 6 | from multiprocessing.dummy import Pool 7 | from multiprocessing import cpu_count 8 | 9 | 10 | class Net: 11 | def __init__(self, subset_name='train', options = None): 12 | 13 | """ Module for loading CelebA data 14 | :param subset_name: "train", "validation", "test" 15 | :param cache_dir: (default) "var/data/OMNIGLOT" 16 | """ 17 | self._debug = False 18 | self._shuffle = False 19 | self._cache_size = 3000 20 | self._mean_reduce = False 21 | self._mean = [5.0, 10.0, 15.0] 22 | if options != None and options != {}: 23 | if 'cache_size' in options: 24 | self._cache_size = options['cache_size'] 25 | if 'mean_reduce' in options: 26 | self._mean_reduce = options['mean_reduce'] 27 | # if 'shuffle' in options: 28 | # self._shuffle = options['shuffle'] 29 | if 'debug' in options: 30 | self._debug = options['debug'] 31 | 32 | current_path = os.path.dirname(os.path.abspath(__file__)) 33 | root_path = current_path[:-9] 34 | self._aflw_train = root_path+'data/aflw_data/aflw_train_images.txt' 35 | self._aflw_test = root_path+'data/aflw_data/aflw_test_images.txt' 36 | self._impath = root_path+'data/aflw_images/' 37 | 38 | with open(self._aflw_train, 'r') as f: 39 | self._lines = f.read().splitlines() 40 | self._train_imlist = [line for line in self._lines] 41 | with open(self._aflw_test, 'r') as f: 42 | self._lines = f.read().splitlines() 43 | self._test_imlist = [line for line in self._lines] 44 | if subset_name == 'train': 45 | self._imlist = self._train_imlist 46 | if subset_name == 'test': 47 | self._imlist = self._test_imlist 48 | 49 | self._num_samples = len(self._imlist) 50 | self._waitlist = list(range(len(self._imlist))) 51 | if self._shuffle: 52 | random.shuffle(self._waitlist) 53 | self._dataset = None 54 | self._cur_pos = 0 # num of sample done in this epoch 55 | self._cur_epoch = 0 # current num of epoch 56 | self._cur_iter = 0 # num of batches returned 57 | self._num_fields = 1 # number of fields need to return (image, label) 58 | self._out_h = 80 59 | self._out_w = 80 60 | 61 | self._image_cache = [] 62 | 63 | self._lock = threading.Lock() 64 | 65 | self._pool_size = cpu_count() 66 | 67 | self._pool = Pool(self._pool_size) 68 | self._cache_thread = threading.Thread(target=self.preload_dataset) 69 | self._cache_thread.start() 70 | 71 | def read_image(self, i): 72 | image_name = self._impath + self._imlist[i] 73 | # The channel for cv2.imread is B, G, R 74 | if not os.path.exists(image_name): 75 | print(image_name) 76 | image_arr = cv2.imread(image_name) 77 | h,w,_ = image_arr.shape 78 | image_arr = cv2.resize(image_arr, (self._out_h, self._out_w)) 79 | result = image_arr.astype(np.float32) / np.array(255., dtype=np.float32) 80 | result[:, :, [0, 1, 2]] = result[:, :, [2, 1, 0]] 81 | return result 82 | 83 | def __call__(self, *args, **kwargs): 84 | return self.next_batch(*args, **kwargs) 85 | 86 | def num_samples(self): 87 | return self._num_samples 88 | 89 | def epoch(self): 90 | return self._cur_epoch 91 | 92 | def iter(self): 93 | return self._cur_iter 94 | 95 | def num_fields(self): 96 | return self._num_fields 97 | 98 | def num_samples_finished(self): 99 | return self._cur_pos 100 | 101 | def reset(self): 102 | """ Reset the state of the data loader 103 | E.g., the reader points at the beginning of the dataset again 104 | :return: None 105 | """ 106 | self._cur_pos = 0 107 | self._cur_epoch = 0 108 | self._cur_iter = 0 109 | self._waitlist = list(range(len(self._imlist))) 110 | if self._shuffle: 111 | random.shuffle(self._waitlist) 112 | tmp = 0 113 | while self._cache_thread.isAlive(): 114 | tmp+=1 115 | self._cache_thread = threading.Thread(target=self.preload_dataset) 116 | self._lock.acquire() 117 | self._image_cache = [] 118 | self._lock.release() 119 | self._cache_thread.start() 120 | 121 | def preload_dataset(self): 122 | if self._debug: 123 | print("preload") 124 | if len(self._image_cache) > self._cache_size: 125 | return 126 | else: 127 | while len(self._image_cache) < 1000: 128 | if len(self._waitlist) < 1000: 129 | self._waitlist += list(range(len(self._imlist))) 130 | if self._shuffle: 131 | random.shuffle(self._waitlist) 132 | 133 | results = self._pool.map(self.read_image, self._waitlist[:1000]) 134 | del self._waitlist[:1000] 135 | self._lock.acquire() 136 | self._image_cache = self._image_cache + list(results) 137 | self._lock.release() 138 | if self._debug: 139 | print(len(self._image_cache)) 140 | 141 | def next_batch(self, batch_size): 142 | """ fetch the next batch 143 | :param batch_size: next batch_size 144 | :return: a tuple includes all data 145 | """ 146 | if batch_size < 0: 147 | batch_size = 0 148 | if self._cache_size < 3 * batch_size: 149 | self._cache_size = 3 * batch_size 150 | 151 | this_batch = [None] * self._num_fields 152 | 153 | if len(self._image_cache) < batch_size: 154 | if self._debug: 155 | print("Blocking!!, Should only appear once with proper setting") 156 | 157 | if not self._cache_thread.isAlive(): 158 | self._cache_thread = threading.Thread(target=self.preload_dataset) 159 | self._cache_thread.start() 160 | self._cache_thread.join() 161 | 162 | self._lock.acquire() 163 | this_batch[0] = self._image_cache[0:batch_size] 164 | del self._image_cache[0:batch_size] 165 | self._lock.release() 166 | else: 167 | self._lock.acquire() 168 | this_batch[0] = self._image_cache[0:batch_size] 169 | del self._image_cache[0:batch_size] 170 | self._lock.release() 171 | if not self._cache_thread.isAlive(): 172 | self._cache_thread = threading.Thread(target=self.preload_dataset) 173 | self._cache_thread.start() 174 | 175 | self._cur_iter += 1 176 | self._cur_pos = self._cur_pos + batch_size 177 | if self._cur_pos >= self._num_samples: 178 | self._cur_epoch += 1 179 | self._cur_pos = self._cur_pos % self._num_samples 180 | 181 | return this_batch 182 | 183 | @staticmethod 184 | def output_types(): # only used for net instance 185 | t = ["float32"] 186 | return t 187 | 188 | @staticmethod 189 | def output_shapes(): 190 | t = [(None, 80, 80, 3)] # None for batch size 191 | return t 192 | 193 | @staticmethod 194 | def output_ranges(): 195 | return [1.] 196 | 197 | @staticmethod 198 | def output_keys(): 199 | return ["data"] 200 | 201 | 202 | if __name__ == '__main__': 203 | main() 204 | 205 | -------------------------------------------------------------------------------- /net_modules/auto_struct/generic_encoder.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from net_modules.auto_struct.generic_structure_encoder import Factory as BaseFactory 3 | import tensorflow as tf 4 | import prettytensor as pt 5 | import zutils.tf_math_funcs as tmf 6 | import net_modules.auto_struct.utils as asu 7 | from zutils.py_utils import dummy_class_for_with 8 | 9 | 10 | class Factory(BaseFactory): 11 | 12 | # __metaclass__ = ABCMeta 13 | 14 | def __init__(self, output_channels, options): 15 | """ 16 | :param output_channels: output_channels for the encoding net 17 | """ 18 | super().__init__(output_channels, options) 19 | 20 | self.structure_as_final_class = False 21 | 22 | self.stop_gradient_at_latent_for_patch = False 23 | self.stop_gradient_at_latent_for_overall = False 24 | 25 | # patch feature dim (utilities for inherent class) 26 | if hasattr(self, "default_patch_feature_dim"): 27 | self.patch_feature_dim = self.default_patch_feature_dim 28 | else: 29 | self.patch_feature_dim = None 30 | if "patch_feature_dim" in self.options: 31 | self.patch_feature_dim = self.options["patch_feature_dim"] 32 | 33 | # overall feature dim (utilities for inherent class) 34 | if hasattr(self, "default_overall_feature_dim"): 35 | self.overall_feature_dim = self.default_overall_feature_dim 36 | else: 37 | self.overall_feature_dim = None 38 | if "overall_feature_dim" in self.options: 39 | self.overall_feature_dim = self.options["overall_feature_dim"] 40 | 41 | def __call__(self, input_tensor, condition_tensor=None, extra_inputs=None): 42 | """Create encoder network. 43 | """ 44 | 45 | latent_tensor, mos = self.patch_structure_overall_encode( 46 | input_tensor, condition_tensor=condition_tensor, extra_inputs=extra_inputs 47 | ) 48 | assert self.output_channels == tmf.get_shape(latent_tensor)[1], \ 49 | "wrong output_channels" 50 | return latent_tensor, mos.extra_outputs 51 | 52 | def patch_structure_overall_encode(self, input_tensor, condition_tensor=None, extra_inputs=None): 53 | """Create encoder network. 54 | """ 55 | 56 | # compute structures 57 | overall_feature, heatmap, structure_latent, mos = self.structure_encode( 58 | input_tensor, condition_tensor=condition_tensor, extra_inputs=extra_inputs) 59 | 60 | if "heatmap_for_patch_features" in mos.extra_outputs \ 61 | and mos.extra_outputs["heatmap_for_patch_features"] is not None: 62 | heatmap_for_patch_features = mos.extra_outputs["heatmap_for_patch_features"] 63 | else: 64 | heatmap_for_patch_features = heatmap 65 | 66 | aug_cache = mos.extra_outputs["aug_cache"] 67 | main_batch_size = aug_cache["main_batch_size"] 68 | 69 | with tf.variable_scope("deterministic"), tf.variable_scope("feature"): 70 | # compute patch features 71 | feature_map = mos(self.image2feature(overall_feature)) 72 | 73 | patch_features = None 74 | if feature_map is not None: 75 | # pool features 76 | batch_size, h, w, num_struct = tmf.get_shape(heatmap_for_patch_features) 77 | feature_channels = tmf.get_shape(feature_map)[-1] 78 | heatmap_e = tf.reshape( 79 | heatmap_for_patch_features, [batch_size, h*w, num_struct, 1]) 80 | feature_map_e = tf.reshape( 81 | feature_map, [batch_size, h*w, 1, feature_channels]) 82 | 83 | patch_features = tf.reduce_sum(feature_map_e * heatmap_e, axis=1) / \ 84 | tf.reduce_sum(heatmap_e, axis=1) # [batch_size, struct_num, feature_channels] 85 | 86 | # if tmf.get_shape(patch_features)[2] != self.patch_feature_dim: 87 | # always add an independent feature space 88 | if hasattr(self, "pt_defaults_scope_value"): 89 | pt_scope = pt.defaults_scope(**self.pt_defaults_scope_value()) 90 | else: 91 | pt_scope = dummy_class_for_with() 92 | with pt_scope: 93 | patch_features = pt.wrap(patch_features).group_connected( 94 | self.patch_feature_dim, activation_fn=None, 95 | tie_groups=self.options["tie_patch_feature_spaces"] 96 | if "tie_patch_feature_spaces" in self.options else False 97 | ) 98 | 99 | with tf.variable_scope("deterministic"): 100 | # use the main batch only 101 | heatmap = heatmap[:main_batch_size] 102 | overall_feature = overall_feature[:main_batch_size] 103 | if patch_features is not None: 104 | patch_features = mos(self.cleanup_augmentation_patchfeatures(patch_features, aug_cache)) 105 | 106 | mos.extra_outputs["for_decoder"]["patch_features"] = patch_features 107 | 108 | with tf.variable_scope("variational"): # use the scope name for backward compitability 109 | with tf.variable_scope("feature"): 110 | # compute patch latent 111 | patch_latent = None 112 | if patch_features is not None: 113 | if self.stop_gradient_at_latent_for_patch: 114 | patch_features_1 = tf.stop_gradient(patch_features) 115 | else: 116 | patch_features_1 = patch_features 117 | patch_latent = mos(self.feature2latent(patch_features_1)) 118 | 119 | # compute overall latent 120 | if self.stop_gradient_at_latent_for_overall: 121 | overall_feature_1 = tf.stop_gradient(overall_feature) 122 | else: 123 | overall_feature_1 = overall_feature 124 | overall_latent = self.overall2latent(overall_feature_1) 125 | 126 | with tf.variable_scope("both"): 127 | latent_tensor = mos(self._fuse_structure_patch_overall( 128 | structure_latent, patch_latent, overall_latent 129 | )) 130 | 131 | return latent_tensor, mos 132 | 133 | def cleanup_augmentation_patchfeatures(self, patch_features, aug_cache): 134 | return patch_features[:aug_cache["main_batch_size"]] 135 | 136 | def image2feature(self, image_tensor): 137 | return None 138 | 139 | def feature2latent(self, patch_features): 140 | batch_size, struct_num, feature_channels = tmf.get_shape(patch_features) 141 | return tf.reshape(patch_features, [batch_size, struct_num * feature_channels]) 142 | 143 | def overall2latent(self, image_features): 144 | return None 145 | 146 | def _fuse_structure_patch_overall(self, structure_latent, patch_latent, overall_latent): 147 | assert structure_latent is not None, "structure latent should not be None" 148 | if patch_latent is None and overall_latent is None: 149 | return self.fuse_structure(structure_latent) 150 | elif patch_latent is None: 151 | return self.fuse_structure_overall(structure_latent, overall_latent) 152 | elif overall_latent is None: 153 | return self.fuse_structure_patch(structure_latent, patch_latent) 154 | else: 155 | return self.fuse_structure_patch_overall(structure_latent, patch_latent, overall_latent) 156 | 157 | # override one of the following three methods based on which features are used -------------------- 158 | 159 | def fuse_structure(self, structure_latent): 160 | return structure_latent 161 | 162 | def fuse_structure_patch(self, structure_latent, patch_latent): 163 | return tf.concat([structure_latent, patch_latent], axis=-1) 164 | 165 | def fuse_structure_overall(self, structure_latent, overall_latent): 166 | return tf.concat([structure_latent, overall_latent], axis=-1) 167 | 168 | def fuse_structure_patch_overall(self, structure_latent, patch_latent, overall_latent): 169 | return tf.concat([structure_latent, patch_latent, overall_latent], axis=-1) 170 | # ------------------------------------------------------------------------------------------------- 171 | 172 | -------------------------------------------------------------------------------- /runner/preprocessing_data_module_wrapper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.ndimage 3 | from zutils.py_utils import link_with_instance 4 | from copy import copy 5 | 6 | 7 | class Net: 8 | 9 | def __init__(self, data_module, options, mode=None): 10 | assert isinstance(options, dict), "wrong options" 11 | self.data_module = data_module 12 | self.opt_dict = options 13 | 14 | if mode is None: 15 | mode = "deterministic" 16 | assert mode in ("deterministic", "random") 17 | if mode == "deterministic": 18 | self.use_random = False 19 | else: 20 | self.use_random = True 21 | 22 | self.shorter_edge_length = self._shorter_edge_length() 23 | self.center_patch_size = self._center_patch_size() 24 | self.crop_size = self._crop_size() 25 | self._data_field_index = self.data_module.output_keys().index('data') 26 | 27 | # output shape 28 | original_output_shape = self.data_module.output_shapes() 29 | self._output_shape = copy(original_output_shape) 30 | if self.crop_size is not None: 31 | target_shape = self.crop_size 32 | elif self.center_patch_size is not None: 33 | target_shape = self.center_patch_size 34 | else: 35 | target_shape = None 36 | if target_shape is not None: 37 | _data_shape = self._output_shape[self._data_field_index] 38 | self._output_shape[self._data_field_index] = (_data_shape[0],) + target_shape + (_data_shape[3],) 39 | self.image_color_scaling = self._image_color_scaling() 40 | 41 | link_with_instance(self, self.data_module) 42 | 43 | def __call__(self, *args, **kwargs): 44 | return self.next_batch(*args, **kwargs) 45 | 46 | def _shorter_edge_length(self): 47 | if "image_shorter_edge_length" not in self.opt_dict: 48 | return None 49 | a = self.opt_dict["image_shorter_edge_length"] 50 | if a is None: 51 | return None 52 | if self.use_random: 53 | if "image_shorter_edge_length_list" not in self.opt_dict: 54 | return [a] 55 | b = self.opt_dict["image_shorter_edge_length_list"] 56 | if b is None: 57 | return [a] 58 | if isinstance(b, range): 59 | b = list(b) 60 | return b 61 | else: 62 | return a 63 | 64 | def _center_patch_size(self): 65 | if "image_center_patch_size" not in self.opt_dict: 66 | return None 67 | s = self.opt_dict["image_center_patch_size"] 68 | if s is None: 69 | return None 70 | if isinstance(s, (int, float)): 71 | return s, s 72 | assert len(s) == 2, "wrong center patch specification" 73 | w = s[0] 74 | h = s[1] 75 | return w, h 76 | 77 | def _crop_size(self): 78 | if "image_crop_size" not in self.opt_dict: 79 | return None 80 | s = self.opt_dict["image_crop_size"] 81 | if s is None: 82 | return None 83 | if isinstance(s, (int, float)): 84 | return s, s 85 | assert len(s) == 2, "wrong crop size specification" 86 | w = s[0] 87 | h = s[1] 88 | return w, h 89 | 90 | def _image_color_scaling(self): 91 | if "image_color_scaling" not in self.opt_dict: 92 | return None 93 | return self.opt_dict["image_color_scaling"] 94 | 95 | #def _image_background_colors(self): 96 | # if "image_background_color" not in self.opt_dict or self.opt_dict["image_background_color"] is None: 97 | # return None, None 98 | # assert "image_background_replace_color" in self.opt_dict and self.opt_dict["image_background_color"], \ 99 | # "must specify image_background_replace_color" 100 | # return self.opt_dict["image_background_color"], self.opt_dict["image_background_color"] 101 | 102 | def output_shapes(self): 103 | return self._output_shape 104 | 105 | @staticmethod 106 | def robust_center_crop(im, crop_size): 107 | im_rank = len(im.shape) 108 | ch = crop_size[0] 109 | cw = crop_size[1] 110 | if ch > im.shape[0] or cw < im.shape[1]: 111 | h_beg = (ch - im.shape[0]) // 2 112 | h_end = ch - im.shape[0] - h_beg 113 | w_beg = (cw - im.shape[1]) // 2 114 | w_end = cw - im.shape[1] - w_beg 115 | im = np.pad(im, [(h_beg, h_end), (w_beg, w_end)] + [(0, 0)] * (im_rank - 2), mode='edge') 116 | if ch != im.shape[0] or cw != im.shape[1]: 117 | h_beg = (im.shape[0] - ch) // 2 118 | h_end = im.shape[0] - ch - h_beg 119 | w_beg = (im.shape[1] - cw) // 2 120 | w_end = im.shape[1] - cw - w_beg 121 | im = im[h_beg:-h_end, w_beg:-w_end] 122 | return im 123 | 124 | def process_data(self, index, input_list, output_list): 125 | 126 | im = input_list[index] 127 | im_rank = len(im.shape) 128 | 129 | # resize shorter edge (can have randomness) 130 | if self.shorter_edge_length is not None: 131 | if isinstance(self.shorter_edge_length, list): 132 | # random generation 133 | n = len(self.shorter_edge_length) 134 | if n == 1: 135 | k = 0 136 | else: 137 | k = np.random.randint(low=0, high=n) 138 | target_shorter_edge = self.shorter_edge_length[k] 139 | else: 140 | target_shorter_edge = self.shorter_edge_length 141 | 142 | actual_shorter_edge = min(im.shape[0], im.shape[1]) 143 | if actual_shorter_edge != target_shorter_edge: 144 | zoom_factor = [target_shorter_edge / actual_shorter_edge] * 2 145 | zoom_factor.extend([1]*(im_rank-2)) 146 | im = scipy.ndimage.zoom(im, zoom_factor) 147 | 148 | # crop center, add pad if necessary 149 | canvas_size = None 150 | if self.center_patch_size is None: 151 | if not self.use_random and self.crop_size is not None: 152 | canvas_size = self.crop_size 153 | else: 154 | if self.crop_size is None: 155 | canvas_size = self.center_patch_size 156 | else: 157 | if self.use_random: 158 | canvas_size = [ 159 | max(self.center_patch_size[0], self.crop_size[0]), 160 | max(self.center_patch_size[1], self.crop_size[1]) 161 | ] 162 | else: 163 | canvas_size = self.crop_size 164 | 165 | if canvas_size is not None: 166 | im = self.robust_center_crop(im, canvas_size) 167 | 168 | # crop final patch (can have randomness) 169 | if self.crop_size is not None: 170 | canvas_size = im.shape[0:2] 171 | if self.crop_size != canvas_size: 172 | assert self.use_random, "internal error: should not reach here is not using random" 173 | h_beg = np.random.randint(low=0, high=canvas_size[0]-self.crop_size[0]) 174 | w_beg = np.random.randint(low=0, high=canvas_size[1]-self.crop_size[1]) 175 | im = im[h_beg:(h_beg+self.crop_size[0]), w_beg:(w_beg+self.crop_size[1])] 176 | 177 | if self.image_color_scaling is not None: 178 | im = im * self.image_color_scaling + (1-self.image_color_scaling) * 0.5 179 | 180 | output_list[index] = im 181 | 182 | def next_batch(self, batch_size): 183 | 184 | # get raw data 185 | data_list = self.data_module(batch_size) 186 | raw_data = data_list[self._data_field_index] 187 | 188 | if isinstance(raw_data, np.ndarray): 189 | n = raw_data.shape[0] 190 | else: 191 | n = len(raw_data) 192 | 193 | processed_data = [None]*n 194 | for i in range(n): # implement in an easy to parallel way 195 | self.process_data(i, raw_data, processed_data) 196 | 197 | processed_data = np.concatenate([np.expand_dims(v, axis=0) for v in processed_data], axis=0) 198 | 199 | output_data_list = copy(data_list) 200 | output_data_list[self._data_field_index] = processed_data 201 | 202 | return output_data_list 203 | 204 | -------------------------------------------------------------------------------- /net_modules/spatial_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import tensorflow as tf 16 | 17 | 18 | def _interpolate(im, x, y, out_size): 19 | with tf.variable_scope('_interpolate'): 20 | # constants 21 | num_batch = tf.shape(im)[0] 22 | height = tf.shape(im)[1] 23 | width = tf.shape(im)[2] 24 | channels = tf.shape(im)[3] 25 | 26 | x = tf.cast(x, 'float32') 27 | y = tf.cast(y, 'float32') 28 | height_f = tf.cast(height, 'float32') 29 | width_f = tf.cast(width, 'float32') 30 | out_height = out_size[0] 31 | out_width = out_size[1] 32 | zero = tf.zeros([], dtype='int32') 33 | max_y = tf.cast(tf.shape(im)[1] - 1, 'int32') 34 | max_x = tf.cast(tf.shape(im)[2] - 1, 'int32') 35 | 36 | # scale indices from [-1, 1] to [0, width/height] 37 | x = (x + 1.0) / 2.0 * (width_f) - 0.5 38 | y = (y + 1.0) / 2.0 * (height_f) - 0.5 39 | x = tf.clip_by_value(x, tf.cast(zero, "float"), tf.cast(max_x, "float")) 40 | y = tf.clip_by_value(y, tf.cast(zero, "float"), tf.cast(max_y, "float")) 41 | 42 | # do sampling 43 | x0 = tf.cast(tf.floor(x), 'int32') 44 | x1 = x0 + 1 45 | y0 = tf.cast(tf.floor(y), 'int32') 46 | y1 = y0 + 1 47 | 48 | x0 = tf.clip_by_value(x0, zero, max_x) 49 | x1 = tf.clip_by_value(x1, zero, max_x) 50 | y0 = tf.clip_by_value(y0, zero, max_y) 51 | y1 = tf.clip_by_value(y1, zero, max_y) 52 | dim2 = width 53 | dim1 = width * height 54 | base = _repeat(tf.range(num_batch) * dim1, out_height * out_width) 55 | base_y0 = base + y0 * dim2 56 | base_y1 = base + y1 * dim2 57 | idx_a = base_y0 + x0 58 | idx_b = base_y1 + x0 59 | idx_c = base_y0 + x1 60 | idx_d = base_y1 + x1 61 | 62 | # use indices to lookup pixels in the flat image and restore 63 | # channels dim 64 | im_flat = tf.reshape(im, tf.stack([-1, channels])) 65 | im_flat = tf.cast(im_flat, 'float32') 66 | Ia = tf.gather(im_flat, idx_a) 67 | Ib = tf.gather(im_flat, idx_b) 68 | Ic = tf.gather(im_flat, idx_c) 69 | Id = tf.gather(im_flat, idx_d) 70 | 71 | # and finally calculate interpolated values 72 | x0_f = tf.cast(x0, 'float32') 73 | x1_f = tf.cast(x1, 'float32') 74 | y0_f = tf.cast(y0, 'float32') 75 | y1_f = tf.cast(y1, 'float32') 76 | wa = tf.expand_dims(((x1_f - x) * (y1_f - y)), 1) 77 | wb = tf.expand_dims(((x1_f - x) * (y - y0_f)), 1) 78 | wc = tf.expand_dims(((x - x0_f) * (y1_f - y)), 1) 79 | wd = tf.expand_dims(((x - x0_f) * (y - y0_f)), 1) 80 | output = tf.add_n([wa * Ia, wb * Ib, wc * Ic, wd * Id]) 81 | return output 82 | 83 | 84 | def _repeat(x, n_repeats): 85 | with tf.variable_scope('_repeat'): 86 | rep = tf.transpose( 87 | tf.expand_dims(tf.ones(shape=tf.stack([n_repeats, ])), 1), [1, 0]) 88 | rep = tf.cast(rep, 'int32') 89 | x = tf.matmul(tf.reshape(x, (-1, 1)), rep) 90 | return tf.reshape(x, [-1]) 91 | 92 | 93 | def _meshgrid(height, width): 94 | with tf.variable_scope('_meshgrid'): 95 | # This should be equivalent to: 96 | # x_t, y_t = np.meshgrid(np.linspace(-1, 1, width), 97 | # np.linspace(-1, 1, height)) 98 | # ones = np.ones(np.prod(x_t.shape)) 99 | # grid = np.vstack([x_t.flatten(), y_t.flatten(), ones]) 100 | x_t = tf.matmul(tf.ones(shape=tf.stack([height, 1])), 101 | tf.transpose(tf.expand_dims(tf.linspace(-1.0, 1.0, width), 1), [1, 0])) 102 | y_t = tf.matmul(tf.expand_dims(tf.linspace(-1.0, 1.0, height), 1), 103 | tf.ones(shape=tf.stack([1, width]))) 104 | 105 | x_t_flat = tf.reshape(x_t, (1, -1)) 106 | y_t_flat = tf.reshape(y_t, (1, -1)) 107 | 108 | ones = tf.ones_like(x_t_flat) 109 | grid = tf.concat(axis=0, values=[x_t_flat, y_t_flat, ones]) 110 | return grid 111 | 112 | 113 | def _transform(theta, input_dim, out_size): 114 | with tf.variable_scope('_transform'): 115 | num_batch = tf.shape(input_dim)[0] 116 | height = tf.shape(input_dim)[1] 117 | width = tf.shape(input_dim)[2] 118 | num_channels = tf.shape(input_dim)[3] 119 | theta = tf.reshape(theta, (-1, 2, 3)) 120 | theta = tf.cast(theta, 'float32') 121 | 122 | # grid of (x_t, y_t, 1), eq (1) in ref [1] 123 | height_f = tf.cast(height, 'float32') 124 | width_f = tf.cast(width, 'float32') 125 | out_height = out_size[0] 126 | out_width = out_size[1] 127 | grid = _meshgrid(out_height, out_width) 128 | grid = tf.expand_dims(grid, 0) 129 | grid = tf.reshape(grid, [-1]) 130 | grid = tf.tile(grid, tf.stack([num_batch])) 131 | grid = tf.reshape(grid, tf.stack([num_batch, 3, -1])) 132 | 133 | # Transform A x (x_t, y_t, 1)^T -> (x_s, y_s) 134 | T_g = tf.matmul(theta, grid) 135 | x_s = tf.slice(T_g, [0, 0, 0], [-1, 1, -1]) 136 | y_s = tf.slice(T_g, [0, 1, 0], [-1, 1, -1]) 137 | x_s_flat = tf.reshape(x_s, [-1]) 138 | y_s_flat = tf.reshape(y_s, [-1]) 139 | 140 | input_transformed = _interpolate( 141 | input_dim, x_s_flat, y_s_flat, 142 | out_size) 143 | 144 | output = tf.reshape( 145 | input_transformed, tf.stack([num_batch, out_height, out_width, num_channels])) 146 | return output 147 | 148 | 149 | def transformer(U, theta, out_size, name='SpatialTransformer', **kwargs): 150 | """Spatial Transformer Layer 151 | 152 | Implements a spatial transformer layer as described in [1]_. 153 | Based on [2]_ and edited by David Dao for Tensorflow. 154 | 155 | Parameters 156 | ---------- 157 | U : float 158 | The output of a convolutional net should have the 159 | shape [num_batch, height, width, num_channels]. 160 | theta: float 161 | The output of the 162 | localisation network should be [num_batch, 6]. 163 | out_size: tuple of two ints 164 | The size of the output of the network (height, width) 165 | 166 | References 167 | ---------- 168 | .. [1] Spatial Transformer Networks 169 | Max Jaderberg, Karen Simonyan, Andrew Zisserman, Koray Kavukcuoglu 170 | Submitted on 5 Jun 2015 171 | .. [2] https://github.com/skaae/transformer_network/blob/master/transformerlayer.py 172 | 173 | Notes 174 | ----- 175 | To initialize the network to the identity transform init 176 | ``theta`` to : 177 | identity = np.array([[1., 0., 0.], 178 | [0., 1., 0.]]) 179 | identity = identity.flatten() 180 | theta = tf.Variable(initial_value=identity) 181 | 182 | """ 183 | 184 | with tf.variable_scope(name): 185 | output = _transform(theta, U, out_size) 186 | return output 187 | 188 | 189 | def batch_transformer(U, thetas, out_size, name='BatchSpatialTransformer'): 190 | """Batch Spatial Transformer Layer 191 | 192 | Parameters 193 | ---------- 194 | 195 | U : float 196 | tensor of inputs [num_batch,height,width,num_channels] 197 | thetas : float 198 | a set of transformations for each input [num_batch,num_transforms,6] 199 | out_size : int 200 | the size of the output [out_height,out_width] 201 | 202 | Returns: float 203 | Tensor of size [num_batch*num_transforms,out_height,out_width,num_channels] 204 | """ 205 | with tf.variable_scope(name): 206 | num_batch, num_transforms = map(int, thetas.get_shape().as_list()[:2]) 207 | indices = [[i]*num_transforms for i in range(num_batch)] 208 | input_repeated = tf.gather(U, tf.reshape(indices, [-1])) 209 | return transformer(input_repeated, thetas, out_size) 210 | --------------------------------------------------------------------------------