├── .gitignore ├── README.md ├── common_flags.py ├── data ├── flying_animals_data │ └── fa_data_decode.py ├── flying_animals_utils.py ├── multi_dsprites_utils.py ├── multi_texture_data │ ├── bg.png │ ├── ellipse_2.png │ ├── square_2.png │ └── tex.png ├── multi_texture_utils.py └── objects_room_utils.py ├── doc ├── flying_animals.gif ├── multi_texture.gif ├── objects_room.gif └── pc.gif ├── eval └── eval_VAE.py ├── main.py ├── model ├── Summary.py ├── __init__.py ├── globalVAE_graph.py ├── nets.py ├── train_graph.py ├── traverse_graph.py └── utils │ ├── __init__.py │ ├── convolution_utils.py │ ├── generic_utils.py │ └── loss_utils.py ├── sample_imgs ├── flying_animals │ ├── 01.png │ └── 02.png ├── multi_dsprites │ ├── 01.png │ └── 02.png ├── multi_texture │ └── 01.png └── objects_room │ ├── 01.png │ ├── 02.png │ ├── 03.png │ └── 04.png ├── script ├── flying_animals │ ├── disentangle.sh │ ├── pretrain_inpainter.sh │ ├── test_segmentation.sh │ ├── train_CIS.sh │ └── train_VAE.sh ├── multi_dsprites │ ├── disentangle.sh │ ├── pretrain_inpainter.sh │ ├── test_segmentation.sh │ ├── train_CIS.sh │ └── train_VAE.sh ├── multi_texture │ ├── disentangle.sh │ ├── perceptual_consistency │ │ ├── finetune_PC.sh │ │ ├── test_segmentation.sh │ │ ├── train_CIS.sh │ │ └── train_VAE.sh │ ├── pretrain_inpainter.sh │ ├── test_segmentation.sh │ ├── train_CIS.sh │ └── train_VAE.sh └── objects_room │ ├── disentangle.sh │ ├── pretrain_inpainter.sh │ ├── test_segmentation.sh │ ├── train_CIS.sh │ └── train_VAE.sh ├── tb.sh ├── test_segmentation.py └── trainer ├── __init__.py ├── train_CIS.py ├── train_PC.py ├── train_VAE.py ├── train_end2end.py ├── train_globalVAE.py └── train_inpainter.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | data/flying_animals_data/*.npz 131 | data/flying_animals_data/data 132 | data/objects_room_data/*.tfrecords 133 | data/multi_dsprites_data/*.tfrecords 134 | 135 | save_checkpoint/ 136 | checkpoint/ 137 | script0/ 138 | tb.sh 139 | resnet/ 140 | *_select/ 141 | outputs/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [CVPR 2020] Learning to Manipulate Individual Objects in an Image 2 | This repo contains the implementation of the method described in the paper 3 | 4 | [Learning to Manipulate Individual Objects in an Image](https://arxiv.org/pdf/2004.05495.pdf) 5 | 6 | Published in the International Conference of Computer Vision and Pattern Recognition (CVPR) 2020. 7 | 8 | 9 | 10 | ### Introduction: 11 | We describe a method to train a generative model with latent factors that are (approximately) independent and localized. This means that perturbing the latent variables affects only local regions of the synthesized image, corresponding to objects. Unlike other unsupervised generative models, ours enables object-centric manipulation, without requiring object-level annotations, or any form of annotation for that matter. For more details, please check our paper. 12 | 13 |

14 | 15 | 16 |

17 | 18 |

19 | 20 |

21 | 22 | 23 | 24 | ## Running the code 25 | ### Prerequisites 26 | 27 | This code was tested with the following packages. Note that other version of them might work but are untested. 28 | 29 | * Ubuntu 16.04 30 | * python3 31 | * tensorflow-gpu==1.14.0 32 | * python-gflags 3.1.2 33 | * keras 2.3.1 34 | * imageio 2.6.1 35 | * numpy 1.17.2 36 | * gitpython 3.0.5 37 | 38 | ### Datasets 39 | 40 | 41 | #### Multi-dSprites and Objects Room 42 | 43 | Download two existing datasets with the following commands: 44 | ``` 45 | mkdir data/multi_dsprites_data data/objects_room_data 46 | wget https://storage.googleapis.com/multi-object-datasets/multi_dsprites/multi_dsprites_colored_on_colored.tfrecords -P data/multi_dsprites_data 47 | wget https://storage.googleapis.com/multi-object-datasets/objects_room/objects_room_train.tfrecords -P data/objects_room_data 48 | ``` 49 | These two datasets are TFRecords files and can be used without pre-processing. 50 | 51 | 52 | #### Multi-Texture 53 | 54 | The components are already included in [data/multi\_texture\_data](data/multi_texture_data) and will be automatically used to generate images online while training and testing. 55 | 56 | 57 | #### Flying Animals 58 | 59 | Please download the zip file from 60 | [this link](https://drive.google.com/open?id=1xs9CdR8HC_RxfuEbZnD_hmMqQusAuhbO), put it in [data/flying\_animals\_data](data/flying_animals_data) and then run the following commands to decode the raw images into .npz file. 61 | ``` 62 | cd data/flying_animals_data 63 | unzip data.zip 64 | python fa_data_decode.py 65 | ``` 66 | These commands generate img_data.npz and img_data_test.npz in [data/flying\_animals\_data](data/flying_animals_data) for training and testing 67 | 68 | ### Training 69 | 70 | To stabilize and speed up adversarial training, our training consists of three steps. Default hyperparameters settings for four datasets and three steps are included in [script/](script). Please modify some arguments, e.g. the path of output checkpoints, in scripts when necessary. 71 | 72 | #### 1.Pretrain inpainting network 73 | 74 | Pretrain inpainting network on the task of predicting pixels on box-shaped occlusions. 75 | ``` 76 | sh script/dataset_name/pretrain_inpainter.sh 77 | ``` 78 | Pretrained checkpoints of inpainting network for each dataset can be downloaded [here](https://drive.google.com/drive/folders/1AcFb2kfFpEuD-Wi_Iz_Z9-mkgs3anEOF?usp=sharing). You can directly restore the downloaded checkpoint to skip this step. 79 | 80 | #### 2.Spatial disentanglement 81 | 82 | Update inpainting network and segmentation network adversarially for spatial disentanglement. 83 | 84 | ``` 85 | sh script/dataset_name/train_CIS.sh 86 | ``` 87 | Note that while for other datasets we train segmentation network from scratch, for flying animals dataset we suggest initializing ResNetV2-50 with checkpoint pretrained on ImageNet which can be found [here](https://github.com/tensorflow/models/tree/master/research/slim#pre-trained-models). Please download the checkpoint by running 88 | ``` 89 | mkdir resnet && cd resnet 90 | wget http://download.tensorflow.org/models/resnet_v2_50_2017_04_14.tar.gz 91 | tar -xvf resnet_v2_50_2017_04_14.tar.gz 92 | ``` 93 | 94 | #### 3.Train VAE 95 | ``` 96 | sh script/dataset_name/train_VAE.sh 97 | ``` 98 | Train encoder and decoder to learn disentangled latent space. 99 | 100 | ### IoU Evaluation 101 | 102 | For a trained model, you can measure its segmentation performance with the function [test\_segmentation.py](./test_segmentation.py). Examples of test script are provided as script/dataset_name/test_segmentation.sh. Edit it with the path to the checkpoint file and run it to compute the mean and standard variance of mean-IoU scores on 10 subsets. 103 | ``` 104 | sh script/dataset_name/test_segmentation.sh 105 | ``` 106 | 107 | ### Disentanglement 108 | 109 | After finishing all training steps, you can visualize the latent space's disentanglement by feeding a target image into the model and varying one latent dimension at a time to see if the perturbation only leads to one type of semantic variation of one particular object in the synthesized image. 110 | 111 | Script examples for disentanglement visualization are provided as script/dataset_name/disentanglement.sh. Edit them with paths to checkpoint and output directories. 112 | ``` 113 | sh script/dataset_name/disentanglement.sh 114 | ``` 115 | Modify some of the arguments when necessary to set which objects and dimensions to perturb and the range of varying latent factors. 116 | 117 | ### Perceptual cycle-consistency 118 | We demonstrate the effectiveness of perceptual cycle-consistency constrain on Multi-Texture with each image including two objects of different identities, ellipse and square. Training scripts of the experiments are provided in [this folder](./script/multi_texture/perceptual_consistency). The first three training steps are the same as mentioned in [Training](./README.md#Training) without enforcing perceptual cycle-consistency. Then we finetune the model with perceptual cycle-consistency constrain by running 119 | ``` 120 | sh script/multi_texture/perceptual_consistency/finetune_PC.sh 121 | ``` 122 | It can be observed that the finetuning decreases identity switching rate and improves identity consistency. As shown in the figure below, finetuned model (middle) consistently captures the ellipse in channel 0 while un-finetuned model (right) can assign the square to channel 0 sometimes. 123 | 124 |

125 | 126 |

127 | 128 | To compute identity switching rate of the segmentation network, run 129 | ``` 130 | sh script/multi_texture/perceptual_consistency/test_segmentation.sh 131 | ``` 132 | We provide checkpoints for two models [here](https://drive.google.com/drive/folders/1WCBgnPim9l5aMjbgAg1wBETd3QoLY-Wd?usp=sharing). If you'd like to explore the effectiveness by yourself, we recommend downloading the [model](https://drive.google.com/drive/folders/1X5kDp-1swauBKaFXF32gRy9wnJEvtCKe?usp=sharing) that has been trained for the first three steps and restoring it to finetune with perceptual consistency. 133 | 134 | ## Downloads 135 | 136 | You can download our trained models for all datasets [here](https://drive.google.com/drive/folders/1AcFb2kfFpEuD-Wi_Iz_Z9-mkgs3anEOF?usp=sharing) including pretrained inpainting networks and final checkpoints of all modules. 137 | 138 | ## Citation 139 | 140 | If you use this code in academic context, please cite the following publication: 141 | 142 | ``` 143 | 144 | @InProceedings{Yang_2020_CVPR, 145 | author = {Yang, Yanchao and Chen, Yutong and Soatto, Stefano}, 146 | title = {Learning to Manipulate Individual Objects in an Image}, 147 | booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 148 | month = {June}, 149 | year = {2020} 150 | } 151 | ``` 152 | 153 | -------------------------------------------------------------------------------- /common_flags.py: -------------------------------------------------------------------------------- 1 | import gflags 2 | FLAGS = gflags.FLAGS 3 | 4 | 5 | 6 | #data 7 | gflags.DEFINE_string('dataset', 'multi_dsprites/ multi_texture/ objects_room/ flying_animals', 'Dataset used') 8 | gflags.DEFINE_string('root_dir',"/your/path/to/dataset", 'Folder containig the dataset') 9 | 10 | 11 | gflags.DEFINE_integer("takenum", -1, 'take number default: the entire dataset. not used for flying_animals') 12 | gflags.DEFINE_integer('skipnum',2000,'skip number default: 2k used for testset not used for flying_animals') 13 | gflags.DEFINE_bool('shuffle',False,'') 14 | 15 | 16 | #dir 17 | gflags.DEFINE_string('checkpoint_dir', "checkpoint", "Experiment folder. It will contain" 18 | "the saved checkpoints and tensorboard logs or disentanglement results") 19 | #gflags.DEFINE_string('output_dir',"./outputs/0","Containing outputs result when doing evaluation") 20 | gflags.DEFINE_integer('summaries_secs', 40, 'number of seconds between computation of summaries, used in train_inpainter') 21 | gflags.DEFINE_integer('summaries_steps', 100, 'number of step between computation of summaries, used in train_CIS') 22 | gflags.DEFINE_integer('ckpt_secs', 3600, 'number of seconds between checkpoint saving') 23 | gflags.DEFINE_integer('ckpt_steps', 10000, 'number of step between checkpoint saving') 24 | 25 | 26 | #resume 27 | gflags.DEFINE_bool('resume_fullmodel', False, 'whether to resume a fullmodel') 28 | gflags.DEFINE_bool('resume_inpainter', True, 'resume pretrained inpainter for train_CIS inpainter_ckpt needed') 29 | gflags.DEFINE_bool('resume_resnet', False, 'whether to use pretrained resnet (effective when resume_fullmodel=False)') 30 | gflags.DEFINE_bool('resume_CIS', False, 'whether to resume inpainter and generator') 31 | #checkpoint to load 32 | # used for resumed training or evaluation 33 | gflags.DEFINE_string('fullmodel_ckpt', '?', 'checkpoint of full model inpainter+Generator(train_CIS) inp+gen+VAE(train_end2end)') 34 | gflags.DEFINE_string('CIS_ckpt', '?', 'checkpoint of inpainter + generator') 35 | gflags.DEFINE_string('mask_ckpt', '?', '') 36 | gflags.DEFINE_string('tex_ckpt', '?', '') 37 | gflags.DEFINE_string('generator_ckpt', '?', 'checkpoint of mask generator') 38 | gflags.DEFINE_string('inpainter_ckpt', '?', 'checkpoint of pretrained inpainter') 39 | gflags.DEFINE_string('resnet_ckpt', 'resnet/resnet_v2_50.ckpt', 'checkpoint of pretrained resnet') 40 | #to - do VAE (TEXTURE AND SHAPE) 41 | 42 | 43 | 44 | gflags.DEFINE_integer('max_training_hrs', 72,'maximum training hours') 45 | #copy the sh 46 | 47 | #mode 48 | gflags.DEFINE_string('mode', 'train_CIS', 'pretrain_inpainter / train_CIS / train_VAE / eval_segment / eval_VAE /train_supGenerator') 49 | gflags.DEFINE_string('sh_path','./train.sh', 'absolute path of the running shell') 50 | 51 | 52 | 53 | # 54 | gflags.DEFINE_integer('batch_size', 32, 'batch_size') 55 | gflags.DEFINE_integer('num_branch', 6, 'output channel of segmentation') 56 | gflags.DEFINE_integer('nobj', -1, 'number of objects, only used in evaluation or fixed_number training') 57 | 58 | #network 59 | gflags.DEFINE_string('model', 'resnet_v2_50', 'resnet_v2_50 or segnet') 60 | #VAE 61 | gflags.DEFINE_integer('tex_dim', 4, 'dimension of texture latent space') 62 | gflags.DEFINE_integer('mask_dim', 10, 'dimension of mask latent space') 63 | gflags.DEFINE_integer('bg_dim', 10, 'dimension of bg latent space') 64 | gflags.DEFINE_float('VAE_weight', 0,'weight of tex_error and mask_error loss for Generator when training End2End') 65 | gflags.DEFINE_float('CIS_weight', 1,'weight of CIS loss for Generator when training End2End') 66 | gflags.DEFINE_float('tex_beta', 10,'ratio of tex_error loss and tex_kl loss') 67 | gflags.DEFINE_float('mask_gamma', 50000,'') 68 | gflags.DEFINE_float('mask_capacity_inc',1e-5, 'increment of mask capacity at each step') 69 | gflags.DEFINE_float('bg_beta', 10,'ratio of bg_error loss and bg_kl loss') 70 | 71 | #hyperparameters 72 | gflags.DEFINE_float('gen_lr',1e-3,'learning rate') 73 | gflags.DEFINE_float('inp_lr',1e-4,'learning rate') 74 | gflags.DEFINE_float('VAE_lr',1e-4,'learning rate') 75 | gflags.DEFINE_float('epsilon', 40, 'epsilon in the denominator of IRR') 76 | gflags.DEFINE_float('gen_clip_value', -1, 'generator''s grad_clip_value -1 means no clip') 77 | gflags.DEFINE_integer('iters_inp', 1, 'iteration # of inpainter') 78 | gflags.DEFINE_integer('iters_gen', 3, 'iteration # of generator') 79 | gflags.DEFINE_integer('iters_gen_vae', 3, 'iteration # of generator and vae used when training end2end') 80 | gflags.DEFINE_float('ita',1e-3,'weight of perceptual consistency loss used in train_PC mode') 81 | 82 | 83 | #flying animals (only support) 84 | gflags.DEFINE_integer('max_num',5,'max number of objects in the image') 85 | #gflags.DEFINE_integer('min_num',1,'min number of objects in the image') 86 | gflags.DEFINE_integer('bg_num', 100, 'number of bg') 87 | gflags.DEFINE_integer('ani_num',240,'') 88 | 89 | 90 | 91 | #automatically set flags 92 | gflags.DEFINE_integer('img_height',64,'') 93 | gflags.DEFINE_integer('img_width',64,'') 94 | gflags.DEFINE_integer('n_bg',1,'') 95 | 96 | 97 | #traverse 98 | gflags.DEFINE_string('input_img','./?','input image path') 99 | gflags.DEFINE_string('traverse_type', 'tex', 'tex or branch') 100 | gflags.DEFINE_integer('top_kdim', 5, 'k dimensions with largest KL divergence to traverse') 101 | gflags.DEFINE_string('traverse_branch', 'all', 'all or #1,#2,#3') 102 | gflags.DEFINE_float('traverse_range', '5', 'k z_mean +- k*sigma') 103 | gflags.DEFINE_float('traverse_start', '-1', 'k z_mean +- k*sigma') 104 | gflags.DEFINE_float('traverse_end', '1', 'k z_mean +- k*sigma') 105 | 106 | gflags.DEFINE_string('VAE_loss','CE','CE or L1') 107 | gflags.DEFINE_bool('PC', False, 'Experiment for perceptual consistency') -------------------------------------------------------------------------------- /data/flying_animals_data/fa_data_decode.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import imageio 3 | import os 4 | import argparse 5 | 6 | 7 | rm_bg = [34,80,71,72,38,28,85,88,93,86,75,70,42,30,24,64,65,90,82,25, 8 | 29,21,40,4,51,79,33,68,83,91,59,6,87,45,94,99, 9 | 23,78,36,19,77,39,62,52,81,56,98,43] 10 | rm_animals = [2,3,7,8,11,20,4,6,9,19] 11 | 12 | 13 | parser = argparse.ArgumentParser(description='decode raw images into .npz') 14 | parser.add_argument('--data_dir', type=str, default='./data', help='directory of the data data_dir/foregrounds,backgrounds,masks') 15 | 16 | args = parser.parse_args() 17 | 18 | data_dir = args.data_dir 19 | 20 | data = {'background':[], 'foreground':[], 'mask':[]} 21 | test_data = {'background':[], 'foreground':[], 'mask':[]} 22 | 23 | for i in range(241): 24 | fg = imageio.imread(os.path.join(data_dir, 'foregrounds', '{}.png'.format(i))) 25 | data['foreground'].append(fg) 26 | 27 | mask = imageio.imread(os.path.join(data_dir, 'masks', '{}.png'.format(i))) 28 | mask = mask.astype(np.bool) 29 | data['mask'].append(mask) 30 | 31 | if not i in rm_animals: 32 | test_data['foreground'].append(fg) 33 | test_data['mask'].append(mask) 34 | 35 | 36 | for i in range(101): 37 | bg = imageio.imread(os.path.join(data_dir, 'backgrounds', '{}.png'.format(i))) 38 | data['background'].append(bg) 39 | 40 | if not i in rm_bg: 41 | test_data['background'].append(bg) 42 | 43 | 44 | 45 | np.savez('img_data',background=data['background'], foreground=data['foreground'], mask=data['mask']) 46 | np.savez('img_data_test',background=test_data['background'], foreground=test_data['foreground'], mask=test_data['mask']) -------------------------------------------------------------------------------- /data/flying_animals_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import random 4 | import imageio 5 | import os 6 | import functools 7 | import scipy.ndimage as ndimage 8 | 9 | #online generator for flying animals dataset 10 | 11 | H, W = 192,256 12 | def convert(imgs): 13 | return (imgs/255).astype(np.float32) 14 | 15 | def check_occlusion(pos, pos_list): 16 | for p in pos_list: 17 | dist = abs(pos[0]-p[0])+abs(pos[1]-p[1]) 18 | if dist<=97: 19 | return True 20 | return False 21 | 22 | def generate_params(data, max_num, num): 23 | deter_params = [] 24 | bg_num = data['background'].shape[0]-1 25 | ani_num = data['foreground'].shape[0]-1 26 | for i in range(num): #size of the validation set 27 | param = dict() 28 | param['bg_index'], param['number'] = random.randint(1, bg_num), get_number(random.uniform(0,1)) 29 | param['fg_indices'] = [random.randint(1,ani_num) for k in range(param['number'])] +[0]*(max_num-param['number']) 30 | 31 | pos_list = [] 32 | params = [] 33 | for k in range(param['number']): 34 | f = random.uniform(0.9, 1.2) 35 | dx, dy = random.uniform(-W*2.2//5, W*2.2//5), random.uniform(-H*2.2//5, H*2.2//5) 36 | while check_occlusion([dx, dy], pos_list): 37 | dx, dy = random.uniform(-W*2.2//5, W*2.2//5), random.uniform(-H*2.2//5, H*2.2//5) 38 | pos_list.append([dx,dy]) 39 | p= [f,0,W/2-f*W/2-f*dx,0,f,H/2-f*H/2-f*dy,0,0] 40 | params.append(p) 41 | params += [[1,0,0,0,1,0,0,0]]*(max_num-param['number']) 42 | param['params'] = params #factor translation 43 | deter_params.append(param) 44 | return deter_params # list of dicts 5*100 element 45 | 46 | def get_number(n): 47 | #probability of distribution -- number of animals in an image 48 | if n<=0.8: 49 | return 5 #0.8 50 | elif n<=0.9: 51 | return 4 #0.1 52 | elif n<=0.96: 53 | return 3 #0.06 54 | elif n<=0.99: 55 | return 2 #0.03 56 | else: 57 | return 1 #0.01 58 | 59 | def params_gen(data, max_num, deterministic_params=None): 60 | backgrounds = convert(data['background']) 61 | foregrounds = convert(data['foreground']) 62 | ani_masks = data['mask'].astype(np.float32) 63 | 64 | bg_num = data['background'].shape[0]-1 65 | ani_num = data['foreground'].shape[0]-1 66 | 67 | step = 0 68 | while True: 69 | step += 1 70 | if deterministic_params: #200 images for validation 71 | param = deterministic_params[(step-1)%len(deterministic_params)] 72 | bg = backgrounds[param['bg_index']] 73 | texes = [foregrounds[i] for i in param['fg_indices']] 74 | masks = [ani_masks[i] for i in param['fg_indices']] 75 | yield bg, np.stack(texes, axis=0), np.stack(masks, axis=0), param['params'] 76 | else: #online generated data (infinite) 77 | number = get_number(random.uniform(0,1)) 78 | params = [] 79 | pos_list = [] 80 | texes = [] 81 | masks = [] 82 | bg = backgrounds[random.randint(1,bg_num)] 83 | for i in range(number): 84 | ind = random.randint(1,ani_num) 85 | texes.append(foregrounds[ind]) 86 | masks.append(ani_masks[ind]) 87 | f = random.uniform(0.9, 1.2) #input /output 88 | dx, dy = random.uniform(-W*2.2//5, W*2.2//5), random.uniform(-H*2.2//5, H*2.2//5) 89 | while check_occlusion([dx, dy], pos_list): 90 | dx, dy = random.uniform(-W*2.2//5, W*2.2//5), random.uniform(-H*2.2//5, H*2.2//5) 91 | pos_list.append([dx,dy]) 92 | param = [f,0,W/2-f*W/2-f*dx,0,f,H/2-f*H/2-f*dy,0,0] 93 | params.append(param) 94 | params += [[1,0,0,0,1,0,0,0]]*(max_num-number) 95 | texes += [backgrounds[0]]*(max_num-number) 96 | masks += [ani_masks[0]]*(max_num-number) 97 | 98 | yield bg, np.stack(texes, axis=0), np.stack(masks, axis=0), params 99 | 100 | 101 | def generate_image(bg, texes, masks, params, max_num): 102 | #given the randomly selected foregrounds, backgrounds and their factors of variation, synthesize the final image. 103 | 104 | #zoom and shift transform 105 | texes = tf.contrib.image.transform(texes, transforms=params,interpolation='BILINEAR') 106 | masks = tf.contrib.image.transform(masks, transforms=params,interpolation='BILINEAR') 107 | texes = tf.clip_by_value(texes, 0, 1) 108 | masks = tf.clip_by_value(masks, 0, 1) 109 | 110 | cum_mask = tf.zeros_like(masks[0]) 111 | depth_masks = [] 112 | perturbed_texes = [] 113 | for i in range(0,max_num): #depth order -> from near to far 114 | perturbed_tex = texes[i]+tf.random.uniform([], minval=-0.36,maxval=0.36, dtype=tf.float32) #brightness variation 115 | perturbed_texes.append(perturbed_tex) 116 | m = masks[i]*(1-cum_mask) 117 | cum_mask += m 118 | depth_masks.append(m) 119 | 120 | 121 | bg = bg + tf.random.uniform([], minval=-0.36,maxval=0.36, dtype=tf.float32) 122 | perturbed_texes = [bg] + perturbed_texes 123 | depth_masks = [1-cum_mask] + depth_masks 124 | perturbed_texes = tf.stack(perturbed_texes, axis=0) #C H W 3 125 | depth_masks = tf.stack(depth_masks, axis=0) #C H W 1 126 | 127 | img = tf.reduce_sum(perturbed_texes*depth_masks, axis=0) #H W 3 128 | 129 | 130 | data = {} 131 | data['img'] = img #float 0~1 132 | data['masks'] = tf.transpose(depth_masks, perm=[1,2,3,0]) #C H W 1 -> H W 1 C (bg first channel) 133 | return data 134 | 135 | 136 | def dataset(data_path, batch_size, max_num=5, phase='train'): 137 | """ 138 | Args: 139 | data_path: the path of npz data/flying_animals_data/img_data.npz 140 | batch_size: batchsize 141 | max_num: max number of animals in an image 142 | phase: train: infinitely online generating image/ val: deterministic 200 / test: deterministic 2000 143 | """ 144 | assert max_num==5,'please re-assign the distribution in get_number(), flying_animals_utils.py, if you need to reset max_num' 145 | 146 | if phase in ['val', 'test']: 147 | assert 200%batch_size==0 148 | data_path = os.path.abspath(data_path) 149 | data_path = data_path.split('.')[0]+'_test.npz' #data/flying_animals_data/img_data_test.npz 150 | 151 | data = np.load(data_path) 152 | deterministic_params = generate_params(data, max_num=max_num, num=200 if phase=='val' else 2000) if not phase=='train' else None 153 | partial_fn = functools.partial(params_gen, data=data, max_num=max_num, 154 | deterministic_params=deterministic_params) 155 | dataset = tf.data.Dataset.from_generator( 156 | partial_fn, 157 | (tf.float32, tf.float32, tf.float32, tf.float32), 158 | (tf.TensorShape([H,W,3]),tf.TensorShape([max_num,H,W,3]),tf.TensorShape([max_num,H,W,1]),tf.TensorShape([max_num,8]))) 159 | dataset = dataset.map(lambda bg,t,m,p: generate_image(bg,t,m,p,max_num=max_num), 160 | num_parallel_calls=1) 161 | 162 | dataset = dataset.batch(batch_size) 163 | dataset = dataset.prefetch(10) 164 | return dataset -------------------------------------------------------------------------------- /data/multi_dsprites_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Multi-dSprites dataset reader.""" 16 | 17 | import functools 18 | import tensorflow as tf 19 | COMPRESSION_TYPE = tf.io.TFRecordOptions.get_compression_type_string('GZIP') 20 | IMAGE_SIZE = [64, 64] 21 | # The maximum number of foreground and background entities in each variant 22 | # of the provided datasets. The values correspond to the number of 23 | # segmentation masks returned per scene. 24 | MAX_NUM_ENTITIES = { 25 | 'binarized': 4, 26 | 'colored_on_grayscale': 6, 27 | 'colored_on_colored': 5 28 | } 29 | BYTE_FEATURES = ['mask', 'image'] 30 | def feature_descriptions(max_num_entities, is_grayscale=False): 31 | """Create a dictionary describing the dataset features. 32 | Args: 33 | max_num_entities: int. The maximum number of foreground and background 34 | entities in each image. This corresponds to the number of segmentation 35 | masks and generative factors returned per scene. 36 | is_grayscale: bool. Whether images are grayscale. Otherwise they're assumed 37 | to be RGB. 38 | Returns: 39 | A dictionary which maps feature names to `tf.Example`-compatible shape and 40 | data type descriptors. 41 | """ 42 | 43 | num_channels = 1 if is_grayscale else 3 44 | return { 45 | 'image': tf.io.FixedLenFeature(IMAGE_SIZE+[num_channels], tf.string), #shape dtype 46 | 'mask': tf.io.FixedLenFeature(IMAGE_SIZE+[max_num_entities, 1], tf.string), 47 | 'x': tf.io.FixedLenFeature([max_num_entities], tf.float32), 48 | 'y': tf.io.FixedLenFeature([max_num_entities], tf.float32), 49 | 'shape': tf.io.FixedLenFeature([max_num_entities], tf.float32), 50 | 'color': tf.io.FixedLenFeature([max_num_entities, num_channels], tf.float32), 51 | 'visibility': tf.io.FixedLenFeature([max_num_entities], tf.float32), 52 | 'orientation': tf.io.FixedLenFeature([max_num_entities], tf.float32), 53 | 'scale': tf.io.FixedLenFeature([max_num_entities], tf.float32), 54 | } 55 | 56 | def _decode(example_proto, features): 57 | # Parse the input `tf.Example` proto using a feature description dictionary. 58 | single_example = tf.io.parse_single_example(example_proto, features) 59 | for k in BYTE_FEATURES: #mask image 60 | single_example[k] = tf.squeeze(tf.decode_raw(single_example[k], tf.uint8), 61 | axis=-1) # height width entities channels 62 | # To return masks in the canonical [entities, height, width, channels] format, 63 | # we need to transpose the tensor axes. 64 | single_example['mask'] = tf.transpose(single_example['mask'], [0, 1, 3, 2]) #H W 1 M 65 | 66 | return map(single_example) 67 | 68 | def map(x): 69 | img = x['image'] 70 | img = tf.cast(img, tf.float32) 71 | img = img/255 72 | mask = x['mask'] 73 | mask = tf.cast(mask, tf.float32) 74 | mask = mask/255 #0~1 75 | data = {} 76 | data['img'] = img 77 | data['masks'] = mask 78 | return data 79 | 80 | def dataset(tfrecords_path, batch_size, phase='train'): 81 | if phase=='test': 82 | skipnum, takenum = 0,2000 83 | shuffle = False 84 | elif phase=='val': 85 | skipnum, takenum = 2000,1000 86 | shuffle = False 87 | else: 88 | skipnum, takenum = 3000, -1 89 | shuffle = True 90 | max_num_entities = MAX_NUM_ENTITIES['colored_on_colored'] #colored on colored -> 5 91 | raw_dataset = tf.data.TFRecordDataset( 92 | tfrecords_path, compression_type=COMPRESSION_TYPE) 93 | raw_dataset = raw_dataset.skip(skipnum).take(takenum) 94 | features = feature_descriptions(max_num_entities, False) 95 | partial_decode_fn = functools.partial(_decode, features=features) 96 | 97 | dataset = raw_dataset.map(partial_decode_fn,num_parallel_calls=1) 98 | if shuffle: 99 | dataset = dataset.shuffle(seed=479, buffer_size=50000, reshuffle_each_iteration=True) 100 | dataset = dataset.repeat().batch(batch_size) 101 | dataset = dataset.prefetch(10) 102 | return dataset 103 | -------------------------------------------------------------------------------- /data/multi_texture_data/bg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenYutongTHU/Learning-to-manipulate-individual-objects-in-an-image-Implementation/db75a5505f7fe2c83c0ded08f425ef11759544bd/data/multi_texture_data/bg.png -------------------------------------------------------------------------------- /data/multi_texture_data/ellipse_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenYutongTHU/Learning-to-manipulate-individual-objects-in-an-image-Implementation/db75a5505f7fe2c83c0ded08f425ef11759544bd/data/multi_texture_data/ellipse_2.png -------------------------------------------------------------------------------- /data/multi_texture_data/square_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenYutongTHU/Learning-to-manipulate-individual-objects-in-an-image-Implementation/db75a5505f7fe2c83c0ded08f425ef11759544bd/data/multi_texture_data/square_2.png -------------------------------------------------------------------------------- /data/multi_texture_data/tex.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenYutongTHU/Learning-to-manipulate-individual-objects-in-an-image-Implementation/db75a5505f7fe2c83c0ded08f425ef11759544bd/data/multi_texture_data/tex.png -------------------------------------------------------------------------------- /data/multi_texture_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import random 4 | import imageio 5 | import os 6 | import functools 7 | import scipy.ndimage as ndimage 8 | H,W=64,64 9 | pos_choice = [-24,-18,-12,-6,0,6,12,18,24] 10 | n_pos = 9 11 | 12 | def check_occlusion(pos, pos_list): 13 | for p in pos_list: 14 | dist = abs(pos[0]-p[0])+abs(pos[1]-p[1]) 15 | if dist<=10: 16 | return True 17 | return False 18 | 19 | def generate_params(data_path, num, max_num, PC): 20 | deterministic_params = [] 21 | if PC: 22 | num = n_pos*n_pos*n_pos*n_pos-1 23 | for i in range(num): 24 | param = dict() 25 | param['number'] = 2 if PC else get_number(random.uniform(0,1)) 26 | param['ind'], param['mat'] = [],[] 27 | pos_list = [] 28 | 29 | dr_=[pos_choice[i%n_pos], pos_choice[((i//n_pos)//n_pos)%n_pos]] 30 | dc_=[pos_choice[(i//n_pos)%n_pos], pos_choice[(((i//n_pos)//n_pos)//n_pos)%n_pos]] 31 | 32 | for k in range(param['number']): 33 | if PC: 34 | param['ind'].append(k) 35 | else: 36 | param['ind'].append(random.randint(0,1)) #shape 0 square 1 ellipse 37 | 38 | if PC: 39 | dr, dc = dr_[k], dc_[k] 40 | else: 41 | dr = random.uniform(-H/2,H/2) 42 | dc = random.uniform(-W/2,W/2) 43 | while check_occlusion([dr,dc], pos_list): 44 | dr = random.uniform(-H/2,H/2) 45 | dc = random.uniform(-W/2,W/2) 46 | pos_list.append([dr,dc]) 47 | 48 | mat = np.zeros([3,4]) 49 | mat[0][0], mat[0][3] = 1, dr 50 | mat[1][1], mat[1][3] = 1, dc 51 | mat[2][2] = 1 52 | param['mat'].append(mat) 53 | param['hue'] = [-0.2,0,0.2] if PC else [random.uniform(-0.45,0.45) for h_ in range(max_num+1)] 54 | deterministic_params.append(param) 55 | return deterministic_params 56 | 57 | def multi_texture_gen(data_path, max_num=4, deterministic_params=None, PC=False): 58 | tex0 = imageio.imread(os.path.join(data_path,'tex.png')) 59 | tex0 = (tex0/255).astype(np.float32) 60 | square = imageio.imread(os.path.join(data_path,'square_2.png')).reshape(64,64,1) 61 | square = (square/255).astype(np.float32) 62 | ellipse = imageio.imread(os.path.join(data_path,'ellipse_2.png')).reshape(64,64,1) 63 | ellipse = (ellipse/255).astype(np.float32) 64 | masks = [ellipse, square] 65 | 66 | step = 0 67 | while True: 68 | param = deterministic_params[step%(len(deterministic_params))] if deterministic_params else None 69 | step += 1 70 | if param: 71 | number = param['number'] 72 | elif PC: 73 | number = 2 74 | else: 75 | number = get_number(random.uniform(0,1)) 76 | shape_masks = [] 77 | shape_texes = [] # 78 | cum_mask = np.zeros_like(masks[0]) # occlusion 79 | pos_list = [] 80 | for i in range(number): #place the randomly selected and transformed shape on the background in the ascending depth order 81 | if param: 82 | ind = param['ind'][i] 83 | elif PC: 84 | ind = i 85 | else: 86 | ind = random.randint(0, 1) #choose the shape 87 | shape = masks[ind].copy() 88 | tex = tex0.copy() 89 | if param: 90 | mat = param['mat'][i] 91 | else: 92 | if PC: 93 | dr = random.choice(pos_choice) 94 | dc = random.choice(pos_choice) 95 | else: 96 | dr = random.uniform(-H/2,H/2) 97 | dc = random.uniform(-W/2,W/2) 98 | while check_occlusion([dr,dc], pos_list): 99 | dr = random.uniform(-H/2,H/2) 100 | dc = random.uniform(-W/2,W/2) 101 | pos_list.append([dr,dc]) 102 | mat = np.zeros([3,4]) 103 | mat[0][0], mat[0][3] = 1, dr 104 | mat[1][1], mat[1][3] = 1, dc 105 | mat[2][2] = 1 106 | shape = ndimage.affine_transform(shape, mat, output_shape=(64,64,1)) 107 | shape = np.clip(shape, 0,1) 108 | shape = shape*(1-cum_mask) 109 | cum_mask += shape 110 | shape_masks.append(shape) 111 | shape_texes.append(tex) 112 | for i in range(number, max_num):#pad the returned element 113 | shape_masks.append(np.zeros_like(shape)) 114 | shape_texes.append(np.zeros_like(tex)) 115 | hue_value = param['hue'] if deterministic_params else [random.uniform(-0.45,0.45) for h_ in range(max_num+1)] 116 | yield (number, np.stack(shape_masks, axis=0), np.stack(shape_texes, axis=0), hue_value) 117 | 118 | def get_number(n): 119 | if n<=0.8: 120 | return 4 #0.8 121 | elif n<=0.9: 122 | return 3 #0.1 123 | elif n<=0.97: 124 | return 2 #0.07 125 | else: 126 | return 1 #0.03 127 | 128 | def combine(number, masks, texes, hue_value, bg, max_num): 129 | hue_texes = [] 130 | for i in range(max_num): 131 | hue_texes.append(tf.image.adjust_hue(texes[i], hue_value[i+1])) #randomly perturb the hue value 132 | hue_texes = tf.stack(hue_texes, axis=0) #C H W 1 133 | fg = tf.reduce_sum(masks*hue_texes, axis=0)#np.sum(masks*texes, axis=0, keepdims=False) #H W 3 134 | bg_mask = 1-tf.reduce_sum(masks, axis=0) # H W 1 135 | bg = tf.image.adjust_hue(bg, hue_value[0]) 136 | 137 | all_masks = tf.concat([tf.expand_dims(bg_mask,axis=0), masks], axis=0) #H W 1 + C H W 1 138 | 139 | data = {} 140 | data['img'] = bg_mask*bg+fg 141 | data['masks'] = tf.transpose(all_masks, perm=[1,2,3,0]) #C H W 1 -> H W 1 C 142 | return data 143 | 144 | def dataset(data_path, batch_size, max_num=4, phase='train', PC=False): 145 | """ 146 | Args: 147 | data_path: the path of combination elements 148 | batch_size: 149 | max_num: max_num: max number of objects in an image 150 | phase: train: infinitely online generating image/ val: deterministic 200 / test: deterministic 2000 151 | PC: experiment for Perceptual consistency 152 | """ 153 | if PC: 154 | assert max_num==2, 'set max_num as 2 in experiment for Perceptual consistency' 155 | else: 156 | assert max_num==4,'please re-assign the distribution in get_number(), multi_texture_utils.py, if you need to reset max_num' 157 | deterministic_params = generate_params(data_path, num=200 if phase=='val' else 2000, max_num=max_num, PC=PC) if not phase=='train' else None 158 | 159 | partial_fn = functools.partial(multi_texture_gen, 160 | data_path=data_path, max_num=max_num, deterministic_params=deterministic_params, PC=PC) 161 | 162 | dataset = tf.data.Dataset.from_generator( 163 | partial_fn,#(data_path, max_num, zoom, rotation), 164 | (tf.int32, tf.float32, tf.float32, tf.float32), 165 | (tf.TensorShape([]),tf.TensorShape([max_num,H,W,1]), tf.TensorShape([max_num,H,W,3]), tf.TensorShape([max_num+1]))) 166 | 167 | bg0 = imageio.imread(os.path.join(data_path,'bg.png')) 168 | bg0 = tf.convert_to_tensor(bg0/255, dtype=tf.float32) #0~1 169 | dataset = dataset.map(lambda n,m,t,h: combine(n,m,t,h, bg=bg0,max_num=max_num), num_parallel_calls=1) 170 | dataset = dataset.batch(batch_size) 171 | # # print (dataset) 172 | dataset = dataset.prefetch(10) 173 | return dataset 174 | 175 | 176 | -------------------------------------------------------------------------------- /data/objects_room_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Objects Room dataset reader.""" 16 | 17 | import functools 18 | import tensorflow as tf 19 | 20 | 21 | COMPRESSION_TYPE = tf.io.TFRecordOptions.get_compression_type_string('GZIP') 22 | IMAGE_SIZE = [64, 64] 23 | # The maximum number of foreground and background entities in each variant 24 | # of the provided datasets. The values correspond to the number of 25 | # segmentation masks returned per scene. 26 | MAX_NUM_ENTITIES = { 27 | 'train': 7, 28 | 'six_objects': 10, 29 | 'empty_room': 4, 30 | 'identical_color': 10 31 | } 32 | BYTE_FEATURES = ['mask', 'image'] 33 | 34 | 35 | def feature_descriptions(max_num_entities): 36 | """Create a dictionary describing the dataset features. 37 | Args: 38 | max_num_entities: int. The maximum number of foreground and background 39 | entities in each image. This corresponds to the number of segmentation 40 | masks returned per scene. 41 | Returns: 42 | A dictionary which maps feature names to `tf.Example`-compatible shape and 43 | data type descriptors. 44 | """ 45 | return { 46 | 'image': tf.FixedLenFeature(IMAGE_SIZE+[3], tf.string), 47 | 'mask': tf.FixedLenFeature([max_num_entities]+IMAGE_SIZE+[1], tf.string), 48 | } 49 | 50 | 51 | def _decode(example_proto, features, random_sky): 52 | # Parse the input `tf.Example` proto using a feature description dictionary. 53 | single_example = tf.parse_single_example(example_proto, features) 54 | for k in BYTE_FEATURES: 55 | single_example[k] = tf.squeeze(tf.decode_raw(single_example[k], tf.uint8), 56 | axis=-1) 57 | # sky floor half-wall1(2) half-wall2(3) objects objects objects 58 | mask = tf.transpose(single_example['mask'], [1, 2, 3, 0]) #H W 1 7 59 | mask = tf.concat([mask[:,:,:,0:2],mask[:,:,:,2:3]+mask[:,:,:,3:4],mask[:,:,:,4:]], axis=-1) #H W 1 6 merge the wall 60 | single_example['mask'] = mask 61 | return map(single_example, random_sky) 62 | 63 | 64 | def map(x, random_sky): 65 | img = tf.cast(x['image']/255, tf.float32) #0~1 66 | mask = tf.cast(x['mask']/255, tf.float32) 67 | 68 | data = {} 69 | data['masks'] = mask 70 | 71 | sky_mask = mask[:,:,:,0]#H W 1 72 | scale = tf.random_uniform(shape=[], minval=0.2, maxval=1, dtype=tf.float32) if random_sky else 1 73 | var_img = img*(1-sky_mask)+img*scale*sky_mask #0~1 74 | data['img'] = var_img 75 | return data 76 | 77 | def dataset(tfrecords_path, batch_size, phase='train'): 78 | if phase=='test': 79 | skipnum, takenum = 0,2000 80 | shuffle = False 81 | elif phase=='val': 82 | skipnum, takenum = 2000,1000 83 | shuffle = False 84 | else: 85 | skipnum, takenum = 3000, -1 86 | shuffle = True 87 | 88 | max_num_entities = MAX_NUM_ENTITIES['train'] 89 | raw_dataset = tf.data.TFRecordDataset( 90 | tfrecords_path, compression_type=COMPRESSION_TYPE, 91 | buffer_size=50, num_parallel_reads=2) 92 | features = feature_descriptions(max_num_entities) 93 | partial_decode_fn = functools.partial(_decode, features=features, random_sky=(phase=='train')) #val constant sky 94 | 95 | dataset = raw_dataset.skip(skipnum).take(takenum) 96 | dataset = dataset.map(partial_decode_fn, num_parallel_calls=1) 97 | if shuffle: 98 | dataset = dataset.shuffle(seed=479, buffer_size=batch_size*100, reshuffle_each_iteration=True) 99 | dataset = dataset.repeat().batch(batch_size) 100 | dataset = dataset.prefetch(10) 101 | return dataset 102 | 103 | 104 | # import imageio 105 | # import numpy as np 106 | # bs = 4 107 | # dataset = dataset('./objects_room_data/objects_room_train.tfrecords',val=True, 108 | # batch_size=bs, skipnum=0, takenum=-1, 109 | # shuffle=False, map_parallel_calls=1) 110 | 111 | # iterator = dataset.make_one_shot_iterator() 112 | 113 | # data_batch = iterator.get_next() 114 | # edge_batch = bin_edge_map(data_batch['img'], 'objects_room') 115 | 116 | # sess = tf.Session() 117 | # for i in range(3): 118 | # data, edges = sess.run((data_batch, edge_batch)) 119 | # for k in range(bs): 120 | # img = data['img'][k,:,:,:] 121 | # masks = data['masks'][k,:,:,:,:] 122 | # imageio.imwrite('debug/{}_{}img.png'.format(i,k), (img*255).astype(np.uint8)) 123 | 124 | # edge = edges[k,:,:,:] #H W 2 125 | # show = np.concatenate([edge, np.zeros_like(edge[:,:,0:1])], axis=-1)#H W 3 126 | # imageio.imwrite('debug/{}_{}edge.png'.format(i,k), (show*255).astype(np.uint8)) 127 | # for m in range(6): 128 | # imageio.imwrite('debug/{}_{}mask{}.png'.format(i,k,m), (masks[:,:,:,m]*img*255).astype(np.uint8)) -------------------------------------------------------------------------------- /doc/flying_animals.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenYutongTHU/Learning-to-manipulate-individual-objects-in-an-image-Implementation/db75a5505f7fe2c83c0ded08f425ef11759544bd/doc/flying_animals.gif -------------------------------------------------------------------------------- /doc/multi_texture.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenYutongTHU/Learning-to-manipulate-individual-objects-in-an-image-Implementation/db75a5505f7fe2c83c0ded08f425ef11759544bd/doc/multi_texture.gif -------------------------------------------------------------------------------- /doc/objects_room.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenYutongTHU/Learning-to-manipulate-individual-objects-in-an-image-Implementation/db75a5505f7fe2c83c0ded08f425ef11759544bd/doc/objects_room.gif -------------------------------------------------------------------------------- /doc/pc.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenYutongTHU/Learning-to-manipulate-individual-objects-in-an-image-Implementation/db75a5505f7fe2c83c0ded08f425ef11759544bd/doc/pc.gif -------------------------------------------------------------------------------- /eval/eval_VAE.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | import gflags 4 | #https://github.com/google/python-gflags 5 | import sys 6 | sys.path.append("..") 7 | import pprint 8 | from keras.utils.generic_utils import Progbar 9 | import model.Summary as Summary 10 | from model.utils.generic_utils import myprint, myinput, Permute_IoU 11 | from model.traverse_graph import Traverse_Graph 12 | import imageio 13 | import numpy as np 14 | import time 15 | 16 | 17 | def convert2float(img): 18 | return (img/255).astype(np.float32) 19 | def convert2int(img): 20 | return (img*255).astype(np.uint8) 21 | 22 | 23 | def pad_img(img): 24 | H,W,C = img.shape 25 | pad_img = (np.ones([H+4, W+4,C])*255).astype(np.uint8) 26 | pad_img[2:2+H,2:2+W,:] = img #H W 3 27 | return pad_img 28 | 29 | def eval(FLAGS): 30 | graph = Traverse_Graph(FLAGS) 31 | graph.build() 32 | 33 | restore_vars = tf.global_variables('VAE') + tf.global_variables('Generator') + tf.global_variables('Fusion') 34 | saver = tf.train.Saver(restore_vars) 35 | 36 | #CIS_saver = tf.train.Saver(tf.global_variables('Generator')) 37 | with tf.Session() as sess: 38 | sess.run(tf.compat.v1.global_variables_initializer()) 39 | assert os.path.isfile(FLAGS.fullmodel_ckpt+'.index') 40 | saver.restore(sess, FLAGS.fullmodel_ckpt) 41 | # CIS_saver.restore(sess, FLAGS.CIS_ckpt) 42 | 43 | #saver.save(sess, '/home/yutong/Learning-to-manipulate-individual-objects-in-an-image-Implementation/save_checkpoint/md/model', global_step=0) 44 | myprint("resume model {}".format(FLAGS.fullmodel_ckpt)) 45 | fetches = { 46 | 'image_batch': graph.image_batch, 47 | 'generated_masks': graph.generated_masks, 48 | 'traverse_results': graph.traverse_results, 49 | 'out_bg': graph.out_bg, 50 | 'in_bg': graph.in_bg 51 | } 52 | assert FLAGS.batch_size==1 53 | input_img = convert2float(imageio.imread(FLAGS.input_img)) 54 | input_img = np.expand_dims(input_img, axis=0) 55 | 56 | results = sess.run(fetches, feed_dict={graph.image_batch0: input_img}) 57 | img = convert2int(results['image_batch'][0]) 58 | 59 | imageio.imwrite(os.path.join(FLAGS.checkpoint_dir, 'img.png'), img) 60 | for i in range(FLAGS.num_branch): 61 | imageio.imwrite(os.path.join(FLAGS.checkpoint_dir, 'segment_{}.png'.format(i)), convert2int(results['generated_masks'][0,:,:,:,i]*results['image_batch'][0])) 62 | 63 | outputs = np.array(results['traverse_results']) 64 | 65 | if FLAGS.traverse_type=='tex': 66 | nch = 3 67 | ndim = FLAGS.tex_dim 68 | elif FLAGS.traverse_type=='bg': 69 | nch = 3 70 | ndim = FLAGS.bg_dim 71 | else: 72 | nch = 1 73 | ndim = FLAGS.mask_dim 74 | 75 | if FLAGS.traverse_type=='bg': 76 | traverse_branch = [FLAGS.num_branch-1] 77 | else: 78 | traverse_branch = [i for i in range(0,FLAGS.num_branch) if FLAGS.traverse_branch=='all' or str(i) in FLAGS.traverse_branch.split(',')] 79 | traverse_value = list(np.linspace(FLAGS.traverse_start, FLAGS.traverse_end, 60)) 80 | 81 | 82 | if FLAGS.dataset == 'flying_animals': 83 | outputs = np.reshape(outputs, [len(traverse_branch), FLAGS.top_kdim, len(traverse_value),FLAGS.img_height//2,FLAGS.img_width//2,-1]) 84 | else: 85 | outputs = np.reshape(outputs, [len(traverse_branch), FLAGS.top_kdim, len(traverse_value),FLAGS.img_height,FLAGS.img_width,-1]) 86 | #tbranch * tdim * step * H * W * 3 87 | 88 | 89 | branches = [] 90 | for i in range(len(traverse_branch)): 91 | values = [[None for jj in range(FLAGS.top_kdim) ] for ii in range(len(traverse_value))] 92 | b = traverse_branch[i] 93 | out = outputs[i] #tdim * step* H * W * 3 94 | for d in range(FLAGS.top_kdim): 95 | gif_imgs = [] 96 | for j in range(len(traverse_value)): 97 | img = (out[d,j,:,:,:]*255).astype(np.uint8) 98 | gif_imgs.append(img) 99 | values[j][d] = pad_img(img) 100 | name = 'branch{}_var{}.gif'.format(b, d) 101 | imageio.mimsave(os.path.join(FLAGS.checkpoint_dir, name), gif_imgs, duration=1/30) 102 | 103 | #values len(traverse_value) * kdim (img) 104 | value_slices = [np.concatenate(values[j], axis=1) for j in range(len(traverse_value))] #group different dimensions along the axis x 105 | #len(traverse_value)*(H*W*3) 106 | branches.append(value_slices) 107 | merge_slices = [np.concatenate([branches[i][j] for i in range(len(traverse_branch))], axis=0) for j in range(len(traverse_value))] 108 | 109 | 110 | #imageio.mimsave(os.path.join(FLAGS.checkpoint_dir, 'output.gif'), merge_slices, duration=1/30) 111 | 112 | 113 | 114 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from itertools import count 3 | import os 4 | import gflags 5 | from git import Repo 6 | import sys 7 | import pprint 8 | from common_flags import FLAGS 9 | import warnings 10 | from trainer import train_inpainter, train_CIS, train_VAE, train_PC 11 | from eval import eval_VAE 12 | from model.utils.generic_utils import myprint, myinput 13 | import random 14 | import numpy as np 15 | 16 | def save_log(source, trg_dir, print_flags_dict, sha): 17 | file_name = source.split('/')[-1] 18 | new_file = os.path.join(trg_dir, file_name) 19 | log_name = 'log' 20 | while os.path.isfile(new_file): 21 | new_file = new_file[:-3]+'_c.sh' #.sh 22 | log_name += '_c' 23 | os.system('cp '+source+' '+ new_file) 24 | myprint ("Save "+source +" as "+new_file) 25 | log_file = os.path.join(trg_dir, log_name+'.txt') 26 | with open(log_file,'w') as log_stream: 27 | log_stream.write('commit:' + sha + '\n') 28 | pprint.pprint(print_flags_dict, log_stream) 29 | with open(new_file, 'a') as sh_stream: 30 | sh_stream.write('\n#commit:'+sha) 31 | myprint('Corresponding log file '+log_file) 32 | myinput("Enter to continue") 33 | os.system('chmod a=rx '+log_file) 34 | os.system('chmod a=rx '+new_file) 35 | return 36 | 37 | def complete_FLAGS(FLAGS): 38 | #complete some configuration given the speficied dataset 39 | 40 | img_size_dict = {'multi_texture': (64,64), 41 | 'multi_dsprites': (64,64), 42 | 'objects_room': (64,64), 43 | 'flying_animals': (192,256)} 44 | max_num_dict = {'multi_texture':4, 45 | 'objects_room': 5, 46 | 'multi_dsprites':4, 47 | 'flying_animals':5} 48 | FLAGS.img_height, FLAGS.img_width = img_size_dict[FLAGS.dataset] 49 | FLAGS.max_num = max_num_dict[FLAGS.dataset] 50 | if FLAGS.PC and FLAGS.dataset=='multi_texture': 51 | FLAGS.max_num=2 52 | 53 | if FLAGS.mode == 'pretrain_inpainter': 54 | FLAGS.num_branch = 2 55 | else: 56 | assert FLAGS.num_branch >= FLAGS.max_num+1 57 | 58 | FLAGS.n_bg = 3 if FLAGS.dataset=='objects_room' else 1 59 | return 60 | 61 | def main(argv): 62 | try: 63 | argv = FLAGS(argv) 64 | except gflags.FlagsError as e: 65 | print ('FlagsError: ', e) 66 | sys.exit(1) 67 | else: 68 | tf.compat.v1.set_random_seed(479) 69 | random.seed(101) 70 | complete_FLAGS(FLAGS) 71 | pp = pprint.PrettyPrinter() 72 | print_flags_dict = {} 73 | for key in FLAGS.__flags.keys(): 74 | print_flags_dict[key] = getattr(FLAGS, key) 75 | pp.pprint(print_flags_dict) 76 | myinput("Press enter to continue") 77 | 78 | repo = Repo() 79 | sha = repo.head.object.hexsha 80 | FLAGS.checkpoint_dir = FLAGS.checkpoint_dir[:-1] if FLAGS.checkpoint_dir[-1]=='/' else FLAGS.checkpoint_dir 81 | if os.path.exists(FLAGS.checkpoint_dir): 82 | I = myinput(FLAGS.checkpoint_dir+' already exists. \n Are you sure to' 83 | ' place the outputs in the same dir? Y or Y! or N\n' 84 | 'Y: resume training, save previous outputs in the dir and continue saving outputs in it\n' 85 | 'Y!: restart training, delete previous outputs in the dir\n' 86 | 'N to quit \n') 87 | if I in ['Y','y']: 88 | save_log(FLAGS.sh_path, FLAGS.checkpoint_dir, print_flags_dict, sha) 89 | import time 90 | tf.compat.v1.set_random_seed(time.localtime()[5]*10) 91 | random.seed(time.localtime()[4]*10) #new random seed 92 | elif I in ['N','n']: 93 | sys.exit(1) 94 | else: 95 | os.system('rm -f -r '+FLAGS.checkpoint_dir+'/*') 96 | save_log(FLAGS.sh_path, FLAGS.checkpoint_dir, print_flags_dict, sha) 97 | else: 98 | os.makedirs(FLAGS.checkpoint_dir) 99 | save_log(FLAGS.sh_path, FLAGS.checkpoint_dir, print_flags_dict, sha) 100 | 101 | 102 | assert FLAGS.mode in ['pretrain_inpainter','train_CIS', 'train_VAE', 'train_PC', 'train_end2end' 103 | 'eval_CIS', 'eval_VAE'] 104 | 105 | if FLAGS.mode == 'pretrain_inpainter': 106 | train_inpainter.train(FLAGS) 107 | elif FLAGS.mode == 'train_CIS': 108 | train_CIS.train(FLAGS) 109 | elif FLAGS.mode == 'eval_VAE': 110 | eval_VAE.eval(FLAGS) 111 | # elif FLAGS.mode == 'train_end2end': 112 | # train_end2end.train(FLAGS) 113 | elif FLAGS.mode == 'train_PC': 114 | train_PC.train(FLAGS) 115 | elif FLAGS.mode == 'train_VAE': 116 | train_VAE.train(FLAGS) 117 | else: 118 | pass 119 | # pass 120 | # elif FLAGS.mode == 'eval_CIS': 121 | # eval_CIS.eval(FLAGS) 122 | 123 | if __name__ == '__main__': 124 | main(sys.argv) 125 | -------------------------------------------------------------------------------- /model/Summary.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def convert2uint8(img_list): 4 | #float[0~1] -> uint8[0~255] 5 | new_img_list = [tf.cast(img*255, tf.uint8) for img in img_list] 6 | return new_img_list 7 | 8 | def collect_globalVAE_summary(graph, FLAGS): 9 | ori = graph.image_batch[0] 10 | reconstr = graph.out_imgs[0] 11 | show_list = convert2uint8([ori, reconstr]) 12 | tf.compat.v1.summary.image('image output', tf.stack(show_list, axis=0), max_outputs=len(show_list), collections=["globalVAE_Sum"]) 13 | 14 | tf.summary.scalar('Reconstruction_Loss', graph.loss, collections=["globalVAE_Sum"]) 15 | tf.summary.scalar('latent_space', graph.kl_var, collections=['globalVAE_Sum_kl']) 16 | 17 | for grad, var in graph.train_vars_grads: 18 | tf.summary.histogram(var.op.name+'/grad', grad, collections=['globalVAE_Sum']) 19 | 20 | return tf.summary.merge(tf.compat.v1.get_collection("globalVAE_Sum")), \ 21 | tf.summary.merge(tf.compat.v1.get_collection("globalVAE_Sum_kl")) 22 | 23 | 24 | def collect_CIS_summary(graph, FLAGS): 25 | #----image to show same with Inpainter_Sum---- 26 | ori = graph.image_batch[0] 27 | edge = tf.concat([graph.edge_map[0], tf.zeros_like(graph.edge_map[0,:,:,0:1])], axis=-1) #H W 3 28 | mean = graph.unconditioned_mean[0] 29 | show_list = convert2uint8([ori, edge,mean]) #0~255 30 | #show_list = [tf.cast(edge*128, tf.uint8)] 31 | tf.compat.v1.summary.image('image_edge_unconditionedMean', 32 | tf.stack(show_list, axis=0), max_outputs=len(show_list), 33 | collections=["CIS_Sum"]) 34 | 35 | for i in range(FLAGS.num_branch): 36 | mask = graph.generated_masks[0,:,:,:,i] 37 | #aug = mask*ori + ori*(1-mask)*0.2 38 | context = ori *(1-mask) # H W 3 39 | GT = ori*mask 40 | predict = graph.pred_intensities[0,:,:,:,i]*mask 41 | show_list = convert2uint8([GT,context,predict]) 42 | tf.compat.v1.summary.image('branch{}'.format(i), 43 | tf.stack(show_list, axis=0), max_outputs=len(show_list), 44 | collections=["CIS_Sum"]) 45 | 46 | #-----curve to show 47 | #plot multiple curvations in one figure 48 | tf.summary.scalar('Inpainter_Loss', graph.loss['Inpainter'], collections=['CIS_Sum']) 49 | tf.summary.scalar('Generator_Loss', graph.loss['Generator_var'], collections=['CIS_Sum_Generator']) 50 | tf.summary.scalar('Inpainter_Loss/ branch', graph.loss['Inpainter_branch_var'], collections=['CIS_Sum_branch']) 51 | tf.summary.scalar('Generator_Loss/ branch', graph.loss['Generator_branch_var'], collections=['CIS_Sum_branch']) 52 | tf.summary.scalar('Generator_Loss/ denominator', graph.loss['Generator_denominator_var'], collections=['CIS_Sum_branch']) 53 | tf.summary.scalar('IoU Validation',graph.loss['EvalIoU_var'], collections=['CIS_eval']) 54 | 55 | #------histogram to show 56 | for grad, var in graph.train_vars_grads['Generator']: 57 | tf.summary.histogram(var.op.name+'/grad', grad, collections=['CIS_Sum']) 58 | for grad, var in graph.train_vars_grads['Inpainter']: 59 | tf.summary.histogram(var.op.name+'/grad', grad, collections=['CIS_Sum']) 60 | return tf.summary.merge(tf.compat.v1.get_collection("CIS_Sum")), \ 61 | tf.summary.merge(tf.compat.v1.get_collection("CIS_Sum_Generator")), \ 62 | tf.summary.merge(tf.compat.v1.get_collection("CIS_Sum_branch")), \ 63 | tf.summary.merge(tf.compat.v1.get_collection("CIS_eval")) 64 | 65 | 66 | def collect_VAE_summary(graph, FLAGS): 67 | ori = graph.image_batch[0] 68 | fusion = graph.fusion_outputs[0] 69 | show_list = convert2uint8([ori, fusion]) 70 | tf.compat.v1.summary.image('image output', tf.stack(show_list, axis=0), max_outputs=len(show_list), collections=["VAE_Sum"]) 71 | 72 | 73 | seg_masks = tf.transpose(graph.generated_masks[0,:,:,:,:]*tf.expand_dims(ori, axis=-1),[3,0,1,2]) #N H W 3 74 | tf.compat.v1.summary.image('segmentation', tf.cast(seg_masks*255,tf.uint8), max_outputs=FLAGS.num_branch, collections=["VAE_Sum"]) 75 | 76 | for i in range(FLAGS.num_branch-FLAGS.n_bg): 77 | seg = graph.generated_masks[0,:,:,:,i]*ori 78 | out_mask = tf.tile(graph.VAE_outmasks[0,:,:,:,i],[1,1,3]) 79 | out_tex = graph.VAE_outtexes[0,:,:,:,i] 80 | fusion = graph.VAE_fusion[0,:,:,:,i] 81 | show_list = convert2uint8([seg, out_mask, out_tex, fusion]) 82 | tf.compat.v1.summary.image('branch{}'.format(i), tf.stack(show_list, axis=0), max_outputs=len(show_list), collections=["VAE_Sum"]) 83 | 84 | #background 85 | seg = tf.reduce_sum(graph.generated_masks[0,:,:,:,-1*FLAGS.n_bg:], axis=-1)*ori 86 | out_bg_tex = graph.VAE_outtex_bg[0,:,:,:] #H W 3 87 | show_list = convert2uint8([seg, out_bg_tex]) 88 | tf.compat.v1.summary.image('background', tf.stack(show_list, axis=0), max_outputs=len(show_list), collections=["VAE_Sum"]) 89 | 90 | tf.summary.scalar('Tex_error', graph.loss['tex_error'], collections=["VAE_Sum"]) 91 | tf.summary.scalar('Mask_error', graph.loss['mask_error'], collections=["VAE_Sum"]) 92 | tf.summary.scalar('BG_error', graph.loss['bg_error'], collections=["VAE_Sum"]) 93 | 94 | tf.summary.scalar('VAEFusion_error', graph.loss['VAE_fusion_error'], collections=["VAE_Sum"]) 95 | tf.summary.scalar('Fusion_Loss', graph.loss['Fusion'], collections=["VAE_Sum"]) 96 | 97 | tf.summary.scalar('Tex_latent_space', graph.loss['tex_kl_var'], collections=['VAE_Sum_tex']) 98 | tf.summary.scalar('Mask_latent_space', graph.loss['mask_kl_var'], collections=['VAE_Sum_mask']) 99 | tf.summary.scalar('BG_latent_space', graph.loss['bg_kl_var'], collections=['VAE_Sum_bg']) 100 | 101 | for grad, var in graph.train_vars_grads['VAE//separate']: 102 | tf.summary.histogram(var.op.name+'/grad', grad, collections=['VAE_Sum']) 103 | for grad, var in graph.train_vars_grads['VAE//fusion']: 104 | tf.summary.histogram(var.op.name+'/grad', grad, collections=['VAE_Sum']) 105 | for grad, var in graph.train_vars_grads['Fusion']: 106 | tf.summary.histogram(var.op.name+'/grad', grad, collections=['VAE_Sum']) 107 | 108 | return tf.summary.merge(tf.compat.v1.get_collection("VAE_Sum")), \ 109 | tf.summary.merge(tf.compat.v1.get_collection("VAE_Sum_tex")), \ 110 | tf.summary.merge(tf.compat.v1.get_collection("VAE_Sum_mask")), \ 111 | tf.summary.merge(tf.compat.v1.get_collection("VAE_Sum_bg")) 112 | 113 | def collect_end2end_summary(graph, FLAGS): 114 | ori = graph.image_batch[0] 115 | fusion = graph.fusion_outputs[0] 116 | show_list = convert2uint8([ori, fusion]) 117 | tf.compat.v1.summary.image('image output', tf.stack(show_list, axis=0), max_outputs=len(show_list), collections=["end2end_Sum"]) 118 | 119 | seg_masks = tf.transpose(graph.generated_masks[0,:,:,:,:]*tf.expand_dims(ori, axis=-1),[3,0,1,2]) #N H W 3 120 | tf.compat.v1.summary.image('segmentation', tf.cast(seg_masks*255,tf.uint8), max_outputs=FLAGS.num_branch, collections=["end2end_Sum"]) 121 | 122 | for i in range(FLAGS.num_branch-FLAGS.n_bg): 123 | seg = graph.generated_masks[0,:,:,:,i]*ori 124 | out_mask = tf.tile(graph.VAE_outmasks[0,:,:,:,i],[1,1,3]) 125 | out_tex = graph.VAE_outtexes[0,:,:,:,i] 126 | fusion = graph.VAE_fusion[0,:,:,:,i] 127 | show_list = convert2uint8([seg, out_mask, out_tex, fusion]) 128 | tf.compat.v1.summary.image('branch{}'.format(i), tf.stack(show_list, axis=0), max_outputs=len(show_list), collections=["end2end_Sum"]) 129 | 130 | #background 131 | seg = tf.reduce_sum(graph.generated_masks[0,:,:,:,-1*FLAGS.n_bg:], axis=-1)*ori 132 | out_bg_tex = graph.VAE_outtex_bg[0,:,:,:] #H W 3 133 | show_list = convert2uint8([seg, out_bg_tex]) 134 | tf.compat.v1.summary.image('background', tf.stack(show_list, axis=0), max_outputs=len(show_list), collections=["end2end_Sum"]) 135 | 136 | 137 | #-----curve to show------------- 138 | #tf.summary.scalar('CIS', graph.loss['CIS'], collections=['end2end_Sum']) 139 | #tf.summary.scalar('Inpainter_Loss', graph.loss['Inpainter'], collections=['end2end_Sum']) 140 | tf.summary.scalar('Tex_error', graph.loss['tex_error'], collections=["end2end_Sum"]) 141 | tf.summary.scalar('Mask_error', graph.loss['mask_error'], collections=["end2end_Sum"]) 142 | tf.summary.scalar('BG_error', graph.loss['bg_error'], collections=["end2end_Sum"]) 143 | 144 | tf.summary.scalar('VAEFusion_error', graph.loss['VAE_fusion_error'], collections=["end2end_Sum"]) 145 | tf.summary.scalar('Fusion_Loss', graph.loss['Fusion'], collections=["end2end_Sum"]) 146 | 147 | tf.summary.scalar('Tex_latent_space', graph.loss['tex_kl_var'], collections=['end2end_Sum_tex']) 148 | tf.summary.scalar('Mask_latent_space', graph.loss['mask_kl_var'], collections=['end2end_Sum_mask']) 149 | tf.summary.scalar('BG_latent_space', graph.loss['bg_kl_var'], collections=['end2end_Sum_bg']) 150 | 151 | tf.summary.scalar('IoU Validation',graph.loss['EvalIoU_var'], collections=['CIS_eval']) 152 | 153 | #-----histogram to show----------- 154 | for grad, var in graph.train_vars_grads['VAE//separate/texVAE']: 155 | tf.summary.histogram(var.op.name+'/grad', grad, collections=['end2end_Sum']) 156 | for grad, var in graph.train_vars_grads['VAE//separate/bgVAE']: 157 | tf.summary.histogram(var.op.name+'/grad', grad, collections=['end2end_Sum']) 158 | for grad, var in graph.train_vars_grads['VAE//fusion']: 159 | tf.summary.histogram(var.op.name+'/grad', grad, collections=['end2end_Sum']) 160 | for grad, var in graph.train_vars_grads['Fusion']: 161 | tf.summary.histogram(var.op.name+'/grad', grad, collections=['end2end_Sum']) 162 | 163 | return tf.summary.merge(tf.compat.v1.get_collection("end2end_Sum")), \ 164 | tf.summary.merge(tf.compat.v1.get_collection("end2end_Sum_tex")), \ 165 | tf.summary.merge(tf.compat.v1.get_collection("end2end_Sum_bg")), \ 166 | tf.summary.merge(tf.compat.v1.get_collection("CIS_eval")) 167 | 168 | def collect_PC_summary(graph, FLAGS): 169 | ori = graph.image_batch[0] 170 | show_list = [ori] 171 | for i in range(FLAGS.num_branch): 172 | seg = graph.generated_masks[0,:,:,:,i]*ori 173 | show_list.append(seg) 174 | show_list = convert2uint8(show_list) 175 | tf.compat.v1.summary.image('original image', tf.stack(show_list, axis=0), max_outputs=len(show_list), collections=["PC_Sum"]) 176 | 177 | for k in range(6): 178 | new = graph.new_imgs[k][0,:,:,:] #H W 3 179 | show_list = [new] 180 | for i in range(FLAGS.num_branch): 181 | seg = graph.new_generated_masks[k][0,:,:,:,i]*new 182 | show_list.append(seg) 183 | show_list.append(tf.tile(graph.new_labels[k][0,:,:,:,i],[1,1,3])) 184 | show_list = convert2uint8(show_list) 185 | tf.compat.v1.summary.image('perturbed image '+str(k), tf.stack(show_list, axis=0), max_outputs=len(show_list), collections=["PC_Sum"]) 186 | tf.summary.scalar('CIS', graph.loss['CIS'], collections=['PC_Sum']) 187 | tf.summary.scalar('Inpainter_Loss', graph.loss['Inpainter'], collections=['PC_Sum']) 188 | tf.summary.scalar('PC', graph.loss['PC'], collections=['PC_Sum']) 189 | tf.summary.scalar('IoU Validation',graph.loss['EvalIoU_var'], collections=['CIS_eval']) 190 | tf.summary.scalar('Identity Switching Rate',graph.switching_rate, collections=['CIS_eval']) 191 | 192 | return tf.summary.merge(tf.compat.v1.get_collection("PC_Sum")), \ 193 | tf.summary.merge(tf.compat.v1.get_collection("CIS_eval")), \ 194 | 195 | def collect_inpainter_summary(graph, FLAGS): 196 | #---------image to show------------- 197 | 198 | # original image edge_map 199 | ori = graph.image_batch[0] 200 | edge = tf.concat([graph.edge_map[0], tf.zeros_like(graph.edge_map[0,:,:,0:1])], axis=-1) #H W 3 201 | show_list = convert2uint8([ori, edge, graph.unconditioned_mean[0]]) #0~255 202 | 203 | tf.compat.v1.summary.image('image_edge', 204 | tf.stack(show_list, axis=0), max_outputs=len(show_list), 205 | collections=["Inpainter_Sum"]) 206 | 207 | for i in range(FLAGS.num_branch): 208 | mask = graph.generated_masks[0,:,:,:,i] 209 | context = ori *(1-mask) 210 | GT = ori*mask 211 | predict = graph.pred_intensities[0,:,:,:,i]*mask 212 | show_list = convert2uint8([GT,context,predict]) 213 | tf.compat.v1.summary.image('branch{}'.format(i), 214 | tf.stack(show_list, axis=0), max_outputs=len(show_list), 215 | collections=["Inpainter_Sum"]) 216 | 217 | loss = graph.loss['Inpainter'] 218 | tf.summary.scalar('Inpainter_Loss', loss, collections=['Inpainter_Sum']) 219 | 220 | for grad, var in graph.train_vars_grads['Inpainter']: 221 | tf.summary.histogram(var.op.name+'/grad', grad, collections=['Inpainter_Sum']) 222 | 223 | return tf.summary.merge(tf.compat.v1.get_collection('Inpainter_Sum')) -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') -------------------------------------------------------------------------------- /model/globalVAE_graph.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | from data import multi_texture_utils, flying_animals_utils, multi_dsprites_utils, objects_room_utils 4 | from .utils.generic_utils import bin_edge_map, train_op,myprint, myinput, erode_dilate, tf_resize_imgs, tf_normalize_imgs 5 | from .utils.loss_utils import Generator_Loss, Inpainter_Loss, Supervised_Generator_Loss 6 | from .nets import Generator_forward, Inpainter_forward, VAE_forward, Fusion_forward, encoder_decoder, gaussian_kl 7 | 8 | 9 | class Train_Graph(object): 10 | def __init__(self, FLAGS): 11 | self.config = FLAGS 12 | #load data 13 | self.batch_size = FLAGS.batch_size 14 | self.img_height, self.img_width = FLAGS.img_height, FLAGS.img_width 15 | #hyperparameters 16 | def build(self): 17 | train_dataset = self.load_training_data() 18 | self.train_iterator = train_dataset.make_one_shot_iterator() 19 | train_batch = self.train_iterator.get_next() 20 | 21 | self.image_batch, self.GT_masks = train_batch['img'], train_batch['masks'] 22 | self.image_batch.set_shape([None, self.img_height, self.img_width, 3]) 23 | 24 | with tf.compat.v1.variable_scope("VAE") as scope: 25 | z_mean, z_log_sigma_sq, out_logit = encoder_decoder(x=self.image_batch, output_ch=3, latent_dim=self.config.tex_dim, training=True) 26 | 27 | self.latent_loss_dim = tf.reduce_mean(gaussian_kl(z_mean, z_log_sigma_sq), 0) #average on batch dim, 28 | self.latent_loss = tf.reduce_sum(self.latent_loss_dim) 29 | 30 | 31 | self.out_imgs = tf.nn.sigmoid(out_logit) #0~1 32 | 33 | if self.config.VAE_loss == 'L1': 34 | self.reconstr_loss = tf.reduce_sum(tf.reduce_mean(tf.abs(self.image_batch-self.out_imgs), axis=0)) #B H W 3 35 | else: 36 | self.reconstr_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=self.image_batch, logits=out_logit) #B H W 3 37 | self.reconstr_loss = tf.reduce_sum(self.reconstr_loss, axis=[1,2,3]) #B, 38 | self.reconstr_loss = tf.reduce_mean(self.reconstr_loss) 39 | 40 | self.loss = self.reconstr_loss+self.config.tex_beta*self.latent_loss 41 | 42 | 43 | 44 | #------------------------------ 45 | with tf.name_scope('train_op'): 46 | self.global_step = tf.Variable(0, name='global_step', trainable=False) 47 | self.incr_global_step = tf.assign(self.global_step, self.global_step+1) 48 | self.train_ops, self.train_vars_grads = self.get_train_ops_grads() 49 | 50 | with tf.name_scope('summary_vars'): 51 | self.kl_var = tf.Variable(0.0, name='kl_var') 52 | # 53 | def get_train_ops_grads(self): 54 | optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=self.config.VAE_lr) 55 | train_vars = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, 'VAE') 56 | update_op = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS, 'VAE') 57 | train_ops, train_vars_grads = train_op(loss=self.loss, 58 | var_list=train_vars, optimizer=optimizer, gradient_clip_value=-1) 59 | train_ops = tf.group([train_ops, update_op]) 60 | return train_ops, train_vars_grads 61 | 62 | def load_training_data(self): 63 | if self.config.dataset == 'multi_texture': 64 | return multi_texture_utils.dataset(self.config.root_dir, val=False, 65 | batch_size=self.batch_size, max_num=self.config.max_num, 66 | zoom=('z' in self.config.variant), 67 | rotation=('r' in self.config.variant), 68 | texture_transform=self.config.texture_transform) 69 | elif self.config.dataset == 'flying_animals': 70 | return flying_animals_utils.dataset(self.config.root_dir,val=False, 71 | batch_size=self.batch_size, max_num=self.config.max_num) 72 | elif self.config.dataset == 'multi_dsprites': 73 | return multi_dsprites_utils.dataset(self.config.root_dir,val=False, 74 | batch_size=self.batch_size, skipnum=self.config.skipnum, takenum=self.config.takenum, 75 | shuffle=True, map_parallel_calls=tf.data.experimental.AUTOTUNE) 76 | elif self.config.dataset == 'objects_room': 77 | return objects_room_utils.dataset(self.config.root_dir,val=False, 78 | batch_size=self.batch_size, skipnum=self.config.skipnum, takenum=self.config.takenum, 79 | shuffle=True, map_parallel_calls=tf.data.experimental.AUTOTUNE) 80 | else: 81 | raise IOError("Unknown Dataset") 82 | -------------------------------------------------------------------------------- /model/nets.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from .utils.convolution_utils import gen_conv, gen_deconv, conv, deconv 3 | from .utils.convolution_utils import _dilated_conv2d, conv2d, deconv2d, conv2d_transpose, InstanceNorm, fully_connect 4 | from tensorflow.contrib.slim.nets import resnet_v2 5 | from .utils.generic_utils import bin_edge_map, erode_dilate 6 | from .utils.loss_utils import region_error 7 | import math 8 | import numpy as np 9 | 10 | def Generator_forward(images, dataset, num_mask, model='resnet_v2_50', scope='Generator', reuse=None, training=True): 11 | if 'resnet' in model: 12 | generated_masks, logits = generator_resnet(images, dataset, num_mask, 13 | reuse=reuse, training=training, scope=scope, model=model) #params 2e7 14 | else: 15 | generated_masks, logits = generator_segnet(images, num_mask, 16 | scope=scope, reuse=reuse, training=training) 17 | return tf.expand_dims(generated_masks, axis=-2), tf.expand_dims(logits, axis=-2) # B H W 1 C B H W 1 C 18 | 19 | def Inpainter_forward(num_branch, input_masks, images, scope, dataset, reuse=None, training=True): 20 | edge_map = bin_edge_map(images, dataset) 21 | pred_intensities = [] 22 | for m in range(num_branch): 23 | mask = input_masks[:,:,:,:,m] #B H W 1 24 | reuse = True if reuse==True else (m>0) 25 | pred_intensity = inpaint_net(images, mask, edge_map, scope=scope, reuse=reuse, training=training) 26 | pred_intensities.append(pred_intensity) 27 | unconditioned_mean = inpaint_net(tf.zeros_like(images), tf.ones_like(mask), 28 | edge_map, scope=scope, reuse=True, training=training) 29 | pred_intensities = tf.stack(pred_intensities, axis=-1) #B H W 3 C 30 | return pred_intensities, unconditioned_mean, edge_map 31 | 32 | def inpaint_net(image, mask, edge, scope, reuse=None, training=True): #params 3e6 33 | # intensity_masked 34 | # B H W 3 35 | image = image - 0.5 #0~1 -> -0.5~0.5 36 | intensity_masked = image*(1-mask) 37 | orisize = intensity_masked.get_shape().as_list()[1:-1] #[H,W] 38 | C = intensity_masked.get_shape().as_list()[-1] # 3 39 | f=0.5 40 | #edge B H W 2 41 | edge_in_channels = edge.get_shape().as_list()[-1] #2 42 | 43 | ones_x = tf.ones_like(intensity_masked)[:, :, :, 0:1] # B H W 1 44 | intensity_masked = tf.concat([intensity_masked, ones_x, 1-mask], axis=3) # B H W C+2 45 | intensity_in_channels = intensity_masked.get_shape().as_list()[-1] 46 | 47 | with tf.variable_scope(scope, reuse=reuse): 48 | 49 | aconv1 = conv( edge, 'aconv1', shape=[7,7, edge_in_channels, int(64*f)], stride=2, reuse=reuse, training=training ) # h/2(192), 64 50 | aconv2 = conv( aconv1, 'aconv2', shape=[5,5,int(64*f), int(128*f)], stride=2, reuse=reuse, training=training ) # h/4(96), 128 51 | aconv3 = conv( aconv2, 'aconv3', shape=[5,5,int(128*f),int(256*f)], stride=2, reuse=reuse, training=training ) # h/8(48), 256 52 | aconv31= conv( aconv3, 'aconv31', shape=[3,3,int(256*f),int(256*f)], stride=1, reuse=reuse, training=training ) 53 | aconv4 = conv( aconv31, 'aconv4', shape=[3,3,int(256*f),int(512*f)], stride=2, reuse=reuse, training=training ) # h/16(24), 512 54 | aconv41= conv( aconv4, 'aconv41', shape=[3,3,int(512*f),int(512*f)], stride=1, reuse=reuse, training=training ) 55 | aconv5 = conv( aconv41, 'aconv5', shape=[3,3,int(512*f),int(512*f)], stride=2, reuse=reuse, training=training ) # h/32(12), 512 56 | aconv51= conv( aconv5, 'aconv51', shape=[3,3,int(512*f),int(512*f)], stride=1, reuse=reuse, training=training ) 57 | aconv6 = conv( aconv51, 'aconv6', shape=[3,3,int(512*f),int(512*f)], stride=2, reuse=reuse, training=training ) # h/64(6), 512 58 | 59 | bconv1 = conv( intensity_masked, 'bconv1', shape=[7,7, intensity_in_channels, int(64*f)], stride=2, reuse=reuse, training=training ) # h/2(192), 64 60 | bconv2 = conv( bconv1, 'bconv2', shape=[5,5,int(64*f), int(128*f)], stride=2, reuse=reuse, training=training ) # h/4(96), 128 61 | bconv3 = conv( bconv2, 'bconv3', shape=[5,5,int(128*f),int(256*f)], stride=2, reuse=reuse, training=training ) # h/8(48), 256 62 | bconv31= conv( bconv3, 'bconv31', shape=[3,3,int(256*f),int(256*f)], stride=1, reuse=reuse, training=training ) 63 | bconv4 = conv( bconv31, 'bconv4', shape=[3,3,int(256*f),int(512*f)], stride=2, reuse=reuse, training=training ) # h/16(24), 512 64 | bconv41= conv( bconv4, 'bconv41', shape=[3,3,int(512*f),int(512*f)], stride=1, reuse=reuse, training=training ) 65 | bconv5 = conv( bconv41, 'bconv5', shape=[3,3,int(512*f),int(512*f)], stride=2, reuse=reuse, training=training ) # h/32(12), 512 66 | bconv51= conv( bconv5, 'bconv51', shape=[3,3,int(512*f),int(512*f)], stride=1, reuse=reuse, training=training ) 67 | bconv6 = conv( bconv51, 'bconv6', shape=[3,3,int(512*f),int(512*f)], stride=2, reuse=reuse, training=training ) # h/64(6), 512 68 | 69 | #conv6 = tf.add( aconv6, bconv6 ) 70 | conv6 = tf.concat( (aconv6, bconv6), 3 ) #h/64(6) 512*2*f 71 | outsz = bconv51.get_shape() # h/32(12), 512*f 72 | deconv5 = deconv( conv6, size=[outsz[1],outsz[2]], name='deconv5', shape=[4,4,int(512*2*f),int(512*f)], reuse=reuse, training=training ) 73 | concat5 = tf.concat( (deconv5,bconv51,aconv51), 3 ) # h/32(12), 512*3*f 74 | 75 | intensity5 = conv( concat5, 'intensity5', shape=[3,3,int(512*3*f),C], stride=1, reuse=reuse, training=training, activation=tf.identity ) # h/32(12), C 76 | outsz = bconv41.get_shape() # h/16(24), 512*f 77 | deconv4 = deconv( concat5, size=[outsz[1],outsz[2]], name='deconv4', shape=[4,4,int(512*3*f),int(512*f)], reuse=reuse, training=training ) 78 | upintensity4 = deconv( intensity5, size=[outsz[1],outsz[2]], name='upintensity4', shape=[4,4,C,C], reuse=reuse, training=training, activation=tf.identity ) 79 | concat4 = tf.concat( (deconv4,bconv41,aconv41,upintensity4), 3 ) # h/16(24), 512*3*f+C 80 | 81 | intensity4 = conv( concat4, 'intensity4', shape=[3,3,int(512*3*f+C),C], stride=1, reuse=reuse, training=training, activation=tf.identity ) # h/16(24), C 82 | outsz = bconv31.get_shape() # h/8(48), 256*f 83 | deconv3 = deconv( concat4, size=[outsz[1],outsz[2]], name='deconv3', shape=[4,4,int(512*3*f+C),int(256*f)], reuse=reuse, training=training ) 84 | upintensity3 = deconv( intensity4, size=[outsz[1],outsz[2]], name='upintensity3', shape=[4,4,C,C], reuse=reuse, training=training, activation=tf.identity ) 85 | concat3 = tf.concat( (deconv3,bconv31,aconv31,upintensity3), 3 ) # h/8(48), 256*3*f+C 86 | 87 | intensity3 = conv( concat3, 'intensity3', shape=[3,3,int(256*3*f+C),C], stride=1, reuse=reuse, training=training, activation=tf.identity ) # h/8(48), C 88 | outsz = bconv2.get_shape() # h/4(96), 128*f 89 | deconv2 = deconv( concat3, size=[outsz[1],outsz[2]], name='deconv2', shape=[4,4,int(256*3*f+C),int(128*f)], reuse=reuse, training=training ) 90 | upintensity2 = deconv( intensity3, size=[outsz[1],outsz[2]], name='upintensity2', shape=[4,4,C,C], reuse=reuse, training=training, activation=tf.identity ) 91 | concat2 = tf.concat( (deconv2,bconv2,aconv2,upintensity2), 3 ) # h/4(96), 128*3*f+C 92 | 93 | intensity2 = conv( concat2, 'intensity2', shape=[3,3,int(128*3*f+C),C], stride=1, reuse=reuse, training=training, activation=tf.identity ) # h/4(96), C 94 | outsz = bconv1.get_shape() # h/2(192), 64*f 95 | deconv1 = deconv( concat2, size=[outsz[1],outsz[2]], name='deconv1', shape=[4,4,int(128*3*f+C),int(64*f)], reuse=reuse, training=training ) 96 | upintensity1 = deconv( intensity2, size=[outsz[1],outsz[2]], name='upintensity1', shape=[4,4,C,C], reuse=reuse, training=training, activation=tf.identity ) 97 | concat1 = tf.concat( (deconv1,bconv1,aconv1,upintensity1), 3 ) # h/2(192), 64*3*f+C 98 | 99 | intensity1 = conv( concat1, 'intensity1', shape=[5,5,int(64*3*f+C),C], stride=1, reuse=reuse, training=training, activation=tf.identity ) # h/2(192), C 100 | pred_intensity = tf.image.resize_images(intensity1, size=orisize) 101 | 102 | pred_intensity = pred_intensity + 0.5 103 | return pred_intensity 104 | 105 | 106 | def generator_resnet(images, dataset , num_mask, scope, model='resnet_v2_50', reuse=None, training=True): 107 | #images = 0~1 108 | images = (images-0.5)*2 #-1 ~ 1 109 | assert dataset in ['multi_texture','flying_animals','multi_dsprites','objects_room'] 110 | if dataset in ['multi_texture', 'multi_dsprites','objects_room']: 111 | images = tf.image.resize_images(images, size=(128,128)) #64*64 -> 128*128 112 | #pad to 32k+1 113 | x = tf.pad(images, paddings=[[0,0],[0,1],[0,1],[0,0]], mode='REFLECT') 114 | dilations = [6, 12, 18, 24] if dataset=='flying_animals' else [2,4,6,8] 115 | o = [] 116 | with tf.compat.v1.variable_scope(scope, reuse=reuse): 117 | if model=='resnet_v2_50': 118 | net, end_points = resnet_v2.resnet_v2_50(x, None, is_training=training, global_pool=False, output_stride=4, reuse=reuse, scope=None) 119 | elif model=='resnet_v2_101': 120 | net, end_points = resnet_v2.resnet_v2_101(x, None, is_training=training, global_pool=False, output_stride=4, reuse=reuse, scope=None) 121 | else: 122 | raise IOError("Only resnet_v2_50 or resnet_v2_101 available") 123 | #classfication 124 | with tf.compat.v1.variable_scope(scope, reuse=tf.AUTO_REUSE): 125 | for i, d in enumerate(dilations): 126 | o.append(_dilated_conv2d(net, 3, num_mask, d, name='aspp/conv%d' % (i+1), biased=True)) 127 | logits = tf.add_n(o) 128 | 129 | if dataset == 'flying_animals': 130 | logits = tf.image.resize_images(logits, size=(193,257)) #B H W C align the feature 131 | logits = logits[:,:-1,:-1,:] #192 256 132 | else: 133 | logits = tf.image.resize_images(logits, size=(129,129)) 134 | logits = logits[:,:-1,:-1,:] #128 128 135 | logits = tf.image.resize_images(logits, size=(64,64)) #64 64 136 | generated_masks = tf.nn.softmax(logits, axis=-1) # B H W cnum 137 | return generated_masks, logits 138 | 139 | def generator_segnet(images, num_mask, scope, reuse=None, training=True, div=10.0): #cnum=32 #params1.5e6 140 | """Mask network. 141 | Args: 142 | image: input rgb image [0, 1] 143 | num_mask: number of mask 144 | Returns: 145 | mask: mask region [0, 1], 1 is fully masked, 0 is not. *num_mask 146 | """ 147 | 148 | mask_channels = num_mask # probability of each mask 149 | x = images 150 | cnum = 64 151 | with tf.compat.v1.variable_scope(scope, reuse=reuse): 152 | # stage1 153 | x_0 = gen_conv(x, cnum, 5, 1, name='conv1', training=training) # --------------------------- 154 | x = gen_conv(x_0, 2*cnum, 3, 2, name='conv2_downsample', training=training) # Skip connection 155 | x_1 = gen_conv(x, 2*cnum, 3, 1, name='conv3', training=training) # ------------------- 156 | x = gen_conv(x_1, 4*cnum, 3, 2, name='conv4_downsample', training=training) 157 | x = gen_conv(x, 4*cnum, 3, 1, name='conv5', training=training) 158 | x_2 = gen_conv(x, 4*cnum, 3, 1, name='conv6', training=training) # ----------------- 159 | x = gen_conv(x_2, 4*cnum, 3, rate=2, name='conv7_atrous', training=training) 160 | x = gen_conv(x, 4*cnum, 3, rate=4, name='conv8_atrous', training=training) 161 | x = gen_conv(x, 4*cnum, 3, rate=8, name='conv9_atrous', training=training) 162 | x = gen_conv(x, 4*cnum, 3, rate=16, name='conv10_atrous', training=training) 163 | x = gen_conv(x, 4*cnum, 3, 1, name='conv11', training=training) + x_2 #------------- 164 | x = gen_conv(x, 4*cnum, 3, 1, name='conv12', training=training) 165 | x = gen_deconv(x, 2*cnum, name='conv13_upsample', training=training) 166 | x = gen_conv(x, 2*cnum, 3, 1, name='conv14', training=training) + x_1 # -------------------- 167 | x = gen_deconv(x, cnum, name='conv15_upsample', training=training) + x_0 #------------------- 168 | x = gen_conv(x, cnum//2, 3, 1, name='conv16', training=training) 169 | x = gen_conv(x, mask_channels, 3, 1, activation=tf.identity, 170 | name='conv17', training=training) 171 | # Division by constant experimentally improved training 172 | x = tf.divide(x, tf.constant(div)) 173 | generated_mask = tf.nn.softmax(x, axis=-1) #soft mask normalization #B*H*W*num_mask 174 | return generated_mask, x 175 | 176 | 177 | def _sample_z(z_mean, z_log_sigma_sq): 178 | eps_shape = tf.shape(z_mean) 179 | eps = tf.random_normal( eps_shape, 0, 1, dtype=tf.float32 ) 180 | # z = mu + sigma * epsilon 181 | z = tf.add(z_mean,tf.multiply(tf.sqrt(tf.exp(z_log_sigma_sq)), eps)) 182 | return z 183 | 184 | def encoder(x, latent_dim, training=True): 185 | B, H, W, C = x.get_shape().as_list() 186 | conv1 = conv2d(x, filter_shape=[4,4,C,32], stride=2, padding='SAME', name='conv1', biased=True, dilation=1) 187 | conv1 = tf.nn.relu(conv1) 188 | 189 | conv2 = conv2d(conv1, filter_shape=[4,4,32,32], stride=2, padding='SAME', name='conv2', biased=True, dilation=1) 190 | conv2 = tf.nn.relu(conv2) 191 | 192 | conv3 = conv2d(conv2, filter_shape=[4,4,32,32], stride=2, padding='SAME', name='conv3', biased=True, dilation=1) 193 | conv3 = tf.nn.relu(conv3) 194 | 195 | conv4 = conv2d(conv3, filter_shape=[4,4,32,32], stride=2, padding='SAME', name='conv4', biased=True, dilation=1) 196 | conv4 = tf.nn.relu(conv4) 197 | #B 4 4 32 198 | 199 | flatten = tf.reshape(conv4, [-1,(H//16)*(W//16)*32]) 200 | fc1 = fully_connect(flatten, weight_shape=[(H//16)*(W//16)*32, 256], name='fc1', biased=True) 201 | fc1 = tf.nn.relu(fc1) 202 | fc2 = fully_connect(fc1, weight_shape=[256, 256], name='fc2', biased=True) 203 | fc2 = tf.nn.relu(fc2) 204 | 205 | z_mean = fully_connect(fc2, weight_shape=[256, latent_dim], name='fc_zmean', biased=True, bias_init_value=0.0) 206 | z_log_sigma_sq = fully_connect(fc2, weight_shape=[256, latent_dim], name='fc_logsigmasq', biased=True, bias_init_value=0.0) 207 | return z_mean, z_log_sigma_sq 208 | 209 | def decoder(z, output_ch, latent_dim, x, training=True): 210 | B, H, W, C = x.get_shape().as_list() 211 | up_fc2 = fully_connect(z, weight_shape=[latent_dim, 256], name='up_fc2', biased=True) 212 | up_fc2 = tf.nn.relu(up_fc2) 213 | 214 | up_fc1 = fully_connect(up_fc2, weight_shape=[256, (H//16)*(W//16)*32], name='up_fc1', biased=True) 215 | up_fc1 = tf.nn.relu(up_fc1) 216 | up_fc1 = tf.reshape(up_fc1, [-1,H//16,W//16,32]) 217 | 218 | deconv4 = deconv2d(up_fc1, filter_shape=[4,4,32,32], output_size=[H//8,W//8], name='deconv4', padding='SAME', biased=True) 219 | deconv4 = tf.nn.leaky_relu(deconv4) 220 | 221 | deconv3 = deconv2d(deconv4, filter_shape=[4,4,32,32], output_size=[H//4,W//4], name='deconv3', padding='SAME', biased=True) 222 | deconv3 = tf.nn.leaky_relu(deconv3) 223 | 224 | deconv2 = deconv2d(deconv3, filter_shape=[4,4,32,32], output_size=[H//2,W//2], name='deconv2', padding='SAME', biased=True) 225 | deconv2 = tf.nn.leaky_relu(deconv2) 226 | 227 | deconv1 = deconv2d(deconv2, filter_shape=[4,4,32,output_ch], output_size=[H,W], name='deconv1', padding='SAME', biased=True) 228 | 229 | out_logit = tf.identity(deconv1) 230 | return out_logit 231 | 232 | def tex_mask_fusion(tex, mask): 233 | inputs = tf.concat([tex, mask], axis=-1)#B H W (3+1) 234 | 235 | inputs = tf.pad(inputs, paddings=[[0,0],[1,2],[1,2],[0,0]], mode='REFLECT') 236 | conv1 = conv2d(inputs, filter_shape=[4,4,4,32], stride=1, padding='VALID', name='conv1', biased=True, dilation=1) 237 | conv1 = tf.nn.relu(conv1) 238 | 239 | conv1 = tf.pad(conv1, paddings=[[0,0],[1,2],[1,2],[0,0]], mode='REFLECT') 240 | conv2 = conv2d(conv1, filter_shape=[4,4,32,32], stride=1, padding='VALID', name='conv2', biased=True, dilation=1) 241 | conv2 = tf.nn.relu(conv2) 242 | 243 | conv2 = tf.pad(conv2, paddings=[[0,0],[1,2],[1,2],[0,0]], mode='REFLECT') 244 | conv3 = conv2d(conv2, filter_shape=[4,4,32,3], stride=1, padding='VALID', name='conv3', biased=True, dilation=1) 245 | output = tf.nn.sigmoid(conv3) 246 | 247 | return output, conv3 248 | 249 | def encoder_decoder(x, output_ch, latent_dim,training=True): 250 | z_mean, z_log_sigma_sq = encoder(x, latent_dim, training=training) 251 | z = _sample_z(z_mean, z_log_sigma_sq) 252 | out_logit = decoder(z, output_ch, latent_dim, x, training=training) 253 | return z_mean, z_log_sigma_sq, out_logit 254 | 255 | def gaussian_kl(mean, log_sigma_sq): 256 | latent_loss = -0.5 * (1 + log_sigma_sq 257 | - tf.square(mean) 258 | - tf.exp(log_sigma_sq)) #B*Z_dim 259 | return latent_loss # B*z_dim 260 | 261 | def VAE_forward(image, masks, bg_dim, tex_dim, mask_dim, scope='VAE', reuse=None, training=True, augmentation=False): 262 | B, H, W, C = image.get_shape().as_list() 263 | num_branch = masks.get_shape().as_list()[-1] #B H W 1 M 264 | 265 | with tf.compat.v1.variable_scope(scope, reuse=reuse): 266 | tex_kl, out_texes = [],[] 267 | mask_kl, out_masks_logit = [], [] 268 | fusion_error, out_fusion = [], [] 269 | latent_zs = {'tex':[], 'mask':[], 'bg':None} 270 | 271 | for i in range(num_branch): 272 | 273 | 274 | inputs = image*masks[:,:,:,:,i] #B H W 3 275 | with tf.compat.v1.variable_scope('separate/texVAE', reuse=tf.compat.v1.AUTO_REUSE): 276 | z_mean, z_log_sigma_sq, out_logit = encoder_decoder(inputs, output_ch=3, latent_dim=tex_dim, training=training) 277 | out_texes.append(tf.nn.sigmoid(out_logit)) #B, 278 | tex_kl.append(tf.reduce_mean(gaussian_kl(z_mean, z_log_sigma_sq), 0)) #B,dim -> dim 279 | latent_zs['tex'].append(z_mean) 280 | 281 | inputs = masks[:,:,:,:,i] 282 | with tf.compat.v1.variable_scope('separate/maskVAE', reuse=tf.compat.v1.AUTO_REUSE): 283 | z_mean, z_log_sigma_sq, out_logit = encoder_decoder(inputs, output_ch=1, latent_dim=mask_dim, training=training) 284 | out_masks_logit.append(out_logit) 285 | mask_kl.append(tf.reduce_mean(gaussian_kl(z_mean, z_log_sigma_sq), 0)) 286 | latent_zs['mask'].append(z_mean) 287 | 288 | #fuse tex and mask 289 | tex, mask = out_texes[-1], tf.nn.sigmoid(out_masks_logit[-1]) #B H W 3 B H W 1 290 | with tf.compat.v1.variable_scope('fusion', reuse=tf.compat.v1.AUTO_REUSE): 291 | fus_output, fus_output_logits = tex_mask_fusion(tex, mask) #B H W 3 292 | out_fusion.append(fus_output) 293 | error = tf.nn.sigmoid_cross_entropy_with_logits(labels=image*masks[:,:,:,:,i], logits=fus_output_logits) # B H W 3 294 | error = tf.reduce_mean(tf.reduce_sum(error, axis=[1,2,3]), axis=0) #B H W 3 -> B -> mean scalar 295 | fusion_error.append(error) 296 | 297 | #KL divergence loss 298 | tex_kl = tf.reduce_mean(tf.stack(tex_kl, axis=0), axis=0)# branch,dim -> dim, 299 | 300 | #reconstruction error 301 | out_texes = tf.stack(out_texes, axis=-1) # B H W 3 M 302 | tex_error = tf.reduce_mean(region_error(X=out_texes, Y=image, region=masks)) #BHW3M BHW3 BHW1M ->B,M -> scalar 303 | 304 | out_masks_logit = tf.stack(out_masks_logit, axis=-1) #B H W 1 M 305 | out_masks = tf.nn.sigmoid(out_masks_logit) #B H W 1 M 306 | 307 | if not augmentation: 308 | mask_error_pixel = tf.nn.sigmoid_cross_entropy_with_logits(labels=masks, logits=out_masks_logit) #B H W 1 M 309 | mask_error_sum = tf.reduce_sum(mask_error_pixel, axis=[1,2,3]) #B,M 310 | mask_error = tf.reduce_mean(mask_error_sum) 311 | mask_kl = tf.reduce_mean(tf.stack(mask_kl, axis=0), axis=0) 312 | else: 313 | #-----------------data augmentation-------------- 314 | #----------generate more position variation to help VAE decompose position feature--------------- 315 | rep = 2 316 | aug_masks = tf.tile(masks, [2*rep,1,1,1,1]) #B H W 1 M -> 2*rep*B H W 1 M 317 | aug_masks = tf.transpose(aug_masks, perm=[0,4,1,2,3]) #2*rep*B M H W 1 318 | aug_masks = tf.reshape(aug_masks, [2*B*rep*num_branch,H,W,1]) 319 | dx = tf.random.uniform(shape=[2*rep*B*num_branch,1],dtype=tf.dtypes.float32,minval=-1*30,maxval=30) 320 | dy = tf.random.uniform(shape=[2*rep*B*num_branch,1],dtype=tf.dtypes.float32,minval=-1*30,maxval=30) 321 | aug_masks = tf.contrib.image.translate(aug_masks,translations=tf.concat([dx,dy], axis=-1),interpolation='NEAREST') 322 | aug_masks = tf.random.shuffle(aug_masks, seed=None) 323 | inputs = aug_masks[0:rep*B*num_branch] #rep*B*M,H,W,1 324 | 325 | with tf.compat.v1.variable_scope('separate/maskVAE', reuse=tf.compat.v1.AUTO_REUSE): 326 | z_mean, z_log_sigma_sq, out_logit = encoder_decoder(inputs, output_ch=1, latent_dim=mask_dim, training=training) 327 | mask_error_pixel= tf.nn.sigmoid_cross_entropy_with_logits(labels=inputs, logits=out_logit) #B*rep*M H W 1 328 | mask_error_sum = tf.reduce_sum(mask_error_pixel, axis=[1,2,3]) #B,M 329 | mask_error = tf.reduce_mean(mask_error_sum) 330 | mask_kl= tf.reduce_mean(gaussian_kl(z_mean, z_log_sigma_sq), 0) #reduce batch 331 | 332 | 333 | out_fusion = tf.stack(out_fusion, axis=-1) #B H W 3 M 334 | fusion_error = tf.reduce_mean(tf.stack(fusion_error, axis=0)) #average on all branch 335 | 336 | # for background, we only learn the representation of its texture 337 | bg_mask = 1-tf.reduce_sum(masks, axis=-1) 338 | inputs = image*bg_mask #B H W 1 339 | with tf.compat.v1.variable_scope('separate/bgVAE', reuse=tf.compat.v1.AUTO_REUSE): 340 | z_mean, z_log_sigma_sq, out_logit = encoder_decoder(inputs, output_ch=3, latent_dim=bg_dim, training=training) 341 | out_bg = tf.nn.sigmoid(out_logit) 342 | bg_error = tf.reduce_mean(region_error(X=out_bg, Y=image, region=tf.expand_dims(bg_mask, axis=-1))) 343 | bg_kl= tf.reduce_mean(gaussian_kl(z_mean, z_log_sigma_sq), 0) #B,dim -> dim 344 | latent_zs['bg'] = z_mean 345 | 346 | loss = {'mask_kl': mask_kl, 'tex_kl': tex_kl, 'bg_kl':bg_kl, 347 | 'mask_error':mask_error, 'tex_error': tex_error, 'bg_error': bg_error , 'fusion_error': fusion_error} 348 | outputs = {'out_masks': out_masks, 'out_texes': out_texes, 'out_bg': out_bg, 'out_fusion': out_fusion} 349 | latent_zs['tex'] = tf.stack(latent_zs['tex'], axis=0) #branch-1, B, dim 350 | latent_zs['mask'] = tf.stack(latent_zs['mask'], axis=0) #branch-1 B dim 351 | 352 | return loss, outputs, latent_zs 353 | 354 | 355 | def Fusion_forward(inputs, scope='Fusion', training=True, reuse=None): 356 | #inputs B H W N*C 357 | B, H, W, C = inputs.get_shape().as_list() 358 | with tf.compat.v1.variable_scope(scope, reuse=reuse): 359 | 360 | x = tf.pad(inputs, paddings=[[0,0],[1,2],[1,2],[0,0]], mode='REFLECT') 361 | conv1 = conv2d(x, filter_shape=[4,4,C,32], stride=1, padding='VALID', name='conv1', biased=True, dilation=1) 362 | conv1 = tf.nn.relu(conv1) 363 | 364 | conv1 = tf.pad(conv1, paddings=[[0,0],[1,2],[1,2],[0,0]], mode='REFLECT') 365 | conv2 = conv2d(conv1, filter_shape=[4,4,32,32], stride=1, padding='VALID', name='conv2', biased=True, dilation=1) 366 | conv2 = tf.nn.relu(conv2) 367 | 368 | conv2 = tf.pad(conv2, paddings=[[0,0],[1,2],[1,2],[0,0]], mode='REFLECT') 369 | conv3 = conv2d(conv2, filter_shape=[4,4,32,32], stride=1, padding='VALID', name='conv3', biased=True, dilation=1) 370 | conv3 = tf.nn.relu(conv3) 371 | 372 | 373 | conv3 = tf.pad(conv3, paddings=[[0,0],[1,2],[1,2],[0,0]], mode='REFLECT') 374 | conv4 = conv2d(conv3, filter_shape=[4,4,32,3], stride=1, padding='VALID', name='conv4', biased=True, dilation=1) 375 | out = tf.nn.sigmoid(conv4) 376 | 377 | return out 378 | 379 | 380 | def Perturbation_forward(var_num, image, generated_masks,VAE_outputs0, latent_zs, mask_top_dims, scope='VAE/'): 381 | B, H, W, C, M = generated_masks.get_shape().as_list() 382 | assert M==3 383 | mask_dim = latent_zs['mask'][0].get_shape().as_list()[1] 384 | tex_dim = latent_zs['tex'][0].get_shape().as_list()[1] 385 | k = tf.random.uniform(shape=[], dtype=tf.int32, minval=0, maxval=2) 386 | b = tf.cond(tf.math.equal(k,0), lambda:1, lambda:0) 387 | 388 | new_outs, new_labels = [], [] 389 | for i in range(var_num): 390 | 391 | dim_x, dim_y = mask_top_dims[0], mask_top_dims[1] 392 | new_x = tf.random.uniform(shape=[B,1], dtype=tf.float32, minval=-1, maxval=1) #B,1 393 | new_y = tf.random.uniform(shape=[B,1], dtype=tf.float32, minval=-1, maxval=1) #B,1 394 | mask_z = tf.concat([latent_zs['mask'][k,:,:dim_x], 395 | new_x, 396 | latent_zs['mask'][k,:,dim_x+1:dim_y], 397 | new_y, 398 | latent_zs['mask'][k,:,dim_y+1:]], axis=-1) 399 | 400 | with tf.compat.v1.variable_scope(scope+'/separate/maskVAE', reuse=tf.compat.v1.AUTO_REUSE): 401 | new_mask_logit = decoder(z=mask_z, output_ch=1, latent_dim=mask_dim, x=image, training=True) 402 | new_mask = tf.nn.sigmoid(new_mask_logit) #B H W 1 403 | 404 | another_mask = VAE_outputs0['out_masks'][:,:,:,:,b] 405 | new_mask = (1-another_mask)*new_mask 406 | bg_mask = 1-new_mask-another_mask 407 | new_masks = tf.cond(tf.math.equal(k,0), 408 | lambda:tf.stack([new_mask,another_mask,bg_mask], axis=-1), 409 | lambda:tf.stack([another_mask,new_mask,bg_mask], axis=-1)) 410 | 411 | with tf.compat.v1.variable_scope(scope+'/fusion', reuse=tf.compat.v1.AUTO_REUSE): 412 | new_fus_output, null = tex_mask_fusion(VAE_outputs0['out_texes'][:,:,:,:,k], new_mask) #B H W 3 413 | 414 | foregrounds = tf.stack([new_fus_output, VAE_outputs0['out_fusion'][:,:,:,:,b]], axis=-1) #B H W 3 M 415 | backgrounds = tf.expand_dims(VAE_outputs0['out_bg'], axis=-1)* \ 416 | tf.expand_dims(bg_mask, axis=-1) 417 | 418 | fusion_inputs = tf.concat([foregrounds, backgrounds], axis=-1) #B H W 3 fg_branch+1 419 | fusion_inputs = tf.reshape(fusion_inputs, [B,fusion_inputs.get_shape()[1],fusion_inputs.get_shape()[2],-1]) 420 | fusion_outputs = Fusion_forward(inputs=fusion_inputs, scope='Fusion/', training=True, reuse=tf.compat.v1.AUTO_REUSE) #B H W 3 421 | new_outs.append(fusion_outputs) 422 | new_labels.append(new_masks) 423 | 424 | return new_outs, new_labels 425 | 426 | -------------------------------------------------------------------------------- /model/train_graph.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | from data import multi_texture_utils, flying_animals_utils, multi_dsprites_utils, objects_room_utils 4 | from .utils.generic_utils import bin_edge_map, train_op,myprint, myinput, erode_dilate, tf_resize_imgs, tf_normalize_imgs, reorder_mask 5 | from .utils.loss_utils import Generator_Loss, Inpainter_Loss, Supervised_Generator_Loss 6 | from .nets import Generator_forward, Inpainter_forward, VAE_forward, Fusion_forward, Perturbation_forward 7 | 8 | mode2scopes = { 9 | 'pretrain_inpainter': ['Inpainter'], 10 | 'train_CIS': ['Inpainter','Generator'], 11 | 'train_PC': ['Inpainter','Generator'], 12 | 'train_VAE': ['VAE//separate', 'VAE//fusion', 'Fusion'], 13 | 'train_end2end': ['Inpainter','Generator','VAE//separate','VAE//fusion', 'Fusion'] 14 | } 15 | 16 | class Train_Graph(object): 17 | def __init__(self, FLAGS): 18 | self.config = FLAGS 19 | self.batch_size = FLAGS.batch_size 20 | self.num_branch = FLAGS.num_branch 21 | self.img_height, self.img_width = FLAGS.img_height, FLAGS.img_width 22 | def build(self): 23 | self.is_training = tf.placeholder_with_default(True, shape=(), name="is_training") 24 | 25 | train_dataset, val_dataset = self.load_training_data(), self.load_val_data() 26 | self.train_iterator = train_dataset.make_one_shot_iterator() 27 | self.val_iterator = val_dataset.make_initializable_iterator() 28 | train_batch = self.train_iterator.get_next() 29 | 30 | current_batch = tf.cond(self.is_training, lambda: train_batch, lambda: self.val_iterator.get_next()) 31 | 32 | self.image_batch, self.GT_masks = current_batch['img'], current_batch['masks'] 33 | self.image_batch.set_shape([self.batch_size, self.img_height, self.img_width, 3]) 34 | self.GT_masks.set_shape([self.batch_size, self.img_height, self.img_width, 1, None]) 35 | if self.config.mode == 'pretrain_inpainter': 36 | self.generated_masks = self.Random_boxes() 37 | else: 38 | with tf.name_scope("Generator") as scope: 39 | self.generated_masks, generated_logits = Generator_forward(self.image_batch, self.config.dataset, 40 | self.num_branch, model=self.config.model, training=self.is_training, reuse=None, scope=scope) 41 | 42 | 43 | self.loss = {} 44 | with tf.name_scope("Inpainter") as scope: 45 | self.pred_intensities, self.unconditioned_mean, self.edge_map = \ 46 | Inpainter_forward(self.num_branch, self.generated_masks, self.image_batch, dataset=self.config.dataset, 47 | reuse=None, training=self.is_training, scope=scope) 48 | 49 | self.loss['Inpainter'], self.loss['Inpainter_branch'] = Inpainter_Loss(self.generated_masks, self.pred_intensities, self.image_batch) 50 | self.loss['Generator'], self.loss['Generator_branch'], self.loss['Generator_denominator'], self.loss['Generator_numerator'] = \ 51 | Generator_Loss(self.generated_masks, self.pred_intensities, self.image_batch, 52 | self.unconditioned_mean, self.config.epsilon) 53 | 54 | if self.config.dataset == 'flying_animals': 55 | #normalize self.image_batch 56 | self.image_batch = tf_normalize_imgs(self.image_batch) 57 | #-----------VAE---------------- 58 | if self.config.mode in ['train_VAE','train_end2end', 'train_PC']: 59 | with tf.name_scope("VAE") as scope: 60 | #-------------erode dilate smooth------------ 61 | filter_masks = [] 62 | for i in range(self.config.num_branch): 63 | filter_masks.append(erode_dilate(self.generated_masks[:,:,:,:,i])) 64 | self.generated_masks = tf.stack(filter_masks, axis=-1)#B H W 1 num_branch 65 | self.generated_masks = reorder_mask(self.generated_masks) 66 | if self.config.dataset == 'flying_animals': 67 | #resize image to (96, 128) 68 | self.image_batch = tf_resize_imgs(self.image_batch, size=[self.img_height//2,self.img_width//2]) 69 | self.generated_masks = tf_resize_imgs(self.generated_masks, size=[self.img_height//2, self.img_width//2]) 70 | 71 | VAE_loss, VAE_outputs, latent_zs = VAE_forward(image=self.image_batch, masks=self.generated_masks[:,:,:,:,:-1*self.config.n_bg], 72 | bg_dim=self.config.bg_dim, tex_dim=self.config.tex_dim, mask_dim=self.config.mask_dim, 73 | scope=scope, reuse=None, training=self.is_training, augmentation=self.config.PC) 74 | self.VAE_outputs = VAE_outputs 75 | if self.config.mode=='train_PC': 76 | #----------perceptual consistency--------------- 77 | mask_top_dims = tf.sort(tf.math.top_k(input=VAE_loss['mask_kl'], k=2)[1]) 78 | # choose latent dimensions with highest kl divergence to perturb, as those dimensions have semantic meanings 79 | self.new_imgs, self.new_labels = Perturbation_forward( 80 | var_num=6, image=self.image_batch, generated_masks=self.generated_masks, 81 | VAE_outputs0 = VAE_outputs, latent_zs = latent_zs, 82 | mask_top_dims=mask_top_dims) 83 | 84 | self.new_generated_masks = [] 85 | PC_loss = [] 86 | for i in range(6): 87 | with tf.name_scope("Generator/") as scope: 88 | new_masks, new_logits = Generator_forward(self.new_imgs[i], self.config.dataset, 89 | self.num_branch, model=self.config.model, training=True, reuse=tf.AUTO_REUSE, scope=scope) 90 | self.new_generated_masks.append(new_masks) 91 | loss = tf.nn.softmax_cross_entropy_with_logits(labels=tf.stop_gradient(self.new_labels[i]), 92 | logits=new_logits) #B H W 1 93 | loss = tf.reduce_sum(loss, axis=[1,2,3])/self.config.batch_size 94 | PC_loss.append(loss) 95 | self.loss['PC'] = tf.reduce_mean(tf.stack(PC_loss,axis=0)) #num_var -> average 96 | self.loss['CIS'] = self.loss['Generator'] 97 | self.loss['Generator'] = self.loss['CIS'] + self.loss['PC']*self.config.ita 98 | else: 99 | self.mask_capacity = tf.placeholder(shape=(), name='mask_capacity', dtype=tf.float32) 100 | self.loss['tex_kl'], self.loss['mask_kl'], self.loss['bg_kl'] = VAE_loss['tex_kl'], VAE_loss['mask_kl'], VAE_loss['bg_kl'] 101 | self.loss['tex_kl_sum'], self.loss['bg_kl_sum'] = tf.reduce_sum(VAE_loss['tex_kl']), tf.reduce_sum(VAE_loss['bg_kl']) 102 | self.loss['mask_kl_sum'] = tf.abs(tf.reduce_sum(VAE_loss['mask_kl'])-self.mask_capacity) 103 | 104 | self.loss['tex_error'], self.loss['mask_error'], self.loss['bg_error'], self.loss['VAE_fusion_error'] = \ 105 | VAE_loss['tex_error'], VAE_loss['mask_error'], VAE_loss['bg_error'], VAE_loss['fusion_error'] 106 | 107 | 108 | self.loss['VAE//separate/texVAE'] = self.loss['tex_error']+self.config.tex_beta*self.loss['tex_kl_sum'] 109 | self.loss['VAE//separate/maskVAE'] = self.loss['mask_error']+self.config.mask_gamma*self.loss['mask_kl_sum'] 110 | self.loss['VAE//separate/bgVAE'] = self.loss['bg_error']+self.config.bg_beta*self.loss['bg_kl_sum'] 111 | 112 | self.loss['VAE//separate'] = self.loss['VAE//separate/texVAE']+self.loss['VAE//separate/maskVAE']+self.loss['VAE//separate/bgVAE'] 113 | self.loss['VAE//fusion'] = self.loss['VAE_fusion_error'] 114 | 115 | 116 | 117 | #-----------fuse all branch--------------- 118 | foregrounds = VAE_outputs['out_fusion'] #B H W 3 (num_branch-1) 119 | background_mask = 1-tf.reduce_sum(VAE_outputs['out_masks'], axis=-1, keepdims=True) #B H W 1 1 120 | backgrounds = tf.expand_dims(VAE_outputs['out_bg'], axis=-1)*background_mask # B H W 3 1 121 | 122 | fusion_inputs = tf.concat([foregrounds, backgrounds], axis=-1) #B H W 3 fg_branch+1 123 | fusion_inputs = tf.reshape(fusion_inputs, [-1,fusion_inputs.get_shape()[1],fusion_inputs.get_shape()[2],3*fusion_inputs.get_shape()[4]]) 124 | 125 | with tf.name_scope("Fusion") as scope: 126 | self.fusion_outputs = Fusion_forward(inputs=fusion_inputs, scope=scope, training=self.is_training, reuse=None) 127 | Fusion_error = tf.abs(self.fusion_outputs-self.image_batch) # B H W 3 128 | self.loss['Fusion'] = tf.reduce_mean(tf.reduce_sum(Fusion_error, axis=[1,2,3])) #B -> ,(average on batch) 129 | 130 | self.loss['CIS'] = self.loss['Generator'] 131 | self.loss['Generator'] = (self.loss['tex_error']+self.loss['mask_error'])*self.config.VAE_weight + \ 132 | self.loss['CIS']*self.config.CIS_weight 133 | 134 | self.VAE_outtexes, self.VAE_outtex_bg = VAE_outputs['out_texes'], VAE_outputs['out_bg'] 135 | self.VAE_outmasks = VAE_outputs['out_masks'] #no background #B H W (num_branch-1) 136 | self.VAE_fusion = VAE_outputs['out_fusion'] 137 | 138 | if self.config.dataset == 'flying_animals': 139 | #resize VAE_out 140 | import functools 141 | resize = functools.partial(tf_resize_imgs, size=[self.img_height, self.img_width]) 142 | self.generated_masks, self.image_batch = resize(self.generated_masks), resize(self.image_batch) 143 | self.VAE_outtexes, self.VAE_outtex_bg, self.VAE_outmasks = resize(self.VAE_outtexes), resize(self.VAE_outtex_bg), resize(self.VAE_outmasks) 144 | self.VAE_fusion, self.fusion_outputs = resize(self.VAE_fusion), resize(self.fusion_outputs) 145 | 146 | #------------------------------ 147 | with tf.name_scope('train_op'): 148 | self.global_step = tf.Variable(0, name='global_step', trainable=False) 149 | self.vae_global_step = tf.Variable(0, name='vae_global_step', trainable=False) 150 | self.incr_global_step = tf.assign(self.global_step, self.global_step+1) 151 | self.incr_vae_global_step = tf.assign(self.vae_global_step, self.vae_global_step+1) 152 | self.train_ops, self.train_vars_grads = self.get_train_ops_grads() 153 | 154 | with tf.name_scope('summary_vars'): 155 | self.loss['Generator_var'] = tf.Variable(0.0, name='Generator_var') 156 | self.loss['Generator_branch_var'] = tf.Variable(0.0, name='Generator_branch_var') 157 | self.loss['Inpainter_branch_var'] = tf.Variable(0.0, name='Inpainter_branch_var') # do tf.summary 158 | self.loss['Generator_denominator_var'] = tf.Variable(0.0, name='Generator_denominator_var') 159 | self.loss['EvalIoU_var'] = tf.Variable(0.0, name='EvalIoU_var') 160 | if self.config.mode == 'train_PC': 161 | self.switching_rate = tf.Variable(0.0, name='switching_rate') 162 | if self.config.mode in ['train_VAE','train_end2end']: 163 | self.loss['tex_kl_var'] = tf.Variable(0.0, name='tex_kl_var') 164 | self.loss['mask_kl_var'] = tf.Variable(0.0, name='mask_kl_var') 165 | self.loss['bg_kl_var'] = tf.Variable(0.0, name='bg_kl_var') 166 | # 167 | def get_train_ops_grads(self): 168 | #generate train_ops 169 | lr_dict = {'Generator':self.config.gen_lr, 'Inpainter':self.config.inp_lr, 170 | 'VAE//separate':self.config.VAE_lr, 'VAE//fusion':self.config.VAE_lr, 'Fusion':self.config.VAE_lr} 171 | scopes = mode2scopes[self.config.mode] 172 | train_ops, train_vars_grads = {}, {} 173 | for m in scopes: 174 | optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=lr_dict[m]) 175 | train_vars = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, m) 176 | update_op = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS, m) 177 | train_ops[m], train_vars_grads[m] = train_op(loss=self.loss[m],var_list=train_vars, optimizer=optimizer) 178 | train_ops[m] = tf.group([train_ops[m], update_op]) 179 | return train_ops, train_vars_grads 180 | 181 | def load_training_data(self): 182 | if self.config.dataset == 'multi_texture': 183 | return multi_texture_utils.dataset(self.config.root_dir, phase='train', 184 | batch_size=self.batch_size, max_num=self.config.max_num, PC=self.config.PC) 185 | elif self.config.dataset == 'flying_animals': 186 | return flying_animals_utils.dataset(self.config.root_dir, phase='train', 187 | batch_size=self.batch_size, max_num=self.config.max_num) 188 | elif self.config.dataset == 'multi_dsprites': 189 | return multi_dsprites_utils.dataset(self.config.root_dir, phase='train', batch_size=self.batch_size) 190 | elif self.config.dataset == 'objects_room': 191 | return objects_room_utils.dataset(self.config.root_dir, phase='train', batch_size=self.batch_size) 192 | else: 193 | raise IOError("Unknown Dataset") 194 | # B H W 3 195 | def load_val_data(self): 196 | if self.config.dataset == 'multi_texture': 197 | return multi_texture_utils.dataset(self.config.root_dir, phase='val', 198 | batch_size=8, max_num=self.config.max_num, PC=self.config.PC) 199 | elif self.config.dataset == 'flying_animals': 200 | return flying_animals_utils.dataset(self.config.root_dir, phase='val', 201 | batch_size=8, max_num=self.config.max_num) 202 | elif self.config.dataset == 'multi_dsprites': 203 | return multi_dsprites_utils.dataset(self.config.root_dir, phase='val', batch_size=8) 204 | elif self.config.dataset == 'objects_room': 205 | return objects_room_utils.dataset(self.config.root_dir, phase='val', batch_size=8) 206 | else: 207 | raise IOError("Unknown Dataset") 208 | 209 | def Random_boxes(self): 210 | b,h,w = self.batch_size, self.config.img_height, self.config.img_width 211 | r1 = tf.random.uniform(shape=[], minval=0, maxval=h*2//3, dtype=tf.int32) 212 | r2 = tf.random.uniform(shape=[], minval=r1+h//5, maxval=h-1, dtype=tf.int32) 213 | c1 = tf.random.uniform(shape=[], minval=0, maxval=w*2//3, dtype=tf.int32) 214 | c2 = tf.random.uniform(shape=[], minval=c1+w//5, maxval=w-1, dtype=tf.int32) 215 | ones = tf.ones([b,h,w,1]) 216 | zeros = tf.zeros([b,h,w,1]) 217 | random_box = tf.concat([zeros[:,0:r1,:,:],ones[:,r1:r2,:,:],zeros[:,r2:,:,:]], axis=1) 218 | random_box = tf.concat([zeros[:,:,0:c1,:],random_box[:,:,c1:c2,:],zeros[:,:,c2:,:]], axis=2) 219 | random_boxes = tf.stack([random_box, 1-random_box], axis=-1) 220 | #B H W 1 221 | return random_boxes #B H W 1 2 222 | -------------------------------------------------------------------------------- /model/traverse_graph.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | from .utils.generic_utils import train_op,myprint, myinput, tf_resize_imgs, reorder_mask, erode_dilate 4 | from .nets import Generator_forward, encoder, decoder, tex_mask_fusion, Fusion_forward, gaussian_kl 5 | import numpy as np 6 | 7 | class Traverse_Graph(object): 8 | 9 | def __init__(self, FLAGS): 10 | self.config = FLAGS 11 | #load data 12 | self.batch_size = FLAGS.batch_size 13 | self.num_branch = FLAGS.num_branch 14 | self.img_height, self.img_width = FLAGS.img_height, FLAGS.img_width 15 | 16 | assert self.config.traverse_type in ['tex', 'mask', 'bg'] 17 | self.traverse_type = self.config.traverse_type 18 | if self.traverse_type == 'bg': 19 | self.traverse_branch = [self.num_branch-1] 20 | else: 21 | self.traverse_branch = [i for i in range(0,self.num_branch) if self.config.traverse_branch=='all' or str(i) in self.config.traverse_branch.split(',')] 22 | 23 | ndim_dict = {'tex':self.config.tex_dim, 'mask':self.config.mask_dim, 'bg':self.config.bg_dim} 24 | self.ndim = ndim_dict[self.config.traverse_type] 25 | 26 | self.traverse_value = list(np.linspace(self.config.traverse_start, self.config.traverse_end, 60)) 27 | self.VAE_outputs = [] 28 | self.traverse_results = [] 29 | 30 | def build(self): 31 | self.image_batch0 = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None,self.img_height,self.img_width,3]) 32 | with tf.name_scope("Generator") as scope: 33 | self.generated_masks, null = Generator_forward(self.image_batch0, self.config.dataset, 34 | self.num_branch, model=self.config.model, training=False, reuse=None, scope=scope) 35 | 36 | filter_masks = [] 37 | for i in range(self.config.num_branch): 38 | filter_masks.append(erode_dilate(self.generated_masks[:,:,:,:,i])) 39 | self.generated_masks = tf.stack(filter_masks, axis=-1)#B H W 1 num_branch 40 | self.generated_masks = reorder_mask(self.generated_masks) 41 | 42 | if self.config.dataset == 'flying_animals': 43 | #resize image to (96, 128) 44 | self.image_batch = tf_resize_imgs(self.image_batch0, size=[self.img_height//2,self.img_width//2]) 45 | self.generated_masks = tf_resize_imgs(self.generated_masks, size=[self.img_height//2, self.img_width//2]) 46 | else: 47 | self.image_batch = tf.identity(self.image_batch0) 48 | 49 | segmented_img = self.generated_masks*tf.expand_dims(self.image_batch, axis=-1)#B H W 3 M 50 | 51 | 52 | 53 | #Fusion 54 | bg_mask = 1-tf.reduce_sum(self.generated_masks[:,:,:,:,:-1*self.config.n_bg], axis=-1) 55 | with tf.compat.v1.variable_scope('VAE//separate/bgVAE', reuse=tf.compat.v1.AUTO_REUSE): 56 | bg_z_mean, bg_z_log_sigma_sq = encoder(x=bg_mask*self.image_batch, latent_dim=self.config.bg_dim, training=False) 57 | out_bg_logit = decoder(bg_z_mean, output_ch=3, latent_dim=self.config.bg_dim, x=bg_mask*self.image_batch, training=False) 58 | out_bg = tf.nn.sigmoid(out_bg_logit) 59 | 60 | self.out_bg = out_bg 61 | self.in_bg = bg_mask*self.image_batch 62 | for i in self.traverse_branch: 63 | with tf.compat.v1.variable_scope('VAE//separate', reuse=tf.compat.v1.AUTO_REUSE): 64 | with tf.compat.v1.variable_scope('texVAE', reuse=tf.compat.v1.AUTO_REUSE): 65 | tex_z_mean, tex_z_log_sigma_sq = encoder(x=self.generated_masks[:,:,:,:,i]*self.image_batch, latent_dim=self.config.tex_dim, training=False) 66 | out_tex_logit = decoder(tex_z_mean, output_ch=3, latent_dim=self.config.tex_dim, x=self.generated_masks[:,:,:,:,i]*self.image_batch, training=False) 67 | out_tex = tf.nn.sigmoid(out_tex_logit) 68 | 69 | with tf.compat.v1.variable_scope('maskVAE', reuse=tf.compat.v1.AUTO_REUSE): 70 | mask_z_mean, mask_z_log_sigma_sq = encoder(x=self.generated_masks[:,:,:,:,i], latent_dim=self.config.mask_dim, training=False) 71 | out_mask_logit = decoder(mask_z_mean, output_ch=1, latent_dim=self.config.mask_dim, x=self.generated_masks[:,:,:,:,i], training=False) 72 | out_mask = tf.nn.sigmoid(out_mask_logit) 73 | if self.traverse_type=='bg': 74 | output_ch = 3 75 | scope = 'VAE//separate/bgVAE' 76 | z_mean, z_log_sigma_sq = bg_z_mean, bg_z_log_sigma_sq 77 | elif self.traverse_type=='tex': 78 | output_ch = 3 79 | scope = 'VAE//separate/texVAE' 80 | z_mean, z_log_sigma_sq = tex_z_mean, tex_z_log_sigma_sq 81 | else: 82 | output_ch = 1 83 | scope = 'VAE//separate/maskVAE' 84 | z_mean, z_log_sigma_sq = mask_z_mean, mask_z_log_sigma_sq 85 | 86 | KLs = gaussian_kl(z_mean, z_log_sigma_sq)[0] #B,zdim -> zdim, 87 | self.traverse_dim = tf.math.top_k(KLs, k=self.config.top_kdim, sorted=True).indices 88 | 89 | for d_count in range(self.config.top_kdim): 90 | d = self.traverse_dim[d_count] 91 | delta_unit = tf.one_hot(indices=[d], depth=self.ndim) 92 | delta_unit = tf.tile(delta_unit, [self.batch_size,1]) #B,dim 93 | 94 | for k in self.traverse_value: 95 | shifted_z = z_mean + delta_unit*k #B,dim 96 | with tf.compat.v1.variable_scope(scope, reuse=tf.compat.v1.AUTO_REUSE): 97 | out_logit = decoder(shifted_z, output_ch=output_ch, latent_dim=self.ndim, x=self.generated_masks[:,:,:,:,i], training=False) 98 | out = tf.nn.sigmoid(out_logit) 99 | 100 | if self.traverse_type=='bg': 101 | foregrounds = segmented_img[:,:,:,:,:-self.config.n_bg] #B H W 3 C 102 | backgrounds = tf.expand_dims(bg_mask*out, axis=-1) #B H W 3 1 103 | else: 104 | if self.traverse_type=='tex': 105 | tex,mask = out, self.generated_masks[:,:,:,:,i] 106 | else: #mask 107 | tex,mask = out_tex, out 108 | with tf.compat.v1.variable_scope('VAE//fusion', reuse=tf.compat.v1.AUTO_REUSE): 109 | VAE_fusion_out, null = tex_mask_fusion(tex=tex, mask=mask) 110 | foregrounds = tf.concat([segmented_img[:,:,:,:,0:i], 111 | tf.expand_dims(VAE_fusion_out, axis=-1), 112 | segmented_img[:,:,:,:,i+1:-self.config.n_bg]], axis=-1) 113 | background_mask = 1-tf.reduce_sum(self.generated_masks[:,:,:,:,0:i],axis=-1)-mask-tf.reduce_sum(self.generated_masks[:,:,:,:,i+1:-self.config.n_bg],axis=-1) 114 | backgrounds = tf.expand_dims(background_mask*out_bg, axis=-1) 115 | 116 | fusion_inputs = tf.concat([foregrounds, backgrounds], axis=-1) 117 | fusion_inputs = tf.reshape(fusion_inputs, [-1,fusion_inputs.get_shape()[1],fusion_inputs.get_shape()[2],3*fusion_inputs.get_shape()[4]]) 118 | fusion_output = Fusion_forward(inputs=fusion_inputs, scope='Fusion/', training=False, reuse=tf.compat.v1.AUTO_REUSE) 119 | self.traverse_results.append(fusion_output) 120 | # def load_val_data(self): 121 | # #now only support multi_dsprites 122 | # if self.config.dataset == 'multi_texture': 123 | # return multi_texture_utils.dataset(self.config.root_dir, val=True, 124 | # batch_size=self.config.max_num, max_num=self.config.max_num, 125 | # zoom=('z' in self.config.variant), 126 | # rotation=('r' in self.config.variant), 127 | # texture_transform=self.config.texture_transform) 128 | # elif self.config.dataset == 'flying_animals': 129 | # return flying_animals_utils.dataset(self.config.root_dir,val=True, 130 | # batch_size=self.config.max_num, max_num=self.config.max_num) 131 | # elif self.config.dataset == 'multi_dsprites': 132 | # return multi_dsprites_utils.dataset(self.config.root_dir,val=True, 133 | # batch_size=self.batch_size, skipnum=0, takenum=self.config.skipnum, 134 | # shuffle=False, map_parallel_calls=1) 135 | # elif self.config.dataset == 'objects_room': 136 | # return objects_room_utils.dataset(self.config.root_dir,val=True, 137 | # batch_size=self.batch_size, skipnum=0, takenum=self.config.skipnum, 138 | # shuffle=False, map_parallel_calls=1) 139 | # else: 140 | # raise IOError("Unknown Dataset") -------------------------------------------------------------------------------- /model/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenYutongTHU/Learning-to-manipulate-individual-objects-in-an-image-Implementation/db75a5505f7fe2c83c0ded08f425ef11759544bd/model/utils/__init__.py -------------------------------------------------------------------------------- /model/utils/convolution_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | def fully_connect(x, weight_shape, name, biased=True, bias_init_value=0.0): 3 | with tf.compat.v1.variable_scope(name): 4 | weight = tf.compat.v1.get_variable("weight", weight_shape, initializer=tf.contrib.layers.xavier_initializer()) 5 | o = tf.matmul(x, weight) 6 | if biased: 7 | b = tf.get_variable('bias', shape=[weight_shape[-1]], initializer=tf.constant_initializer(bias_init_value)) 8 | o = o+b 9 | return o 10 | 11 | 12 | def _dilated_conv2d(x, kernel_size, num_o, dilation_factor, name, biased=False): 13 | num_x = x.shape[-1].value #C last 14 | with tf.compat.v1.variable_scope(name): 15 | w = tf.get_variable('weights', shape=[kernel_size, kernel_size, num_x, num_o], initializer=tf.contrib.layers.xavier_initializer_conv2d()) 16 | o = tf.nn.atrous_conv2d(x, w, dilation_factor, padding='SAME') 17 | if biased: 18 | b = tf.get_variable('biases', shape=[num_o], initializer=tf.compat.v1.constant_initializer(0.0)) 19 | o = tf.nn.bias_add(o, b) 20 | return o 21 | 22 | def conv2d_transpose(x, filter_shape, output_shape, stride, name, padding='SAME', dilation=1, biased=True): 23 | with tf.compat.v1.variable_scope(name): 24 | w = tf.get_variable('weights', shape=filter_shape, initializer=tf.contrib.layers.xavier_initializer_conv2d()) 25 | o = tf.nn.conv2d_transpose(x, filters=w, output_shape=output_shape, 26 | strides=stride, padding=padding, dilations=dilation) 27 | if biased: 28 | b = tf.get_variable('biases', shape=[filter_shape[-2]], initializer=tf.compat.v1.constant_initializer(0.0)) # output_channel input_channel 29 | o = tf.nn.bias_add(o, b) 30 | return o 31 | 32 | def conv2d(x, filter_shape, stride, name, padding='SAME', dilation=1, biased=True): 33 | with tf.compat.v1.variable_scope(name): 34 | w = tf.get_variable('weights', shape=filter_shape, initializer=tf.contrib.layers.xavier_initializer_conv2d()) 35 | o = tf.nn.conv2d(x, filters=w, strides=stride, padding=padding, dilations=dilation) 36 | if biased: 37 | b = tf.get_variable('biases', shape=[filter_shape[-1]], initializer=tf.compat.v1.constant_initializer(0.0)) 38 | o = tf.nn.bias_add(o, b) 39 | return o 40 | 41 | def deconv2d(x, filter_shape, output_size, name, padding='SAME', biased=True): 42 | with tf.compat.v1.variable_scope(name): 43 | x = tf.image.resize_images(x, size=output_size) 44 | w = tf.get_variable('weights', shape=filter_shape, initializer=tf.contrib.layers.xavier_initializer_conv2d()) 45 | o = tf.nn.conv2d(x, filters=w, strides=1, padding=padding) 46 | if biased: 47 | b = tf.get_variable('biases', shape=[filter_shape[-1]], initializer=tf.compat.v1.constant_initializer(0.0)) 48 | o = tf.nn.bias_add(o, b) 49 | return o 50 | 51 | def InstanceNorm(x): 52 | return tf.contrib.layers.instance_norm(x, center=False, scale=False, epsilon=1e-5, trainable=False) 53 | 54 | def gen_conv(x, cnum, ksize, stride=1, rate=1, name='conv', 55 | padding='SAME', activation=tf.nn.elu, training=True, 56 | kernel_initializer=None): 57 | """Define conv for generator. 58 | Args: 59 | x: Input. 60 | cnum: Channel number. 61 | ksize: Kernel size. 62 | stride: Convolution stride. 63 | Rate: Rate for or dilated conv. 64 | name: Name of layers. 65 | padding: Default to SYMMETRIC. 66 | activation: Activation function after convolution. 67 | training: If current graph is for training or inference, used for bn. 68 | Returns: 69 | tf.Tensor: output 70 | """ 71 | x = tf.layers.conv2d(x,cnum, ksize, stride, dilation_rate=rate, 72 | activation=None, padding=padding, name=name, 73 | kernel_initializer=kernel_initializer) 74 | # We empirically found BN to help if not trained (works as regularizer) 75 | x = tf.layers.batch_normalization(x) 76 | x = activation(x) 77 | 78 | return x 79 | 80 | def gen_deconv(x, cnum, name='upsample', padding='SAME', training=True): 81 | """Define deconv for generator. 82 | The deconv is defined to be a x2 resize_nearest_neighbor operation with 83 | additional gen_conv operation. 84 | Args: 85 | x: Input. 86 | cnum: Channel number. 87 | name: Name of layers. 88 | training: If current graph is for training or inference, used for bn. 89 | Returns: 90 | tf.Tensor: output 91 | """ 92 | with tf.compat.v1.variable_scope(name): 93 | x = resize(x, func=tf.compat.v1.image.resize_nearest_neighbor) 94 | x = gen_conv( 95 | x, cnum, 3, 1, name=name+'_conv', padding=padding, 96 | training=training) 97 | return x 98 | 99 | def resize(x, scale=2, to_shape=None, align_corners=True, dynamic=False, 100 | func=tf.compat.v1.image.resize_bilinear, name='resize'): 101 | """ 102 | This resize operation is used to scale the input according to some given 103 | scale and function. 104 | """ 105 | 106 | if dynamic: 107 | xs = tf.cast(tf.shape(x), tf.float32) 108 | new_xs = [tf.cast(xs[1]*scale, tf.int32), 109 | tf.cast(xs[2]*scale, tf.int32)] 110 | else: 111 | xs = x.get_shape().as_list() 112 | new_xs = [int(xs[1]*scale), int(xs[2]*scale)] 113 | with tf.compat.v1.variable_scope(name): 114 | if to_shape is None: 115 | x = func(x, new_xs, align_corners=align_corners) 116 | else: 117 | x = func(x, [to_shape[0], to_shape[1]], 118 | align_corners=align_corners) 119 | return x 120 | 121 | 122 | def conv(inputs, name, shape, stride, padding='SAME', dilations=1 ,reuse=None, training=True, activation=tf.nn.leaky_relu, 123 | init_w=tf.contrib.layers.xavier_initializer_conv2d(), init_b=tf.constant_initializer(0.0)): 124 | with tf.compat.v1.variable_scope(name, reuse=reuse) as scope: 125 | kernel = tf.contrib.framework.model_variable('weights', shape=shape, initializer=init_w, trainable=True) 126 | conv = tf.nn.conv2d(inputs, kernel, [1, stride, stride, 1], padding=padding, dilations=dilations) 127 | biases = tf.contrib.framework.model_variable('biases', shape=[shape[3]], initializer=init_b, trainable=True) 128 | conv = tf.nn.bias_add(conv, biases) 129 | if activation: 130 | conv = activation(conv) 131 | return conv 132 | 133 | def deconv(inputs, size, name, shape, reuse=None, training=True, activation=tf.nn.leaky_relu): 134 | deconv = tf.image.resize_images( inputs, size=size ) 135 | deconv = conv( deconv, name=name, shape=shape, stride=1, reuse=reuse, training=training, activation=activation ) 136 | return deconv -------------------------------------------------------------------------------- /model/utils/generic_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import tensorflow as tf 4 | import random 5 | import os 6 | import imageio 7 | import math 8 | 9 | 10 | def erode_dilate(img): 11 | img = -1*tf.nn.max_pool2d(-1*img, ksize=3, strides=1, padding='SAME') 12 | img = tf.nn.max_pool2d(img, ksize=3, strides=1, padding='SAME') 13 | return img 14 | 15 | def tf_normalize_imgs(imgs): 16 | #normalize to 0~1 17 | #B H W 3 18 | max_value = tf.reduce_max(imgs, axis=[1,2], keepdims=True) 19 | min_value = tf.reduce_min(imgs, axis=[1,2], keepdims=True) 20 | imgs = tf.math.divide_no_nan(imgs-min_value, max_value-min_value) 21 | return imgs 22 | 23 | def tf_resize_imgs(imgs, size): 24 | #default bilinear 25 | shape = imgs.get_shape().as_list() 26 | assert len(shape) <=5 and len(shape)>=3 27 | if len(shape) <= 4: #B H W C 28 | return tf.image.resize(imgs, size=size, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) 29 | else: #B H W C M 30 | resize_imgs = [tf.image.resize(imgs[:,:,:,:,i], size=size, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) for i in range(shape[-1])] 31 | resize_imgs = tf.stack(resize_imgs, axis=-1) 32 | return resize_imgs 33 | 34 | def reorder_mask(masks): 35 | #put background segmentation at the last channels 36 | B, H, W, C, M = masks.get_shape().as_list() 37 | pixel_sum = tf.reduce_sum(masks, axis=[0,1,2,3]) #C, 38 | ind = tf.math.argmax(pixel_sum) 39 | reordered = tf.concat([masks[:,:,:,:,0:ind],masks[:,:,:,:,ind+1:],masks[:,:,:,:,ind:ind+1]], axis=-1) 40 | reordered.set_shape([B,H,W,C,M]) 41 | return reordered 42 | 43 | def myprint(string): 44 | print ("\033[0;30;42m"+string+"\033[0m") 45 | 46 | def myinput(string=""): 47 | return input ("\033[0;30;41m"+string+"\033[0m") 48 | 49 | def bin_edge_map(imgs, dataset): 50 | assert dataset in ['flying_animals', 'multi_texture', 'multi_dsprites', 'objects_room'] 51 | if dataset == 'flying_animals': 52 | d_xy = tf.reduce_sum(tf.image.sobel_edges(imgs), axis=-2) 53 | d_xy = d_xy/15 54 | d_xy = tf.cast(tf.abs(d_xy)>0.04, tf.float32) 55 | else: 56 | d_xy = tf.abs(tf.image.sobel_edges(imgs)) # B H W 3 2 57 | d_xy = tf.reduce_sum(d_xy,axis=-2, keepdims=False) #B H W 2 58 | #adjust the range to 0~1 59 | max_val = tf.math.reduce_max(d_xy, axis=[1,2], keepdims=True) 60 | min_val = tf.math.reduce_min(d_xy, axis=[1,2], keepdims=True) 61 | d_xy = tf.math.divide_no_nan(d_xy-min_val,max_val-min_val) #0~1 62 | #sparse binary edge map 63 | threshold = {'multi_texture': 0.5, 'multi_dsprites':0.0001, 'objects_room':0.2} 64 | d_xy = tf.cast(d_xy>threshold[dataset], tf.float32) 65 | return d_xy 66 | 67 | def Permute_IoU(label, pred): 68 | A, B = label, pred 69 | H,W,nc,N = A.shape 70 | ans = 0 71 | nc = 0 #non-empty channel 72 | ans_perfg, arg_maxIoU = [], [] 73 | for i in range(N): 74 | src = np.expand_dims(A[:,:,:,i], axis=-1)>0.04 #H W 1 1 #binary 75 | if np.sum(src) > 4: 76 | nc += 1 77 | trg = B>0.04 #H W 1 C 78 | U = np.sum(src+trg, axis=(0,1,2)) #H W 1 C -> C 79 | I = np.sum(src*trg, axis=(0,1,2)) #H W 1 C -> C 80 | eps = 1e-8 81 | IoU = I/(eps+U) 82 | ans += np.max(IoU) 83 | arg_maxIoU.append(np.argmax(IoU)) 84 | else: #empty channel 85 | arg_maxIoU.append(-1) 86 | assert nc >0 87 | return ans/nc, arg_maxIoU 88 | 89 | def train_op(loss, var_list, optimizer, gradient_clip_value=-1): 90 | grads_and_vars = optimizer.compute_gradients(loss, var_list=var_list) 91 | 92 | clipped_grad_and_vars = [ (ClipIfNotNone(grad, gradient_clip_value),var) \ 93 | for grad, var in grads_and_vars ] 94 | train_operation = optimizer.apply_gradients(clipped_grad_and_vars) 95 | 96 | return train_operation, clipped_grad_and_vars 97 | 98 | def ClipIfNotNone(grad, clipvalue): 99 | if clipvalue==-1: 100 | return grad 101 | else: 102 | return tf.clip_by_value(grad, -clipvalue, clipvalue) -------------------------------------------------------------------------------- /model/utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import tensorflow as tf 3 | 4 | import tensorflow as tf 5 | 6 | def region_error(X, Y, region): 7 | #A B H W c / B H W 3 M 8 | #B B H W 3 / B H W 3 M 9 | #region B H W 1 /B H W 1 M 10 | X = tf.cond(tf.equal(tf.rank(X), 5), lambda:X, lambda:tf.expand_dims(X, axis=-1)) 11 | Y = tf.cond(tf.equal(tf.rank(Y), 5), lambda:Y, lambda:tf.expand_dims(Y, axis=-1)) 12 | error = tf.abs(X-Y)*region #B H W 3 M 13 | error = tf.reduce_sum(error, axis=[1,2,3]) #B,M 14 | return error #B,M 15 | 16 | def Supervised_Generator_Loss(pred, GT): 17 | #B H W 1 C 18 | # use cross entropy 19 | error = tf.compat.v1.nn.softmax_cross_entropy_with_logits(labels=GT, logits=pred, dim=-1)#, axis=-1)#B H W 1 #??? 20 | error = tf.reduce_mean(error) #B H W 1 -> perpixel 21 | return error 22 | 23 | def Generator_Loss(masks, pred_intensities, image, unconditioned_mean, epsilon): 24 | #masks B H W 1 C 25 | #pred_intensities B H W 3 C 26 | #unconditioned_mean B H W 3 27 | #image B H W 3 28 | numerator = region_error(pred_intensities, image, masks) #B,C 29 | denominator = epsilon + region_error(unconditioned_mean, image, masks) # B,C 30 | IRR = 1-tf.math.divide(numerator, denominator) #B,C 31 | perbranch_loss = tf.reduce_mean(IRR, axis=0) #C, 32 | loss = tf.reduce_sum(perbranch_loss) 33 | return loss, IRR, denominator, numerator #scalar, 34 | 35 | def Inpainter_Loss(masks, pred_intensities, image): 36 | #pred_intensities 0~1 37 | #image 0~1 38 | B, H, W, C = image.get_shape().as_list() #B H W C 39 | num_pixel = H*W*C 40 | loss0 = region_error(pred_intensities, image, masks) #B,M 41 | perbranch_loss = tf.reduce_mean(loss0, axis=0) # M, 42 | loss = tf.reduce_sum(perbranch_loss) #M->scalar 43 | loss = tf.math.divide(loss, num_pixel) 44 | return loss, loss0 -------------------------------------------------------------------------------- /sample_imgs/flying_animals/01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenYutongTHU/Learning-to-manipulate-individual-objects-in-an-image-Implementation/db75a5505f7fe2c83c0ded08f425ef11759544bd/sample_imgs/flying_animals/01.png -------------------------------------------------------------------------------- /sample_imgs/flying_animals/02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenYutongTHU/Learning-to-manipulate-individual-objects-in-an-image-Implementation/db75a5505f7fe2c83c0ded08f425ef11759544bd/sample_imgs/flying_animals/02.png -------------------------------------------------------------------------------- /sample_imgs/multi_dsprites/01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenYutongTHU/Learning-to-manipulate-individual-objects-in-an-image-Implementation/db75a5505f7fe2c83c0ded08f425ef11759544bd/sample_imgs/multi_dsprites/01.png -------------------------------------------------------------------------------- /sample_imgs/multi_dsprites/02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenYutongTHU/Learning-to-manipulate-individual-objects-in-an-image-Implementation/db75a5505f7fe2c83c0ded08f425ef11759544bd/sample_imgs/multi_dsprites/02.png -------------------------------------------------------------------------------- /sample_imgs/multi_texture/01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenYutongTHU/Learning-to-manipulate-individual-objects-in-an-image-Implementation/db75a5505f7fe2c83c0ded08f425ef11759544bd/sample_imgs/multi_texture/01.png -------------------------------------------------------------------------------- /sample_imgs/objects_room/01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenYutongTHU/Learning-to-manipulate-individual-objects-in-an-image-Implementation/db75a5505f7fe2c83c0ded08f425ef11759544bd/sample_imgs/objects_room/01.png -------------------------------------------------------------------------------- /sample_imgs/objects_room/02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenYutongTHU/Learning-to-manipulate-individual-objects-in-an-image-Implementation/db75a5505f7fe2c83c0ded08f425ef11759544bd/sample_imgs/objects_room/02.png -------------------------------------------------------------------------------- /sample_imgs/objects_room/03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenYutongTHU/Learning-to-manipulate-individual-objects-in-an-image-Implementation/db75a5505f7fe2c83c0ded08f425ef11759544bd/sample_imgs/objects_room/03.png -------------------------------------------------------------------------------- /sample_imgs/objects_room/04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenYutongTHU/Learning-to-manipulate-individual-objects-in-an-image-Implementation/db75a5505f7fe2c83c0ded08f425ef11759544bd/sample_imgs/objects_room/04.png -------------------------------------------------------------------------------- /script/flying_animals/disentangle.sh: -------------------------------------------------------------------------------- 1 | ABSPATH=$(readlink -f $0) 2 | 3 | 4 | CUDA_VISIBLE_DEVICES=0 python main.py --sh_path=$ABSPATH \ 5 | --checkpoint_dir=outputs/flying_animals/02_bg_range6 \ 6 | --dataset=flying_animals --max_num=5 --num_branch=6 --root_dir=data/flying_animals_data/img_data.npz \ 7 | --mode=eval_VAE --model=resnet_v2_50 \ 8 | \ 9 | \ 10 | --fullmodel_ckpt='path to checkpoint' \ 11 | --tex_dim=20 --bg_dim=30 --mask_dim=20 \ 12 | \ 13 | --input_img=sample_imgs/flying_animals/02.png \ 14 | --traverse_type=bg --top_kdim=4 --traverse_branch=5 \ 15 | --batch_size=1 --traverse_start=-6 --traverse_end=6 \ 16 | 17 | #checkpoint_dir: folder containing outputs 18 | 19 | #input_img: target image (only support processing single image) 20 | 21 | #traverse_type: 22 | #mask: traverse shape latent space 23 | #tex: traverse texture/color latent space 24 | #bg: traverse background latent space 25 | 26 | #top_kdim: 27 | #choose k dimensions with largest kl divergence to traverse 28 | #these dimensions should encode k most significant variables of the object. 29 | #results of traversing dimension of kth largest kl divergence are output as branch{i}_var{k}.gif 30 | 31 | 32 | #traverse_branch: which branches to traverse 33 | #(only effective when traverse_type in ['tex','mask']) 34 | #'all': generate results of traversing all branches except the background branch 35 | #'0,1,2': mannually choose branches to traverse 36 | -------------------------------------------------------------------------------- /script/flying_animals/pretrain_inpainter.sh: -------------------------------------------------------------------------------- 1 | ABSPATH=$(readlink -f $0) 2 | 3 | 4 | CUDA_VISIBLE_DEVICES=0 python main.py \ 5 | --checkpoint_dir=checkpoint/flying_animals/pretrain_inpainter \ 6 | --sh_path=$ABSPATH \ 7 | \ 8 | --dataset=flying_animals --root_dir=data/flying_animals_data/img_data.npz \ 9 | --mode=pretrain_inpainter --max_num=5 \ 10 | \ 11 | --batch_size=4 --inp_lr=1e-4 \ 12 | \ 13 | --summaries_secs=180 --ckpt_secs=5000 \ -------------------------------------------------------------------------------- /script/flying_animals/test_segmentation.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python test_segmentation.py \ 2 | --data_path ../data/flying_animals_data/img_data.npz \ 3 | --ckpt_path 'path to checkpoint including segmentation network' \ 4 | --batch_size 8 --num_branch 6 --dataset_name flying_animals 5 | -------------------------------------------------------------------------------- /script/flying_animals/train_CIS.sh: -------------------------------------------------------------------------------- 1 | ABSPATH=$(readlink -f $0) 2 | 3 | CUDA_VISIBLE_DEVICES=0 python main.py --sh_path=$ABSPATH \ 4 | --checkpoint_dir=checkpoint/flying_animals/CIS \ 5 | --dataset=flying_animals --max_num=5 --num_branch=6 --root_dir=data/flying_animals_data/img_data.npz \ 6 | --mode=train_CIS --model=resnet_v2_50 \ 7 | \ 8 | \ 9 | --resume_inpainter=True --resume_fullmodel=False --resume_resnet=True \ 10 | --fullmodel_ckpt=/ \ 11 | --inpainter_ckpt='path of pretrained inpainter ckpt to resume here' 12 | 'e.g. checkpoint/flying_animals/pretrain_inpainter/inpainter-100000' \ 13 | --resnet_ckpt=resnet/resnet_v2_50/resnet_v2_50.ckpt \ 14 | \ 15 | --batch_size=4 --inp_lr=1e-4 --gen_lr=1e-4 \ 16 | --epsilon=100 --iters_inp=1 --iters_gen=2 \ 17 | --ckpt_steps=10000 --summaries_steps=500 \ -------------------------------------------------------------------------------- /script/flying_animals/train_VAE.sh: -------------------------------------------------------------------------------- 1 | ABSPATH=$(readlink -f $0) 2 | 3 | CUDA_VISIBLE_DEVICES=0 python main.py --sh_path=$ABSPATH \ 4 | --checkpoint_dir=checkpoint/flying_animals/VAE \ 5 | --dataset=flying_animals --max_num=5 --num_branch=6 --root_dir=data/flying_animals_data/img_data.npz \ 6 | --mode=train_VAE --model=resnet_v2_50 \ 7 | \ 8 | \ 9 | --resume_CIS=True --resume_fullmodel=False \ 10 | --fullmodel_ckpt=/ \ 11 | --CIS_ckpt='path of CIS ckpt to resume here, e.g. checkpoint/flying_animals/CIS/model-100000' \ 12 | \ 13 | --batch_size=4 --VAE_lr=1e-4 \ 14 | --tex_dim=20 --tex_beta=4 \ 15 | --bg_dim=30 --bg_beta=4 \ 16 | --mask_dim=20 --mask_gamma=500 --mask_capacity_inc=5e-5 \ 17 | --ckpt_steps=10000 --summaries_steps=1000 \ -------------------------------------------------------------------------------- /script/multi_dsprites/disentangle.sh: -------------------------------------------------------------------------------- 1 | ABSPATH=$(readlink -f $0) 2 | 3 | CUDA_VISIBLE_DEVICES=0 python main.py --sh_path=$ABSPATH \ 4 | --checkpoint_dir=outputs/multi_dsprites/02 \ 5 | --dataset=multi_dsprites --num_branch=5 --root_dir=data/multi_dsprites_data/multi_dsprites_colored_on_colored.tfrecords \ 6 | --mode=eval_VAE --model=resnet_v2_50 \ 7 | \ 8 | --fullmodel_ckpt='path to checkpoint containing segmentation network and encoder-decoder' \ 9 | --tex_dim=5 --bg_dim=5 --mask_dim=10 \ 10 | \ 11 | --input_img=sample_imgs/multi_dsprites/02.png \ 12 | --traverse_type=mask --top_kdim=5 --traverse_branch=2,3 \ 13 | --batch_size=1 --traverse_range=0.5 \ 14 | 15 | #checkpoint_dir: folder containing outputs 16 | 17 | #input_img: target image (only support processing single image) 18 | 19 | #traverse_type: 20 | #mask: traverse shape latent space 21 | #tex: traverse texture/color latent space 22 | #bg: traverse background latent space 23 | 24 | #top_kdim: 25 | #choose k dimensions with largest kl divergence to traverse 26 | #these dimensions should encode k most significant variables of the object. 27 | #results of traversing dimension of kth largest kl divergence are output as branch{i}_var{k}.gif 28 | 29 | 30 | #traverse_branch: which branches to traverse 31 | #(only effective when traverse_type in ['tex','mask']) 32 | #'all': generate results of traversing all branches except the background branch 33 | #'0,1,2': mannually choose branches to traverse 34 | 35 | -------------------------------------------------------------------------------- /script/multi_dsprites/pretrain_inpainter.sh: -------------------------------------------------------------------------------- 1 | ABSPATH=$(readlink -f $0) 2 | 3 | 4 | CUDA_VISIBLE_DEVICES=0 python main.py \ 5 | --checkpoint_dir=checkpoint/multi_dsprites/pretrain_inpainter \ 6 | --sh_path=$ABSPATH \ 7 | \ 8 | --dataset=multi_dsprites --root_dir=data/multi_dsprites_data/multi_dsprites_colored_on_colored.tfrecords \ 9 | --mode=pretrain_inpainter \ 10 | \ 11 | --batch_size=16 --inp_lr=1e-4 \ 12 | \ 13 | --summaries_secs=60 --ckpt_secs=10000 \ -------------------------------------------------------------------------------- /script/multi_dsprites/test_segmentation.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python test_segmentation.py \ 2 | --data_path ../data/multi_dsprites_data/multi_dsprites_colored_on_colored.tfrecords \ 3 | --ckpt_path 'path to checkpoint including segmentation network' \ 4 | --batch_size 8 --num_branch 5 --dataset_name multi_dsprites -------------------------------------------------------------------------------- /script/multi_dsprites/train_CIS.sh: -------------------------------------------------------------------------------- 1 | ABSPATH=$(readlink -f $0) 2 | 3 | 4 | CUDA_VISIBLE_DEVICES=0 python main.py --sh_path=$ABSPATH \ 5 | --checkpoint_dir=checkpoint/multi_dsprites/CIS \ 6 | --dataset=multi_dsprites --num_branch=5 --root_dir=data/multi_dsprites_data/multi_dsprites_colored_on_colored.tfrecords \ 7 | --mode=train_CIS --model=resnet_v2_50 \ 8 | \ 9 | \ 10 | --resume_inpainter=True --resume_fullmodel=False \ 11 | --fullmodel_ckpt=/ \ 12 | --inpainter_ckpt='path of pretrained inpainter ckpt to resume here' 13 | 'e.g. checkpoint/multi_dsprites/pretrain_inpainter/inpainter-50000' \ 14 | \ 15 | --batch_size=8 --inp_lr=1e-4 --gen_lr=3e-5 \ 16 | --epsilon=30 --iters_inp=1 --iters_gen=3 \ 17 | --ckpt_steps=4000 \ 18 | 19 | -------------------------------------------------------------------------------- /script/multi_dsprites/train_VAE.sh: -------------------------------------------------------------------------------- 1 | ABSPATH=$(readlink -f $0) 2 | 3 | 4 | CUDA_VISIBLE_DEVICES=0 python main.py --sh_path=$ABSPATH \ 5 | --checkpoint_dir=checkpoint/multi_dsprites/VAE \ 6 | --dataset=multi_dsprites --num_branch=5 --root_dir=data/multi_dsprites_data/multi_dsprites_colored_on_colored.tfrecords \ 7 | --mode=train_VAE --model=resnet_v2_50 \ 8 | \ 9 | \ 10 | --resume_CIS=True --resume_fullmodel=False \ 11 | --fullmodel_ckpt= \ 12 | --CIS_ckpt='path of CIS ckpt to resume here, e.g. checkpoint/multi_dsprites/CIS/model-100000' \ 13 | \ 14 | --batch_size=16 --VAE_lr=1e-4 \ 15 | --tex_dim=5 --tex_beta=6 \ 16 | --bg_dim=5 --bg_beta=6 \ 17 | --mask_dim=10 --mask_gamma=500 --mask_capacity_inc=2e-5 \ 18 | --ckpt_steps=20000 --summaries_steps=1000 \ 19 | 20 | -------------------------------------------------------------------------------- /script/multi_texture/disentangle.sh: -------------------------------------------------------------------------------- 1 | ABSPATH=$(readlink -f $0) 2 | 3 | CUDA_VISIBLE_DEVICES=0 python main.py --sh_path=$ABSPATH \ 4 | --checkpoint_dir=outputs/multi_texture \ 5 | --dataset=multi_texture --max_num=4 --num_branch=5 --root_dir=data/multi_texture_data/ \ 6 | --mode=eval_VAE --model=resnet_v2_50 \ 7 | \ 8 | \ 9 | --fullmodel_ckpt='path to checkpoint' \ 10 | --tex_dim=5 --mask_dim=10 --bg_dim=5 \ 11 | \ 12 | --input_img=sample_imgs/multi_texture/01.png \ 13 | --traverse_type=mask --top_kdim=2 --traverse_branch=0,1,2,3 --traverse_range=0.5 \ 14 | --batch_size=1 \ 15 | 16 | #checkpoint_dir: folder containing outputs 17 | 18 | #input_img: target image (only support processing single image) 19 | 20 | #traverse_type: 21 | #mask: traverse shape latent space 22 | #tex: traverse texture/color latent space 23 | #bg: traverse background latent space 24 | 25 | #top_kdim: 26 | #choose k dimensions with largest kl divergence to traverse 27 | #these dimensions should encode k most significant variables of the object. 28 | #results of traversing dimension of kth largest kl divergence are output as branch{i}_var{k}.gif 29 | 30 | 31 | #traverse_branch: which branches to traverse 32 | #(only effective when traverse_type in ['tex','mask']) 33 | #'all': generate results of traversing all branches except the background branch 34 | #'0,1,2': mannually choose branches to traverse 35 | -------------------------------------------------------------------------------- /script/multi_texture/perceptual_consistency/finetune_PC.sh: -------------------------------------------------------------------------------- 1 | ABSPATH=$(readlink -f $0) 2 | 3 | CUDA_VISIBLE_DEVICES=0 python main.py --sh_path=$ABSPATH \ 4 | --checkpoint_dir=checkpoint/multi_texture/CIS/PC/finetune \ 5 | --dataset=multi_texture --max_num=2 --num_branch=3 --root_dir=data/multi_texture_data/ \ 6 | --mode=train_PC --model=segnet --PC=True \ 7 | \ 8 | --fullmodel_ckpt='path of pretrained ckpt to finetune here' 9 | 'e.g. checkpoint/multi_texture/pc/model-0' \ 10 | --batch_size=4 --inp_lr=1e-4 --gen_lr=3e-5 \ 11 | --bg_dim=5 --tex_dim=5 --mask_dim=10 \ 12 | --epsilon=50 --ita=1e-3 --iters_inp=1 --iters_gen_vae=3 \ 13 | --ckpt_steps=10000 --summaries_steps=100 \ 14 | -------------------------------------------------------------------------------- /script/multi_texture/perceptual_consistency/test_segmentation.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python test_segmentation.py \ 2 | --data_path data/multi_texture_data \ 3 | --ckpt_path 'path to checkpoint including segmentation network' \ 4 | --batch_size 8 --num_branch 3 --dataset_name multi_texture --PC=True --model=segnet 5 | -------------------------------------------------------------------------------- /script/multi_texture/perceptual_consistency/train_CIS.sh: -------------------------------------------------------------------------------- 1 | ABSPATH=$(readlink -f $0) 2 | 3 | CUDA_VISIBLE_DEVICES=0 python main.py --sh_path=$ABSPATH \ 4 | --checkpoint_dir=checkpoint/multi_texture/CIS/PC/CIS \ 5 | --dataset=multi_texture --max_num=2 --num_branch=3 --root_dir=data/multi_texture_data/ \ 6 | --mode=train_CIS --model=segnet --PC=True \ 7 | \ 8 | \ 9 | --resume_inpainter=True --resume_fullmodel=False \ 10 | --fullmodel_ckpt= \ 11 | --inpainter_ckpt='path of pretrained inpainter ckpt to resume here' 12 | 'e.g. checkpoint/multi_texture/pretrain_inpainter/inpainter-50000' \ 13 | \ 14 | --batch_size=8 --inp_lr=1e-4 --gen_lr=3e-5 \ 15 | --epsilon=50 --iters_inp=1 --iters_gen=3 \ 16 | --ckpt_steps=5000 \ -------------------------------------------------------------------------------- /script/multi_texture/perceptual_consistency/train_VAE.sh: -------------------------------------------------------------------------------- 1 | ABSPATH=$(readlink -f $0) 2 | 3 | 4 | CUDA_VISIBLE_DEVICES=0 python main.py --sh_path=$ABSPATH \ 5 | --checkpoint_dir=checkpoint/multi_texture/CIS/PC/VAE \ 6 | --dataset=multi_texture --max_num=2 --num_branch=3 --root_dir=data/multi_texture_data/ --PC=True \ 7 | --mode=train_VAE --model=segnet \ 8 | \ 9 | \ 10 | --resume_CIS=True --resume_fullmodel=False \ 11 | --fullmodel_ckpt= \ 12 | --CIS_ckpt='path of CIS ckpt to resume here' \ 13 | \ 14 | --batch_size=16 --VAE_lr=1e-4 \ 15 | --tex_dim=5 --tex_beta=4 \ 16 | --bg_dim=5 --bg_beta=6 \ 17 | --mask_dim=10 --mask_gamma=500 --mask_capacity_inc=2e-5 \ 18 | --ckpt_steps=20000 --summaries_steps=1000 \ -------------------------------------------------------------------------------- /script/multi_texture/pretrain_inpainter.sh: -------------------------------------------------------------------------------- 1 | ABSPATH=$(readlink -f $0) 2 | 3 | 4 | CUDA_VISIBLE_DEVICES=0 python main.py \ 5 | --checkpoint_dir=checkpoint/multi_texture/pretrain_inpainter \ 6 | --sh_path=$ABSPATH \ 7 | \ 8 | --dataset=multi_texture --max_num=4 --root_dir=data/multi_texture_data \ 9 | --mode=pretrain_inpainter \ 10 | \ 11 | --batch_size=16 --inp_lr=1e-4 \ 12 | \ 13 | --summaries_secs=60 --ckpt_secs=10000 \ -------------------------------------------------------------------------------- /script/multi_texture/test_segmentation.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python test_segmentation.py \ 2 | --data_path ../data/multi_texture_data \ 3 | --ckpt_path 'path to checkpoint including segmentation network' \ 4 | --batch_size 8 --num_branch 5 --dataset_name multi_texture -------------------------------------------------------------------------------- /script/multi_texture/train_CIS.sh: -------------------------------------------------------------------------------- 1 | ABSPATH=$(readlink -f $0) 2 | 3 | CUDA_VISIBLE_DEVICES=0 python main.py --sh_path=$ABSPATH \ 4 | --checkpoint_dir=checkpoint/multi_texture/CIS \ 5 | --dataset=multi_texture --max_num=4 --num_branch=5 --root_dir=data/multi_texture_data/ \ 6 | --mode=train_CIS --model=resnet_v2_50 \ 7 | \ 8 | \ 9 | --resume_inpainter=True --resume_fullmodel=False \ 10 | --fullmodel_ckpt=/ \ 11 | --inpainter_ckpt='path of pretrained inpainter ckpt to resume here' 12 | 'e.g. checkpoint/multi_texture/pretrain_inpainter/inpainter-50000' \ 13 | \ 14 | --batch_size=8 --inp_lr=1e-4 --gen_lr=3e-5 \ 15 | --epsilon=30 --iters_inp=1 --iters_gen=3 \ 16 | --ckpt_steps=4000 \ -------------------------------------------------------------------------------- /script/multi_texture/train_VAE.sh: -------------------------------------------------------------------------------- 1 | ABSPATH=$(readlink -f $0) 2 | 3 | CUDA_VISIBLE_DEVICES=0 python main.py --sh_path=$ABSPATH \ 4 | --checkpoint_dir=checkpoint/multi_texture/VAE \ 5 | --dataset=multi_texture --max_num=4 --num_branch=5 --root_dir=data/multi_texture_data/ \ 6 | --mode=train_VAE --model=resnet_v2_50 \ 7 | \ 8 | \ 9 | --resume_CIS=True --resume_fullmodel=False \ 10 | --fullmodel_ckpt= \ 11 | --CIS_ckpt='path of CIS ckpt to resume here, e.g. checkpoint/multi_texture/CIS/model-100000' \ 12 | \ 13 | --batch_size=16 --VAE_lr=1e-4 \ 14 | --tex_dim=5 --tex_beta=4 \ 15 | --bg_dim=5 --bg_beta=6 \ 16 | --mask_dim=10 --mask_gamma=500 --mask_capacity_inc=2e-5 \ 17 | --ckpt_steps=20000 --summaries_steps=1000 \ -------------------------------------------------------------------------------- /script/objects_room/disentangle.sh: -------------------------------------------------------------------------------- 1 | ABSPATH=$(readlink -f $0) 2 | CUDA_VISIBLE_DEVICES=0 python main.py --sh_path=$ABSPATH \ 3 | --checkpoint_dir=outputs/objects_room/tex_split \ 4 | --dataset=objects_room --num_branch=6 --root_dir=data/objects_room_data/objects_room_train.tfrecords \ 5 | --mode=eval_VAE --model=resnet_v2_50 \ 6 | \ 7 | \ 8 | --fullmodel_ckpt='path to model checkpoint' \ 9 | --tex_dim=5 --bg_dim=10 --mask_dim=10 \ 10 | \ 11 | --input_img=sample_imgs/objects_room/01.png \ 12 | --traverse_type=tex --top_kdim=1 --traverse_branch=0,2 --traverse_start=-0.5 --traverse_end=1.5 \ 13 | --batch_size=1 \ 14 | 15 | #checkpoint_dir: folder containing outputs 16 | 17 | #input_img: target image (only support processing single image) 18 | 19 | #traverse_type: 20 | #mask: traverse shape latent space 21 | #tex: traverse texture/color latent space 22 | #bg: traverse background latent space 23 | 24 | #top_kdim: 25 | #choose k dimensions with the largest kl divergence to traverse 26 | #these dimensions should encode k most significant variables of the object. 27 | #results of traversing dimension of kth largest kl divergence are output as branch{i}_var{k}.gif 28 | 29 | 30 | #traverse_branch: which branches to traverse 31 | #(only effective when traverse_type in ['tex','mask']) 32 | #'all': generate results of traversing all branches except the background branch 33 | #'0,1,2': mannually choose branches to traverse -------------------------------------------------------------------------------- /script/objects_room/pretrain_inpainter.sh: -------------------------------------------------------------------------------- 1 | ABSPATH=$(readlink -f $0) 2 | 3 | CUDA_VISIBLE_DEVICES=0 python main.py \ 4 | --checkpoint_dir=checkpoint/objects_room/pretrain_inpainter \ 5 | --sh_path=$ABSPATH \ 6 | \ 7 | --dataset=objects_room --root_dir=data/objects_room_data/objects_room_train.tfrecords \ 8 | --mode=pretrain_inpainter \ 9 | \ 10 | --batch_size=16 --inp_lr=3e-5 \ 11 | \ 12 | --summaries_secs=60 --ckpt_secs=10000 \ 13 | 14 | -------------------------------------------------------------------------------- /script/objects_room/test_segmentation.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python test_segmentation.py \ 2 | --data_path ../data/objects_room_data/objects_room_train.tfrecords \ 3 | --ckpt_path 'path to checkpoint including segmentation network' \ 4 | --batch_size 8 --num_branch 6 --dataset_name objects_room 5 | -------------------------------------------------------------------------------- /script/objects_room/train_CIS.sh: -------------------------------------------------------------------------------- 1 | ABSPATH=$(readlink -f $0) 2 | 3 | CUDA_VISIBLE_DEVICES=0 python main.py --sh_path=$ABSPATH \ 4 | --checkpoint_dir=checkpoint/objects_room/CIS \ 5 | --dataset=objects_room --num_branch=6 --root_dir=data/objects_room_data/objects_room_train.tfrecords \ 6 | --mode=train_CIS --model=resnet_v2_50 \ 7 | \ 8 | \ 9 | --resume_inpainter=True --resume_fullmodel=False \ 10 | --fullmodel_ckpt=/ \ 11 | --inpainter_ckpt='path of pretrained inpainter ckpt to resume here' 12 | 'e.g. checkpoint/objects_room/pretrain_inpainter/inpainter-50000' \ 13 | \ 14 | --batch_size=8 --inp_lr=1e-4 --gen_lr=3e-5 \ 15 | --epsilon=30 --iters_inp=1 --iters_gen=3 \ 16 | --ckpt_steps=4000 \ -------------------------------------------------------------------------------- /script/objects_room/train_VAE.sh: -------------------------------------------------------------------------------- 1 | ABSPATH=$(readlink -f $0) 2 | 3 | 4 | CUDA_VISIBLE_DEVICES=0 python main.py --sh_path=$ABSPATH \ 5 | --checkpoint_dir=checkpoint/objects_room/VAE \ 6 | --dataset=objects_room --num_branch=6 --root_dir=data/objects_room_data/objects_room_train.tfrecords \ 7 | --mode=train_VAE --model=resnet_v2_50 \ 8 | \ 9 | \ 10 | --resume_CIS=True --resume_fullmodel=False \ 11 | --fullmodel_ckpt=/ \ 12 | --CIS_ckpt='path of CIS ckpt to resume here, e.g. checkpoint/objects_room/CIS/model-100000' \ 13 | \ 14 | --batch_size=8 --VAE_lr=1e-4 \ 15 | --tex_dim=5 --tex_beta=6 \ 16 | --bg_dim=5 --bg_beta=6 \ 17 | --mask_dim=10 --mask_gamma=500 --mask_capacity_inc=2e-4 \ 18 | --ckpt_steps=20000 --summaries_steps=1000 \ -------------------------------------------------------------------------------- /tb.sh: -------------------------------------------------------------------------------- 1 | tensorboard \ 2 | --logdir='path to logdir' \ 3 | --port=6106 4 | -------------------------------------------------------------------------------- /test_segmentation.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | import sys 4 | sys.path.append("..") 5 | import pprint 6 | from model.utils.generic_utils import myprint, myinput, Permute_IoU, reorder_mask 7 | from model.nets import Generator_forward 8 | import imageio 9 | import numpy as np 10 | import time 11 | from data import multi_texture_utils, flying_animals_utils, objects_room_utils, multi_dsprites_utils 12 | import argparse 13 | 14 | parser = argparse.ArgumentParser(description='test the segmentation mean IoU') 15 | parser.add_argument('--data_path', type=str, help='dir of the test set') 16 | parser.add_argument('--batch_size', default=8, type=int, help='batchsize', required=False) 17 | parser.add_argument('--dataset_name', default='flying_animals', type=str, help='flying_animals / multi_texture / multi_dsprites / objects_room') 18 | parser.add_argument('--ckpt_path', default='./', type=str) 19 | parser.add_argument('--num_branch', default=6, type=int, help='#branch should match the checkpoint and network') 20 | parser.add_argument('--PC',default=False,type=bool, help='whether to test perceptual consistency, output identity switching rate') 21 | parser.add_argument('--model',default='resnet_v2_50',type=str, help='segmentation network model, resnet_v2_50 or segnet') 22 | args = parser.parse_args() 23 | 24 | #usage python test/test_segmentation.py arg1 arg2 arg3 25 | 26 | 27 | data_path = args.data_path 28 | batch_size = args.batch_size 29 | dataset_name = args.dataset_name 30 | num_branch = args.num_branch 31 | PC, model = args.PC, args.model 32 | 33 | if dataset_name == 'flying_animals': 34 | img_height, img_width = 192, 256 35 | dataset = flying_animals_utils.dataset(data_path=data_path,batch_size=batch_size, max_num=5, phase='test') 36 | elif dataset_name == 'multi_texture': 37 | img_height, img_width = 64, 64 38 | dataset = multi_texture_utils.dataset(data_path=data_path, batch_size=batch_size, max_num=2 if PC else 4, phase='test',PC=PC) 39 | elif dataset_name == 'multi_dsprites': 40 | img_height, img_width = 64, 64 41 | dataset = multi_dsprites_utils.dataset(tfrecords_path=data_path,batch_size=batch_size, phase='test') 42 | elif dataset_name == 'objects_room': 43 | img_height, img_width = 64, 64 44 | dataset = objects_room_utils.dataset(tfrecords_path=data_path,batch_size=batch_size, phase='test') 45 | 46 | ckpt_path = args.ckpt_path 47 | 48 | 49 | iterator = dataset.make_one_shot_iterator() 50 | test_batch = iterator.get_next() 51 | img, tf_GT_masks = test_batch['img'], test_batch['masks'] 52 | img.set_shape([batch_size, img_height, img_width, 3]) 53 | tf_GT_masks.set_shape([batch_size, img_height, img_width, 1, None]) 54 | 55 | with tf.name_scope("Generator") as scope: 56 | tf_generated_masks, null = Generator_forward(img, dataset_name, 57 | num_branch, model=model, training=False, reuse=None, scope=scope) 58 | tf_generated_masks = reorder_mask(tf_generated_masks) #place the background at the last channel 59 | 60 | restore_vars = tf.global_variables('Generator') 61 | saver = tf.train.Saver(restore_vars) 62 | 63 | with tf.Session() as sess: 64 | saver.restore(sess, ckpt_path) 65 | scores = [] 66 | fetches = {'GT_masks': tf_GT_masks, 'generated_masks': tf_generated_masks, 'img':img} 67 | 68 | if PC: 69 | #test perceptual consistency 70 | num = 9*9*9*9-1 71 | assert num%batch_size==0 72 | niter = num//batch_size 73 | score, arg_maxIoUs = 0, [] 74 | for u in range(niter): 75 | results = sess.run(fetches) 76 | for j in range(batch_size): 77 | s, arg_maxIoU = Permute_IoU(label=results['GT_masks'][j], pred=results['generated_masks'][j]) 78 | score += s 79 | arg_maxIoUs.append(arg_maxIoU) 80 | score = score/num 81 | arg_maxIoUs = np.stack(arg_maxIoUs, axis=0) 82 | count = np.sum(arg_maxIoUs, axis=0) 83 | switching_rate = np.min(count)/num 84 | print("IoU: {} identity switching rate: {} ".format(score, switching_rate)) 85 | 86 | 87 | else: 88 | for i in range(10): #10 subsets 89 | #200 images in each subset 90 | assert 200%batch_size==0 91 | niter = 200//batch_size 92 | score = [] 93 | for u in range(niter): 94 | results = sess.run(fetches) 95 | for j in range(batch_size): 96 | s, null = Permute_IoU(label=results['GT_masks'][j], pred=results['generated_masks'][j]) 97 | score.append(s) 98 | scores.append(score) #10*200 99 | print("subset {}: mean {} variance{}\n".format(i+1, np.mean(scores[i]), np.var(scores[i]))) 100 | mean_IoU = np.mean(scores, axis=-1) #10, 101 | print("mean of mean_IoU: {} std of mean_IoU: {}\n".format(np.mean(mean_IoU), np.std(mean_IoU, ddof=1))) 102 | 103 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenYutongTHU/Learning-to-manipulate-individual-objects-in-an-image-Implementation/db75a5505f7fe2c83c0ded08f425ef11759544bd/trainer/__init__.py -------------------------------------------------------------------------------- /trainer/train_CIS.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | import gflags 4 | #https://github.com/google/python-gflags 5 | import sys 6 | sys.path.append("..") 7 | import pprint 8 | from keras.utils.generic_utils import Progbar 9 | import model.Summary as Summary 10 | from model.utils.generic_utils import myprint, myinput, Permute_IoU 11 | from model.train_graph import Train_Graph 12 | import imageio 13 | import numpy as np 14 | import time 15 | import re 16 | 17 | def train(FLAGS): 18 | graph = Train_Graph(FLAGS) 19 | graph.build() 20 | 21 | summary_op, generator_summary_op, branch_summary_op, eval_summary_op = Summary.collect_CIS_summary(graph, FLAGS) 22 | with tf.name_scope("parameter_count"): 23 | total_parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) \ 24 | for v in tf.trainable_variables()]) 25 | 26 | save_vars = tf.global_variables('Inpainter')+tf.global_variables('Generator')+ \ 27 | tf.global_variables('train_op') #including global step 28 | if FLAGS.resume_inpainter: 29 | assert os.path.isfile(FLAGS.inpainter_ckpt+'.index') 30 | inpainter_saver = tf.train.Saver(tf.trainable_variables('Inpainter'))#only restore the trainable variables 31 | 32 | if FLAGS.resume_resnet: 33 | assert os.path.isfile(FLAGS.resnet_ckpt) 34 | resnet_reader=tf.compat.v1.train.NewCheckpointReader(FLAGS.resnet_ckpt) 35 | resnet_map = resnet_reader.get_variable_to_shape_map() 36 | resnet_dict = dict() 37 | for v in tf.trainable_variables('Generator//resnet_v2'): 38 | if 'resnet_v2_50/'+v.op.name[21:] in resnet_map.keys(): 39 | resnet_dict['resnet_v2_50/'+v.op.name[21:]] = v 40 | resnet_var_name = [v.name for v in tf.trainable_variables('Generator//resnet_v2') \ 41 | if 'resnet_v2_50/'+v.op.name[21:] in resnet_map.keys()] 42 | resnet_saver = tf.train.Saver(resnet_dict) 43 | 44 | saver = tf.train.Saver(save_vars, max_to_keep=100) 45 | branch_writers = [tf.summary.FileWriter(os.path.join(FLAGS.checkpoint_dir, "branch"+str(m))) \ 46 | for m in range(FLAGS.num_branch)] #save generator loss for each branch 47 | sv = tf.train.Supervisor(logdir=os.path.join(FLAGS.checkpoint_dir, "CIS_Sum"), 48 | saver=None, save_summaries_secs=0) 49 | 50 | with sv.managed_session() as sess: 51 | myprint ("Number of total params: {0} \n".format( \ 52 | sess.run(total_parameter_count))) 53 | if FLAGS.resume_fullmodel: 54 | assert os.path.isfile(FLAGS.fullmodel_ckpt+'.index') 55 | saver.restore(sess, FLAGS.fullmodel_ckpt) 56 | myprint ("Resumed training from model {}".format(FLAGS.fullmodel_ckpt)) 57 | myprint ("Start from step {}".format(sess.run(graph.global_step))) 58 | myprint ("Save checkpoint in {}".format(FLAGS.checkpoint_dir)) 59 | if not os.path.dirname(FLAGS.fullmodel_ckpt) == FLAGS.checkpoint_dir: 60 | print ("\033[0;30;41m"+"Warning: checkpoint dir and fullmodel ckpt do not match"+"\033[0m") 61 | myprint ("Please make sure that new checkpoint will be saved in the same dir with the resumed model") 62 | else: 63 | if FLAGS.resume_inpainter: 64 | assert os.path.isfile(FLAGS.inpainter_ckpt+'.index') 65 | inpainter_saver.restore(sess, FLAGS.inpainter_ckpt) 66 | myprint ("Load pretrained inpainter {}".format(FLAGS.inpainter_ckpt)) 67 | 68 | if FLAGS.resume_resnet: 69 | resnet_saver.restore(sess, FLAGS.resnet_ckpt) 70 | myprint ("Load pretrained resnet {}".format(FLAGS.resnet_ckpt)) 71 | if not FLAGS.resume_resnet and not FLAGS.resume_inpainter: 72 | myprint ("Train from scratch") 73 | myinput('Press enter to continue') 74 | 75 | start_time = time.time() 76 | step = sess.run(graph.global_step) 77 | progbar = Progbar(target=FLAGS.ckpt_steps) #100k 78 | 79 | sum_iters = FLAGS.iters_gen + FLAGS.iters_inp 80 | 81 | while (time.time()-start_time) should have an VAE step 78 | 79 | if vae_step % FLAGS.summaries_steps == 0: 80 | fetches['tex_kl'], fetches['mask_kl'], fetches['bg_kl'] = graph.loss['tex_kl'], graph.loss['mask_kl'], graph.loss['bg_kl'] 81 | fetches['Fusion'] = graph.loss['Fusion'] 82 | fetches['summary'] = summary_op 83 | 84 | 85 | results = sess.run(fetches, feed_dict={graph.is_training: True, graph.mask_capacity: mask_capacity}) 86 | progbar.update(vae_step%FLAGS.ckpt_steps) 87 | 88 | if vae_step % FLAGS.summaries_steps == 0 : 89 | print (" Step:%3dk time:%4.4fmin " \ 90 | %(vae_step/1000, (time.time()-start_time)/60)) 91 | sv.summary_writer.add_summary(results['summary'], vae_step) 92 | 93 | for d in range(FLAGS.tex_dim): 94 | tex_summary = sess.run(tex_latent_summary_op, feed_dict={graph.loss['tex_kl_var']: results['tex_kl'][d]}) 95 | tex_latent_writers[d].add_summary(tex_summary, vae_step) 96 | 97 | for d in range(FLAGS.bg_dim): 98 | bg_summary = sess.run(bg_latent_summary_op, feed_dict={graph.loss['bg_kl_var']: results['bg_kl'][d]}) 99 | bg_latent_writers[d].add_summary(bg_summary, vae_step) 100 | 101 | for d in range(FLAGS.mask_dim): 102 | mask_summary = sess.run(mask_latent_summary_op, feed_dict={graph.loss['mask_kl_var']: results['mask_kl'][d]}) 103 | mask_latent_writers[d].add_summary(mask_summary, vae_step) 104 | 105 | 106 | if vae_step % FLAGS.ckpt_steps == 0: 107 | saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'model'), global_step=vae_step) 108 | progbar = Progbar(target=FLAGS.ckpt_steps) 109 | 110 | vae_step = results['vae_step'] 111 | 112 | myprint("Training completed") -------------------------------------------------------------------------------- /trainer/train_end2end.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | import gflags 4 | #https://github.com/google/python-gflags 5 | import sys 6 | sys.path.append("..") 7 | import pprint 8 | from keras.utils.generic_utils import Progbar 9 | import model.Summary as Summary 10 | from model.utils.generic_utils import myprint, myinput, Permute_IoU 11 | from model.train_graph import Train_Graph 12 | import imageio 13 | import numpy as np 14 | import time 15 | import re 16 | 17 | def train(FLAGS): 18 | # learner 19 | graph = Train_Graph(FLAGS) 20 | graph.build() 21 | 22 | summary_op, tex_latent_summary_op, bg_latent_summary_op, eval_summary_op = Summary.collect_end2end_summary(graph, FLAGS) 23 | # train 24 | #define model saver 25 | with tf.name_scope("parameter_count"): 26 | total_parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) \ 27 | for v in tf.trainable_variables()]) 28 | 29 | save_vars = tf.global_variables() 30 | # tf.global_variables('Inpainter')+tf.global_variables('Generator')+ \ 31 | # tf.global_variables('VAE')+tf.global_variables('Fusion') \ 32 | # +tf.global_variables('train_op') #including global step 33 | 34 | if FLAGS.resume_CIS: 35 | CIS_vars = tf.global_variables('Inpainter')+tf.global_variables('Generator') 36 | CIS_saver = tf.train.Saver(CIS_vars, max_to_keep=100) 37 | 38 | mask_saver = tf.train.Saver(tf.global_variables('VAE//separate/maskVAE/'), max_to_keep=100) 39 | tex_saver = tf.train.Saver(tf.global_variables('VAE//separate/texVAE/'), max_to_keep=100) 40 | 41 | saver = tf.train.Saver(save_vars, max_to_keep=100) 42 | branch_writers = [tf.summary.FileWriter(os.path.join(FLAGS.checkpoint_dir,'branch'+str(m))) for m in range(FLAGS.num_branch)] 43 | tex_latent_writers = [tf.summary.FileWriter(os.path.join(FLAGS.checkpoint_dir, "tex_latent"+str(m))) for m in range(FLAGS.tex_dim)] 44 | bg_latent_writers = [tf.summary.FileWriter(os.path.join(FLAGS.checkpoint_dir, "bg_latent"+str(m))) for m in range(FLAGS.bg_dim)] 45 | #mask_latent_writers = [tf.summary.FileWriter(os.path.join(FLAGS.checkpoint_dir, "mask_latent"+str(m))) for m in range(FLAGS.mask_dim)] 46 | 47 | 48 | 49 | sv = tf.train.Supervisor(logdir=os.path.join(FLAGS.checkpoint_dir, "end2end_Sum"), 50 | saver=None, save_summaries_secs=0) #not saved automatically for flexibility 51 | 52 | with sv.managed_session() as sess: 53 | myprint ("Number of total params: {0} \n".format( \ 54 | sess.run(total_parameter_count))) 55 | if FLAGS.resume_fullmodel: 56 | assert os.path.isfile(FLAGS.fullmodel_ckpt+'.index') 57 | saver.restore(sess, FLAGS.fullmodel_ckpt) 58 | myprint ("Resumed training from model {}".format(FLAGS.fullmodel_ckpt)) 59 | myprint ("Start from step {} vae_step{}".format(sess.run(graph.global_step), sess.run(graph.vae_global_step))) 60 | myprint ("Save checkpoint in {}".format(FLAGS.checkpoint_dir)) 61 | if not os.path.dirname(FLAGS.fullmodel_ckpt) == FLAGS.checkpoint_dir: 62 | print ("\033[0;30;41m"+"Warning: checkpoint dir and fullmodel ckpt do not match"+"\033[0m") 63 | #myprint ("Please make sure that the checkpoint will be saved in the same dir with the resumed model") 64 | else: 65 | if os.path.isfile(FLAGS.mask_ckpt+'.index'): 66 | mask_saver.restore(sess, FLAGS.mask_ckpt) 67 | myprint ("Load pretrained maskVAE {}".format(FLAGS.mask_ckpt)) 68 | if os.path.isfile(FLAGS.tex_ckpt+'.index'): 69 | tex_saver.restore(sess, FLAGS.tex_ckpt) 70 | myprint ("Load pretrained texVAE {}".format(FLAGS.tex_ckpt)) 71 | if FLAGS.resume_CIS: 72 | assert os.path.isfile(FLAGS.CIS_ckpt+'.index') 73 | CIS_saver.restore(sess, FLAGS.CIS_ckpt) 74 | myprint ("Load pretrained inpainter and generator {}".format(FLAGS.CIS_ckpt)) 75 | else: 76 | myprint ("Train from scratch") 77 | myinput('Press enter to continue') 78 | 79 | start_time = time.time() 80 | step = sess.run(graph.global_step) 81 | vae_step = sess.run(graph.vae_global_step) 82 | progbar = Progbar(target=FLAGS.ckpt_steps) #100k 83 | 84 | sum_iters = FLAGS.iters_gen_vae + FLAGS.iters_inp 85 | 86 | while (time.time()-start_time) should have an VAE step 98 | fetches['vae_global_step'], fetches['vae_global_step_inc'] = graph.vae_global_step, graph.incr_vae_global_step 99 | 100 | if step % FLAGS.summaries_steps == 0: 101 | fetches["Inpainter_Loss"],fetches["Generator_Loss"] = graph.loss['Inpainter'], graph.loss['Generator'] 102 | fetches["VAE//texVAE"], fetches["VAE//texVAE_BG"], fetches['VAE//fusion'] = graph.loss['VAE//separate/texVAE/'], graph.loss['VAE//separate/texVAE_BG/'], graph.loss['VAE//fusion'] 103 | fetches['tex_kl'], fetches['bg_kl'] = graph.loss['tex_kl'], graph.loss['bg_kl'] 104 | fetches['summary'] = summary_op 105 | 106 | if step % FLAGS.ckpt_steps == 0: 107 | fetches['generated_masks'] = graph.generated_masks 108 | fetches['GT_masks'] = graph.GT_masks 109 | 110 | results = sess.run(fetches, feed_dict={graph.is_training: True, graph.mask_capacity: mask_capacity}) 111 | progbar.update(step%FLAGS.ckpt_steps) 112 | 113 | if step % FLAGS.summaries_steps == 0 : 114 | print (" Step:%3dk time:%4.4fmin VAELoss%4.2f" \ 115 | %(step/1000, (time.time()-start_time)/60, results["VAE//texVAE"]+results['VAE//fusion']+results['VAE//texVAE_BG'])) 116 | sv.summary_writer.add_summary(results['summary'], step) 117 | 118 | for d in range(FLAGS.tex_dim): 119 | tex_summary = sess.run(tex_latent_summary_op, feed_dict={graph.loss['tex_kl_var']: results['tex_kl'][d]}) 120 | tex_latent_writers[d].add_summary(tex_summary, step) 121 | 122 | for d in range(FLAGS.bg_dim): 123 | bg_summary = sess.run(bg_latent_summary_op, feed_dict={graph.loss['bg_kl_var']: results['bg_kl'][d]}) 124 | bg_latent_writers[d].add_summary(bg_summary, step) 125 | 126 | # for d in range(FLAGS.mask_dim): 127 | # mask_summary = sess.run(mask_latent_summary_op, feed_dict={graph.loss['mask_kl_var']: results['mask_kl'][d]}) 128 | # mask_latent_writers[d].add_summary(mask_summary, step) 129 | 130 | 131 | if step % FLAGS.ckpt_steps == 0: 132 | saver.save(sess, os.path.join(FLAGS.checkpoint_dir, 'model'), global_step=step) 133 | progbar = Progbar(target=FLAGS.ckpt_steps) 134 | 135 | #evaluation 136 | sess.run(graph.val_iterator.initializer) 137 | fetches = {'GT_masks':graph.GT_masks, 'generated_masks':graph.generated_masks} 138 | 139 | if FLAGS.dataset in ['multi_texture', 'flying_animals']: 140 | #note that for multi_texture bg_num is just a fake number it represents number of samples for each type of image 141 | score = [[]]*FLAGS.max_num 142 | for bg in range(FLAGS.bg_num): 143 | results_val=sess.run(fetches, feed_dict={graph.is_training: False}) 144 | for k in range(FLAGS.max_num): 145 | #score[k].append(Permute_IoU(results_val['GT_masks'][k], results_val['generated_masks'][k])) 146 | score[k] = score[k] + [Permute_IoU(label=results_val['GT_masks'][k], pred=results_val['generated_masks'][k])] 147 | for k in range(FLAGS.max_num): 148 | eval_summary = sess.run(eval_summary_op, feed_dict={graph.loss['EvalIoU_var']: np.mean(score[k])}) 149 | branch_writers[k+1].add_summary(eval_summary, step) 150 | else: 151 | num_sample = FLAGS.skipnum 152 | niter = num_sample//FLAGS.batch_size 153 | assert num_sample%FLAGS.batch_size==0 154 | score = 0 155 | for it in range(niter): 156 | results_val = sess.run(fetches, feed_dict={graph.is_training:False}) 157 | for k in range(FLAGS.batch_size): 158 | score += Permute_IoU(label=results_val['GT_masks'][k], pred=results_val['generated_masks'][k]) 159 | score = score/num_sample 160 | eval_summary = sess.run(eval_summary_op, feed_dict={graph.loss['EvalIoU_var']: score}) 161 | sv.summary_writer.add_summary(eval_summary, step) 162 | 163 | step = results['step'] 164 | vae_step = results['vae_global_step'] 165 | 166 | myprint("Training completed") -------------------------------------------------------------------------------- /trainer/train_globalVAE.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | import gflags 4 | #https://github.com/google/python-gflags 5 | import sys 6 | sys.path.append("..") 7 | import pprint 8 | from keras.utils.generic_utils import Progbar 9 | import model.Summary as Summary 10 | from model.utils.generic_utils import myprint, myinput, Permute_IoU 11 | from model.globalVAE_graph import Train_Graph 12 | import imageio 13 | import numpy as np 14 | import time 15 | import re 16 | 17 | def train(FLAGS): 18 | # learner 19 | graph = Train_Graph(FLAGS) 20 | graph.build() 21 | 22 | summary_op, latent_summary_op = Summary.collect_globalVAE_summary(graph, FLAGS) 23 | # train 24 | #define model saver 25 | with tf.name_scope("parameter_count"): 26 | total_parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) \ 27 | for v in tf.trainable_variables()]) 28 | 29 | save_vars = tf.global_variables() 30 | saver = tf.train.Saver(save_vars, max_to_keep=100) 31 | 32 | latent_writers = [tf.summary.FileWriter(os.path.join(FLAGS.checkpoint_dir, "latent"+str(m))) \ 33 | for m in range(FLAGS.tex_dim)] 34 | sv = tf.train.Supervisor(logdir=os.path.join(FLAGS.checkpoint_dir, "globalVAE_Sum"), 35 | saver=None, save_summaries_secs=0) #not saved automatically for flexibility 36 | 37 | with sv.managed_session() as sess: 38 | myprint ("Number of total params: {0} \n".format( \ 39 | sess.run(total_parameter_count))) 40 | if FLAGS.resume_fullmodel: 41 | assert os.path.isfile(FLAGS.fullmodel_ckpt+'.index') 42 | saver.restore(sess, FLAGS.fullmodel_ckpt) 43 | myprint ("Resumed training from model {}".format(FLAGS.fullmodel_ckpt)) 44 | myprint ("Start from step {}".format(sess.run(graph.global_step))) 45 | myprint ("Save checkpoint in {}".format(FLAGS.checkpoint_dir)) 46 | if not os.path.dirname(FLAGS.fullmodel_ckpt) == FLAGS.checkpoint_dir: 47 | print ("\033[0;30;41m"+"Warning: checkpoint dir and fullmodel ckpt do not match"+"\033[0m") 48 | #myprint ("Please make sure that the checkpoint will be saved in the same dir with the resumed model") 49 | else: 50 | myprint ("Train from scratch") 51 | myinput('Press enter to continue') 52 | 53 | start_time = time.time() 54 | step = sess.run(graph.global_step) 55 | progbar = Progbar(target=FLAGS.ckpt_steps) #100k 56 | 57 | while (time.time()-start_time)