├── .gitignore
├── README.md
├── common_flags.py
├── data
├── flying_animals_data
│ └── fa_data_decode.py
├── flying_animals_utils.py
├── multi_dsprites_utils.py
├── multi_texture_data
│ ├── bg.png
│ ├── ellipse_2.png
│ ├── square_2.png
│ └── tex.png
├── multi_texture_utils.py
└── objects_room_utils.py
├── doc
├── flying_animals.gif
├── multi_texture.gif
├── objects_room.gif
└── pc.gif
├── eval
└── eval_VAE.py
├── main.py
├── model
├── Summary.py
├── __init__.py
├── globalVAE_graph.py
├── nets.py
├── train_graph.py
├── traverse_graph.py
└── utils
│ ├── __init__.py
│ ├── convolution_utils.py
│ ├── generic_utils.py
│ └── loss_utils.py
├── sample_imgs
├── flying_animals
│ ├── 01.png
│ └── 02.png
├── multi_dsprites
│ ├── 01.png
│ └── 02.png
├── multi_texture
│ └── 01.png
└── objects_room
│ ├── 01.png
│ ├── 02.png
│ ├── 03.png
│ └── 04.png
├── script
├── flying_animals
│ ├── disentangle.sh
│ ├── pretrain_inpainter.sh
│ ├── test_segmentation.sh
│ ├── train_CIS.sh
│ └── train_VAE.sh
├── multi_dsprites
│ ├── disentangle.sh
│ ├── pretrain_inpainter.sh
│ ├── test_segmentation.sh
│ ├── train_CIS.sh
│ └── train_VAE.sh
├── multi_texture
│ ├── disentangle.sh
│ ├── perceptual_consistency
│ │ ├── finetune_PC.sh
│ │ ├── test_segmentation.sh
│ │ ├── train_CIS.sh
│ │ └── train_VAE.sh
│ ├── pretrain_inpainter.sh
│ ├── test_segmentation.sh
│ ├── train_CIS.sh
│ └── train_VAE.sh
└── objects_room
│ ├── disentangle.sh
│ ├── pretrain_inpainter.sh
│ ├── test_segmentation.sh
│ ├── train_CIS.sh
│ └── train_VAE.sh
├── tb.sh
├── test_segmentation.py
└── trainer
├── __init__.py
├── train_CIS.py
├── train_PC.py
├── train_VAE.py
├── train_end2end.py
├── train_globalVAE.py
└── train_inpainter.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 | data/flying_animals_data/*.npz
131 | data/flying_animals_data/data
132 | data/objects_room_data/*.tfrecords
133 | data/multi_dsprites_data/*.tfrecords
134 |
135 | save_checkpoint/
136 | checkpoint/
137 | script0/
138 | tb.sh
139 | resnet/
140 | *_select/
141 | outputs/
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [CVPR 2020] Learning to Manipulate Individual Objects in an Image
2 | This repo contains the implementation of the method described in the paper
3 |
4 | [Learning to Manipulate Individual Objects in an Image](https://arxiv.org/pdf/2004.05495.pdf)
5 |
6 | Published in the International Conference of Computer Vision and Pattern Recognition (CVPR) 2020.
7 |
8 |
9 |
10 | ### Introduction:
11 | We describe a method to train a generative model with latent factors that are (approximately) independent and localized. This means that perturbing the latent variables affects only local regions of the synthesized image, corresponding to objects. Unlike other unsupervised generative models, ours enables object-centric manipulation, without requiring object-level annotations, or any form of annotation for that matter. For more details, please check our paper.
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 | ## Running the code
25 | ### Prerequisites
26 |
27 | This code was tested with the following packages. Note that other version of them might work but are untested.
28 |
29 | * Ubuntu 16.04
30 | * python3
31 | * tensorflow-gpu==1.14.0
32 | * python-gflags 3.1.2
33 | * keras 2.3.1
34 | * imageio 2.6.1
35 | * numpy 1.17.2
36 | * gitpython 3.0.5
37 |
38 | ### Datasets
39 |
40 |
41 | #### Multi-dSprites and Objects Room
42 |
43 | Download two existing datasets with the following commands:
44 | ```
45 | mkdir data/multi_dsprites_data data/objects_room_data
46 | wget https://storage.googleapis.com/multi-object-datasets/multi_dsprites/multi_dsprites_colored_on_colored.tfrecords -P data/multi_dsprites_data
47 | wget https://storage.googleapis.com/multi-object-datasets/objects_room/objects_room_train.tfrecords -P data/objects_room_data
48 | ```
49 | These two datasets are TFRecords files and can be used without pre-processing.
50 |
51 |
52 | #### Multi-Texture
53 |
54 | The components are already included in [data/multi\_texture\_data](data/multi_texture_data) and will be automatically used to generate images online while training and testing.
55 |
56 |
57 | #### Flying Animals
58 |
59 | Please download the zip file from
60 | [this link](https://drive.google.com/open?id=1xs9CdR8HC_RxfuEbZnD_hmMqQusAuhbO), put it in [data/flying\_animals\_data](data/flying_animals_data) and then run the following commands to decode the raw images into .npz file.
61 | ```
62 | cd data/flying_animals_data
63 | unzip data.zip
64 | python fa_data_decode.py
65 | ```
66 | These commands generate img_data.npz and img_data_test.npz in [data/flying\_animals\_data](data/flying_animals_data) for training and testing
67 |
68 | ### Training
69 |
70 | To stabilize and speed up adversarial training, our training consists of three steps. Default hyperparameters settings for four datasets and three steps are included in [script/](script). Please modify some arguments, e.g. the path of output checkpoints, in scripts when necessary.
71 |
72 | #### 1.Pretrain inpainting network
73 |
74 | Pretrain inpainting network on the task of predicting pixels on box-shaped occlusions.
75 | ```
76 | sh script/dataset_name/pretrain_inpainter.sh
77 | ```
78 | Pretrained checkpoints of inpainting network for each dataset can be downloaded [here](https://drive.google.com/drive/folders/1AcFb2kfFpEuD-Wi_Iz_Z9-mkgs3anEOF?usp=sharing). You can directly restore the downloaded checkpoint to skip this step.
79 |
80 | #### 2.Spatial disentanglement
81 |
82 | Update inpainting network and segmentation network adversarially for spatial disentanglement.
83 |
84 | ```
85 | sh script/dataset_name/train_CIS.sh
86 | ```
87 | Note that while for other datasets we train segmentation network from scratch, for flying animals dataset we suggest initializing ResNetV2-50 with checkpoint pretrained on ImageNet which can be found [here](https://github.com/tensorflow/models/tree/master/research/slim#pre-trained-models). Please download the checkpoint by running
88 | ```
89 | mkdir resnet && cd resnet
90 | wget http://download.tensorflow.org/models/resnet_v2_50_2017_04_14.tar.gz
91 | tar -xvf resnet_v2_50_2017_04_14.tar.gz
92 | ```
93 |
94 | #### 3.Train VAE
95 | ```
96 | sh script/dataset_name/train_VAE.sh
97 | ```
98 | Train encoder and decoder to learn disentangled latent space.
99 |
100 | ### IoU Evaluation
101 |
102 | For a trained model, you can measure its segmentation performance with the function [test\_segmentation.py](./test_segmentation.py). Examples of test script are provided as script/dataset_name/test_segmentation.sh. Edit it with the path to the checkpoint file and run it to compute the mean and standard variance of mean-IoU scores on 10 subsets.
103 | ```
104 | sh script/dataset_name/test_segmentation.sh
105 | ```
106 |
107 | ### Disentanglement
108 |
109 | After finishing all training steps, you can visualize the latent space's disentanglement by feeding a target image into the model and varying one latent dimension at a time to see if the perturbation only leads to one type of semantic variation of one particular object in the synthesized image.
110 |
111 | Script examples for disentanglement visualization are provided as script/dataset_name/disentanglement.sh. Edit them with paths to checkpoint and output directories.
112 | ```
113 | sh script/dataset_name/disentanglement.sh
114 | ```
115 | Modify some of the arguments when necessary to set which objects and dimensions to perturb and the range of varying latent factors.
116 |
117 | ### Perceptual cycle-consistency
118 | We demonstrate the effectiveness of perceptual cycle-consistency constrain on Multi-Texture with each image including two objects of different identities, ellipse and square. Training scripts of the experiments are provided in [this folder](./script/multi_texture/perceptual_consistency). The first three training steps are the same as mentioned in [Training](./README.md#Training) without enforcing perceptual cycle-consistency. Then we finetune the model with perceptual cycle-consistency constrain by running
119 | ```
120 | sh script/multi_texture/perceptual_consistency/finetune_PC.sh
121 | ```
122 | It can be observed that the finetuning decreases identity switching rate and improves identity consistency. As shown in the figure below, finetuned model (middle) consistently captures the ellipse in channel 0 while un-finetuned model (right) can assign the square to channel 0 sometimes.
123 |
124 |
125 |
126 |
127 |
128 | To compute identity switching rate of the segmentation network, run
129 | ```
130 | sh script/multi_texture/perceptual_consistency/test_segmentation.sh
131 | ```
132 | We provide checkpoints for two models [here](https://drive.google.com/drive/folders/1WCBgnPim9l5aMjbgAg1wBETd3QoLY-Wd?usp=sharing). If you'd like to explore the effectiveness by yourself, we recommend downloading the [model](https://drive.google.com/drive/folders/1X5kDp-1swauBKaFXF32gRy9wnJEvtCKe?usp=sharing) that has been trained for the first three steps and restoring it to finetune with perceptual consistency.
133 |
134 | ## Downloads
135 |
136 | You can download our trained models for all datasets [here](https://drive.google.com/drive/folders/1AcFb2kfFpEuD-Wi_Iz_Z9-mkgs3anEOF?usp=sharing) including pretrained inpainting networks and final checkpoints of all modules.
137 |
138 | ## Citation
139 |
140 | If you use this code in academic context, please cite the following publication:
141 |
142 | ```
143 |
144 | @InProceedings{Yang_2020_CVPR,
145 | author = {Yang, Yanchao and Chen, Yutong and Soatto, Stefano},
146 | title = {Learning to Manipulate Individual Objects in an Image},
147 | booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
148 | month = {June},
149 | year = {2020}
150 | }
151 | ```
152 |
153 |
--------------------------------------------------------------------------------
/common_flags.py:
--------------------------------------------------------------------------------
1 | import gflags
2 | FLAGS = gflags.FLAGS
3 |
4 |
5 |
6 | #data
7 | gflags.DEFINE_string('dataset', 'multi_dsprites/ multi_texture/ objects_room/ flying_animals', 'Dataset used')
8 | gflags.DEFINE_string('root_dir',"/your/path/to/dataset", 'Folder containig the dataset')
9 |
10 |
11 | gflags.DEFINE_integer("takenum", -1, 'take number default: the entire dataset. not used for flying_animals')
12 | gflags.DEFINE_integer('skipnum',2000,'skip number default: 2k used for testset not used for flying_animals')
13 | gflags.DEFINE_bool('shuffle',False,'')
14 |
15 |
16 | #dir
17 | gflags.DEFINE_string('checkpoint_dir', "checkpoint", "Experiment folder. It will contain"
18 | "the saved checkpoints and tensorboard logs or disentanglement results")
19 | #gflags.DEFINE_string('output_dir',"./outputs/0","Containing outputs result when doing evaluation")
20 | gflags.DEFINE_integer('summaries_secs', 40, 'number of seconds between computation of summaries, used in train_inpainter')
21 | gflags.DEFINE_integer('summaries_steps', 100, 'number of step between computation of summaries, used in train_CIS')
22 | gflags.DEFINE_integer('ckpt_secs', 3600, 'number of seconds between checkpoint saving')
23 | gflags.DEFINE_integer('ckpt_steps', 10000, 'number of step between checkpoint saving')
24 |
25 |
26 | #resume
27 | gflags.DEFINE_bool('resume_fullmodel', False, 'whether to resume a fullmodel')
28 | gflags.DEFINE_bool('resume_inpainter', True, 'resume pretrained inpainter for train_CIS inpainter_ckpt needed')
29 | gflags.DEFINE_bool('resume_resnet', False, 'whether to use pretrained resnet (effective when resume_fullmodel=False)')
30 | gflags.DEFINE_bool('resume_CIS', False, 'whether to resume inpainter and generator')
31 | #checkpoint to load
32 | # used for resumed training or evaluation
33 | gflags.DEFINE_string('fullmodel_ckpt', '?', 'checkpoint of full model inpainter+Generator(train_CIS) inp+gen+VAE(train_end2end)')
34 | gflags.DEFINE_string('CIS_ckpt', '?', 'checkpoint of inpainter + generator')
35 | gflags.DEFINE_string('mask_ckpt', '?', '')
36 | gflags.DEFINE_string('tex_ckpt', '?', '')
37 | gflags.DEFINE_string('generator_ckpt', '?', 'checkpoint of mask generator')
38 | gflags.DEFINE_string('inpainter_ckpt', '?', 'checkpoint of pretrained inpainter')
39 | gflags.DEFINE_string('resnet_ckpt', 'resnet/resnet_v2_50.ckpt', 'checkpoint of pretrained resnet')
40 | #to - do VAE (TEXTURE AND SHAPE)
41 |
42 |
43 |
44 | gflags.DEFINE_integer('max_training_hrs', 72,'maximum training hours')
45 | #copy the sh
46 |
47 | #mode
48 | gflags.DEFINE_string('mode', 'train_CIS', 'pretrain_inpainter / train_CIS / train_VAE / eval_segment / eval_VAE /train_supGenerator')
49 | gflags.DEFINE_string('sh_path','./train.sh', 'absolute path of the running shell')
50 |
51 |
52 |
53 | #
54 | gflags.DEFINE_integer('batch_size', 32, 'batch_size')
55 | gflags.DEFINE_integer('num_branch', 6, 'output channel of segmentation')
56 | gflags.DEFINE_integer('nobj', -1, 'number of objects, only used in evaluation or fixed_number training')
57 |
58 | #network
59 | gflags.DEFINE_string('model', 'resnet_v2_50', 'resnet_v2_50 or segnet')
60 | #VAE
61 | gflags.DEFINE_integer('tex_dim', 4, 'dimension of texture latent space')
62 | gflags.DEFINE_integer('mask_dim', 10, 'dimension of mask latent space')
63 | gflags.DEFINE_integer('bg_dim', 10, 'dimension of bg latent space')
64 | gflags.DEFINE_float('VAE_weight', 0,'weight of tex_error and mask_error loss for Generator when training End2End')
65 | gflags.DEFINE_float('CIS_weight', 1,'weight of CIS loss for Generator when training End2End')
66 | gflags.DEFINE_float('tex_beta', 10,'ratio of tex_error loss and tex_kl loss')
67 | gflags.DEFINE_float('mask_gamma', 50000,'')
68 | gflags.DEFINE_float('mask_capacity_inc',1e-5, 'increment of mask capacity at each step')
69 | gflags.DEFINE_float('bg_beta', 10,'ratio of bg_error loss and bg_kl loss')
70 |
71 | #hyperparameters
72 | gflags.DEFINE_float('gen_lr',1e-3,'learning rate')
73 | gflags.DEFINE_float('inp_lr',1e-4,'learning rate')
74 | gflags.DEFINE_float('VAE_lr',1e-4,'learning rate')
75 | gflags.DEFINE_float('epsilon', 40, 'epsilon in the denominator of IRR')
76 | gflags.DEFINE_float('gen_clip_value', -1, 'generator''s grad_clip_value -1 means no clip')
77 | gflags.DEFINE_integer('iters_inp', 1, 'iteration # of inpainter')
78 | gflags.DEFINE_integer('iters_gen', 3, 'iteration # of generator')
79 | gflags.DEFINE_integer('iters_gen_vae', 3, 'iteration # of generator and vae used when training end2end')
80 | gflags.DEFINE_float('ita',1e-3,'weight of perceptual consistency loss used in train_PC mode')
81 |
82 |
83 | #flying animals (only support)
84 | gflags.DEFINE_integer('max_num',5,'max number of objects in the image')
85 | #gflags.DEFINE_integer('min_num',1,'min number of objects in the image')
86 | gflags.DEFINE_integer('bg_num', 100, 'number of bg')
87 | gflags.DEFINE_integer('ani_num',240,'')
88 |
89 |
90 |
91 | #automatically set flags
92 | gflags.DEFINE_integer('img_height',64,'')
93 | gflags.DEFINE_integer('img_width',64,'')
94 | gflags.DEFINE_integer('n_bg',1,'')
95 |
96 |
97 | #traverse
98 | gflags.DEFINE_string('input_img','./?','input image path')
99 | gflags.DEFINE_string('traverse_type', 'tex', 'tex or branch')
100 | gflags.DEFINE_integer('top_kdim', 5, 'k dimensions with largest KL divergence to traverse')
101 | gflags.DEFINE_string('traverse_branch', 'all', 'all or #1,#2,#3')
102 | gflags.DEFINE_float('traverse_range', '5', 'k z_mean +- k*sigma')
103 | gflags.DEFINE_float('traverse_start', '-1', 'k z_mean +- k*sigma')
104 | gflags.DEFINE_float('traverse_end', '1', 'k z_mean +- k*sigma')
105 |
106 | gflags.DEFINE_string('VAE_loss','CE','CE or L1')
107 | gflags.DEFINE_bool('PC', False, 'Experiment for perceptual consistency')
--------------------------------------------------------------------------------
/data/flying_animals_data/fa_data_decode.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import imageio
3 | import os
4 | import argparse
5 |
6 |
7 | rm_bg = [34,80,71,72,38,28,85,88,93,86,75,70,42,30,24,64,65,90,82,25,
8 | 29,21,40,4,51,79,33,68,83,91,59,6,87,45,94,99,
9 | 23,78,36,19,77,39,62,52,81,56,98,43]
10 | rm_animals = [2,3,7,8,11,20,4,6,9,19]
11 |
12 |
13 | parser = argparse.ArgumentParser(description='decode raw images into .npz')
14 | parser.add_argument('--data_dir', type=str, default='./data', help='directory of the data data_dir/foregrounds,backgrounds,masks')
15 |
16 | args = parser.parse_args()
17 |
18 | data_dir = args.data_dir
19 |
20 | data = {'background':[], 'foreground':[], 'mask':[]}
21 | test_data = {'background':[], 'foreground':[], 'mask':[]}
22 |
23 | for i in range(241):
24 | fg = imageio.imread(os.path.join(data_dir, 'foregrounds', '{}.png'.format(i)))
25 | data['foreground'].append(fg)
26 |
27 | mask = imageio.imread(os.path.join(data_dir, 'masks', '{}.png'.format(i)))
28 | mask = mask.astype(np.bool)
29 | data['mask'].append(mask)
30 |
31 | if not i in rm_animals:
32 | test_data['foreground'].append(fg)
33 | test_data['mask'].append(mask)
34 |
35 |
36 | for i in range(101):
37 | bg = imageio.imread(os.path.join(data_dir, 'backgrounds', '{}.png'.format(i)))
38 | data['background'].append(bg)
39 |
40 | if not i in rm_bg:
41 | test_data['background'].append(bg)
42 |
43 |
44 |
45 | np.savez('img_data',background=data['background'], foreground=data['foreground'], mask=data['mask'])
46 | np.savez('img_data_test',background=test_data['background'], foreground=test_data['foreground'], mask=test_data['mask'])
--------------------------------------------------------------------------------
/data/flying_animals_utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import random
4 | import imageio
5 | import os
6 | import functools
7 | import scipy.ndimage as ndimage
8 |
9 | #online generator for flying animals dataset
10 |
11 | H, W = 192,256
12 | def convert(imgs):
13 | return (imgs/255).astype(np.float32)
14 |
15 | def check_occlusion(pos, pos_list):
16 | for p in pos_list:
17 | dist = abs(pos[0]-p[0])+abs(pos[1]-p[1])
18 | if dist<=97:
19 | return True
20 | return False
21 |
22 | def generate_params(data, max_num, num):
23 | deter_params = []
24 | bg_num = data['background'].shape[0]-1
25 | ani_num = data['foreground'].shape[0]-1
26 | for i in range(num): #size of the validation set
27 | param = dict()
28 | param['bg_index'], param['number'] = random.randint(1, bg_num), get_number(random.uniform(0,1))
29 | param['fg_indices'] = [random.randint(1,ani_num) for k in range(param['number'])] +[0]*(max_num-param['number'])
30 |
31 | pos_list = []
32 | params = []
33 | for k in range(param['number']):
34 | f = random.uniform(0.9, 1.2)
35 | dx, dy = random.uniform(-W*2.2//5, W*2.2//5), random.uniform(-H*2.2//5, H*2.2//5)
36 | while check_occlusion([dx, dy], pos_list):
37 | dx, dy = random.uniform(-W*2.2//5, W*2.2//5), random.uniform(-H*2.2//5, H*2.2//5)
38 | pos_list.append([dx,dy])
39 | p= [f,0,W/2-f*W/2-f*dx,0,f,H/2-f*H/2-f*dy,0,0]
40 | params.append(p)
41 | params += [[1,0,0,0,1,0,0,0]]*(max_num-param['number'])
42 | param['params'] = params #factor translation
43 | deter_params.append(param)
44 | return deter_params # list of dicts 5*100 element
45 |
46 | def get_number(n):
47 | #probability of distribution -- number of animals in an image
48 | if n<=0.8:
49 | return 5 #0.8
50 | elif n<=0.9:
51 | return 4 #0.1
52 | elif n<=0.96:
53 | return 3 #0.06
54 | elif n<=0.99:
55 | return 2 #0.03
56 | else:
57 | return 1 #0.01
58 |
59 | def params_gen(data, max_num, deterministic_params=None):
60 | backgrounds = convert(data['background'])
61 | foregrounds = convert(data['foreground'])
62 | ani_masks = data['mask'].astype(np.float32)
63 |
64 | bg_num = data['background'].shape[0]-1
65 | ani_num = data['foreground'].shape[0]-1
66 |
67 | step = 0
68 | while True:
69 | step += 1
70 | if deterministic_params: #200 images for validation
71 | param = deterministic_params[(step-1)%len(deterministic_params)]
72 | bg = backgrounds[param['bg_index']]
73 | texes = [foregrounds[i] for i in param['fg_indices']]
74 | masks = [ani_masks[i] for i in param['fg_indices']]
75 | yield bg, np.stack(texes, axis=0), np.stack(masks, axis=0), param['params']
76 | else: #online generated data (infinite)
77 | number = get_number(random.uniform(0,1))
78 | params = []
79 | pos_list = []
80 | texes = []
81 | masks = []
82 | bg = backgrounds[random.randint(1,bg_num)]
83 | for i in range(number):
84 | ind = random.randint(1,ani_num)
85 | texes.append(foregrounds[ind])
86 | masks.append(ani_masks[ind])
87 | f = random.uniform(0.9, 1.2) #input /output
88 | dx, dy = random.uniform(-W*2.2//5, W*2.2//5), random.uniform(-H*2.2//5, H*2.2//5)
89 | while check_occlusion([dx, dy], pos_list):
90 | dx, dy = random.uniform(-W*2.2//5, W*2.2//5), random.uniform(-H*2.2//5, H*2.2//5)
91 | pos_list.append([dx,dy])
92 | param = [f,0,W/2-f*W/2-f*dx,0,f,H/2-f*H/2-f*dy,0,0]
93 | params.append(param)
94 | params += [[1,0,0,0,1,0,0,0]]*(max_num-number)
95 | texes += [backgrounds[0]]*(max_num-number)
96 | masks += [ani_masks[0]]*(max_num-number)
97 |
98 | yield bg, np.stack(texes, axis=0), np.stack(masks, axis=0), params
99 |
100 |
101 | def generate_image(bg, texes, masks, params, max_num):
102 | #given the randomly selected foregrounds, backgrounds and their factors of variation, synthesize the final image.
103 |
104 | #zoom and shift transform
105 | texes = tf.contrib.image.transform(texes, transforms=params,interpolation='BILINEAR')
106 | masks = tf.contrib.image.transform(masks, transforms=params,interpolation='BILINEAR')
107 | texes = tf.clip_by_value(texes, 0, 1)
108 | masks = tf.clip_by_value(masks, 0, 1)
109 |
110 | cum_mask = tf.zeros_like(masks[0])
111 | depth_masks = []
112 | perturbed_texes = []
113 | for i in range(0,max_num): #depth order -> from near to far
114 | perturbed_tex = texes[i]+tf.random.uniform([], minval=-0.36,maxval=0.36, dtype=tf.float32) #brightness variation
115 | perturbed_texes.append(perturbed_tex)
116 | m = masks[i]*(1-cum_mask)
117 | cum_mask += m
118 | depth_masks.append(m)
119 |
120 |
121 | bg = bg + tf.random.uniform([], minval=-0.36,maxval=0.36, dtype=tf.float32)
122 | perturbed_texes = [bg] + perturbed_texes
123 | depth_masks = [1-cum_mask] + depth_masks
124 | perturbed_texes = tf.stack(perturbed_texes, axis=0) #C H W 3
125 | depth_masks = tf.stack(depth_masks, axis=0) #C H W 1
126 |
127 | img = tf.reduce_sum(perturbed_texes*depth_masks, axis=0) #H W 3
128 |
129 |
130 | data = {}
131 | data['img'] = img #float 0~1
132 | data['masks'] = tf.transpose(depth_masks, perm=[1,2,3,0]) #C H W 1 -> H W 1 C (bg first channel)
133 | return data
134 |
135 |
136 | def dataset(data_path, batch_size, max_num=5, phase='train'):
137 | """
138 | Args:
139 | data_path: the path of npz data/flying_animals_data/img_data.npz
140 | batch_size: batchsize
141 | max_num: max number of animals in an image
142 | phase: train: infinitely online generating image/ val: deterministic 200 / test: deterministic 2000
143 | """
144 | assert max_num==5,'please re-assign the distribution in get_number(), flying_animals_utils.py, if you need to reset max_num'
145 |
146 | if phase in ['val', 'test']:
147 | assert 200%batch_size==0
148 | data_path = os.path.abspath(data_path)
149 | data_path = data_path.split('.')[0]+'_test.npz' #data/flying_animals_data/img_data_test.npz
150 |
151 | data = np.load(data_path)
152 | deterministic_params = generate_params(data, max_num=max_num, num=200 if phase=='val' else 2000) if not phase=='train' else None
153 | partial_fn = functools.partial(params_gen, data=data, max_num=max_num,
154 | deterministic_params=deterministic_params)
155 | dataset = tf.data.Dataset.from_generator(
156 | partial_fn,
157 | (tf.float32, tf.float32, tf.float32, tf.float32),
158 | (tf.TensorShape([H,W,3]),tf.TensorShape([max_num,H,W,3]),tf.TensorShape([max_num,H,W,1]),tf.TensorShape([max_num,8])))
159 | dataset = dataset.map(lambda bg,t,m,p: generate_image(bg,t,m,p,max_num=max_num),
160 | num_parallel_calls=1)
161 |
162 | dataset = dataset.batch(batch_size)
163 | dataset = dataset.prefetch(10)
164 | return dataset
--------------------------------------------------------------------------------
/data/multi_dsprites_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Multi-dSprites dataset reader."""
16 |
17 | import functools
18 | import tensorflow as tf
19 | COMPRESSION_TYPE = tf.io.TFRecordOptions.get_compression_type_string('GZIP')
20 | IMAGE_SIZE = [64, 64]
21 | # The maximum number of foreground and background entities in each variant
22 | # of the provided datasets. The values correspond to the number of
23 | # segmentation masks returned per scene.
24 | MAX_NUM_ENTITIES = {
25 | 'binarized': 4,
26 | 'colored_on_grayscale': 6,
27 | 'colored_on_colored': 5
28 | }
29 | BYTE_FEATURES = ['mask', 'image']
30 | def feature_descriptions(max_num_entities, is_grayscale=False):
31 | """Create a dictionary describing the dataset features.
32 | Args:
33 | max_num_entities: int. The maximum number of foreground and background
34 | entities in each image. This corresponds to the number of segmentation
35 | masks and generative factors returned per scene.
36 | is_grayscale: bool. Whether images are grayscale. Otherwise they're assumed
37 | to be RGB.
38 | Returns:
39 | A dictionary which maps feature names to `tf.Example`-compatible shape and
40 | data type descriptors.
41 | """
42 |
43 | num_channels = 1 if is_grayscale else 3
44 | return {
45 | 'image': tf.io.FixedLenFeature(IMAGE_SIZE+[num_channels], tf.string), #shape dtype
46 | 'mask': tf.io.FixedLenFeature(IMAGE_SIZE+[max_num_entities, 1], tf.string),
47 | 'x': tf.io.FixedLenFeature([max_num_entities], tf.float32),
48 | 'y': tf.io.FixedLenFeature([max_num_entities], tf.float32),
49 | 'shape': tf.io.FixedLenFeature([max_num_entities], tf.float32),
50 | 'color': tf.io.FixedLenFeature([max_num_entities, num_channels], tf.float32),
51 | 'visibility': tf.io.FixedLenFeature([max_num_entities], tf.float32),
52 | 'orientation': tf.io.FixedLenFeature([max_num_entities], tf.float32),
53 | 'scale': tf.io.FixedLenFeature([max_num_entities], tf.float32),
54 | }
55 |
56 | def _decode(example_proto, features):
57 | # Parse the input `tf.Example` proto using a feature description dictionary.
58 | single_example = tf.io.parse_single_example(example_proto, features)
59 | for k in BYTE_FEATURES: #mask image
60 | single_example[k] = tf.squeeze(tf.decode_raw(single_example[k], tf.uint8),
61 | axis=-1) # height width entities channels
62 | # To return masks in the canonical [entities, height, width, channels] format,
63 | # we need to transpose the tensor axes.
64 | single_example['mask'] = tf.transpose(single_example['mask'], [0, 1, 3, 2]) #H W 1 M
65 |
66 | return map(single_example)
67 |
68 | def map(x):
69 | img = x['image']
70 | img = tf.cast(img, tf.float32)
71 | img = img/255
72 | mask = x['mask']
73 | mask = tf.cast(mask, tf.float32)
74 | mask = mask/255 #0~1
75 | data = {}
76 | data['img'] = img
77 | data['masks'] = mask
78 | return data
79 |
80 | def dataset(tfrecords_path, batch_size, phase='train'):
81 | if phase=='test':
82 | skipnum, takenum = 0,2000
83 | shuffle = False
84 | elif phase=='val':
85 | skipnum, takenum = 2000,1000
86 | shuffle = False
87 | else:
88 | skipnum, takenum = 3000, -1
89 | shuffle = True
90 | max_num_entities = MAX_NUM_ENTITIES['colored_on_colored'] #colored on colored -> 5
91 | raw_dataset = tf.data.TFRecordDataset(
92 | tfrecords_path, compression_type=COMPRESSION_TYPE)
93 | raw_dataset = raw_dataset.skip(skipnum).take(takenum)
94 | features = feature_descriptions(max_num_entities, False)
95 | partial_decode_fn = functools.partial(_decode, features=features)
96 |
97 | dataset = raw_dataset.map(partial_decode_fn,num_parallel_calls=1)
98 | if shuffle:
99 | dataset = dataset.shuffle(seed=479, buffer_size=50000, reshuffle_each_iteration=True)
100 | dataset = dataset.repeat().batch(batch_size)
101 | dataset = dataset.prefetch(10)
102 | return dataset
103 |
--------------------------------------------------------------------------------
/data/multi_texture_data/bg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenYutongTHU/Learning-to-manipulate-individual-objects-in-an-image-Implementation/db75a5505f7fe2c83c0ded08f425ef11759544bd/data/multi_texture_data/bg.png
--------------------------------------------------------------------------------
/data/multi_texture_data/ellipse_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenYutongTHU/Learning-to-manipulate-individual-objects-in-an-image-Implementation/db75a5505f7fe2c83c0ded08f425ef11759544bd/data/multi_texture_data/ellipse_2.png
--------------------------------------------------------------------------------
/data/multi_texture_data/square_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenYutongTHU/Learning-to-manipulate-individual-objects-in-an-image-Implementation/db75a5505f7fe2c83c0ded08f425ef11759544bd/data/multi_texture_data/square_2.png
--------------------------------------------------------------------------------
/data/multi_texture_data/tex.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenYutongTHU/Learning-to-manipulate-individual-objects-in-an-image-Implementation/db75a5505f7fe2c83c0ded08f425ef11759544bd/data/multi_texture_data/tex.png
--------------------------------------------------------------------------------
/data/multi_texture_utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import random
4 | import imageio
5 | import os
6 | import functools
7 | import scipy.ndimage as ndimage
8 | H,W=64,64
9 | pos_choice = [-24,-18,-12,-6,0,6,12,18,24]
10 | n_pos = 9
11 |
12 | def check_occlusion(pos, pos_list):
13 | for p in pos_list:
14 | dist = abs(pos[0]-p[0])+abs(pos[1]-p[1])
15 | if dist<=10:
16 | return True
17 | return False
18 |
19 | def generate_params(data_path, num, max_num, PC):
20 | deterministic_params = []
21 | if PC:
22 | num = n_pos*n_pos*n_pos*n_pos-1
23 | for i in range(num):
24 | param = dict()
25 | param['number'] = 2 if PC else get_number(random.uniform(0,1))
26 | param['ind'], param['mat'] = [],[]
27 | pos_list = []
28 |
29 | dr_=[pos_choice[i%n_pos], pos_choice[((i//n_pos)//n_pos)%n_pos]]
30 | dc_=[pos_choice[(i//n_pos)%n_pos], pos_choice[(((i//n_pos)//n_pos)//n_pos)%n_pos]]
31 |
32 | for k in range(param['number']):
33 | if PC:
34 | param['ind'].append(k)
35 | else:
36 | param['ind'].append(random.randint(0,1)) #shape 0 square 1 ellipse
37 |
38 | if PC:
39 | dr, dc = dr_[k], dc_[k]
40 | else:
41 | dr = random.uniform(-H/2,H/2)
42 | dc = random.uniform(-W/2,W/2)
43 | while check_occlusion([dr,dc], pos_list):
44 | dr = random.uniform(-H/2,H/2)
45 | dc = random.uniform(-W/2,W/2)
46 | pos_list.append([dr,dc])
47 |
48 | mat = np.zeros([3,4])
49 | mat[0][0], mat[0][3] = 1, dr
50 | mat[1][1], mat[1][3] = 1, dc
51 | mat[2][2] = 1
52 | param['mat'].append(mat)
53 | param['hue'] = [-0.2,0,0.2] if PC else [random.uniform(-0.45,0.45) for h_ in range(max_num+1)]
54 | deterministic_params.append(param)
55 | return deterministic_params
56 |
57 | def multi_texture_gen(data_path, max_num=4, deterministic_params=None, PC=False):
58 | tex0 = imageio.imread(os.path.join(data_path,'tex.png'))
59 | tex0 = (tex0/255).astype(np.float32)
60 | square = imageio.imread(os.path.join(data_path,'square_2.png')).reshape(64,64,1)
61 | square = (square/255).astype(np.float32)
62 | ellipse = imageio.imread(os.path.join(data_path,'ellipse_2.png')).reshape(64,64,1)
63 | ellipse = (ellipse/255).astype(np.float32)
64 | masks = [ellipse, square]
65 |
66 | step = 0
67 | while True:
68 | param = deterministic_params[step%(len(deterministic_params))] if deterministic_params else None
69 | step += 1
70 | if param:
71 | number = param['number']
72 | elif PC:
73 | number = 2
74 | else:
75 | number = get_number(random.uniform(0,1))
76 | shape_masks = []
77 | shape_texes = [] #
78 | cum_mask = np.zeros_like(masks[0]) # occlusion
79 | pos_list = []
80 | for i in range(number): #place the randomly selected and transformed shape on the background in the ascending depth order
81 | if param:
82 | ind = param['ind'][i]
83 | elif PC:
84 | ind = i
85 | else:
86 | ind = random.randint(0, 1) #choose the shape
87 | shape = masks[ind].copy()
88 | tex = tex0.copy()
89 | if param:
90 | mat = param['mat'][i]
91 | else:
92 | if PC:
93 | dr = random.choice(pos_choice)
94 | dc = random.choice(pos_choice)
95 | else:
96 | dr = random.uniform(-H/2,H/2)
97 | dc = random.uniform(-W/2,W/2)
98 | while check_occlusion([dr,dc], pos_list):
99 | dr = random.uniform(-H/2,H/2)
100 | dc = random.uniform(-W/2,W/2)
101 | pos_list.append([dr,dc])
102 | mat = np.zeros([3,4])
103 | mat[0][0], mat[0][3] = 1, dr
104 | mat[1][1], mat[1][3] = 1, dc
105 | mat[2][2] = 1
106 | shape = ndimage.affine_transform(shape, mat, output_shape=(64,64,1))
107 | shape = np.clip(shape, 0,1)
108 | shape = shape*(1-cum_mask)
109 | cum_mask += shape
110 | shape_masks.append(shape)
111 | shape_texes.append(tex)
112 | for i in range(number, max_num):#pad the returned element
113 | shape_masks.append(np.zeros_like(shape))
114 | shape_texes.append(np.zeros_like(tex))
115 | hue_value = param['hue'] if deterministic_params else [random.uniform(-0.45,0.45) for h_ in range(max_num+1)]
116 | yield (number, np.stack(shape_masks, axis=0), np.stack(shape_texes, axis=0), hue_value)
117 |
118 | def get_number(n):
119 | if n<=0.8:
120 | return 4 #0.8
121 | elif n<=0.9:
122 | return 3 #0.1
123 | elif n<=0.97:
124 | return 2 #0.07
125 | else:
126 | return 1 #0.03
127 |
128 | def combine(number, masks, texes, hue_value, bg, max_num):
129 | hue_texes = []
130 | for i in range(max_num):
131 | hue_texes.append(tf.image.adjust_hue(texes[i], hue_value[i+1])) #randomly perturb the hue value
132 | hue_texes = tf.stack(hue_texes, axis=0) #C H W 1
133 | fg = tf.reduce_sum(masks*hue_texes, axis=0)#np.sum(masks*texes, axis=0, keepdims=False) #H W 3
134 | bg_mask = 1-tf.reduce_sum(masks, axis=0) # H W 1
135 | bg = tf.image.adjust_hue(bg, hue_value[0])
136 |
137 | all_masks = tf.concat([tf.expand_dims(bg_mask,axis=0), masks], axis=0) #H W 1 + C H W 1
138 |
139 | data = {}
140 | data['img'] = bg_mask*bg+fg
141 | data['masks'] = tf.transpose(all_masks, perm=[1,2,3,0]) #C H W 1 -> H W 1 C
142 | return data
143 |
144 | def dataset(data_path, batch_size, max_num=4, phase='train', PC=False):
145 | """
146 | Args:
147 | data_path: the path of combination elements
148 | batch_size:
149 | max_num: max_num: max number of objects in an image
150 | phase: train: infinitely online generating image/ val: deterministic 200 / test: deterministic 2000
151 | PC: experiment for Perceptual consistency
152 | """
153 | if PC:
154 | assert max_num==2, 'set max_num as 2 in experiment for Perceptual consistency'
155 | else:
156 | assert max_num==4,'please re-assign the distribution in get_number(), multi_texture_utils.py, if you need to reset max_num'
157 | deterministic_params = generate_params(data_path, num=200 if phase=='val' else 2000, max_num=max_num, PC=PC) if not phase=='train' else None
158 |
159 | partial_fn = functools.partial(multi_texture_gen,
160 | data_path=data_path, max_num=max_num, deterministic_params=deterministic_params, PC=PC)
161 |
162 | dataset = tf.data.Dataset.from_generator(
163 | partial_fn,#(data_path, max_num, zoom, rotation),
164 | (tf.int32, tf.float32, tf.float32, tf.float32),
165 | (tf.TensorShape([]),tf.TensorShape([max_num,H,W,1]), tf.TensorShape([max_num,H,W,3]), tf.TensorShape([max_num+1])))
166 |
167 | bg0 = imageio.imread(os.path.join(data_path,'bg.png'))
168 | bg0 = tf.convert_to_tensor(bg0/255, dtype=tf.float32) #0~1
169 | dataset = dataset.map(lambda n,m,t,h: combine(n,m,t,h, bg=bg0,max_num=max_num), num_parallel_calls=1)
170 | dataset = dataset.batch(batch_size)
171 | # # print (dataset)
172 | dataset = dataset.prefetch(10)
173 | return dataset
174 |
175 |
176 |
--------------------------------------------------------------------------------
/data/objects_room_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ============================================================================
15 | """Objects Room dataset reader."""
16 |
17 | import functools
18 | import tensorflow as tf
19 |
20 |
21 | COMPRESSION_TYPE = tf.io.TFRecordOptions.get_compression_type_string('GZIP')
22 | IMAGE_SIZE = [64, 64]
23 | # The maximum number of foreground and background entities in each variant
24 | # of the provided datasets. The values correspond to the number of
25 | # segmentation masks returned per scene.
26 | MAX_NUM_ENTITIES = {
27 | 'train': 7,
28 | 'six_objects': 10,
29 | 'empty_room': 4,
30 | 'identical_color': 10
31 | }
32 | BYTE_FEATURES = ['mask', 'image']
33 |
34 |
35 | def feature_descriptions(max_num_entities):
36 | """Create a dictionary describing the dataset features.
37 | Args:
38 | max_num_entities: int. The maximum number of foreground and background
39 | entities in each image. This corresponds to the number of segmentation
40 | masks returned per scene.
41 | Returns:
42 | A dictionary which maps feature names to `tf.Example`-compatible shape and
43 | data type descriptors.
44 | """
45 | return {
46 | 'image': tf.FixedLenFeature(IMAGE_SIZE+[3], tf.string),
47 | 'mask': tf.FixedLenFeature([max_num_entities]+IMAGE_SIZE+[1], tf.string),
48 | }
49 |
50 |
51 | def _decode(example_proto, features, random_sky):
52 | # Parse the input `tf.Example` proto using a feature description dictionary.
53 | single_example = tf.parse_single_example(example_proto, features)
54 | for k in BYTE_FEATURES:
55 | single_example[k] = tf.squeeze(tf.decode_raw(single_example[k], tf.uint8),
56 | axis=-1)
57 | # sky floor half-wall1(2) half-wall2(3) objects objects objects
58 | mask = tf.transpose(single_example['mask'], [1, 2, 3, 0]) #H W 1 7
59 | mask = tf.concat([mask[:,:,:,0:2],mask[:,:,:,2:3]+mask[:,:,:,3:4],mask[:,:,:,4:]], axis=-1) #H W 1 6 merge the wall
60 | single_example['mask'] = mask
61 | return map(single_example, random_sky)
62 |
63 |
64 | def map(x, random_sky):
65 | img = tf.cast(x['image']/255, tf.float32) #0~1
66 | mask = tf.cast(x['mask']/255, tf.float32)
67 |
68 | data = {}
69 | data['masks'] = mask
70 |
71 | sky_mask = mask[:,:,:,0]#H W 1
72 | scale = tf.random_uniform(shape=[], minval=0.2, maxval=1, dtype=tf.float32) if random_sky else 1
73 | var_img = img*(1-sky_mask)+img*scale*sky_mask #0~1
74 | data['img'] = var_img
75 | return data
76 |
77 | def dataset(tfrecords_path, batch_size, phase='train'):
78 | if phase=='test':
79 | skipnum, takenum = 0,2000
80 | shuffle = False
81 | elif phase=='val':
82 | skipnum, takenum = 2000,1000
83 | shuffle = False
84 | else:
85 | skipnum, takenum = 3000, -1
86 | shuffle = True
87 |
88 | max_num_entities = MAX_NUM_ENTITIES['train']
89 | raw_dataset = tf.data.TFRecordDataset(
90 | tfrecords_path, compression_type=COMPRESSION_TYPE,
91 | buffer_size=50, num_parallel_reads=2)
92 | features = feature_descriptions(max_num_entities)
93 | partial_decode_fn = functools.partial(_decode, features=features, random_sky=(phase=='train')) #val constant sky
94 |
95 | dataset = raw_dataset.skip(skipnum).take(takenum)
96 | dataset = dataset.map(partial_decode_fn, num_parallel_calls=1)
97 | if shuffle:
98 | dataset = dataset.shuffle(seed=479, buffer_size=batch_size*100, reshuffle_each_iteration=True)
99 | dataset = dataset.repeat().batch(batch_size)
100 | dataset = dataset.prefetch(10)
101 | return dataset
102 |
103 |
104 | # import imageio
105 | # import numpy as np
106 | # bs = 4
107 | # dataset = dataset('./objects_room_data/objects_room_train.tfrecords',val=True,
108 | # batch_size=bs, skipnum=0, takenum=-1,
109 | # shuffle=False, map_parallel_calls=1)
110 |
111 | # iterator = dataset.make_one_shot_iterator()
112 |
113 | # data_batch = iterator.get_next()
114 | # edge_batch = bin_edge_map(data_batch['img'], 'objects_room')
115 |
116 | # sess = tf.Session()
117 | # for i in range(3):
118 | # data, edges = sess.run((data_batch, edge_batch))
119 | # for k in range(bs):
120 | # img = data['img'][k,:,:,:]
121 | # masks = data['masks'][k,:,:,:,:]
122 | # imageio.imwrite('debug/{}_{}img.png'.format(i,k), (img*255).astype(np.uint8))
123 |
124 | # edge = edges[k,:,:,:] #H W 2
125 | # show = np.concatenate([edge, np.zeros_like(edge[:,:,0:1])], axis=-1)#H W 3
126 | # imageio.imwrite('debug/{}_{}edge.png'.format(i,k), (show*255).astype(np.uint8))
127 | # for m in range(6):
128 | # imageio.imwrite('debug/{}_{}mask{}.png'.format(i,k,m), (masks[:,:,:,m]*img*255).astype(np.uint8))
--------------------------------------------------------------------------------
/doc/flying_animals.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenYutongTHU/Learning-to-manipulate-individual-objects-in-an-image-Implementation/db75a5505f7fe2c83c0ded08f425ef11759544bd/doc/flying_animals.gif
--------------------------------------------------------------------------------
/doc/multi_texture.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenYutongTHU/Learning-to-manipulate-individual-objects-in-an-image-Implementation/db75a5505f7fe2c83c0ded08f425ef11759544bd/doc/multi_texture.gif
--------------------------------------------------------------------------------
/doc/objects_room.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenYutongTHU/Learning-to-manipulate-individual-objects-in-an-image-Implementation/db75a5505f7fe2c83c0ded08f425ef11759544bd/doc/objects_room.gif
--------------------------------------------------------------------------------
/doc/pc.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenYutongTHU/Learning-to-manipulate-individual-objects-in-an-image-Implementation/db75a5505f7fe2c83c0ded08f425ef11759544bd/doc/pc.gif
--------------------------------------------------------------------------------
/eval/eval_VAE.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import os
3 | import gflags
4 | #https://github.com/google/python-gflags
5 | import sys
6 | sys.path.append("..")
7 | import pprint
8 | from keras.utils.generic_utils import Progbar
9 | import model.Summary as Summary
10 | from model.utils.generic_utils import myprint, myinput, Permute_IoU
11 | from model.traverse_graph import Traverse_Graph
12 | import imageio
13 | import numpy as np
14 | import time
15 |
16 |
17 | def convert2float(img):
18 | return (img/255).astype(np.float32)
19 | def convert2int(img):
20 | return (img*255).astype(np.uint8)
21 |
22 |
23 | def pad_img(img):
24 | H,W,C = img.shape
25 | pad_img = (np.ones([H+4, W+4,C])*255).astype(np.uint8)
26 | pad_img[2:2+H,2:2+W,:] = img #H W 3
27 | return pad_img
28 |
29 | def eval(FLAGS):
30 | graph = Traverse_Graph(FLAGS)
31 | graph.build()
32 |
33 | restore_vars = tf.global_variables('VAE') + tf.global_variables('Generator') + tf.global_variables('Fusion')
34 | saver = tf.train.Saver(restore_vars)
35 |
36 | #CIS_saver = tf.train.Saver(tf.global_variables('Generator'))
37 | with tf.Session() as sess:
38 | sess.run(tf.compat.v1.global_variables_initializer())
39 | assert os.path.isfile(FLAGS.fullmodel_ckpt+'.index')
40 | saver.restore(sess, FLAGS.fullmodel_ckpt)
41 | # CIS_saver.restore(sess, FLAGS.CIS_ckpt)
42 |
43 | #saver.save(sess, '/home/yutong/Learning-to-manipulate-individual-objects-in-an-image-Implementation/save_checkpoint/md/model', global_step=0)
44 | myprint("resume model {}".format(FLAGS.fullmodel_ckpt))
45 | fetches = {
46 | 'image_batch': graph.image_batch,
47 | 'generated_masks': graph.generated_masks,
48 | 'traverse_results': graph.traverse_results,
49 | 'out_bg': graph.out_bg,
50 | 'in_bg': graph.in_bg
51 | }
52 | assert FLAGS.batch_size==1
53 | input_img = convert2float(imageio.imread(FLAGS.input_img))
54 | input_img = np.expand_dims(input_img, axis=0)
55 |
56 | results = sess.run(fetches, feed_dict={graph.image_batch0: input_img})
57 | img = convert2int(results['image_batch'][0])
58 |
59 | imageio.imwrite(os.path.join(FLAGS.checkpoint_dir, 'img.png'), img)
60 | for i in range(FLAGS.num_branch):
61 | imageio.imwrite(os.path.join(FLAGS.checkpoint_dir, 'segment_{}.png'.format(i)), convert2int(results['generated_masks'][0,:,:,:,i]*results['image_batch'][0]))
62 |
63 | outputs = np.array(results['traverse_results'])
64 |
65 | if FLAGS.traverse_type=='tex':
66 | nch = 3
67 | ndim = FLAGS.tex_dim
68 | elif FLAGS.traverse_type=='bg':
69 | nch = 3
70 | ndim = FLAGS.bg_dim
71 | else:
72 | nch = 1
73 | ndim = FLAGS.mask_dim
74 |
75 | if FLAGS.traverse_type=='bg':
76 | traverse_branch = [FLAGS.num_branch-1]
77 | else:
78 | traverse_branch = [i for i in range(0,FLAGS.num_branch) if FLAGS.traverse_branch=='all' or str(i) in FLAGS.traverse_branch.split(',')]
79 | traverse_value = list(np.linspace(FLAGS.traverse_start, FLAGS.traverse_end, 60))
80 |
81 |
82 | if FLAGS.dataset == 'flying_animals':
83 | outputs = np.reshape(outputs, [len(traverse_branch), FLAGS.top_kdim, len(traverse_value),FLAGS.img_height//2,FLAGS.img_width//2,-1])
84 | else:
85 | outputs = np.reshape(outputs, [len(traverse_branch), FLAGS.top_kdim, len(traverse_value),FLAGS.img_height,FLAGS.img_width,-1])
86 | #tbranch * tdim * step * H * W * 3
87 |
88 |
89 | branches = []
90 | for i in range(len(traverse_branch)):
91 | values = [[None for jj in range(FLAGS.top_kdim) ] for ii in range(len(traverse_value))]
92 | b = traverse_branch[i]
93 | out = outputs[i] #tdim * step* H * W * 3
94 | for d in range(FLAGS.top_kdim):
95 | gif_imgs = []
96 | for j in range(len(traverse_value)):
97 | img = (out[d,j,:,:,:]*255).astype(np.uint8)
98 | gif_imgs.append(img)
99 | values[j][d] = pad_img(img)
100 | name = 'branch{}_var{}.gif'.format(b, d)
101 | imageio.mimsave(os.path.join(FLAGS.checkpoint_dir, name), gif_imgs, duration=1/30)
102 |
103 | #values len(traverse_value) * kdim (img)
104 | value_slices = [np.concatenate(values[j], axis=1) for j in range(len(traverse_value))] #group different dimensions along the axis x
105 | #len(traverse_value)*(H*W*3)
106 | branches.append(value_slices)
107 | merge_slices = [np.concatenate([branches[i][j] for i in range(len(traverse_branch))], axis=0) for j in range(len(traverse_value))]
108 |
109 |
110 | #imageio.mimsave(os.path.join(FLAGS.checkpoint_dir, 'output.gif'), merge_slices, duration=1/30)
111 |
112 |
113 |
114 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from itertools import count
3 | import os
4 | import gflags
5 | from git import Repo
6 | import sys
7 | import pprint
8 | from common_flags import FLAGS
9 | import warnings
10 | from trainer import train_inpainter, train_CIS, train_VAE, train_PC
11 | from eval import eval_VAE
12 | from model.utils.generic_utils import myprint, myinput
13 | import random
14 | import numpy as np
15 |
16 | def save_log(source, trg_dir, print_flags_dict, sha):
17 | file_name = source.split('/')[-1]
18 | new_file = os.path.join(trg_dir, file_name)
19 | log_name = 'log'
20 | while os.path.isfile(new_file):
21 | new_file = new_file[:-3]+'_c.sh' #.sh
22 | log_name += '_c'
23 | os.system('cp '+source+' '+ new_file)
24 | myprint ("Save "+source +" as "+new_file)
25 | log_file = os.path.join(trg_dir, log_name+'.txt')
26 | with open(log_file,'w') as log_stream:
27 | log_stream.write('commit:' + sha + '\n')
28 | pprint.pprint(print_flags_dict, log_stream)
29 | with open(new_file, 'a') as sh_stream:
30 | sh_stream.write('\n#commit:'+sha)
31 | myprint('Corresponding log file '+log_file)
32 | myinput("Enter to continue")
33 | os.system('chmod a=rx '+log_file)
34 | os.system('chmod a=rx '+new_file)
35 | return
36 |
37 | def complete_FLAGS(FLAGS):
38 | #complete some configuration given the speficied dataset
39 |
40 | img_size_dict = {'multi_texture': (64,64),
41 | 'multi_dsprites': (64,64),
42 | 'objects_room': (64,64),
43 | 'flying_animals': (192,256)}
44 | max_num_dict = {'multi_texture':4,
45 | 'objects_room': 5,
46 | 'multi_dsprites':4,
47 | 'flying_animals':5}
48 | FLAGS.img_height, FLAGS.img_width = img_size_dict[FLAGS.dataset]
49 | FLAGS.max_num = max_num_dict[FLAGS.dataset]
50 | if FLAGS.PC and FLAGS.dataset=='multi_texture':
51 | FLAGS.max_num=2
52 |
53 | if FLAGS.mode == 'pretrain_inpainter':
54 | FLAGS.num_branch = 2
55 | else:
56 | assert FLAGS.num_branch >= FLAGS.max_num+1
57 |
58 | FLAGS.n_bg = 3 if FLAGS.dataset=='objects_room' else 1
59 | return
60 |
61 | def main(argv):
62 | try:
63 | argv = FLAGS(argv)
64 | except gflags.FlagsError as e:
65 | print ('FlagsError: ', e)
66 | sys.exit(1)
67 | else:
68 | tf.compat.v1.set_random_seed(479)
69 | random.seed(101)
70 | complete_FLAGS(FLAGS)
71 | pp = pprint.PrettyPrinter()
72 | print_flags_dict = {}
73 | for key in FLAGS.__flags.keys():
74 | print_flags_dict[key] = getattr(FLAGS, key)
75 | pp.pprint(print_flags_dict)
76 | myinput("Press enter to continue")
77 |
78 | repo = Repo()
79 | sha = repo.head.object.hexsha
80 | FLAGS.checkpoint_dir = FLAGS.checkpoint_dir[:-1] if FLAGS.checkpoint_dir[-1]=='/' else FLAGS.checkpoint_dir
81 | if os.path.exists(FLAGS.checkpoint_dir):
82 | I = myinput(FLAGS.checkpoint_dir+' already exists. \n Are you sure to'
83 | ' place the outputs in the same dir? Y or Y! or N\n'
84 | 'Y: resume training, save previous outputs in the dir and continue saving outputs in it\n'
85 | 'Y!: restart training, delete previous outputs in the dir\n'
86 | 'N to quit \n')
87 | if I in ['Y','y']:
88 | save_log(FLAGS.sh_path, FLAGS.checkpoint_dir, print_flags_dict, sha)
89 | import time
90 | tf.compat.v1.set_random_seed(time.localtime()[5]*10)
91 | random.seed(time.localtime()[4]*10) #new random seed
92 | elif I in ['N','n']:
93 | sys.exit(1)
94 | else:
95 | os.system('rm -f -r '+FLAGS.checkpoint_dir+'/*')
96 | save_log(FLAGS.sh_path, FLAGS.checkpoint_dir, print_flags_dict, sha)
97 | else:
98 | os.makedirs(FLAGS.checkpoint_dir)
99 | save_log(FLAGS.sh_path, FLAGS.checkpoint_dir, print_flags_dict, sha)
100 |
101 |
102 | assert FLAGS.mode in ['pretrain_inpainter','train_CIS', 'train_VAE', 'train_PC', 'train_end2end'
103 | 'eval_CIS', 'eval_VAE']
104 |
105 | if FLAGS.mode == 'pretrain_inpainter':
106 | train_inpainter.train(FLAGS)
107 | elif FLAGS.mode == 'train_CIS':
108 | train_CIS.train(FLAGS)
109 | elif FLAGS.mode == 'eval_VAE':
110 | eval_VAE.eval(FLAGS)
111 | # elif FLAGS.mode == 'train_end2end':
112 | # train_end2end.train(FLAGS)
113 | elif FLAGS.mode == 'train_PC':
114 | train_PC.train(FLAGS)
115 | elif FLAGS.mode == 'train_VAE':
116 | train_VAE.train(FLAGS)
117 | else:
118 | pass
119 | # pass
120 | # elif FLAGS.mode == 'eval_CIS':
121 | # eval_CIS.eval(FLAGS)
122 |
123 | if __name__ == '__main__':
124 | main(sys.argv)
125 |
--------------------------------------------------------------------------------
/model/Summary.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | def convert2uint8(img_list):
4 | #float[0~1] -> uint8[0~255]
5 | new_img_list = [tf.cast(img*255, tf.uint8) for img in img_list]
6 | return new_img_list
7 |
8 | def collect_globalVAE_summary(graph, FLAGS):
9 | ori = graph.image_batch[0]
10 | reconstr = graph.out_imgs[0]
11 | show_list = convert2uint8([ori, reconstr])
12 | tf.compat.v1.summary.image('image output', tf.stack(show_list, axis=0), max_outputs=len(show_list), collections=["globalVAE_Sum"])
13 |
14 | tf.summary.scalar('Reconstruction_Loss', graph.loss, collections=["globalVAE_Sum"])
15 | tf.summary.scalar('latent_space', graph.kl_var, collections=['globalVAE_Sum_kl'])
16 |
17 | for grad, var in graph.train_vars_grads:
18 | tf.summary.histogram(var.op.name+'/grad', grad, collections=['globalVAE_Sum'])
19 |
20 | return tf.summary.merge(tf.compat.v1.get_collection("globalVAE_Sum")), \
21 | tf.summary.merge(tf.compat.v1.get_collection("globalVAE_Sum_kl"))
22 |
23 |
24 | def collect_CIS_summary(graph, FLAGS):
25 | #----image to show same with Inpainter_Sum----
26 | ori = graph.image_batch[0]
27 | edge = tf.concat([graph.edge_map[0], tf.zeros_like(graph.edge_map[0,:,:,0:1])], axis=-1) #H W 3
28 | mean = graph.unconditioned_mean[0]
29 | show_list = convert2uint8([ori, edge,mean]) #0~255
30 | #show_list = [tf.cast(edge*128, tf.uint8)]
31 | tf.compat.v1.summary.image('image_edge_unconditionedMean',
32 | tf.stack(show_list, axis=0), max_outputs=len(show_list),
33 | collections=["CIS_Sum"])
34 |
35 | for i in range(FLAGS.num_branch):
36 | mask = graph.generated_masks[0,:,:,:,i]
37 | #aug = mask*ori + ori*(1-mask)*0.2
38 | context = ori *(1-mask) # H W 3
39 | GT = ori*mask
40 | predict = graph.pred_intensities[0,:,:,:,i]*mask
41 | show_list = convert2uint8([GT,context,predict])
42 | tf.compat.v1.summary.image('branch{}'.format(i),
43 | tf.stack(show_list, axis=0), max_outputs=len(show_list),
44 | collections=["CIS_Sum"])
45 |
46 | #-----curve to show
47 | #plot multiple curvations in one figure
48 | tf.summary.scalar('Inpainter_Loss', graph.loss['Inpainter'], collections=['CIS_Sum'])
49 | tf.summary.scalar('Generator_Loss', graph.loss['Generator_var'], collections=['CIS_Sum_Generator'])
50 | tf.summary.scalar('Inpainter_Loss/ branch', graph.loss['Inpainter_branch_var'], collections=['CIS_Sum_branch'])
51 | tf.summary.scalar('Generator_Loss/ branch', graph.loss['Generator_branch_var'], collections=['CIS_Sum_branch'])
52 | tf.summary.scalar('Generator_Loss/ denominator', graph.loss['Generator_denominator_var'], collections=['CIS_Sum_branch'])
53 | tf.summary.scalar('IoU Validation',graph.loss['EvalIoU_var'], collections=['CIS_eval'])
54 |
55 | #------histogram to show
56 | for grad, var in graph.train_vars_grads['Generator']:
57 | tf.summary.histogram(var.op.name+'/grad', grad, collections=['CIS_Sum'])
58 | for grad, var in graph.train_vars_grads['Inpainter']:
59 | tf.summary.histogram(var.op.name+'/grad', grad, collections=['CIS_Sum'])
60 | return tf.summary.merge(tf.compat.v1.get_collection("CIS_Sum")), \
61 | tf.summary.merge(tf.compat.v1.get_collection("CIS_Sum_Generator")), \
62 | tf.summary.merge(tf.compat.v1.get_collection("CIS_Sum_branch")), \
63 | tf.summary.merge(tf.compat.v1.get_collection("CIS_eval"))
64 |
65 |
66 | def collect_VAE_summary(graph, FLAGS):
67 | ori = graph.image_batch[0]
68 | fusion = graph.fusion_outputs[0]
69 | show_list = convert2uint8([ori, fusion])
70 | tf.compat.v1.summary.image('image output', tf.stack(show_list, axis=0), max_outputs=len(show_list), collections=["VAE_Sum"])
71 |
72 |
73 | seg_masks = tf.transpose(graph.generated_masks[0,:,:,:,:]*tf.expand_dims(ori, axis=-1),[3,0,1,2]) #N H W 3
74 | tf.compat.v1.summary.image('segmentation', tf.cast(seg_masks*255,tf.uint8), max_outputs=FLAGS.num_branch, collections=["VAE_Sum"])
75 |
76 | for i in range(FLAGS.num_branch-FLAGS.n_bg):
77 | seg = graph.generated_masks[0,:,:,:,i]*ori
78 | out_mask = tf.tile(graph.VAE_outmasks[0,:,:,:,i],[1,1,3])
79 | out_tex = graph.VAE_outtexes[0,:,:,:,i]
80 | fusion = graph.VAE_fusion[0,:,:,:,i]
81 | show_list = convert2uint8([seg, out_mask, out_tex, fusion])
82 | tf.compat.v1.summary.image('branch{}'.format(i), tf.stack(show_list, axis=0), max_outputs=len(show_list), collections=["VAE_Sum"])
83 |
84 | #background
85 | seg = tf.reduce_sum(graph.generated_masks[0,:,:,:,-1*FLAGS.n_bg:], axis=-1)*ori
86 | out_bg_tex = graph.VAE_outtex_bg[0,:,:,:] #H W 3
87 | show_list = convert2uint8([seg, out_bg_tex])
88 | tf.compat.v1.summary.image('background', tf.stack(show_list, axis=0), max_outputs=len(show_list), collections=["VAE_Sum"])
89 |
90 | tf.summary.scalar('Tex_error', graph.loss['tex_error'], collections=["VAE_Sum"])
91 | tf.summary.scalar('Mask_error', graph.loss['mask_error'], collections=["VAE_Sum"])
92 | tf.summary.scalar('BG_error', graph.loss['bg_error'], collections=["VAE_Sum"])
93 |
94 | tf.summary.scalar('VAEFusion_error', graph.loss['VAE_fusion_error'], collections=["VAE_Sum"])
95 | tf.summary.scalar('Fusion_Loss', graph.loss['Fusion'], collections=["VAE_Sum"])
96 |
97 | tf.summary.scalar('Tex_latent_space', graph.loss['tex_kl_var'], collections=['VAE_Sum_tex'])
98 | tf.summary.scalar('Mask_latent_space', graph.loss['mask_kl_var'], collections=['VAE_Sum_mask'])
99 | tf.summary.scalar('BG_latent_space', graph.loss['bg_kl_var'], collections=['VAE_Sum_bg'])
100 |
101 | for grad, var in graph.train_vars_grads['VAE//separate']:
102 | tf.summary.histogram(var.op.name+'/grad', grad, collections=['VAE_Sum'])
103 | for grad, var in graph.train_vars_grads['VAE//fusion']:
104 | tf.summary.histogram(var.op.name+'/grad', grad, collections=['VAE_Sum'])
105 | for grad, var in graph.train_vars_grads['Fusion']:
106 | tf.summary.histogram(var.op.name+'/grad', grad, collections=['VAE_Sum'])
107 |
108 | return tf.summary.merge(tf.compat.v1.get_collection("VAE_Sum")), \
109 | tf.summary.merge(tf.compat.v1.get_collection("VAE_Sum_tex")), \
110 | tf.summary.merge(tf.compat.v1.get_collection("VAE_Sum_mask")), \
111 | tf.summary.merge(tf.compat.v1.get_collection("VAE_Sum_bg"))
112 |
113 | def collect_end2end_summary(graph, FLAGS):
114 | ori = graph.image_batch[0]
115 | fusion = graph.fusion_outputs[0]
116 | show_list = convert2uint8([ori, fusion])
117 | tf.compat.v1.summary.image('image output', tf.stack(show_list, axis=0), max_outputs=len(show_list), collections=["end2end_Sum"])
118 |
119 | seg_masks = tf.transpose(graph.generated_masks[0,:,:,:,:]*tf.expand_dims(ori, axis=-1),[3,0,1,2]) #N H W 3
120 | tf.compat.v1.summary.image('segmentation', tf.cast(seg_masks*255,tf.uint8), max_outputs=FLAGS.num_branch, collections=["end2end_Sum"])
121 |
122 | for i in range(FLAGS.num_branch-FLAGS.n_bg):
123 | seg = graph.generated_masks[0,:,:,:,i]*ori
124 | out_mask = tf.tile(graph.VAE_outmasks[0,:,:,:,i],[1,1,3])
125 | out_tex = graph.VAE_outtexes[0,:,:,:,i]
126 | fusion = graph.VAE_fusion[0,:,:,:,i]
127 | show_list = convert2uint8([seg, out_mask, out_tex, fusion])
128 | tf.compat.v1.summary.image('branch{}'.format(i), tf.stack(show_list, axis=0), max_outputs=len(show_list), collections=["end2end_Sum"])
129 |
130 | #background
131 | seg = tf.reduce_sum(graph.generated_masks[0,:,:,:,-1*FLAGS.n_bg:], axis=-1)*ori
132 | out_bg_tex = graph.VAE_outtex_bg[0,:,:,:] #H W 3
133 | show_list = convert2uint8([seg, out_bg_tex])
134 | tf.compat.v1.summary.image('background', tf.stack(show_list, axis=0), max_outputs=len(show_list), collections=["end2end_Sum"])
135 |
136 |
137 | #-----curve to show-------------
138 | #tf.summary.scalar('CIS', graph.loss['CIS'], collections=['end2end_Sum'])
139 | #tf.summary.scalar('Inpainter_Loss', graph.loss['Inpainter'], collections=['end2end_Sum'])
140 | tf.summary.scalar('Tex_error', graph.loss['tex_error'], collections=["end2end_Sum"])
141 | tf.summary.scalar('Mask_error', graph.loss['mask_error'], collections=["end2end_Sum"])
142 | tf.summary.scalar('BG_error', graph.loss['bg_error'], collections=["end2end_Sum"])
143 |
144 | tf.summary.scalar('VAEFusion_error', graph.loss['VAE_fusion_error'], collections=["end2end_Sum"])
145 | tf.summary.scalar('Fusion_Loss', graph.loss['Fusion'], collections=["end2end_Sum"])
146 |
147 | tf.summary.scalar('Tex_latent_space', graph.loss['tex_kl_var'], collections=['end2end_Sum_tex'])
148 | tf.summary.scalar('Mask_latent_space', graph.loss['mask_kl_var'], collections=['end2end_Sum_mask'])
149 | tf.summary.scalar('BG_latent_space', graph.loss['bg_kl_var'], collections=['end2end_Sum_bg'])
150 |
151 | tf.summary.scalar('IoU Validation',graph.loss['EvalIoU_var'], collections=['CIS_eval'])
152 |
153 | #-----histogram to show-----------
154 | for grad, var in graph.train_vars_grads['VAE//separate/texVAE']:
155 | tf.summary.histogram(var.op.name+'/grad', grad, collections=['end2end_Sum'])
156 | for grad, var in graph.train_vars_grads['VAE//separate/bgVAE']:
157 | tf.summary.histogram(var.op.name+'/grad', grad, collections=['end2end_Sum'])
158 | for grad, var in graph.train_vars_grads['VAE//fusion']:
159 | tf.summary.histogram(var.op.name+'/grad', grad, collections=['end2end_Sum'])
160 | for grad, var in graph.train_vars_grads['Fusion']:
161 | tf.summary.histogram(var.op.name+'/grad', grad, collections=['end2end_Sum'])
162 |
163 | return tf.summary.merge(tf.compat.v1.get_collection("end2end_Sum")), \
164 | tf.summary.merge(tf.compat.v1.get_collection("end2end_Sum_tex")), \
165 | tf.summary.merge(tf.compat.v1.get_collection("end2end_Sum_bg")), \
166 | tf.summary.merge(tf.compat.v1.get_collection("CIS_eval"))
167 |
168 | def collect_PC_summary(graph, FLAGS):
169 | ori = graph.image_batch[0]
170 | show_list = [ori]
171 | for i in range(FLAGS.num_branch):
172 | seg = graph.generated_masks[0,:,:,:,i]*ori
173 | show_list.append(seg)
174 | show_list = convert2uint8(show_list)
175 | tf.compat.v1.summary.image('original image', tf.stack(show_list, axis=0), max_outputs=len(show_list), collections=["PC_Sum"])
176 |
177 | for k in range(6):
178 | new = graph.new_imgs[k][0,:,:,:] #H W 3
179 | show_list = [new]
180 | for i in range(FLAGS.num_branch):
181 | seg = graph.new_generated_masks[k][0,:,:,:,i]*new
182 | show_list.append(seg)
183 | show_list.append(tf.tile(graph.new_labels[k][0,:,:,:,i],[1,1,3]))
184 | show_list = convert2uint8(show_list)
185 | tf.compat.v1.summary.image('perturbed image '+str(k), tf.stack(show_list, axis=0), max_outputs=len(show_list), collections=["PC_Sum"])
186 | tf.summary.scalar('CIS', graph.loss['CIS'], collections=['PC_Sum'])
187 | tf.summary.scalar('Inpainter_Loss', graph.loss['Inpainter'], collections=['PC_Sum'])
188 | tf.summary.scalar('PC', graph.loss['PC'], collections=['PC_Sum'])
189 | tf.summary.scalar('IoU Validation',graph.loss['EvalIoU_var'], collections=['CIS_eval'])
190 | tf.summary.scalar('Identity Switching Rate',graph.switching_rate, collections=['CIS_eval'])
191 |
192 | return tf.summary.merge(tf.compat.v1.get_collection("PC_Sum")), \
193 | tf.summary.merge(tf.compat.v1.get_collection("CIS_eval")), \
194 |
195 | def collect_inpainter_summary(graph, FLAGS):
196 | #---------image to show-------------
197 |
198 | # original image edge_map
199 | ori = graph.image_batch[0]
200 | edge = tf.concat([graph.edge_map[0], tf.zeros_like(graph.edge_map[0,:,:,0:1])], axis=-1) #H W 3
201 | show_list = convert2uint8([ori, edge, graph.unconditioned_mean[0]]) #0~255
202 |
203 | tf.compat.v1.summary.image('image_edge',
204 | tf.stack(show_list, axis=0), max_outputs=len(show_list),
205 | collections=["Inpainter_Sum"])
206 |
207 | for i in range(FLAGS.num_branch):
208 | mask = graph.generated_masks[0,:,:,:,i]
209 | context = ori *(1-mask)
210 | GT = ori*mask
211 | predict = graph.pred_intensities[0,:,:,:,i]*mask
212 | show_list = convert2uint8([GT,context,predict])
213 | tf.compat.v1.summary.image('branch{}'.format(i),
214 | tf.stack(show_list, axis=0), max_outputs=len(show_list),
215 | collections=["Inpainter_Sum"])
216 |
217 | loss = graph.loss['Inpainter']
218 | tf.summary.scalar('Inpainter_Loss', loss, collections=['Inpainter_Sum'])
219 |
220 | for grad, var in graph.train_vars_grads['Inpainter']:
221 | tf.summary.histogram(var.op.name+'/grad', grad, collections=['Inpainter_Sum'])
222 |
223 | return tf.summary.merge(tf.compat.v1.get_collection('Inpainter_Sum'))
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('..')
--------------------------------------------------------------------------------
/model/globalVAE_graph.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import os
3 | from data import multi_texture_utils, flying_animals_utils, multi_dsprites_utils, objects_room_utils
4 | from .utils.generic_utils import bin_edge_map, train_op,myprint, myinput, erode_dilate, tf_resize_imgs, tf_normalize_imgs
5 | from .utils.loss_utils import Generator_Loss, Inpainter_Loss, Supervised_Generator_Loss
6 | from .nets import Generator_forward, Inpainter_forward, VAE_forward, Fusion_forward, encoder_decoder, gaussian_kl
7 |
8 |
9 | class Train_Graph(object):
10 | def __init__(self, FLAGS):
11 | self.config = FLAGS
12 | #load data
13 | self.batch_size = FLAGS.batch_size
14 | self.img_height, self.img_width = FLAGS.img_height, FLAGS.img_width
15 | #hyperparameters
16 | def build(self):
17 | train_dataset = self.load_training_data()
18 | self.train_iterator = train_dataset.make_one_shot_iterator()
19 | train_batch = self.train_iterator.get_next()
20 |
21 | self.image_batch, self.GT_masks = train_batch['img'], train_batch['masks']
22 | self.image_batch.set_shape([None, self.img_height, self.img_width, 3])
23 |
24 | with tf.compat.v1.variable_scope("VAE") as scope:
25 | z_mean, z_log_sigma_sq, out_logit = encoder_decoder(x=self.image_batch, output_ch=3, latent_dim=self.config.tex_dim, training=True)
26 |
27 | self.latent_loss_dim = tf.reduce_mean(gaussian_kl(z_mean, z_log_sigma_sq), 0) #average on batch dim,
28 | self.latent_loss = tf.reduce_sum(self.latent_loss_dim)
29 |
30 |
31 | self.out_imgs = tf.nn.sigmoid(out_logit) #0~1
32 |
33 | if self.config.VAE_loss == 'L1':
34 | self.reconstr_loss = tf.reduce_sum(tf.reduce_mean(tf.abs(self.image_batch-self.out_imgs), axis=0)) #B H W 3
35 | else:
36 | self.reconstr_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=self.image_batch, logits=out_logit) #B H W 3
37 | self.reconstr_loss = tf.reduce_sum(self.reconstr_loss, axis=[1,2,3]) #B,
38 | self.reconstr_loss = tf.reduce_mean(self.reconstr_loss)
39 |
40 | self.loss = self.reconstr_loss+self.config.tex_beta*self.latent_loss
41 |
42 |
43 |
44 | #------------------------------
45 | with tf.name_scope('train_op'):
46 | self.global_step = tf.Variable(0, name='global_step', trainable=False)
47 | self.incr_global_step = tf.assign(self.global_step, self.global_step+1)
48 | self.train_ops, self.train_vars_grads = self.get_train_ops_grads()
49 |
50 | with tf.name_scope('summary_vars'):
51 | self.kl_var = tf.Variable(0.0, name='kl_var')
52 | #
53 | def get_train_ops_grads(self):
54 | optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=self.config.VAE_lr)
55 | train_vars = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, 'VAE')
56 | update_op = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS, 'VAE')
57 | train_ops, train_vars_grads = train_op(loss=self.loss,
58 | var_list=train_vars, optimizer=optimizer, gradient_clip_value=-1)
59 | train_ops = tf.group([train_ops, update_op])
60 | return train_ops, train_vars_grads
61 |
62 | def load_training_data(self):
63 | if self.config.dataset == 'multi_texture':
64 | return multi_texture_utils.dataset(self.config.root_dir, val=False,
65 | batch_size=self.batch_size, max_num=self.config.max_num,
66 | zoom=('z' in self.config.variant),
67 | rotation=('r' in self.config.variant),
68 | texture_transform=self.config.texture_transform)
69 | elif self.config.dataset == 'flying_animals':
70 | return flying_animals_utils.dataset(self.config.root_dir,val=False,
71 | batch_size=self.batch_size, max_num=self.config.max_num)
72 | elif self.config.dataset == 'multi_dsprites':
73 | return multi_dsprites_utils.dataset(self.config.root_dir,val=False,
74 | batch_size=self.batch_size, skipnum=self.config.skipnum, takenum=self.config.takenum,
75 | shuffle=True, map_parallel_calls=tf.data.experimental.AUTOTUNE)
76 | elif self.config.dataset == 'objects_room':
77 | return objects_room_utils.dataset(self.config.root_dir,val=False,
78 | batch_size=self.batch_size, skipnum=self.config.skipnum, takenum=self.config.takenum,
79 | shuffle=True, map_parallel_calls=tf.data.experimental.AUTOTUNE)
80 | else:
81 | raise IOError("Unknown Dataset")
82 |
--------------------------------------------------------------------------------
/model/nets.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from .utils.convolution_utils import gen_conv, gen_deconv, conv, deconv
3 | from .utils.convolution_utils import _dilated_conv2d, conv2d, deconv2d, conv2d_transpose, InstanceNorm, fully_connect
4 | from tensorflow.contrib.slim.nets import resnet_v2
5 | from .utils.generic_utils import bin_edge_map, erode_dilate
6 | from .utils.loss_utils import region_error
7 | import math
8 | import numpy as np
9 |
10 | def Generator_forward(images, dataset, num_mask, model='resnet_v2_50', scope='Generator', reuse=None, training=True):
11 | if 'resnet' in model:
12 | generated_masks, logits = generator_resnet(images, dataset, num_mask,
13 | reuse=reuse, training=training, scope=scope, model=model) #params 2e7
14 | else:
15 | generated_masks, logits = generator_segnet(images, num_mask,
16 | scope=scope, reuse=reuse, training=training)
17 | return tf.expand_dims(generated_masks, axis=-2), tf.expand_dims(logits, axis=-2) # B H W 1 C B H W 1 C
18 |
19 | def Inpainter_forward(num_branch, input_masks, images, scope, dataset, reuse=None, training=True):
20 | edge_map = bin_edge_map(images, dataset)
21 | pred_intensities = []
22 | for m in range(num_branch):
23 | mask = input_masks[:,:,:,:,m] #B H W 1
24 | reuse = True if reuse==True else (m>0)
25 | pred_intensity = inpaint_net(images, mask, edge_map, scope=scope, reuse=reuse, training=training)
26 | pred_intensities.append(pred_intensity)
27 | unconditioned_mean = inpaint_net(tf.zeros_like(images), tf.ones_like(mask),
28 | edge_map, scope=scope, reuse=True, training=training)
29 | pred_intensities = tf.stack(pred_intensities, axis=-1) #B H W 3 C
30 | return pred_intensities, unconditioned_mean, edge_map
31 |
32 | def inpaint_net(image, mask, edge, scope, reuse=None, training=True): #params 3e6
33 | # intensity_masked
34 | # B H W 3
35 | image = image - 0.5 #0~1 -> -0.5~0.5
36 | intensity_masked = image*(1-mask)
37 | orisize = intensity_masked.get_shape().as_list()[1:-1] #[H,W]
38 | C = intensity_masked.get_shape().as_list()[-1] # 3
39 | f=0.5
40 | #edge B H W 2
41 | edge_in_channels = edge.get_shape().as_list()[-1] #2
42 |
43 | ones_x = tf.ones_like(intensity_masked)[:, :, :, 0:1] # B H W 1
44 | intensity_masked = tf.concat([intensity_masked, ones_x, 1-mask], axis=3) # B H W C+2
45 | intensity_in_channels = intensity_masked.get_shape().as_list()[-1]
46 |
47 | with tf.variable_scope(scope, reuse=reuse):
48 |
49 | aconv1 = conv( edge, 'aconv1', shape=[7,7, edge_in_channels, int(64*f)], stride=2, reuse=reuse, training=training ) # h/2(192), 64
50 | aconv2 = conv( aconv1, 'aconv2', shape=[5,5,int(64*f), int(128*f)], stride=2, reuse=reuse, training=training ) # h/4(96), 128
51 | aconv3 = conv( aconv2, 'aconv3', shape=[5,5,int(128*f),int(256*f)], stride=2, reuse=reuse, training=training ) # h/8(48), 256
52 | aconv31= conv( aconv3, 'aconv31', shape=[3,3,int(256*f),int(256*f)], stride=1, reuse=reuse, training=training )
53 | aconv4 = conv( aconv31, 'aconv4', shape=[3,3,int(256*f),int(512*f)], stride=2, reuse=reuse, training=training ) # h/16(24), 512
54 | aconv41= conv( aconv4, 'aconv41', shape=[3,3,int(512*f),int(512*f)], stride=1, reuse=reuse, training=training )
55 | aconv5 = conv( aconv41, 'aconv5', shape=[3,3,int(512*f),int(512*f)], stride=2, reuse=reuse, training=training ) # h/32(12), 512
56 | aconv51= conv( aconv5, 'aconv51', shape=[3,3,int(512*f),int(512*f)], stride=1, reuse=reuse, training=training )
57 | aconv6 = conv( aconv51, 'aconv6', shape=[3,3,int(512*f),int(512*f)], stride=2, reuse=reuse, training=training ) # h/64(6), 512
58 |
59 | bconv1 = conv( intensity_masked, 'bconv1', shape=[7,7, intensity_in_channels, int(64*f)], stride=2, reuse=reuse, training=training ) # h/2(192), 64
60 | bconv2 = conv( bconv1, 'bconv2', shape=[5,5,int(64*f), int(128*f)], stride=2, reuse=reuse, training=training ) # h/4(96), 128
61 | bconv3 = conv( bconv2, 'bconv3', shape=[5,5,int(128*f),int(256*f)], stride=2, reuse=reuse, training=training ) # h/8(48), 256
62 | bconv31= conv( bconv3, 'bconv31', shape=[3,3,int(256*f),int(256*f)], stride=1, reuse=reuse, training=training )
63 | bconv4 = conv( bconv31, 'bconv4', shape=[3,3,int(256*f),int(512*f)], stride=2, reuse=reuse, training=training ) # h/16(24), 512
64 | bconv41= conv( bconv4, 'bconv41', shape=[3,3,int(512*f),int(512*f)], stride=1, reuse=reuse, training=training )
65 | bconv5 = conv( bconv41, 'bconv5', shape=[3,3,int(512*f),int(512*f)], stride=2, reuse=reuse, training=training ) # h/32(12), 512
66 | bconv51= conv( bconv5, 'bconv51', shape=[3,3,int(512*f),int(512*f)], stride=1, reuse=reuse, training=training )
67 | bconv6 = conv( bconv51, 'bconv6', shape=[3,3,int(512*f),int(512*f)], stride=2, reuse=reuse, training=training ) # h/64(6), 512
68 |
69 | #conv6 = tf.add( aconv6, bconv6 )
70 | conv6 = tf.concat( (aconv6, bconv6), 3 ) #h/64(6) 512*2*f
71 | outsz = bconv51.get_shape() # h/32(12), 512*f
72 | deconv5 = deconv( conv6, size=[outsz[1],outsz[2]], name='deconv5', shape=[4,4,int(512*2*f),int(512*f)], reuse=reuse, training=training )
73 | concat5 = tf.concat( (deconv5,bconv51,aconv51), 3 ) # h/32(12), 512*3*f
74 |
75 | intensity5 = conv( concat5, 'intensity5', shape=[3,3,int(512*3*f),C], stride=1, reuse=reuse, training=training, activation=tf.identity ) # h/32(12), C
76 | outsz = bconv41.get_shape() # h/16(24), 512*f
77 | deconv4 = deconv( concat5, size=[outsz[1],outsz[2]], name='deconv4', shape=[4,4,int(512*3*f),int(512*f)], reuse=reuse, training=training )
78 | upintensity4 = deconv( intensity5, size=[outsz[1],outsz[2]], name='upintensity4', shape=[4,4,C,C], reuse=reuse, training=training, activation=tf.identity )
79 | concat4 = tf.concat( (deconv4,bconv41,aconv41,upintensity4), 3 ) # h/16(24), 512*3*f+C
80 |
81 | intensity4 = conv( concat4, 'intensity4', shape=[3,3,int(512*3*f+C),C], stride=1, reuse=reuse, training=training, activation=tf.identity ) # h/16(24), C
82 | outsz = bconv31.get_shape() # h/8(48), 256*f
83 | deconv3 = deconv( concat4, size=[outsz[1],outsz[2]], name='deconv3', shape=[4,4,int(512*3*f+C),int(256*f)], reuse=reuse, training=training )
84 | upintensity3 = deconv( intensity4, size=[outsz[1],outsz[2]], name='upintensity3', shape=[4,4,C,C], reuse=reuse, training=training, activation=tf.identity )
85 | concat3 = tf.concat( (deconv3,bconv31,aconv31,upintensity3), 3 ) # h/8(48), 256*3*f+C
86 |
87 | intensity3 = conv( concat3, 'intensity3', shape=[3,3,int(256*3*f+C),C], stride=1, reuse=reuse, training=training, activation=tf.identity ) # h/8(48), C
88 | outsz = bconv2.get_shape() # h/4(96), 128*f
89 | deconv2 = deconv( concat3, size=[outsz[1],outsz[2]], name='deconv2', shape=[4,4,int(256*3*f+C),int(128*f)], reuse=reuse, training=training )
90 | upintensity2 = deconv( intensity3, size=[outsz[1],outsz[2]], name='upintensity2', shape=[4,4,C,C], reuse=reuse, training=training, activation=tf.identity )
91 | concat2 = tf.concat( (deconv2,bconv2,aconv2,upintensity2), 3 ) # h/4(96), 128*3*f+C
92 |
93 | intensity2 = conv( concat2, 'intensity2', shape=[3,3,int(128*3*f+C),C], stride=1, reuse=reuse, training=training, activation=tf.identity ) # h/4(96), C
94 | outsz = bconv1.get_shape() # h/2(192), 64*f
95 | deconv1 = deconv( concat2, size=[outsz[1],outsz[2]], name='deconv1', shape=[4,4,int(128*3*f+C),int(64*f)], reuse=reuse, training=training )
96 | upintensity1 = deconv( intensity2, size=[outsz[1],outsz[2]], name='upintensity1', shape=[4,4,C,C], reuse=reuse, training=training, activation=tf.identity )
97 | concat1 = tf.concat( (deconv1,bconv1,aconv1,upintensity1), 3 ) # h/2(192), 64*3*f+C
98 |
99 | intensity1 = conv( concat1, 'intensity1', shape=[5,5,int(64*3*f+C),C], stride=1, reuse=reuse, training=training, activation=tf.identity ) # h/2(192), C
100 | pred_intensity = tf.image.resize_images(intensity1, size=orisize)
101 |
102 | pred_intensity = pred_intensity + 0.5
103 | return pred_intensity
104 |
105 |
106 | def generator_resnet(images, dataset , num_mask, scope, model='resnet_v2_50', reuse=None, training=True):
107 | #images = 0~1
108 | images = (images-0.5)*2 #-1 ~ 1
109 | assert dataset in ['multi_texture','flying_animals','multi_dsprites','objects_room']
110 | if dataset in ['multi_texture', 'multi_dsprites','objects_room']:
111 | images = tf.image.resize_images(images, size=(128,128)) #64*64 -> 128*128
112 | #pad to 32k+1
113 | x = tf.pad(images, paddings=[[0,0],[0,1],[0,1],[0,0]], mode='REFLECT')
114 | dilations = [6, 12, 18, 24] if dataset=='flying_animals' else [2,4,6,8]
115 | o = []
116 | with tf.compat.v1.variable_scope(scope, reuse=reuse):
117 | if model=='resnet_v2_50':
118 | net, end_points = resnet_v2.resnet_v2_50(x, None, is_training=training, global_pool=False, output_stride=4, reuse=reuse, scope=None)
119 | elif model=='resnet_v2_101':
120 | net, end_points = resnet_v2.resnet_v2_101(x, None, is_training=training, global_pool=False, output_stride=4, reuse=reuse, scope=None)
121 | else:
122 | raise IOError("Only resnet_v2_50 or resnet_v2_101 available")
123 | #classfication
124 | with tf.compat.v1.variable_scope(scope, reuse=tf.AUTO_REUSE):
125 | for i, d in enumerate(dilations):
126 | o.append(_dilated_conv2d(net, 3, num_mask, d, name='aspp/conv%d' % (i+1), biased=True))
127 | logits = tf.add_n(o)
128 |
129 | if dataset == 'flying_animals':
130 | logits = tf.image.resize_images(logits, size=(193,257)) #B H W C align the feature
131 | logits = logits[:,:-1,:-1,:] #192 256
132 | else:
133 | logits = tf.image.resize_images(logits, size=(129,129))
134 | logits = logits[:,:-1,:-1,:] #128 128
135 | logits = tf.image.resize_images(logits, size=(64,64)) #64 64
136 | generated_masks = tf.nn.softmax(logits, axis=-1) # B H W cnum
137 | return generated_masks, logits
138 |
139 | def generator_segnet(images, num_mask, scope, reuse=None, training=True, div=10.0): #cnum=32 #params1.5e6
140 | """Mask network.
141 | Args:
142 | image: input rgb image [0, 1]
143 | num_mask: number of mask
144 | Returns:
145 | mask: mask region [0, 1], 1 is fully masked, 0 is not. *num_mask
146 | """
147 |
148 | mask_channels = num_mask # probability of each mask
149 | x = images
150 | cnum = 64
151 | with tf.compat.v1.variable_scope(scope, reuse=reuse):
152 | # stage1
153 | x_0 = gen_conv(x, cnum, 5, 1, name='conv1', training=training) # ---------------------------
154 | x = gen_conv(x_0, 2*cnum, 3, 2, name='conv2_downsample', training=training) # Skip connection
155 | x_1 = gen_conv(x, 2*cnum, 3, 1, name='conv3', training=training) # -------------------
156 | x = gen_conv(x_1, 4*cnum, 3, 2, name='conv4_downsample', training=training)
157 | x = gen_conv(x, 4*cnum, 3, 1, name='conv5', training=training)
158 | x_2 = gen_conv(x, 4*cnum, 3, 1, name='conv6', training=training) # -----------------
159 | x = gen_conv(x_2, 4*cnum, 3, rate=2, name='conv7_atrous', training=training)
160 | x = gen_conv(x, 4*cnum, 3, rate=4, name='conv8_atrous', training=training)
161 | x = gen_conv(x, 4*cnum, 3, rate=8, name='conv9_atrous', training=training)
162 | x = gen_conv(x, 4*cnum, 3, rate=16, name='conv10_atrous', training=training)
163 | x = gen_conv(x, 4*cnum, 3, 1, name='conv11', training=training) + x_2 #-------------
164 | x = gen_conv(x, 4*cnum, 3, 1, name='conv12', training=training)
165 | x = gen_deconv(x, 2*cnum, name='conv13_upsample', training=training)
166 | x = gen_conv(x, 2*cnum, 3, 1, name='conv14', training=training) + x_1 # --------------------
167 | x = gen_deconv(x, cnum, name='conv15_upsample', training=training) + x_0 #-------------------
168 | x = gen_conv(x, cnum//2, 3, 1, name='conv16', training=training)
169 | x = gen_conv(x, mask_channels, 3, 1, activation=tf.identity,
170 | name='conv17', training=training)
171 | # Division by constant experimentally improved training
172 | x = tf.divide(x, tf.constant(div))
173 | generated_mask = tf.nn.softmax(x, axis=-1) #soft mask normalization #B*H*W*num_mask
174 | return generated_mask, x
175 |
176 |
177 | def _sample_z(z_mean, z_log_sigma_sq):
178 | eps_shape = tf.shape(z_mean)
179 | eps = tf.random_normal( eps_shape, 0, 1, dtype=tf.float32 )
180 | # z = mu + sigma * epsilon
181 | z = tf.add(z_mean,tf.multiply(tf.sqrt(tf.exp(z_log_sigma_sq)), eps))
182 | return z
183 |
184 | def encoder(x, latent_dim, training=True):
185 | B, H, W, C = x.get_shape().as_list()
186 | conv1 = conv2d(x, filter_shape=[4,4,C,32], stride=2, padding='SAME', name='conv1', biased=True, dilation=1)
187 | conv1 = tf.nn.relu(conv1)
188 |
189 | conv2 = conv2d(conv1, filter_shape=[4,4,32,32], stride=2, padding='SAME', name='conv2', biased=True, dilation=1)
190 | conv2 = tf.nn.relu(conv2)
191 |
192 | conv3 = conv2d(conv2, filter_shape=[4,4,32,32], stride=2, padding='SAME', name='conv3', biased=True, dilation=1)
193 | conv3 = tf.nn.relu(conv3)
194 |
195 | conv4 = conv2d(conv3, filter_shape=[4,4,32,32], stride=2, padding='SAME', name='conv4', biased=True, dilation=1)
196 | conv4 = tf.nn.relu(conv4)
197 | #B 4 4 32
198 |
199 | flatten = tf.reshape(conv4, [-1,(H//16)*(W//16)*32])
200 | fc1 = fully_connect(flatten, weight_shape=[(H//16)*(W//16)*32, 256], name='fc1', biased=True)
201 | fc1 = tf.nn.relu(fc1)
202 | fc2 = fully_connect(fc1, weight_shape=[256, 256], name='fc2', biased=True)
203 | fc2 = tf.nn.relu(fc2)
204 |
205 | z_mean = fully_connect(fc2, weight_shape=[256, latent_dim], name='fc_zmean', biased=True, bias_init_value=0.0)
206 | z_log_sigma_sq = fully_connect(fc2, weight_shape=[256, latent_dim], name='fc_logsigmasq', biased=True, bias_init_value=0.0)
207 | return z_mean, z_log_sigma_sq
208 |
209 | def decoder(z, output_ch, latent_dim, x, training=True):
210 | B, H, W, C = x.get_shape().as_list()
211 | up_fc2 = fully_connect(z, weight_shape=[latent_dim, 256], name='up_fc2', biased=True)
212 | up_fc2 = tf.nn.relu(up_fc2)
213 |
214 | up_fc1 = fully_connect(up_fc2, weight_shape=[256, (H//16)*(W//16)*32], name='up_fc1', biased=True)
215 | up_fc1 = tf.nn.relu(up_fc1)
216 | up_fc1 = tf.reshape(up_fc1, [-1,H//16,W//16,32])
217 |
218 | deconv4 = deconv2d(up_fc1, filter_shape=[4,4,32,32], output_size=[H//8,W//8], name='deconv4', padding='SAME', biased=True)
219 | deconv4 = tf.nn.leaky_relu(deconv4)
220 |
221 | deconv3 = deconv2d(deconv4, filter_shape=[4,4,32,32], output_size=[H//4,W//4], name='deconv3', padding='SAME', biased=True)
222 | deconv3 = tf.nn.leaky_relu(deconv3)
223 |
224 | deconv2 = deconv2d(deconv3, filter_shape=[4,4,32,32], output_size=[H//2,W//2], name='deconv2', padding='SAME', biased=True)
225 | deconv2 = tf.nn.leaky_relu(deconv2)
226 |
227 | deconv1 = deconv2d(deconv2, filter_shape=[4,4,32,output_ch], output_size=[H,W], name='deconv1', padding='SAME', biased=True)
228 |
229 | out_logit = tf.identity(deconv1)
230 | return out_logit
231 |
232 | def tex_mask_fusion(tex, mask):
233 | inputs = tf.concat([tex, mask], axis=-1)#B H W (3+1)
234 |
235 | inputs = tf.pad(inputs, paddings=[[0,0],[1,2],[1,2],[0,0]], mode='REFLECT')
236 | conv1 = conv2d(inputs, filter_shape=[4,4,4,32], stride=1, padding='VALID', name='conv1', biased=True, dilation=1)
237 | conv1 = tf.nn.relu(conv1)
238 |
239 | conv1 = tf.pad(conv1, paddings=[[0,0],[1,2],[1,2],[0,0]], mode='REFLECT')
240 | conv2 = conv2d(conv1, filter_shape=[4,4,32,32], stride=1, padding='VALID', name='conv2', biased=True, dilation=1)
241 | conv2 = tf.nn.relu(conv2)
242 |
243 | conv2 = tf.pad(conv2, paddings=[[0,0],[1,2],[1,2],[0,0]], mode='REFLECT')
244 | conv3 = conv2d(conv2, filter_shape=[4,4,32,3], stride=1, padding='VALID', name='conv3', biased=True, dilation=1)
245 | output = tf.nn.sigmoid(conv3)
246 |
247 | return output, conv3
248 |
249 | def encoder_decoder(x, output_ch, latent_dim,training=True):
250 | z_mean, z_log_sigma_sq = encoder(x, latent_dim, training=training)
251 | z = _sample_z(z_mean, z_log_sigma_sq)
252 | out_logit = decoder(z, output_ch, latent_dim, x, training=training)
253 | return z_mean, z_log_sigma_sq, out_logit
254 |
255 | def gaussian_kl(mean, log_sigma_sq):
256 | latent_loss = -0.5 * (1 + log_sigma_sq
257 | - tf.square(mean)
258 | - tf.exp(log_sigma_sq)) #B*Z_dim
259 | return latent_loss # B*z_dim
260 |
261 | def VAE_forward(image, masks, bg_dim, tex_dim, mask_dim, scope='VAE', reuse=None, training=True, augmentation=False):
262 | B, H, W, C = image.get_shape().as_list()
263 | num_branch = masks.get_shape().as_list()[-1] #B H W 1 M
264 |
265 | with tf.compat.v1.variable_scope(scope, reuse=reuse):
266 | tex_kl, out_texes = [],[]
267 | mask_kl, out_masks_logit = [], []
268 | fusion_error, out_fusion = [], []
269 | latent_zs = {'tex':[], 'mask':[], 'bg':None}
270 |
271 | for i in range(num_branch):
272 |
273 |
274 | inputs = image*masks[:,:,:,:,i] #B H W 3
275 | with tf.compat.v1.variable_scope('separate/texVAE', reuse=tf.compat.v1.AUTO_REUSE):
276 | z_mean, z_log_sigma_sq, out_logit = encoder_decoder(inputs, output_ch=3, latent_dim=tex_dim, training=training)
277 | out_texes.append(tf.nn.sigmoid(out_logit)) #B,
278 | tex_kl.append(tf.reduce_mean(gaussian_kl(z_mean, z_log_sigma_sq), 0)) #B,dim -> dim
279 | latent_zs['tex'].append(z_mean)
280 |
281 | inputs = masks[:,:,:,:,i]
282 | with tf.compat.v1.variable_scope('separate/maskVAE', reuse=tf.compat.v1.AUTO_REUSE):
283 | z_mean, z_log_sigma_sq, out_logit = encoder_decoder(inputs, output_ch=1, latent_dim=mask_dim, training=training)
284 | out_masks_logit.append(out_logit)
285 | mask_kl.append(tf.reduce_mean(gaussian_kl(z_mean, z_log_sigma_sq), 0))
286 | latent_zs['mask'].append(z_mean)
287 |
288 | #fuse tex and mask
289 | tex, mask = out_texes[-1], tf.nn.sigmoid(out_masks_logit[-1]) #B H W 3 B H W 1
290 | with tf.compat.v1.variable_scope('fusion', reuse=tf.compat.v1.AUTO_REUSE):
291 | fus_output, fus_output_logits = tex_mask_fusion(tex, mask) #B H W 3
292 | out_fusion.append(fus_output)
293 | error = tf.nn.sigmoid_cross_entropy_with_logits(labels=image*masks[:,:,:,:,i], logits=fus_output_logits) # B H W 3
294 | error = tf.reduce_mean(tf.reduce_sum(error, axis=[1,2,3]), axis=0) #B H W 3 -> B -> mean scalar
295 | fusion_error.append(error)
296 |
297 | #KL divergence loss
298 | tex_kl = tf.reduce_mean(tf.stack(tex_kl, axis=0), axis=0)# branch,dim -> dim,
299 |
300 | #reconstruction error
301 | out_texes = tf.stack(out_texes, axis=-1) # B H W 3 M
302 | tex_error = tf.reduce_mean(region_error(X=out_texes, Y=image, region=masks)) #BHW3M BHW3 BHW1M ->B,M -> scalar
303 |
304 | out_masks_logit = tf.stack(out_masks_logit, axis=-1) #B H W 1 M
305 | out_masks = tf.nn.sigmoid(out_masks_logit) #B H W 1 M
306 |
307 | if not augmentation:
308 | mask_error_pixel = tf.nn.sigmoid_cross_entropy_with_logits(labels=masks, logits=out_masks_logit) #B H W 1 M
309 | mask_error_sum = tf.reduce_sum(mask_error_pixel, axis=[1,2,3]) #B,M
310 | mask_error = tf.reduce_mean(mask_error_sum)
311 | mask_kl = tf.reduce_mean(tf.stack(mask_kl, axis=0), axis=0)
312 | else:
313 | #-----------------data augmentation--------------
314 | #----------generate more position variation to help VAE decompose position feature---------------
315 | rep = 2
316 | aug_masks = tf.tile(masks, [2*rep,1,1,1,1]) #B H W 1 M -> 2*rep*B H W 1 M
317 | aug_masks = tf.transpose(aug_masks, perm=[0,4,1,2,3]) #2*rep*B M H W 1
318 | aug_masks = tf.reshape(aug_masks, [2*B*rep*num_branch,H,W,1])
319 | dx = tf.random.uniform(shape=[2*rep*B*num_branch,1],dtype=tf.dtypes.float32,minval=-1*30,maxval=30)
320 | dy = tf.random.uniform(shape=[2*rep*B*num_branch,1],dtype=tf.dtypes.float32,minval=-1*30,maxval=30)
321 | aug_masks = tf.contrib.image.translate(aug_masks,translations=tf.concat([dx,dy], axis=-1),interpolation='NEAREST')
322 | aug_masks = tf.random.shuffle(aug_masks, seed=None)
323 | inputs = aug_masks[0:rep*B*num_branch] #rep*B*M,H,W,1
324 |
325 | with tf.compat.v1.variable_scope('separate/maskVAE', reuse=tf.compat.v1.AUTO_REUSE):
326 | z_mean, z_log_sigma_sq, out_logit = encoder_decoder(inputs, output_ch=1, latent_dim=mask_dim, training=training)
327 | mask_error_pixel= tf.nn.sigmoid_cross_entropy_with_logits(labels=inputs, logits=out_logit) #B*rep*M H W 1
328 | mask_error_sum = tf.reduce_sum(mask_error_pixel, axis=[1,2,3]) #B,M
329 | mask_error = tf.reduce_mean(mask_error_sum)
330 | mask_kl= tf.reduce_mean(gaussian_kl(z_mean, z_log_sigma_sq), 0) #reduce batch
331 |
332 |
333 | out_fusion = tf.stack(out_fusion, axis=-1) #B H W 3 M
334 | fusion_error = tf.reduce_mean(tf.stack(fusion_error, axis=0)) #average on all branch
335 |
336 | # for background, we only learn the representation of its texture
337 | bg_mask = 1-tf.reduce_sum(masks, axis=-1)
338 | inputs = image*bg_mask #B H W 1
339 | with tf.compat.v1.variable_scope('separate/bgVAE', reuse=tf.compat.v1.AUTO_REUSE):
340 | z_mean, z_log_sigma_sq, out_logit = encoder_decoder(inputs, output_ch=3, latent_dim=bg_dim, training=training)
341 | out_bg = tf.nn.sigmoid(out_logit)
342 | bg_error = tf.reduce_mean(region_error(X=out_bg, Y=image, region=tf.expand_dims(bg_mask, axis=-1)))
343 | bg_kl= tf.reduce_mean(gaussian_kl(z_mean, z_log_sigma_sq), 0) #B,dim -> dim
344 | latent_zs['bg'] = z_mean
345 |
346 | loss = {'mask_kl': mask_kl, 'tex_kl': tex_kl, 'bg_kl':bg_kl,
347 | 'mask_error':mask_error, 'tex_error': tex_error, 'bg_error': bg_error , 'fusion_error': fusion_error}
348 | outputs = {'out_masks': out_masks, 'out_texes': out_texes, 'out_bg': out_bg, 'out_fusion': out_fusion}
349 | latent_zs['tex'] = tf.stack(latent_zs['tex'], axis=0) #branch-1, B, dim
350 | latent_zs['mask'] = tf.stack(latent_zs['mask'], axis=0) #branch-1 B dim
351 |
352 | return loss, outputs, latent_zs
353 |
354 |
355 | def Fusion_forward(inputs, scope='Fusion', training=True, reuse=None):
356 | #inputs B H W N*C
357 | B, H, W, C = inputs.get_shape().as_list()
358 | with tf.compat.v1.variable_scope(scope, reuse=reuse):
359 |
360 | x = tf.pad(inputs, paddings=[[0,0],[1,2],[1,2],[0,0]], mode='REFLECT')
361 | conv1 = conv2d(x, filter_shape=[4,4,C,32], stride=1, padding='VALID', name='conv1', biased=True, dilation=1)
362 | conv1 = tf.nn.relu(conv1)
363 |
364 | conv1 = tf.pad(conv1, paddings=[[0,0],[1,2],[1,2],[0,0]], mode='REFLECT')
365 | conv2 = conv2d(conv1, filter_shape=[4,4,32,32], stride=1, padding='VALID', name='conv2', biased=True, dilation=1)
366 | conv2 = tf.nn.relu(conv2)
367 |
368 | conv2 = tf.pad(conv2, paddings=[[0,0],[1,2],[1,2],[0,0]], mode='REFLECT')
369 | conv3 = conv2d(conv2, filter_shape=[4,4,32,32], stride=1, padding='VALID', name='conv3', biased=True, dilation=1)
370 | conv3 = tf.nn.relu(conv3)
371 |
372 |
373 | conv3 = tf.pad(conv3, paddings=[[0,0],[1,2],[1,2],[0,0]], mode='REFLECT')
374 | conv4 = conv2d(conv3, filter_shape=[4,4,32,3], stride=1, padding='VALID', name='conv4', biased=True, dilation=1)
375 | out = tf.nn.sigmoid(conv4)
376 |
377 | return out
378 |
379 |
380 | def Perturbation_forward(var_num, image, generated_masks,VAE_outputs0, latent_zs, mask_top_dims, scope='VAE/'):
381 | B, H, W, C, M = generated_masks.get_shape().as_list()
382 | assert M==3
383 | mask_dim = latent_zs['mask'][0].get_shape().as_list()[1]
384 | tex_dim = latent_zs['tex'][0].get_shape().as_list()[1]
385 | k = tf.random.uniform(shape=[], dtype=tf.int32, minval=0, maxval=2)
386 | b = tf.cond(tf.math.equal(k,0), lambda:1, lambda:0)
387 |
388 | new_outs, new_labels = [], []
389 | for i in range(var_num):
390 |
391 | dim_x, dim_y = mask_top_dims[0], mask_top_dims[1]
392 | new_x = tf.random.uniform(shape=[B,1], dtype=tf.float32, minval=-1, maxval=1) #B,1
393 | new_y = tf.random.uniform(shape=[B,1], dtype=tf.float32, minval=-1, maxval=1) #B,1
394 | mask_z = tf.concat([latent_zs['mask'][k,:,:dim_x],
395 | new_x,
396 | latent_zs['mask'][k,:,dim_x+1:dim_y],
397 | new_y,
398 | latent_zs['mask'][k,:,dim_y+1:]], axis=-1)
399 |
400 | with tf.compat.v1.variable_scope(scope+'/separate/maskVAE', reuse=tf.compat.v1.AUTO_REUSE):
401 | new_mask_logit = decoder(z=mask_z, output_ch=1, latent_dim=mask_dim, x=image, training=True)
402 | new_mask = tf.nn.sigmoid(new_mask_logit) #B H W 1
403 |
404 | another_mask = VAE_outputs0['out_masks'][:,:,:,:,b]
405 | new_mask = (1-another_mask)*new_mask
406 | bg_mask = 1-new_mask-another_mask
407 | new_masks = tf.cond(tf.math.equal(k,0),
408 | lambda:tf.stack([new_mask,another_mask,bg_mask], axis=-1),
409 | lambda:tf.stack([another_mask,new_mask,bg_mask], axis=-1))
410 |
411 | with tf.compat.v1.variable_scope(scope+'/fusion', reuse=tf.compat.v1.AUTO_REUSE):
412 | new_fus_output, null = tex_mask_fusion(VAE_outputs0['out_texes'][:,:,:,:,k], new_mask) #B H W 3
413 |
414 | foregrounds = tf.stack([new_fus_output, VAE_outputs0['out_fusion'][:,:,:,:,b]], axis=-1) #B H W 3 M
415 | backgrounds = tf.expand_dims(VAE_outputs0['out_bg'], axis=-1)* \
416 | tf.expand_dims(bg_mask, axis=-1)
417 |
418 | fusion_inputs = tf.concat([foregrounds, backgrounds], axis=-1) #B H W 3 fg_branch+1
419 | fusion_inputs = tf.reshape(fusion_inputs, [B,fusion_inputs.get_shape()[1],fusion_inputs.get_shape()[2],-1])
420 | fusion_outputs = Fusion_forward(inputs=fusion_inputs, scope='Fusion/', training=True, reuse=tf.compat.v1.AUTO_REUSE) #B H W 3
421 | new_outs.append(fusion_outputs)
422 | new_labels.append(new_masks)
423 |
424 | return new_outs, new_labels
425 |
426 |
--------------------------------------------------------------------------------
/model/train_graph.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import os
3 | from data import multi_texture_utils, flying_animals_utils, multi_dsprites_utils, objects_room_utils
4 | from .utils.generic_utils import bin_edge_map, train_op,myprint, myinput, erode_dilate, tf_resize_imgs, tf_normalize_imgs, reorder_mask
5 | from .utils.loss_utils import Generator_Loss, Inpainter_Loss, Supervised_Generator_Loss
6 | from .nets import Generator_forward, Inpainter_forward, VAE_forward, Fusion_forward, Perturbation_forward
7 |
8 | mode2scopes = {
9 | 'pretrain_inpainter': ['Inpainter'],
10 | 'train_CIS': ['Inpainter','Generator'],
11 | 'train_PC': ['Inpainter','Generator'],
12 | 'train_VAE': ['VAE//separate', 'VAE//fusion', 'Fusion'],
13 | 'train_end2end': ['Inpainter','Generator','VAE//separate','VAE//fusion', 'Fusion']
14 | }
15 |
16 | class Train_Graph(object):
17 | def __init__(self, FLAGS):
18 | self.config = FLAGS
19 | self.batch_size = FLAGS.batch_size
20 | self.num_branch = FLAGS.num_branch
21 | self.img_height, self.img_width = FLAGS.img_height, FLAGS.img_width
22 | def build(self):
23 | self.is_training = tf.placeholder_with_default(True, shape=(), name="is_training")
24 |
25 | train_dataset, val_dataset = self.load_training_data(), self.load_val_data()
26 | self.train_iterator = train_dataset.make_one_shot_iterator()
27 | self.val_iterator = val_dataset.make_initializable_iterator()
28 | train_batch = self.train_iterator.get_next()
29 |
30 | current_batch = tf.cond(self.is_training, lambda: train_batch, lambda: self.val_iterator.get_next())
31 |
32 | self.image_batch, self.GT_masks = current_batch['img'], current_batch['masks']
33 | self.image_batch.set_shape([self.batch_size, self.img_height, self.img_width, 3])
34 | self.GT_masks.set_shape([self.batch_size, self.img_height, self.img_width, 1, None])
35 | if self.config.mode == 'pretrain_inpainter':
36 | self.generated_masks = self.Random_boxes()
37 | else:
38 | with tf.name_scope("Generator") as scope:
39 | self.generated_masks, generated_logits = Generator_forward(self.image_batch, self.config.dataset,
40 | self.num_branch, model=self.config.model, training=self.is_training, reuse=None, scope=scope)
41 |
42 |
43 | self.loss = {}
44 | with tf.name_scope("Inpainter") as scope:
45 | self.pred_intensities, self.unconditioned_mean, self.edge_map = \
46 | Inpainter_forward(self.num_branch, self.generated_masks, self.image_batch, dataset=self.config.dataset,
47 | reuse=None, training=self.is_training, scope=scope)
48 |
49 | self.loss['Inpainter'], self.loss['Inpainter_branch'] = Inpainter_Loss(self.generated_masks, self.pred_intensities, self.image_batch)
50 | self.loss['Generator'], self.loss['Generator_branch'], self.loss['Generator_denominator'], self.loss['Generator_numerator'] = \
51 | Generator_Loss(self.generated_masks, self.pred_intensities, self.image_batch,
52 | self.unconditioned_mean, self.config.epsilon)
53 |
54 | if self.config.dataset == 'flying_animals':
55 | #normalize self.image_batch
56 | self.image_batch = tf_normalize_imgs(self.image_batch)
57 | #-----------VAE----------------
58 | if self.config.mode in ['train_VAE','train_end2end', 'train_PC']:
59 | with tf.name_scope("VAE") as scope:
60 | #-------------erode dilate smooth------------
61 | filter_masks = []
62 | for i in range(self.config.num_branch):
63 | filter_masks.append(erode_dilate(self.generated_masks[:,:,:,:,i]))
64 | self.generated_masks = tf.stack(filter_masks, axis=-1)#B H W 1 num_branch
65 | self.generated_masks = reorder_mask(self.generated_masks)
66 | if self.config.dataset == 'flying_animals':
67 | #resize image to (96, 128)
68 | self.image_batch = tf_resize_imgs(self.image_batch, size=[self.img_height//2,self.img_width//2])
69 | self.generated_masks = tf_resize_imgs(self.generated_masks, size=[self.img_height//2, self.img_width//2])
70 |
71 | VAE_loss, VAE_outputs, latent_zs = VAE_forward(image=self.image_batch, masks=self.generated_masks[:,:,:,:,:-1*self.config.n_bg],
72 | bg_dim=self.config.bg_dim, tex_dim=self.config.tex_dim, mask_dim=self.config.mask_dim,
73 | scope=scope, reuse=None, training=self.is_training, augmentation=self.config.PC)
74 | self.VAE_outputs = VAE_outputs
75 | if self.config.mode=='train_PC':
76 | #----------perceptual consistency---------------
77 | mask_top_dims = tf.sort(tf.math.top_k(input=VAE_loss['mask_kl'], k=2)[1])
78 | # choose latent dimensions with highest kl divergence to perturb, as those dimensions have semantic meanings
79 | self.new_imgs, self.new_labels = Perturbation_forward(
80 | var_num=6, image=self.image_batch, generated_masks=self.generated_masks,
81 | VAE_outputs0 = VAE_outputs, latent_zs = latent_zs,
82 | mask_top_dims=mask_top_dims)
83 |
84 | self.new_generated_masks = []
85 | PC_loss = []
86 | for i in range(6):
87 | with tf.name_scope("Generator/") as scope:
88 | new_masks, new_logits = Generator_forward(self.new_imgs[i], self.config.dataset,
89 | self.num_branch, model=self.config.model, training=True, reuse=tf.AUTO_REUSE, scope=scope)
90 | self.new_generated_masks.append(new_masks)
91 | loss = tf.nn.softmax_cross_entropy_with_logits(labels=tf.stop_gradient(self.new_labels[i]),
92 | logits=new_logits) #B H W 1
93 | loss = tf.reduce_sum(loss, axis=[1,2,3])/self.config.batch_size
94 | PC_loss.append(loss)
95 | self.loss['PC'] = tf.reduce_mean(tf.stack(PC_loss,axis=0)) #num_var -> average
96 | self.loss['CIS'] = self.loss['Generator']
97 | self.loss['Generator'] = self.loss['CIS'] + self.loss['PC']*self.config.ita
98 | else:
99 | self.mask_capacity = tf.placeholder(shape=(), name='mask_capacity', dtype=tf.float32)
100 | self.loss['tex_kl'], self.loss['mask_kl'], self.loss['bg_kl'] = VAE_loss['tex_kl'], VAE_loss['mask_kl'], VAE_loss['bg_kl']
101 | self.loss['tex_kl_sum'], self.loss['bg_kl_sum'] = tf.reduce_sum(VAE_loss['tex_kl']), tf.reduce_sum(VAE_loss['bg_kl'])
102 | self.loss['mask_kl_sum'] = tf.abs(tf.reduce_sum(VAE_loss['mask_kl'])-self.mask_capacity)
103 |
104 | self.loss['tex_error'], self.loss['mask_error'], self.loss['bg_error'], self.loss['VAE_fusion_error'] = \
105 | VAE_loss['tex_error'], VAE_loss['mask_error'], VAE_loss['bg_error'], VAE_loss['fusion_error']
106 |
107 |
108 | self.loss['VAE//separate/texVAE'] = self.loss['tex_error']+self.config.tex_beta*self.loss['tex_kl_sum']
109 | self.loss['VAE//separate/maskVAE'] = self.loss['mask_error']+self.config.mask_gamma*self.loss['mask_kl_sum']
110 | self.loss['VAE//separate/bgVAE'] = self.loss['bg_error']+self.config.bg_beta*self.loss['bg_kl_sum']
111 |
112 | self.loss['VAE//separate'] = self.loss['VAE//separate/texVAE']+self.loss['VAE//separate/maskVAE']+self.loss['VAE//separate/bgVAE']
113 | self.loss['VAE//fusion'] = self.loss['VAE_fusion_error']
114 |
115 |
116 |
117 | #-----------fuse all branch---------------
118 | foregrounds = VAE_outputs['out_fusion'] #B H W 3 (num_branch-1)
119 | background_mask = 1-tf.reduce_sum(VAE_outputs['out_masks'], axis=-1, keepdims=True) #B H W 1 1
120 | backgrounds = tf.expand_dims(VAE_outputs['out_bg'], axis=-1)*background_mask # B H W 3 1
121 |
122 | fusion_inputs = tf.concat([foregrounds, backgrounds], axis=-1) #B H W 3 fg_branch+1
123 | fusion_inputs = tf.reshape(fusion_inputs, [-1,fusion_inputs.get_shape()[1],fusion_inputs.get_shape()[2],3*fusion_inputs.get_shape()[4]])
124 |
125 | with tf.name_scope("Fusion") as scope:
126 | self.fusion_outputs = Fusion_forward(inputs=fusion_inputs, scope=scope, training=self.is_training, reuse=None)
127 | Fusion_error = tf.abs(self.fusion_outputs-self.image_batch) # B H W 3
128 | self.loss['Fusion'] = tf.reduce_mean(tf.reduce_sum(Fusion_error, axis=[1,2,3])) #B -> ,(average on batch)
129 |
130 | self.loss['CIS'] = self.loss['Generator']
131 | self.loss['Generator'] = (self.loss['tex_error']+self.loss['mask_error'])*self.config.VAE_weight + \
132 | self.loss['CIS']*self.config.CIS_weight
133 |
134 | self.VAE_outtexes, self.VAE_outtex_bg = VAE_outputs['out_texes'], VAE_outputs['out_bg']
135 | self.VAE_outmasks = VAE_outputs['out_masks'] #no background #B H W (num_branch-1)
136 | self.VAE_fusion = VAE_outputs['out_fusion']
137 |
138 | if self.config.dataset == 'flying_animals':
139 | #resize VAE_out
140 | import functools
141 | resize = functools.partial(tf_resize_imgs, size=[self.img_height, self.img_width])
142 | self.generated_masks, self.image_batch = resize(self.generated_masks), resize(self.image_batch)
143 | self.VAE_outtexes, self.VAE_outtex_bg, self.VAE_outmasks = resize(self.VAE_outtexes), resize(self.VAE_outtex_bg), resize(self.VAE_outmasks)
144 | self.VAE_fusion, self.fusion_outputs = resize(self.VAE_fusion), resize(self.fusion_outputs)
145 |
146 | #------------------------------
147 | with tf.name_scope('train_op'):
148 | self.global_step = tf.Variable(0, name='global_step', trainable=False)
149 | self.vae_global_step = tf.Variable(0, name='vae_global_step', trainable=False)
150 | self.incr_global_step = tf.assign(self.global_step, self.global_step+1)
151 | self.incr_vae_global_step = tf.assign(self.vae_global_step, self.vae_global_step+1)
152 | self.train_ops, self.train_vars_grads = self.get_train_ops_grads()
153 |
154 | with tf.name_scope('summary_vars'):
155 | self.loss['Generator_var'] = tf.Variable(0.0, name='Generator_var')
156 | self.loss['Generator_branch_var'] = tf.Variable(0.0, name='Generator_branch_var')
157 | self.loss['Inpainter_branch_var'] = tf.Variable(0.0, name='Inpainter_branch_var') # do tf.summary
158 | self.loss['Generator_denominator_var'] = tf.Variable(0.0, name='Generator_denominator_var')
159 | self.loss['EvalIoU_var'] = tf.Variable(0.0, name='EvalIoU_var')
160 | if self.config.mode == 'train_PC':
161 | self.switching_rate = tf.Variable(0.0, name='switching_rate')
162 | if self.config.mode in ['train_VAE','train_end2end']:
163 | self.loss['tex_kl_var'] = tf.Variable(0.0, name='tex_kl_var')
164 | self.loss['mask_kl_var'] = tf.Variable(0.0, name='mask_kl_var')
165 | self.loss['bg_kl_var'] = tf.Variable(0.0, name='bg_kl_var')
166 | #
167 | def get_train_ops_grads(self):
168 | #generate train_ops
169 | lr_dict = {'Generator':self.config.gen_lr, 'Inpainter':self.config.inp_lr,
170 | 'VAE//separate':self.config.VAE_lr, 'VAE//fusion':self.config.VAE_lr, 'Fusion':self.config.VAE_lr}
171 | scopes = mode2scopes[self.config.mode]
172 | train_ops, train_vars_grads = {}, {}
173 | for m in scopes:
174 | optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=lr_dict[m])
175 | train_vars = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, m)
176 | update_op = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS, m)
177 | train_ops[m], train_vars_grads[m] = train_op(loss=self.loss[m],var_list=train_vars, optimizer=optimizer)
178 | train_ops[m] = tf.group([train_ops[m], update_op])
179 | return train_ops, train_vars_grads
180 |
181 | def load_training_data(self):
182 | if self.config.dataset == 'multi_texture':
183 | return multi_texture_utils.dataset(self.config.root_dir, phase='train',
184 | batch_size=self.batch_size, max_num=self.config.max_num, PC=self.config.PC)
185 | elif self.config.dataset == 'flying_animals':
186 | return flying_animals_utils.dataset(self.config.root_dir, phase='train',
187 | batch_size=self.batch_size, max_num=self.config.max_num)
188 | elif self.config.dataset == 'multi_dsprites':
189 | return multi_dsprites_utils.dataset(self.config.root_dir, phase='train', batch_size=self.batch_size)
190 | elif self.config.dataset == 'objects_room':
191 | return objects_room_utils.dataset(self.config.root_dir, phase='train', batch_size=self.batch_size)
192 | else:
193 | raise IOError("Unknown Dataset")
194 | # B H W 3
195 | def load_val_data(self):
196 | if self.config.dataset == 'multi_texture':
197 | return multi_texture_utils.dataset(self.config.root_dir, phase='val',
198 | batch_size=8, max_num=self.config.max_num, PC=self.config.PC)
199 | elif self.config.dataset == 'flying_animals':
200 | return flying_animals_utils.dataset(self.config.root_dir, phase='val',
201 | batch_size=8, max_num=self.config.max_num)
202 | elif self.config.dataset == 'multi_dsprites':
203 | return multi_dsprites_utils.dataset(self.config.root_dir, phase='val', batch_size=8)
204 | elif self.config.dataset == 'objects_room':
205 | return objects_room_utils.dataset(self.config.root_dir, phase='val', batch_size=8)
206 | else:
207 | raise IOError("Unknown Dataset")
208 |
209 | def Random_boxes(self):
210 | b,h,w = self.batch_size, self.config.img_height, self.config.img_width
211 | r1 = tf.random.uniform(shape=[], minval=0, maxval=h*2//3, dtype=tf.int32)
212 | r2 = tf.random.uniform(shape=[], minval=r1+h//5, maxval=h-1, dtype=tf.int32)
213 | c1 = tf.random.uniform(shape=[], minval=0, maxval=w*2//3, dtype=tf.int32)
214 | c2 = tf.random.uniform(shape=[], minval=c1+w//5, maxval=w-1, dtype=tf.int32)
215 | ones = tf.ones([b,h,w,1])
216 | zeros = tf.zeros([b,h,w,1])
217 | random_box = tf.concat([zeros[:,0:r1,:,:],ones[:,r1:r2,:,:],zeros[:,r2:,:,:]], axis=1)
218 | random_box = tf.concat([zeros[:,:,0:c1,:],random_box[:,:,c1:c2,:],zeros[:,:,c2:,:]], axis=2)
219 | random_boxes = tf.stack([random_box, 1-random_box], axis=-1)
220 | #B H W 1
221 | return random_boxes #B H W 1 2
222 |
--------------------------------------------------------------------------------
/model/traverse_graph.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import os
3 | from .utils.generic_utils import train_op,myprint, myinput, tf_resize_imgs, reorder_mask, erode_dilate
4 | from .nets import Generator_forward, encoder, decoder, tex_mask_fusion, Fusion_forward, gaussian_kl
5 | import numpy as np
6 |
7 | class Traverse_Graph(object):
8 |
9 | def __init__(self, FLAGS):
10 | self.config = FLAGS
11 | #load data
12 | self.batch_size = FLAGS.batch_size
13 | self.num_branch = FLAGS.num_branch
14 | self.img_height, self.img_width = FLAGS.img_height, FLAGS.img_width
15 |
16 | assert self.config.traverse_type in ['tex', 'mask', 'bg']
17 | self.traverse_type = self.config.traverse_type
18 | if self.traverse_type == 'bg':
19 | self.traverse_branch = [self.num_branch-1]
20 | else:
21 | self.traverse_branch = [i for i in range(0,self.num_branch) if self.config.traverse_branch=='all' or str(i) in self.config.traverse_branch.split(',')]
22 |
23 | ndim_dict = {'tex':self.config.tex_dim, 'mask':self.config.mask_dim, 'bg':self.config.bg_dim}
24 | self.ndim = ndim_dict[self.config.traverse_type]
25 |
26 | self.traverse_value = list(np.linspace(self.config.traverse_start, self.config.traverse_end, 60))
27 | self.VAE_outputs = []
28 | self.traverse_results = []
29 |
30 | def build(self):
31 | self.image_batch0 = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None,self.img_height,self.img_width,3])
32 | with tf.name_scope("Generator") as scope:
33 | self.generated_masks, null = Generator_forward(self.image_batch0, self.config.dataset,
34 | self.num_branch, model=self.config.model, training=False, reuse=None, scope=scope)
35 |
36 | filter_masks = []
37 | for i in range(self.config.num_branch):
38 | filter_masks.append(erode_dilate(self.generated_masks[:,:,:,:,i]))
39 | self.generated_masks = tf.stack(filter_masks, axis=-1)#B H W 1 num_branch
40 | self.generated_masks = reorder_mask(self.generated_masks)
41 |
42 | if self.config.dataset == 'flying_animals':
43 | #resize image to (96, 128)
44 | self.image_batch = tf_resize_imgs(self.image_batch0, size=[self.img_height//2,self.img_width//2])
45 | self.generated_masks = tf_resize_imgs(self.generated_masks, size=[self.img_height//2, self.img_width//2])
46 | else:
47 | self.image_batch = tf.identity(self.image_batch0)
48 |
49 | segmented_img = self.generated_masks*tf.expand_dims(self.image_batch, axis=-1)#B H W 3 M
50 |
51 |
52 |
53 | #Fusion
54 | bg_mask = 1-tf.reduce_sum(self.generated_masks[:,:,:,:,:-1*self.config.n_bg], axis=-1)
55 | with tf.compat.v1.variable_scope('VAE//separate/bgVAE', reuse=tf.compat.v1.AUTO_REUSE):
56 | bg_z_mean, bg_z_log_sigma_sq = encoder(x=bg_mask*self.image_batch, latent_dim=self.config.bg_dim, training=False)
57 | out_bg_logit = decoder(bg_z_mean, output_ch=3, latent_dim=self.config.bg_dim, x=bg_mask*self.image_batch, training=False)
58 | out_bg = tf.nn.sigmoid(out_bg_logit)
59 |
60 | self.out_bg = out_bg
61 | self.in_bg = bg_mask*self.image_batch
62 | for i in self.traverse_branch:
63 | with tf.compat.v1.variable_scope('VAE//separate', reuse=tf.compat.v1.AUTO_REUSE):
64 | with tf.compat.v1.variable_scope('texVAE', reuse=tf.compat.v1.AUTO_REUSE):
65 | tex_z_mean, tex_z_log_sigma_sq = encoder(x=self.generated_masks[:,:,:,:,i]*self.image_batch, latent_dim=self.config.tex_dim, training=False)
66 | out_tex_logit = decoder(tex_z_mean, output_ch=3, latent_dim=self.config.tex_dim, x=self.generated_masks[:,:,:,:,i]*self.image_batch, training=False)
67 | out_tex = tf.nn.sigmoid(out_tex_logit)
68 |
69 | with tf.compat.v1.variable_scope('maskVAE', reuse=tf.compat.v1.AUTO_REUSE):
70 | mask_z_mean, mask_z_log_sigma_sq = encoder(x=self.generated_masks[:,:,:,:,i], latent_dim=self.config.mask_dim, training=False)
71 | out_mask_logit = decoder(mask_z_mean, output_ch=1, latent_dim=self.config.mask_dim, x=self.generated_masks[:,:,:,:,i], training=False)
72 | out_mask = tf.nn.sigmoid(out_mask_logit)
73 | if self.traverse_type=='bg':
74 | output_ch = 3
75 | scope = 'VAE//separate/bgVAE'
76 | z_mean, z_log_sigma_sq = bg_z_mean, bg_z_log_sigma_sq
77 | elif self.traverse_type=='tex':
78 | output_ch = 3
79 | scope = 'VAE//separate/texVAE'
80 | z_mean, z_log_sigma_sq = tex_z_mean, tex_z_log_sigma_sq
81 | else:
82 | output_ch = 1
83 | scope = 'VAE//separate/maskVAE'
84 | z_mean, z_log_sigma_sq = mask_z_mean, mask_z_log_sigma_sq
85 |
86 | KLs = gaussian_kl(z_mean, z_log_sigma_sq)[0] #B,zdim -> zdim,
87 | self.traverse_dim = tf.math.top_k(KLs, k=self.config.top_kdim, sorted=True).indices
88 |
89 | for d_count in range(self.config.top_kdim):
90 | d = self.traverse_dim[d_count]
91 | delta_unit = tf.one_hot(indices=[d], depth=self.ndim)
92 | delta_unit = tf.tile(delta_unit, [self.batch_size,1]) #B,dim
93 |
94 | for k in self.traverse_value:
95 | shifted_z = z_mean + delta_unit*k #B,dim
96 | with tf.compat.v1.variable_scope(scope, reuse=tf.compat.v1.AUTO_REUSE):
97 | out_logit = decoder(shifted_z, output_ch=output_ch, latent_dim=self.ndim, x=self.generated_masks[:,:,:,:,i], training=False)
98 | out = tf.nn.sigmoid(out_logit)
99 |
100 | if self.traverse_type=='bg':
101 | foregrounds = segmented_img[:,:,:,:,:-self.config.n_bg] #B H W 3 C
102 | backgrounds = tf.expand_dims(bg_mask*out, axis=-1) #B H W 3 1
103 | else:
104 | if self.traverse_type=='tex':
105 | tex,mask = out, self.generated_masks[:,:,:,:,i]
106 | else: #mask
107 | tex,mask = out_tex, out
108 | with tf.compat.v1.variable_scope('VAE//fusion', reuse=tf.compat.v1.AUTO_REUSE):
109 | VAE_fusion_out, null = tex_mask_fusion(tex=tex, mask=mask)
110 | foregrounds = tf.concat([segmented_img[:,:,:,:,0:i],
111 | tf.expand_dims(VAE_fusion_out, axis=-1),
112 | segmented_img[:,:,:,:,i+1:-self.config.n_bg]], axis=-1)
113 | background_mask = 1-tf.reduce_sum(self.generated_masks[:,:,:,:,0:i],axis=-1)-mask-tf.reduce_sum(self.generated_masks[:,:,:,:,i+1:-self.config.n_bg],axis=-1)
114 | backgrounds = tf.expand_dims(background_mask*out_bg, axis=-1)
115 |
116 | fusion_inputs = tf.concat([foregrounds, backgrounds], axis=-1)
117 | fusion_inputs = tf.reshape(fusion_inputs, [-1,fusion_inputs.get_shape()[1],fusion_inputs.get_shape()[2],3*fusion_inputs.get_shape()[4]])
118 | fusion_output = Fusion_forward(inputs=fusion_inputs, scope='Fusion/', training=False, reuse=tf.compat.v1.AUTO_REUSE)
119 | self.traverse_results.append(fusion_output)
120 | # def load_val_data(self):
121 | # #now only support multi_dsprites
122 | # if self.config.dataset == 'multi_texture':
123 | # return multi_texture_utils.dataset(self.config.root_dir, val=True,
124 | # batch_size=self.config.max_num, max_num=self.config.max_num,
125 | # zoom=('z' in self.config.variant),
126 | # rotation=('r' in self.config.variant),
127 | # texture_transform=self.config.texture_transform)
128 | # elif self.config.dataset == 'flying_animals':
129 | # return flying_animals_utils.dataset(self.config.root_dir,val=True,
130 | # batch_size=self.config.max_num, max_num=self.config.max_num)
131 | # elif self.config.dataset == 'multi_dsprites':
132 | # return multi_dsprites_utils.dataset(self.config.root_dir,val=True,
133 | # batch_size=self.batch_size, skipnum=0, takenum=self.config.skipnum,
134 | # shuffle=False, map_parallel_calls=1)
135 | # elif self.config.dataset == 'objects_room':
136 | # return objects_room_utils.dataset(self.config.root_dir,val=True,
137 | # batch_size=self.batch_size, skipnum=0, takenum=self.config.skipnum,
138 | # shuffle=False, map_parallel_calls=1)
139 | # else:
140 | # raise IOError("Unknown Dataset")
--------------------------------------------------------------------------------
/model/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenYutongTHU/Learning-to-manipulate-individual-objects-in-an-image-Implementation/db75a5505f7fe2c83c0ded08f425ef11759544bd/model/utils/__init__.py
--------------------------------------------------------------------------------
/model/utils/convolution_utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | def fully_connect(x, weight_shape, name, biased=True, bias_init_value=0.0):
3 | with tf.compat.v1.variable_scope(name):
4 | weight = tf.compat.v1.get_variable("weight", weight_shape, initializer=tf.contrib.layers.xavier_initializer())
5 | o = tf.matmul(x, weight)
6 | if biased:
7 | b = tf.get_variable('bias', shape=[weight_shape[-1]], initializer=tf.constant_initializer(bias_init_value))
8 | o = o+b
9 | return o
10 |
11 |
12 | def _dilated_conv2d(x, kernel_size, num_o, dilation_factor, name, biased=False):
13 | num_x = x.shape[-1].value #C last
14 | with tf.compat.v1.variable_scope(name):
15 | w = tf.get_variable('weights', shape=[kernel_size, kernel_size, num_x, num_o], initializer=tf.contrib.layers.xavier_initializer_conv2d())
16 | o = tf.nn.atrous_conv2d(x, w, dilation_factor, padding='SAME')
17 | if biased:
18 | b = tf.get_variable('biases', shape=[num_o], initializer=tf.compat.v1.constant_initializer(0.0))
19 | o = tf.nn.bias_add(o, b)
20 | return o
21 |
22 | def conv2d_transpose(x, filter_shape, output_shape, stride, name, padding='SAME', dilation=1, biased=True):
23 | with tf.compat.v1.variable_scope(name):
24 | w = tf.get_variable('weights', shape=filter_shape, initializer=tf.contrib.layers.xavier_initializer_conv2d())
25 | o = tf.nn.conv2d_transpose(x, filters=w, output_shape=output_shape,
26 | strides=stride, padding=padding, dilations=dilation)
27 | if biased:
28 | b = tf.get_variable('biases', shape=[filter_shape[-2]], initializer=tf.compat.v1.constant_initializer(0.0)) # output_channel input_channel
29 | o = tf.nn.bias_add(o, b)
30 | return o
31 |
32 | def conv2d(x, filter_shape, stride, name, padding='SAME', dilation=1, biased=True):
33 | with tf.compat.v1.variable_scope(name):
34 | w = tf.get_variable('weights', shape=filter_shape, initializer=tf.contrib.layers.xavier_initializer_conv2d())
35 | o = tf.nn.conv2d(x, filters=w, strides=stride, padding=padding, dilations=dilation)
36 | if biased:
37 | b = tf.get_variable('biases', shape=[filter_shape[-1]], initializer=tf.compat.v1.constant_initializer(0.0))
38 | o = tf.nn.bias_add(o, b)
39 | return o
40 |
41 | def deconv2d(x, filter_shape, output_size, name, padding='SAME', biased=True):
42 | with tf.compat.v1.variable_scope(name):
43 | x = tf.image.resize_images(x, size=output_size)
44 | w = tf.get_variable('weights', shape=filter_shape, initializer=tf.contrib.layers.xavier_initializer_conv2d())
45 | o = tf.nn.conv2d(x, filters=w, strides=1, padding=padding)
46 | if biased:
47 | b = tf.get_variable('biases', shape=[filter_shape[-1]], initializer=tf.compat.v1.constant_initializer(0.0))
48 | o = tf.nn.bias_add(o, b)
49 | return o
50 |
51 | def InstanceNorm(x):
52 | return tf.contrib.layers.instance_norm(x, center=False, scale=False, epsilon=1e-5, trainable=False)
53 |
54 | def gen_conv(x, cnum, ksize, stride=1, rate=1, name='conv',
55 | padding='SAME', activation=tf.nn.elu, training=True,
56 | kernel_initializer=None):
57 | """Define conv for generator.
58 | Args:
59 | x: Input.
60 | cnum: Channel number.
61 | ksize: Kernel size.
62 | stride: Convolution stride.
63 | Rate: Rate for or dilated conv.
64 | name: Name of layers.
65 | padding: Default to SYMMETRIC.
66 | activation: Activation function after convolution.
67 | training: If current graph is for training or inference, used for bn.
68 | Returns:
69 | tf.Tensor: output
70 | """
71 | x = tf.layers.conv2d(x,cnum, ksize, stride, dilation_rate=rate,
72 | activation=None, padding=padding, name=name,
73 | kernel_initializer=kernel_initializer)
74 | # We empirically found BN to help if not trained (works as regularizer)
75 | x = tf.layers.batch_normalization(x)
76 | x = activation(x)
77 |
78 | return x
79 |
80 | def gen_deconv(x, cnum, name='upsample', padding='SAME', training=True):
81 | """Define deconv for generator.
82 | The deconv is defined to be a x2 resize_nearest_neighbor operation with
83 | additional gen_conv operation.
84 | Args:
85 | x: Input.
86 | cnum: Channel number.
87 | name: Name of layers.
88 | training: If current graph is for training or inference, used for bn.
89 | Returns:
90 | tf.Tensor: output
91 | """
92 | with tf.compat.v1.variable_scope(name):
93 | x = resize(x, func=tf.compat.v1.image.resize_nearest_neighbor)
94 | x = gen_conv(
95 | x, cnum, 3, 1, name=name+'_conv', padding=padding,
96 | training=training)
97 | return x
98 |
99 | def resize(x, scale=2, to_shape=None, align_corners=True, dynamic=False,
100 | func=tf.compat.v1.image.resize_bilinear, name='resize'):
101 | """
102 | This resize operation is used to scale the input according to some given
103 | scale and function.
104 | """
105 |
106 | if dynamic:
107 | xs = tf.cast(tf.shape(x), tf.float32)
108 | new_xs = [tf.cast(xs[1]*scale, tf.int32),
109 | tf.cast(xs[2]*scale, tf.int32)]
110 | else:
111 | xs = x.get_shape().as_list()
112 | new_xs = [int(xs[1]*scale), int(xs[2]*scale)]
113 | with tf.compat.v1.variable_scope(name):
114 | if to_shape is None:
115 | x = func(x, new_xs, align_corners=align_corners)
116 | else:
117 | x = func(x, [to_shape[0], to_shape[1]],
118 | align_corners=align_corners)
119 | return x
120 |
121 |
122 | def conv(inputs, name, shape, stride, padding='SAME', dilations=1 ,reuse=None, training=True, activation=tf.nn.leaky_relu,
123 | init_w=tf.contrib.layers.xavier_initializer_conv2d(), init_b=tf.constant_initializer(0.0)):
124 | with tf.compat.v1.variable_scope(name, reuse=reuse) as scope:
125 | kernel = tf.contrib.framework.model_variable('weights', shape=shape, initializer=init_w, trainable=True)
126 | conv = tf.nn.conv2d(inputs, kernel, [1, stride, stride, 1], padding=padding, dilations=dilations)
127 | biases = tf.contrib.framework.model_variable('biases', shape=[shape[3]], initializer=init_b, trainable=True)
128 | conv = tf.nn.bias_add(conv, biases)
129 | if activation:
130 | conv = activation(conv)
131 | return conv
132 |
133 | def deconv(inputs, size, name, shape, reuse=None, training=True, activation=tf.nn.leaky_relu):
134 | deconv = tf.image.resize_images( inputs, size=size )
135 | deconv = conv( deconv, name=name, shape=shape, stride=1, reuse=reuse, training=training, activation=activation )
136 | return deconv
--------------------------------------------------------------------------------
/model/utils/generic_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import tensorflow as tf
4 | import random
5 | import os
6 | import imageio
7 | import math
8 |
9 |
10 | def erode_dilate(img):
11 | img = -1*tf.nn.max_pool2d(-1*img, ksize=3, strides=1, padding='SAME')
12 | img = tf.nn.max_pool2d(img, ksize=3, strides=1, padding='SAME')
13 | return img
14 |
15 | def tf_normalize_imgs(imgs):
16 | #normalize to 0~1
17 | #B H W 3
18 | max_value = tf.reduce_max(imgs, axis=[1,2], keepdims=True)
19 | min_value = tf.reduce_min(imgs, axis=[1,2], keepdims=True)
20 | imgs = tf.math.divide_no_nan(imgs-min_value, max_value-min_value)
21 | return imgs
22 |
23 | def tf_resize_imgs(imgs, size):
24 | #default bilinear
25 | shape = imgs.get_shape().as_list()
26 | assert len(shape) <=5 and len(shape)>=3
27 | if len(shape) <= 4: #B H W C
28 | return tf.image.resize(imgs, size=size, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
29 | else: #B H W C M
30 | resize_imgs = [tf.image.resize(imgs[:,:,:,:,i], size=size, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) for i in range(shape[-1])]
31 | resize_imgs = tf.stack(resize_imgs, axis=-1)
32 | return resize_imgs
33 |
34 | def reorder_mask(masks):
35 | #put background segmentation at the last channels
36 | B, H, W, C, M = masks.get_shape().as_list()
37 | pixel_sum = tf.reduce_sum(masks, axis=[0,1,2,3]) #C,
38 | ind = tf.math.argmax(pixel_sum)
39 | reordered = tf.concat([masks[:,:,:,:,0:ind],masks[:,:,:,:,ind+1:],masks[:,:,:,:,ind:ind+1]], axis=-1)
40 | reordered.set_shape([B,H,W,C,M])
41 | return reordered
42 |
43 | def myprint(string):
44 | print ("\033[0;30;42m"+string+"\033[0m")
45 |
46 | def myinput(string=""):
47 | return input ("\033[0;30;41m"+string+"\033[0m")
48 |
49 | def bin_edge_map(imgs, dataset):
50 | assert dataset in ['flying_animals', 'multi_texture', 'multi_dsprites', 'objects_room']
51 | if dataset == 'flying_animals':
52 | d_xy = tf.reduce_sum(tf.image.sobel_edges(imgs), axis=-2)
53 | d_xy = d_xy/15
54 | d_xy = tf.cast(tf.abs(d_xy)>0.04, tf.float32)
55 | else:
56 | d_xy = tf.abs(tf.image.sobel_edges(imgs)) # B H W 3 2
57 | d_xy = tf.reduce_sum(d_xy,axis=-2, keepdims=False) #B H W 2
58 | #adjust the range to 0~1
59 | max_val = tf.math.reduce_max(d_xy, axis=[1,2], keepdims=True)
60 | min_val = tf.math.reduce_min(d_xy, axis=[1,2], keepdims=True)
61 | d_xy = tf.math.divide_no_nan(d_xy-min_val,max_val-min_val) #0~1
62 | #sparse binary edge map
63 | threshold = {'multi_texture': 0.5, 'multi_dsprites':0.0001, 'objects_room':0.2}
64 | d_xy = tf.cast(d_xy>threshold[dataset], tf.float32)
65 | return d_xy
66 |
67 | def Permute_IoU(label, pred):
68 | A, B = label, pred
69 | H,W,nc,N = A.shape
70 | ans = 0
71 | nc = 0 #non-empty channel
72 | ans_perfg, arg_maxIoU = [], []
73 | for i in range(N):
74 | src = np.expand_dims(A[:,:,:,i], axis=-1)>0.04 #H W 1 1 #binary
75 | if np.sum(src) > 4:
76 | nc += 1
77 | trg = B>0.04 #H W 1 C
78 | U = np.sum(src+trg, axis=(0,1,2)) #H W 1 C -> C
79 | I = np.sum(src*trg, axis=(0,1,2)) #H W 1 C -> C
80 | eps = 1e-8
81 | IoU = I/(eps+U)
82 | ans += np.max(IoU)
83 | arg_maxIoU.append(np.argmax(IoU))
84 | else: #empty channel
85 | arg_maxIoU.append(-1)
86 | assert nc >0
87 | return ans/nc, arg_maxIoU
88 |
89 | def train_op(loss, var_list, optimizer, gradient_clip_value=-1):
90 | grads_and_vars = optimizer.compute_gradients(loss, var_list=var_list)
91 |
92 | clipped_grad_and_vars = [ (ClipIfNotNone(grad, gradient_clip_value),var) \
93 | for grad, var in grads_and_vars ]
94 | train_operation = optimizer.apply_gradients(clipped_grad_and_vars)
95 |
96 | return train_operation, clipped_grad_and_vars
97 |
98 | def ClipIfNotNone(grad, clipvalue):
99 | if clipvalue==-1:
100 | return grad
101 | else:
102 | return tf.clip_by_value(grad, -clipvalue, clipvalue)
--------------------------------------------------------------------------------
/model/utils/loss_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import tensorflow as tf
3 |
4 | import tensorflow as tf
5 |
6 | def region_error(X, Y, region):
7 | #A B H W c / B H W 3 M
8 | #B B H W 3 / B H W 3 M
9 | #region B H W 1 /B H W 1 M
10 | X = tf.cond(tf.equal(tf.rank(X), 5), lambda:X, lambda:tf.expand_dims(X, axis=-1))
11 | Y = tf.cond(tf.equal(tf.rank(Y), 5), lambda:Y, lambda:tf.expand_dims(Y, axis=-1))
12 | error = tf.abs(X-Y)*region #B H W 3 M
13 | error = tf.reduce_sum(error, axis=[1,2,3]) #B,M
14 | return error #B,M
15 |
16 | def Supervised_Generator_Loss(pred, GT):
17 | #B H W 1 C
18 | # use cross entropy
19 | error = tf.compat.v1.nn.softmax_cross_entropy_with_logits(labels=GT, logits=pred, dim=-1)#, axis=-1)#B H W 1 #???
20 | error = tf.reduce_mean(error) #B H W 1 -> perpixel
21 | return error
22 |
23 | def Generator_Loss(masks, pred_intensities, image, unconditioned_mean, epsilon):
24 | #masks B H W 1 C
25 | #pred_intensities B H W 3 C
26 | #unconditioned_mean B H W 3
27 | #image B H W 3
28 | numerator = region_error(pred_intensities, image, masks) #B,C
29 | denominator = epsilon + region_error(unconditioned_mean, image, masks) # B,C
30 | IRR = 1-tf.math.divide(numerator, denominator) #B,C
31 | perbranch_loss = tf.reduce_mean(IRR, axis=0) #C,
32 | loss = tf.reduce_sum(perbranch_loss)
33 | return loss, IRR, denominator, numerator #scalar,
34 |
35 | def Inpainter_Loss(masks, pred_intensities, image):
36 | #pred_intensities 0~1
37 | #image 0~1
38 | B, H, W, C = image.get_shape().as_list() #B H W C
39 | num_pixel = H*W*C
40 | loss0 = region_error(pred_intensities, image, masks) #B,M
41 | perbranch_loss = tf.reduce_mean(loss0, axis=0) # M,
42 | loss = tf.reduce_sum(perbranch_loss) #M->scalar
43 | loss = tf.math.divide(loss, num_pixel)
44 | return loss, loss0
--------------------------------------------------------------------------------
/sample_imgs/flying_animals/01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenYutongTHU/Learning-to-manipulate-individual-objects-in-an-image-Implementation/db75a5505f7fe2c83c0ded08f425ef11759544bd/sample_imgs/flying_animals/01.png
--------------------------------------------------------------------------------
/sample_imgs/flying_animals/02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenYutongTHU/Learning-to-manipulate-individual-objects-in-an-image-Implementation/db75a5505f7fe2c83c0ded08f425ef11759544bd/sample_imgs/flying_animals/02.png
--------------------------------------------------------------------------------
/sample_imgs/multi_dsprites/01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenYutongTHU/Learning-to-manipulate-individual-objects-in-an-image-Implementation/db75a5505f7fe2c83c0ded08f425ef11759544bd/sample_imgs/multi_dsprites/01.png
--------------------------------------------------------------------------------
/sample_imgs/multi_dsprites/02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenYutongTHU/Learning-to-manipulate-individual-objects-in-an-image-Implementation/db75a5505f7fe2c83c0ded08f425ef11759544bd/sample_imgs/multi_dsprites/02.png
--------------------------------------------------------------------------------
/sample_imgs/multi_texture/01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenYutongTHU/Learning-to-manipulate-individual-objects-in-an-image-Implementation/db75a5505f7fe2c83c0ded08f425ef11759544bd/sample_imgs/multi_texture/01.png
--------------------------------------------------------------------------------
/sample_imgs/objects_room/01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenYutongTHU/Learning-to-manipulate-individual-objects-in-an-image-Implementation/db75a5505f7fe2c83c0ded08f425ef11759544bd/sample_imgs/objects_room/01.png
--------------------------------------------------------------------------------
/sample_imgs/objects_room/02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenYutongTHU/Learning-to-manipulate-individual-objects-in-an-image-Implementation/db75a5505f7fe2c83c0ded08f425ef11759544bd/sample_imgs/objects_room/02.png
--------------------------------------------------------------------------------
/sample_imgs/objects_room/03.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenYutongTHU/Learning-to-manipulate-individual-objects-in-an-image-Implementation/db75a5505f7fe2c83c0ded08f425ef11759544bd/sample_imgs/objects_room/03.png
--------------------------------------------------------------------------------
/sample_imgs/objects_room/04.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenYutongTHU/Learning-to-manipulate-individual-objects-in-an-image-Implementation/db75a5505f7fe2c83c0ded08f425ef11759544bd/sample_imgs/objects_room/04.png
--------------------------------------------------------------------------------
/script/flying_animals/disentangle.sh:
--------------------------------------------------------------------------------
1 | ABSPATH=$(readlink -f $0)
2 |
3 |
4 | CUDA_VISIBLE_DEVICES=0 python main.py --sh_path=$ABSPATH \
5 | --checkpoint_dir=outputs/flying_animals/02_bg_range6 \
6 | --dataset=flying_animals --max_num=5 --num_branch=6 --root_dir=data/flying_animals_data/img_data.npz \
7 | --mode=eval_VAE --model=resnet_v2_50 \
8 | \
9 | \
10 | --fullmodel_ckpt='path to checkpoint' \
11 | --tex_dim=20 --bg_dim=30 --mask_dim=20 \
12 | \
13 | --input_img=sample_imgs/flying_animals/02.png \
14 | --traverse_type=bg --top_kdim=4 --traverse_branch=5 \
15 | --batch_size=1 --traverse_start=-6 --traverse_end=6 \
16 |
17 | #checkpoint_dir: folder containing outputs
18 |
19 | #input_img: target image (only support processing single image)
20 |
21 | #traverse_type:
22 | #mask: traverse shape latent space
23 | #tex: traverse texture/color latent space
24 | #bg: traverse background latent space
25 |
26 | #top_kdim:
27 | #choose k dimensions with largest kl divergence to traverse
28 | #these dimensions should encode k most significant variables of the object.
29 | #results of traversing dimension of kth largest kl divergence are output as branch{i}_var{k}.gif
30 |
31 |
32 | #traverse_branch: which branches to traverse
33 | #(only effective when traverse_type in ['tex','mask'])
34 | #'all': generate results of traversing all branches except the background branch
35 | #'0,1,2': mannually choose branches to traverse
36 |
--------------------------------------------------------------------------------
/script/flying_animals/pretrain_inpainter.sh:
--------------------------------------------------------------------------------
1 | ABSPATH=$(readlink -f $0)
2 |
3 |
4 | CUDA_VISIBLE_DEVICES=0 python main.py \
5 | --checkpoint_dir=checkpoint/flying_animals/pretrain_inpainter \
6 | --sh_path=$ABSPATH \
7 | \
8 | --dataset=flying_animals --root_dir=data/flying_animals_data/img_data.npz \
9 | --mode=pretrain_inpainter --max_num=5 \
10 | \
11 | --batch_size=4 --inp_lr=1e-4 \
12 | \
13 | --summaries_secs=180 --ckpt_secs=5000 \
--------------------------------------------------------------------------------
/script/flying_animals/test_segmentation.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python test_segmentation.py \
2 | --data_path ../data/flying_animals_data/img_data.npz \
3 | --ckpt_path 'path to checkpoint including segmentation network' \
4 | --batch_size 8 --num_branch 6 --dataset_name flying_animals
5 |
--------------------------------------------------------------------------------
/script/flying_animals/train_CIS.sh:
--------------------------------------------------------------------------------
1 | ABSPATH=$(readlink -f $0)
2 |
3 | CUDA_VISIBLE_DEVICES=0 python main.py --sh_path=$ABSPATH \
4 | --checkpoint_dir=checkpoint/flying_animals/CIS \
5 | --dataset=flying_animals --max_num=5 --num_branch=6 --root_dir=data/flying_animals_data/img_data.npz \
6 | --mode=train_CIS --model=resnet_v2_50 \
7 | \
8 | \
9 | --resume_inpainter=True --resume_fullmodel=False --resume_resnet=True \
10 | --fullmodel_ckpt=/ \
11 | --inpainter_ckpt='path of pretrained inpainter ckpt to resume here'
12 | 'e.g. checkpoint/flying_animals/pretrain_inpainter/inpainter-100000' \
13 | --resnet_ckpt=resnet/resnet_v2_50/resnet_v2_50.ckpt \
14 | \
15 | --batch_size=4 --inp_lr=1e-4 --gen_lr=1e-4 \
16 | --epsilon=100 --iters_inp=1 --iters_gen=2 \
17 | --ckpt_steps=10000 --summaries_steps=500 \
--------------------------------------------------------------------------------
/script/flying_animals/train_VAE.sh:
--------------------------------------------------------------------------------
1 | ABSPATH=$(readlink -f $0)
2 |
3 | CUDA_VISIBLE_DEVICES=0 python main.py --sh_path=$ABSPATH \
4 | --checkpoint_dir=checkpoint/flying_animals/VAE \
5 | --dataset=flying_animals --max_num=5 --num_branch=6 --root_dir=data/flying_animals_data/img_data.npz \
6 | --mode=train_VAE --model=resnet_v2_50 \
7 | \
8 | \
9 | --resume_CIS=True --resume_fullmodel=False \
10 | --fullmodel_ckpt=/ \
11 | --CIS_ckpt='path of CIS ckpt to resume here, e.g. checkpoint/flying_animals/CIS/model-100000' \
12 | \
13 | --batch_size=4 --VAE_lr=1e-4 \
14 | --tex_dim=20 --tex_beta=4 \
15 | --bg_dim=30 --bg_beta=4 \
16 | --mask_dim=20 --mask_gamma=500 --mask_capacity_inc=5e-5 \
17 | --ckpt_steps=10000 --summaries_steps=1000 \
--------------------------------------------------------------------------------
/script/multi_dsprites/disentangle.sh:
--------------------------------------------------------------------------------
1 | ABSPATH=$(readlink -f $0)
2 |
3 | CUDA_VISIBLE_DEVICES=0 python main.py --sh_path=$ABSPATH \
4 | --checkpoint_dir=outputs/multi_dsprites/02 \
5 | --dataset=multi_dsprites --num_branch=5 --root_dir=data/multi_dsprites_data/multi_dsprites_colored_on_colored.tfrecords \
6 | --mode=eval_VAE --model=resnet_v2_50 \
7 | \
8 | --fullmodel_ckpt='path to checkpoint containing segmentation network and encoder-decoder' \
9 | --tex_dim=5 --bg_dim=5 --mask_dim=10 \
10 | \
11 | --input_img=sample_imgs/multi_dsprites/02.png \
12 | --traverse_type=mask --top_kdim=5 --traverse_branch=2,3 \
13 | --batch_size=1 --traverse_range=0.5 \
14 |
15 | #checkpoint_dir: folder containing outputs
16 |
17 | #input_img: target image (only support processing single image)
18 |
19 | #traverse_type:
20 | #mask: traverse shape latent space
21 | #tex: traverse texture/color latent space
22 | #bg: traverse background latent space
23 |
24 | #top_kdim:
25 | #choose k dimensions with largest kl divergence to traverse
26 | #these dimensions should encode k most significant variables of the object.
27 | #results of traversing dimension of kth largest kl divergence are output as branch{i}_var{k}.gif
28 |
29 |
30 | #traverse_branch: which branches to traverse
31 | #(only effective when traverse_type in ['tex','mask'])
32 | #'all': generate results of traversing all branches except the background branch
33 | #'0,1,2': mannually choose branches to traverse
34 |
35 |
--------------------------------------------------------------------------------
/script/multi_dsprites/pretrain_inpainter.sh:
--------------------------------------------------------------------------------
1 | ABSPATH=$(readlink -f $0)
2 |
3 |
4 | CUDA_VISIBLE_DEVICES=0 python main.py \
5 | --checkpoint_dir=checkpoint/multi_dsprites/pretrain_inpainter \
6 | --sh_path=$ABSPATH \
7 | \
8 | --dataset=multi_dsprites --root_dir=data/multi_dsprites_data/multi_dsprites_colored_on_colored.tfrecords \
9 | --mode=pretrain_inpainter \
10 | \
11 | --batch_size=16 --inp_lr=1e-4 \
12 | \
13 | --summaries_secs=60 --ckpt_secs=10000 \
--------------------------------------------------------------------------------
/script/multi_dsprites/test_segmentation.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python test_segmentation.py \
2 | --data_path ../data/multi_dsprites_data/multi_dsprites_colored_on_colored.tfrecords \
3 | --ckpt_path 'path to checkpoint including segmentation network' \
4 | --batch_size 8 --num_branch 5 --dataset_name multi_dsprites
--------------------------------------------------------------------------------
/script/multi_dsprites/train_CIS.sh:
--------------------------------------------------------------------------------
1 | ABSPATH=$(readlink -f $0)
2 |
3 |
4 | CUDA_VISIBLE_DEVICES=0 python main.py --sh_path=$ABSPATH \
5 | --checkpoint_dir=checkpoint/multi_dsprites/CIS \
6 | --dataset=multi_dsprites --num_branch=5 --root_dir=data/multi_dsprites_data/multi_dsprites_colored_on_colored.tfrecords \
7 | --mode=train_CIS --model=resnet_v2_50 \
8 | \
9 | \
10 | --resume_inpainter=True --resume_fullmodel=False \
11 | --fullmodel_ckpt=/ \
12 | --inpainter_ckpt='path of pretrained inpainter ckpt to resume here'
13 | 'e.g. checkpoint/multi_dsprites/pretrain_inpainter/inpainter-50000' \
14 | \
15 | --batch_size=8 --inp_lr=1e-4 --gen_lr=3e-5 \
16 | --epsilon=30 --iters_inp=1 --iters_gen=3 \
17 | --ckpt_steps=4000 \
18 |
19 |
--------------------------------------------------------------------------------
/script/multi_dsprites/train_VAE.sh:
--------------------------------------------------------------------------------
1 | ABSPATH=$(readlink -f $0)
2 |
3 |
4 | CUDA_VISIBLE_DEVICES=0 python main.py --sh_path=$ABSPATH \
5 | --checkpoint_dir=checkpoint/multi_dsprites/VAE \
6 | --dataset=multi_dsprites --num_branch=5 --root_dir=data/multi_dsprites_data/multi_dsprites_colored_on_colored.tfrecords \
7 | --mode=train_VAE --model=resnet_v2_50 \
8 | \
9 | \
10 | --resume_CIS=True --resume_fullmodel=False \
11 | --fullmodel_ckpt= \
12 | --CIS_ckpt='path of CIS ckpt to resume here, e.g. checkpoint/multi_dsprites/CIS/model-100000' \
13 | \
14 | --batch_size=16 --VAE_lr=1e-4 \
15 | --tex_dim=5 --tex_beta=6 \
16 | --bg_dim=5 --bg_beta=6 \
17 | --mask_dim=10 --mask_gamma=500 --mask_capacity_inc=2e-5 \
18 | --ckpt_steps=20000 --summaries_steps=1000 \
19 |
20 |
--------------------------------------------------------------------------------
/script/multi_texture/disentangle.sh:
--------------------------------------------------------------------------------
1 | ABSPATH=$(readlink -f $0)
2 |
3 | CUDA_VISIBLE_DEVICES=0 python main.py --sh_path=$ABSPATH \
4 | --checkpoint_dir=outputs/multi_texture \
5 | --dataset=multi_texture --max_num=4 --num_branch=5 --root_dir=data/multi_texture_data/ \
6 | --mode=eval_VAE --model=resnet_v2_50 \
7 | \
8 | \
9 | --fullmodel_ckpt='path to checkpoint' \
10 | --tex_dim=5 --mask_dim=10 --bg_dim=5 \
11 | \
12 | --input_img=sample_imgs/multi_texture/01.png \
13 | --traverse_type=mask --top_kdim=2 --traverse_branch=0,1,2,3 --traverse_range=0.5 \
14 | --batch_size=1 \
15 |
16 | #checkpoint_dir: folder containing outputs
17 |
18 | #input_img: target image (only support processing single image)
19 |
20 | #traverse_type:
21 | #mask: traverse shape latent space
22 | #tex: traverse texture/color latent space
23 | #bg: traverse background latent space
24 |
25 | #top_kdim:
26 | #choose k dimensions with largest kl divergence to traverse
27 | #these dimensions should encode k most significant variables of the object.
28 | #results of traversing dimension of kth largest kl divergence are output as branch{i}_var{k}.gif
29 |
30 |
31 | #traverse_branch: which branches to traverse
32 | #(only effective when traverse_type in ['tex','mask'])
33 | #'all': generate results of traversing all branches except the background branch
34 | #'0,1,2': mannually choose branches to traverse
35 |
--------------------------------------------------------------------------------
/script/multi_texture/perceptual_consistency/finetune_PC.sh:
--------------------------------------------------------------------------------
1 | ABSPATH=$(readlink -f $0)
2 |
3 | CUDA_VISIBLE_DEVICES=0 python main.py --sh_path=$ABSPATH \
4 | --checkpoint_dir=checkpoint/multi_texture/CIS/PC/finetune \
5 | --dataset=multi_texture --max_num=2 --num_branch=3 --root_dir=data/multi_texture_data/ \
6 | --mode=train_PC --model=segnet --PC=True \
7 | \
8 | --fullmodel_ckpt='path of pretrained ckpt to finetune here'
9 | 'e.g. checkpoint/multi_texture/pc/model-0' \
10 | --batch_size=4 --inp_lr=1e-4 --gen_lr=3e-5 \
11 | --bg_dim=5 --tex_dim=5 --mask_dim=10 \
12 | --epsilon=50 --ita=1e-3 --iters_inp=1 --iters_gen_vae=3 \
13 | --ckpt_steps=10000 --summaries_steps=100 \
14 |
--------------------------------------------------------------------------------
/script/multi_texture/perceptual_consistency/test_segmentation.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python test_segmentation.py \
2 | --data_path data/multi_texture_data \
3 | --ckpt_path 'path to checkpoint including segmentation network' \
4 | --batch_size 8 --num_branch 3 --dataset_name multi_texture --PC=True --model=segnet
5 |
--------------------------------------------------------------------------------
/script/multi_texture/perceptual_consistency/train_CIS.sh:
--------------------------------------------------------------------------------
1 | ABSPATH=$(readlink -f $0)
2 |
3 | CUDA_VISIBLE_DEVICES=0 python main.py --sh_path=$ABSPATH \
4 | --checkpoint_dir=checkpoint/multi_texture/CIS/PC/CIS \
5 | --dataset=multi_texture --max_num=2 --num_branch=3 --root_dir=data/multi_texture_data/ \
6 | --mode=train_CIS --model=segnet --PC=True \
7 | \
8 | \
9 | --resume_inpainter=True --resume_fullmodel=False \
10 | --fullmodel_ckpt= \
11 | --inpainter_ckpt='path of pretrained inpainter ckpt to resume here'
12 | 'e.g. checkpoint/multi_texture/pretrain_inpainter/inpainter-50000' \
13 | \
14 | --batch_size=8 --inp_lr=1e-4 --gen_lr=3e-5 \
15 | --epsilon=50 --iters_inp=1 --iters_gen=3 \
16 | --ckpt_steps=5000 \
--------------------------------------------------------------------------------
/script/multi_texture/perceptual_consistency/train_VAE.sh:
--------------------------------------------------------------------------------
1 | ABSPATH=$(readlink -f $0)
2 |
3 |
4 | CUDA_VISIBLE_DEVICES=0 python main.py --sh_path=$ABSPATH \
5 | --checkpoint_dir=checkpoint/multi_texture/CIS/PC/VAE \
6 | --dataset=multi_texture --max_num=2 --num_branch=3 --root_dir=data/multi_texture_data/ --PC=True \
7 | --mode=train_VAE --model=segnet \
8 | \
9 | \
10 | --resume_CIS=True --resume_fullmodel=False \
11 | --fullmodel_ckpt= \
12 | --CIS_ckpt='path of CIS ckpt to resume here' \
13 | \
14 | --batch_size=16 --VAE_lr=1e-4 \
15 | --tex_dim=5 --tex_beta=4 \
16 | --bg_dim=5 --bg_beta=6 \
17 | --mask_dim=10 --mask_gamma=500 --mask_capacity_inc=2e-5 \
18 | --ckpt_steps=20000 --summaries_steps=1000 \
--------------------------------------------------------------------------------
/script/multi_texture/pretrain_inpainter.sh:
--------------------------------------------------------------------------------
1 | ABSPATH=$(readlink -f $0)
2 |
3 |
4 | CUDA_VISIBLE_DEVICES=0 python main.py \
5 | --checkpoint_dir=checkpoint/multi_texture/pretrain_inpainter \
6 | --sh_path=$ABSPATH \
7 | \
8 | --dataset=multi_texture --max_num=4 --root_dir=data/multi_texture_data \
9 | --mode=pretrain_inpainter \
10 | \
11 | --batch_size=16 --inp_lr=1e-4 \
12 | \
13 | --summaries_secs=60 --ckpt_secs=10000 \
--------------------------------------------------------------------------------
/script/multi_texture/test_segmentation.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python test_segmentation.py \
2 | --data_path ../data/multi_texture_data \
3 | --ckpt_path 'path to checkpoint including segmentation network' \
4 | --batch_size 8 --num_branch 5 --dataset_name multi_texture
--------------------------------------------------------------------------------
/script/multi_texture/train_CIS.sh:
--------------------------------------------------------------------------------
1 | ABSPATH=$(readlink -f $0)
2 |
3 | CUDA_VISIBLE_DEVICES=0 python main.py --sh_path=$ABSPATH \
4 | --checkpoint_dir=checkpoint/multi_texture/CIS \
5 | --dataset=multi_texture --max_num=4 --num_branch=5 --root_dir=data/multi_texture_data/ \
6 | --mode=train_CIS --model=resnet_v2_50 \
7 | \
8 | \
9 | --resume_inpainter=True --resume_fullmodel=False \
10 | --fullmodel_ckpt=/ \
11 | --inpainter_ckpt='path of pretrained inpainter ckpt to resume here'
12 | 'e.g. checkpoint/multi_texture/pretrain_inpainter/inpainter-50000' \
13 | \
14 | --batch_size=8 --inp_lr=1e-4 --gen_lr=3e-5 \
15 | --epsilon=30 --iters_inp=1 --iters_gen=3 \
16 | --ckpt_steps=4000 \
--------------------------------------------------------------------------------
/script/multi_texture/train_VAE.sh:
--------------------------------------------------------------------------------
1 | ABSPATH=$(readlink -f $0)
2 |
3 | CUDA_VISIBLE_DEVICES=0 python main.py --sh_path=$ABSPATH \
4 | --checkpoint_dir=checkpoint/multi_texture/VAE \
5 | --dataset=multi_texture --max_num=4 --num_branch=5 --root_dir=data/multi_texture_data/ \
6 | --mode=train_VAE --model=resnet_v2_50 \
7 | \
8 | \
9 | --resume_CIS=True --resume_fullmodel=False \
10 | --fullmodel_ckpt= \
11 | --CIS_ckpt='path of CIS ckpt to resume here, e.g. checkpoint/multi_texture/CIS/model-100000' \
12 | \
13 | --batch_size=16 --VAE_lr=1e-4 \
14 | --tex_dim=5 --tex_beta=4 \
15 | --bg_dim=5 --bg_beta=6 \
16 | --mask_dim=10 --mask_gamma=500 --mask_capacity_inc=2e-5 \
17 | --ckpt_steps=20000 --summaries_steps=1000 \
--------------------------------------------------------------------------------
/script/objects_room/disentangle.sh:
--------------------------------------------------------------------------------
1 | ABSPATH=$(readlink -f $0)
2 | CUDA_VISIBLE_DEVICES=0 python main.py --sh_path=$ABSPATH \
3 | --checkpoint_dir=outputs/objects_room/tex_split \
4 | --dataset=objects_room --num_branch=6 --root_dir=data/objects_room_data/objects_room_train.tfrecords \
5 | --mode=eval_VAE --model=resnet_v2_50 \
6 | \
7 | \
8 | --fullmodel_ckpt='path to model checkpoint' \
9 | --tex_dim=5 --bg_dim=10 --mask_dim=10 \
10 | \
11 | --input_img=sample_imgs/objects_room/01.png \
12 | --traverse_type=tex --top_kdim=1 --traverse_branch=0,2 --traverse_start=-0.5 --traverse_end=1.5 \
13 | --batch_size=1 \
14 |
15 | #checkpoint_dir: folder containing outputs
16 |
17 | #input_img: target image (only support processing single image)
18 |
19 | #traverse_type:
20 | #mask: traverse shape latent space
21 | #tex: traverse texture/color latent space
22 | #bg: traverse background latent space
23 |
24 | #top_kdim:
25 | #choose k dimensions with the largest kl divergence to traverse
26 | #these dimensions should encode k most significant variables of the object.
27 | #results of traversing dimension of kth largest kl divergence are output as branch{i}_var{k}.gif
28 |
29 |
30 | #traverse_branch: which branches to traverse
31 | #(only effective when traverse_type in ['tex','mask'])
32 | #'all': generate results of traversing all branches except the background branch
33 | #'0,1,2': mannually choose branches to traverse
--------------------------------------------------------------------------------
/script/objects_room/pretrain_inpainter.sh:
--------------------------------------------------------------------------------
1 | ABSPATH=$(readlink -f $0)
2 |
3 | CUDA_VISIBLE_DEVICES=0 python main.py \
4 | --checkpoint_dir=checkpoint/objects_room/pretrain_inpainter \
5 | --sh_path=$ABSPATH \
6 | \
7 | --dataset=objects_room --root_dir=data/objects_room_data/objects_room_train.tfrecords \
8 | --mode=pretrain_inpainter \
9 | \
10 | --batch_size=16 --inp_lr=3e-5 \
11 | \
12 | --summaries_secs=60 --ckpt_secs=10000 \
13 |
14 |
--------------------------------------------------------------------------------
/script/objects_room/test_segmentation.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python test_segmentation.py \
2 | --data_path ../data/objects_room_data/objects_room_train.tfrecords \
3 | --ckpt_path 'path to checkpoint including segmentation network' \
4 | --batch_size 8 --num_branch 6 --dataset_name objects_room
5 |
--------------------------------------------------------------------------------
/script/objects_room/train_CIS.sh:
--------------------------------------------------------------------------------
1 | ABSPATH=$(readlink -f $0)
2 |
3 | CUDA_VISIBLE_DEVICES=0 python main.py --sh_path=$ABSPATH \
4 | --checkpoint_dir=checkpoint/objects_room/CIS \
5 | --dataset=objects_room --num_branch=6 --root_dir=data/objects_room_data/objects_room_train.tfrecords \
6 | --mode=train_CIS --model=resnet_v2_50 \
7 | \
8 | \
9 | --resume_inpainter=True --resume_fullmodel=False \
10 | --fullmodel_ckpt=/ \
11 | --inpainter_ckpt='path of pretrained inpainter ckpt to resume here'
12 | 'e.g. checkpoint/objects_room/pretrain_inpainter/inpainter-50000' \
13 | \
14 | --batch_size=8 --inp_lr=1e-4 --gen_lr=3e-5 \
15 | --epsilon=30 --iters_inp=1 --iters_gen=3 \
16 | --ckpt_steps=4000 \
--------------------------------------------------------------------------------
/script/objects_room/train_VAE.sh:
--------------------------------------------------------------------------------
1 | ABSPATH=$(readlink -f $0)
2 |
3 |
4 | CUDA_VISIBLE_DEVICES=0 python main.py --sh_path=$ABSPATH \
5 | --checkpoint_dir=checkpoint/objects_room/VAE \
6 | --dataset=objects_room --num_branch=6 --root_dir=data/objects_room_data/objects_room_train.tfrecords \
7 | --mode=train_VAE --model=resnet_v2_50 \
8 | \
9 | \
10 | --resume_CIS=True --resume_fullmodel=False \
11 | --fullmodel_ckpt=/ \
12 | --CIS_ckpt='path of CIS ckpt to resume here, e.g. checkpoint/objects_room/CIS/model-100000' \
13 | \
14 | --batch_size=8 --VAE_lr=1e-4 \
15 | --tex_dim=5 --tex_beta=6 \
16 | --bg_dim=5 --bg_beta=6 \
17 | --mask_dim=10 --mask_gamma=500 --mask_capacity_inc=2e-4 \
18 | --ckpt_steps=20000 --summaries_steps=1000 \
--------------------------------------------------------------------------------
/tb.sh:
--------------------------------------------------------------------------------
1 | tensorboard \
2 | --logdir='path to logdir' \
3 | --port=6106
4 |
--------------------------------------------------------------------------------
/test_segmentation.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import os
3 | import sys
4 | sys.path.append("..")
5 | import pprint
6 | from model.utils.generic_utils import myprint, myinput, Permute_IoU, reorder_mask
7 | from model.nets import Generator_forward
8 | import imageio
9 | import numpy as np
10 | import time
11 | from data import multi_texture_utils, flying_animals_utils, objects_room_utils, multi_dsprites_utils
12 | import argparse
13 |
14 | parser = argparse.ArgumentParser(description='test the segmentation mean IoU')
15 | parser.add_argument('--data_path', type=str, help='dir of the test set')
16 | parser.add_argument('--batch_size', default=8, type=int, help='batchsize', required=False)
17 | parser.add_argument('--dataset_name', default='flying_animals', type=str, help='flying_animals / multi_texture / multi_dsprites / objects_room')
18 | parser.add_argument('--ckpt_path', default='./', type=str)
19 | parser.add_argument('--num_branch', default=6, type=int, help='#branch should match the checkpoint and network')
20 | parser.add_argument('--PC',default=False,type=bool, help='whether to test perceptual consistency, output identity switching rate')
21 | parser.add_argument('--model',default='resnet_v2_50',type=str, help='segmentation network model, resnet_v2_50 or segnet')
22 | args = parser.parse_args()
23 |
24 | #usage python test/test_segmentation.py arg1 arg2 arg3
25 |
26 |
27 | data_path = args.data_path
28 | batch_size = args.batch_size
29 | dataset_name = args.dataset_name
30 | num_branch = args.num_branch
31 | PC, model = args.PC, args.model
32 |
33 | if dataset_name == 'flying_animals':
34 | img_height, img_width = 192, 256
35 | dataset = flying_animals_utils.dataset(data_path=data_path,batch_size=batch_size, max_num=5, phase='test')
36 | elif dataset_name == 'multi_texture':
37 | img_height, img_width = 64, 64
38 | dataset = multi_texture_utils.dataset(data_path=data_path, batch_size=batch_size, max_num=2 if PC else 4, phase='test',PC=PC)
39 | elif dataset_name == 'multi_dsprites':
40 | img_height, img_width = 64, 64
41 | dataset = multi_dsprites_utils.dataset(tfrecords_path=data_path,batch_size=batch_size, phase='test')
42 | elif dataset_name == 'objects_room':
43 | img_height, img_width = 64, 64
44 | dataset = objects_room_utils.dataset(tfrecords_path=data_path,batch_size=batch_size, phase='test')
45 |
46 | ckpt_path = args.ckpt_path
47 |
48 |
49 | iterator = dataset.make_one_shot_iterator()
50 | test_batch = iterator.get_next()
51 | img, tf_GT_masks = test_batch['img'], test_batch['masks']
52 | img.set_shape([batch_size, img_height, img_width, 3])
53 | tf_GT_masks.set_shape([batch_size, img_height, img_width, 1, None])
54 |
55 | with tf.name_scope("Generator") as scope:
56 | tf_generated_masks, null = Generator_forward(img, dataset_name,
57 | num_branch, model=model, training=False, reuse=None, scope=scope)
58 | tf_generated_masks = reorder_mask(tf_generated_masks) #place the background at the last channel
59 |
60 | restore_vars = tf.global_variables('Generator')
61 | saver = tf.train.Saver(restore_vars)
62 |
63 | with tf.Session() as sess:
64 | saver.restore(sess, ckpt_path)
65 | scores = []
66 | fetches = {'GT_masks': tf_GT_masks, 'generated_masks': tf_generated_masks, 'img':img}
67 |
68 | if PC:
69 | #test perceptual consistency
70 | num = 9*9*9*9-1
71 | assert num%batch_size==0
72 | niter = num//batch_size
73 | score, arg_maxIoUs = 0, []
74 | for u in range(niter):
75 | results = sess.run(fetches)
76 | for j in range(batch_size):
77 | s, arg_maxIoU = Permute_IoU(label=results['GT_masks'][j], pred=results['generated_masks'][j])
78 | score += s
79 | arg_maxIoUs.append(arg_maxIoU)
80 | score = score/num
81 | arg_maxIoUs = np.stack(arg_maxIoUs, axis=0)
82 | count = np.sum(arg_maxIoUs, axis=0)
83 | switching_rate = np.min(count)/num
84 | print("IoU: {} identity switching rate: {} ".format(score, switching_rate))
85 |
86 |
87 | else:
88 | for i in range(10): #10 subsets
89 | #200 images in each subset
90 | assert 200%batch_size==0
91 | niter = 200//batch_size
92 | score = []
93 | for u in range(niter):
94 | results = sess.run(fetches)
95 | for j in range(batch_size):
96 | s, null = Permute_IoU(label=results['GT_masks'][j], pred=results['generated_masks'][j])
97 | score.append(s)
98 | scores.append(score) #10*200
99 | print("subset {}: mean {} variance{}\n".format(i+1, np.mean(scores[i]), np.var(scores[i])))
100 | mean_IoU = np.mean(scores, axis=-1) #10,
101 | print("mean of mean_IoU: {} std of mean_IoU: {}\n".format(np.mean(mean_IoU), np.std(mean_IoU, ddof=1)))
102 |
103 |
--------------------------------------------------------------------------------
/trainer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenYutongTHU/Learning-to-manipulate-individual-objects-in-an-image-Implementation/db75a5505f7fe2c83c0ded08f425ef11759544bd/trainer/__init__.py
--------------------------------------------------------------------------------
/trainer/train_CIS.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import os
3 | import gflags
4 | #https://github.com/google/python-gflags
5 | import sys
6 | sys.path.append("..")
7 | import pprint
8 | from keras.utils.generic_utils import Progbar
9 | import model.Summary as Summary
10 | from model.utils.generic_utils import myprint, myinput, Permute_IoU
11 | from model.train_graph import Train_Graph
12 | import imageio
13 | import numpy as np
14 | import time
15 | import re
16 |
17 | def train(FLAGS):
18 | graph = Train_Graph(FLAGS)
19 | graph.build()
20 |
21 | summary_op, generator_summary_op, branch_summary_op, eval_summary_op = Summary.collect_CIS_summary(graph, FLAGS)
22 | with tf.name_scope("parameter_count"):
23 | total_parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) \
24 | for v in tf.trainable_variables()])
25 |
26 | save_vars = tf.global_variables('Inpainter')+tf.global_variables('Generator')+ \
27 | tf.global_variables('train_op') #including global step
28 | if FLAGS.resume_inpainter:
29 | assert os.path.isfile(FLAGS.inpainter_ckpt+'.index')
30 | inpainter_saver = tf.train.Saver(tf.trainable_variables('Inpainter'))#only restore the trainable variables
31 |
32 | if FLAGS.resume_resnet:
33 | assert os.path.isfile(FLAGS.resnet_ckpt)
34 | resnet_reader=tf.compat.v1.train.NewCheckpointReader(FLAGS.resnet_ckpt)
35 | resnet_map = resnet_reader.get_variable_to_shape_map()
36 | resnet_dict = dict()
37 | for v in tf.trainable_variables('Generator//resnet_v2'):
38 | if 'resnet_v2_50/'+v.op.name[21:] in resnet_map.keys():
39 | resnet_dict['resnet_v2_50/'+v.op.name[21:]] = v
40 | resnet_var_name = [v.name for v in tf.trainable_variables('Generator//resnet_v2') \
41 | if 'resnet_v2_50/'+v.op.name[21:] in resnet_map.keys()]
42 | resnet_saver = tf.train.Saver(resnet_dict)
43 |
44 | saver = tf.train.Saver(save_vars, max_to_keep=100)
45 | branch_writers = [tf.summary.FileWriter(os.path.join(FLAGS.checkpoint_dir, "branch"+str(m))) \
46 | for m in range(FLAGS.num_branch)] #save generator loss for each branch
47 | sv = tf.train.Supervisor(logdir=os.path.join(FLAGS.checkpoint_dir, "CIS_Sum"),
48 | saver=None, save_summaries_secs=0)
49 |
50 | with sv.managed_session() as sess:
51 | myprint ("Number of total params: {0} \n".format( \
52 | sess.run(total_parameter_count)))
53 | if FLAGS.resume_fullmodel:
54 | assert os.path.isfile(FLAGS.fullmodel_ckpt+'.index')
55 | saver.restore(sess, FLAGS.fullmodel_ckpt)
56 | myprint ("Resumed training from model {}".format(FLAGS.fullmodel_ckpt))
57 | myprint ("Start from step {}".format(sess.run(graph.global_step)))
58 | myprint ("Save checkpoint in {}".format(FLAGS.checkpoint_dir))
59 | if not os.path.dirname(FLAGS.fullmodel_ckpt) == FLAGS.checkpoint_dir:
60 | print ("\033[0;30;41m"+"Warning: checkpoint dir and fullmodel ckpt do not match"+"\033[0m")
61 | myprint ("Please make sure that new checkpoint will be saved in the same dir with the resumed model")
62 | else:
63 | if FLAGS.resume_inpainter:
64 | assert os.path.isfile(FLAGS.inpainter_ckpt+'.index')
65 | inpainter_saver.restore(sess, FLAGS.inpainter_ckpt)
66 | myprint ("Load pretrained inpainter {}".format(FLAGS.inpainter_ckpt))
67 |
68 | if FLAGS.resume_resnet:
69 | resnet_saver.restore(sess, FLAGS.resnet_ckpt)
70 | myprint ("Load pretrained resnet {}".format(FLAGS.resnet_ckpt))
71 | if not FLAGS.resume_resnet and not FLAGS.resume_inpainter:
72 | myprint ("Train from scratch")
73 | myinput('Press enter to continue')
74 |
75 | start_time = time.time()
76 | step = sess.run(graph.global_step)
77 | progbar = Progbar(target=FLAGS.ckpt_steps) #100k
78 |
79 | sum_iters = FLAGS.iters_gen + FLAGS.iters_inp
80 |
81 | while (time.time()-start_time) should have an VAE step
78 |
79 | if vae_step % FLAGS.summaries_steps == 0:
80 | fetches['tex_kl'], fetches['mask_kl'], fetches['bg_kl'] = graph.loss['tex_kl'], graph.loss['mask_kl'], graph.loss['bg_kl']
81 | fetches['Fusion'] = graph.loss['Fusion']
82 | fetches['summary'] = summary_op
83 |
84 |
85 | results = sess.run(fetches, feed_dict={graph.is_training: True, graph.mask_capacity: mask_capacity})
86 | progbar.update(vae_step%FLAGS.ckpt_steps)
87 |
88 | if vae_step % FLAGS.summaries_steps == 0 :
89 | print (" Step:%3dk time:%4.4fmin " \
90 | %(vae_step/1000, (time.time()-start_time)/60))
91 | sv.summary_writer.add_summary(results['summary'], vae_step)
92 |
93 | for d in range(FLAGS.tex_dim):
94 | tex_summary = sess.run(tex_latent_summary_op, feed_dict={graph.loss['tex_kl_var']: results['tex_kl'][d]})
95 | tex_latent_writers[d].add_summary(tex_summary, vae_step)
96 |
97 | for d in range(FLAGS.bg_dim):
98 | bg_summary = sess.run(bg_latent_summary_op, feed_dict={graph.loss['bg_kl_var']: results['bg_kl'][d]})
99 | bg_latent_writers[d].add_summary(bg_summary, vae_step)
100 |
101 | for d in range(FLAGS.mask_dim):
102 | mask_summary = sess.run(mask_latent_summary_op, feed_dict={graph.loss['mask_kl_var']: results['mask_kl'][d]})
103 | mask_latent_writers[d].add_summary(mask_summary, vae_step)
104 |
105 |
106 | if vae_step % FLAGS.ckpt_steps == 0:
107 | saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'model'), global_step=vae_step)
108 | progbar = Progbar(target=FLAGS.ckpt_steps)
109 |
110 | vae_step = results['vae_step']
111 |
112 | myprint("Training completed")
--------------------------------------------------------------------------------
/trainer/train_end2end.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import os
3 | import gflags
4 | #https://github.com/google/python-gflags
5 | import sys
6 | sys.path.append("..")
7 | import pprint
8 | from keras.utils.generic_utils import Progbar
9 | import model.Summary as Summary
10 | from model.utils.generic_utils import myprint, myinput, Permute_IoU
11 | from model.train_graph import Train_Graph
12 | import imageio
13 | import numpy as np
14 | import time
15 | import re
16 |
17 | def train(FLAGS):
18 | # learner
19 | graph = Train_Graph(FLAGS)
20 | graph.build()
21 |
22 | summary_op, tex_latent_summary_op, bg_latent_summary_op, eval_summary_op = Summary.collect_end2end_summary(graph, FLAGS)
23 | # train
24 | #define model saver
25 | with tf.name_scope("parameter_count"):
26 | total_parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) \
27 | for v in tf.trainable_variables()])
28 |
29 | save_vars = tf.global_variables()
30 | # tf.global_variables('Inpainter')+tf.global_variables('Generator')+ \
31 | # tf.global_variables('VAE')+tf.global_variables('Fusion') \
32 | # +tf.global_variables('train_op') #including global step
33 |
34 | if FLAGS.resume_CIS:
35 | CIS_vars = tf.global_variables('Inpainter')+tf.global_variables('Generator')
36 | CIS_saver = tf.train.Saver(CIS_vars, max_to_keep=100)
37 |
38 | mask_saver = tf.train.Saver(tf.global_variables('VAE//separate/maskVAE/'), max_to_keep=100)
39 | tex_saver = tf.train.Saver(tf.global_variables('VAE//separate/texVAE/'), max_to_keep=100)
40 |
41 | saver = tf.train.Saver(save_vars, max_to_keep=100)
42 | branch_writers = [tf.summary.FileWriter(os.path.join(FLAGS.checkpoint_dir,'branch'+str(m))) for m in range(FLAGS.num_branch)]
43 | tex_latent_writers = [tf.summary.FileWriter(os.path.join(FLAGS.checkpoint_dir, "tex_latent"+str(m))) for m in range(FLAGS.tex_dim)]
44 | bg_latent_writers = [tf.summary.FileWriter(os.path.join(FLAGS.checkpoint_dir, "bg_latent"+str(m))) for m in range(FLAGS.bg_dim)]
45 | #mask_latent_writers = [tf.summary.FileWriter(os.path.join(FLAGS.checkpoint_dir, "mask_latent"+str(m))) for m in range(FLAGS.mask_dim)]
46 |
47 |
48 |
49 | sv = tf.train.Supervisor(logdir=os.path.join(FLAGS.checkpoint_dir, "end2end_Sum"),
50 | saver=None, save_summaries_secs=0) #not saved automatically for flexibility
51 |
52 | with sv.managed_session() as sess:
53 | myprint ("Number of total params: {0} \n".format( \
54 | sess.run(total_parameter_count)))
55 | if FLAGS.resume_fullmodel:
56 | assert os.path.isfile(FLAGS.fullmodel_ckpt+'.index')
57 | saver.restore(sess, FLAGS.fullmodel_ckpt)
58 | myprint ("Resumed training from model {}".format(FLAGS.fullmodel_ckpt))
59 | myprint ("Start from step {} vae_step{}".format(sess.run(graph.global_step), sess.run(graph.vae_global_step)))
60 | myprint ("Save checkpoint in {}".format(FLAGS.checkpoint_dir))
61 | if not os.path.dirname(FLAGS.fullmodel_ckpt) == FLAGS.checkpoint_dir:
62 | print ("\033[0;30;41m"+"Warning: checkpoint dir and fullmodel ckpt do not match"+"\033[0m")
63 | #myprint ("Please make sure that the checkpoint will be saved in the same dir with the resumed model")
64 | else:
65 | if os.path.isfile(FLAGS.mask_ckpt+'.index'):
66 | mask_saver.restore(sess, FLAGS.mask_ckpt)
67 | myprint ("Load pretrained maskVAE {}".format(FLAGS.mask_ckpt))
68 | if os.path.isfile(FLAGS.tex_ckpt+'.index'):
69 | tex_saver.restore(sess, FLAGS.tex_ckpt)
70 | myprint ("Load pretrained texVAE {}".format(FLAGS.tex_ckpt))
71 | if FLAGS.resume_CIS:
72 | assert os.path.isfile(FLAGS.CIS_ckpt+'.index')
73 | CIS_saver.restore(sess, FLAGS.CIS_ckpt)
74 | myprint ("Load pretrained inpainter and generator {}".format(FLAGS.CIS_ckpt))
75 | else:
76 | myprint ("Train from scratch")
77 | myinput('Press enter to continue')
78 |
79 | start_time = time.time()
80 | step = sess.run(graph.global_step)
81 | vae_step = sess.run(graph.vae_global_step)
82 | progbar = Progbar(target=FLAGS.ckpt_steps) #100k
83 |
84 | sum_iters = FLAGS.iters_gen_vae + FLAGS.iters_inp
85 |
86 | while (time.time()-start_time) should have an VAE step
98 | fetches['vae_global_step'], fetches['vae_global_step_inc'] = graph.vae_global_step, graph.incr_vae_global_step
99 |
100 | if step % FLAGS.summaries_steps == 0:
101 | fetches["Inpainter_Loss"],fetches["Generator_Loss"] = graph.loss['Inpainter'], graph.loss['Generator']
102 | fetches["VAE//texVAE"], fetches["VAE//texVAE_BG"], fetches['VAE//fusion'] = graph.loss['VAE//separate/texVAE/'], graph.loss['VAE//separate/texVAE_BG/'], graph.loss['VAE//fusion']
103 | fetches['tex_kl'], fetches['bg_kl'] = graph.loss['tex_kl'], graph.loss['bg_kl']
104 | fetches['summary'] = summary_op
105 |
106 | if step % FLAGS.ckpt_steps == 0:
107 | fetches['generated_masks'] = graph.generated_masks
108 | fetches['GT_masks'] = graph.GT_masks
109 |
110 | results = sess.run(fetches, feed_dict={graph.is_training: True, graph.mask_capacity: mask_capacity})
111 | progbar.update(step%FLAGS.ckpt_steps)
112 |
113 | if step % FLAGS.summaries_steps == 0 :
114 | print (" Step:%3dk time:%4.4fmin VAELoss%4.2f" \
115 | %(step/1000, (time.time()-start_time)/60, results["VAE//texVAE"]+results['VAE//fusion']+results['VAE//texVAE_BG']))
116 | sv.summary_writer.add_summary(results['summary'], step)
117 |
118 | for d in range(FLAGS.tex_dim):
119 | tex_summary = sess.run(tex_latent_summary_op, feed_dict={graph.loss['tex_kl_var']: results['tex_kl'][d]})
120 | tex_latent_writers[d].add_summary(tex_summary, step)
121 |
122 | for d in range(FLAGS.bg_dim):
123 | bg_summary = sess.run(bg_latent_summary_op, feed_dict={graph.loss['bg_kl_var']: results['bg_kl'][d]})
124 | bg_latent_writers[d].add_summary(bg_summary, step)
125 |
126 | # for d in range(FLAGS.mask_dim):
127 | # mask_summary = sess.run(mask_latent_summary_op, feed_dict={graph.loss['mask_kl_var']: results['mask_kl'][d]})
128 | # mask_latent_writers[d].add_summary(mask_summary, step)
129 |
130 |
131 | if step % FLAGS.ckpt_steps == 0:
132 | saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'model'), global_step=step)
133 | progbar = Progbar(target=FLAGS.ckpt_steps)
134 |
135 | #evaluation
136 | sess.run(graph.val_iterator.initializer)
137 | fetches = {'GT_masks':graph.GT_masks, 'generated_masks':graph.generated_masks}
138 |
139 | if FLAGS.dataset in ['multi_texture', 'flying_animals']:
140 | #note that for multi_texture bg_num is just a fake number it represents number of samples for each type of image
141 | score = [[]]*FLAGS.max_num
142 | for bg in range(FLAGS.bg_num):
143 | results_val=sess.run(fetches, feed_dict={graph.is_training: False})
144 | for k in range(FLAGS.max_num):
145 | #score[k].append(Permute_IoU(results_val['GT_masks'][k], results_val['generated_masks'][k]))
146 | score[k] = score[k] + [Permute_IoU(label=results_val['GT_masks'][k], pred=results_val['generated_masks'][k])]
147 | for k in range(FLAGS.max_num):
148 | eval_summary = sess.run(eval_summary_op, feed_dict={graph.loss['EvalIoU_var']: np.mean(score[k])})
149 | branch_writers[k+1].add_summary(eval_summary, step)
150 | else:
151 | num_sample = FLAGS.skipnum
152 | niter = num_sample//FLAGS.batch_size
153 | assert num_sample%FLAGS.batch_size==0
154 | score = 0
155 | for it in range(niter):
156 | results_val = sess.run(fetches, feed_dict={graph.is_training:False})
157 | for k in range(FLAGS.batch_size):
158 | score += Permute_IoU(label=results_val['GT_masks'][k], pred=results_val['generated_masks'][k])
159 | score = score/num_sample
160 | eval_summary = sess.run(eval_summary_op, feed_dict={graph.loss['EvalIoU_var']: score})
161 | sv.summary_writer.add_summary(eval_summary, step)
162 |
163 | step = results['step']
164 | vae_step = results['vae_global_step']
165 |
166 | myprint("Training completed")
--------------------------------------------------------------------------------
/trainer/train_globalVAE.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import os
3 | import gflags
4 | #https://github.com/google/python-gflags
5 | import sys
6 | sys.path.append("..")
7 | import pprint
8 | from keras.utils.generic_utils import Progbar
9 | import model.Summary as Summary
10 | from model.utils.generic_utils import myprint, myinput, Permute_IoU
11 | from model.globalVAE_graph import Train_Graph
12 | import imageio
13 | import numpy as np
14 | import time
15 | import re
16 |
17 | def train(FLAGS):
18 | # learner
19 | graph = Train_Graph(FLAGS)
20 | graph.build()
21 |
22 | summary_op, latent_summary_op = Summary.collect_globalVAE_summary(graph, FLAGS)
23 | # train
24 | #define model saver
25 | with tf.name_scope("parameter_count"):
26 | total_parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) \
27 | for v in tf.trainable_variables()])
28 |
29 | save_vars = tf.global_variables()
30 | saver = tf.train.Saver(save_vars, max_to_keep=100)
31 |
32 | latent_writers = [tf.summary.FileWriter(os.path.join(FLAGS.checkpoint_dir, "latent"+str(m))) \
33 | for m in range(FLAGS.tex_dim)]
34 | sv = tf.train.Supervisor(logdir=os.path.join(FLAGS.checkpoint_dir, "globalVAE_Sum"),
35 | saver=None, save_summaries_secs=0) #not saved automatically for flexibility
36 |
37 | with sv.managed_session() as sess:
38 | myprint ("Number of total params: {0} \n".format( \
39 | sess.run(total_parameter_count)))
40 | if FLAGS.resume_fullmodel:
41 | assert os.path.isfile(FLAGS.fullmodel_ckpt+'.index')
42 | saver.restore(sess, FLAGS.fullmodel_ckpt)
43 | myprint ("Resumed training from model {}".format(FLAGS.fullmodel_ckpt))
44 | myprint ("Start from step {}".format(sess.run(graph.global_step)))
45 | myprint ("Save checkpoint in {}".format(FLAGS.checkpoint_dir))
46 | if not os.path.dirname(FLAGS.fullmodel_ckpt) == FLAGS.checkpoint_dir:
47 | print ("\033[0;30;41m"+"Warning: checkpoint dir and fullmodel ckpt do not match"+"\033[0m")
48 | #myprint ("Please make sure that the checkpoint will be saved in the same dir with the resumed model")
49 | else:
50 | myprint ("Train from scratch")
51 | myinput('Press enter to continue')
52 |
53 | start_time = time.time()
54 | step = sess.run(graph.global_step)
55 | progbar = Progbar(target=FLAGS.ckpt_steps) #100k
56 |
57 | while (time.time()-start_time)