├── .gitignore
├── LICENSE.txt
├── README.md
├── cm_22.png
├── config.py
├── generation
├── .gitignore
├── data.py
├── experiments
│ └── config
│ │ ├── CSM
│ │ ├── config.py
│ │ ├── model.py
│ │ ├── options.py
│ │ ├── preprocessing.py
│ │ ├── summaries.py
│ │ └── write_output.py
│ │ ├── LSM
│ │ ├── config.py
│ │ ├── model.py
│ │ ├── options.py
│ │ ├── preprocessing.py
│ │ ├── summaries.py
│ │ └── write_output.py
│ │ ├── PM
│ │ ├── config.py
│ │ ├── model.py
│ │ ├── options.py
│ │ ├── preprocessing.py
│ │ ├── summaries.py
│ │ └── write_output.py
│ │ ├── PM_class
│ │ ├── config.py
│ │ ├── model.py
│ │ ├── options.py
│ │ ├── preprocessing.py
│ │ ├── summaries.py
│ │ └── write_output.py
│ │ └── template
│ │ ├── config.py
│ │ ├── model.py
│ │ ├── options.py
│ │ ├── preprocessing.py
│ │ ├── summaries.py
│ │ └── write_output.py
├── generate.sh
├── generate_conditioned.sh
├── run.py
├── test_runner.py
└── tools
│ ├── 01_chic10k_to_fashion_joined.py
│ ├── 02_create_clothdataset.py
│ ├── 03_create_pose_input.sh
│ ├── 04_run_deepercut.sh
│ ├── 05_run_fits.sh
│ ├── 06_render_bodies.py
│ ├── 07_create_additional_conditioning.py
│ ├── 08_prepare_directinpaint.sh
│ ├── 09_pack_db.sh
│ └── geometric_median.py
├── gp_tools
├── __init__.py
├── tf.py
└── write.py
├── requirements.txt
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | generation/data
3 | generation/experiments/states
4 | .#*
5 | .*_cache
6 | generation/experiments/features
7 | .DS_Store
8 | generation/results
9 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) 2017 University of Tuebingen, Christoph Lassner.
2 |
3 | Code and annotations are available under the
4 | Creative Commons Attribution-Noncommercial 4.0 Internation License
5 | (https://creativecommons.org/licenses/by-nc/4.0/).
6 |
7 | Some portions of the code originate from
8 | https://github.com/affinelayer/pix2pix-tensorflow . Thanks to Christopher Hesse
9 | for making his implementation available. For these portions, his license
10 | applies:
11 |
12 | MIT License
13 |
14 | Copyright (c) 2017 Christopher Hesse
15 |
16 | Permission is hereby granted, free of charge, to any person obtaining a copy
17 | of this software and associated documentation files (the "Software"), to deal
18 | in the Software without restriction, including without limitation the rights
19 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
20 | copies of the Software, and to permit persons to whom the Software is
21 | furnished to do so, subject to the following conditions:
22 |
23 | The above copyright notice and this permission notice shall be included in all
24 | copies or substantial portions of the Software.
25 |
26 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
27 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
28 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
29 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
30 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
31 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
32 | SOFTWARE.
33 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Generating People code repository
2 |
3 | Requirements:
4 |
5 | * OpenCV (on Ubuntu, e.g., install libopencv-dev and python-opencv).
6 | * SMPL (download at http://smpl.is.tue.mpg.de/downloads) and unzip to a
7 | place of your choice.
8 | * Edit the file `config.py` to set up the paths.
9 | * `tensorflow` or `tensorflow-gpu` in a version >=v1.1.0 (I did not want to add
10 | it to the requirements to force installation of the GPU or non-GPU version).
11 | * Only if you want to run pose estimation and 3D fitting to integrate new data
12 | into the dataset: set up the unite the people repository
13 | (https://github.com/classner/up) and adjust its path in `config.py`.
14 |
15 | The rest of the requirements is then automatically installed when running:
16 |
17 | ```
18 | python setup.py develop
19 | ```
20 |
21 | ## Setting up the data
22 |
23 | The scripts in `generation/tools/` transform the Chictopia data to construct to
24 | the final database. Iteratively go through the scripts to create it. Otherwise,
25 | download the pre-processed data from our website
26 | (http://files.is.tuebingen.mpg.de/classner/gp/), unzip it to the folder
27 | `generation/data/pose/extracted` and only run the last script
28 |
29 | ```
30 | ./09_pack_db.sh full
31 | ```
32 |
33 | ## Training / running models
34 |
35 | Model configuration and training artifacts are in the `experiments` folder. The
36 | `config` subfolder contains model configurations (LSM=latent sketch module,
37 | CSM=conditional sketch module, PM=portray module, PSM_class=portray module with
38 | class input). You can track the contents of this folder with git since it's
39 | lightweight and no artifacts are stored there. To create a new model, just copy
40 | `template` (or link to the files in it) and change `options.py` in the new
41 | folder.
42 |
43 | To run training/validation/testing use
44 |
45 | ```
46 | ./run.py [train,val,trainval,test,{sample}] experiments/config/modelname
47 | ```
48 |
49 | where `trainval` runs a training on training+validation. Artifacts during
50 | training are written to `experiments/states/modelname` (you can run a
51 | tensorboard there for monitoring). The generated results from testing are stored
52 | in `experiments/features/modelname/runstate`, where runstate is either a
53 | training stage or point in time (if sampling). You can use the `test_runner.py`
54 | script to automatically scan for newly created training checkpoints and
55 | validating/testing them with the command
56 |
57 | ```
58 | ./test_runner.py experiments/states/modelname [val, test]
59 | ```
60 |
61 | Pre-trained models can be downloaded from
62 | http://files.is.tuebingen.mpg.de/classner/gp .
63 |
64 | ## Generating people
65 |
66 | If you have trained or downloaded the LSM and PM models, you can use a
67 | convenience script to sample people. For this, navigate to the `generation`
68 | folder and run
69 |
70 | ```
71 | ./generate.sh n_people [out_folder]
72 | ```
73 |
74 | to generate `n_people` to the optionally specified `out_folder`. If unspecified,
75 | the output folder is set to `generated`.
76 |
77 | ## Citing
78 |
79 | If you use this code for your research, please consider citing us:
80 |
81 | ```
82 | @INPROCEEDINGS{Lassner:GeneratingPeople:2017,
83 | author = {Christoph Lassner and Gerard Pons-Moll and Peter V. Gehler},
84 | title = {A Generative Model for People in Clothing},
85 | year = {2017},
86 | booktitle = {Proceedings of the IEEE International Conference on Computer Vision}
87 | }
88 | ```
89 |
90 | ## Acknowledgements
91 |
92 | Our models are strongly inspired by the pix2pix line of work by Isola et al.
93 | (https://phillipi.github.io/pix2pix/). Parts of the code are inspired by the
94 | implementation by Christopher Hesse (https://affinelayer.com/pix2pix/). Overall,
95 | this repository is set up similar to the Deeplab project structure, enabling
96 | efficient model specification, tracking and training
97 | (http://liangchiehchen.com/projects/DeepLabv2_resnet.html) and combining it with
98 | the advantages of Tensorboard.
99 |
--------------------------------------------------------------------------------
/cm_22.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/classner/generating_people/ee6b54945d395efe23ddceb4a2daca3fffc7e89d/cm_22.png
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python2
2 | """Configuration values for the project."""
3 | import os.path as path
4 | import click
5 | if __name__ != '__main__':
6 | import matplotlib
7 | import scipy.misc as sm
8 | import numpy as np
9 |
10 |
11 | ############################ EDIT HERE ########################################
12 | SMPL_FP = path.expanduser("~/smpl")
13 | UP_FP = path.expanduser("~/git/up")
14 | CHICTOPIA_DATA_FP = path.expanduser("~/datasets/chictopia10k")
15 | GEN_DATA_FP = path.abspath(path.join(path.dirname(__file__),
16 | 'data', 'intermediate'))
17 | EXP_DATA_FP = path.abspath(path.join(path.dirname(__file__),
18 | 'generation', 'data'))
19 |
20 | if __name__ != '__main__':
21 | CMAP = matplotlib.colors.ListedColormap(sm.imread(path.join(
22 | path.dirname(__file__), "cm_22.png"))[0]
23 | .astype(np.float32) / 255.)
24 |
25 |
26 | ###############################################################################
27 | # Infrastructure. Don't edit. #
28 | ###############################################################################
29 |
30 | @click.command()
31 | @click.argument('key', type=click.STRING)
32 | def cli(key):
33 | """Print a config value to STDOUT."""
34 | if key in globals().keys():
35 | print globals()[key]
36 | else:
37 | raise Exception("Requested configuration value not available! "
38 | "Available keys: " +
39 | str([kval for kval in globals().keys() if kval.isupper()]) +
40 | ".")
41 |
42 |
43 | if __name__ == '__main__':
44 | cli() # pylint: disable=no-value-for-parameter
45 |
46 |
--------------------------------------------------------------------------------
/generation/.gitignore:
--------------------------------------------------------------------------------
1 | experiments/states
2 | experiments/features
3 | data
4 |
--------------------------------------------------------------------------------
/generation/data.py:
--------------------------------------------------------------------------------
1 | """Common loading infrastructure."""
2 | import tempfile
3 | import subprocess
4 | import shutil
5 | import math
6 | import os.path as path
7 | from glob import glob
8 | import logging
9 | import clustertools.db.tools as cdbt
10 | import tensorflow as tf
11 |
12 |
13 | LOGGER = logging.getLogger(__name__)
14 | TMP_DIRS = []
15 |
16 |
17 | def prepare(path_queue, data_def, config):
18 | reader = tf.TFRecordReader()
19 | _, serialized_example = reader.read(path_queue)
20 | return cdbt.decode_tf_tensors(
21 | serialized_example, data_def[0], data_def[1], as_list=False)
22 |
23 |
24 | def get_dataset(EXP_DATA_FP, mode, config):
25 | global TMP_DIRS
26 | suffix = config["dset_suffix"]
27 | dataset_dir = path.join(EXP_DATA_FP, config['dset_type'], suffix)
28 | LOGGER.info("Preparing dataset from `%s`.", dataset_dir)
29 | if not path.exists(dataset_dir):
30 | raise Exception("input_dir does not exist: %s." % (dataset_dir))
31 | input_paths = []
32 | if mode in ['train', 'trainval']:
33 | input_paths.extend(glob(path.join(dataset_dir, "train_p*.tfrecords")))
34 | if mode in ['trainval', 'val', 'sample']:
35 | input_paths.extend(glob(path.join(dataset_dir, 'val_p*.tfrecords')))
36 | if mode in ['test']:
37 | input_paths.extend(glob(path.join(dataset_dir, 'test_p*.tfrecords')))
38 | if len(input_paths) == 0:
39 | # Assume that it's a directory instead of tfrecord structures.
40 | # Create a temporary directory with tfrecord files.
41 | tmp_dir = tempfile.mkdtemp(dir=dataset_dir, prefix='tmp_')
42 | TMP_DIRS.append(tmp_dir)
43 | input_paths = []
44 | if mode in ['train', 'trainval']:
45 | input_paths.extend(glob(path.join(dataset_dir, "train")))
46 | if mode in ['trainval', 'val', 'sample']:
47 | input_paths.extend(glob(path.join(dataset_dir, 'val')))
48 | if mode in ['test']:
49 | input_paths.extend(glob(path.join(dataset_dir, 'test')))
50 | for fp in input_paths:
51 | subprocess.check_call([
52 | "tfrpack",
53 | fp,
54 | "--out_fp",
55 | path.join(tmp_dir, path.basename(fp)),
56 | ])
57 | input_paths = glob(path.join(tmp_dir, "*.tfrecords"))
58 | if len(input_paths) == 0:
59 | raise Exception("`%s` contains no files for this mode" % (dataset_dir))
60 | nsamples, colnames, coltypes = cdbt.scan_tfdb(input_paths)
61 | data_def = (colnames, coltypes)
62 | LOGGER.info("Found %d samples with columns: %s.", nsamples, data_def[0])
63 | with tf.name_scope("load_data"):
64 | path_queue = tf.train.string_input_producer(
65 | input_paths,
66 | capacity=200 + 8 * config["batch_size"],
67 | shuffle=(mode == "train"))
68 | if mode in ['train', 'trainval']:
69 | # Prepare for shuffling.
70 | loader_list = [prepare(path_queue, data_def, config)
71 | for _ in range(config["num_threads"])]
72 | else:
73 | loader_list = [prepare(path_queue, data_def, config)]
74 | steps_per_epoch = int(math.ceil(float(nsamples) / config["batch_size"]))
75 | return nsamples, steps_per_epoch, loader_list
76 |
77 |
78 | def cleanup():
79 | """Deletes temprorary directories if necessary."""
80 | for tmp_fp in TMP_DIRS:
81 | shutil.rmtree(tmp_fp)
82 |
--------------------------------------------------------------------------------
/generation/experiments/config/CSM/config.py:
--------------------------------------------------------------------------------
1 | ../template/config.py
--------------------------------------------------------------------------------
/generation/experiments/config/CSM/model.py:
--------------------------------------------------------------------------------
1 | ../template/model.py
--------------------------------------------------------------------------------
/generation/experiments/config/CSM/options.py:
--------------------------------------------------------------------------------
1 | config = {
2 | # Model. #######################
3 | # Supported are: 'pix2pix', 'vae', 'cvae'.
4 | "model_version": 'cvae',
5 | # What modes are available for this model.
6 | "supp_modes": ['train', 'val', 'trainval', 'test', 'sample'],
7 | # Number of Generator Filters in the first layer.
8 | # Increases gradually to max. 8 * ngf.
9 | "ngf": 64, # 64
10 | # Number of Discriminator Filers in the first layer.
11 | # Increases gradually to max. 8 * ndf.
12 | "ndf": 64, # 64
13 | # Number of latent space dimensions (z) for the vae.
14 | "nz": 32,
15 | # If "input_as_class", use softmax xentropy. Otherwise sigmoid xentropy.
16 | "iac_softmax": False,
17 | # No interconnections inside of the model.
18 | "cvae_noconn": True,
19 | # Omit variational sampling, making the CVAE a CAE.
20 | "cvae_nosampling": False,
21 | # How many samples to draw per image for CVAE.
22 | "cvae_nsamples_per_image": 5,
23 | # Instead of building a y encoder, just downscale y.
24 | "cvae_downscale_y": False,
25 | # Use batchnorm for the vector encoding in the pix2pix model.
26 | "pix2pix_zbatchnorm": True,
27 |
28 | # Data and preprocessing. ########################
29 | # Whether to use on-line data augmentation by flipping.
30 | "flip": False,
31 | # Whether to treat input data as classification or image.
32 | "input_as_class": True,
33 | # Whether to treat conditioning data as classification or image.
34 | "conditioning_as_class": True,
35 | # If input_as_class True, it is possible to give class weights here. (If
36 | # undesired, set to `None`.) This is a dictionary from class=>weight.
37 | "class_weights": {
38 | 18: 10.,
39 | 19: 10.,
40 | 20: 10.,
41 | 21: 10.,
42 | },
43 | # If True, weights are applied only for recall computation. Otherwise,
44 | # an approximation is used for Precision as well assuming that all weights
45 | # apart from the 1.0 weights are equal.
46 | "class_weights_only_recall": True,
47 | # Scale the images up to this size before cropping
48 | # them to `crop_size`.
49 | "scale_size": 286,
50 | "crop_size": 256,
51 | # Type of image content.
52 | "dset_type": "people",
53 | "dset_suffix": "full",
54 |
55 | # Optimizer ####################
56 | "max_epochs": 70,
57 | "max_steps": None,
58 | "batch_size": 40,
59 | "lr": 0.0002, # Adam lr.
60 | "beta1": 0.5, # Adam beta1 param.
61 | "gan_weight": 0.0,
62 | "recon_weight": 1.0,
63 | "latent_weight": 7.0,
64 |
65 | # Infrastructure. ##############
66 | # Save summaries after every x seconds.
67 | "summary_freq": 30,
68 | # Create traces after every x batches.
69 | "trace_freq": 0,
70 | # After every x epochs, save model.
71 | "save_freq": 10,
72 | # Keep x saves.
73 | "kept_saves": 10,
74 | # After every x epochs, render images.
75 | "display_freq": 10,
76 | # Random seed to use.
77 | "seed": 538728914,
78 | }
79 |
--------------------------------------------------------------------------------
/generation/experiments/config/CSM/preprocessing.py:
--------------------------------------------------------------------------------
1 | ../template/preprocessing.py
--------------------------------------------------------------------------------
/generation/experiments/config/CSM/summaries.py:
--------------------------------------------------------------------------------
1 | ../template/summaries.py
--------------------------------------------------------------------------------
/generation/experiments/config/CSM/write_output.py:
--------------------------------------------------------------------------------
1 | ../template/write_output.py
--------------------------------------------------------------------------------
/generation/experiments/config/LSM/config.py:
--------------------------------------------------------------------------------
1 | ../template/config.py
--------------------------------------------------------------------------------
/generation/experiments/config/LSM/model.py:
--------------------------------------------------------------------------------
1 | ../template/model.py
--------------------------------------------------------------------------------
/generation/experiments/config/LSM/options.py:
--------------------------------------------------------------------------------
1 | config = {
2 | # Model. #######################
3 | # Supported are: 'pix2pix', 'vae', 'cvae'.
4 | "model_version": 'vae',
5 | # What modes are available for this model.
6 | "supp_modes": ['train', 'val', 'trainval', 'test', 'sample'],
7 | # Number of Generator Filters in the first layer.
8 | # Increases gradually to max. 8 * ngf.
9 | "ngf": 64, # 64
10 | # Number of Discriminator Filers in the first layer.
11 | # Increases gradually to max. 8 * ndf.
12 | "ndf": 64, # 64
13 | # Number of latent space dimensions (z) for the vae.
14 | "nz": 32,
15 | # If "input_as_class", use softmax xentropy. Otherwise sigmoid xentropy.
16 | "iac_softmax": False,
17 | # No interconnections inside of the model.
18 | "cvae_noconn": False,
19 | # Omit variational sampling, making the CVAE a CAE.
20 | "cvae_nosampling": False,
21 | # How many samples to draw per image for CVAE.
22 | "cvae_nsamples_per_image": 5,
23 | # Instead of building a y encoder, just downscale y.
24 | "cvae_downscale_y": False,
25 | # Use batchnorm for the vector encoding in the pix2pix model.
26 | "pix2pix_zbatchnorm": True,
27 |
28 | # Data and preprocessing. ########################
29 | # Whether to use on-line data augmentation by flipping.
30 | "flip": False,
31 | # Whether to treat input data as classification or image.
32 | "input_as_class": True,
33 | # Whether to treat conditioning data as classification or image.
34 | "conditioning_as_class": False,
35 | # If input_as_class True, it is possible to give class weights here. (If
36 | # undesired, set to `None`.) This is a dictionary from class=>weight.
37 | "class_weights": {
38 | 18: 10.,
39 | 19: 10.,
40 | 20: 10.,
41 | 21: 10.,
42 | },
43 | # If True, weights are applied only for recall computation. Otherwise,
44 | # an approximation is used for Precision as well assuming that all weights
45 | # apart from the 1.0 weights are equal.
46 | "class_weights_only_recall": True,
47 | # Scale the images up to this size before cropping
48 | # them to `crop_size`.
49 | "scale_size": 286,
50 | "crop_size": 256,
51 | # Type of image content. Currently only supports "people".
52 | "dset_type": "people",
53 | "dset_suffix": "full",
54 |
55 | # Optimizer ####################
56 | "max_epochs": 150,
57 | "max_steps": None,
58 | "batch_size": 30,
59 | "lr": 0.0002, # Adam lr.
60 | "beta1": 0.5, # Adam beta1 param.
61 | "gan_weight": 0.0,
62 | "recon_weight": 1.0,
63 | "latent_weight": 7.0,
64 |
65 | # Infrastructure. ##############
66 | # Save summaries after every x seconds.
67 | "summary_freq": 30,
68 | # Create traces after every x batches.
69 | "trace_freq": 0,
70 | # After every x epochs, save model.
71 | "save_freq": 10,
72 | # Keep x saves.
73 | "kept_saves": 10,
74 | # After every x epochs, render images.
75 | "display_freq": 10,
76 | # Random seed to use.
77 | "seed": 538728914,
78 | }
79 |
--------------------------------------------------------------------------------
/generation/experiments/config/LSM/preprocessing.py:
--------------------------------------------------------------------------------
1 | ../template/preprocessing.py
--------------------------------------------------------------------------------
/generation/experiments/config/LSM/summaries.py:
--------------------------------------------------------------------------------
1 | ../template/summaries.py
--------------------------------------------------------------------------------
/generation/experiments/config/LSM/write_output.py:
--------------------------------------------------------------------------------
1 | ../template/write_output.py
--------------------------------------------------------------------------------
/generation/experiments/config/PM/config.py:
--------------------------------------------------------------------------------
1 | ../template/config.py
--------------------------------------------------------------------------------
/generation/experiments/config/PM/model.py:
--------------------------------------------------------------------------------
1 | ../template/model.py
--------------------------------------------------------------------------------
/generation/experiments/config/PM/options.py:
--------------------------------------------------------------------------------
1 | config = {
2 | # Model. #######################
3 | # Supported are: 'portray', 'vae', 'cvae'.
4 | "model_version": 'portray',
5 | # What modes are available for this model.
6 | "supp_modes": ['train', 'val', 'trainval', 'test'],
7 | # Number of Generator Filters in the first layer.
8 | # Increases gradually to max. 8 * ngf.
9 | "ngf": 64,
10 | # Number of Discriminator Filers in the first layer.
11 | # Increases gradually to max. 8 * ndf.
12 | "ndf": 64,
13 | # Number of latent space dimensions (z) for the vae.
14 | "nz": 512,
15 | # If "input_as_class", use softmax xentropy. Otherwise sigmoid xentropy.
16 | "iac_softmax": True,
17 | # No interconnections inside of the model.
18 | "cvae_noconn": False,
19 | # Omit variational sampling, making the CVAE a CAE.
20 | "cvae_nosampling": True,
21 | # How many samples to draw per image for CVAE.
22 | "cvae_nsamples_per_image": 5,
23 | # Instead of building a y encoder, just downscale y.
24 | "cvae_downscale_y": False,
25 |
26 | # Data. ########################
27 | # Whether to use on-line data augmentation by flipping.
28 | "flip": False,
29 | # Whether to treat input data as classification or image.
30 | "input_as_class": False,
31 | # Whether to treat conditioning data as classification or image (CVAE only).
32 | "conditioning_as_class": False,
33 | # If input_as_class True, it is possible to give class weights here. (If
34 | # undesired, set to `None`.) This is a dictionary from class=>weight.
35 | "class_weights": None,
36 | # If True, weights are applied only for recall computation. Otherwise,
37 | # an approximation is used for Precision as well assuming that all weights
38 | # apart from the 1.0 weights are equal.
39 | "class_weights_only_recall": True,
40 | # Scale the images up to this size before cropping
41 | # them to `crop_size`.
42 | "scale_size": 286,
43 | "crop_size": 256,
44 | # Type of image content.
45 | "dset_type": "people",
46 | "dset_suffix": "full",
47 |
48 | # Optimizer ####################
49 | "max_epochs": 150,
50 | "max_steps": None,
51 | "batch_size": 1,
52 | "lr": 0.0002, # Adam lr.
53 | "beta1": 0.5, # Adam beta1 param.
54 | "gan_weight": 1.0,
55 | "recon_weight": 100.0,
56 | "latent_weight": 0.,
57 |
58 | # Infrastructure. ##############
59 | # Save summaries after every x seconds.
60 | "summary_freq": 30,
61 | # Create traces after every x batches.
62 | "trace_freq": 0,
63 | # After every x epochs, save model.
64 | "save_freq": 10,
65 | # Keep x saves.
66 | "kept_saves": 10,
67 | # After every x epochs, render images.
68 | "display_freq": 10,
69 | # Random seed to use.
70 | "seed": 538728914,
71 | }
72 |
--------------------------------------------------------------------------------
/generation/experiments/config/PM/preprocessing.py:
--------------------------------------------------------------------------------
1 | ../template/preprocessing.py
--------------------------------------------------------------------------------
/generation/experiments/config/PM/summaries.py:
--------------------------------------------------------------------------------
1 | ../template/summaries.py
--------------------------------------------------------------------------------
/generation/experiments/config/PM/write_output.py:
--------------------------------------------------------------------------------
1 | ../template/write_output.py
--------------------------------------------------------------------------------
/generation/experiments/config/PM_class/config.py:
--------------------------------------------------------------------------------
1 | ../template/config.py
--------------------------------------------------------------------------------
/generation/experiments/config/PM_class/model.py:
--------------------------------------------------------------------------------
1 | ../template/model.py
--------------------------------------------------------------------------------
/generation/experiments/config/PM_class/options.py:
--------------------------------------------------------------------------------
1 | config = {
2 | # Model. #######################
3 | # Supported are: 'portray', 'vae', 'cvae'.
4 | "model_version": 'portray',
5 | # What modes are available for this model.
6 | "supp_modes": ['train', 'val', 'trainval', 'test'],
7 | # Number of Generator Filters in the first layer.
8 | # Increases gradually to max. 8 * ngf.
9 | "ngf": 64,
10 | # Number of Discriminator Filers in the first layer.
11 | # Increases gradually to max. 8 * ndf.
12 | "ndf": 64,
13 | # Number of latent space dimensions (z) for the vae.
14 | "nz": 512,
15 | # If "input_as_class", use softmax xentropy. Otherwise sigmoid xentropy.
16 | "iac_softmax": True,
17 | # No interconnections inside of the model.
18 | "cvae_noconn": False,
19 | # Omit variational sampling, making the CVAE a CAE.
20 | "cvae_nosampling": True,
21 | # How many samples to draw per image for CVAE.
22 | "cvae_nsamples_per_image": 5,
23 | # Instead of building a y encoder, just downscale y.
24 | "cvae_downscale_y": False,
25 |
26 | # Data. ########################
27 | # Whether to use on-line data augmentation by flipping.
28 | "flip": False,
29 | # Whether to treat input data as classification or image.
30 | "input_as_class": True,
31 | # Whether to treat conditioning data as classification or image (CVAE only).
32 | "conditioning_as_class": False,
33 | # If input_as_class True, it is possible to give class weights here. (If
34 | # undesired, set to `None`.) This is a dictionary from class=>weight.
35 | "class_weights": None,
36 | # If True, weights are applied only for recall computation. Otherwise,
37 | # an approximation is used for Precision as well assuming that all weights
38 | # apart from the 1.0 weights are equal.
39 | "class_weights_only_recall": True,
40 | # Scale the images up to this size before cropping
41 | # them to `crop_size`.
42 | "scale_size": 286,
43 | "crop_size": 256,
44 | # Type of image content. Currently only supports "people".
45 | "dset_type": "people",
46 | "dset_suffix": "full",
47 |
48 | # Optimizer ####################
49 | "max_epochs": 150,
50 | "max_steps": None,
51 | "batch_size": 1,
52 | "lr": 0.0002, # Adam lr.
53 | "beta1": 0.5, # Adam beta1 param.
54 | "gan_weight": 1.0,
55 | "recon_weight": 100.0,
56 | "latent_weight": 0.,
57 |
58 | # Infrastructure. ##############
59 | # Save summaries after every x seconds.
60 | "summary_freq": 30,
61 | # Create traces after every x batches.
62 | "trace_freq": 0,
63 | # After every x epochs, save model.
64 | "save_freq": 10,
65 | # Keep x saves.
66 | "kept_saves": 10,
67 | # After every x epochs, render images.
68 | "display_freq": 10,
69 | # Random seed to use.
70 | "seed": 538728914,
71 | }
72 |
--------------------------------------------------------------------------------
/generation/experiments/config/PM_class/preprocessing.py:
--------------------------------------------------------------------------------
1 | ../template/preprocessing.py
--------------------------------------------------------------------------------
/generation/experiments/config/PM_class/summaries.py:
--------------------------------------------------------------------------------
1 | ../template/summaries.py
--------------------------------------------------------------------------------
/generation/experiments/config/PM_class/write_output.py:
--------------------------------------------------------------------------------
1 | ../template/write_output.py
--------------------------------------------------------------------------------
/generation/experiments/config/template/config.py:
--------------------------------------------------------------------------------
1 | import imp
2 | import logging
3 | import os.path as path
4 |
5 | LOGGER = logging.getLogger(__name__)
6 |
7 |
8 | def get_config():
9 | CONF_FP = path.join(path.dirname(__file__), "options.py")
10 | LOGGER.info("Loading experiment configuration from `%s`...", CONF_FP)
11 | options = imp.load_source('_options',
12 | path.abspath(path.join(path.dirname(__file__),
13 | 'options.py')))
14 | # with open(CONF_FP, 'r') as inf:
15 | # config = json.loads(inf.read())
16 | LOGGER.info("Done.")
17 | return options.config
18 |
19 |
20 | def adjust_config(config, mode):
21 | # Don't misuse this!
22 | # Results are better if batchnorm is always used in 'training' mode and
23 | # normalizes the observed distribution. That's why it's important to leave
24 | # the batchsize unchanged.
25 | #if mode not in ['train', 'trainval']:
26 | # config['batch_size'] = 1
27 | return config
28 |
--------------------------------------------------------------------------------
/generation/experiments/config/template/model.py:
--------------------------------------------------------------------------------
1 | """Defining the model or model family."""
2 | import collections
3 | import logging
4 |
5 | import numpy as np
6 | import tensorflow as tf
7 | import tensorflow.contrib.layers as tfl
8 |
9 | from gp_tools.tf import get_or_load_variable, get_val_or_initializer
10 |
11 | # flake8: noqa: E501
12 | LOGGER = logging.getLogger(__name__)
13 | EPS = 1e-12
14 | #: Defines the field of a model. Required:
15 | #: gobal_step, train
16 | Model = collections.namedtuple(
17 | "Model",
18 | "global_step, train, " # Required
19 | "z_mean, z_log_sigma_sq, z, y, outputs, predict_real, predict_fake, "
20 | "discrim_loss, gen_loss_GAN, gen_loss_recon, gen_loss_latent, gen_accuracy")
21 |
22 |
23 | # Graph components. ###########################################################
24 | def conv(batch_input, out_channels, stride, orig_graph):
25 | """Convolution without bias as in Isola's paper."""
26 | with tf.variable_scope("conv"):
27 | in_channels = batch_input.get_shape()[3]
28 | filter = get_or_load_variable(
29 | orig_graph,
30 | "filter",
31 | [4, 4, in_channels, out_channels],
32 | dtype=tf.float32,
33 | initializer=tf.random_normal_initializer(0, 0.02))
34 | # [batch, in_height, in_width, in_channels],
35 | # [filter_width, filter_height, in_channels, out_channels]
36 | # => [batch, out_height, out_width, out_channels]
37 | padded_input = tf.pad(batch_input,
38 | [[0, 0], [1, 1], [1, 1], [0, 0]],
39 | mode="CONSTANT")
40 | conv = tf.nn.conv2d(padded_input, filter, [1, stride, stride, 1],
41 | padding="VALID")
42 | return conv
43 |
44 |
45 | def lrelu(x, a):
46 | with tf.name_scope("lrelu"):
47 | # adding these together creates the leak part and linear part
48 | # then cancels them out by subtracting/adding an absolute value term
49 | # leak: a*x/2 - a*abs(x)/2
50 | # linear: x/2 + abs(x)/2
51 | return (0.5 * (1 + a)) * x + (0.5 * (1 - a)) * tf.abs(x)
52 |
53 |
54 | def batchnorm(input, orig_graph, is_training):
55 | return tfl.batch_norm(
56 | input,
57 | decay=0.9,
58 | scale=True,
59 | epsilon=1E-5,
60 | activation_fn=None,
61 | param_initializers={
62 | 'beta': get_val_or_initializer(orig_graph,
63 | tf.constant_initializer(0.),
64 | 'BatchNorm/beta'),
65 | 'gamma': get_val_or_initializer(orig_graph,
66 | tf.random_normal_initializer(1.0,
67 | 0.02),
68 | 'BatchNorm/gamma'),
69 | 'moving_mean': get_val_or_initializer(orig_graph,
70 | tf.constant_initializer(0.),
71 | 'BatchNorm/moving_mean'),
72 | 'moving_variance': get_val_or_initializer(orig_graph,
73 | tf.ones_initializer(),
74 | 'BatchNorm/moving_variance')
75 | },
76 | is_training=is_training,
77 | fused=True, # new implementation with a fused kernel => speedup.
78 | )
79 |
80 |
81 | def deconv(batch_input, out_channels, orig_graph):
82 | with tf.variable_scope("deconv"):
83 | batch, in_height, in_width, in_channels = \
84 | batch_input.get_shape().as_list()
85 | filter = get_or_load_variable(
86 | orig_graph,
87 | "filter", [4, 4, out_channels, in_channels], dtype=tf.float32,
88 | initializer=tf.random_normal_initializer(0, 0.02))
89 | # [batch, in_height, in_width, in_channels],
90 | # [filter_width, filter_height, out_channels, in_channels]
91 | # => [batch, out_height, out_width, out_channels]
92 | conv = tf.nn.conv2d_transpose(
93 | batch_input, filter,
94 | [batch, in_height * 2, in_width * 2, out_channels],
95 | [1, 2, 2, 1],
96 | padding="SAME")
97 | return conv
98 |
99 |
100 | # Processing part. ############################################################
101 | def create_generator(generator_inputs,
102 | generator_outputs_channels,
103 | mode,
104 | config,
105 | full_graph,
106 | conditioning,
107 | rwgrid,
108 | is_training):
109 | if generator_inputs is None:
110 | assert config["model_version"] in ['vae', 'cvae']
111 | LOGGER.info("Omitting encoder network - sampling...")
112 | if (config.get("pix2pix_zbatchnorm", False) and
113 | config["model_version"] == 'portray'):
114 | assert config["batch_size"] > 1
115 | layers = []
116 | if config["model_version"] == 'cvae':
117 | # Build y encoder.
118 | y_layers = []
119 | # y_encoder_1: [batch, 256, 256, in_channels] => [batch, 128, 128, ngf]
120 | with tf.variable_scope("y_encoder_1"):
121 | if config["cvae_downscale_y"]:
122 | output = tf.image.resize_images(
123 | conditioning,
124 | [conditioning.get_shape().as_list()[1] // 2,
125 | conditioning.get_shape().as_list()[2] // 2],
126 | method=tf.image.ResizeMethod.BILINEAR)
127 | else:
128 | output = conv(conditioning, config["ngf"], 2, full_graph)
129 | y_layers.append(output)
130 | layer_specs = [
131 | config["ngf"] * 2, # y_encoder_2: [batch, 128, 128, ngf] => [batch, 64, 64, ngf * 2]
132 | config["ngf"] * 4, # y_encoder_3: [batch, 64, 64, ngf * 2] => [batch, 32, 32, ngf * 4]
133 | config["ngf"] * 8, # y_encoder_4: [batch, 32, 32, ngf * 4] => [batch, 16, 16, ngf * 8]
134 | config["ngf"] * 8, # y_encoder_5: [batch, 16, 16, ngf * 8] => [batch, 8, 8, ngf * 8]
135 | config["ngf"] * 8, # y_encoder_6: [batch, 8, 8, ngf * 8] => [batch, 4, 4, ngf * 8]
136 | config["ngf"] * 8, # y_encoder_7: [batch, 4, 4, ngf * 8] => [batch, 2, 2, ngf * 8]
137 | config["ngf"] * 8, # y_encoder_8: [batch, 2, 2, ngf * 8] => [batch, 1, 1, ngf * 8]
138 | ]
139 | for ch_idx, out_channels in enumerate(layer_specs):
140 | with tf.variable_scope("y_encoder_%d" % (len(y_layers) + 1)):
141 | if config["cvae_downscale_y"]:
142 | output = tf.image.resize_images(
143 | y_layers[-1],
144 | [y_layers[-1].get_shape().as_list()[1] // 2,
145 | y_layers[-1].get_shape().as_list()[2] // 2],
146 | method=tf.image.ResizeMethod.BILINEAR)
147 | else:
148 | rectified = lrelu(y_layers[-1], 0.2)
149 | # [batch, in_height, in_width, in_channels] =>
150 | # [batch, in_height/2, in_width/2, out_channels]
151 | convolved = conv(rectified, out_channels, 2, full_graph)
152 | if ch_idx != len(layer_specs) - 1:
153 | output = batchnorm(convolved, full_graph, is_training)
154 | else:
155 | output = convolved
156 | y_layers.append(output)
157 | y = y_layers[-1]
158 | else:
159 | y = tf.zeros((1,), dtype=tf.float32)
160 |
161 | if generator_inputs is not None:
162 | # encoder_1: [batch, 256, 256, in_channels] => [batch, 128, 128, ngf]
163 | with tf.variable_scope("encoder_1"):
164 | output = conv(generator_inputs, config["ngf"], 2, full_graph)
165 | layers.append(output)
166 |
167 | layer_specs = [
168 | config["ngf"] * 2, # encoder_2: [batch, 128, 128, ngf] => [batch, 64, 64, ngf * 2]
169 | config["ngf"] * 4, # encoder_3: [batch, 64, 64, ngf * 2] => [batch, 32, 32, ngf * 4]
170 | config["ngf"] * 8, # encoder_4: [batch, 32, 32, ngf * 4] => [batch, 16, 16, ngf * 8]
171 | config["ngf"] * 8, # encoder_5: [batch, 16, 16, ngf * 8] => [batch, 8, 8, ngf * 8]
172 | config["ngf"] * 8, # encoder_6: [batch, 8, 8, ngf * 8] => [batch, 4, 4, ngf * 8]
173 | config["ngf"] * 8, # encoder_7: [batch, 4, 4, ngf * 8] => [batch, 2, 2, ngf * 8]
174 | ]
175 | if (config['model_version'] not in ['vae', 'cvae'] or
176 | config['cvae_nosampling']):
177 | layer_specs.append(config["ngf"] * 8)
178 | # encoder_8: [batch, 2, 2, ngf * 8] => [batch, 1, 1, ngf * 8]
179 |
180 | for ch_idx, out_channels in enumerate(layer_specs):
181 | with tf.variable_scope("encoder_%d" % (len(layers) + 1)):
182 | if config['model_version'] == 'cvae':
183 | if not config['cvae_noconn'] or ch_idx == 0:
184 | input = tf.concat([layers[-1],
185 | y_layers[len(layers) - 1]],
186 | axis=3)
187 | else:
188 | input = layers[-1]
189 | else:
190 | input = layers[-1]
191 | rectified = lrelu(input, 0.2)
192 | # [batch, in_height, in_width, in_channels] =>
193 | # [batch, in_height/2, in_width/2, out_channels]
194 | convolved = conv(rectified, out_channels, 2, full_graph)
195 | if (config['model_version'] in ['vae', 'cvae'] or
196 | ch_idx != len(layer_specs) - 1 or
197 | (ch_idx == len(layer_specs) - 1 and
198 | config.get("pix2pix_zbatchnorm", False) and
199 | config["model_version"] == 'portray') or
200 | config['cvae_noconn']):
201 | output = batchnorm(convolved, full_graph, is_training)
202 | else:
203 | output = convolved
204 | layers.append(output)
205 |
206 | if (config['model_version'] in ["vae", 'cvae']
207 | and not config['cvae_nosampling']):
208 | # VAE infrastructure.
209 | with tf.variable_scope("vae_encoder"):
210 | with tf.variable_scope("z_mean"):
211 | if config['model_version'] == 'cvae':
212 | input = tf.concat([layers[-1],
213 | y_layers[len(layers) - 1]],
214 | axis=3)
215 | else:
216 | input = layers[-1]
217 | weights = get_or_load_variable(
218 | full_graph,
219 | "weights", [np.prod(input.get_shape().as_list()[1:]),
220 | config["nz"]],
221 | dtype=tf.float32,
222 | initializer=tf.random_normal_initializer(0, 0.02))
223 | biases = get_or_load_variable(
224 | full_graph, "biases", [config["nz"], ],
225 | dtype=tf.float32)
226 | z_mean = tf.add(
227 | tf.matmul(tf.reshape(input,
228 | (config["batch_size"], -1)),
229 | weights),
230 | biases)
231 | with tf.variable_scope("z_log_sigma_sq"):
232 | weights = get_or_load_variable(
233 | full_graph, "weights", [
234 | np.prod(input.get_shape().as_list()[1:]),
235 | config["nz"]],
236 | dtype=tf.float32,
237 | initializer=tf.random_normal_initializer(0, 0.02))
238 | biases = get_or_load_variable(full_graph,
239 | "biases", [config["nz"], ],
240 | dtype=tf.float32)
241 | if config['model_version'] == 'cvae':
242 | input = tf.concat([layers[-1],
243 | y_layers[len(layers) - 1]],
244 | axis=3)
245 | else:
246 | input = layers[-1]
247 | # Save the 0.5, should be learned.
248 | z_log_sigma_sq = tf.add(
249 | tf.matmul(
250 | tf.reshape(input, (config["batch_size"], -1)),
251 | weights),
252 | biases)
253 |
254 | if (config['model_version'] in ['vae', 'cvae']
255 | and not config["cvae_nosampling"]):
256 | with tf.variable_scope("vae_latent_representation"):
257 | if rwgrid is None:
258 | eps = tf.random_normal((config["batch_size"],
259 | config["nz"]), 0, 1,
260 | dtype=tf.float32)
261 | if generator_inputs is not None:
262 | # z = mu + sigma*epsilon
263 | z = tf.add(z_mean,
264 | tf.multiply(tf.sqrt(tf.exp(z_log_sigma_sq)), eps))
265 | else:
266 | z = eps
267 | z_mean = tf.constant(np.zeros((config["nz"],),
268 | dtype=np.float32),
269 | dtype=tf.float32)
270 | z_log_sigma_sq = tf.constant(np.ones((config["nz"],),
271 | dtype=np.float32),
272 | dtype=tf.float32)
273 | z = z[:, None, None, :]
274 | else:
275 | z = tf.py_func(rwgrid.sample, [], tf.float32)
276 | z_mean = tf.constant(np.zeros((config["nz"],),
277 | dtype=np.float32),
278 | dtype=tf.float32)
279 | z_log_sigma_sq = tf.constant(np.ones((config["nz"],),
280 | dtype=np.float32),
281 | dtype=tf.float32)
282 | z.set_shape([config["batch_size"], 1, 1, config["nz"]])
283 | layers.append(z)
284 | else:
285 | if generator_inputs is None:
286 | raise Exception(
287 | "Sampling required for this model configuration!"
288 | "Model must be VAE or CVAE and cvae_nosampling may not be "
289 | "set!")
290 | z = layers[-1]
291 | z_mean = tf.constant(np.zeros((z.get_shape().as_list()[3],),
292 | dtype=np.float32),
293 | dtype=tf.float32)
294 | z_log_sigma_sq = tf.constant(np.ones((z.get_shape().as_list()[3],),
295 | dtype=np.float32),
296 | dtype=tf.float32)
297 |
298 | layer_specs = [
299 | (config["ngf"] * 8, 0.5), # decoder_8: [batch, 1, 1, ngf * 8] => [batch, 2, 2, ngf * 8 * 2]
300 | (config["ngf"] * 8, 0.5), # decoder_7: [batch, 2, 2, ngf * 8 * 2] => [batch, 4, 4, ngf * 8 * 2]
301 | (config["ngf"] * 8, 0.5), # decoder_6: [batch, 4, 4, ngf * 8 * 2] => [batch, 8, 8, ngf * 8 * 2]
302 | (config["ngf"] * 8, 0.0), # decoder_5: [batch, 8, 8, ngf * 8 * 2] => [batch, 16, 16, ngf * 8 * 2]
303 | (config["ngf"] * 4, 0.0), # decoder_4: [batch, 16, 16, ngf * 8 * 2] => [batch, 32, 32, ngf * 4 * 2]
304 | (config["ngf"] * 2, 0.0), # decoder_3: [batch, 32, 32, ngf * 4 * 2] => [batch, 64, 64, ngf * 2 * 2]
305 | (config["ngf"], 0.0), # decoder_2: [batch, 64, 64, ngf * 2 * 2] => [batch, 128, 128, ngf * 2]
306 | ]
307 | if config["model_version"] in ['vae', 'cvae']:
308 | for spec_idx in [0, 1, 2]:
309 | layer_specs[spec_idx] = (config["ngf"] * 8, 0.)
310 |
311 | if generator_inputs is None:
312 | assert config["model_version"] in ['vae', 'cvae']
313 | num_encoder_layers = 8
314 | else:
315 | num_encoder_layers = len(layers)
316 | for decoder_layer, (out_channels, dropout) in enumerate(layer_specs):
317 | skip_layer = num_encoder_layers - decoder_layer - 1
318 | with tf.variable_scope("decoder_%d" % (skip_layer + 1)):
319 | if config["model_version"] == 'cvae' and decoder_layer == 0:
320 | layers.append(tf.concat([y, z], axis=3))
321 | if config["model_version"] == 'portray' and decoder_layer != 0:
322 | input = tf.concat([layers[-1], layers[skip_layer]], axis=3)
323 | else:
324 | input = layers[-1]
325 | if (decoder_layer == 0 and
326 | config["model_version"] in ['vae', 'cvae']):
327 | # Don't use a ReLU on the latent encoding.
328 | rectified = input
329 | else:
330 | rectified = tf.nn.relu(input)
331 | # [batch, in_height, in_width, in_channels] =>
332 | # [batch, in_height*2, in_width*2, out_channels]
333 | rs = rectified.get_shape().as_list()
334 | if rs[0] is None:
335 | rectified.set_shape([config["batch_size"],
336 | rs[1], rs[2], rs[3]])
337 | output = deconv(rectified, out_channels, full_graph)
338 | output = batchnorm(output, full_graph, is_training)
339 | if dropout > 0.0:
340 | output = tf.nn.dropout(output, keep_prob=1 - dropout)
341 | layers.append(output)
342 |
343 | # decoder_1: [batch, 128, 128, ngf * 2] =>
344 | # [batch, 256, 256, generator_outputs_channels]
345 | with tf.variable_scope("decoder_1"):
346 | input = layers[-1]
347 | rectified = tf.nn.relu(input)
348 | output = deconv(rectified, generator_outputs_channels, full_graph)
349 | unnormalized_output = output
350 | if (config["input_as_class"] and
351 | config["model_version"] in ['vae', 'cvae']):
352 | output = tf.sigmoid(output) - 0.5
353 | else:
354 | output = tf.tanh(output)
355 | layers.append(output)
356 |
357 | return layers[-1], z_mean, z_log_sigma_sq, z, unnormalized_output, y
358 |
359 |
360 | def create_discriminator(discrim_inputs, discrim_targets, config, full_graph,
361 | is_training):
362 | n_layers = 3
363 | if discrim_inputs is None or discrim_targets is None:
364 | LOGGER.info("Omitting discriminator, no inputs or targets provided.")
365 | return tf.constant(np.zeros((config["batch_size"], 30, 30, 1),
366 | dtype=np.float32),
367 | dtype=tf.float32)
368 | layers = []
369 |
370 | # 2x [batch, height, width, in_channels] =>
371 | # [batch, height, width, in_channels * 2]
372 | input = tf.concat([discrim_inputs, discrim_targets], axis=3)
373 |
374 | # layer_1: [batch, 256, 256, in_channels * 2] => [batch, 128, 128, ndf]
375 | with tf.variable_scope("layer_1"):
376 | convolved = conv(input, config["ndf"], 2, full_graph)
377 | rectified = lrelu(convolved, 0.2)
378 | layers.append(rectified)
379 |
380 | # layer_2: [batch, 128, 128, ndf] => [batch, 64, 64, ndf * 2]
381 | # layer_3: [batch, 64, 64, ndf * 2] => [batch, 32, 32, ndf * 4]
382 | # layer_4: [batch, 32, 32, ndf * 4] => [batch, 31, 31, ndf * 8]
383 | for i in range(n_layers):
384 | with tf.variable_scope("layer_%d" % (len(layers) + 1)):
385 | out_channels = config["ndf"] * min(2**(i+1), 8)
386 | stride = 1 if i == n_layers - 1 else 2 # last layer has stride 1
387 | convolved = conv(layers[-1], out_channels, stride, full_graph)
388 | normalized = batchnorm(convolved, full_graph, is_training)
389 | rectified = lrelu(normalized, 0.2)
390 | layers.append(rectified)
391 |
392 | # layer_5: [batch, 31, 31, ndf * 8] => [batch, 30, 30, 1]
393 | with tf.variable_scope("layer_%d" % (len(layers) + 1)):
394 | convolved = conv(rectified, 1, 1, full_graph)
395 | unnormalized_output = convolved
396 | output = tf.sigmoid(convolved)
397 | layers.append(output)
398 |
399 | return layers[-1], unnormalized_output
400 |
401 |
402 | # Interface function. #########################################################
403 | def create_model(mode, examples, config, load_info):
404 | LOGGER.info("Building model...")
405 | if (config["model_version"] == 'portray' and
406 | config.get("portray_additional_conditioning", False)):
407 | inputs = tf.concat([examples.inputs, examples.conditioning], axis=3)
408 | else:
409 | inputs = examples.inputs
410 | conditioning = examples.conditioning
411 | targets = examples.targets
412 | if config["model_version"] == 'cvae':
413 | discrim_ref = conditioning
414 | else:
415 | discrim_ref = inputs
416 |
417 | with tf.variable_scope("generator"):
418 | if mode == 'sample' and config["model_version"] in ['cvae', 'vae']:
419 | if config["input_as_class"]:
420 | out_channels = 22
421 | else:
422 | out_channels = 3
423 | else:
424 | out_channels = int(targets.get_shape()[-1])
425 | # We leave batchnorm always in training mode, because it gives slightly
426 | # better performance to normalize the observed distribution.
427 | outputs, z_mean, z_log_sigma_sq, z, unnormalized_outputs, y =\
428 | create_generator(inputs, out_channels, mode,
429 | config, load_info, conditioning, None,
430 | True)
431 | #mode in ['train', 'trainval'])
432 |
433 | # Create two copies of discriminator, one for real pairs and one for fake
434 | # pairs. Both share the same underlying variables.
435 | with tf.name_scope("real_discriminator"):
436 | with tf.variable_scope("discriminator"):
437 | if (config["gan_weight"] != 0. and
438 | mode not in ['sample', 'transform']):
439 | # 2x [batch, height, width, channels] => [batch, 30, 30, 1]
440 | predict_real, predict_real_unnorm = create_discriminator(
441 | discrim_ref, targets, config, load_info, True)
442 | #mode in ['train', 'trainval'])
443 | else:
444 | predict_real = tf.constant(np.zeros((config["batch_size"],
445 | 30, 30, 1),
446 | dtype=np.float32))
447 | with tf.name_scope("fake_discriminator"):
448 | with tf.variable_scope("discriminator", reuse=True):
449 | if (config["gan_weight"] != 0. and
450 | mode not in ['sample', 'transform']):
451 | # 2x [batch, height, width, channels] => [batch, 30, 30, 1]
452 | predict_fake, predict_fake_unnorm = create_discriminator(
453 | discrim_ref, outputs, config, load_info, True)
454 | #mode in ['train', 'trainval'])
455 | else:
456 | predict_fake = tf.constant(np.ones((config["batch_size"],
457 | 30, 30, 1),
458 | dtype=np.float32))
459 | # Loss. ###################################
460 | if config["model_version"] in ['vae', 'cvae']:
461 | reduction_op = tf.reduce_sum
462 | else:
463 | reduction_op = tf.reduce_mean
464 | with tf.name_scope("discriminator_loss"):
465 | # minimizing -tf.log will try to get inputs to 1
466 | # predict_real => 1
467 | # predict_fake => 0
468 | if config["gan_weight"] != 0. and mode not in ['sample', 'transform']:
469 | discrim_loss = tf.reduce_mean(
470 | reduction_op(-(tf.log(predict_real + EPS) +
471 | tf.log(1 - predict_fake + EPS)),
472 | axis=[1, 2, 3]))
473 | else:
474 | discrim_loss = tf.constant(0, tf.float32)
475 |
476 | with tf.name_scope("generator_loss"):
477 | if targets is not None:
478 | # predict_fake => 1
479 | # abs(targets - outputs) => 0
480 | gen_loss_GAN = tf.reduce_mean(
481 | reduction_op(-tf.log(predict_fake + EPS), axis=[1, 2, 3]))
482 | if (config["input_as_class"] and
483 | config["model_version"] in ['vae', 'cvae']):
484 | labels = targets + .5
485 | if config["class_weights"] is not None:
486 | LOGGER.info("Using class weights (unmentioned classes "
487 | "have weight one): %s.",
488 | config["class_weights"])
489 | # Determine loss weight matrix.
490 | ones = tf.constant(
491 | np.ones((labels.get_shape().as_list()[:3]),
492 | dtype=np.float32))
493 | cwm = tf.identity(ones)
494 | for cw_tuple in config["class_weights"].items():
495 | cwm = tf.where(
496 | tf.equal(labels[:, :, :, cw_tuple[0]], 1.),
497 | ones * cw_tuple[1], # if condition is True
498 | cwm # if condition is False
499 | )
500 | if not config.get("class_weights_only_recall", True):
501 | # Assuming the scaling is equal for all classes
502 | # (except the factor 1 ones) this works.
503 | cwm = tf.where(
504 | tf.equal(tf.argmax(unnormalized_outputs,
505 | axis=3),
506 | cw_tuple[0]),
507 | ones * cw_tuple[1],
508 | cwm
509 | )
510 | if config["iac_softmax"]:
511 | gen_loss_recon = tf.nn.softmax_cross_entropy_with_logits(
512 | logits=unnormalized_outputs, labels=labels)
513 | else:
514 | gen_loss_recon = tf.nn.sigmoid_cross_entropy_with_logits(
515 | logits=unnormalized_outputs, labels=labels)
516 | if config["class_weights"] is not None:
517 | cwm = tf.stack([cwm] * 22, axis=3)
518 | if config["class_weights"] is not None:
519 | # Apply.
520 | gen_loss_recon *= cwm
521 | gen_loss_recon = tf.reduce_mean(
522 | reduction_op(gen_loss_recon, axis=[1, 2]))
523 | else:
524 | # L1.
525 | gen_loss_recon = tf.reduce_mean(
526 | reduction_op(tf.abs(targets - outputs),
527 | axis=[1, 2, 3]))
528 | if (config["model_version"] in ['vae', 'cvae'] and
529 | config["latent_weight"] != 0. and
530 | not config["cvae_nosampling"]):
531 | gen_loss_latent = tf.reduce_mean(
532 | -0.5 * tf.reduce_sum(1 + z_log_sigma_sq
533 | - tf.square(z_mean)
534 | - tf.exp(z_log_sigma_sq), [1, ]))
535 | else:
536 | gen_loss_latent = tf.constant(0, tf.float32)
537 | gen_loss = (gen_loss_GAN * config["gan_weight"] +
538 | gen_loss_recon * config["recon_weight"] +
539 | gen_loss_latent * config["latent_weight"])
540 | else:
541 | gen_loss_GAN = tf.constant(0, tf.float32)
542 | gen_loss_recon = tf.constant(0, tf.float32)
543 | gen_loss_latent = tf.constant(0, tf.float32)
544 | gen_loss = tf.constant(0, tf.float32)
545 |
546 | with tf.variable_scope("global_step"):
547 | global_step = get_or_load_variable(load_info,
548 | "global_step",
549 | (1,),
550 | dtype=tf.int64,
551 | initializer=tf.constant_initializer(
552 | value=0,
553 | dtype=tf.int64),
554 | trainable=False,
555 | sloppy=True)
556 | incr_global_step = tf.assign(global_step, global_step+1)
557 |
558 | if targets is not None:
559 | # For batchnorm running statistics updates.
560 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
561 | with tf.name_scope("discriminator_train"):
562 | with tf.control_dependencies(update_ops):
563 | if config["gan_weight"] != 0.:
564 | discrim_tvars = [var for var in tf.trainable_variables()
565 | if var.name.startswith("discriminator")]
566 | discrim_optim = tf.train.AdamOptimizer(config["lr"],
567 | config["beta1"])
568 | discrim_train = discrim_optim.minimize(
569 | discrim_loss,
570 | var_list=discrim_tvars)
571 | else:
572 | dummyvar = tf.Variable(0)
573 | discrim_train = tf.assign(dummyvar, 0)
574 | with tf.name_scope("generator_train"):
575 | with tf.control_dependencies([discrim_train]):
576 | gen_tvars = [var for var in tf.trainable_variables()
577 | if var.name.startswith("generator")]
578 | gen_optim = tf.train.AdamOptimizer(config["lr"],
579 | config["beta1"])
580 | gen_train = gen_optim.minimize(gen_loss, var_list=gen_tvars)
581 | train = tf.group(incr_global_step, gen_train)
582 | display_discrim_loss = discrim_loss
583 | display_gen_loss_GAN = (gen_loss_GAN *
584 | config["gan_weight"])
585 | display_gen_loss_recon = (gen_loss_recon *
586 | config["recon_weight"])
587 | display_gen_loss_latent = (gen_loss_latent *
588 | config["latent_weight"])
589 | if (config["model_version"] in ['vae', 'cvae'] and
590 | config["input_as_class"]):
591 | gen_accuracy = tf.reduce_mean(
592 | tf.cast(
593 | tf.equal(tf.argmax(unnormalized_outputs, axis=3),
594 | tf.argmax(targets, axis=3)), tf.float32))
595 | else:
596 | gen_accuracy = tf.constant(0., dtype=tf.float32)
597 | else:
598 | train = incr_global_step
599 | display_discrim_loss = discrim_loss
600 | display_gen_loss_GAN = gen_loss_GAN
601 | display_gen_loss_recon = gen_loss_recon
602 | display_gen_loss_latent = gen_loss_latent
603 | gen_accuracy = tf.constant(0., dtype=tf.float32)
604 | LOGGER.info("Model complete.")
605 | return Model(
606 | z_mean=z_mean,
607 | z_log_sigma_sq=z_log_sigma_sq,
608 | z=z,
609 | y=y,
610 | predict_real=predict_real,
611 | predict_fake=predict_fake,
612 | discrim_loss=display_discrim_loss,
613 | gen_loss_GAN=display_gen_loss_GAN,
614 | gen_loss_recon=display_gen_loss_recon,
615 | gen_loss_latent=display_gen_loss_latent,
616 | gen_accuracy=gen_accuracy,
617 | outputs=outputs,
618 | train=train,
619 | global_step=global_step,
620 | )
621 |
--------------------------------------------------------------------------------
/generation/experiments/config/template/options.py:
--------------------------------------------------------------------------------
1 | config = {
2 | # Model. #######################
3 | # Supported are: 'portray', 'vae', 'cvae'.
4 | "model_version": 'cvae',
5 | # What modes are available for this model.
6 | "supp_modes": ['train', 'val', 'trainval', 'test'], # 'sample'
7 | # Number of Generator Filters in the first layer.
8 | # Increases gradually to max. 8 * ngf.
9 | "ngf": 64, # 64
10 | # Number of Discriminator Filers in the first layer.
11 | # Increases gradually to max. 8 * ndf.
12 | "ndf": 64, # 64
13 | # Number of latent space dimensions (z) for the vae.
14 | "nz": 512,
15 | # If "input_as_class", use softmax xentropy. Otherwise sigmoid xentropy.
16 | "iac_softmax": True,
17 | # No interconnections inside of the model.
18 | "cvae_noconn": False,
19 | # Omit variational sampling, making the CVAE a CAE.
20 | "cvae_nosampling": True,
21 | # How many samples to draw per image for CVAE.
22 | "cvae_nsamples_per_image": 5,
23 | # Instead of building a y encoder, just downscale y.
24 | "cvae_downscale_y": False,
25 | # Use additional color conditioning as input.
26 | "portray_additional_conditioning": False,
27 |
28 | # Data and preprocessing. ########################
29 | # Whether to use on-line data augmentation by flipping.
30 | "flip": False,
31 | # Whether to treat input data as classification or image.
32 | "input_as_class": False,
33 | # Whether to treat conditioning data as classification or image.
34 | "conditioning_as_class": False,
35 | # If input_as_class True, it is possible to give class weights here. (If
36 | # undesired, set to `None`.) This is a dictionary from class=>weight.
37 | "class_weights": None,
38 | # If True, weights are applied only for recall computation. Otherwise,
39 | # an approximation is used for Precision as well assuming that all weights
40 | # apart from the 1.0 weights are equal.
41 | "class_weights_only_recall": True,
42 | # Scale the images up to this size before cropping
43 | # them to `crop_size`.
44 | "scale_size": 286,
45 | "crop_size": 256,
46 | # Type of image content.
47 | "dset_type": "people",
48 | "dset_suffix": "full",
49 |
50 | # Optimizer ####################
51 | "max_epochs": 150,
52 | "max_steps": None,
53 | "batch_size": 30,
54 | "lr": 0.0002, # Adam lr.
55 | "beta1": 0.5, # Adam beta1 param.
56 | "gan_weight": 0.0,
57 | "recon_weight": 100.0,
58 | "latent_weight": 0.0,
59 |
60 | # Infrastructure. ##############
61 | # Save summaries after every x seconds.
62 | "summary_freq": 30,
63 | # Create traces after every x batches.
64 | "trace_freq": 0,
65 | # After every x epochs, save model.
66 | "save_freq": 10,
67 | # Keep x saves.
68 | "kept_saves": 10,
69 | # After every x epochs, render images.
70 | "display_freq": 10,
71 | # Random seed to use.
72 | "seed": 538728914,
73 | }
74 |
--------------------------------------------------------------------------------
/generation/experiments/config/template/preprocessing.py:
--------------------------------------------------------------------------------
1 | """Set up the preprocessing pipeline."""
2 | import collections
3 | import logging
4 |
5 | import tensorflow as tf
6 | from gp_tools.tf import one_hot
7 |
8 | LOGGER = logging.getLogger(__name__)
9 |
10 |
11 | #: Define what an example looks like for this model family.
12 | Examples = collections.namedtuple(
13 | "Examples",
14 | "paths, inputs, targets, conditioning")
15 |
16 |
17 | def transform(image, mode, config, nearest=False):
18 | """Apply all relevant image transformations."""
19 | r = image
20 | if nearest:
21 | rm = tf.image.ResizeMethod.NEAREST_NEIGHBOR
22 | else:
23 | rm = tf.image.ResizeMethod.BILINEAR
24 | if config["flip"] and mode in ['train', 'trainval']:
25 | r = tf.image.random_flip_left_right(r, seed=config['seed'])
26 | if (config["scale_size"] != config["crop_size"] and
27 | mode in ['train', 'trainval']):
28 | r = tf.image.resize_images(r,
29 | [config["scale_size"],
30 | config["scale_size"]],
31 | method=rm)
32 | if config["scale_size"] > config["crop_size"]:
33 | offset = tf.cast(tf.floor(
34 | tf.random_uniform([2],
35 | 0,
36 | config["scale_size"] -
37 | config["crop_size"] + 1,
38 | seed=config["seed"])),
39 | dtype=tf.int32)
40 | r = r[offset[0]:offset[0] + config["crop_size"],
41 | offset[1]:offset[1] + config["crop_size"],
42 | :]
43 | elif config["scale_size"] < config["crop_size"]:
44 | raise Exception("scale size cannot be less than crop size")
45 | else:
46 | if (image.get_shape().as_list()[0] != config["crop_size"] or
47 | image.get_shape().as_list()[1] != config["crop_size"]):
48 | r = tf.image.resize_images(r,
49 | [config["crop_size"],
50 | config["crop_size"]],
51 | method=rm)
52 | return r
53 |
54 |
55 | def prepare(load_dict, mode, config):
56 | conditioning = None
57 | paths = load_dict['original_filename']
58 | if config["model_version"] in ['vae', 'cvae']:
59 | # Ignore the image. Only care for the labels.
60 | if config["input_as_class"]:
61 | labels = load_dict['labels']
62 | labels.set_shape((config["scale_size"], config["scale_size"], 3))
63 | labels = labels[:, :, 0]
64 | labels = tf.cast(one_hot(labels, 22), tf.float32) - 0.5
65 | labels = transform(labels, mode, config, True)
66 | labels.set_shape((config["crop_size"], config["crop_size"], 22))
67 | else:
68 | labels = load_dict['label_vis']
69 | labels.set_shape((config["scale_size"], config["scale_size"], 3))
70 | labels = tf.cast(labels, tf.float32)
71 | labels = labels * 2. / 255. - 1.
72 | labels = transform(labels, mode, config)
73 | labels.set_shape((config["crop_size"], config["crop_size"], 3))
74 | inputs = labels
75 | # Conditioning?
76 | if config["model_version"] == 'cvae':
77 | if config["conditioning_as_class"]:
78 | conditioning = load_dict["bodysegments"]
79 | conditioning.set_shape((config["scale_size"],
80 | config["scale_size"], 3))
81 | conditioning = conditioning[:, :, 0]
82 | conditioning = tf.cast(one_hot(conditioning, 7),
83 | tf.float32) - 0.5
84 | conditioning = transform(conditioning, mode, config, True)
85 | conditioning.set_shape((config["crop_size"],
86 | config["crop_size"], 7))
87 | else:
88 | conditioning = load_dict["bodysegments_vis"]
89 | conditioning.set_shape((config["scale_size"],
90 | config["scale_size"], 3))
91 | conditioning = (tf.cast(conditioning, tf.float32) *
92 | 2. / 255. - 1.)
93 | conditioning = transform(conditioning, mode, config)
94 | conditioning.set_shape((config["crop_size"],
95 | config["crop_size"], 3))
96 | else:
97 | if config["input_as_class"]:
98 | inputs = load_dict['labels']
99 | inputs.set_shape((config["scale_size"], config["scale_size"], 3))
100 | inputs = inputs[:, :, 0]
101 | inputs = tf.cast(one_hot(inputs, 22), tf.float32) - 0.5
102 | inputs = transform(inputs, mode, config, True)
103 | inputs.set_shape((config["crop_size"], config["crop_size"], 22))
104 | else:
105 | inputs = load_dict['label_vis']
106 | inputs.set_shape((config["scale_size"], config["scale_size"], 3))
107 | inputs = tf.cast(inputs, tf.float32) * 2. / 255. - 1.
108 | inputs = transform(inputs, mode, config)
109 | inputs.set_shape((config["crop_size"], config["crop_size"], 3))
110 | labels = load_dict['image']
111 | labels.set_shape((config["scale_size"], config["scale_size"], 3))
112 | labels = tf.cast(labels, tf.float32) * 2. / 255. - 1.
113 | labels = transform(labels, mode, config)
114 | labels.set_shape((config["crop_size"], config["crop_size"], 3))
115 | if config.get("portray_additional_conditioning", None):
116 | conditioning = load_dict[config['portray_additional_conditioning']]
117 | conditioning.set_shape((config["scale_size"],
118 | config["scale_size"], 3))
119 | conditioning = (tf.cast(conditioning, tf.float32) *
120 | 2. / 255. - 1.)
121 | conditioning = transform(conditioning, mode, config)
122 | conditioning.set_shape((config["crop_size"],
123 | config["crop_size"], 3))
124 | inputs = tf.cast(inputs, tf.float32)
125 | labels = tf.cast(labels, tf.float32)
126 | if conditioning is not None:
127 | conditioning = tf.cast(conditioning, tf.float32)
128 | return paths, inputs, labels, conditioning
129 | else:
130 | return paths, inputs, labels
131 |
132 |
133 | def preprocess(nsamples, loader_list, mode, config):
134 | if mode == "sample" and config["model_version"] != 'cvae':
135 | return Examples(
136 | paths=None,
137 | inputs=None,
138 | targets=None,
139 | conditioning=None
140 | )
141 | """
142 | if mode == 'sample' and config["model_version"] == 'cvae':
143 | LOGGER.info("Producing %d samples per image.",
144 | config["cvae_nsamples_per_image"])
145 | mult_input_paths = []
146 | for fp in input_paths:
147 | mult_input_paths.extend([fp] * config["cvae_nsamples_per_image"])
148 | input_paths = mult_input_paths
149 | """
150 | min_after_dequeue = 200
151 | capacity = min_after_dequeue + 4 * config["batch_size"]
152 | with tf.name_scope("preprocess"):
153 | example_list = [prepare(load_dict, mode, config)
154 | for load_dict in loader_list]
155 | if mode in ['train', 'trainval']:
156 | prep_tuple = tf.train.shuffle_batch_join(
157 | example_list,
158 | batch_size=config["batch_size"],
159 | capacity=capacity,
160 | min_after_dequeue=min_after_dequeue)
161 | else:
162 | real_bs = config["batch_size"]
163 | if mode == 'sample' and config["model_version"] == 'cvae':
164 | assert config["batch_size"] % config["cvae_nsamples_per_image"] == 0, (
165 | "cvae_nsamples_per_image must be a divisor of batch_size!")
166 | real_bs = config["batch_size"] // config["cvae_nsamples_per_image"]
167 | prep_tuple = tf.train.batch(
168 | example_list[0],
169 | batch_size=real_bs,
170 | capacity=capacity,
171 | num_threads=1, # For determinism.
172 | allow_smaller_final_batch=False)
173 | paths = prep_tuple[0]
174 | inputs = prep_tuple[1]
175 | targets = prep_tuple[2]
176 | if len(prep_tuple) > 3:
177 | conditioning = prep_tuple[3]
178 | else:
179 | conditioning = None
180 | if mode == 'sample':
181 | assert config["model_version"] in ['cvae', 'vae']
182 | targets = None
183 | inputs = None
184 | if config["model_version"] == 'vae':
185 | paths = None
186 | else:
187 | path_batches = [tf.string_join([tf.identity(paths), tf.constant("_s" + str(batch_idx))])
188 | for batch_idx in range(config["cvae_nsamples_per_image"])]
189 | paths = tf.concat(path_batches, axis=0)
190 | conditioning = tf.concat([tf.identity(conditioning)
191 | for _ in range(config["cvae_nsamples_per_image"])],
192 | axis=0)
193 | return Examples(
194 | paths=paths,
195 | inputs=inputs,
196 | targets=targets,
197 | conditioning=conditioning,
198 | )
199 |
--------------------------------------------------------------------------------
/generation/experiments/config/template/summaries.py:
--------------------------------------------------------------------------------
1 | """Summarize the networks actions."""
2 | import os.path as path
3 | import logging
4 | import clustertools.visualization as vs
5 | import tensorflow as tf
6 | import numpy as np
7 | import cv2
8 | import sys
9 | sys.path.insert(0, path.join(path.dirname(__file__), '..', '..', '..', '..'))
10 | from config import CMAP # noqa: E402
11 |
12 |
13 | LOGGER = logging.getLogger(__name__)
14 |
15 |
16 | def postprocess_colormap(cls, postprocess=True):
17 | """Create a colormap out of the classes and postprocess the face."""
18 | batch = vs.apply_colormap(cls, vmin=0, vmax=21, cmap=CMAP)
19 | cmap = vs.apply_colormap(np.array(range(22), dtype='uint8'),
20 | vmin=0, vmax=21, cmap=CMAP)
21 | COLSET = cmap[18:22]
22 | FCOL = cmap[11]
23 | if postprocess:
24 | kernel = np.ones((2, 2), dtype=np.uint8)
25 | for im in batch:
26 | for col in COLSET:
27 | # Extract the map of the matching color.
28 | colmap = np.all(im == col, axis=2).astype(np.uint8)
29 | # Erode.
30 | while np.sum(colmap) > 10:
31 | colmap = cv2.erode(colmap, kernel)
32 | # Prepare the original map for remapping.
33 | im[np.all(im == col, axis=2)] = FCOL
34 | # Backproject.
35 | im[colmap == 1] = col
36 | return batch[:, :, :, :3]
37 |
38 |
39 | def deprocess(config, image, argmax=False, postprocess=True):
40 | if argmax:
41 | def cfunc(x): return postprocess_colormap(x, postprocess)
42 | return tf.py_func(cfunc, [tf.argmax(image, 3)], tf.uint8)
43 | else:
44 | return tf.image.convert_image_dtype((image + 1) / 2, dtype=tf.uint8,
45 | saturate=True)
46 |
47 |
48 | def create_summaries(mode, examples, model, config):
49 | LOGGER.info("Setting up summaries and fetches...")
50 | with tf.variable_scope("deprocessing"):
51 | # Deprocess images. #######################################################
52 | if mode != 'sample':
53 | # Inputs.
54 | with tf.name_scope("deprocess_inputs"):
55 | deprocessed_inputs = deprocess(config,
56 | examples.inputs,
57 | config["input_as_class"])
58 | # Targets.
59 | with tf.name_scope("deprocess_targets"):
60 | deprocessed_targets = deprocess(
61 | config,
62 | examples.targets,
63 | (config["model_version"] in ['vae', 'cvae'] and
64 | config["input_as_class"]))
65 | else:
66 | deprocessed_inputs = None
67 | deprocessed_targets = None
68 | if mode != 'sample' or config["model_version"] == 'cvae':
69 | paths = examples.paths
70 | else:
71 | paths = None
72 | if config["model_version"] == 'cvae':
73 | with tf.name_scope("deprocessed_conditioning"):
74 | deprocessed_conditioning = deprocess(
75 | config,
76 | examples.conditioning,
77 | config["conditioning_as_class"])
78 | elif (config["model_version"] == 'portray' and
79 | examples.conditioning is not None):
80 | deprocessed_conditioning = deprocess(config,
81 | examples.conditioning)
82 | else:
83 | deprocessed_conditioning = None
84 | with tf.name_scope("deprocess_outputs"):
85 | deprocessed_outputs = deprocess(
86 | config,
87 | model.outputs,
88 | (config["input_as_class"] and
89 | config["model_version"] in ['vae', 'cvae']))
90 | if (config["input_as_class"] and
91 | config["model_version"] in ['vae', 'cvae']):
92 | with tf.name_scope("deprocess_unpostprocessed_outputs"):
93 | deprocessed_unpostprocessed_outputs = deprocess(
94 | config, model.outputs, True, False)
95 | else:
96 | deprocessed_unpostprocessed_outputs = None
97 | # Encode the images. ######################################################
98 | display_fetches = dict()
99 | with tf.name_scope("encode_images"):
100 | for name, res in [
101 | ('inputs', deprocessed_inputs),
102 | ('conditioning', deprocessed_conditioning),
103 | ('outputs', deprocessed_outputs),
104 | ('unpostprocessed_outputs',
105 | deprocessed_unpostprocessed_outputs),
106 | ('targets', deprocessed_targets),
107 | ]:
108 | if res is not None:
109 | display_fetches[name] = tf.map_fn(tf.image.encode_png,
110 | res,
111 | dtype=tf.string,
112 | name=name+'_pngs')
113 | if mode != 'sample':
114 | display_fetches['y'] = model.y
115 | display_fetches['z'] = model.z
116 |
117 | if mode != 'sample' or config["model_version"] == 'cvae':
118 | display_fetches['paths'] = paths
119 |
120 | # Create the summaries. ###################################################
121 | if deprocessed_inputs is not None:
122 | with tf.name_scope("inputs_summary"):
123 | tf.summary.image("inputs", deprocessed_inputs)
124 | if deprocessed_targets is not None:
125 | with tf.name_scope("targets_summary"):
126 | tf.summary.image("targets", deprocessed_targets)
127 | with tf.name_scope("outputs_summary"):
128 | tf.summary.image("outputs", deprocessed_outputs)
129 | if deprocessed_conditioning is not None:
130 | with tf.name_scope("conditioning_summary"):
131 | tf.summary.image("conditioning", deprocessed_conditioning)
132 | with tf.name_scope("predict_real_summary"):
133 | tf.summary.image("predict_real",
134 | tf.image.convert_image_dtype(model.predict_real,
135 | dtype=tf.uint8))
136 | with tf.name_scope("predict_fake_summary"):
137 | tf.summary.image("predict_fake",
138 | tf.image.convert_image_dtype(model.predict_fake,
139 | dtype=tf.uint8))
140 | tf.summary.histogram("z_mean", model.z_mean)
141 | tf.summary.histogram("z_log_sigma_sq", model.z_log_sigma_sq)
142 | tf.summary.histogram("z", model.z)
143 |
144 | if mode in ['train', 'trainval']:
145 | tf.summary.scalar("loss/discriminator", model.discrim_loss)
146 | tf.summary.scalar("loss/generator_GAN", model.gen_loss_GAN)
147 | tf.summary.scalar("loss/generator_recon", model.gen_loss_recon)
148 | tf.summary.scalar("loss/generator_latent", model.gen_loss_latent)
149 | tf.summary.scalar("loss/generator_accuracy", model.gen_accuracy)
150 | test_fetches = {}
151 | else:
152 | # These fetches will be evaluated and averaged at test time.
153 | test_fetches = {}
154 | test_fetches["loss/discriminator"] = model.discrim_loss
155 | test_fetches["loss/generator_GAN"] = model.gen_loss_GAN
156 | test_fetches["loss/generator_recon"] = model.gen_loss_recon
157 | test_fetches["loss/generator_latent"] = model.gen_loss_latent
158 | test_fetches["loss/generator_accuracy"] = model.gen_accuracy
159 |
160 | LOGGER.info("Summaries and fetches complete.")
161 | return display_fetches, test_fetches
162 |
--------------------------------------------------------------------------------
/generation/experiments/config/template/write_output.py:
--------------------------------------------------------------------------------
1 | """Summaries and outputs."""
2 | import os
3 | import os.path as path
4 | from collections import OrderedDict
5 | import logging
6 | from gp_tools.write import append_index
7 |
8 |
9 | LOGGER = logging.getLogger(__name__)
10 |
11 |
12 | def save_grid(fetches, image_dir, config, rwgrid, batch):
13 | index_path = os.path.join(path.dirname(image_dir), "index.html")
14 | if os.path.exists(index_path):
15 | index = open(index_path, "a")
16 | else:
17 | index = open(index_path, "w")
18 | index.write("
")
19 | index.write("Grid | ")
20 | for col_idx in range(rwgrid.gridspec[1]):
21 | index.write("%d | " % col_idx)
22 | index.write("
\n")
23 | outputs = fetches['outputs']
24 | for sample_idx in range(config["batch_size"]):
25 | y_pos = ((batch * config["batch_size"] + sample_idx) //
26 | rwgrid.gridspec[1])
27 | x_pos = ((batch * config["batch_size"] + sample_idx) %
28 | rwgrid.gridspec[1])
29 | if y_pos >= rwgrid.gridspec[0]:
30 | break
31 | filename = "%08d-%08d.png" % (y_pos, x_pos)
32 | out_path = os.path.join(image_dir, filename)
33 | with open(out_path, "w") as f:
34 | f.write(outputs[sample_idx])
35 | if x_pos == 0:
36 | index.write("%d | " % (y_pos))
37 | index.write(' | ' % (filename))
38 | if x_pos == rwgrid.gridspec[1] - 1:
39 | index.write("
\n")
40 | return index_path
41 |
42 |
43 | warned_y = False
44 | warned_z = False
45 |
46 |
47 | def save_images(fetches, image_dir, mode, config, step=None, batch=0):
48 | global warned_y, warned_z
49 | image_dir = path.join(image_dir, 'images')
50 | if not path.exists(image_dir):
51 | os.makedirs(image_dir)
52 | row_infos = []
53 | for im_idx in range(config["batch_size"]):
54 | if step is not None:
55 | row_info = OrderedDict([('step', (str(step), 'text')), ])
56 | else:
57 | row_info = OrderedDict()
58 | if mode in ['train', 'trainval', 'val', 'test', 'transform']:
59 | in_path = fetches["paths"][im_idx]
60 | name, _ = os.path.splitext(os.path.basename(in_path))
61 | elif mode == 'sample':
62 | name = str(config["batch_size"] * batch + im_idx)
63 | if 'paths' in fetches.keys():
64 | in_path = fetches["paths"][im_idx]
65 | fname, _ = os.path.splitext(os.path.basename(in_path))
66 | name += '-' + fname
67 | if step is not None:
68 | name = str(step) + '_' + name
69 | row_info["name"] = (name, 'text')
70 | if 'inputs' in fetches.keys():
71 | row_info["inputs"] = (fetches['inputs'][im_idx], 'image')
72 | if 'conditioning' in fetches.keys():
73 | row_info["conditioning"] = (fetches["conditioning"][im_idx],
74 | 'image')
75 | if 'outputs' in fetches.keys():
76 | row_info["outputs"] = (fetches["outputs"][im_idx], 'image')
77 | if 'targets' in fetches.keys():
78 | row_info["targets"] = (fetches["targets"][im_idx], 'image')
79 | if 'y' in fetches.keys():
80 | try:
81 | row_info["y"] = (fetches["y"][im_idx], 'plain')
82 | except:
83 | if not warned_y:
84 | LOGGER.warn("Not sufficient info for storing y!")
85 | warned_y = True
86 | if 'z' in fetches.keys():
87 | try:
88 | row_info["z"] = (fetches["z"][im_idx], 'plain')
89 | except:
90 | if not warned_z:
91 | LOGGER.warn("Not sufficient info for storing z!")
92 | warned_z = True
93 | row_infos.append(row_info)
94 | LOGGER.debug("Processed image %d.",
95 | batch * config["batch_size"] + im_idx + 1)
96 | index_fp = append_index(row_infos, image_dir, mode)
97 | return index_fp
98 |
--------------------------------------------------------------------------------
/generation/generate.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e # Exit on error.
3 | if [ -z ${1+x} ]; then
4 | echo Please specify the number of people! >&2; exit 1
5 | fi
6 | npeople=$1
7 | re='^[0-9]+$'
8 | if ! [[ $1 =~ $re ]] ; then
9 | echo "Error: specify a number" >&2; exit 1
10 | fi
11 |
12 | if [ -z ${2+x} ]; then
13 | out_fp=generated
14 | else
15 | out_fp=$2
16 | fi
17 | if [ -e ${out_fp} ]; then
18 | echo "Output folder exists: ${out_fp}. Please pick a non-existing folder." >&2
19 | exit 1
20 | fi
21 |
22 | # Check environment.
23 | if [ -e tmp ]; then
24 | echo "The directory 'tmp' exists, maybe from an incomplete previous run?" >&2
25 | echo "If so, please delete it and rerun so that it can be used cleanly." >&2
26 | exit 1
27 | fi
28 | if [ -e data/people/tmp ]; then
29 | echo "The directory 'data/people/tmp' exists, maybe from an incomplete previous run?" >&2
30 | echo "If so, please delete it and rerun so that it can be used cleanly." >&2
31 | exit 1
32 | fi
33 | if [ ! -d experiments/states/LSM ]; then
34 | echo "State folder for the latent sketch module not found at " >&2
35 | echo "'experiments/states/LSM'. Either run the training (./run.py trainval experiments/config/LSM) " >&2
36 | echo "or download a pretrained model from http://gp.is.tuebingen.mpg.de." >&2
37 | exit 1
38 | fi
39 | if [ ! -d experiments/states/PM ]; then
40 | echo "State folder for the portray module not found at " >&2
41 | echo "'experiments/states/PM'. Either run the training (./run.py trainval experiments/config/PM) " >&2
42 | echo "or download a pretrained model from http://gp.is.tuebingen.mpg.de." >&2
43 | exit 1
44 | fi
45 |
46 | echo Generating $1 people...
47 | echo Sampling sketches...
48 | ./run.py sample experiments/config/LSM --out_fp tmp --n_samples ${npeople}
49 | echo Done.
50 | echo Preparing for portray module...
51 | mkdir tmp/portray_dset
52 | for sample_idx in $(seq 0 $((${npeople}-1))); do
53 | fullid=$(printf "%04d" ${sample_idx})
54 | # Simulate full dataset.
55 | cp tmp/images/${sample_idx}_outputs.png tmp/portray_dset/${fullid}_bodysegments:png.png
56 | cp tmp/images/${sample_idx}_outputs.png tmp/portray_dset/${fullid}_bodysegments_vis:png.png
57 | cp tmp/images/${sample_idx}_outputs.png tmp/portray_dset/${fullid}_image:png.png
58 | cp tmp/images/${sample_idx}_outputs.png tmp/portray_dset/${fullid}_labels:png.png
59 | cp tmp/images/${sample_idx}_outputs.png tmp/portray_dset/${fullid}_label_vis:png.png
60 | echo ${fullid}_sample.png > tmp/portray_dset/${fullid}_original_filename.txt
61 | done
62 | echo Creating archive...
63 | mkdir -p data/people/tmp
64 | tfrpack tmp/portray_dset --out_fp data/people/tmp/test
65 | echo Done.
66 | echo Creating images...
67 | ./run.py test experiments/config/PM --override_dset_suffix tmp --out_fp ${out_fp}
68 | echo Done.
69 | echo Cleaning up...
70 | rm -rf tmp
71 | rm -rf data/people/tmp
72 | echo Done.
73 |
74 |
--------------------------------------------------------------------------------
/generation/generate_conditioned.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e # Exit on error.
3 | if false; then
4 | if [ -z ${1+x} ]; then
5 | echo Please specify the number of people! >&2; exit 1
6 | fi
7 | npeople=$1
8 | re='^[0-9]+$'
9 | if ! [[ $1 =~ $re ]] ; then
10 | echo "Error: specify a number" >&2; exit 1
11 | fi
12 | if [ -z ${2+x} ] || [ ! ${2: -4} == ".png" ] || [ ! -f $2 ]; then
13 | echo "Error: please provide a path to the image and .pkl file with the "\
14 | "conditioning (provide the image filename in the form "\
15 | "'00001_image.png'; '00001_body.pkl' must exist as well." >&2; exit 1
16 | else
17 | image_fn=$(basename $2)
18 | body_fn="${image_fn%_*}_body.pkl"
19 | image_fp=$2
20 | body_fp="$(dirname ${image_fp})/${body_fn}"
21 | if [ ! -f ${body_fp} ]; then
22 | echo "Error: please provide a path to the image and .pkl file with the "\
23 | "conditioning (provide the image filename in the form "\
24 | "'00001_image.png'; '00001_body.pkl' must exist as well. "\
25 | "Did't find '${body_fp}'.">&2; exit 1
26 | fi
27 | fi
28 | if [ -z ${3+x} ]; then
29 | out_fp=generated
30 | else
31 | out_fp=$3
32 | fi
33 | if [ -e ${out_fp} ]; then
34 | echo "Output folder exists: ${out_fp}. Please pick a non-existing folder." >&2
35 | exit 1
36 | fi
37 |
38 | # Check environment.
39 | if [ -e tmp ]; then
40 | echo "The directory 'tmp' exists, maybe from an incomplete previous run?" >&2
41 | echo "If so, please delete it and rerun so that it can be used cleanly." >&2
42 | exit 1
43 | fi
44 | if [ -e data/people/tmp ]; then
45 | echo "The directory 'data/people/tmp' exists, maybe from an incomplete previous run?" >&2
46 | echo "If so, please delete it and rerun so that it can be used cleanly." >&2
47 | exit 1
48 | fi
49 | if [ ! -d experiments/states/CSM ]; then
50 | echo "State folder for the conditional sketch module not found at " >&2
51 | echo "'experiments/states/CSM'. Either run the training (./run.py trainval experiments/config/CSM) " >&2
52 | echo "or download a pretrained model from http://gp.is.tuebingen.mpg.de." >&2
53 | exit 1
54 | fi
55 | if [ ! -d experiments/states/PM ]; then
56 | echo "State folder for the portray module not found at " >&2
57 | echo "'experiments/states/PM'. Either run the training (./run.py trainval experiments/config/PM) " >&2
58 | echo "or download a pretrained model from http://gp.is.tuebingen.mpg.de." >&2
59 | exit 1
60 | fi
61 |
62 | echo Generating ${npeople} people from conditioning provided in ${body_fp}...
63 | echo Creating 2D conditioning segments...
64 | mkdir -p tmp/dset/test
65 | cp ${image_fp} tmp/dset/test/0_image.png
66 | cp ${body_fp} tmp/dset/test/0_image.png_body.pkl
67 | mkdir -p tmp/prepared_input/test
68 | ./tools/06_render_bodies.py --dset_folder=tmp/dset --out_folder=tmp/prepared_input
69 | cp ${image_fp} tmp/prepared_input/test/0_image.png
70 | cp tmp/prepared_input/test/0_bodysegments.png tmp/prepared_input/test/0_labels.png
71 | cp tmp/prepared_input/test/0_bodysegments_vis.png tmp/prepared_input/test/0_labels_vis.png
72 | echo "0_conditioning.png" > tmp/prepared_input/test/0_original_filename.txt
73 | cp tmp/prepared_input/test/0_bodysegments_vis.png tmp/prepared_input/test/0_segcolors.png
74 | echo Creating archive...
75 | mkdir -p data/people/tmp
76 | tfrpack tmp/prepared_input/test --out_fp data/people/tmp/val
77 | echo Sampling with conditioning...
78 | mkdir -p tmp/csm_out
79 | ./run.py sample experiments/config/CSM --override_dset_suffix tmp --out_fp tmp/csm_out --n_samples ${npeople}
80 | echo Done.
81 | echo Preparing for portray module...
82 | mkdir tmp/portray_dset
83 | for sample_idx in $(seq 0 $((${npeople}-1))); do
84 | fullid=$(printf "%04d" ${sample_idx})
85 | # Simulate full dataset.
86 | cp tmp/csm_out/images/${sample_idx}-0_conditioning_outputs.png tmp/portray_dset/${fullid}_bodysegments:png.png
87 | cp tmp/csm_out/images/${sample_idx}-0_conditioning_outputs.png tmp/portray_dset/${fullid}_bodysegments_vis:png.png
88 | cp tmp/csm_out/images/${sample_idx}-0_conditioning_outputs.png tmp/portray_dset/${fullid}_image:png.png
89 | cp tmp/csm_out/images/${sample_idx}-0_conditioning_outputs.png tmp/portray_dset/${fullid}_labels:png.png
90 | cp tmp/csm_out/images/${sample_idx}-0_conditioning_outputs.png tmp/portray_dset/${fullid}_label_vis:png.png
91 | echo ${fullid}_sample.png > tmp/portray_dset/${fullid}_original_filename.txt
92 | done
93 | echo Creating archive...
94 | mkdir -p data/people/tmp2
95 | tfrpack tmp/portray_dset --out_fp data/people/tmp2/test
96 | echo Done.
97 | echo Creating images...
98 | ./run.py test experiments/config/PM --override_dset_suffix tmp2 --out_fp ${out_fp}
99 | echo Done.
100 | echo Cleaning up...
101 | rm -rf tmp
102 | rm -rf data/people/tmp
103 | rm -rf data/people/tmp2
104 | echo Done.
105 |
--------------------------------------------------------------------------------
/generation/run.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python2
2 | """Main control for the experiments."""
3 | import os
4 | import os.path as path
5 | import ast
6 | import sys
7 | from glob import glob
8 | import signal
9 | import imp
10 | import logging
11 | import time
12 | import numpy as np
13 | import socket
14 |
15 | import tqdm
16 | import click
17 | import tensorflow as tf
18 | from tensorflow.python.client import timeline
19 |
20 | from clustertools.log import LOGFORMAT
21 | import data
22 | sys.path.insert(0, path.join('..'))
23 | from config import EXP_DATA_FP # noqa: E402
24 |
25 |
26 | LOGGER = logging.getLogger(__name__)
27 | EXP_DIR = path.join(path.dirname(__file__), 'experiments')
28 |
29 |
30 | @click.command()
31 | @click.argument("mode",
32 | type=click.Choice([
33 | "train", "val", "trainval", "test", "sample", "transform"]
34 | ))
35 | @click.argument("exp_name",
36 | type=click.Path(exists=True, writable=True, file_okay=False))
37 | @click.option("--num_threads", type=click.INT, default=8,
38 | help="Number of data preprocessing threads.")
39 | @click.option("--no_checkpoint", type=click.BOOL, is_flag=True,
40 | help="Ignore checkpoints.")
41 | @click.option("--checkpoint", type=click.Path(exists=True, dir_okay=False),
42 | default=None, help="Checkpoint to use for restoring (+.meta).")
43 | @click.option("--n_samples", type=click.INT, default=100,
44 | help="The number of samples to sample.")
45 | @click.option("--out_fp", type=click.Path(writable=True), default=None,
46 | help="If specified, write test or sample results there.")
47 | @click.option("--override_dset_suffix", type=click.STRING, default=None,
48 | help="If specified, override the configure dset_suffix.")
49 | @click.option("--custom_options", type=click.STRING, default="",
50 | help="Provide model specific custom options.")
51 | @click.option("--no_output", type=click.BOOL, default=False, is_flag=True,
52 | help="Don't store results in test modes.")
53 | def cli(**args):
54 | """Main control for the experiments."""
55 | mode = args['mode']
56 | exp_name = args['exp_name'].strip("/")
57 | assert exp_name.startswith(path.join("experiments", "config"))
58 | exp_purename = path.basename(exp_name)
59 | exp_feat_fp = path.join("experiments", "features", exp_purename)
60 | exp_log_fp = path.join("experiments", "states", exp_purename)
61 | if not path.exists(exp_feat_fp):
62 | os.makedirs(exp_feat_fp)
63 | if not path.exists(exp_log_fp):
64 | os.makedirs(exp_log_fp)
65 | # Set up file logging.
66 | fh = logging.FileHandler(path.join(exp_log_fp, 'run.py.log'))
67 | fh.setLevel(logging.INFO)
68 | formatter = logging.Formatter(LOGFORMAT)
69 | fh.setFormatter(formatter)
70 | LOGGER.addHandler(fh)
71 | LOGGER.info("Running on host: %s", socket.getfqdn())
72 | if "JOB_ID" in os.environ.keys():
73 | LOGGER.info("Condor job id: %s", os.environ["JOB_ID"])
74 | LOGGER.info("Running mode `%s` for experiment `%s`.", mode, exp_name)
75 | # Configuration.
76 | exp_config_mod = imp.load_source('_exp_config',
77 | path.join(exp_name, 'config.py'))
78 | exp_config = exp_config_mod.adjust_config(
79 | exp_config_mod.get_config(), mode)
80 | assert mode in exp_config["supp_modes"], (
81 | "Unsupported mode by this model: %s, available: %s." % (
82 | mode, str(exp_config["supp_modes"])))
83 | if args["override_dset_suffix"] is not None:
84 | LOGGER.warn("Overriding dset suffix to `%s`!",
85 | args["override_dset_suffix"])
86 | exp_config["dset_suffix"] = args["override_dset_suffix"]
87 | if args['custom_options'] != '':
88 | custom_options = ast.literal_eval(args["custom_options"])
89 | exp_config.update(custom_options)
90 | exp_config['num_threads'] = args["num_threads"]
91 | exp_config['n_samples'] = args["n_samples"]
92 | LOGGER.info("Configuration:")
93 | for key, val in exp_config.items():
94 | LOGGER.info("%s = %s", key, val)
95 | # Data setup.
96 | with tf.device('/cpu:0'):
97 | nsamples, steps_per_epoch, loader_list = \
98 | data.get_dataset(EXP_DATA_FP, mode, exp_config)
99 | LOGGER.info("%d examples prepared, %d steps per epoch.",
100 | nsamples, steps_per_epoch)
101 | LOGGER.info("Setting up preprocessing...")
102 | exp_prep_mod = imp.load_source('_exp_preprocessing',
103 | path.join(exp_name, 'preprocessing.py'))
104 | with tf.device('/cpu:0'):
105 | examples = exp_prep_mod.preprocess(nsamples, loader_list, mode,
106 | exp_config)
107 | # Checkpointing.
108 | if args['no_checkpoint']:
109 | assert args['checkpoint'] is None
110 | if not args["no_checkpoint"]:
111 | LOGGER.info("Looking for checkpoints...")
112 | if args['checkpoint'] is not None:
113 | checkpoint = args['checkpoint'][:-5]
114 | else:
115 | checkpoint = tf.train.latest_checkpoint(exp_log_fp)
116 | if checkpoint is None:
117 | LOGGER.info("No checkpoint found. Continuing without.")
118 | checkpoint, load_session, load_graph = None, None, None
119 | else:
120 | load_graph = tf.Graph()
121 | with load_graph.as_default():
122 | meta_file = checkpoint + '.meta'
123 | LOGGER.info("Restoring graph from `%s`...", meta_file)
124 | rest_saver = tf.train.import_meta_graph(meta_file,
125 | clear_devices=True)
126 | LOGGER.info("Graph restored. Loading checkpoint `%s`...",
127 | checkpoint)
128 | load_session = tf.Session()
129 | rest_saver.restore(load_session, checkpoint)
130 | else:
131 | checkpoint, load_session, load_graph = None, None, None
132 | if mode not in ['train', 'trainval'] and load_session is None:
133 | raise Exception("The mode %s requires a checkpoint!" % (mode))
134 | # Build model.
135 | model_mod = imp.load_source('_model',
136 | path.join(exp_name, 'model.py'))
137 | model = model_mod.create_model(
138 | mode, examples, exp_config, (load_session, load_graph))
139 | del load_graph # Free graph.
140 | if load_session is not None:
141 | load_session.close()
142 | # Setup summaries.
143 | summary_mod = imp.load_source('_summaries',
144 | path.join(exp_name, 'summaries.py'))
145 | display_fetches, test_fetches = summary_mod.create_summaries(
146 | mode, examples, model, exp_config)
147 | # Stats.
148 | with tf.name_scope("parameter_count"):
149 | parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v))
150 | for v in tf.trainable_variables()])
151 | # Prepare output.
152 | out_mod = imp.load_source("_write_output",
153 | path.join(exp_name, 'write_output.py'))
154 | # Preparing session.
155 | if mode in ['train', 'trainval']:
156 | saver = tf.train.Saver(max_to_keep=exp_config["kept_saves"])
157 | else:
158 | saver = None
159 | sess_config = tf.ConfigProto(log_device_placement=False)
160 | sess_config.gpu_options.allow_growth = True
161 | sw = tf.summary.FileWriter(path.join(exp_log_fp, mode))
162 | summary_op = tf.summary.merge_all()
163 | prepared_session = tf.Session(config=sess_config)
164 | initializer = tf.global_variables_initializer()
165 | epoch = 0
166 | with prepared_session as sess:
167 | # from tensorflow.python import debug as tf_debug
168 | # sess = tf_debug.LocalCLIDebugWrapperSession(sess)
169 | LOGGER.info("Parameter count: %d.", sess.run(parameter_count))
170 | LOGGER.info("Starting queue runners...")
171 | coord = tf.train.Coordinator()
172 | threads = tf.train.start_queue_runners(sess=sess, coord=coord)
173 | LOGGER.info("Initializing variables...")
174 | sess.run(initializer)
175 | fetches = {}
176 | fetches["global_step"] = model.global_step
177 | global_step = sess.run(fetches)["global_step"][0]
178 | LOGGER.info("On global step: %d.", global_step)
179 | if len(glob(path.join(exp_log_fp, mode, 'events.*'))) == 0:
180 | LOGGER.info("Summarizing graph...")
181 | sw.add_graph(sess.graph, global_step=global_step)
182 | if mode in ['val', 'test']:
183 | image_dir = path.join(exp_feat_fp, 'step_' + str(global_step))
184 | elif mode in ['sample']:
185 | image_dir = path.join(exp_feat_fp,
186 | time.strftime("%Y-%m-%d_%H-%M-%S",
187 | time.gmtime()))
188 | else:
189 | image_dir = exp_log_fp
190 | if args["out_fp"] is not None:
191 | image_dir = args["out_fp"]
192 | if not args["no_output"]:
193 | LOGGER.info("Writing image status to `%s`.", image_dir)
194 | else:
195 | image_dir = None
196 | if mode in ['val', 'test', 'sample', 'transform']:
197 | shutdown_requested = [False]
198 | def SIGINT_handler(signal, frame): # noqa: E306
199 | LOGGER.warn("Received SIGINT.")
200 | shutdown_requested[0] = True
201 | signal.signal(signal.SIGINT, SIGINT_handler)
202 | # run a single epoch over all input data
203 | if mode in ['val', 'test', 'transform']:
204 | num_ex = steps_per_epoch
205 | else:
206 | if exp_config['model_version'] == 'cvae':
207 | num_ex = steps_per_epoch * exp_config["cvae_nsamples_per_image"]
208 | else:
209 | num_ex = int(np.ceil(float(args["n_samples"]) /
210 | exp_config["batch_size"]))
211 | av_results = dict((name, []) for name in test_fetches.keys())
212 | av_placeholders = dict((name, tf.placeholder(tf.float32))
213 | for name in test_fetches.keys())
214 | for name in test_fetches.keys():
215 | tf.summary.scalar(name, av_placeholders[name],
216 | collections=['evaluation'])
217 | test_summary = tf.summary.merge_all('evaluation')
218 | display_fetches.update(test_fetches)
219 | for b_id in tqdm.tqdm(range(num_ex)):
220 | results = sess.run(display_fetches)
221 | if not args['no_output']:
222 | index_fp = out_mod.save_images(results, image_dir, mode,
223 | exp_config, batch=b_id)
224 | # Check for problems with this result.
225 | results_valid = True
226 | for key in test_fetches.keys():
227 | if not np.isfinite(results[key]):
228 | if 'paths' in results.keys():
229 | LOGGER.warn("There's a problem with results for "
230 | "%s! Skipping.", results['paths'][0])
231 | else:
232 | LOGGER.warn("Erroneous result for batch %d!",
233 | b_id)
234 | results_valid = False
235 | break
236 | if results_valid:
237 | for key in test_fetches.keys():
238 | av_results[key].append(results[key])
239 | if shutdown_requested[0]:
240 | break
241 | LOGGER.info("Results:")
242 | feed_results = dict()
243 | for key in test_fetches.keys():
244 | av_results[key] = np.mean(av_results[key])
245 | feed_results[av_placeholders[key]] = av_results[key]
246 | LOGGER.info(" %s: %s", key, av_results[key])
247 | if not shutdown_requested[0]:
248 | sw.add_summary(sess.run(test_summary, feed_dict=feed_results),
249 | global_step)
250 | else:
251 | LOGGER.warn("Not writing results to tf summary due to "
252 | "incomplete evaluation.")
253 | if not args['no_output']:
254 | LOGGER.info("Wrote index at `%s`.", index_fp)
255 | elif mode in ["train", "trainval"]:
256 | # Training.
257 | max_steps = 2**32
258 | last_summary_written = time.time()
259 | if exp_config["max_epochs"] is not None:
260 | max_steps = steps_per_epoch * exp_config["max_epochs"]
261 | if exp_config["max_steps"] is not None:
262 | max_steps = exp_config["max_steps"]
263 | shutdown_requested = [False] # Needs to be mutable to access.
264 | # Register signal handler to save on Ctrl-C.
265 | def SIGINT_handler(signal, frame): # noqa: E306
266 | LOGGER.warn("Received SIGINT. Saving model...")
267 | saver.save(sess,
268 | path.join(exp_log_fp, "model"),
269 | global_step=model.global_step)
270 | shutdown_requested[0] = True
271 | signal.signal(signal.SIGINT, SIGINT_handler)
272 | pbar = tqdm.tqdm(total=(max_steps - global_step) *
273 | exp_config["batch_size"])
274 | for step in range(global_step, max_steps):
275 | def should(freq, epochs=False):
276 | if epochs:
277 | return freq > 0 and ((epoch + 1) % freq == 0 and
278 | (step + 1) % steps_per_epoch == 0
279 | or step == max_steps - 1)
280 | else:
281 | return freq > 0 and ((step + 1) % freq == 0 or
282 | step == max_steps - 1)
283 | options = None
284 | run_metadata = None
285 | if should(exp_config["trace_freq"]):
286 | options = tf.RunOptions(
287 | trace_level=tf.RunOptions.FULL_TRACE)
288 | run_metadata = tf.RunMetadata()
289 | # Setup fetches.
290 | fetches = {
291 | "train": model.train,
292 | "global_step": model.global_step,
293 | }
294 | if ((time.time() - last_summary_written) >
295 | exp_config["summary_freq"]):
296 | fetches["summary"] = summary_op
297 | if (should(exp_config["display_freq"], epochs=True) or
298 | should(exp_config["save_freq"], epochs=True) or
299 | step == max_steps - 1):
300 | fetches["display"] = display_fetches
301 | # Run!
302 | results = sess.run(fetches, options=options,
303 | run_metadata=run_metadata)
304 | # Write.
305 | if (should(exp_config["save_freq"], epochs=True) or
306 | results["global_step"] == 1 or
307 | step == max_steps - 1):
308 | # Save directly at first iteration to make sure this is
309 | # working.
310 | LOGGER.info("Saving model...")
311 | gs = model.global_step
312 | saver.save(sess,
313 | path.join(exp_log_fp, "model"),
314 | global_step=gs)
315 | if "summary" in results.keys():
316 | sw.add_summary(results["summary"],
317 | results["global_step"])
318 | last_summary_written = time.time()
319 | if "display" in results.keys():
320 | LOGGER.info("saving display images")
321 | out_mod.save_images(results["display"],
322 | image_dir,
323 | mode,
324 | exp_config,
325 | step=results["global_step"][0])
326 | if should(exp_config["trace_freq"]):
327 | LOGGER.info("recording trace")
328 | sw.add_run_metadata(
329 | run_metadata, "step_%d" % results["global_step"])
330 | trace = timeline.Timeline(
331 | step_stats=run_metadata.step_stats)
332 | with open(path.join(
333 | exp_log_fp,
334 | "timeline.json"), "w") as trace_file:
335 | trace_file.write(trace.generate_chrome_trace_format())
336 | # Enter 'chrome://tracing' in chrome to open the file.
337 | epoch = results["global_step"] // steps_per_epoch
338 | pbar.update(exp_config["batch_size"])
339 | if shutdown_requested[0]:
340 | break
341 | pbar.close()
342 | LOGGER.info("Shutting down...")
343 | coord.request_stop()
344 | coord.join(threads)
345 | data.cleanup()
346 | LOGGER.info("Done.")
347 |
348 |
349 | if __name__ == '__main__':
350 | logging.basicConfig(level=logging.INFO, format=LOGFORMAT)
351 | logging.getLogger("clustertools.db.tools").setLevel(logging.WARN)
352 | logging.getLogger("PIL.Image").setLevel(logging.WARN)
353 | cli()
354 |
--------------------------------------------------------------------------------
/generation/test_runner.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python2
2 | import os
3 | import os.path as path
4 | import subprocess
5 | import time
6 | from glob import glob
7 | import logging
8 | import click
9 | from tensorflow.tensorboard.backend.event_processing.event_accumulator import (
10 | EventAccumulator)
11 | from clustertools.log import LOGFORMAT
12 |
13 |
14 | LOGGER = logging.getLogger(__name__)
15 |
16 |
17 | def get_unevaluated_checkpoint(result_fp, events_acc):
18 | LOGGER.info("Scanning for unevaluated checkpoints...")
19 | # Get all checkpoints.
20 | checkpoints = sorted(
21 | [int(path.basename(val)[6:-5])
22 | for val in glob(path.join(result_fp, 'model-*.meta'))])
23 | LOGGER.debug("Available checkpoints: %s.", checkpoints)
24 | events_acc.Reload()
25 | available_tags = events_acc.Tags()['scalars']
26 | if available_tags:
27 | test_tag = events_acc.Tags()['scalars'][0]
28 | LOGGER.debug("Using tag `%s` for check.", test_tag)
29 | recorded_steps = sorted([ev.step
30 | for ev in events_acc.Scalars(test_tag)])
31 | LOGGER.debug("Recorded steps: %s.", recorded_steps)
32 | else:
33 | LOGGER.debug("No recorded steps found.")
34 | recorded_steps = []
35 | # Merge.
36 | for cp_idx in checkpoints:
37 | if cp_idx not in recorded_steps:
38 | LOGGER.info("Detected unevaluated checkpoint: %d.", cp_idx)
39 | return cp_idx
40 | LOGGER.info("Scan complete. No new checkpoints found.")
41 | return None
42 |
43 |
44 | @click.command()
45 | @click.argument("state_fp", type=click.Path(exists=True, readable=True))
46 | @click.argument("monitor_set", type=click.Choice(["val", "test"]))
47 | @click.option("--check_interval", type=click.INT, default=60,
48 | help="Interval in seconds between checks for new checkpoints.")
49 | def cli(state_fp, monitor_set, check_interval=60):
50 | """Start a process providing validation/test results for a training."""
51 | LOGGER.info("Starting monitoring for result path `%s` and set `%s`.",
52 | state_fp, monitor_set)
53 | if not path.exists(path.join(state_fp, monitor_set)):
54 | os.makedirs(path.join(state_fp, monitor_set))
55 | events_acc = EventAccumulator(path.join(state_fp, monitor_set))
56 | while True:
57 | cp_idx = get_unevaluated_checkpoint(state_fp, events_acc)
58 | if cp_idx is not None:
59 | LOGGER.info("Running evaluation for checkpoint %d...", cp_idx)
60 | subprocess.check_call(["./run.py",
61 | monitor_set,
62 | path.join("experiments", "config",
63 | path.basename(state_fp)),
64 | "--checkpoint",
65 | path.join(state_fp, 'model-%d.meta' % (
66 | cp_idx)),
67 | "--no_output"])
68 | else:
69 | time.sleep(check_interval)
70 |
71 |
72 | if __name__ == '__main__':
73 | logging.basicConfig(level=logging.INFO, format=LOGFORMAT)
74 | cli() # pylint: disable=no-value-for-parameter
75 |
--------------------------------------------------------------------------------
/generation/tools/01_chic10k_to_fashion_joined.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python2
2 | """Assemble the fashion dataset."""
3 | import os
4 | import os.path as path
5 | import sys
6 | import scipy.misc as sm
7 | from glob import glob
8 | import numpy as np
9 | import click
10 | import logging
11 | import cv2
12 | import dlib
13 | from clustertools.log import LOGFORMAT
14 | from clustertools.visualization import apply_colormap
15 | import clustertools.db.tools as cdbt
16 |
17 | sys.path.insert(0, path.join('..', '..'))
18 | from config import CHICTOPIA_DATA_FP, CMAP # noqa: E402
19 |
20 |
21 | LOGGER = logging.getLogger(__name__)
22 |
23 |
24 | def getface(sketch, BORDER=0):
25 | # Face has value 11.
26 | minind_x = 255
27 | maxind_x = 0
28 | for row_idx in range(sketch.shape[0]):
29 | if 11 in sketch[row_idx, :]:
30 | minind_x = min([minind_x, np.argmax(sketch[row_idx, :] == 11)])
31 | maxind_x = max(maxind_x, sketch.shape[1] - 1 - np.argmax(sketch[row_idx, ::-1] == 11))
32 | minind_y = 255
33 | maxind_y = 0
34 | for col_idx in range(sketch.shape[1]):
35 | if 11 in sketch[:, col_idx]:
36 | minind_y = min([minind_y, np.argmax(sketch[:, col_idx] == 11)])
37 | maxind_y = max([maxind_y, sketch.shape[0] - 1 - np.argmax(sketch[::-1, col_idx] == 11)])
38 | LOGGER.debug("Without border: min_y=%d, max_y=%d, min_x=%d, max_x=%d",
39 | minind_y, maxind_y, minind_x, maxind_x)
40 | minind_x = max([0, minind_x - BORDER])
41 | maxind_x = min([sketch.shape[1] - 1, maxind_x + BORDER])
42 | minind_y = max([0, minind_y - BORDER])
43 | maxind_y = min([sketch.shape[0] - 1, maxind_y + BORDER])
44 | LOGGER.debug("With border: min_y=%d, max_y=%d, min_x=%d, max_x=%d",
45 | minind_y, maxind_y, minind_x, maxind_x)
46 | # Make the area rectangular.
47 | if maxind_y - minind_y != maxind_x - minind_x:
48 | if maxind_y - minind_y > maxind_x - minind_x:
49 | # Height is longer, enlarge width.
50 | diff = maxind_y - minind_y - maxind_x + minind_x
51 | if minind_x < int(np.floor(diff / 2.)):
52 | maxind_x = maxind_x + (diff - minind_x)
53 | minind_x = 0
54 | elif sketch.shape[1] - maxind_x - int(np.ceil(diff / 2.)) < 0:
55 | minind_x = minind_x - (diff - sketch.shape[1] + maxind_x)
56 | maxind_x = sketch.shape[1] - 1
57 | else:
58 | minind_x = minind_x - int(np.floor(diff / 2.))
59 | maxind_x = maxind_x + int(np.ceil(diff / 2.))
60 | else:
61 | # Width is longer, enlarge height.
62 | diff = - (maxind_y - minind_y - maxind_x + minind_x)
63 | if minind_y < int(np.floor(diff / 2.)):
64 | maxind_y = maxind_y + (diff - minind_y)
65 | minind_y = 0
66 | elif sketch.shape[0] - maxind_y - int(np.ceil(diff / 2.)) < 0:
67 | minind_y = minind_y - (diff - sketch.shape[0] + maxind_y)
68 | maxind_y = sketch.shape[0] - 1
69 | else:
70 | minind_y = minind_y - int(np.floor(diff / 2.))
71 | maxind_y = maxind_y + int(np.ceil(diff / 2.))
72 | if maxind_y - minind_y <= 0 or maxind_x - minind_x <= 0:
73 | LOGGER.warn("No face detected in image!")
74 | return None
75 | return minind_y, maxind_y, minind_x, maxind_x
76 |
77 |
78 | fdetector = dlib.get_frontal_face_detector()
79 | spredictor = dlib.shape_predictor(path.join(path.dirname(__file__),
80 | 'shape_predictor_68_face_landmarks.dat'))
81 | if not path.exists(spredictor):
82 | LOGGER.critical("Please download and unpack the face shape model from "
83 | "http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 "
84 | "to `%s`.", spredictor)
85 | sys.exit(1)
86 |
87 | def prepare(im_idx):
88 | # Load.
89 | image = sm.imread(path.join(CHICTOPIA_DATA_FP, 'JPEGImages',
90 | '%s.jpg' % (im_idx)))
91 | if image.ndim != 3:
92 | return []
93 | resize_factor = 513. / max(image.shape[:2])
94 | im_resized = sm.imresize(image, resize_factor)
95 | annotation = sm.imread(path.join(CHICTOPIA_DATA_FP, 'SegmentationClassAug',
96 | '%s.png' % (im_idx)))
97 | resannot = sm.imresize(annotation, resize_factor, interp='nearest')
98 | # Holes.
99 | kernel = np.ones((7, 7), np.uint8)
100 | closed_annot = cv2.morphologyEx(resannot, cv2.MORPH_CLOSE, kernel)
101 | grad = cv2.morphologyEx(resannot, cv2.MORPH_BLACKHAT, kernel)
102 | to_fill = np.logical_and(resannot == 0, grad > 0)
103 | resannot[to_fill] = closed_annot[to_fill]
104 | # Face detection.
105 | FDEBUG = False
106 | # For debugging.
107 | if FDEBUG:
108 | win = dlib.image_window()
109 | win.clear_overlay()
110 | win.set_image(im_resized)
111 | face_box = getface(resannot)
112 | max_IOU = 0.
113 | if face_box is not None:
114 | most_likely_det = None
115 | dets, _, _ = fdetector.run(im_resized, 1, -1)
116 | for k, d in enumerate(dets):
117 | # Calculate IOU with ~ground truth.
118 | ar_pred = (d.right() - d.left()) * (d.bottom() - d.top())
119 | face_points = (resannot == 11)
120 | face_pos = np.where(face_points)
121 | inters_x = np.logical_and(face_pos[1] >= d.left(),
122 | face_pos[1] < d.right())
123 | inters_y = np.logical_and(face_pos[0] >= d.top(),
124 | face_pos[0] < d.bottom())
125 | inters_p = np.sum(np.logical_and(inters_x, inters_y))
126 | outs_p = np.sum(face_points) - inters_p
127 | IOU = float(inters_p) / (outs_p + ar_pred)
128 | if IOU > 1.:
129 | import ipdb; ipdb.set_trace()
130 | if IOU > 0.3 and IOU > max_IOU:
131 | most_likely_det = d
132 | max_IOU = IOU
133 | if most_likely_det is not None:
134 | shape = spredictor(im_resized, most_likely_det)
135 | # Save hat, hair and sunglasses (likely to cover eyes or nose).
136 | hat = (resannot == 1)
137 | hair = (resannot == 2)
138 | sungl = (resannot == 3)
139 | # Add annotations:
140 | an_lm = {
141 | (48, 67): 18, # lips
142 | (27, 35): 19, # nose
143 | (36, 41): 20, # leye
144 | (42, 47): 21, # reye
145 | }
146 | for rng, ann_id in an_lm.items():
147 | poly = np.empty((2, rng[1] - rng[0]),
148 | dtype=np.int64)
149 | for point_idx, point_id in enumerate(range(*rng)):
150 | poly[0, point_idx] = shape.part(point_id).x
151 | poly[1, point_idx] = shape.part(point_id).y
152 | # Draw additional annotations.
153 | poly = poly.T.copy()
154 | cv2.fillPoly(
155 | resannot,
156 | [poly],
157 | (ann_id,))
158 | # Write back hat, hair and sungl.
159 | resannot[hat] = 1
160 | resannot[hair] = 2
161 | resannot[sungl] = 3
162 | if FDEBUG:
163 | win.add_overlay(shape)
164 | win.add_overlay(most_likely_det)
165 | dlib.hit_enter_to_continue()
166 | else:
167 | # No reliable face found.
168 | return []
169 | return [(
170 | '%s.jpg' % (im_idx),
171 | im_resized,
172 | np.dstack([resannot] * 3),
173 | apply_colormap(resannot, vmin=0, vmax=21, cmap=CMAP)[:, :, :3]
174 | )]
175 |
176 |
177 | @click.command()
178 | def cli():
179 | """Assemble a unified fashion dataset."""
180 | np.random.seed(1)
181 | out_fp = path.join(path.dirname(__file__), '..', '..', 'data')
182 | LOGGER.info("Using output directory `%s`.", out_fp)
183 | if not path.exists(out_fp):
184 | os.mkdir(out_fp)
185 | db_fns = glob(path.join(out_fp, '*.tfrecords'))
186 | for db_fn in db_fns:
187 | os.unlink(db_fn)
188 | chic10k_root = CHICTOPIA_DATA_FP
189 | chic10k_im_fps = sorted(glob(path.join(chic10k_root, 'JPEGImages', '*.jpg')))
190 | chic10k_ids = [path.basename(im_fp)[:path.basename(im_fp).index('.')]
191 | for im_fp in chic10k_im_fps]
192 | perm = np.random.permutation(chic10k_ids)
193 | train_ids = perm[:int(len(perm) * 0.8)]
194 | val_ids = perm[int(len(perm) * 0.8):int(len(perm) * 0.9)]
195 | test_ids = perm[int(len(perm) * 0.9):]
196 | creator = cdbt.TFRecordCreator([
197 | ('original_filename', cdbt.SPECTYPES.text),
198 | # It is critical here to use lossless compression - the adversary will
199 | # otherwise pick up JPEG compression cues.
200 | ('image', cdbt.SPECTYPES.imlossless),
201 | ('labels', cdbt.SPECTYPES.imlossless),
202 | ('label_vis', cdbt.SPECTYPES.imlossless),
203 | ],
204 | examples_per_file=300)
205 | for pname, pids in zip(['train', 'val', 'test'],
206 | [train_ids, val_ids, test_ids]):
207 | creator.open(path.join(out_fp, pname))
208 | creator.add_to_dset(prepare,
209 | pids,
210 | num_threads=16,
211 | progress=True)
212 | creator.close()
213 | LOGGER.info("Done.")
214 |
215 |
216 | if __name__ == '__main__':
217 | logging.basicConfig(level=logging.INFO, format=LOGFORMAT)
218 | cli()
219 |
--------------------------------------------------------------------------------
/generation/tools/02_create_clothdataset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env
2 | """Create the dataset."""
3 | # pylint: disable=wrong-import-position, invalid-name
4 | import os
5 | import os.path as path
6 | import sys
7 | import logging
8 |
9 | import numpy as np
10 | import scipy
11 | import scipy.misc
12 | import click
13 |
14 | sys.path.insert(0, path.join(path.dirname(__file__), '..', '..'))
15 | from config import CMAP # pylint: disable=no-name-in-module
16 | from clustertools.log import LOGFORMAT
17 | from clustertools.visualization import apply_colormap
18 | import clustertools.db.tools as cdb
19 |
20 |
21 | LOGGER = logging.getLogger(__name__)
22 |
23 |
24 | # Build lr-swapper.
25 | chick10k_mapping = {
26 | 0: 0, # Background.
27 | 1: 1, # Hat.
28 | 2: 2, # Hair.
29 | 3: 3, # Sunglasses.
30 | 4: 4, # Top.
31 | 5: 5, # Skirt.
32 | 6: 6, # ??
33 | 7: 7, # Dress.
34 | 8: 8, # Belt.
35 | 9: 10, # Left shoe.
36 | 10: 9, # Right shoe.
37 | 11: 11, # Skin, face.
38 | 12: 13, # Skin, left leg.
39 | 13: 12, # Skin, right leg.
40 | 14: 15, # Skin, left arm.
41 | 15: 14, # Skin, right arm,
42 | 16: 16, # Bag.
43 | 17: 17, # ??.
44 | 18: 18, # lips.
45 | 19: 19, # nose.
46 | 20: 21, # leye.
47 | 21: 20, # reye.
48 | }
49 |
50 | def lrswap_regions(annotations):
51 | """Swap left and right annotations."""
52 | assert annotations.ndim == 2
53 | swapspec = chick10k_mapping
54 | swapped = np.empty_like(annotations)
55 | for py in range(annotations.shape[0]):
56 | for px in range(annotations.shape[1]):
57 | swapped[py, px] = swapspec[annotations[py, px]]
58 | return swapped
59 |
60 | def pad_height(image, crop):
61 | """Pad the height to given crop size."""
62 | if image.ndim == 2:
63 | image = image[:, :, None]
64 | image = np.vstack((np.zeros((int(np.floor(max([0, crop - image.shape[0]]) / 2.)),
65 | image.shape[1],
66 | image.shape[2]), dtype=image.dtype),
67 | image,
68 | np.zeros((int(np.ceil(max([0, crop - image.shape[0]]) / 2.)),
69 | image.shape[1],
70 | image.shape[2]), dtype=image.dtype)))
71 | if image.shape[2] == 1: # pylint: disable=no-else-return
72 | return image[:, :, 0]
73 | else:
74 | return image
75 |
76 | def pad_width(image, crop):
77 | """Pad the width to given crop size."""
78 | if image.ndim == 2:
79 | image = image[:, :, None]
80 | image = np.hstack((np.zeros((image.shape[0],
81 | int(np.floor(max([0, crop - image.shape[1]]) / 2.)),
82 | image.shape[2]), dtype=image.dtype),
83 | image,
84 | np.zeros((image.shape[0],
85 | int(np.ceil(max([0, crop - image.shape[1]]) / 2.)),
86 | image.shape[2]), dtype=image.dtype)))
87 | if image.shape[2] == 1:
88 | return image[:, :, 0]
89 | else:
90 | return image
91 |
92 | crop = 0
93 | def convert(inputs):
94 | imname = inputs['original_filename']
95 | image = inputs['image']
96 | labels = inputs['labels']
97 | label_vis = inputs['label_vis']
98 | results = []
99 | segmentation = labels[:, :, 0]
100 | norm_factor = float(crop) / max(image.shape[:2])
101 | image = scipy.misc.imresize(image, norm_factor, interp='bilinear')
102 | segmentation = scipy.misc.imresize(segmentation, norm_factor, interp='nearest')
103 | if image.shape[0] < crop:
104 | # Pad height.
105 | image = pad_height(image, crop)
106 | segmentation = pad_height(segmentation, crop)
107 | if image.shape[1] < crop:
108 | image = pad_width(image, crop)
109 | segmentation = pad_width(segmentation, crop)
110 | labels = np.dstack([segmentation] * 3)
111 | label_vis = apply_colormap(segmentation, vmax=21, vmin=0, cmap=CMAP)[:, :, :3]
112 | results.append([imname, image * (labels != 0), labels, label_vis])
113 | # Swapped version.
114 | imname = path.splitext(imname)[0] + '_swapped' + path.splitext(imname)[1]
115 | image = image[:, ::-1]
116 | segmentation = segmentation[:, ::-1]
117 | segmentation = lrswap_regions(segmentation)
118 | labels = np.dstack([segmentation] * 3)
119 | label_vis = apply_colormap(segmentation, vmax=21, vmin=0, cmap=CMAP)[:, :, :3]
120 | results.append([imname, image * (labels != 0), labels, label_vis])
121 | return results
122 |
123 |
124 | @click.command()
125 | @click.argument("suffix", type=click.STRING)
126 | @click.option("--crop_size", type=click.INT, default=286,
127 | help="Crop size for the images.")
128 | def cli(suffix, crop_size): # pylint: disable=too-many-locals, too-many-arguments
129 | """Create clothing segmentation to fashion image dataset."""
130 | global crop
131 | np.random.seed(1)
132 | crop = crop_size
133 | LOGGER.info("Creating generation dataset with target "
134 | "image size %f and suffix `%s`.",
135 | crop, suffix)
136 | assert ' ' not in suffix
137 | dset_fp = path.join(path.dirname(__file__), '..', 'data', 'people', suffix)
138 | if path.exists(dset_fp):
139 | if not click.confirm("Dataset folder exists: `%s`! Continue?" % (
140 | dset_fp)):
141 | return
142 | else:
143 | os.makedirs(dset_fp)
144 | converter = cdb.TFConverter([
145 | ('original_filename', cdb.SPECTYPES.text),
146 | ('image', cdb.SPECTYPES.imlossless),
147 | ('labels', cdb.SPECTYPES.imlossless),
148 | ('label_vis', cdb.SPECTYPES.imlossless),
149 | ])
150 | LOGGER.info("Processing...")
151 | for pname in ['train', 'val', 'test']:
152 | converter.open(
153 | path.join(path.dirname('__file__'), '..', '..', 'data', pname),
154 | path.join(dset_fp, pname))
155 | converter.convert_dset(convert,
156 | num_threads=16,
157 | progress=True)
158 | converter.close()
159 | LOGGER.info("Done.")
160 |
161 |
162 | if __name__ == '__main__':
163 | logging.basicConfig(level=logging.INFO, format=LOGFORMAT)
164 | cli() # pylint: disable=no-value-for-parameter
165 |
--------------------------------------------------------------------------------
/generation/tools/03_create_pose_input.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [ ! -d ../data/pose ]; then
4 | mkdir ../data/pose
5 | fi
6 |
7 | if [ ! -d ../data/pose/extracted ]; then
8 | mkdir ../data/pose/extracted
9 | mkdir ../data/pose/extracted/{train,val,test}
10 | mkdir ../data/pose/input
11 | mkdir ../data/pose/input/{train,val,test}
12 | fi
13 |
14 | for part in train val test; do
15 | tfrcat ../data/people/$1/${part} --out_fp ../data/pose/extracted/${part}
16 | cp ../data/pose/extracted/${part}/*_image*.png ../data/pose/input/${part}/
17 | done
18 |
19 |
--------------------------------------------------------------------------------
/generation/tools/04_run_deepercut.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | up_fp=$(../../config.py UP_FP)
4 |
5 | for part in train val test; do
6 | ${up_fp}/pose/pose_deepercut.py ../data/pose/input/${part}
7 | done
8 |
--------------------------------------------------------------------------------
/generation/tools/05_run_fits.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | trap "exit" INT
4 |
5 | for dset_part in train val test; do
6 | $(../../config.py UP_FP)/3dfit/bodyfit.py \
7 | ../data/pose/input/$dset_part/ --use_inner_penetration --only_missing \
8 | --allow_subsampling
9 | done
10 |
--------------------------------------------------------------------------------
/generation/tools/06_render_bodies.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python2
2 | import imp
3 | import os.path as path
4 | from glob import glob
5 | import logging
6 | import click
7 | import tqdm
8 | import numpy as np
9 | import scipy.misc as sm
10 |
11 | from clustertools import visualization as vs
12 | from clustertools.log import LOGFORMAT
13 | import up_tools.model as upm
14 | import up_tools.render_segmented_views as upr
15 | import pymp
16 | config = imp.load_source(
17 | 'config',
18 | path.abspath(path.join(path.dirname(__file__),
19 | "..", "..", "config.py")))
20 |
21 |
22 | LOGGER = logging.getLogger(__name__)
23 |
24 |
25 | def process_image(im_fp, dset_part, out_folder):
26 | bn = path.basename(im_fp)
27 | dn = path.dirname(im_fp)
28 | img_idx = int(bn[:bn.find("_")])
29 | body_fp = path.join(dn, bn + '_body.pkl')
30 | im = sm.imread(im_fp)
31 | if not path.exists(body_fp):
32 | raise Exception("Body fit not found for `%s` (`%s`)!" % (im_fp, body_fp))
33 | rendering = upr.render_body_impl(body_fp,
34 | resolution=(im.shape[0], im.shape[1]),
35 | quiet=True,
36 | use_light=False)[0]
37 | annotation = upm.regions_to_classes(rendering, upm.six_region_groups,
38 | warn_id=str(img_idx))
39 | out_fp = path.join(out_folder, dset_part,
40 | "{:0{width}d}_bodysegments.png".format(
41 | img_idx, width=bn.find("_")))
42 | sm.imsave(out_fp, annotation)
43 | out_fp = path.join(out_folder, dset_part,
44 | "{:0{width}d}_bodysegments_vis.png".format(
45 | img_idx, width=bn.find("_")))
46 | sm.imsave(out_fp, vs.apply_colormap(annotation, vmin=0, vmax=6,
47 | cmap=config.CMAP)[:, :, 0:3])
48 |
49 |
50 | @click.command()
51 | @click.option("--dset_folder",
52 | type=click.Path(exists=True, readable=True, file_okay=False),
53 | default=path.join('..', 'data', 'pose', 'input'),
54 | help="The dataset folder to process.")
55 | @click.option("--out_folder",
56 | type=click.Path(exists=True, writable=True, file_okay=False),
57 | default=path.join('..', 'data', 'pose', 'extracted'),
58 | help="The output folder.")
59 | def cli(dset_folder=path.join('..', 'data', 'pose', 'input'),
60 | out_folder=path.join('..', 'data', 'pose', 'extracted')):
61 | LOGGER.info("Processing...")
62 | for part in ['train', 'val', 'test']:
63 | im_fps = sorted(glob(path.join(dset_folder, part, '*_image*.png')))
64 | # Filter.
65 | im_fps = [im_fp for im_fp in im_fps
66 | if ('pose' not in path.basename(im_fp) and
67 | 'body' not in path.basename(im_fp))]
68 | with pymp.Parallel(12, if_=False) as p:
69 | for im_fp in p.iterate(tqdm.tqdm(im_fps)):
70 | process_image(im_fp, part, out_folder)
71 | LOGGER.info("Done.")
72 |
73 | if __name__ == '__main__':
74 | logging.basicConfig(level=logging.INFO, format=LOGFORMAT)
75 | cli()
76 |
--------------------------------------------------------------------------------
/generation/tools/07_create_additional_conditioning.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python2
2 | import imp
3 | import os.path as path
4 | from glob import glob
5 | import logging
6 | import click
7 | import tqdm
8 | import scipy.misc as sm
9 | import numpy as np
10 |
11 | from clustertools.log import LOGFORMAT
12 | from geometric_median import geometric_median
13 | import pymp
14 | config = imp.load_source(
15 | 'config',
16 | path.abspath(path.join(path.dirname(__file__),
17 | "..", "..", "config.py")))
18 |
19 |
20 | LOGGER = logging.getLogger(__name__)
21 |
22 |
23 | def process_image(im_fp, dset_part):
24 | bn = path.basename(im_fp)
25 | dn = path.dirname(im_fp)
26 | img_idx = int(bn[:bn.find("_")])
27 | im = sm.imread(im_fp)
28 | segmentation = sm.imread(path.join(path.dirname(im_fp),
29 | "{0:0{width}d}_labels:png.png".format(
30 | img_idx, width=bn.find("_"))))
31 | if segmentation.ndim == 3:
32 | segmentation = segmentation[:, :, 0]
33 | region_ids = sorted(np.unique(segmentation))
34 | img_colors = np.zeros_like(im)
35 | for region_id in region_ids:
36 | if region_id == 0:
37 | # Background is already 0-labeled.
38 | continue
39 | region_id_set = [region_id]
40 | if region_id == 9: # left shoe.
41 | region_id_set.append(10)
42 | elif region_id == 10:
43 | region_id_set.append(9)
44 | elif region_id in [11, 14, 15, 19]: # Skin.
45 | region_id_set = [11, 14, 15, 19]
46 | elif region_id in [12, 13]: # legs.
47 | region_id_set = [12, 13]
48 | elif region_id in [20, 21]:
49 | region_id_set = [20, 21] # eyes.
50 | region_colors = np.vstack(
51 | [im[segmentation == idx] for idx in region_id_set])
52 | med_color = geometric_median(region_colors)
53 | img_colors[segmentation == region_id] = med_color
54 | sm.imsave(path.join(dn, "{0:0{width}d}_segcolors.png".format(
55 | img_idx, width=bn.find("_"))), img_colors)
56 |
57 |
58 | @click.command()
59 | def cli():
60 | LOGGER.info("Processing...")
61 | for part in ['train', 'val', 'test']:
62 | im_fps = sorted(glob(path.join('..', 'data', 'pose', 'extracted', part,
63 | '*_image*.png')))
64 | # Filter.
65 | im_fps = [im_fp for im_fp in im_fps
66 | if ('pose' not in path.basename(im_fp) and
67 | 'body' not in path.basename(im_fp))]
68 | with pymp.Parallel(12, if_=True) as p:
69 | for im_fp in p.iterate(tqdm.tqdm(im_fps)):
70 | process_image(im_fp, part)
71 | LOGGER.info("Done.")
72 |
73 |
74 | if __name__ == '__main__':
75 | logging.basicConfig(level=logging.INFO, format=LOGFORMAT)
76 | cli()
77 |
--------------------------------------------------------------------------------
/generation/tools/08_prepare_directinpaint.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env zsh
2 |
3 | if [ "${#}" -ne 2 ]; then
4 | echo Provide in and output!
5 | exit 1
6 | fi
7 |
8 | indset=$1
9 | outdset=$2
10 |
11 | mkdir ../data/people/${outdset}
12 | mkdir ../data/people/${outdset}/test
13 | for im in ../data/people/${indset}/*/*_pix2pix.png; do
14 | echo $(basename ${im})
15 | cp ${im} ../data/people/${outdset}/test/
16 | done
17 |
--------------------------------------------------------------------------------
/generation/tools/09_pack_db.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | trap "exit" INT
4 |
5 | out_suff=$1
6 | if [ ! -d ../data/people ]; then
7 | mkdir ../data/people
8 | fi
9 |
10 | if [ -d ../data/people/${out_suff} ]; then
11 | echo A dataset with this suffix exists!
12 | exit 1
13 | fi
14 | mkdir ../data/people/${out_suff}
15 |
16 | for dset_part in train val test; do
17 | ~/git/clustertools/clustertools/scripts/tfrpack.py \
18 | ../data/pose/extracted/${dset_part} \
19 | --out_fp ../data/people/${out_suff}/${dset_part}
20 | done
21 |
--------------------------------------------------------------------------------
/generation/tools/geometric_median.py:
--------------------------------------------------------------------------------
1 | # Taken from here:
2 | # http://stackoverflow.com/questions/30299267/geometric-median-of-multidimensional-points
3 | # All credits go to user 'orlp', the code is released by him under the zlib
4 | # license.
5 | import numpy as np
6 | from scipy.spatial.distance import cdist, euclidean
7 |
8 | def geometric_median(X, eps=1e-5):
9 | y = np.mean(X, 0)
10 |
11 | while True:
12 | D = cdist(X, [y])
13 | nonzeros = (D != 0)[:, 0]
14 |
15 | Dinv = 1 / D[nonzeros]
16 | Dinvs = np.sum(Dinv)
17 | W = Dinv / Dinvs
18 | T = np.sum(W * X[nonzeros], 0)
19 |
20 | num_zeros = len(X) - np.sum(nonzeros)
21 | if num_zeros == 0:
22 | y1 = T
23 | elif num_zeros == len(X):
24 | return y
25 | else:
26 | R = (T - y) * Dinvs
27 | r = np.linalg.norm(R)
28 | rinv = 0 if r == 0 else num_zeros/r
29 | y1 = max(0, 1-rinv)*T + min(1, rinv)*y
30 |
31 | if euclidean(y, y1) < eps:
32 | return y1
33 |
34 | y = y1
35 |
--------------------------------------------------------------------------------
/gp_tools/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/classner/generating_people/ee6b54945d395efe23ddceb4a2daca3fffc7e89d/gp_tools/__init__.py
--------------------------------------------------------------------------------
/gp_tools/tf.py:
--------------------------------------------------------------------------------
1 | """Tensorflow tools."""
2 | import logging
3 | import numpy as np
4 | import tensorflow as tf
5 |
6 |
7 | LOGGER = logging.getLogger(__name__)
8 |
9 |
10 | def get_val_or_initializer(load_tpl, initializer, varname,
11 | sloppy=False, dtype=tf.float32):
12 | """Get the variable value from an existing graph or the initializer."""
13 | orig_sess, orig_graph = load_tpl
14 | full_name = tf.get_variable_scope().name + '/' + varname + ':0'
15 | if orig_graph is not None:
16 | LOGGER.debug("Restoring `%s`." % (full_name))
17 | with orig_graph.as_default():
18 | orig_var = [var for var in tf.global_variables()
19 | if var.name == full_name]
20 | if len(orig_var) != 1:
21 | if not sloppy:
22 | LOGGER.critical("Missing value for variable `%s`!" % (
23 | full_name))
24 | import ipdb; ipdb.set_trace() # noqa: E702
25 | else:
26 | orig_var = None
27 | init_value = None
28 | else:
29 | orig_var = orig_var[0]
30 | init_value = orig_var.eval(session=orig_sess)
31 | if orig_var is not None:
32 | init_value = tf.convert_to_tensor(init_value, dtype=dtype)
33 | else:
34 | init_value = None
35 | if init_value is not None:
36 | return lambda *args, **kwargs: tf.cast(
37 | init_value, kwargs.get('dtype', tf.float32))
38 | else:
39 | return initializer
40 |
41 |
42 | def get_or_load_variable(load_tpl, *args, **kwargs):
43 | """Get a variable value from an existing graph or create one."""
44 | orig_sess, orig_graph = load_tpl
45 | full_name = tf.get_variable_scope().name + '/' + args[0] + ':0'
46 | if orig_graph is not None:
47 | LOGGER.debug("Restoring `%s`." % (full_name))
48 | with orig_graph.as_default():
49 | orig_var = [var for var in tf.global_variables()
50 | if var.name == full_name]
51 | if len(orig_var) != 1:
52 | if 'sloppy' not in kwargs.keys() or not kwargs['sloppy']:
53 | LOGGER.critical("Missing value for variable `%s`!" % (
54 | full_name))
55 | import ipdb; ipdb.set_trace() # noqa: E702
56 | else:
57 | orig_var = None
58 | init_value = None
59 | else:
60 | orig_var = orig_var[0]
61 | init_value = orig_var.eval(session=orig_sess)
62 | if orig_var is not None:
63 | new_dt = tf.float32
64 | if 'dtype' in kwargs.keys():
65 | new_dt = kwargs['dtype']
66 | init_value = tf.convert_to_tensor(init_value, dtype=new_dt)
67 | else:
68 | init_value = None
69 | if init_value is not None:
70 | trainable = True
71 | if "trainable" in kwargs.keys():
72 | trainable = kwargs["trainable"]
73 | return tf.get_variable(args[0],
74 | initializer=init_value,
75 | trainable=trainable)
76 | else:
77 | if 'sloppy' in kwargs.keys():
78 | del kwargs['sloppy']
79 | return tf.get_variable(*args, **kwargs)
80 |
81 |
82 | def one_hot(inputs, num_classes):
83 | """
84 | One hot encoding with fixed number of classes.
85 |
86 | # noqa: E501
87 | See also: http://stackoverflow.com/questions/35226198/is-this-one-hot-encoding-in-tensorflow-fast-or-flawed-for-any-reason
88 | """
89 | inshape = inputs.get_shape().as_list()
90 | assert len(inshape) <= 2
91 | for shcomp in inshape:
92 | assert shcomp is not None
93 | input_vec = tf.reshape(inputs, (-1, 1))
94 | table = tf.constant(np.identity(num_classes, dtype=np.float32))
95 | embeddings = tf.nn.embedding_lookup(table, tf.cast(input_vec, tf.int32))
96 | outshape = inshape + [num_classes, ]
97 | output = tf.reshape(embeddings, outshape)
98 | return output
99 |
--------------------------------------------------------------------------------
/gp_tools/write.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os.path as path
3 |
4 |
5 | def append_index(row_infos, image_dir, mode):
6 | """Append or create the presentation html file for the images."""
7 | index_path = path.join(path.dirname(image_dir), "index.html")
8 | if path.exists(index_path):
9 | index = open(index_path, "a")
10 | else:
11 | index = open(index_path, "w")
12 | index.write("\n\n\n")
13 | colnames = [key for key, val in row_infos[0].items()
14 | if val[1] in ['image', 'text']]
15 | for coln in colnames:
16 | index.write("%s | \n" % (coln))
17 | index.write("
\n")
18 | for row_info in row_infos:
19 | index.write("\n")
20 | for coln, (colc, colt) in row_info.items():
21 | if colt == 'text':
22 | index.write("%s | " % (colc))
23 | elif colt == 'image':
24 | filename = path.join(image_dir,
25 | row_info['name'][0] + '_' + coln + '.png')
26 | with open(filename, 'w') as outf:
27 | outf.write(colc)
28 | index.write(" | " % (
29 | path.basename(filename)))
30 | elif colt == 'plain':
31 | filename = path.join(image_dir,
32 | row_info['name'][0] + '_' + coln + '.npy')
33 | np.save(filename, colc)
34 | else:
35 | raise Exception("Unsupported mode: %s." % (mode))
36 | index.write("
\n")
37 | return index_path
38 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | scipy
3 | tqdm
4 | click
5 | pillow
6 | dlib
7 | pymp-pypi
8 | git+https://github.com/classner/clustertools.git#egg=clustertools
9 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python2
2 | # -*- coding: utf-8 -*-
3 | """
4 | The setup script for this project.
5 | @author: Christoph Lassner
6 | """
7 | from setuptools import setup
8 | from pip.req import parse_requirements
9 |
10 | VERSION = '0.1'
11 | REQS = [str(ir.req) for ir in parse_requirements('requirements.txt',
12 | session='tmp')]
13 |
14 | setup(
15 | name='gp_tools',
16 | author='Christoph Lassner',
17 | author_email='mail@christophlassner.de',
18 | packages=['gp_tools'],
19 | dependency_links=['http://github.com/classner/clustertools/tarball/master#egg=clustertools'],
20 | include_package_data=True,
21 | install_requires=REQS,
22 | version=VERSION,
23 | license='Creative Commons Non-Commercial 4.0',
24 | )
25 |
--------------------------------------------------------------------------------