├── .gitignore ├── LICENSE ├── README.md ├── _config.yml ├── code ├── clevr │ ├── cfg │ │ ├── clevr_eval.yml │ │ └── clevr_train.yml │ ├── main.py │ ├── miscc │ │ ├── __init__.py │ │ ├── config.py │ │ ├── datasets.py │ │ └── utils.py │ ├── model.py │ └── trainer.py ├── coco │ ├── attngan │ │ ├── DAMSMencoders │ │ │ └── README.md │ │ ├── GlobalAttention.py │ │ ├── cfg │ │ │ ├── coco_eval.yml │ │ │ └── coco_train.yml │ │ ├── datasets.py │ │ ├── losses.py │ │ ├── main.py │ │ ├── miscc │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ ├── losses.py │ │ │ └── utils.py │ │ ├── model.py │ │ ├── trainer.py │ │ └── utils.py │ └── stackgan │ │ ├── cfg │ │ ├── coco_s1_eval.yml │ │ ├── coco_s1_train.yml │ │ ├── coco_s2_eval.yml │ │ └── coco_s2_train.yml │ │ ├── main.py │ │ ├── miscc │ │ ├── __init__.py │ │ ├── config.py │ │ ├── datasets.py │ │ └── utils.py │ │ ├── model.py │ │ └── trainer.py └── multi-mnist │ ├── cfg │ ├── mnist_eval.yml │ └── mnist_train.yml │ ├── main.py │ ├── miscc │ ├── __init__.py │ ├── config.py │ ├── datasets.py │ └── utils.py │ ├── model.py │ └── trainer.py ├── data └── README.md ├── examples ├── clevr_cogent.png ├── clevr_example.png ├── clevr_generated.png ├── clevr_real_examples.png ├── coco_attngan_example.png ├── coco_bbox_example.png ├── coco_examples.png ├── coco_generated_examples.png ├── coco_no_bbox.png ├── coco_pathway.png ├── coco_stackgan_example.png ├── coco_stuff_example.png ├── d_final_pathway.png ├── d_global_pathway.png ├── d_object_pathway.png ├── datasets_examples.png ├── g_final_pathway.png ├── g_global_pathway.png ├── g_object_pathway.png ├── gan_graphic.png ├── label_encoding.png ├── layout_encoding.png ├── model.png ├── multi-mnist_example.png ├── multi_mnist_digit_bottom_half.png ├── multi_mnist_digit_example.png ├── multi_mnist_digit_generalization.png ├── multi_mnist_real.png └── multi_mnist_size_example.png ├── index.md ├── models └── README.md ├── poster └── Generating Multiple Objects at Spatially Distinct Locations.pdf ├── requirements.txt ├── sample.sh └── train.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Tobias Hinz 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Generating Multiple Objects at Spatially Distinct Locations 2 | Pytorch implementation for reproducing the results from the paper [Generating Multiple Objects at Spatially Distinct Locations](https://openreview.net/forum?id=H1edIiA9KQ) by Tobias Hinz, Stefan Heinrich, and Stefan Wermter accepted for publication at the [International Conference on Learning Representations](https://iclr.cc/) 2019. 3 | 4 | > *For more information and visualizations also see our [blog post](https://tohinz.github.io/blog/generating-multiple-objects-at-spatially-distinct-locations)* 5 | 6 | > Our poster can be found [here](https://postersession.ai/) 7 | 8 | Have a look at our follow-up work [Semantic Object Accuracy for Generative Text-to-Image Synthesis](https://arxiv.org/abs/1910.13321) with available [code](https://github.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis). 9 | 10 | ![Model-Architecture](examples/model.png) 11 | 12 | # Dependencies 13 | - python 2.7 14 | - pytorch 0.4.1 15 | 16 | Please add the project folder to PYTHONPATH and install the required dependencies: 17 | 18 | ``` 19 | pip install -r requirements.txt 20 | ``` 21 | 22 | # Data 23 | - Multi-MNIST: adapted from [here](https://github.com/aakhundov/tf-attend-infer-repeat) 24 | - contains the three data sets used in the paper: normal (three digits per image), split_digits (0-4 in top half of image, 5-9 in bottom half), and bottom_half_empty (no digits in bottom half of the image) 25 | - [download](https://www2.informatik.uni-hamburg.de/wtm/software/multiple-objects-gan/data-multi-mnist.zip) our data, save it to `data/` and extract 26 | - CLEVR: adapted from [here](https://github.com/facebookresearch/clevr-dataset-gen) 27 | - Main: [download](https://www2.informatik.uni-hamburg.de/wtm/software/multiple-objects-gan/data-clevr-main.zip) our data, save it to `data/` and extract 28 | - CoGenT: [download](https://www2.informatik.uni-hamburg.de/wtm/software/multiple-objects-gan/data-clevr-cogent.zip) our data, save it to `data/` and extract 29 | - MS-COCO: 30 | - [download](https://www2.informatik.uni-hamburg.de/wtm/software/multiple-objects-gan/data-ms-coco.zip) our preprocessed data (bounding boxes and bounding box labels), save it to `data/` and extract 31 | - obtain the train and validation images from the 2014 split [here](http://cocodataset.org/#download), extract and save them in `data/MS-COCO/train/` and `data/MS-COCO/test/` 32 | - for the StackGAN architecture: obtain the preprocessed char-CNN-RNN text embeddings from [here](https://github.com/hanzhanggit/StackGAN-Pytorch) and put the files in `data/MS-COCO/train/` and `data/MS-COCO/test/` 33 | - for the AttnGAN architecture: obtain the preprocessed metadata and the pre-trained DAMSM model from [here](https://github.com/taoxugit/AttnGAN) 34 | - extract the preprocessed metadata, then add the files downloaded in the first step (bounding boxes and bounding box labels) to the `data/coco/coco/train/` and `data/coco/coco/test/` folder 35 | - put the downloaded DAMSM model into `code/coco/attngan/DAMSMencoders/` and extract 36 | 37 | # Training 38 | - to start training run `sh train.sh data gpu-ids` where you choose the desired data set and architecture (mnist/clevr/coco-stackgan-1/coco-stackgan-2/coco-attngan) and which/how many gpus to train on 39 | - e.g. to train on the Multi-MNIST data set on one GPU: `sh train.sh mnist 0` 40 | - e.g. to train the AttnGAN architecture on the MS-COCO data set on three GPUs: `sh train.sh coco-attngan 0,1,2` 41 | - training parameters can be adapted via `code/dataset/cfg/dataset_train.yml` 42 | - make sure the DATA_DIR in the respective `code/dataset/cfg/dataset_train.yml` points to the correct path 43 | - results are stored in `output/` 44 | 45 | # Evaluating 46 | - update the eval cfg file in `code/dataset/cfg/dataset_eval.yml` and adapt the path of `NET_G` to point to the model you want to use (default path is to the pretrained models linked below) 47 | - run `sh sample.sh mnist/clevr/coco-stackgan-2/coco-attngan` to generate images using the specified model 48 | 49 | # Pretrained Models 50 | - pretrained model for Multi-MNIST: [download](https://www2.informatik.uni-hamburg.de/wtm/software/multiple-objects-gan/model-multi-mnist.zip), save to `models` and extract 51 | - pretrained model for CLEVR: [download](https://www2.informatik.uni-hamburg.de/wtm/software/multiple-objects-gan/model-clevr.zip), save to `models` and extract 52 | - pretrained model for MS-COCO: 53 | - StackGAN architecture: [download](https://www2.informatik.uni-hamburg.de/wtm/software/multiple-objects-gan/model-ms-coco-stackgan.zip), save to `models` and extract 54 | - AttnGAN architecture: [download](https://www2.informatik.uni-hamburg.de/wtm/software/multiple-objects-gan/model-ms-coco-attngan.zip), save to `models` and extract 55 | 56 | # Examples Generated by the Pretrained Models 57 | ### Multi-MNIST 58 | ![Multi-Mnist Examples](examples/multi-mnist_example.png) 59 | 60 | ### CLEVR 61 | ![CLEVR Examples](examples/clevr_example.png) 62 | 63 | ### MS-COCO 64 | ##### StackGAN Architecture 65 | ![COCO-StackGAN Examples](examples/coco_stackgan_example.png) 66 | 67 | ##### AttnGAN Architecture 68 | ![COCO-AttnGAN Examples](examples/coco_attngan_example.png) 69 | 70 | # Acknowledgement 71 | - Code for the experiments on Multi-MNIST and CLEVR data sets is adapted from [StackGAN-Pytorch](https://github.com/hanzhanggit/StackGAN-Pytorch). 72 | - Code for the experiments on MS-COCO with the StackGAN architecture is adapted from [StackGAN-Pytorch](https://github.com/hanzhanggit/StackGAN-Pytorch), while the code with the AttnGAN architecture is adapted from [AttnGAN](https://github.com/taoxugit/AttnGAN). 73 | 74 | # Citing 75 | If you find our model useful in your research please consider citing: 76 | 77 | ``` 78 | @inproceedings{hinz2019generating, 79 | title = {Generating Multiple Objects at Spatially Distinct Locations}, 80 | author = {Tobias Hinz and Stefan Heinrich and Stefan Wermter}, 81 | booktitle = {International Conference on Learning Representations}, 82 | year = {2019}, 83 | url = {https://openreview.net/forum?id=H1edIiA9KQ}, 84 | } 85 | ``` 86 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-cayman 2 | 3 | # This is the default format. 4 | # For more see: https://github.com/mojombo/jekyll/wiki/Permalinks 5 | permalink: /:categories/:year/:month/:day/:title 6 | 7 | exclude: ["LICENSE"] 8 | highlighter: rouge 9 | 10 | # Themes are encouraged to use these universal variables 11 | # so be sure to set them if your theme uses them. 12 | title : Generating Multiple Objects at Spatially Distinct Locations 13 | description: Summary of our paper published at ICLR 2019 14 | author : 15 | name : Tobias Hinz 16 | github : tohinz 17 | 18 | production_url : https://tohinz.github.io/multiple-objects-gan/ 19 | 20 | # Tell Github to use the kramdown markdown interpreter 21 | # (see https://help.github.com/articles/migrating-your-pages-site-from-maruku) 22 | markdown: kramdown 23 | 24 | # All Jekyll-Bootstrap specific configurations are namespaced into this hash 25 | # 26 | JB : 27 | version : 0.3.0 28 | 29 | BASE_PATH : https://tohinz.github.io/multiple-objects-gan/ 30 | 31 | # These paths are to the main pages Jekyll-Bootstrap ships with. 32 | # Some JB helpers refer to these paths; change them here if needed. 33 | # 34 | archive_path: nil 35 | categories_path : nil 36 | tags_path : nil 37 | atom_path : nil 38 | rss_path : nil 39 | 40 | # Settings for comments helper 41 | # Set 'provider' to the comment provider you want to use. 42 | # Set 'provider' to false to turn commenting off globally. 43 | # 44 | comments : 45 | provider : false 46 | 47 | # Settings for analytics helper 48 | # Set 'provider' to the analytics provider you want to use. 49 | # Set 'provider' to false to turn analytics off globally. 50 | # 51 | analytics : 52 | provider : false 53 | 54 | # Settings for sharing helper. 55 | # Sharing is for things like tweet, plusone, like, reddit buttons etc. 56 | # Set 'provider' to the sharing provider you want to use. 57 | # Set 'provider' to false to turn sharing off globally. 58 | # 59 | sharing : 60 | provider : false 61 | 62 | # Settings for all other include helpers can be defined by creating 63 | # a hash with key named for the given helper. ex: 64 | # 65 | # pages_list : 66 | # provider : "custom" 67 | # 68 | # Setting any helper's provider to 'custom' will bypass the helper code 69 | # and include your custom code. Your custom file must be defined at: 70 | # ./_includes/custom/[HELPER] 71 | # where [HELPER] is the name of the helper you are overriding. 72 | -------------------------------------------------------------------------------- /code/clevr/cfg/clevr_eval.yml: -------------------------------------------------------------------------------- 1 | DATASET_NAME: 'clevr' 2 | GPU_ID: '0' 3 | Z_DIM: 100 4 | NET_G: '../../models/model-clevr-0039.pth' 5 | DATA_DIR: '../../data/clevr' 6 | WORKERS: 4 7 | IMSIZE: 64 8 | USE_BBOX_LAYOUT: True 9 | TRAIN: 10 | FLAG: False 11 | BATCH_SIZE: 1 12 | 13 | GAN: 14 | CONDITION_DIM: 16 15 | DF_DIM: 48 16 | GF_DIM: 96 17 | 18 | -------------------------------------------------------------------------------- /code/clevr/cfg/clevr_train.yml: -------------------------------------------------------------------------------- 1 | DATASET_NAME: 'clevr' 2 | GPU_ID: '0' 3 | Z_DIM: 100 4 | DATA_DIR: '../../data/clevr' 5 | IMSIZE: 64 6 | WORKERS: 4 7 | USE_BBOX_LAYOUT: True 8 | TRAIN: 9 | FLAG: True 10 | BATCH_SIZE: 128 11 | MAX_EPOCH: 40 12 | LR_DECAY_EPOCH: 10 13 | SNAPSHOT_INTERVAL: 10 14 | DISCRIMINATOR_LR: 0.0002 15 | GENERATOR_LR: 0.0002 16 | 17 | GAN: 18 | CONDITION_DIM: 16 19 | DF_DIM: 48 20 | GF_DIM: 96 21 | -------------------------------------------------------------------------------- /code/clevr/main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch.backends.cudnn as cudnn 3 | import torch 4 | import torchvision.transforms as transforms 5 | import torch.backends.cudnn as cudnn 6 | 7 | import argparse 8 | import os 9 | import random 10 | import sys 11 | import pprint 12 | import datetime 13 | import dateutil 14 | import dateutil.tz 15 | from shutil import copyfile 16 | 17 | 18 | dir_path = (os.path.abspath(os.path.join(os.path.realpath(__file__), './.'))) 19 | sys.path.append(dir_path) 20 | 21 | from miscc.datasets import TextDataset 22 | from miscc.config import cfg, cfg_from_file 23 | from miscc.utils import mkdir_p 24 | from trainer import GANTrainer 25 | 26 | 27 | def parse_args(): 28 | parser = argparse.ArgumentParser(description='Train a GAN network') 29 | parser.add_argument('--cfg', dest='cfg_file', 30 | help='optional config file', 31 | default='birds_stage1.yml', type=str) 32 | parser.add_argument('--gpu', dest='gpu_id', type=str, default='0') 33 | parser.add_argument('--data_dir', dest='data_dir', type=str, default='') 34 | parser.add_argument('--manualSeed', type=int, help='manual seed') 35 | args = parser.parse_args() 36 | return args 37 | 38 | if __name__ == "__main__": 39 | args = parse_args() 40 | if args.cfg_file is not None: 41 | cfg_from_file(args.cfg_file) 42 | if args.gpu_id != -1: 43 | cfg.GPU_ID = args.gpu_id 44 | if args.data_dir != '': 45 | cfg.DATA_DIR = args.data_dir 46 | print('Using config:') 47 | pprint.pprint(cfg) 48 | if args.manualSeed is None: 49 | args.manualSeed = random.randint(1, 10000) 50 | random.seed(args.manualSeed) 51 | torch.manual_seed(args.manualSeed) 52 | if cfg.CUDA: 53 | torch.cuda.manual_seed_all(args.manualSeed) 54 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 55 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') 56 | output_dir = '../../output/%s_%s' % \ 57 | (cfg.DATASET_NAME, timestamp) 58 | 59 | cudnn.benchmark = True 60 | 61 | num_gpu = len(cfg.GPU_ID.split(',')) 62 | if cfg.TRAIN.FLAG: 63 | try: 64 | os.makedirs(output_dir) 65 | except OSError as exc: # Python >2.5 66 | if exc.errno == errno.EEXIST and os.path.isdir(path): 67 | pass 68 | else: 69 | raise 70 | 71 | copyfile(sys.argv[0], output_dir + "/" + sys.argv[0]) 72 | copyfile("trainer.py", output_dir + "/" + "trainer.py") 73 | copyfile("model.py", output_dir + "/" + "model.py") 74 | copyfile("miscc/utils.py", output_dir + "/" + "utils.py") 75 | copyfile("miscc/datasets.py", output_dir + "/" + "datasets.py") 76 | copyfile(args.cfg_file, output_dir + "/" + "cfg_file.yml") 77 | 78 | imsize=64 79 | img_transform = transforms.Compose([ 80 | transforms.ToTensor(), 81 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 82 | dataset = TextDataset(cfg.DATA_DIR, split="train", imsize=imsize, transform=img_transform) 83 | assert dataset 84 | dataloader = torch.utils.data.DataLoader( 85 | dataset, batch_size=cfg.TRAIN.BATCH_SIZE, 86 | drop_last=True, shuffle=True, num_workers=int(cfg.WORKERS)) 87 | 88 | algo = GANTrainer(output_dir) 89 | algo.train(dataloader) 90 | else: 91 | imsize=64 92 | datapath= '%s/test/' % (cfg.DATA_DIR) 93 | img_transform = transforms.Compose([ 94 | transforms.ToTensor(), 95 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 96 | dataset = TextDataset(cfg.DATA_DIR, split="test", imsize=imsize, transform=img_transform) 97 | assert dataset 98 | dataloader = torch.utils.data.DataLoader( 99 | dataset, batch_size=cfg.TRAIN.BATCH_SIZE, 100 | drop_last=True, shuffle=True, num_workers=int(cfg.WORKERS)) 101 | algo = GANTrainer(output_dir) 102 | algo.sample(dataloader, num_samples=25, draw_bbox=True) 103 | -------------------------------------------------------------------------------- /code/clevr/miscc/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | -------------------------------------------------------------------------------- /code/clevr/miscc/config.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import os.path as osp 5 | import numpy as np 6 | from easydict import EasyDict as edict 7 | 8 | 9 | __C = edict() 10 | cfg = __C 11 | 12 | # Dataset name: flowers, birds 13 | __C.DATASET_NAME = 'clevr' 14 | __C.CONFIG_NAME = '' 15 | __C.GPU_ID = '0' 16 | __C.CUDA = True 17 | __C.WORKERS = 6 18 | 19 | __C.NET_G = '' 20 | __C.NET_D = '' 21 | __C.DATA_DIR = '' 22 | __C.VIS_COUNT = 64 23 | 24 | __C.Z_DIM = 100 25 | __C.IMSIZE = 64 26 | 27 | __C.USE_LOCAL_PATHWAY = True 28 | __C.USE_BBOX_LAYOUT = True 29 | 30 | # Training options 31 | __C.TRAIN = edict() 32 | __C.TRAIN.FLAG = True 33 | __C.TRAIN.BATCH_SIZE = 64 34 | __C.TRAIN.MAX_EPOCH = 600 35 | __C.TRAIN.SNAPSHOT_INTERVAL = 50 36 | __C.TRAIN.LR_DECAY_EPOCH = 600 37 | __C.TRAIN.DISCRIMINATOR_LR = 2e-4 38 | __C.TRAIN.GENERATOR_LR = 2e-4 39 | 40 | # Modal options 41 | __C.GAN = edict() 42 | __C.GAN.CONDITION_DIM = 128 43 | __C.GAN.DF_DIM = 64 44 | __C.GAN.GF_DIM = 128 45 | __C.GAN.R_NUM = 4 46 | 47 | 48 | def _merge_a_into_b(a, b): 49 | """Merge config dictionary a into config dictionary b, clobbering the 50 | options in b whenever they are also specified in a. 51 | """ 52 | if type(a) is not edict: 53 | return 54 | 55 | for k, v in a.iteritems(): 56 | # a must specify keys that are in b 57 | if not b.has_key(k): 58 | raise KeyError('{} is not a valid config key'.format(k)) 59 | 60 | # the types must match, too 61 | old_type = type(b[k]) 62 | if old_type is not type(v): 63 | if isinstance(b[k], np.ndarray): 64 | v = np.array(v, dtype=b[k].dtype) 65 | else: 66 | raise ValueError(('Type mismatch ({} vs. {}) ' 67 | 'for config key: {}').format(type(b[k]), 68 | type(v), k)) 69 | 70 | # recursively merge dicts 71 | if type(v) is edict: 72 | try: 73 | _merge_a_into_b(a[k], b[k]) 74 | except: 75 | print('Error under config key: {}'.format(k)) 76 | raise 77 | else: 78 | b[k] = v 79 | 80 | 81 | def cfg_from_file(filename): 82 | """Load a config file and merge it into the default options.""" 83 | import yaml 84 | with open(filename, 'r') as f: 85 | yaml_cfg = edict(yaml.load(f)) 86 | 87 | _merge_a_into_b(yaml_cfg, __C) 88 | -------------------------------------------------------------------------------- /code/clevr/miscc/datasets.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import torch.utils.data as data 7 | import PIL 8 | import os 9 | import os.path 10 | import random 11 | import numpy as np 12 | import json 13 | import glob 14 | from PIL import Image 15 | import torchvision.transforms as transforms 16 | import torch 17 | import sys 18 | if sys.version_info[0] == 2: 19 | import cPickle as pickle 20 | else: 21 | import pickle 22 | 23 | from miscc.config import cfg 24 | from miscc.utils import * 25 | 26 | shape_dict = { 27 | "cube": 0, 28 | "cylinder": 1, 29 | "sphere": 2 30 | } 31 | 32 | color_dict = { 33 | "gray": 0, 34 | "red": 1, 35 | "blue": 2, 36 | "green": 3, 37 | "brown": 4, 38 | "purple": 5, 39 | "cyan": 6, 40 | "yellow": 7 41 | } 42 | 43 | 44 | class TextDataset(data.Dataset): 45 | def __init__(self, data_dir, imsize, split='train', transform=None): 46 | 47 | self.transform = transform 48 | self.imsize = imsize 49 | self.data = [] 50 | self.data_dir = data_dir 51 | self.split_dir = os.path.join(data_dir, split) 52 | self.img_dir = os.path.join(self.split_dir, "images") 53 | self.scene_dir = os.path.join(self.split_dir, "scenes") 54 | self.max_objects = 4 55 | 56 | self.filenames = self.load_filenames() 57 | 58 | def get_img(self, img_path): 59 | img = Image.open(img_path).convert('RGB') 60 | 61 | if self.transform is not None: 62 | img = self.transform(img) 63 | 64 | flip_img = random.random() < 0.5 65 | if flip_img: 66 | idx = [i for i in reversed(range(img.shape[2]))] 67 | idx = torch.LongTensor(idx) 68 | img = torch.index_select(img, 2, idx) 69 | 70 | return img, flip_img 71 | 72 | def load_bboxes(self): 73 | bbox_path = os.path.join(self.split_dir, 'bboxes.pickle') 74 | with open(bbox_path, "rb") as f: 75 | bboxes = pickle.load(f) 76 | bboxes = np.array(bboxes) 77 | return bboxes 78 | 79 | def load_labels(self): 80 | label_path = os.path.join(self.split_dir, 'labels.pickle') 81 | with open(label_path, "rb") as f: 82 | labels = pickle.load(f) 83 | labels = np.array(labels) 84 | return labels 85 | 86 | def load_filenames(self): 87 | filenames = [filename for filename in glob.glob(self.scene_dir + '/*.json')] 88 | print('Load scenes from: %s (%d)' % (self.scene_dir, len(filenames))) 89 | return filenames 90 | 91 | def calc_transformation_matrix(self, bbox): 92 | bbox = torch.from_numpy(bbox) 93 | bbox = bbox.view(-1, 4) 94 | transf_matrices_inv = compute_transformation_matrix_inverse(bbox) 95 | transf_matrices_inv = transf_matrices_inv.view(self.max_objects, 2, 3) 96 | transf_matrices = compute_transformation_matrix(bbox) 97 | transf_matrices = transf_matrices.view(self.max_objects, 2, 3) 98 | return transf_matrices, transf_matrices_inv 99 | 100 | def label_one_hot(self, label, dim): 101 | labels = torch.from_numpy(label) 102 | labels = labels.long() 103 | # remove -1 to enable one-hot converting 104 | labels[labels < 0] = dim-1 105 | label_one_hot = torch.FloatTensor(labels.shape[0], dim).fill_(0) 106 | label_one_hot = label_one_hot.scatter_(1, labels, 1).float() 107 | return label_one_hot 108 | 109 | def __getitem__(self, index): 110 | # load image 111 | key = self.filenames[index] 112 | with open(key, "rb") as f: 113 | json_file = json.load(f) 114 | img_name = self.img_dir +"/" + json_file["image_filename"] 115 | img, flip_img = self.get_img(img_name) 116 | 117 | # load bbox# 118 | bbox = np.zeros((self.max_objects, 4), dtype=np.float32) 119 | bbox[:] = -1.0 120 | for idx in range(len(json_file["objects"])): 121 | bbox[idx, :] = json_file["objects"][idx]["bbox"] 122 | bbox = bbox / float(self.imsize) 123 | 124 | # load label 125 | # shapes: 3; colors: 8; materials: 2 (not used), size: 3 (but given through bbox) 126 | label_shape = np.zeros(self.max_objects) 127 | label_color = np.zeros(self.max_objects) 128 | label_shape[:] = -1 129 | label_color[:] = -1 130 | for idx in range(len(json_file["objects"])): 131 | label_shape[idx] = shape_dict[json_file["objects"][idx]["shape"]] 132 | label_color[idx] = color_dict[json_file["objects"][idx]["color"]] 133 | 134 | label_shape = self.label_one_hot(np.expand_dims(label_shape, 1), 4) 135 | label_color = self.label_one_hot(np.expand_dims(label_color, 1), 9) 136 | label = torch.cat((label_shape, label_color), 1) 137 | 138 | if flip_img: 139 | bbox[:, 0] = 1.0 - bbox[:, 0] - bbox[:, 2] 140 | transformation_matrices = self.calc_transformation_matrix(bbox) 141 | 142 | return img, transformation_matrices, label, bbox 143 | 144 | def __len__(self): 145 | return len(self.filenames) 146 | -------------------------------------------------------------------------------- /code/clevr/miscc/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import errno 3 | import numpy as np 4 | import cPickle as pickle 5 | import glob 6 | 7 | from copy import deepcopy 8 | from miscc.config import cfg 9 | 10 | from torch.nn import init 11 | import torch 12 | import torch.nn as nn 13 | import torchvision.utils as vutils 14 | from torch.autograd import grad 15 | from torch.autograd import Variable 16 | 17 | 18 | def compute_transformation_matrix_inverse(bbox): 19 | x, y = bbox[:, 0], bbox[:, 1] 20 | w, h = bbox[:, 2], bbox[:, 3] 21 | 22 | scale_x = 1.0 / w 23 | scale_y = 1.0 / h 24 | 25 | t_x = 2 * scale_x * (0.5 - (x + 0.5 * w)) 26 | t_y = 2 * scale_y * (0.5 - (y + 0.5 * h)) 27 | 28 | zeros = torch.FloatTensor(bbox.shape[0],1).fill_(0) 29 | 30 | transformation_matrix = torch.cat([scale_x.unsqueeze(-1), zeros, t_x.unsqueeze(-1), 31 | zeros, scale_y.unsqueeze(-1), t_y.unsqueeze(-1)], 1).view(-1, 2, 3) 32 | 33 | return transformation_matrix 34 | 35 | 36 | def compute_transformation_matrix(bbox): 37 | x, y = bbox[:, 0], bbox[:, 1] 38 | w, h = bbox[:, 2], bbox[:, 3] 39 | 40 | scale_x = w 41 | scale_y = h 42 | 43 | t_x = 2 * ((x + 0.5 * w) - 0.5) 44 | t_y = 2 * ((y + 0.5 * h) - 0.5) 45 | 46 | zeros = torch.FloatTensor(bbox.shape[0],1).fill_(0) 47 | 48 | transformation_matrix = torch.cat([scale_x.unsqueeze(-1), zeros, t_x.unsqueeze(-1), 49 | zeros, scale_y.unsqueeze(-1), t_y.unsqueeze(-1)], 1).view(-1, 2, 3) 50 | 51 | return transformation_matrix 52 | 53 | 54 | def load_validation_data(filenames, index): 55 | key = filenames[index] 56 | with open(key, "rb") as f: 57 | json_file = json.load(f) 58 | img_name = self.img_dir + "/" + json_file["image_filename"] 59 | img, flip_img = get_img(img_name) 60 | 61 | # load bbox# 62 | bbox = np.zeros((self.max_objects, 4), dtype=np.float32) 63 | bbox[:] = -1.0 64 | for idx in range(len(json_file["objects"])): 65 | bbox[idx, :] = json_file["objects"][idx]["bbox"] 66 | bbox = bbox / float(self.imsize) 67 | 68 | # load label 69 | # shapes: 3; colors: 8; materials: 2 (not used), size: 3 (but given through bbox) 70 | label_shape = np.zeros(self.max_objects) 71 | label_color = np.zeros(self.max_objects) 72 | label_shape[:] = -1 73 | label_color[:] = -1 74 | for idx in range(len(json_file["objects"])): 75 | label_shape[idx] = shape_dict[json_file["objects"][idx]["shape"]] 76 | label_color[idx] = color_dict[json_file["objects"][idx]["color"]] 77 | 78 | label_shape = self.label_one_hot(np.expand_dims(label_shape, 1), 4) 79 | label_color = self.label_one_hot(np.expand_dims(label_color, 1), 9) 80 | label = torch.cat((label_shape, label_color), 1) 81 | 82 | if flip_img: 83 | bbox[:, 0] = 1.0 - bbox[:, 0] - bbox[:, 2] 84 | transformation_matrices = self.calc_transformation_matrix(bbox) 85 | 86 | return img, transformation_matrices, label 87 | 88 | return torch.from_numpy(labels), torch.from_numpy(bboxes) 89 | 90 | 91 | def compute_discriminator_loss(netD, real_imgs, fake_imgs, 92 | real_labels, fake_labels, 93 | local_label, transf_matrices, transf_matrices_inv, gpus): 94 | criterion = nn.BCEWithLogitsLoss() 95 | batch_size = real_imgs.size(0) 96 | fake = fake_imgs.detach() 97 | local_label = local_label.detach() 98 | local_label_cond = local_label[:, 0, :] + local_label[:, 1, :] + local_label[:, 2, :] + local_label[:, 3, :] 99 | local_label_cond[local_label_cond < 0] = 0 100 | real_features = nn.parallel.data_parallel(netD, (real_imgs, local_label, transf_matrices, transf_matrices_inv), gpus) 101 | fake_features = nn.parallel.data_parallel(netD, (fake, local_label, transf_matrices, transf_matrices_inv), gpus) 102 | # real pairs 103 | inputs = (real_features, local_label_cond) 104 | real_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus) 105 | # real_logits = torch.clamp(real_logits, 1e-8, 1-1e-8) 106 | errD_real = criterion(real_logits, real_labels) 107 | # wrong pairs 108 | inputs = (real_features[:(batch_size-1)], local_label_cond[1:]) 109 | wrong_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus) 110 | errD_wrong = criterion(wrong_logits, fake_labels[1:]) 111 | # fake pairs 112 | inputs = (fake_features, local_label_cond) 113 | fake_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus) 114 | errD_fake = criterion(fake_logits, fake_labels) 115 | 116 | if netD.get_uncond_logits is not None: 117 | real_logits = nn.parallel.data_parallel(netD.get_uncond_logits, (real_features), gpus) 118 | fake_logits = nn.parallel.data_parallel(netD.get_uncond_logits, (fake_features), gpus) 119 | uncond_errD_real = criterion(real_logits, real_labels) 120 | uncond_errD_fake = criterion(fake_logits, fake_labels) 121 | # 122 | errD = ((errD_real + uncond_errD_real) / 2. + 123 | (errD_fake + errD_wrong + uncond_errD_fake) / 3.) 124 | errD_real = (errD_real + uncond_errD_real) / 2. 125 | errD_fake = (errD_fake + uncond_errD_fake) / 2. 126 | else: 127 | errD = errD_real + (errD_fake + errD_wrong) * 0.5 128 | return errD, errD_real.item(), errD_wrong.item(), errD_fake.item() 129 | 130 | 131 | def compute_generator_loss(netD, fake_imgs, real_labels, local_label, transf_matrices, transf_matrices_inv, gpus): 132 | criterion = nn.BCEWithLogitsLoss() 133 | local_label_cond = local_label[:, 0, :] + local_label[:, 1, :] + local_label[:, 2, :] + local_label[:, 3, :] 134 | local_label_cond[local_label_cond < 0] = 0 135 | fake_features = nn.parallel.data_parallel(netD, (fake_imgs, local_label, transf_matrices, transf_matrices_inv), gpus) 136 | # fake pairs 137 | inputs = (fake_features, local_label_cond) 138 | fake_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus) 139 | errD_fake = criterion(fake_logits, real_labels) 140 | if netD.get_uncond_logits is not None: 141 | fake_logits = nn.parallel.data_parallel(netD.get_uncond_logits, (fake_features), gpus) 142 | uncond_errD_fake = criterion(fake_logits, real_labels) 143 | errD_fake += uncond_errD_fake 144 | return errD_fake 145 | 146 | 147 | ############################# 148 | def weights_init(m): 149 | classname = m.__class__.__name__ 150 | if classname.find('Conv') != -1: 151 | m.weight.data.normal_(0.0, 0.02) 152 | elif classname.find('BatchNorm') != -1: 153 | m.weight.data.normal_(1.0, 0.02) 154 | m.bias.data.fill_(0) 155 | elif classname.find('Linear') != -1: 156 | m.weight.data.normal_(0.0, 0.02) 157 | if m.bias is not None: 158 | m.bias.data.fill_(0.0) 159 | 160 | 161 | ############################# 162 | def save_img_results(data_img, fake, epoch, image_dir): 163 | num = cfg.VIS_COUNT 164 | fake = fake[0:num] 165 | # data_img is changed to [0,1] 166 | if data_img is not None: 167 | data_img = data_img[0:num] 168 | vutils.save_image( 169 | data_img, '%s/real_samples.png' % image_dir, 170 | normalize=True) 171 | # fake.data is still [-1, 1] 172 | vutils.save_image( 173 | fake.data, '%s/fake_samples_epoch_%03d.png' % 174 | (image_dir, epoch), normalize=True) 175 | else: 176 | vutils.save_image( 177 | fake.data, '%s/lr_fake_samples_epoch_%03d.png' % 178 | (image_dir, epoch), normalize=True) 179 | 180 | 181 | def save_model(netG, netD, optimG, optimD, epoch, model_dir, saveD=False, saveOptim=False, max_to_keep=5): 182 | checkpoint = { 183 | 'epoch': epoch, 184 | 'netG': netG.state_dict(), 185 | 'optimG': optimG.state_dict() if saveOptim else {}, 186 | 'netD': netD.state_dict() if saveD else {}, 187 | 'optimD': optimD.state_dict() if saveOptim else {}} 188 | torch.save(checkpoint, "{}/checkpoint_{:04}.pth".format(model_dir, epoch)) 189 | print('Save G/D models') 190 | 191 | if max_to_keep is not None and max_to_keep > 0: 192 | checkpoint_list = sorted([ckpt for ckpt in glob.glob(model_dir + "/" + '*.pth')]) 193 | while len(checkpoint_list) > max_to_keep: 194 | os.remove(checkpoint_list[0]) 195 | checkpoint_list = checkpoint_list[1:] 196 | 197 | 198 | def mkdir_p(path): 199 | try: 200 | os.makedirs(path) 201 | except OSError as exc: # Python >2.5 202 | if exc.errno == errno.EEXIST and os.path.isdir(path): 203 | pass 204 | else: 205 | raise 206 | -------------------------------------------------------------------------------- /code/clevr/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.parallel 4 | from miscc.config import cfg 5 | from miscc.utils import compute_transformation_matrix, compute_transformation_matrix_inverse 6 | from torch.autograd import Variable 7 | 8 | 9 | def conv3x3(in_planes, out_planes, stride=1): 10 | "3x3 convolution with padding" 11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 12 | padding=1, bias=False) 13 | 14 | 15 | # Upsale the spatial size by a factor of 2 16 | def upBlock(in_planes, out_planes): 17 | block = nn.Sequential( 18 | nn.Upsample(scale_factor=2, mode='nearest'), 19 | conv3x3(in_planes, out_planes), 20 | nn.BatchNorm2d(out_planes), 21 | nn.ReLU(True)) 22 | return block 23 | 24 | 25 | class ResBlock(nn.Module): 26 | def __init__(self, channel_num): 27 | super(ResBlock, self).__init__() 28 | self.block = nn.Sequential( 29 | conv3x3(channel_num, channel_num), 30 | nn.BatchNorm2d(channel_num), 31 | nn.ReLU(True), 32 | conv3x3(channel_num, channel_num), 33 | nn.BatchNorm2d(channel_num)) 34 | self.relu = nn.ReLU(inplace=True) 35 | 36 | def forward(self, x): 37 | residual = x 38 | out = self.block(x) 39 | out += residual 40 | out = self.relu(out) 41 | return out 42 | 43 | 44 | class D_GET_LOGITS(nn.Module): 45 | def __init__(self, ndf, nef, bcondition=True): 46 | super(D_GET_LOGITS, self).__init__() 47 | self.df_dim = ndf 48 | self.ef_dim = cfg.GAN.CONDITION_DIM 49 | self.bcondition = bcondition 50 | if bcondition: 51 | self.outlogits = nn.Sequential( 52 | conv3x3(ndf * 8 + 13, ndf * 8), 53 | nn.BatchNorm2d(ndf * 8), 54 | nn.LeakyReLU(0.2, inplace=True), 55 | nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4)) 56 | else: 57 | self.outlogits = nn.Sequential( 58 | nn.Conv2d(109, 1, kernel_size=4, stride=4)) 59 | 60 | def forward(self, h_code, c_code=None): 61 | # conditioning output 62 | if self.bcondition and c_code is not None: 63 | c_code = c_code.view(h_code.shape[0], 13, 1, 1) 64 | c_code = c_code.repeat(1, 1, 4, 4) 65 | h_c_code = torch.cat((h_code, c_code), 1) 66 | else: 67 | h_c_code = h_code 68 | 69 | output = self.outlogits(h_c_code) 70 | return output.view(-1) 71 | 72 | 73 | def stn(image, transformation_matrix, size): 74 | grid = torch.nn.functional.affine_grid(transformation_matrix, torch.Size(size)) 75 | out_image = torch.nn.functional.grid_sample(image, grid) 76 | 77 | return out_image 78 | 79 | 80 | class BBOX_NET(nn.Module): 81 | def __init__(self): 82 | super(BBOX_NET, self).__init__() 83 | self.c_dim = cfg.GAN.CONDITION_DIM 84 | self.encode = nn.Sequential( 85 | # 128 * 16 x 16 86 | conv3x3(self.c_dim, self.c_dim // 2, stride=2), 87 | nn.LeakyReLU(0.2, inplace=True), 88 | # 64 x 8 x 8 89 | conv3x3(self.c_dim // 2, self.c_dim // 4, stride=2), 90 | nn.BatchNorm2d(self.c_dim // 4), 91 | nn.LeakyReLU(0.2, inplace=True), 92 | # 32 x 4 x 4 93 | conv3x3(self.c_dim // 4, self.c_dim // 8, stride=2), 94 | nn.BatchNorm2d(self.c_dim // 8), 95 | nn.LeakyReLU(0.2, inplace=True), 96 | # 16 x 2 x 2 97 | ) 98 | 99 | def forward(self, labels, transf_matr_inv, num_objects): 100 | label_layout = torch.cuda.FloatTensor(labels.shape[0], self.c_dim, 16, 16).fill_(0) 101 | for idx in range(num_objects): 102 | current_label = labels[:, idx] 103 | current_label = current_label.view(current_label.shape[0], current_label.shape[1], 1, 1) 104 | current_label = current_label.repeat(1, 1, 16, 16) 105 | current_label = stn(current_label, transf_matr_inv[:, idx], current_label.shape) 106 | label_layout += current_label 107 | 108 | layout_encoding = self.encode(label_layout).view(labels.shape[0], -1) 109 | 110 | return layout_encoding 111 | 112 | # ############# Networks for stageI GAN ############# 113 | class STAGE1_G(nn.Module): 114 | def __init__(self): 115 | super(STAGE1_G, self).__init__() 116 | self.gf_dim = cfg.GAN.GF_DIM * 8 117 | self.ef_dim = cfg.GAN.CONDITION_DIM 118 | self.z_dim = cfg.Z_DIM 119 | self.define_module() 120 | 121 | def define_module(self): 122 | ninput = self.z_dim 123 | linput = 13 124 | ngf = self.gf_dim 125 | 126 | if cfg.USE_BBOX_LAYOUT or cfg.USE_BBOX_LAYOUT_S1: 127 | self.bbox_net = BBOX_NET() 128 | ninput += 8 129 | 130 | # -> ngf x 4 x 4 131 | self.fc = nn.Sequential( 132 | nn.Linear(ninput, ngf * 4 * 4, bias=False), 133 | nn.BatchNorm1d(ngf * 4 * 4), 134 | nn.ReLU(True)) 135 | 136 | # local pathway 137 | self.label = nn.Sequential( 138 | nn.Linear(linput, self.ef_dim, bias=False), 139 | nn.BatchNorm1d(self.ef_dim), 140 | nn.ReLU(True)) 141 | self.local1 = upBlock(self.ef_dim, ngf // 2) 142 | self.local2 = upBlock(ngf // 2, ngf // 4) 143 | 144 | # global pathway 145 | # ngf x 4 x 4 -> ngf/2 x 8 x 8 146 | self.upsample1 = upBlock(ngf, ngf // 2) 147 | # -> ngf/4 x 16 x 16 148 | self.upsample2 = upBlock(ngf // 2, ngf // 4) 149 | # -> ngf/8 x 32 x 32 150 | self.upsample3 = upBlock(ngf // 2, ngf // 8) 151 | # -> ngf/16 x 64 x 64 152 | self.upsample4 = upBlock(ngf // 8, ngf // 16) 153 | # -> 3 x 64 x 64 154 | self.img = nn.Sequential( 155 | conv3x3(ngf // 16, 3), 156 | nn.Tanh()) 157 | 158 | def forward(self, noise, transf_matrices_inv, label_one_hot, num_objects=4): 159 | local_labels = torch.cuda.FloatTensor(noise.shape[0], num_objects, self.ef_dim).fill_(0) 160 | 161 | # local pathway 162 | h_code_locals = torch.cuda.FloatTensor(noise.shape[0], self.gf_dim // 4, 16, 16).fill_(0) 163 | for idx in range(num_objects): 164 | current_label = self.label(label_one_hot[:, idx]) 165 | local_labels[:, idx] = current_label 166 | current_label = current_label.view(current_label.shape[0], self.ef_dim, 1, 1) 167 | current_label = current_label.repeat(1, 1, 4, 4) 168 | h_code_local = self.local1(current_label) 169 | h_code_local = self.local2(h_code_local) 170 | h_code_local = stn(h_code_local, transf_matrices_inv[:, idx], h_code_local.shape) 171 | h_code_locals += h_code_local 172 | 173 | # global pathway 174 | if cfg.USE_BBOX_LAYOUT: 175 | bbox_code = self.bbox_net(local_labels, transf_matrices_inv, num_objects) 176 | z_c_code = torch.cat((noise, bbox_code), 1) 177 | else: 178 | z_c_code = noise 179 | h_code = self.fc(z_c_code) 180 | h_code = h_code.view(-1, self.gf_dim, 4, 4) 181 | h_code = self.upsample1(h_code) 182 | h_code = self.upsample2(h_code) 183 | 184 | # combine local and global 185 | h_code = torch.cat((h_code, h_code_locals), 1) 186 | 187 | h_code = self.upsample3(h_code) 188 | h_code = self.upsample4(h_code) 189 | 190 | # state size 3 x 64 x 64 191 | fake_img = self.img(h_code) 192 | return fake_img 193 | 194 | 195 | class STAGE1_D(nn.Module): 196 | def __init__(self): 197 | super(STAGE1_D, self).__init__() 198 | self.df_dim = cfg.GAN.DF_DIM 199 | self.ef_dim = cfg.GAN.CONDITION_DIM 200 | self.define_module() 201 | 202 | def define_module(self): 203 | ndf, nef = self.df_dim, self.ef_dim 204 | linput = 13 205 | 206 | # local pathway 207 | self.local = nn.Sequential( 208 | nn.Conv2d(3 + linput, ndf * 2, 4, 1, 1, bias=False), 209 | nn.BatchNorm2d(ndf * 2), 210 | nn.LeakyReLU(0.2, inplace=True) 211 | ) 212 | 213 | self.act = nn.LeakyReLU(0.2, inplace=True) 214 | 215 | self.conv1 = nn.Conv2d(3, ndf, 4, 2, 1, bias=False) 216 | self.conv2 = nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False) 217 | self.bn2 = nn.BatchNorm2d(ndf * 2) 218 | self.conv3 = nn.Conv2d(ndf*4, ndf * 4, 4, 2, 1, bias=False) 219 | self.bn3 = nn.BatchNorm2d(ndf * 4) 220 | self.conv4 = nn.Conv2d(ndf*4, ndf * 8, 4, 2, 1, bias=False) 221 | self.bn4 = nn.BatchNorm2d(ndf * 8) 222 | 223 | self.get_cond_logits = D_GET_LOGITS(ndf, nef) 224 | self.get_uncond_logits = None 225 | 226 | def _encode_img(self, image, label, transf_matrices, transf_matrices_inv, num_objects=4): 227 | # local pathway 228 | h_code_locals = torch.cuda.FloatTensor(image.shape[0], self.df_dim * 2, 16, 16).fill_(0) 229 | 230 | for idx in range(num_objects): 231 | current_label = label[:, idx].view(label.shape[0], 13, 1, 1) 232 | current_label = current_label.repeat(1, 1, 16, 16) 233 | h_code_local = stn(image, transf_matrices[:, idx], (image.shape[0], image.shape[1], 16, 16)) 234 | h_code_local = torch.cat((h_code_local, current_label), 1) 235 | h_code_local = self.local(h_code_local) 236 | h_code_local = stn(h_code_local, transf_matrices_inv[:, idx], (h_code_local.shape[0], h_code_local.shape[1], 16, 16)) 237 | h_code_locals += h_code_local 238 | 239 | h_code = self.conv1(image) 240 | h_code = self.act(h_code) 241 | h_code = self.conv2(h_code) 242 | h_code = self.bn2(h_code) 243 | h_code = self.act(h_code) 244 | 245 | # combine global and local 246 | h_code = torch.cat((h_code, h_code_locals), 1) 247 | 248 | h_code = self.conv3(h_code) 249 | h_code = self.bn3(h_code) 250 | h_code = self.act(h_code) 251 | 252 | h_code = self.conv4(h_code) 253 | h_code = self.bn4(h_code) 254 | h_code = self.act(h_code) 255 | return h_code 256 | 257 | def forward(self, image, label, transf_matrices, transf_matrices_inv): 258 | img_embedding = self._encode_img(image, label, transf_matrices, transf_matrices_inv) 259 | 260 | return img_embedding 261 | -------------------------------------------------------------------------------- /code/clevr/trainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from six.moves import range 3 | from PIL import Image 4 | 5 | import torch.backends.cudnn as cudnn 6 | import torch 7 | import torch.nn as nn 8 | from torch.autograd import Variable 9 | import torch.optim as optim 10 | import os 11 | import time 12 | 13 | import numpy as np 14 | import torchfile 15 | 16 | from miscc.config import cfg 17 | from miscc.utils import mkdir_p 18 | from miscc.utils import weights_init 19 | from miscc.utils import save_img_results, save_model 20 | from miscc.utils import compute_discriminator_loss, compute_generator_loss 21 | from miscc.utils import compute_transformation_matrix, compute_transformation_matrix_inverse 22 | from miscc.utils import load_validation_data 23 | 24 | from tensorboard import summary 25 | from tensorboard import FileWriter 26 | 27 | 28 | class GANTrainer(object): 29 | def __init__(self, output_dir): 30 | if cfg.TRAIN.FLAG: 31 | self.model_dir = os.path.join(output_dir, 'Model') 32 | self.image_dir = os.path.join(output_dir, 'Image') 33 | self.log_dir = os.path.join(output_dir, 'Log') 34 | mkdir_p(self.model_dir) 35 | mkdir_p(self.image_dir) 36 | mkdir_p(self.log_dir) 37 | self.summary_writer = FileWriter(self.log_dir) 38 | 39 | self.max_epoch = cfg.TRAIN.MAX_EPOCH 40 | self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL 41 | self.max_objects = 4 42 | 43 | s_gpus = cfg.GPU_ID.split(',') 44 | self.gpus = [int(ix) for ix in s_gpus] 45 | self.num_gpus = len(self.gpus) 46 | self.batch_size = cfg.TRAIN.BATCH_SIZE 47 | torch.cuda.set_device(self.gpus[0]) 48 | cudnn.benchmark = True 49 | 50 | # ############# For training stageI GAN ############# 51 | def load_network_stageI(self): 52 | from model import STAGE1_G, STAGE1_D 53 | netG = STAGE1_G() 54 | netG.apply(weights_init) 55 | netD = STAGE1_D() 56 | netD.apply(weights_init) 57 | 58 | if cfg.NET_G != '': 59 | state_dict = \ 60 | torch.load(cfg.NET_G, map_location=lambda storage, loc: storage) 61 | netG.load_state_dict(state_dict["netG"]) 62 | print('Load from: ', cfg.NET_G) 63 | if cfg.NET_D != '': 64 | state_dict = \ 65 | torch.load(cfg.NET_D, map_location=lambda storage, loc: storage) 66 | netD.load_state_dict(state_dict) 67 | print('Load from: ', cfg.NET_D) 68 | if cfg.CUDA: 69 | netG.cuda() 70 | netD.cuda() 71 | return netG, netD 72 | 73 | def train(self, data_loader, stage=1): 74 | netG, netD = self.load_network_stageI() 75 | 76 | nz = cfg.Z_DIM 77 | batch_size = self.batch_size 78 | noise = Variable(torch.FloatTensor(batch_size, nz)) 79 | 80 | fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1), requires_grad=False) 81 | real_labels = Variable(torch.FloatTensor(batch_size).fill_(1)) 82 | fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0)) 83 | if cfg.CUDA: 84 | noise, fixed_noise = noise.cuda(), fixed_noise.cuda() 85 | real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda() 86 | 87 | generator_lr = cfg.TRAIN.GENERATOR_LR 88 | discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR 89 | lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH 90 | 91 | netG_para = [] 92 | for p in netG.parameters(): 93 | if p.requires_grad: 94 | netG_para.append(p) 95 | optimizerD = optim.Adam(netD.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) 96 | optimizerG = optim.Adam(netG_para, lr=cfg.TRAIN.GENERATOR_LR, betas=(0.5, 0.999)) 97 | 98 | print("Start training...") 99 | count = 0 100 | for epoch in range(self.max_epoch): 101 | start_t = time.time() 102 | if epoch % lr_decay_step == 0 and epoch > 0: 103 | generator_lr *= 0.5 104 | for param_group in optimizerG.param_groups: 105 | param_group['lr'] = generator_lr 106 | discriminator_lr *= 0.5 107 | for param_group in optimizerD.param_groups: 108 | param_group['lr'] = discriminator_lr 109 | 110 | for i, data in enumerate(data_loader, 0): 111 | ###################################################### 112 | # (1) Prepare training data 113 | ###################################################### 114 | real_img_cpu, transformation_matrices, label_one_hot, _ = data 115 | 116 | transf_matrices, transf_matrices_inv = tuple(transformation_matrices) 117 | transf_matrices = transf_matrices.detach() 118 | transf_matrices_inv = transf_matrices_inv.detach() 119 | 120 | real_imgs = Variable(real_img_cpu) 121 | if cfg.CUDA: 122 | real_imgs = real_imgs.cuda() 123 | label_one_hot = label_one_hot.cuda() 124 | transf_matrices = transf_matrices.cuda() 125 | transf_matrices_inv = transf_matrices_inv.cuda() 126 | 127 | ####################################################### 128 | # (2) Generate fake images 129 | ###################################################### 130 | noise.data.normal_(0, 1) 131 | inputs = (noise, transf_matrices_inv, label_one_hot) 132 | fake_imgs = nn.parallel.data_parallel(netG, inputs, self.gpus) 133 | 134 | ############################ 135 | # (3) Update D network 136 | ########################### 137 | netD.zero_grad() 138 | 139 | errD, errD_real, errD_wrong, errD_fake = \ 140 | compute_discriminator_loss(netD, real_imgs, fake_imgs, 141 | real_labels, fake_labels, 142 | label_one_hot, transf_matrices, transf_matrices_inv, self.gpus) 143 | errD.backward(retain_graph=True) 144 | optimizerD.step() 145 | ############################ 146 | # (2) Update G network 147 | ########################### 148 | netG.zero_grad() 149 | errG = compute_generator_loss(netD, fake_imgs, real_labels, label_one_hot, 150 | transf_matrices, transf_matrices_inv, self.gpus) 151 | 152 | errG_total = errG 153 | errG_total.backward() 154 | optimizerG.step() 155 | 156 | ############################ 157 | # (3) Log results 158 | ########################### 159 | count = count + 1 160 | if i % 500 == 0: 161 | summary_D = summary.scalar('D_loss', errD.item()) 162 | summary_D_r = summary.scalar('D_loss_real', errD_real) 163 | summary_D_w = summary.scalar('D_loss_wrong', errD_wrong) 164 | summary_D_f = summary.scalar('D_loss_fake', errD_fake) 165 | summary_G = summary.scalar('G_loss', errG.item()) 166 | 167 | self.summary_writer.add_summary(summary_D, count) 168 | self.summary_writer.add_summary(summary_D_r, count) 169 | self.summary_writer.add_summary(summary_D_w, count) 170 | self.summary_writer.add_summary(summary_D_f, count) 171 | self.summary_writer.add_summary(summary_G, count) 172 | 173 | # save the image result for each epoch 174 | with torch.no_grad(): 175 | inputs = (noise, transf_matrices_inv, label_one_hot) 176 | fake = nn.parallel.data_parallel(netG, inputs, self.gpus) 177 | save_img_results(real_img_cpu, fake, epoch, self.image_dir) 178 | with torch.no_grad(): 179 | inputs = (noise, transf_matrices_inv, label_one_hot) 180 | fake = nn.parallel.data_parallel(netG, inputs, self.gpus) 181 | save_img_results(real_img_cpu, fake, epoch, self.image_dir) 182 | end_t = time.time() 183 | print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f 184 | Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f 185 | Total Time: %.2fsec 186 | ''' 187 | % (epoch, self.max_epoch, i, len(data_loader), 188 | errD.item(), errG.item(), 189 | errD_real, errD_wrong, errD_fake, (end_t - start_t))) 190 | if epoch % self.snapshot_interval == 0: 191 | save_model(netG, netD, optimizerG, optimizerD, epoch, self.model_dir) 192 | # 193 | save_model(netG, netD, optimizerG, optimizerD, epoch, self.model_dir) 194 | # 195 | self.summary_writer.close() 196 | 197 | 198 | def sample(self, data_loader, num_samples=25, draw_bbox=True, max_objects=4): 199 | from PIL import Image, ImageDraw, ImageFont 200 | import cPickle as pickle 201 | import torchvision 202 | import torchvision.utils as vutils 203 | netG, _ = self.load_network_stageI() 204 | netG.eval() 205 | 206 | # path to save generated samples 207 | save_dir = cfg.NET_G[:cfg.NET_G.find('.pth')] + "_samples_" + str(max_objects) + "_objects" 208 | print("saving to:", save_dir) 209 | mkdir_p(save_dir) 210 | 211 | nz = cfg.Z_DIM 212 | noise = Variable(torch.FloatTensor(9, nz)) 213 | if cfg.CUDA: 214 | noise = noise.cuda() 215 | 216 | imsize = 64 217 | 218 | # for count in range(num_samples): 219 | count = 0 220 | for i, data in enumerate(data_loader, 0): 221 | if count == num_samples: 222 | break 223 | ###################################################### 224 | # (1) Prepare training data 225 | ###################################################### 226 | real_img_cpu, transformation_matrices, label_one_hot, bbox = data 227 | 228 | transf_matrices, transf_matrices_inv = tuple(transformation_matrices) 229 | transf_matrices_inv = transf_matrices_inv.detach() 230 | 231 | real_img = Variable(real_img_cpu) 232 | if cfg.CUDA: 233 | real_img = real_img.cuda() 234 | label_one_hot = label_one_hot.cuda() 235 | transf_matrices_inv = transf_matrices_inv.cuda() 236 | 237 | transf_matrices_inv_batch = transf_matrices_inv.view(1, max_objects, 2, 3).repeat(9, 1, 1, 1) 238 | label_one_hot_batch = label_one_hot.view(1, max_objects, 13).repeat(9, 1, 1) 239 | 240 | ####################################################### 241 | # (2) Generate fake images 242 | ###################################################### 243 | noise.data.normal_(0, 1) 244 | inputs = (noise, transf_matrices_inv_batch, label_one_hot_batch) 245 | with torch.no_grad(): 246 | fake_imgs= nn.parallel.data_parallel(netG, inputs, self.gpus) 247 | 248 | data_img = torch.FloatTensor(20, 3, imsize, imsize).fill_(0) 249 | data_img[0] = real_img 250 | data_img[1:10] = fake_imgs 251 | 252 | if draw_bbox: 253 | for idx in range(max_objects): 254 | x, y, w, h = tuple([int(imsize*x) for x in bbox[0, idx]]) 255 | w = imsize-1 if w > imsize-1 else w 256 | h = imsize-1 if h > imsize-1 else h 257 | if x <= -1 or y <= -1: 258 | break 259 | data_img[:10, :, y, x:x + w] = 1 260 | data_img[:10, :, y:y + h, x] = 1 261 | data_img[:10, :, y+h, x:x + w] = 1 262 | data_img[:10, :, y:y + h, x + w] = 1 263 | 264 | # write caption into image 265 | shape_dict = { 266 | 0: "cube", 267 | 1: "cylinder", 268 | 2: "sphere", 269 | 3: "empty" 270 | } 271 | 272 | color_dict = { 273 | 0: "gray", 274 | 1: "red", 275 | 2: "blue", 276 | 3: "green", 277 | 4: "brown", 278 | 5: "purple", 279 | 6: "cyan", 280 | 7: "yellow", 281 | 8: "empty" 282 | } 283 | text_img = Image.new('L', (imsize * 10, imsize), color='white') 284 | d = ImageDraw.Draw(text_img) 285 | label = label_one_hot_batch[0] 286 | label = label.cpu().numpy() 287 | label_shape = label[:, :4] 288 | label_color = label[:, 4:] 289 | label_shape = np.argmax(label_shape, axis=1) 290 | label_color = np.argmax(label_color, axis=1) 291 | label_combined = ", ".join([color_dict[label_color[_]] + " " + shape_dict[label_shape[_]] 292 | for _ in range(max_objects)]) 293 | d.text((10, 10), label_combined) 294 | text_img = torchvision.transforms.functional.to_tensor(text_img) 295 | text_img = torch.chunk(text_img, 10, 2) 296 | text_img = torch.cat([text_img[i].view(1, 1, imsize, imsize) for i in range(10)], 0) 297 | data_img[10:] = text_img 298 | vutils.save_image(data_img, '{}/vis_{}.png'.format(save_dir, count), normalize=True, nrow=10) 299 | count += 1 300 | 301 | print("Saved {} files to {}".format(count, save_dir)) 302 | -------------------------------------------------------------------------------- /code/coco/attngan/DAMSMencoders/README.md: -------------------------------------------------------------------------------- 1 | Put the pre-trained DAMSM model into a folder `coco` 2 | -------------------------------------------------------------------------------- /code/coco/attngan/GlobalAttention.py: -------------------------------------------------------------------------------- 1 | """ 2 | Global attention takes a matrix and a query metrix. 3 | Based on each query vector q, it computes a parameterized convex combination of the matrix 4 | based. 5 | H_1 H_2 H_3 ... H_n 6 | q q q q 7 | | | | | 8 | \ | | / 9 | ..... 10 | \ | / 11 | a 12 | Constructs a unit mapping. 13 | $$(H_1 + H_n, q) => (a)$$ 14 | Where H is of `batch x n x dim` and q is of `batch x dim`. 15 | 16 | References: 17 | https://github.com/OpenNMT/OpenNMT-py/tree/fc23dfef1ba2f258858b2765d24565266526dc76/onmt/modules 18 | http://www.aclweb.org/anthology/D15-1166 19 | """ 20 | 21 | import torch 22 | import torch.nn as nn 23 | 24 | 25 | def conv1x1(in_planes, out_planes): 26 | "1x1 convolution with padding" 27 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, 28 | padding=0, bias=False) 29 | 30 | 31 | def func_attention(query, context, gamma1): 32 | """ 33 | query: batch x ndf x queryL 34 | context: batch x ndf x ih x iw (sourceL=ihxiw) 35 | mask: batch_size x sourceL 36 | """ 37 | batch_size, queryL = query.size(0), query.size(2) 38 | ih, iw = context.size(2), context.size(3) 39 | sourceL = ih * iw 40 | 41 | # --> batch x sourceL x ndf 42 | context = context.view(batch_size, -1, sourceL) 43 | contextT = torch.transpose(context, 1, 2).contiguous() 44 | 45 | # Get attention 46 | # (batch x sourceL x ndf)(batch x ndf x queryL) 47 | # -->batch x sourceL x queryL 48 | attn = torch.bmm(contextT, query) # Eq. (7) in AttnGAN paper 49 | # --> batch*sourceL x queryL 50 | attn = attn.view(batch_size*sourceL, queryL) 51 | attn = nn.Softmax(dim=1)(attn) # Eq. (8) 52 | 53 | # --> batch x sourceL x queryL 54 | attn = attn.view(batch_size, sourceL, queryL) 55 | # --> batch*queryL x sourceL 56 | attn = torch.transpose(attn, 1, 2).contiguous() 57 | attn = attn.view(batch_size*queryL, sourceL) 58 | # Eq. (9) 59 | attn = attn * gamma1 60 | attn = nn.Softmax(dim=1)(attn) 61 | attn = attn.view(batch_size, queryL, sourceL) 62 | # --> batch x sourceL x queryL 63 | attnT = torch.transpose(attn, 1, 2).contiguous() 64 | 65 | # (batch x ndf x sourceL)(batch x sourceL x queryL) 66 | # --> batch x ndf x queryL 67 | weightedContext = torch.bmm(context, attnT) 68 | 69 | return weightedContext, attn.view(batch_size, -1, ih, iw) 70 | 71 | 72 | class GlobalAttentionGeneral(nn.Module): 73 | def __init__(self, idf, cdf): 74 | super(GlobalAttentionGeneral, self).__init__() 75 | self.conv_context = conv1x1(cdf, idf) 76 | self.sm = nn.Softmax(dim=1) 77 | self.mask = None 78 | 79 | def applyMask(self, mask): 80 | self.mask = mask # batch x sourceL 81 | 82 | def forward(self, input, context): 83 | """ 84 | input: batch x idf x ih x iw (queryL=ihxiw) 85 | context: batch x cdf x sourceL 86 | """ 87 | ih, iw = input.size(2), input.size(3) 88 | queryL = ih * iw 89 | batch_size, sourceL = context.size(0), context.size(2) 90 | 91 | # --> batch x queryL x idf 92 | target = input.view(batch_size, -1, queryL) 93 | targetT = torch.transpose(target, 1, 2).contiguous() 94 | # batch x cdf x sourceL --> batch x cdf x sourceL x 1 95 | sourceT = context.unsqueeze(3) 96 | # --> batch x idf x sourceL 97 | sourceT = self.conv_context(sourceT).squeeze(3) 98 | 99 | # Get attention 100 | # (batch x queryL x idf)(batch x idf x sourceL) 101 | # -->batch x queryL x sourceL 102 | attn = torch.bmm(targetT, sourceT) 103 | # --> batch*queryL x sourceL 104 | attn = attn.view(batch_size*queryL, sourceL) 105 | if self.mask is not None: 106 | # batch_size x sourceL --> batch_size*queryL x sourceL 107 | mask = self.mask.repeat(queryL, 1) 108 | attn.data.masked_fill_(mask.data, -float('inf')) 109 | # print(attn.shape) 110 | # exit() 111 | attn = self.sm(attn) # Eq. (2) 112 | # --> batch x queryL x sourceL 113 | attn = attn.view(batch_size, queryL, sourceL) 114 | # --> batch x sourceL x queryL 115 | attn = torch.transpose(attn, 1, 2).contiguous() 116 | 117 | # (batch x idf x sourceL)(batch x sourceL x queryL) 118 | # --> batch x idf x queryL 119 | weightedContext = torch.bmm(sourceT, attn) 120 | weightedContext = weightedContext.view(batch_size, -1, ih, iw) 121 | attn = attn.view(batch_size, -1, ih, iw) 122 | 123 | return weightedContext, attn 124 | -------------------------------------------------------------------------------- /code/coco/attngan/cfg/coco_eval.yml: -------------------------------------------------------------------------------- 1 | CONFIG_NAME: 'attn2' 2 | 3 | DATASET_NAME: 'coco' 4 | DATA_DIR: '../../../data/coco/coco' 5 | IMG_DIR: "../../../data/MS-COCO/test/val2014" 6 | GPU_ID: '0' 7 | WORKERS: 1 8 | 9 | B_VALIDATION: True 10 | TREE: 11 | BRANCH_NUM: 3 12 | 13 | 14 | TRAIN: 15 | FLAG: False 16 | NET_G: '../../../models/model-ms-coco-attngan-0100.pth' 17 | B_NET_D: False 18 | BATCH_SIZE: 50 19 | NET_E: 'DAMSMencoders/coco/text_encoder100.pth' 20 | 21 | 22 | GAN: 23 | DF_DIM: 96 24 | GF_DIM: 48 25 | Z_DIM: 100 26 | R_NUM: 3 27 | 28 | TEXT: 29 | EMBEDDING_DIM: 256 30 | CAPTIONS_PER_IMAGE: 5 31 | WORDS_NUM: 20 32 | -------------------------------------------------------------------------------- /code/coco/attngan/cfg/coco_train.yml: -------------------------------------------------------------------------------- 1 | CONFIG_NAME: 'glu-gan2' 2 | 3 | DATASET_NAME: 'coco' 4 | DATA_DIR: '../../../data/coco/coco' 5 | IMG_DIR: "../../../data/MS-COCO/train/train2014" 6 | GPU_ID: '0,1,2' 7 | WORKERS: 20 8 | 9 | 10 | TREE: 11 | BRANCH_NUM: 3 12 | 13 | 14 | TRAIN: 15 | FLAG: True 16 | NET_G: '' # '../models/coco_AttnGAN2.pth' 17 | B_NET_D: True 18 | BATCH_SIZE: 14 # 32 19 | MAX_EPOCH: 120 20 | SNAPSHOT_INTERVAL: 5 21 | DISCRIMINATOR_LR: 0.0002 22 | GENERATOR_LR: 0.0002 23 | # 24 | NET_E: 'DAMSMencoders/coco/text_encoder100.pth' 25 | SMOOTH: 26 | GAMMA1: 4.0 # 1,2,5 good 4 best 10&100bad 27 | GAMMA2: 5.0 28 | GAMMA3: 10.0 # 10good 1&100bad 29 | LAMBDA: 50.0 30 | 31 | 32 | GAN: 33 | DF_DIM: 96 34 | GF_DIM: 48 35 | Z_DIM: 100 36 | R_NUM: 3 37 | 38 | TEXT: 39 | EMBEDDING_DIM: 256 40 | CAPTIONS_PER_IMAGE: 5 41 | WORDS_NUM: 12 42 | -------------------------------------------------------------------------------- /code/coco/attngan/datasets.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | from nltk.tokenize import RegexpTokenizer 7 | from collections import defaultdict 8 | from miscc.config import cfg 9 | 10 | import torch 11 | import torch.utils.data as data 12 | from torch.autograd import Variable 13 | import torchvision.transforms as transforms 14 | 15 | import os 16 | import sys 17 | import numpy as np 18 | from PIL import Image 19 | import numpy.random as random 20 | 21 | if sys.version_info[0] == 2: 22 | import cPickle as pickle 23 | else: 24 | import pickle 25 | 26 | from miscc.utils import * 27 | 28 | def prepare_data(data, eval=False): 29 | if eval: 30 | imgs, captions, captions_lens, class_ids, keys, transformation_matrices, label, bbox = data 31 | else: 32 | imgs, captions, captions_lens, class_ids, keys, transformation_matrices, label = data 33 | 34 | # sort data by the length in a decreasing order 35 | sorted_cap_lens, sorted_cap_indices = \ 36 | torch.sort(captions_lens, 0, True) 37 | 38 | real_imgs = [] 39 | for i in range(len(imgs)): 40 | imgs[i] = imgs[i][sorted_cap_indices] 41 | if cfg.CUDA: 42 | real_imgs.append(Variable(imgs[i]).cuda()) 43 | else: 44 | real_imgs.append(Variable(imgs[i])) 45 | 46 | captions = captions[sorted_cap_indices].squeeze() 47 | class_ids = class_ids[sorted_cap_indices].numpy() 48 | transformation_matrices[0] = transformation_matrices[0][sorted_cap_indices] 49 | transformation_matrices[1] = transformation_matrices[1][sorted_cap_indices] 50 | label = label[sorted_cap_indices] 51 | # sent_indices = sent_indices[sorted_cap_indices] 52 | keys = [keys[i] for i in sorted_cap_indices.numpy()] 53 | # print('keys', type(keys), keys[-1]) # list 54 | if cfg.CUDA: 55 | captions = Variable(captions).cuda() 56 | sorted_cap_lens = Variable(sorted_cap_lens).cuda() 57 | transformation_matrices[0] = transformation_matrices[0].cuda() 58 | transformation_matrices[1] = transformation_matrices[1].cuda() 59 | label = label.cuda() 60 | else: 61 | captions = Variable(captions) 62 | sorted_cap_lens = Variable(sorted_cap_lens) 63 | 64 | if eval: 65 | bbox = bbox[sorted_cap_indices] 66 | return [real_imgs, captions, sorted_cap_lens, class_ids, keys, transformation_matrices, label, bbox] 67 | else: 68 | return [real_imgs, captions, sorted_cap_lens, class_ids, keys, transformation_matrices, label] 69 | 70 | 71 | def get_imgs(img_path, imsize, bbox=None, 72 | transform=None, normalize=None): 73 | img = Image.open(img_path).convert('RGB') 74 | if transform is not None: 75 | img = transform(img) 76 | 77 | img, bbox_scaled = crop_imgs(img, bbox) 78 | 79 | ret = [] 80 | if cfg.GAN.B_DCGAN: 81 | ret = [normalize(img)] 82 | else: 83 | for i in range(cfg.TREE.BRANCH_NUM): 84 | # print(imsize[i]) 85 | if i < (cfg.TREE.BRANCH_NUM - 1): 86 | re_img = transforms.ToPILImage()(img) 87 | re_img = transforms.Resize((imsize[i], imsize[i]))(re_img) 88 | else: 89 | re_img = transforms.ToPILImage()(img) 90 | ret.append(normalize(re_img)) 91 | 92 | return ret, bbox_scaled 93 | 94 | 95 | def crop_imgs(image, bbox, max_objects=3): 96 | ori_size = 268 97 | imsize = 256 98 | 99 | flip_img = random.random() < 0.5 100 | img_crop = ori_size - imsize 101 | h1 = int(np.floor((img_crop) * np.random.random())) 102 | w1 = int(np.floor((img_crop) * np.random.random())) 103 | 104 | bbox_scaled = np.zeros_like(bbox) 105 | bbox_scaled[...] = -1.0 106 | 107 | for idx in range(max_objects): 108 | bbox_tmp = bbox[idx] 109 | if bbox_tmp[0] == -1: 110 | break 111 | 112 | x_new = max(bbox_tmp[0] * float(ori_size) - h1, 0) / float(imsize) 113 | y_new = max(bbox_tmp[1] * float(ori_size) - w1, 0) / float(imsize) 114 | 115 | width_new = min((float(ori_size)/imsize) * bbox_tmp[2], 1.0) 116 | if x_new + width_new > 0.999: 117 | width_new = 1.0 - x_new - 0.001 118 | 119 | height_new = min((float(ori_size)/imsize) * bbox_tmp[3], 1.0) 120 | if y_new + height_new > 0.999: 121 | height_new = 1.0 - y_new - 0.001 122 | 123 | if flip_img: 124 | x_new = 1.0-x_new-width_new 125 | 126 | bbox_scaled[idx] = [x_new, y_new, width_new, height_new] 127 | 128 | cropped_image = image[:, w1: w1 + imsize, h1: h1 + imsize] 129 | 130 | if flip_img: 131 | idx = [i for i in reversed(range(cropped_image.shape[2]))] 132 | idx = torch.LongTensor(idx) 133 | transformed_image = torch.index_select(cropped_image, 2, idx) 134 | else: 135 | transformed_image = cropped_image 136 | 137 | return transformed_image, bbox_scaled 138 | 139 | 140 | class TextDataset(data.Dataset): 141 | def __init__(self, data_dir, img_dir, split='train', base_size=64, 142 | transform=None, target_transform=None, eval=False): 143 | self.transform = transform 144 | self.norm = transforms.Compose([ 145 | transforms.ToTensor(), 146 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 147 | self.target_transform = target_transform 148 | self.embeddings_num = cfg.TEXT.CAPTIONS_PER_IMAGE 149 | self.img_dir = img_dir 150 | self.split_dir = os.path.join(data_dir, split) 151 | self.eval = eval 152 | 153 | self.imsize = [] 154 | for i in range(cfg.TREE.BRANCH_NUM): 155 | self.imsize.append(base_size) 156 | base_size = base_size * 2 157 | 158 | self.data = [] 159 | self.data_dir = data_dir 160 | self.bbox = self.load_bbox() 161 | self.labels = self.load_labels() 162 | self.split_dir = os.path.join(data_dir, split) 163 | self.max_objects = 3 164 | 165 | self.filenames, self.captions, self.ixtoword, \ 166 | self.wordtoix, self.n_words = self.load_text_data(data_dir, split) 167 | 168 | self.class_id = self.load_class_id(self.split_dir, len(self.filenames)) 169 | self.number_example = len(self.filenames) 170 | 171 | def load_bbox(self): 172 | bbox_path = os.path.join(self.split_dir, 'bboxes.pickle') 173 | with open(bbox_path, "rb") as f: 174 | bboxes = pickle.load(f) 175 | bboxes = np.array(bboxes) 176 | print("bboxes: ", bboxes.shape) 177 | return bboxes 178 | 179 | def load_labels(self): 180 | label_path = os.path.join(self.split_dir, 'labels.pickle') 181 | with open(label_path, "rb") as f: 182 | labels = pickle.load(f) 183 | labels = np.array(labels) 184 | print("labels: ", labels.shape) 185 | return labels 186 | 187 | def load_captions(self, data_dir, filenames): 188 | all_captions = [] 189 | for i in range(len(filenames)): 190 | cap_path = '%s/text/%s.txt' % (data_dir, filenames[i]) 191 | with open(cap_path, "r") as f: 192 | captions = f.read().decode('utf8').split('\n') 193 | cnt = 0 194 | for cap in captions: 195 | if len(cap) == 0: 196 | continue 197 | cap = cap.replace("\ufffd\ufffd", " ") 198 | # picks out sequences of alphanumeric characters as tokens 199 | # and drops everything else 200 | tokenizer = RegexpTokenizer(r'\w+') 201 | tokens = tokenizer.tokenize(cap.lower()) 202 | # print('tokens', tokens) 203 | if len(tokens) == 0: 204 | print('cap', cap) 205 | continue 206 | 207 | tokens_new = [] 208 | for t in tokens: 209 | t = t.encode('ascii', 'ignore').decode('ascii') 210 | if len(t) > 0: 211 | tokens_new.append(t) 212 | all_captions.append(tokens_new) 213 | cnt += 1 214 | if cnt == self.embeddings_num: 215 | break 216 | if cnt < self.embeddings_num: 217 | print('ERROR: the captions for %s less than %d' 218 | % (filenames[i], cnt)) 219 | return all_captions 220 | 221 | def build_dictionary(self, train_captions, test_captions): 222 | word_counts = defaultdict(float) 223 | captions = train_captions + test_captions 224 | for sent in captions: 225 | for word in sent: 226 | word_counts[word] += 1 227 | 228 | vocab = [w for w in word_counts if word_counts[w] >= 0] 229 | 230 | ixtoword = {} 231 | ixtoword[0] = '' 232 | wordtoix = {} 233 | wordtoix[''] = 0 234 | ix = 1 235 | for w in vocab: 236 | wordtoix[w] = ix 237 | ixtoword[ix] = w 238 | ix += 1 239 | 240 | train_captions_new = [] 241 | for t in train_captions: 242 | rev = [] 243 | for w in t: 244 | if w in wordtoix: 245 | rev.append(wordtoix[w]) 246 | # rev.append(0) # do not need '' token 247 | train_captions_new.append(rev) 248 | 249 | test_captions_new = [] 250 | for t in test_captions: 251 | rev = [] 252 | for w in t: 253 | if w in wordtoix: 254 | rev.append(wordtoix[w]) 255 | # rev.append(0) # do not need '' token 256 | test_captions_new.append(rev) 257 | 258 | return [train_captions_new, test_captions_new, 259 | ixtoword, wordtoix, len(ixtoword)] 260 | 261 | def load_text_data(self, data_dir, split): 262 | filepath = os.path.join(data_dir, 'captions.pickle') 263 | train_names = self.load_filenames(data_dir, 'train') 264 | test_names = self.load_filenames(data_dir, 'test') 265 | if not os.path.isfile(filepath): 266 | train_captions = self.load_captions(data_dir, train_names) 267 | test_captions = self.load_captions(data_dir, test_names) 268 | 269 | train_captions, test_captions, ixtoword, wordtoix, n_words = \ 270 | self.build_dictionary(train_captions, test_captions) 271 | with open(filepath, 'wb') as f: 272 | pickle.dump([train_captions, test_captions, 273 | ixtoword, wordtoix], f, protocol=2) 274 | print('Save to: ', filepath) 275 | else: 276 | with open(filepath, 'rb') as f: 277 | x = pickle.load(f) 278 | train_captions, test_captions = x[0], x[1] 279 | ixtoword, wordtoix = x[2], x[3] 280 | del x 281 | n_words = len(ixtoword) 282 | print('Load from: ', filepath) 283 | if split == 'train': 284 | # a list of list: each list contains 285 | # the indices of words in a sentence 286 | captions = train_captions 287 | filenames = train_names 288 | else: # split=='test' 289 | captions = test_captions 290 | filenames = test_names 291 | return filenames, captions, ixtoword, wordtoix, n_words 292 | 293 | def load_class_id(self, data_dir, total_num): 294 | if os.path.isfile(data_dir + '/class_info.pickle'): 295 | with open(data_dir + '/class_info.pickle', 'rb') as f: 296 | class_id = pickle.load(f) 297 | else: 298 | class_id = np.arange(total_num) 299 | return class_id 300 | 301 | def load_filenames(self, data_dir, split): 302 | filepath = '%s/%s/filenames.pickle' % (data_dir, split) 303 | if os.path.isfile(filepath): 304 | with open(filepath, 'rb') as f: 305 | filenames = pickle.load(f) 306 | print('Load filenames from: %s (%d)' % (filepath, len(filenames))) 307 | else: 308 | filenames = [] 309 | return filenames 310 | 311 | def get_caption(self, sent_ix): 312 | # a list of indices for a sentence 313 | sent_caption = np.asarray(self.captions[sent_ix]).astype('int64') 314 | if (sent_caption == 0).sum() > 0: 315 | print('ERROR: do not need END (0) token', sent_caption) 316 | num_words = len(sent_caption) 317 | # pad with 0s (i.e., '') 318 | x = np.zeros((cfg.TEXT.WORDS_NUM, 1), dtype='int64') 319 | x_len = num_words 320 | if num_words <= cfg.TEXT.WORDS_NUM: 321 | x[:num_words, 0] = sent_caption 322 | else: 323 | ix = list(np.arange(num_words)) # 1, 2, 3,..., maxNum 324 | np.random.shuffle(ix) 325 | ix = ix[:cfg.TEXT.WORDS_NUM] 326 | ix = np.sort(ix) 327 | x[:, 0] = sent_caption[ix] 328 | x_len = cfg.TEXT.WORDS_NUM 329 | return x, x_len 330 | 331 | def get_transformation_matrices(self, bbox): 332 | bbox = torch.from_numpy(bbox) 333 | bbox = bbox.view(-1, 4) 334 | transf_matrices_inv = compute_transformation_matrix_inverse(bbox) 335 | transf_matrices_inv = transf_matrices_inv.view(self.max_objects, 2, 3) 336 | transf_matrices = compute_transformation_matrix(bbox) 337 | transf_matrices = transf_matrices.view(self.max_objects, 2, 3) 338 | 339 | return transf_matrices, transf_matrices_inv 340 | 341 | def get_one_hot_labels(self, label): 342 | labels = torch.from_numpy(label) 343 | labels = labels.long() 344 | # remove -1 to enable one-hot converting 345 | labels[labels < 0] = 80 346 | label_one_hot = torch.FloatTensor(labels.shape[0], 81).fill_(0) 347 | label_one_hot = label_one_hot.scatter_(1, labels, 1).float() 348 | 349 | return label_one_hot 350 | 351 | def __getitem__(self, index): 352 | # 353 | key = self.filenames[index] 354 | cls_id = self.class_id[index] 355 | # 356 | if self.bbox is not None: 357 | bbox = self.bbox[index] 358 | 359 | img_name = '%s/%s.jpg' % (self.img_dir, key) 360 | imgs, bbox_scaled = get_imgs(img_name, self.imsize, 361 | bbox, self.transform, normalize=self.norm) 362 | transformation_matrices = self.get_transformation_matrices(bbox_scaled) 363 | 364 | # load label 365 | label = self.labels[index] 366 | label = self.get_one_hot_labels(label) 367 | 368 | # randomly select a sentence 369 | sent_ix = random.randint(0, self.embeddings_num) 370 | new_sent_ix = index * self.embeddings_num + sent_ix 371 | caps, cap_len = self.get_caption(new_sent_ix) 372 | if self.eval: 373 | return imgs, caps, cap_len, cls_id, key, transformation_matrices, label, bbox_scaled 374 | return imgs, caps, cap_len, cls_id, key, transformation_matrices, label 375 | 376 | def __len__(self): 377 | return len(self.filenames) 378 | -------------------------------------------------------------------------------- /code/coco/attngan/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import numpy as np 5 | from miscc.config import cfg 6 | 7 | from GlobalAttention import func_attention 8 | 9 | 10 | # ##################Loss for matching text-image################### 11 | def cosine_similarity(x1, x2, dim=1, eps=1e-8): 12 | """Returns cosine similarity between x1 and x2, computed along dim. 13 | """ 14 | w12 = torch.sum(x1 * x2, dim) 15 | w1 = torch.norm(x1, 2, dim) 16 | w2 = torch.norm(x2, 2, dim) 17 | return (w12 / (w1 * w2).clamp(min=eps)).squeeze() 18 | 19 | 20 | def sent_loss(cnn_code, rnn_code, labels, class_ids, 21 | batch_size, eps=1e-8): 22 | # ### Mask mis-match samples ### 23 | # that come from the same class as the real sample ### 24 | masks = [] 25 | if class_ids is not None: 26 | for i in range(batch_size): 27 | mask = (class_ids == class_ids[i]).astype(np.uint8) 28 | mask[i] = 0 29 | masks.append(mask.reshape((1, -1))) 30 | masks = np.concatenate(masks, 0) 31 | # masks: batch_size x batch_size 32 | masks = torch.ByteTensor(masks) 33 | if cfg.CUDA: 34 | masks = masks.cuda() 35 | 36 | # --> seq_len x batch_size x nef 37 | if cnn_code.dim() == 2: 38 | cnn_code = cnn_code.unsqueeze(0) 39 | rnn_code = rnn_code.unsqueeze(0) 40 | 41 | # cnn_code_norm / rnn_code_norm: seq_len x batch_size x 1 42 | cnn_code_norm = torch.norm(cnn_code, 2, dim=2, keepdim=True) 43 | rnn_code_norm = torch.norm(rnn_code, 2, dim=2, keepdim=True) 44 | # scores* / norm*: seq_len x batch_size x batch_size 45 | scores0 = torch.bmm(cnn_code, rnn_code.transpose(1, 2)) 46 | norm0 = torch.bmm(cnn_code_norm, rnn_code_norm.transpose(1, 2)) 47 | scores0 = scores0 / norm0.clamp(min=eps) * cfg.TRAIN.SMOOTH.GAMMA3 48 | 49 | # --> batch_size x batch_size 50 | scores0 = scores0.squeeze() 51 | if class_ids is not None: 52 | scores0.data.masked_fill_(masks, -float('inf')) 53 | scores1 = scores0.transpose(0, 1) 54 | if labels is not None: 55 | loss0 = nn.CrossEntropyLoss()(scores0, labels) 56 | loss1 = nn.CrossEntropyLoss()(scores1, labels) 57 | else: 58 | loss0, loss1 = None, None 59 | return loss0, loss1 60 | 61 | 62 | def words_loss(img_features, words_emb, labels, 63 | cap_lens, class_ids, batch_size): 64 | """ 65 | words_emb(query): batch x nef x seq_len 66 | img_features(context): batch x nef x 17 x 17 67 | """ 68 | masks = [] 69 | att_maps = [] 70 | similarities = [] 71 | cap_lens = cap_lens.data.tolist() 72 | for i in range(batch_size): 73 | if class_ids is not None: 74 | mask = (class_ids == class_ids[i]).astype(np.uint8) 75 | mask[i] = 0 76 | masks.append(mask.reshape((1, -1))) 77 | # Get the i-th text description 78 | words_num = cap_lens[i] 79 | # -> 1 x nef x words_num 80 | word = words_emb[i, :, :words_num].unsqueeze(0).contiguous() 81 | # -> batch_size x nef x words_num 82 | word = word.repeat(batch_size, 1, 1) 83 | # batch x nef x 17*17 84 | context = img_features 85 | """ 86 | word(query): batch x nef x words_num 87 | context: batch x nef x 17 x 17 88 | weiContext: batch x nef x words_num 89 | attn: batch x words_num x 17 x 17 90 | """ 91 | weiContext, attn = func_attention(word, context, cfg.TRAIN.SMOOTH.GAMMA1) 92 | att_maps.append(attn[i].unsqueeze(0).contiguous()) 93 | # --> batch_size x words_num x nef 94 | word = word.transpose(1, 2).contiguous() 95 | weiContext = weiContext.transpose(1, 2).contiguous() 96 | # --> batch_size*words_num x nef 97 | word = word.view(batch_size * words_num, -1) 98 | weiContext = weiContext.view(batch_size * words_num, -1) 99 | # 100 | # -->batch_size*words_num 101 | row_sim = cosine_similarity(word, weiContext) 102 | # --> batch_size x words_num 103 | row_sim = row_sim.view(batch_size, words_num) 104 | 105 | # Eq. (10) 106 | row_sim.mul_(cfg.TRAIN.SMOOTH.GAMMA2).exp_() 107 | row_sim = row_sim.sum(dim=1, keepdim=True) 108 | row_sim = torch.log(row_sim) 109 | 110 | # --> 1 x batch_size 111 | # similarities(i, j): the similarity between the i-th image and the j-th text description 112 | similarities.append(row_sim) 113 | 114 | # batch_size x batch_size 115 | similarities = torch.cat(similarities, 1) 116 | if class_ids is not None: 117 | masks = np.concatenate(masks, 0) 118 | # masks: batch_size x batch_size 119 | masks = torch.ByteTensor(masks) 120 | if cfg.CUDA: 121 | masks = masks.cuda() 122 | 123 | similarities = similarities * cfg.TRAIN.SMOOTH.GAMMA3 124 | if class_ids is not None: 125 | similarities.data.masked_fill_(masks, -float('inf')) 126 | similarities1 = similarities.transpose(0, 1) 127 | if labels is not None: 128 | loss0 = nn.CrossEntropyLoss()(similarities, labels) 129 | loss1 = nn.CrossEntropyLoss()(similarities1, labels) 130 | else: 131 | loss0, loss1 = None, None 132 | return loss0, loss1, att_maps 133 | 134 | 135 | # ##################Loss for G and Ds############################## 136 | def discriminator_loss(netD, real_imgs, fake_imgs, conditions, 137 | real_labels, fake_labels, gpus, local_labels=None, 138 | transf_matrices=None, transf_matrices_inv=None): 139 | # Forward 140 | # real_features = netD(real_imgs) 141 | # fake_features = netD(fake_imgs.detach()) 142 | if local_labels is not None: 143 | inputs = (real_imgs, local_labels, transf_matrices, transf_matrices_inv) 144 | else: 145 | inputs = (real_imgs) 146 | real_features = nn.parallel.data_parallel(netD, inputs, gpus) 147 | # real_features = netD(real_imgs, local_labels, transf_matrices, transf_matrices_inv) 148 | if local_labels is not None: 149 | inputs = (fake_imgs.detach(), local_labels, transf_matrices, transf_matrices_inv) 150 | else: 151 | inputs = (fake_imgs.detach()) 152 | fake_features = nn.parallel.data_parallel(netD, inputs, gpus) 153 | # fake_features = netD(fake_imgs.detach(), local_labels, transf_matrices, transf_matrices_inv) 154 | # loss 155 | # 156 | cond_real_logits = netD.COND_DNET(real_features, conditions) 157 | cond_real_errD = nn.BCELoss()(cond_real_logits, real_labels) 158 | cond_fake_logits = netD.COND_DNET(fake_features, conditions) 159 | cond_fake_errD = nn.BCELoss()(cond_fake_logits, fake_labels) 160 | # 161 | batch_size = real_features.size(0) 162 | cond_wrong_logits = netD.COND_DNET(real_features[:(batch_size - 1)], conditions[1:batch_size]) 163 | cond_wrong_errD = nn.BCELoss()(cond_wrong_logits, fake_labels[1:batch_size]) 164 | 165 | if netD.UNCOND_DNET is not None: 166 | real_logits = netD.UNCOND_DNET(real_features) 167 | fake_logits = netD.UNCOND_DNET(fake_features) 168 | real_errD = nn.BCELoss()(real_logits, real_labels) 169 | fake_errD = nn.BCELoss()(fake_logits, fake_labels) 170 | errD = ((real_errD + cond_real_errD) / 2. + 171 | (fake_errD + cond_fake_errD + cond_wrong_errD) / 3.) 172 | else: 173 | errD = cond_real_errD + (cond_fake_errD + cond_wrong_errD) / 2. 174 | return errD 175 | 176 | 177 | def generator_loss(netsD, image_encoder, fake_imgs, real_labels, 178 | words_embs, sent_emb, match_labels, 179 | cap_lens, class_ids, gpus, local_labels=None, 180 | transf_matrices=None, transf_matrices_inv=None): 181 | numDs = len(netsD) 182 | batch_size = real_labels.size(0) 183 | logs = '' 184 | # Forward 185 | errG_total = 0 186 | for i in range(numDs): 187 | # features = netsD[i](fake_imgs[i]) 188 | if i == 0: 189 | inputs = (fake_imgs[i], local_labels, transf_matrices, transf_matrices_inv) 190 | else: 191 | inputs = (fake_imgs[i]) 192 | # features = netsD[i](fake_imgs[i], local_labels, transf_matrices, transf_matrices_inv) 193 | features = nn.parallel.data_parallel(netsD[i], inputs, gpus) 194 | cond_logits = netsD[i].COND_DNET(features, sent_emb) 195 | cond_errG = nn.BCELoss()(cond_logits, real_labels) 196 | if netsD[i].UNCOND_DNET is not None: 197 | logits = netsD[i].UNCOND_DNET(features) 198 | errG = nn.BCELoss()(logits, real_labels) 199 | g_loss = errG + cond_errG 200 | else: 201 | g_loss = cond_errG 202 | errG_total += g_loss 203 | # err_img = errG_total.data[0] 204 | logs += 'g_loss%d: %.2f ' % (i, g_loss.item()) 205 | 206 | # Ranking loss 207 | if i == (numDs - 1): 208 | # words_features: batch_size x nef x 17 x 17 209 | # sent_code: batch_size x nef 210 | region_features, cnn_code = image_encoder(fake_imgs[i]) 211 | w_loss0, w_loss1, _ = words_loss(region_features, words_embs, 212 | match_labels, cap_lens, 213 | class_ids, batch_size) 214 | w_loss = (w_loss0 + w_loss1) * \ 215 | cfg.TRAIN.SMOOTH.LAMBDA 216 | # err_words = err_words + w_loss.data[0] 217 | 218 | s_loss0, s_loss1 = sent_loss(cnn_code, sent_emb, 219 | match_labels, class_ids, batch_size) 220 | s_loss = (s_loss0 + s_loss1) * \ 221 | cfg.TRAIN.SMOOTH.LAMBDA 222 | # err_sent = err_sent + s_loss.data[0] 223 | 224 | errG_total += w_loss + s_loss 225 | logs += 'w_loss: %.2f s_loss: %.2f ' % (w_loss.item(), s_loss.item()) 226 | return errG_total, logs 227 | 228 | 229 | ################################################################## 230 | def KL_loss(mu, logvar): 231 | # -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) 232 | KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) 233 | KLD = torch.mean(KLD_element).mul_(-0.5) 234 | return KLD 235 | -------------------------------------------------------------------------------- /code/coco/attngan/main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | from miscc.config import cfg, cfg_from_file 4 | from datasets import TextDataset 5 | from trainer import condGANTrainer as trainer 6 | 7 | import os 8 | import sys 9 | import time 10 | import random 11 | import pprint 12 | import datetime 13 | import dateutil.tz 14 | import argparse 15 | import numpy as np 16 | from shutil import copyfile 17 | 18 | import torch 19 | import torchvision.transforms as transforms 20 | 21 | dir_path = (os.path.abspath(os.path.join(os.path.realpath(__file__), './.'))) 22 | sys.path.append(dir_path) 23 | 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser(description='Train a AttnGAN network') 27 | parser.add_argument('--cfg', dest='cfg_file', 28 | help='optional config file', 29 | default='cfg/bird_attn2.yml', type=str) 30 | # parser.add_argument('--gpu', dest='gpu_id', type=int, default=-1) 31 | parser.add_argument('--gpu', dest='gpu_id', type=str, default='0') 32 | parser.add_argument('--resume', dest='resume', type=str, default='') 33 | parser.add_argument('--data_dir', dest='data_dir', type=str, default='') 34 | parser.add_argument('--manualSeed', type=int, help='manual seed') 35 | args = parser.parse_args() 36 | return args 37 | 38 | 39 | def gen_example(wordtoix, algo): 40 | '''generate images from example sentences''' 41 | from nltk.tokenize import RegexpTokenizer 42 | filepath = '%s/example_filenames.txt' % (cfg.DATA_DIR) 43 | data_dic = {} 44 | with open(filepath, "r") as f: 45 | filenames = f.read().decode('utf8').split('\n') 46 | for name in filenames: 47 | if len(name) == 0: 48 | continue 49 | filepath = '%s/%s.txt' % (cfg.DATA_DIR, name) 50 | with open(filepath, "r") as f: 51 | print('Load from:', name) 52 | sentences = f.read().decode('utf8').split('\n') 53 | # a list of indices for a sentence 54 | captions = [] 55 | cap_lens = [] 56 | for sent in sentences: 57 | if len(sent) == 0: 58 | continue 59 | sent = sent.replace("\ufffd\ufffd", " ") 60 | tokenizer = RegexpTokenizer(r'\w+') 61 | tokens = tokenizer.tokenize(sent.lower()) 62 | if len(tokens) == 0: 63 | print('sent', sent) 64 | continue 65 | 66 | rev = [] 67 | for t in tokens: 68 | t = t.encode('ascii', 'ignore').decode('ascii') 69 | if len(t) > 0 and t in wordtoix: 70 | rev.append(wordtoix[t]) 71 | captions.append(rev) 72 | cap_lens.append(len(rev)) 73 | max_len = np.max(cap_lens) 74 | 75 | sorted_indices = np.argsort(cap_lens)[::-1] 76 | cap_lens = np.asarray(cap_lens) 77 | cap_lens = cap_lens[sorted_indices] 78 | cap_array = np.zeros((len(captions), max_len), dtype='int64') 79 | for i in range(len(captions)): 80 | idx = sorted_indices[i] 81 | cap = captions[idx] 82 | c_len = len(cap) 83 | cap_array[i, :c_len] = cap 84 | key = name[(name.rfind('/') + 1):] 85 | data_dic[key] = [cap_array, cap_lens, sorted_indices] 86 | algo.gen_example(data_dic) 87 | 88 | 89 | if __name__ == "__main__": 90 | args = parse_args() 91 | if args.cfg_file is not None: 92 | cfg_from_file(args.cfg_file) 93 | if args.gpu_id != -1: 94 | cfg.GPU_ID = args.gpu_id 95 | else: 96 | cfg.CUDA = False 97 | if args.data_dir != '': 98 | cfg.DATA_DIR = args.data_dir 99 | print('Using config:') 100 | pprint.pprint(cfg) 101 | if args.manualSeed is None: 102 | args.manualSeed = random.randint(1, 10000) 103 | random.seed(args.manualSeed) 104 | np.random.seed(args.manualSeed) 105 | torch.manual_seed(args.manualSeed) 106 | if cfg.CUDA: 107 | torch.cuda.manual_seed_all(args.manualSeed) 108 | 109 | if args.resume == "": 110 | resume = False 111 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 112 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') 113 | output_dir = '../../../output/%s_%s_%s' % \ 114 | (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp) 115 | else: 116 | assert os.path.isdir(args.resume) 117 | resume = True 118 | output_dir = args.resume 119 | 120 | split_dir, bshuffle = 'train', True 121 | eval = False 122 | if not cfg.TRAIN.FLAG: 123 | split_dir = 'test' 124 | eval = True 125 | 126 | # Get data loader 127 | imsize = cfg.TREE.BASE_SIZE * (2 ** (cfg.TREE.BRANCH_NUM - 1)) 128 | image_transform = transforms.Compose([ 129 | transforms.Resize((268, 268)), 130 | transforms.ToTensor()]) 131 | dataset = TextDataset(cfg.DATA_DIR, cfg.IMG_DIR, split_dir, 132 | base_size=cfg.TREE.BASE_SIZE, 133 | transform=image_transform, eval=eval) 134 | assert dataset 135 | dataloader = torch.utils.data.DataLoader( 136 | dataset, batch_size=cfg.TRAIN.BATCH_SIZE, 137 | drop_last=True, shuffle=bshuffle, num_workers=int(cfg.WORKERS)) 138 | 139 | # Define models and go to train/evaluate 140 | algo = trainer(output_dir, dataloader, dataset.n_words, dataset.ixtoword, resume) 141 | 142 | start_t = time.time() 143 | if cfg.TRAIN.FLAG: 144 | if not resume: 145 | copyfile(sys.argv[0], output_dir + "/" + sys.argv[0]) 146 | copyfile("trainer.py", output_dir + "/" + "trainer.py") 147 | copyfile("model.py", output_dir + "/" + "model.py") 148 | copyfile("miscc/utils.py", output_dir + "/" + "utils.py") 149 | copyfile("miscc/losses.py", output_dir + "/" + "losses.py") 150 | copyfile("datasets.py", output_dir + "/" + "datasets.py") 151 | copyfile(args.cfg_file, output_dir + "/" + "cfg_file.yml") 152 | algo.train() 153 | else: 154 | '''generate images from pre-extracted embeddings''' 155 | if cfg.B_VALIDATION: 156 | # algo.visualize_bbox(split_dir, num_samples=25, draw_bbox=True) 157 | # algo.sampling(split_dir) # generate images for the whole valid dataset 158 | algo.sample(split_dir, num_samples=25, draw_bbox=True) 159 | else: 160 | gen_example(dataset.wordtoix, algo, num_samples=30000) # generate images for customized captions 161 | end_t = time.time() 162 | -------------------------------------------------------------------------------- /code/coco/attngan/miscc/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | -------------------------------------------------------------------------------- /code/coco/attngan/miscc/config.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import os.path as osp 5 | import numpy as np 6 | from easydict import EasyDict as edict 7 | 8 | 9 | __C = edict() 10 | cfg = __C 11 | 12 | # Dataset name: flowers, birds 13 | __C.DATASET_NAME = 'birds' 14 | __C.CONFIG_NAME = '' 15 | __C.DATA_DIR = '' 16 | __C.IMG_DIR = '' 17 | __C.GPU_ID = '0' 18 | __C.CUDA = True 19 | __C.WORKERS = 6 20 | 21 | __C.RNN_TYPE = 'LSTM' # 'GRU' 22 | __C.B_VALIDATION = False 23 | 24 | __C.TREE = edict() 25 | __C.TREE.BRANCH_NUM = 3 26 | __C.TREE.BASE_SIZE = 64 27 | 28 | 29 | # Training options 30 | __C.TRAIN = edict() 31 | __C.TRAIN.BATCH_SIZE = 64 32 | __C.TRAIN.MAX_EPOCH = 600 33 | __C.TRAIN.SNAPSHOT_INTERVAL = 2000 34 | __C.TRAIN.DISCRIMINATOR_LR = 2e-4 35 | __C.TRAIN.GENERATOR_LR = 2e-4 36 | __C.TRAIN.ENCODER_LR = 2e-4 37 | __C.TRAIN.RNN_GRAD_CLIP = 0.25 38 | __C.TRAIN.FLAG = True 39 | __C.TRAIN.NET_E = '' 40 | __C.TRAIN.NET_G = '' 41 | __C.TRAIN.B_NET_D = True 42 | 43 | __C.TRAIN.SMOOTH = edict() 44 | __C.TRAIN.SMOOTH.GAMMA1 = 5.0 45 | __C.TRAIN.SMOOTH.GAMMA3 = 10.0 46 | __C.TRAIN.SMOOTH.GAMMA2 = 5.0 47 | __C.TRAIN.SMOOTH.LAMBDA = 1.0 48 | 49 | 50 | # Modal options 51 | __C.GAN = edict() 52 | __C.GAN.DF_DIM = 64 53 | __C.GAN.GF_DIM = 128 54 | __C.GAN.Z_DIM = 100 55 | __C.GAN.CONDITION_DIM = 100 56 | __C.GAN.R_NUM = 2 57 | __C.GAN.B_ATTENTION = True 58 | __C.GAN.B_DCGAN = False 59 | 60 | 61 | __C.TEXT = edict() 62 | __C.TEXT.CAPTIONS_PER_IMAGE = 10 63 | __C.TEXT.EMBEDDING_DIM = 256 64 | __C.TEXT.WORDS_NUM = 18 65 | 66 | 67 | def _merge_a_into_b(a, b): 68 | """Merge config dictionary a into config dictionary b, clobbering the 69 | options in b whenever they are also specified in a. 70 | """ 71 | if type(a) is not edict: 72 | return 73 | 74 | for k, v in a.iteritems(): 75 | # a must specify keys that are in b 76 | if not b.has_key(k): 77 | raise KeyError('{} is not a valid config key'.format(k)) 78 | 79 | # the types must match, too 80 | old_type = type(b[k]) 81 | if old_type is not type(v): 82 | if isinstance(b[k], np.ndarray): 83 | v = np.array(v, dtype=b[k].dtype) 84 | else: 85 | raise ValueError(('Type mismatch ({} vs. {}) ' 86 | 'for config key: {}').format(type(b[k]), 87 | type(v), k)) 88 | 89 | # recursively merge dicts 90 | if type(v) is edict: 91 | try: 92 | _merge_a_into_b(a[k], b[k]) 93 | except: 94 | print('Error under config key: {}'.format(k)) 95 | raise 96 | else: 97 | b[k] = v 98 | 99 | 100 | def cfg_from_file(filename): 101 | """Load a config file and merge it into the default options.""" 102 | import yaml 103 | with open(filename, 'r') as f: 104 | yaml_cfg = edict(yaml.load(f)) 105 | 106 | _merge_a_into_b(yaml_cfg, __C) 107 | -------------------------------------------------------------------------------- /code/coco/attngan/miscc/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import numpy as np 5 | from miscc.config import cfg 6 | 7 | from GlobalAttention import func_attention 8 | 9 | 10 | # ##################Loss for matching text-image################### 11 | def cosine_similarity(x1, x2, dim=1, eps=1e-8): 12 | """Returns cosine similarity between x1 and x2, computed along dim. 13 | """ 14 | w12 = torch.sum(x1 * x2, dim) 15 | w1 = torch.norm(x1, 2, dim) 16 | w2 = torch.norm(x2, 2, dim) 17 | return (w12 / (w1 * w2).clamp(min=eps)).squeeze() 18 | 19 | 20 | def sent_loss(cnn_code, rnn_code, labels, class_ids, 21 | batch_size, eps=1e-8): 22 | # ### Mask mis-match samples ### 23 | # that come from the same class as the real sample ### 24 | masks = [] 25 | if class_ids is not None: 26 | for i in range(batch_size): 27 | mask = (class_ids == class_ids[i]).astype(np.uint8) 28 | mask[i] = 0 29 | masks.append(mask.reshape((1, -1))) 30 | masks = np.concatenate(masks, 0) 31 | # masks: batch_size x batch_size 32 | masks = torch.ByteTensor(masks) 33 | if cfg.CUDA: 34 | masks = masks.cuda() 35 | 36 | # --> seq_len x batch_size x nef 37 | if cnn_code.dim() == 2: 38 | cnn_code = cnn_code.unsqueeze(0) 39 | rnn_code = rnn_code.unsqueeze(0) 40 | 41 | # cnn_code_norm / rnn_code_norm: seq_len x batch_size x 1 42 | cnn_code_norm = torch.norm(cnn_code, 2, dim=2, keepdim=True) 43 | rnn_code_norm = torch.norm(rnn_code, 2, dim=2, keepdim=True) 44 | # scores* / norm*: seq_len x batch_size x batch_size 45 | scores0 = torch.bmm(cnn_code, rnn_code.transpose(1, 2)) 46 | norm0 = torch.bmm(cnn_code_norm, rnn_code_norm.transpose(1, 2)) 47 | scores0 = scores0 / norm0.clamp(min=eps) * cfg.TRAIN.SMOOTH.GAMMA3 48 | 49 | # --> batch_size x batch_size 50 | scores0 = scores0.squeeze() 51 | if class_ids is not None: 52 | scores0.data.masked_fill_(masks, -float('inf')) 53 | scores1 = scores0.transpose(0, 1) 54 | if labels is not None: 55 | loss0 = nn.CrossEntropyLoss()(scores0, labels) 56 | loss1 = nn.CrossEntropyLoss()(scores1, labels) 57 | else: 58 | loss0, loss1 = None, None 59 | return loss0, loss1 60 | 61 | 62 | def words_loss(img_features, words_emb, labels, 63 | cap_lens, class_ids, batch_size): 64 | """ 65 | words_emb(query): batch x nef x seq_len 66 | img_features(context): batch x nef x 17 x 17 67 | """ 68 | masks = [] 69 | att_maps = [] 70 | similarities = [] 71 | cap_lens = cap_lens.data.tolist() 72 | for i in range(batch_size): 73 | if class_ids is not None: 74 | mask = (class_ids == class_ids[i]).astype(np.uint8) 75 | mask[i] = 0 76 | masks.append(mask.reshape((1, -1))) 77 | # Get the i-th text description 78 | words_num = cap_lens[i] 79 | # -> 1 x nef x words_num 80 | word = words_emb[i, :, :words_num].unsqueeze(0).contiguous() 81 | # -> batch_size x nef x words_num 82 | word = word.repeat(batch_size, 1, 1) 83 | # batch x nef x 17*17 84 | context = img_features 85 | """ 86 | word(query): batch x nef x words_num 87 | context: batch x nef x 17 x 17 88 | weiContext: batch x nef x words_num 89 | attn: batch x words_num x 17 x 17 90 | """ 91 | weiContext, attn = func_attention(word, context, cfg.TRAIN.SMOOTH.GAMMA1) 92 | att_maps.append(attn[i].unsqueeze(0).contiguous()) 93 | # --> batch_size x words_num x nef 94 | word = word.transpose(1, 2).contiguous() 95 | weiContext = weiContext.transpose(1, 2).contiguous() 96 | # --> batch_size*words_num x nef 97 | word = word.view(batch_size * words_num, -1) 98 | weiContext = weiContext.view(batch_size * words_num, -1) 99 | # 100 | # -->batch_size*words_num 101 | row_sim = cosine_similarity(word, weiContext) 102 | # --> batch_size x words_num 103 | row_sim = row_sim.view(batch_size, words_num) 104 | 105 | # Eq. (10) 106 | row_sim.mul_(cfg.TRAIN.SMOOTH.GAMMA2).exp_() 107 | row_sim = row_sim.sum(dim=1, keepdim=True) 108 | row_sim = torch.log(row_sim) 109 | 110 | # --> 1 x batch_size 111 | # similarities(i, j): the similarity between the i-th image and the j-th text description 112 | similarities.append(row_sim) 113 | 114 | # batch_size x batch_size 115 | similarities = torch.cat(similarities, 1) 116 | if class_ids is not None: 117 | masks = np.concatenate(masks, 0) 118 | # masks: batch_size x batch_size 119 | masks = torch.ByteTensor(masks) 120 | if cfg.CUDA: 121 | masks = masks.cuda() 122 | 123 | similarities = similarities * cfg.TRAIN.SMOOTH.GAMMA3 124 | if class_ids is not None: 125 | similarities.data.masked_fill_(masks, -float('inf')) 126 | similarities1 = similarities.transpose(0, 1) 127 | if labels is not None: 128 | loss0 = nn.CrossEntropyLoss()(similarities, labels) 129 | loss1 = nn.CrossEntropyLoss()(similarities1, labels) 130 | else: 131 | loss0, loss1 = None, None 132 | return loss0, loss1, att_maps 133 | 134 | 135 | # ##################Loss for G and Ds############################## 136 | def discriminator_loss(netD, real_imgs, fake_imgs, conditions, 137 | real_labels, fake_labels, gpus, local_labels=None, 138 | transf_matrices=None, transf_matrices_inv=None): 139 | # Forward 140 | # real_features = netD(real_imgs) 141 | # fake_features = netD(fake_imgs.detach()) 142 | if local_labels is not None: 143 | inputs = (real_imgs, local_labels, transf_matrices, transf_matrices_inv) 144 | else: 145 | inputs = (real_imgs) 146 | real_features = nn.parallel.data_parallel(netD, inputs, gpus) 147 | # real_features = netD(real_imgs, local_labels, transf_matrices, transf_matrices_inv) 148 | if local_labels is not None: 149 | inputs = (fake_imgs.detach(), local_labels, transf_matrices, transf_matrices_inv) 150 | else: 151 | inputs = (fake_imgs.detach()) 152 | fake_features = nn.parallel.data_parallel(netD, inputs, gpus) 153 | # fake_features = netD(fake_imgs.detach(), local_labels, transf_matrices, transf_matrices_inv) 154 | # loss 155 | # 156 | cond_real_logits = netD.COND_DNET(real_features, conditions) 157 | cond_real_errD = nn.BCELoss()(cond_real_logits, real_labels) 158 | cond_fake_logits = netD.COND_DNET(fake_features, conditions) 159 | cond_fake_errD = nn.BCELoss()(cond_fake_logits, fake_labels) 160 | # 161 | batch_size = real_features.size(0) 162 | cond_wrong_logits = netD.COND_DNET(real_features[:(batch_size - 1)], conditions[1:batch_size]) 163 | cond_wrong_errD = nn.BCELoss()(cond_wrong_logits, fake_labels[1:batch_size]) 164 | 165 | if netD.UNCOND_DNET is not None: 166 | real_logits = netD.UNCOND_DNET(real_features) 167 | fake_logits = netD.UNCOND_DNET(fake_features) 168 | real_errD = nn.BCELoss()(real_logits, real_labels) 169 | fake_errD = nn.BCELoss()(fake_logits, fake_labels) 170 | errD = ((real_errD + cond_real_errD) / 2. + 171 | (fake_errD + cond_fake_errD + cond_wrong_errD) / 3.) 172 | else: 173 | errD = cond_real_errD + (cond_fake_errD + cond_wrong_errD) / 2. 174 | return errD 175 | 176 | 177 | def generator_loss(netsD, image_encoder, fake_imgs, real_labels, 178 | words_embs, sent_emb, match_labels, 179 | cap_lens, class_ids, gpus, local_labels=None, 180 | transf_matrices=None, transf_matrices_inv=None): 181 | numDs = len(netsD) 182 | batch_size = real_labels.size(0) 183 | logs = '' 184 | # Forward 185 | errG_total = 0 186 | for i in range(numDs): 187 | # features = netsD[i](fake_imgs[i]) 188 | if i == 0: 189 | inputs = (fake_imgs[i], local_labels, transf_matrices, transf_matrices_inv) 190 | else: 191 | inputs = (fake_imgs[i]) 192 | # features = netsD[i](fake_imgs[i], local_labels, transf_matrices, transf_matrices_inv) 193 | features = nn.parallel.data_parallel(netsD[i], inputs, gpus) 194 | cond_logits = netsD[i].COND_DNET(features, sent_emb) 195 | cond_errG = nn.BCELoss()(cond_logits, real_labels) 196 | if netsD[i].UNCOND_DNET is not None: 197 | logits = netsD[i].UNCOND_DNET(features) 198 | errG = nn.BCELoss()(logits, real_labels) 199 | g_loss = errG + cond_errG 200 | else: 201 | g_loss = cond_errG 202 | errG_total += g_loss 203 | # err_img = errG_total.data[0] 204 | logs += 'g_loss%d: %.2f ' % (i, g_loss.item()) 205 | 206 | # Ranking loss 207 | if i == (numDs - 1): 208 | # words_features: batch_size x nef x 17 x 17 209 | # sent_code: batch_size x nef 210 | region_features, cnn_code = image_encoder(fake_imgs[i]) 211 | w_loss0, w_loss1, _ = words_loss(region_features, words_embs, 212 | match_labels, cap_lens, 213 | class_ids, batch_size) 214 | w_loss = (w_loss0 + w_loss1) * \ 215 | cfg.TRAIN.SMOOTH.LAMBDA 216 | # err_words = err_words + w_loss.data[0] 217 | 218 | s_loss0, s_loss1 = sent_loss(cnn_code, sent_emb, 219 | match_labels, class_ids, batch_size) 220 | s_loss = (s_loss0 + s_loss1) * \ 221 | cfg.TRAIN.SMOOTH.LAMBDA 222 | # err_sent = err_sent + s_loss.data[0] 223 | 224 | errG_total += w_loss + s_loss 225 | logs += 'w_loss: %.2f s_loss: %.2f ' % (w_loss.item(), s_loss.item()) 226 | return errG_total, logs 227 | 228 | 229 | ################################################################## 230 | def KL_loss(mu, logvar): 231 | # -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) 232 | KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) 233 | KLD = torch.mean(KLD_element).mul_(-0.5) 234 | return KLD 235 | -------------------------------------------------------------------------------- /code/coco/attngan/miscc/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import errno 3 | import numpy as np 4 | from torch.nn import init 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from PIL import Image, ImageDraw, ImageFont 10 | from copy import deepcopy 11 | import skimage.transform 12 | 13 | from miscc.config import cfg 14 | 15 | 16 | def compute_transformation_matrix_inverse(bbox): 17 | x, y = bbox[:, 0], bbox[:, 1] 18 | w, h = bbox[:, 2], bbox[:, 3] 19 | 20 | scale_x = 1.0 / w 21 | scale_y = 1.0 / h 22 | 23 | t_x = 2 * scale_x * (0.5 - (x + 0.5 * w)) 24 | t_y = 2 * scale_y * (0.5 - (y + 0.5 * h)) 25 | 26 | zeros = torch.FloatTensor(bbox.shape[0],1).fill_(0) 27 | 28 | transformation_matrix = torch.cat([scale_x.unsqueeze(-1), zeros, t_x.unsqueeze(-1), 29 | zeros, scale_y.unsqueeze(-1), t_y.unsqueeze(-1)], 1).view(-1, 2, 3) 30 | 31 | return transformation_matrix 32 | 33 | 34 | def compute_transformation_matrix(bbox): 35 | x, y = bbox[:, 0], bbox[:, 1] 36 | w, h = bbox[:, 2], bbox[:, 3] 37 | 38 | scale_x = w 39 | scale_y = h 40 | 41 | t_x = 2 * ((x + 0.5 * w) - 0.5) 42 | t_y = 2 * ((y + 0.5 * h) - 0.5) 43 | 44 | zeros = torch.FloatTensor(bbox.shape[0],1).fill_(0) 45 | 46 | transformation_matrix = torch.cat([scale_x.unsqueeze(-1), zeros, t_x.unsqueeze(-1), 47 | zeros, scale_y.unsqueeze(-1), t_y.unsqueeze(-1)], 1).view(-1, 2, 3) 48 | 49 | return transformation_matrix 50 | 51 | # For visualization ################################################ 52 | COLOR_DIC = {0:[128,64,128], 1:[244, 35,232], 53 | 2:[70, 70, 70], 3:[102,102,156], 54 | 4:[190,153,153], 5:[153,153,153], 55 | 6:[250,170, 30], 7:[220, 220, 0], 56 | 8:[107,142, 35], 9:[152,251,152], 57 | 10:[70,130,180], 11:[220,20, 60], 58 | 12:[255, 0, 0], 13:[0, 0, 142], 59 | 14:[119,11, 32], 15:[0, 60,100], 60 | 16:[0, 80, 100], 17:[0, 0, 230], 61 | 18:[0, 0, 70], 19:[0, 0, 0]} 62 | FONT_MAX = 50 63 | 64 | 65 | def drawCaption(convas, captions, ixtoword, vis_size, off1=2, off2=2): 66 | num = captions.size(0) 67 | img_txt = Image.fromarray(convas) 68 | # get a font 69 | # fnt = None # ImageFont.truetype('Pillow/Tests/fonts/FreeMono.ttf', 50) 70 | fnt = ImageFont.truetype('Pillow/Tests/fonts/FreeMono.ttf', 50) 71 | # get a drawing context 72 | d = ImageDraw.Draw(img_txt) 73 | sentence_list = [] 74 | for i in range(num): 75 | cap = captions[i].data.cpu().numpy() 76 | sentence = [] 77 | for j in range(len(cap)): 78 | if cap[j] == 0: 79 | break 80 | word = ixtoword[cap[j]].encode('ascii', 'ignore').decode('ascii') 81 | d.text(((j + off1) * (vis_size + off2), i * FONT_MAX), '%d:%s' % (j, word[:6]), 82 | font=fnt, fill=(255, 255, 255, 255)) 83 | sentence.append(word) 84 | sentence_list.append(sentence) 85 | return img_txt, sentence_list 86 | 87 | 88 | def build_super_images(real_imgs, captions, ixtoword, 89 | attn_maps, att_sze, lr_imgs=None, 90 | batch_size=cfg.TRAIN.BATCH_SIZE, 91 | max_word_num=cfg.TEXT.WORDS_NUM): 92 | nvis = 8 93 | real_imgs = real_imgs[:nvis] 94 | if lr_imgs is not None: 95 | lr_imgs = lr_imgs[:nvis] 96 | if att_sze == 17: 97 | vis_size = att_sze * 16 98 | else: 99 | vis_size = real_imgs.size(2) 100 | 101 | text_convas = \ 102 | np.ones([batch_size * FONT_MAX, 103 | (max_word_num + 2) * (vis_size + 2), 3], 104 | dtype=np.uint8) 105 | 106 | for i in range(max_word_num): 107 | istart = (i + 2) * (vis_size + 2) 108 | iend = (i + 3) * (vis_size + 2) 109 | text_convas[:, istart:iend, :] = COLOR_DIC[i] 110 | 111 | 112 | real_imgs = \ 113 | nn.Upsample(size=(vis_size, vis_size), mode='bilinear')(real_imgs) 114 | # [-1, 1] --> [0, 1] 115 | real_imgs.add_(1).div_(2).mul_(255) 116 | real_imgs = real_imgs.data.numpy() 117 | # b x c x h x w --> b x h x w x c 118 | real_imgs = np.transpose(real_imgs, (0, 2, 3, 1)) 119 | pad_sze = real_imgs.shape 120 | middle_pad = np.zeros([pad_sze[2], 2, 3]) 121 | post_pad = np.zeros([pad_sze[1], pad_sze[2], 3]) 122 | if lr_imgs is not None: 123 | lr_imgs = \ 124 | nn.Upsample(size=(vis_size, vis_size), mode='bilinear')(lr_imgs) 125 | # [-1, 1] --> [0, 1] 126 | lr_imgs.add_(1).div_(2).mul_(255) 127 | lr_imgs = lr_imgs.data.numpy() 128 | # b x c x h x w --> b x h x w x c 129 | lr_imgs = np.transpose(lr_imgs, (0, 2, 3, 1)) 130 | 131 | # batch x seq_len x 17 x 17 --> batch x 1 x 17 x 17 132 | seq_len = max_word_num 133 | img_set = [] 134 | num = nvis # len(attn_maps) 135 | 136 | text_map, sentences = \ 137 | drawCaption(text_convas, captions, ixtoword, vis_size) 138 | text_map = np.asarray(text_map).astype(np.uint8) 139 | 140 | bUpdate = 1 141 | for i in range(num): 142 | attn = attn_maps[i].cpu().view(1, -1, att_sze, att_sze) 143 | # --> 1 x 1 x 17 x 17 144 | attn_max = attn.max(dim=1, keepdim=True) 145 | attn = torch.cat([attn_max[0], attn], 1) 146 | # 147 | attn = attn.view(-1, 1, att_sze, att_sze) 148 | attn = attn.repeat(1, 3, 1, 1).data.numpy() 149 | # n x c x h x w --> n x h x w x c 150 | attn = np.transpose(attn, (0, 2, 3, 1)) 151 | num_attn = attn.shape[0] 152 | # 153 | img = real_imgs[i] 154 | if lr_imgs is None: 155 | lrI = img 156 | else: 157 | lrI = lr_imgs[i] 158 | row = [lrI, middle_pad] 159 | row_merge = [img, middle_pad] 160 | row_beforeNorm = [] 161 | minVglobal, maxVglobal = 1, 0 162 | for j in range(num_attn): 163 | one_map = attn[j] 164 | if (vis_size // att_sze) > 1: 165 | one_map = \ 166 | skimage.transform.pyramid_expand(one_map, sigma=20, 167 | upscale=vis_size // att_sze) 168 | row_beforeNorm.append(one_map) 169 | minV = one_map.min() 170 | maxV = one_map.max() 171 | if minVglobal > minV: 172 | minVglobal = minV 173 | if maxVglobal < maxV: 174 | maxVglobal = maxV 175 | for j in range(seq_len + 1): 176 | if j < num_attn: 177 | one_map = row_beforeNorm[j] 178 | one_map = (one_map - minVglobal) / (maxVglobal - minVglobal) 179 | one_map *= 255 180 | # 181 | PIL_im = Image.fromarray(np.uint8(img)) 182 | PIL_att = Image.fromarray(np.uint8(one_map)) 183 | merged = \ 184 | Image.new('RGBA', (vis_size, vis_size), (0, 0, 0, 0)) 185 | mask = Image.new('L', (vis_size, vis_size), (210)) 186 | merged.paste(PIL_im, (0, 0)) 187 | merged.paste(PIL_att, (0, 0), mask) 188 | merged = np.array(merged)[:, :, :3] 189 | else: 190 | one_map = post_pad 191 | merged = post_pad 192 | row.append(one_map) 193 | row.append(middle_pad) 194 | # 195 | row_merge.append(merged) 196 | row_merge.append(middle_pad) 197 | row = np.concatenate(row, 1) 198 | row_merge = np.concatenate(row_merge, 1) 199 | txt = text_map[i * FONT_MAX: (i + 1) * FONT_MAX] 200 | if txt.shape[1] != row.shape[1]: 201 | print('txt', txt.shape, 'row', row.shape) 202 | bUpdate = 0 203 | break 204 | row = np.concatenate([txt, row, row_merge], 0) 205 | img_set.append(row) 206 | if bUpdate: 207 | img_set = np.concatenate(img_set, 0) 208 | img_set = img_set.astype(np.uint8) 209 | return img_set, sentences 210 | else: 211 | return None 212 | 213 | 214 | def build_super_images2(real_imgs, captions, cap_lens, ixtoword, 215 | attn_maps, att_sze, vis_size=256, topK=5): 216 | batch_size = real_imgs.size(0) 217 | max_word_num = np.max(cap_lens) 218 | text_convas = np.ones([batch_size * FONT_MAX, 219 | max_word_num * (vis_size + 2), 3], 220 | dtype=np.uint8) 221 | 222 | real_imgs = \ 223 | nn.Upsample(size=(vis_size, vis_size), mode='bilinear')(real_imgs) 224 | # [-1, 1] --> [0, 1] 225 | real_imgs.add_(1).div_(2).mul_(255) 226 | real_imgs = real_imgs.data.numpy() 227 | # b x c x h x w --> b x h x w x c 228 | real_imgs = np.transpose(real_imgs, (0, 2, 3, 1)) 229 | pad_sze = real_imgs.shape 230 | middle_pad = np.zeros([pad_sze[2], 2, 3]) 231 | 232 | # batch x seq_len x 17 x 17 --> batch x 1 x 17 x 17 233 | img_set = [] 234 | num = len(attn_maps) 235 | 236 | text_map, sentences = \ 237 | drawCaption(text_convas, captions, ixtoword, vis_size, off1=0) 238 | text_map = np.asarray(text_map).astype(np.uint8) 239 | 240 | bUpdate = 1 241 | for i in range(num): 242 | attn = attn_maps[i].cpu().view(1, -1, att_sze, att_sze) 243 | # 244 | attn = attn.view(-1, 1, att_sze, att_sze) 245 | attn = attn.repeat(1, 3, 1, 1).data.numpy() 246 | # n x c x h x w --> n x h x w x c 247 | attn = np.transpose(attn, (0, 2, 3, 1)) 248 | num_attn = cap_lens[i] 249 | thresh = 2./float(num_attn) 250 | # 251 | img = real_imgs[i] 252 | row = [] 253 | row_merge = [] 254 | row_txt = [] 255 | row_beforeNorm = [] 256 | conf_score = [] 257 | for j in range(num_attn): 258 | one_map = attn[j] 259 | mask0 = one_map > (2. * thresh) 260 | conf_score.append(np.sum(one_map * mask0)) 261 | mask = one_map > thresh 262 | one_map = one_map * mask 263 | if (vis_size // att_sze) > 1: 264 | one_map = \ 265 | skimage.transform.pyramid_expand(one_map, sigma=20, 266 | upscale=vis_size // att_sze) 267 | minV = one_map.min() 268 | maxV = one_map.max() 269 | one_map = (one_map - minV) / (maxV - minV) 270 | row_beforeNorm.append(one_map) 271 | sorted_indices = np.argsort(conf_score)[::-1] 272 | 273 | for j in range(num_attn): 274 | one_map = row_beforeNorm[j] 275 | one_map *= 255 276 | # 277 | PIL_im = Image.fromarray(np.uint8(img)) 278 | PIL_att = Image.fromarray(np.uint8(one_map)) 279 | merged = \ 280 | Image.new('RGBA', (vis_size, vis_size), (0, 0, 0, 0)) 281 | mask = Image.new('L', (vis_size, vis_size), (180)) # (210) 282 | merged.paste(PIL_im, (0, 0)) 283 | merged.paste(PIL_att, (0, 0), mask) 284 | merged = np.array(merged)[:, :, :3] 285 | 286 | row.append(np.concatenate([one_map, middle_pad], 1)) 287 | # 288 | row_merge.append(np.concatenate([merged, middle_pad], 1)) 289 | # 290 | txt = text_map[i * FONT_MAX:(i + 1) * FONT_MAX, 291 | j * (vis_size + 2):(j + 1) * (vis_size + 2), :] 292 | row_txt.append(txt) 293 | # reorder 294 | row_new = [] 295 | row_merge_new = [] 296 | txt_new = [] 297 | for j in range(num_attn): 298 | idx = sorted_indices[j] 299 | row_new.append(row[idx]) 300 | row_merge_new.append(row_merge[idx]) 301 | txt_new.append(row_txt[idx]) 302 | row = np.concatenate(row_new[:topK], 1) 303 | row_merge = np.concatenate(row_merge_new[:topK], 1) 304 | txt = np.concatenate(txt_new[:topK], 1) 305 | if txt.shape[1] != row.shape[1]: 306 | print('Warnings: txt', txt.shape, 'row', row.shape, 307 | 'row_merge_new', row_merge_new.shape) 308 | bUpdate = 0 309 | break 310 | row = np.concatenate([txt, row_merge], 0) 311 | img_set.append(row) 312 | if bUpdate: 313 | img_set = np.concatenate(img_set, 0) 314 | img_set = img_set.astype(np.uint8) 315 | return img_set, sentences 316 | else: 317 | return None 318 | 319 | 320 | #################################################################### 321 | def weights_init(m): 322 | classname = m.__class__.__name__ 323 | if classname.find('Conv') != -1: 324 | nn.init.orthogonal_(m.weight.data, 1.0) 325 | elif classname.find('BatchNorm') != -1: 326 | m.weight.data.normal_(1.0, 0.02) 327 | m.bias.data.fill_(0) 328 | elif classname.find('Linear') != -1: 329 | nn.init.orthogonal_(m.weight.data, 1.0) 330 | if m.bias is not None: 331 | m.bias.data.fill_(0.0) 332 | 333 | 334 | def load_params(model, new_param): 335 | for p, new_p in zip(model.parameters(), new_param): 336 | p.data.copy_(new_p) 337 | 338 | 339 | def copy_G_params(model): 340 | flatten = deepcopy(list(p.data for p in model.parameters())) 341 | return flatten 342 | 343 | 344 | def mkdir_p(path): 345 | try: 346 | os.makedirs(path) 347 | except OSError as exc: # Python >2.5 348 | if exc.errno == errno.EEXIST and os.path.isdir(path): 349 | pass 350 | else: 351 | raise 352 | -------------------------------------------------------------------------------- /code/coco/attngan/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import errno 3 | import numpy as np 4 | from torch.nn import init 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from PIL import Image, ImageDraw, ImageFont 10 | from copy import deepcopy 11 | import skimage.transform 12 | 13 | from miscc.config import cfg 14 | 15 | 16 | def compute_transformation_matrix_inverse(bbox): 17 | x, y = bbox[:, 0], bbox[:, 1] 18 | w, h = bbox[:, 2], bbox[:, 3] 19 | 20 | scale_x = 1.0 / w 21 | scale_y = 1.0 / h 22 | 23 | t_x = 2 * scale_x * (0.5 - (x + 0.5 * w)) 24 | t_y = 2 * scale_y * (0.5 - (y + 0.5 * h)) 25 | 26 | zeros = torch.FloatTensor(bbox.shape[0],1).fill_(0) 27 | 28 | transformation_matrix = torch.cat([scale_x.unsqueeze(-1), zeros, t_x.unsqueeze(-1), 29 | zeros, scale_y.unsqueeze(-1), t_y.unsqueeze(-1)], 1).view(-1, 2, 3) 30 | 31 | return transformation_matrix 32 | 33 | 34 | def compute_transformation_matrix(bbox): 35 | x, y = bbox[:, 0], bbox[:, 1] 36 | w, h = bbox[:, 2], bbox[:, 3] 37 | 38 | scale_x = w 39 | scale_y = h 40 | 41 | t_x = 2 * ((x + 0.5 * w) - 0.5) 42 | t_y = 2 * ((y + 0.5 * h) - 0.5) 43 | 44 | zeros = torch.FloatTensor(bbox.shape[0],1).fill_(0) 45 | 46 | transformation_matrix = torch.cat([scale_x.unsqueeze(-1), zeros, t_x.unsqueeze(-1), 47 | zeros, scale_y.unsqueeze(-1), t_y.unsqueeze(-1)], 1).view(-1, 2, 3) 48 | 49 | return transformation_matrix 50 | 51 | # For visualization ################################################ 52 | COLOR_DIC = {0:[128,64,128], 1:[244, 35,232], 53 | 2:[70, 70, 70], 3:[102,102,156], 54 | 4:[190,153,153], 5:[153,153,153], 55 | 6:[250,170, 30], 7:[220, 220, 0], 56 | 8:[107,142, 35], 9:[152,251,152], 57 | 10:[70,130,180], 11:[220,20, 60], 58 | 12:[255, 0, 0], 13:[0, 0, 142], 59 | 14:[119,11, 32], 15:[0, 60,100], 60 | 16:[0, 80, 100], 17:[0, 0, 230], 61 | 18:[0, 0, 70], 19:[0, 0, 0]} 62 | FONT_MAX = 50 63 | 64 | 65 | def drawCaption(convas, captions, ixtoword, vis_size, off1=2, off2=2): 66 | num = captions.size(0) 67 | img_txt = Image.fromarray(convas) 68 | # get a font 69 | # fnt = None # ImageFont.truetype('Pillow/Tests/fonts/FreeMono.ttf', 50) 70 | fnt = ImageFont.truetype('Pillow/Tests/fonts/FreeMono.ttf', 50) 71 | # get a drawing context 72 | d = ImageDraw.Draw(img_txt) 73 | sentence_list = [] 74 | for i in range(num): 75 | cap = captions[i].data.cpu().numpy() 76 | sentence = [] 77 | for j in range(len(cap)): 78 | if cap[j] == 0: 79 | break 80 | word = ixtoword[cap[j]].encode('ascii', 'ignore').decode('ascii') 81 | d.text(((j + off1) * (vis_size + off2), i * FONT_MAX), '%d:%s' % (j, word[:6]), 82 | font=fnt, fill=(255, 255, 255, 255)) 83 | sentence.append(word) 84 | sentence_list.append(sentence) 85 | return img_txt, sentence_list 86 | 87 | 88 | def build_super_images(real_imgs, captions, ixtoword, 89 | attn_maps, att_sze, lr_imgs=None, 90 | batch_size=cfg.TRAIN.BATCH_SIZE, 91 | max_word_num=cfg.TEXT.WORDS_NUM): 92 | nvis = 8 93 | real_imgs = real_imgs[:nvis] 94 | if lr_imgs is not None: 95 | lr_imgs = lr_imgs[:nvis] 96 | if att_sze == 17: 97 | vis_size = att_sze * 16 98 | else: 99 | vis_size = real_imgs.size(2) 100 | 101 | text_convas = \ 102 | np.ones([batch_size * FONT_MAX, 103 | (max_word_num + 2) * (vis_size + 2), 3], 104 | dtype=np.uint8) 105 | 106 | for i in range(max_word_num): 107 | istart = (i + 2) * (vis_size + 2) 108 | iend = (i + 3) * (vis_size + 2) 109 | text_convas[:, istart:iend, :] = COLOR_DIC[i] 110 | 111 | 112 | real_imgs = \ 113 | nn.Upsample(size=(vis_size, vis_size), mode='bilinear')(real_imgs) 114 | # [-1, 1] --> [0, 1] 115 | real_imgs.add_(1).div_(2).mul_(255) 116 | real_imgs = real_imgs.data.numpy() 117 | # b x c x h x w --> b x h x w x c 118 | real_imgs = np.transpose(real_imgs, (0, 2, 3, 1)) 119 | pad_sze = real_imgs.shape 120 | middle_pad = np.zeros([pad_sze[2], 2, 3]) 121 | post_pad = np.zeros([pad_sze[1], pad_sze[2], 3]) 122 | if lr_imgs is not None: 123 | lr_imgs = \ 124 | nn.Upsample(size=(vis_size, vis_size), mode='bilinear')(lr_imgs) 125 | # [-1, 1] --> [0, 1] 126 | lr_imgs.add_(1).div_(2).mul_(255) 127 | lr_imgs = lr_imgs.data.numpy() 128 | # b x c x h x w --> b x h x w x c 129 | lr_imgs = np.transpose(lr_imgs, (0, 2, 3, 1)) 130 | 131 | # batch x seq_len x 17 x 17 --> batch x 1 x 17 x 17 132 | seq_len = max_word_num 133 | img_set = [] 134 | num = nvis # len(attn_maps) 135 | 136 | text_map, sentences = \ 137 | drawCaption(text_convas, captions, ixtoword, vis_size) 138 | text_map = np.asarray(text_map).astype(np.uint8) 139 | 140 | bUpdate = 1 141 | for i in range(num): 142 | attn = attn_maps[i].cpu().view(1, -1, att_sze, att_sze) 143 | # --> 1 x 1 x 17 x 17 144 | attn_max = attn.max(dim=1, keepdim=True) 145 | attn = torch.cat([attn_max[0], attn], 1) 146 | # 147 | attn = attn.view(-1, 1, att_sze, att_sze) 148 | attn = attn.repeat(1, 3, 1, 1).data.numpy() 149 | # n x c x h x w --> n x h x w x c 150 | attn = np.transpose(attn, (0, 2, 3, 1)) 151 | num_attn = attn.shape[0] 152 | # 153 | img = real_imgs[i] 154 | if lr_imgs is None: 155 | lrI = img 156 | else: 157 | lrI = lr_imgs[i] 158 | row = [lrI, middle_pad] 159 | row_merge = [img, middle_pad] 160 | row_beforeNorm = [] 161 | minVglobal, maxVglobal = 1, 0 162 | for j in range(num_attn): 163 | one_map = attn[j] 164 | if (vis_size // att_sze) > 1: 165 | one_map = \ 166 | skimage.transform.pyramid_expand(one_map, sigma=20, 167 | upscale=vis_size // att_sze) 168 | row_beforeNorm.append(one_map) 169 | minV = one_map.min() 170 | maxV = one_map.max() 171 | if minVglobal > minV: 172 | minVglobal = minV 173 | if maxVglobal < maxV: 174 | maxVglobal = maxV 175 | for j in range(seq_len + 1): 176 | if j < num_attn: 177 | one_map = row_beforeNorm[j] 178 | one_map = (one_map - minVglobal) / (maxVglobal - minVglobal) 179 | one_map *= 255 180 | # 181 | PIL_im = Image.fromarray(np.uint8(img)) 182 | PIL_att = Image.fromarray(np.uint8(one_map)) 183 | merged = \ 184 | Image.new('RGBA', (vis_size, vis_size), (0, 0, 0, 0)) 185 | mask = Image.new('L', (vis_size, vis_size), (210)) 186 | merged.paste(PIL_im, (0, 0)) 187 | merged.paste(PIL_att, (0, 0), mask) 188 | merged = np.array(merged)[:, :, :3] 189 | else: 190 | one_map = post_pad 191 | merged = post_pad 192 | row.append(one_map) 193 | row.append(middle_pad) 194 | # 195 | row_merge.append(merged) 196 | row_merge.append(middle_pad) 197 | row = np.concatenate(row, 1) 198 | row_merge = np.concatenate(row_merge, 1) 199 | txt = text_map[i * FONT_MAX: (i + 1) * FONT_MAX] 200 | if txt.shape[1] != row.shape[1]: 201 | print('txt', txt.shape, 'row', row.shape) 202 | bUpdate = 0 203 | break 204 | row = np.concatenate([txt, row, row_merge], 0) 205 | img_set.append(row) 206 | if bUpdate: 207 | img_set = np.concatenate(img_set, 0) 208 | img_set = img_set.astype(np.uint8) 209 | return img_set, sentences 210 | else: 211 | return None 212 | 213 | 214 | def build_super_images2(real_imgs, captions, cap_lens, ixtoword, 215 | attn_maps, att_sze, vis_size=256, topK=5): 216 | batch_size = real_imgs.size(0) 217 | max_word_num = np.max(cap_lens) 218 | text_convas = np.ones([batch_size * FONT_MAX, 219 | max_word_num * (vis_size + 2), 3], 220 | dtype=np.uint8) 221 | 222 | real_imgs = \ 223 | nn.Upsample(size=(vis_size, vis_size), mode='bilinear')(real_imgs) 224 | # [-1, 1] --> [0, 1] 225 | real_imgs.add_(1).div_(2).mul_(255) 226 | real_imgs = real_imgs.data.numpy() 227 | # b x c x h x w --> b x h x w x c 228 | real_imgs = np.transpose(real_imgs, (0, 2, 3, 1)) 229 | pad_sze = real_imgs.shape 230 | middle_pad = np.zeros([pad_sze[2], 2, 3]) 231 | 232 | # batch x seq_len x 17 x 17 --> batch x 1 x 17 x 17 233 | img_set = [] 234 | num = len(attn_maps) 235 | 236 | text_map, sentences = \ 237 | drawCaption(text_convas, captions, ixtoword, vis_size, off1=0) 238 | text_map = np.asarray(text_map).astype(np.uint8) 239 | 240 | bUpdate = 1 241 | for i in range(num): 242 | attn = attn_maps[i].cpu().view(1, -1, att_sze, att_sze) 243 | # 244 | attn = attn.view(-1, 1, att_sze, att_sze) 245 | attn = attn.repeat(1, 3, 1, 1).data.numpy() 246 | # n x c x h x w --> n x h x w x c 247 | attn = np.transpose(attn, (0, 2, 3, 1)) 248 | num_attn = cap_lens[i] 249 | thresh = 2./float(num_attn) 250 | # 251 | img = real_imgs[i] 252 | row = [] 253 | row_merge = [] 254 | row_txt = [] 255 | row_beforeNorm = [] 256 | conf_score = [] 257 | for j in range(num_attn): 258 | one_map = attn[j] 259 | mask0 = one_map > (2. * thresh) 260 | conf_score.append(np.sum(one_map * mask0)) 261 | mask = one_map > thresh 262 | one_map = one_map * mask 263 | if (vis_size // att_sze) > 1: 264 | one_map = \ 265 | skimage.transform.pyramid_expand(one_map, sigma=20, 266 | upscale=vis_size // att_sze) 267 | minV = one_map.min() 268 | maxV = one_map.max() 269 | one_map = (one_map - minV) / (maxV - minV) 270 | row_beforeNorm.append(one_map) 271 | sorted_indices = np.argsort(conf_score)[::-1] 272 | 273 | for j in range(num_attn): 274 | one_map = row_beforeNorm[j] 275 | one_map *= 255 276 | # 277 | PIL_im = Image.fromarray(np.uint8(img)) 278 | PIL_att = Image.fromarray(np.uint8(one_map)) 279 | merged = \ 280 | Image.new('RGBA', (vis_size, vis_size), (0, 0, 0, 0)) 281 | mask = Image.new('L', (vis_size, vis_size), (180)) # (210) 282 | merged.paste(PIL_im, (0, 0)) 283 | merged.paste(PIL_att, (0, 0), mask) 284 | merged = np.array(merged)[:, :, :3] 285 | 286 | row.append(np.concatenate([one_map, middle_pad], 1)) 287 | # 288 | row_merge.append(np.concatenate([merged, middle_pad], 1)) 289 | # 290 | txt = text_map[i * FONT_MAX:(i + 1) * FONT_MAX, 291 | j * (vis_size + 2):(j + 1) * (vis_size + 2), :] 292 | row_txt.append(txt) 293 | # reorder 294 | row_new = [] 295 | row_merge_new = [] 296 | txt_new = [] 297 | for j in range(num_attn): 298 | idx = sorted_indices[j] 299 | row_new.append(row[idx]) 300 | row_merge_new.append(row_merge[idx]) 301 | txt_new.append(row_txt[idx]) 302 | row = np.concatenate(row_new[:topK], 1) 303 | row_merge = np.concatenate(row_merge_new[:topK], 1) 304 | txt = np.concatenate(txt_new[:topK], 1) 305 | if txt.shape[1] != row.shape[1]: 306 | print('Warnings: txt', txt.shape, 'row', row.shape, 307 | 'row_merge_new', row_merge_new.shape) 308 | bUpdate = 0 309 | break 310 | row = np.concatenate([txt, row_merge], 0) 311 | img_set.append(row) 312 | if bUpdate: 313 | img_set = np.concatenate(img_set, 0) 314 | img_set = img_set.astype(np.uint8) 315 | return img_set, sentences 316 | else: 317 | return None 318 | 319 | 320 | #################################################################### 321 | def weights_init(m): 322 | classname = m.__class__.__name__ 323 | if classname.find('Conv') != -1: 324 | nn.init.orthogonal_(m.weight.data, 1.0) 325 | elif classname.find('BatchNorm') != -1: 326 | m.weight.data.normal_(1.0, 0.02) 327 | m.bias.data.fill_(0) 328 | elif classname.find('Linear') != -1: 329 | nn.init.orthogonal_(m.weight.data, 1.0) 330 | if m.bias is not None: 331 | m.bias.data.fill_(0.0) 332 | 333 | 334 | def load_params(model, new_param): 335 | for p, new_p in zip(model.parameters(), new_param): 336 | p.data.copy_(new_p) 337 | 338 | 339 | def copy_G_params(model): 340 | flatten = deepcopy(list(p.data for p in model.parameters())) 341 | return flatten 342 | 343 | 344 | def mkdir_p(path): 345 | try: 346 | os.makedirs(path) 347 | except OSError as exc: # Python >2.5 348 | if exc.errno == errno.EEXIST and os.path.isdir(path): 349 | pass 350 | else: 351 | raise 352 | -------------------------------------------------------------------------------- /code/coco/stackgan/cfg/coco_s1_eval.yml: -------------------------------------------------------------------------------- 1 | CONFIG_NAME: 'stageI' 2 | 3 | DATASET_NAME: 'coco' 4 | EMBEDDING_TYPE: 'cnn-rnn' 5 | GPU_ID: '0' 6 | Z_DIM: 100 7 | NET_G: '../../../models/pretrained_model.pth' 8 | IMG_DIR: "../../../data/MS-COCO/test/val2014" 9 | DATA_DIR: '../../../data/MS-COCO' 10 | WORKERS: 4 11 | IMSIZE: 256 12 | STAGE: 2 13 | USE_BBOX_LAYOUT: True 14 | TRAIN: 15 | FLAG: False 16 | BATCH_SIZE: 100 17 | 18 | GAN: 19 | CONDITION_DIM: 128 20 | DF_DIM: 96 21 | GF_DIM: 192 22 | R_NUM: 2 23 | 24 | TEXT: 25 | DIMENSION: 1024 26 | -------------------------------------------------------------------------------- /code/coco/stackgan/cfg/coco_s1_train.yml: -------------------------------------------------------------------------------- 1 | CONFIG_NAME: 'stageI' 2 | 3 | DATASET_NAME: 'coco' 4 | EMBEDDING_TYPE: 'cnn-rnn' 5 | GPU_ID: '0,1' 6 | Z_DIM: 100 7 | IMG_DIR: "../../../data/MS-COCO/train/train2014" 8 | DATA_DIR: '../../../data/MS-COCO' 9 | IMSIZE: 64 10 | WORKERS: 12 11 | STAGE: 1 12 | USE_BBOX_LAYOUT: True 13 | TRAIN: 14 | FLAG: True 15 | BATCH_SIZE: 128 16 | MAX_EPOCH: 120 17 | LR_DECAY_EPOCH: 20 18 | SNAPSHOT_INTERVAL: 10 19 | DISCRIMINATOR_LR: 0.0002 20 | GENERATOR_LR: 0.0002 21 | COEFF: 22 | KL: 2.0 23 | 24 | GAN: 25 | CONDITION_DIM: 128 26 | DF_DIM: 96 27 | GF_DIM: 192 28 | 29 | TEXT: 30 | DIMENSION: 1024 31 | -------------------------------------------------------------------------------- /code/coco/stackgan/cfg/coco_s2_eval.yml: -------------------------------------------------------------------------------- 1 | CONFIG_NAME: 'stageI' 2 | 3 | DATASET_NAME: 'coco' 4 | EMBEDDING_TYPE: 'cnn-rnn' 5 | GPU_ID: '0' 6 | Z_DIM: 100 7 | NET_G: '../../../models/model-coco-stackgan-stage-ii-0110.pth' 8 | IMG_DIR: "../../../data/MS-COCO/test/val2014" 9 | DATA_DIR: '../../../data/MS-COCO' 10 | WORKERS: 4 11 | IMSIZE: 256 12 | STAGE: 2 13 | USE_BBOX_LAYOUT: True 14 | TRAIN: 15 | FLAG: False 16 | BATCH_SIZE: 100 17 | 18 | GAN: 19 | CONDITION_DIM: 128 20 | DF_DIM: 96 21 | GF_DIM: 192 22 | R_NUM: 2 23 | 24 | TEXT: 25 | DIMENSION: 1024 26 | -------------------------------------------------------------------------------- /code/coco/stackgan/cfg/coco_s2_train.yml: -------------------------------------------------------------------------------- 1 | CONFIG_NAME: 'stageII' 2 | 3 | DATASET_NAME: 'coco' 4 | EMBEDDING_TYPE: 'cnn-rnn' 5 | GPU_ID: '0,1' 6 | Z_DIM: 100 7 | STAGE1_G: '../../../data/models/pretrained_model.pth' 8 | IMG_DIR: "../../../data/MS-COCO/train/train2014" 9 | DATA_DIR: '../../../data/MS-COCO' 10 | IMSIZE: 256 11 | WORKERS: 16 12 | STAGE: 2 13 | USE_BBOX_LAYOUT: True 14 | TRAIN: 15 | FLAG: True 16 | BATCH_SIZE: 40 17 | MAX_EPOCH: 120 18 | LR_DECAY_EPOCH: 20 19 | SNAPSHOT_INTERVAL: 10 20 | DISCRIMINATOR_LR: 0.0002 21 | GENERATOR_LR: 0.0002 22 | COEFF: 23 | KL: 2.0 24 | 25 | GAN: 26 | CONDITION_DIM: 128 27 | DF_DIM: 96 28 | GF_DIM: 192 29 | R_NUM: 2 30 | 31 | TEXT: 32 | DIMENSION: 1024 33 | -------------------------------------------------------------------------------- /code/coco/stackgan/main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch.backends.cudnn as cudnn 3 | import torch 4 | import torchvision.transforms as transforms 5 | import torch.backends.cudnn as cudnn 6 | 7 | import argparse 8 | import os 9 | import random 10 | import sys 11 | import pprint 12 | import datetime 13 | import dateutil 14 | import dateutil.tz 15 | from shutil import copyfile 16 | 17 | 18 | dir_path = (os.path.abspath(os.path.join(os.path.realpath(__file__), './.'))) 19 | sys.path.append(dir_path) 20 | 21 | from miscc.datasets import TextDataset 22 | from miscc.config import cfg, cfg_from_file 23 | from miscc.utils import mkdir_p 24 | from trainer import GANTrainer 25 | 26 | 27 | def parse_args(): 28 | parser = argparse.ArgumentParser(description='Train a GAN network') 29 | parser.add_argument('--cfg', dest='cfg_file', 30 | help='optional config file', 31 | default='birds_stage1.yml', type=str) 32 | parser.add_argument('--gpu', dest='gpu_id', type=str, default='0') 33 | parser.add_argument('--data_dir', dest='data_dir', type=str, default='') 34 | parser.add_argument('--manualSeed', type=int, help='manual seed') 35 | args = parser.parse_args() 36 | return args 37 | 38 | if __name__ == "__main__": 39 | args = parse_args() 40 | if args.cfg_file is not None: 41 | cfg_from_file(args.cfg_file) 42 | if args.gpu_id != -1: 43 | cfg.GPU_ID = args.gpu_id 44 | if args.data_dir != '': 45 | cfg.DATA_DIR = args.data_dir 46 | print('Using config:') 47 | pprint.pprint(cfg) 48 | if args.manualSeed is None: 49 | args.manualSeed = random.randint(1, 10000) 50 | random.seed(args.manualSeed) 51 | torch.manual_seed(args.manualSeed) 52 | if cfg.CUDA: 53 | torch.cuda.manual_seed_all(args.manualSeed) 54 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 55 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') 56 | output_dir = '../../..//output/%s_%s_%s' % \ 57 | (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp) 58 | 59 | cudnn.benchmark = True 60 | 61 | num_gpu = len(cfg.GPU_ID.split(',')) 62 | if cfg.TRAIN.FLAG: 63 | try: 64 | os.makedirs(output_dir) 65 | except OSError as exc: # Python >2.5 66 | if exc.errno == errno.EEXIST and os.path.isdir(path): 67 | pass 68 | else: 69 | raise 70 | 71 | copyfile(sys.argv[0], output_dir + "/" + sys.argv[0]) 72 | copyfile("trainer.py", output_dir + "/" + "trainer.py") 73 | copyfile("model.py", output_dir + "/" + "model.py") 74 | copyfile("miscc/utils.py", output_dir + "/" + "utils.py") 75 | copyfile("miscc/datasets.py", output_dir + "/" + "datasets.py") 76 | copyfile(args.cfg_file, output_dir + "/" + "cfg_file.yml") 77 | 78 | if cfg.STAGE == 1: 79 | resize = 76 80 | imsize=64 81 | elif cfg.STAGE == 2: 82 | resize = 268 83 | imsize = 256 84 | 85 | img_transform = transforms.Compose([ 86 | transforms.Resize((resize, resize)), 87 | transforms.ToTensor(), 88 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 89 | dataset = TextDataset(cfg.DATA_DIR, cfg.IMG_DIR, split="train", imsize=imsize, transform=img_transform, 90 | crop=True, stage=cfg.STAGE) 91 | assert dataset 92 | dataloader = torch.utils.data.DataLoader( 93 | dataset, batch_size=cfg.TRAIN.BATCH_SIZE, 94 | drop_last=True, shuffle=True, num_workers=int(cfg.WORKERS)) 95 | 96 | algo = GANTrainer(output_dir) 97 | algo.train(dataloader, cfg.STAGE) 98 | else: 99 | datapath= '%s/test/' % (cfg.DATA_DIR) 100 | algo = GANTrainer(output_dir) 101 | algo.sample(datapath, num_samples=25, stage=cfg.STAGE, draw_bbox=True) 102 | -------------------------------------------------------------------------------- /code/coco/stackgan/miscc/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | -------------------------------------------------------------------------------- /code/coco/stackgan/miscc/config.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import os.path as osp 5 | import numpy as np 6 | from easydict import EasyDict as edict 7 | 8 | 9 | __C = edict() 10 | cfg = __C 11 | 12 | # Dataset name: flowers, birds 13 | __C.DATASET_NAME = 'coco' 14 | __C.EMBEDDING_TYPE = 'cnn-rnn' 15 | __C.CONFIG_NAME = '' 16 | __C.GPU_ID = '0' 17 | __C.CUDA = True 18 | __C.WORKERS = 6 19 | 20 | __C.NET_G = '' 21 | __C.NET_D = '' 22 | __C.STAGE1_G = '' 23 | __C.DATA_DIR = '' 24 | __C.IMG_DIR = '' 25 | __C.VIS_COUNT = 64 26 | 27 | __C.Z_DIM = 100 28 | __C.IMSIZE = 64 29 | __C.STAGE = 1 30 | 31 | __C.USE_LOCAL_PATHWAY = True 32 | __C.USE_BBOX_LAYOUT = True 33 | 34 | # Training options 35 | __C.TRAIN = edict() 36 | __C.TRAIN.FLAG = True 37 | __C.TRAIN.BATCH_SIZE = 64 38 | __C.TRAIN.MAX_EPOCH = 600 39 | __C.TRAIN.SNAPSHOT_INTERVAL = 50 40 | __C.TRAIN.PRETRAINED_MODEL = '' 41 | __C.TRAIN.PRETRAINED_EPOCH = 600 42 | __C.TRAIN.LR_DECAY_EPOCH = 600 43 | __C.TRAIN.DISCRIMINATOR_LR = 2e-4 44 | __C.TRAIN.GENERATOR_LR = 2e-4 45 | 46 | __C.TRAIN.COEFF = edict() 47 | __C.TRAIN.COEFF.KL = 2.0 48 | 49 | # Modal options 50 | __C.GAN = edict() 51 | __C.GAN.CONDITION_DIM = 128 52 | __C.GAN.DF_DIM = 64 53 | __C.GAN.GF_DIM = 128 54 | __C.GAN.R_NUM = 4 55 | 56 | __C.TEXT = edict() 57 | __C.TEXT.DIMENSION = 1024 58 | 59 | 60 | def _merge_a_into_b(a, b): 61 | """Merge config dictionary a into config dictionary b, clobbering the 62 | options in b whenever they are also specified in a. 63 | """ 64 | if type(a) is not edict: 65 | return 66 | 67 | for k, v in a.iteritems(): 68 | # a must specify keys that are in b 69 | if not b.has_key(k): 70 | raise KeyError('{} is not a valid config key'.format(k)) 71 | 72 | # the types must match, too 73 | old_type = type(b[k]) 74 | if old_type is not type(v): 75 | if isinstance(b[k], np.ndarray): 76 | v = np.array(v, dtype=b[k].dtype) 77 | else: 78 | raise ValueError(('Type mismatch ({} vs. {}) ' 79 | 'for config key: {}').format(type(b[k]), 80 | type(v), k)) 81 | 82 | # recursively merge dicts 83 | if type(v) is edict: 84 | try: 85 | _merge_a_into_b(a[k], b[k]) 86 | except: 87 | print('Error under config key: {}'.format(k)) 88 | raise 89 | else: 90 | b[k] = v 91 | 92 | 93 | def cfg_from_file(filename): 94 | """Load a config file and merge it into the default options.""" 95 | import yaml 96 | with open(filename, 'r') as f: 97 | yaml_cfg = edict(yaml.load(f)) 98 | 99 | _merge_a_into_b(yaml_cfg, __C) 100 | -------------------------------------------------------------------------------- /code/coco/stackgan/miscc/datasets.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | 7 | import torch.utils.data as data 8 | import PIL 9 | import os 10 | import os.path 11 | import random 12 | import numpy as np 13 | from PIL import Image 14 | import torchvision.transforms as transforms 15 | import torch 16 | import sys 17 | if sys.version_info[0] == 2: 18 | import cPickle as pickle 19 | else: 20 | import pickle 21 | 22 | from miscc.config import cfg 23 | 24 | 25 | class TextDataset(data.Dataset): 26 | def __init__(self, data_dir, img_dir, imsize, split='train', embedding_type='cnn-rnn', transform=None, crop=True, stage=1): 27 | 28 | self.transform = transform 29 | self.imsize = imsize 30 | self.crop = crop 31 | self.data = [] 32 | self.data_dir = data_dir 33 | self.split_dir = os.path.join(data_dir, split) 34 | self.img_dir = img_dir 35 | self.max_objects = 3 36 | self.stage = stage 37 | 38 | self.filenames = self.load_filenames() 39 | self.bboxes = self.load_bboxes() 40 | self.labels = self.load_labels() 41 | self.embeddings = self.load_embedding(self.split_dir, embedding_type) 42 | 43 | def get_img(self, img_path): 44 | img = Image.open(img_path).convert('RGB') 45 | 46 | if self.transform is not None: 47 | img = self.transform(img) 48 | 49 | return img 50 | 51 | def load_bboxes(self): 52 | bbox_path = os.path.join(self.split_dir, 'bboxes.pickle') 53 | with open(bbox_path, "rb") as f: 54 | bboxes = pickle.load(f) 55 | bboxes = np.array(bboxes) 56 | return bboxes 57 | 58 | def load_labels(self): 59 | label_path = os.path.join(self.split_dir, 'labels.pickle') 60 | with open(label_path, "rb") as f: 61 | labels = pickle.load(f) 62 | labels = np.array(labels) 63 | return labels 64 | 65 | def load_all_captions(self): 66 | caption_dict = {} 67 | for key in self.filenames: 68 | caption_name = '%s/text/%s.txt' % (self.data_dir, key) 69 | captions = self.load_captions(caption_name) 70 | caption_dict[key] = captions 71 | return caption_dict 72 | 73 | def load_captions(self, caption_name): 74 | cap_path = caption_name 75 | with open(cap_path, "r") as f: 76 | captions = f.read().decode('utf8').split('\n') 77 | captions = [cap.replace("\ufffd\ufffd", " ") 78 | for cap in captions if len(cap) > 0] 79 | return captions 80 | 81 | def load_embedding(self, data_dir, embedding_type): 82 | if embedding_type == 'cnn-rnn': 83 | embedding_filename = '/char-CNN-RNN-embeddings.pickle' 84 | elif embedding_type == 'cnn-gru': 85 | embedding_filename = '/char-CNN-GRU-embeddings.pickle' 86 | elif embedding_type == 'skip-thought': 87 | embedding_filename = '/skip-thought-embeddings.pickle' 88 | 89 | with open(data_dir + embedding_filename, 'rb') as f: 90 | embeddings = pickle.load(f) 91 | embeddings = np.array(embeddings) 92 | return embeddings 93 | 94 | def load_filenames(self): 95 | filepath = os.path.join(self.split_dir, 'filenames.pickle') 96 | with open(filepath, 'rb') as f: 97 | filenames = pickle.load(f) 98 | print('Load filenames from: %s (%d)' % (filepath, len(filenames))) 99 | return filenames 100 | 101 | def crop_imgs(self, image, bbox): 102 | ori_size = image.shape[1] 103 | imsize = self.imsize 104 | 105 | flip_img = random.random() < 0.5 106 | img_crop = ori_size - self.imsize 107 | h1 = int(np.floor((img_crop) * np.random.random())) 108 | w1 = int(np.floor((img_crop) * np.random.random())) 109 | 110 | if self.stage == 1: 111 | bbox_scaled = np.zeros_like(bbox) 112 | bbox_scaled[...] = -1.0 113 | 114 | for idx in range(self.max_objects): 115 | bbox_tmp = bbox[idx] 116 | if bbox_tmp[0] == -1: 117 | break 118 | 119 | x_new = max(bbox_tmp[0] * float(ori_size) - h1, 0) / float(imsize) 120 | y_new = max(bbox_tmp[1] * float(ori_size) - w1, 0) / float(imsize) 121 | 122 | width_new = min((float(ori_size)/imsize) * bbox_tmp[2], 1.0) 123 | if x_new + width_new > 0.999: 124 | width_new = 1.0 - x_new - 0.001 125 | 126 | height_new = min((float(ori_size)/imsize) * bbox_tmp[3], 1.0) 127 | if y_new + height_new > 0.999: 128 | height_new = 1.0 - y_new - 0.001 129 | 130 | if flip_img: 131 | x_new = 1.0-x_new-width_new 132 | 133 | bbox_scaled[idx] = [x_new, y_new, width_new, height_new] 134 | else: 135 | # need two bboxes for stage 1 G and stage 2 G 136 | bbox_scaled = [np.zeros_like(bbox), np.zeros_like(bbox)] 137 | bbox_scaled[0][...] = -1.0 138 | bbox_scaled[1][...] = -1.0 139 | 140 | for idx in range(self.max_objects): 141 | bbox_tmp = bbox[idx] 142 | if bbox_tmp[0] == -1: 143 | break 144 | 145 | # scale bboxes for stage 1 G 146 | stage1_size = 64 147 | stage1_ori_size = 76 148 | x_new = max(bbox_tmp[0] * float(stage1_ori_size) - h1, 0) / float(stage1_size) 149 | y_new = max(bbox_tmp[1] * float(stage1_ori_size) - w1, 0) / float(stage1_size) 150 | 151 | width_new = min((float(stage1_ori_size) / stage1_size) * bbox_tmp[2], 1.0) 152 | if x_new + width_new > 0.999: 153 | width_new = 1.0 - x_new - 0.001 154 | 155 | height_new = min((float(stage1_ori_size) / stage1_size) * bbox_tmp[3], 1.0) 156 | if y_new + height_new > 0.999: 157 | height_new = 1.0 - y_new - 0.001 158 | 159 | if flip_img: 160 | x_new = 1.0 - x_new - width_new 161 | 162 | bbox_scaled[0][idx] = [x_new, y_new, width_new, height_new] 163 | 164 | # scale bboxes for stage 2 G 165 | x_new = max(bbox_tmp[0] * float(ori_size) - h1, 0) / float(imsize) 166 | y_new = max(bbox_tmp[1] * float(ori_size) - w1, 0) / float(imsize) 167 | 168 | width_new = min((float(ori_size) / imsize) * bbox_tmp[2], 1.0) 169 | if x_new + width_new > 0.999: 170 | width_new = 1.0 - x_new - 0.001 171 | 172 | height_new = min((float(ori_size) / imsize) * bbox_tmp[3], 1.0) 173 | if y_new + height_new > 0.999: 174 | height_new = 1.0 - y_new - 0.001 175 | 176 | if flip_img: 177 | x_new = 1.0 - x_new - width_new 178 | 179 | bbox_scaled[1][idx] = [x_new, y_new, width_new, height_new] 180 | 181 | 182 | cropped_image = image[:, w1: w1 + imsize, h1: h1 + imsize] 183 | 184 | if flip_img: 185 | idx = [i for i in reversed(range(cropped_image.shape[2]))] 186 | idx = torch.LongTensor(idx) 187 | transformed_image = torch.index_select(cropped_image, 2, idx) 188 | else: 189 | transformed_image = cropped_image 190 | 191 | return transformed_image, bbox_scaled 192 | 193 | 194 | def __getitem__(self, index): 195 | # load image 196 | key = self.filenames[index] 197 | img_name = self.img_dir +"/" + key + ".jpg" 198 | img = self.get_img(img_name) 199 | 200 | # load bbox 201 | bbox = self.bboxes[index] 202 | 203 | # load label 204 | label = self.labels[index] 205 | 206 | # load caption embedding 207 | embeddings = self.embeddings[index, :, :] 208 | embedding_ix = random.randint(0, embeddings.shape[0]-1) 209 | embedding = embeddings[embedding_ix, :] 210 | 211 | if self.crop: 212 | img, bbox = self.crop_imgs(img, bbox) 213 | 214 | return img, bbox, label, embedding 215 | 216 | def __len__(self): 217 | return len(self.filenames) 218 | -------------------------------------------------------------------------------- /code/coco/stackgan/miscc/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import errno 3 | import numpy as np 4 | import cPickle as pickle 5 | import glob 6 | 7 | from copy import deepcopy 8 | from miscc.config import cfg 9 | 10 | from torch.nn import init 11 | import torch 12 | import torch.nn as nn 13 | import torchvision.utils as vutils 14 | from torch.autograd import grad 15 | from torch.autograd import Variable 16 | 17 | 18 | def compute_transformation_matrix_inverse(bbox): 19 | x, y = bbox[:, 0], bbox[:, 1] 20 | w, h = bbox[:, 2], bbox[:, 3] 21 | 22 | scale_x = 1.0 / w 23 | scale_y = 1.0 / h 24 | 25 | t_x = 2 * scale_x * (0.5 - (x + 0.5 * w)) 26 | t_y = 2 * scale_y * (0.5 - (y + 0.5 * h)) 27 | 28 | zeros = torch.cuda.FloatTensor(bbox.shape[0],1).fill_(0) 29 | 30 | transformation_matrix = torch.cat([scale_x.unsqueeze(-1), zeros, t_x.unsqueeze(-1), 31 | zeros, scale_y.unsqueeze(-1), t_y.unsqueeze(-1)], 1).view(-1, 2, 3) 32 | 33 | return transformation_matrix 34 | 35 | 36 | def compute_transformation_matrix(bbox): 37 | x, y = bbox[:, 0], bbox[:, 1] 38 | w, h = bbox[:, 2], bbox[:, 3] 39 | 40 | scale_x = w 41 | scale_y = h 42 | 43 | t_x = 2 * ((x + 0.5 * w) - 0.5) 44 | t_y = 2 * ((y + 0.5 * h) - 0.5) 45 | 46 | zeros = torch.cuda.FloatTensor(bbox.shape[0],1).fill_(0) 47 | 48 | transformation_matrix = torch.cat([scale_x.unsqueeze(-1), zeros, t_x.unsqueeze(-1), 49 | zeros, scale_y.unsqueeze(-1), t_y.unsqueeze(-1)], 1).view(-1, 2, 3) 50 | 51 | return transformation_matrix 52 | 53 | 54 | def load_validation_data(datapath, ori_size=76, imsize=64): 55 | 56 | with open(datapath + "bboxes.pickle", "rb") as f: 57 | bboxes = pickle.load(f) 58 | bboxes = np.array(bboxes) 59 | 60 | with open(datapath + "labels.pickle", "rb") as f: 61 | labels = pickle.load(f) 62 | labels = np.array(labels) 63 | 64 | return torch.from_numpy(labels), torch.from_numpy(bboxes) 65 | 66 | 67 | ############################# 68 | def KL_loss(mu, logvar): 69 | KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) 70 | KLD = torch.mean(KLD_element).mul_(-0.5) 71 | return KLD 72 | 73 | 74 | def compute_discriminator_loss(netD, real_imgs, fake_imgs, 75 | real_labels, fake_labels, 76 | local_label, transf_matrices, transf_matrices_inv, 77 | conditions, gpus): 78 | criterion = nn.BCEWithLogitsLoss() 79 | batch_size = real_imgs.size(0) 80 | cond = conditions.detach() 81 | fake = fake_imgs.detach() 82 | local_label = local_label.detach() 83 | real_features = nn.parallel.data_parallel(netD, (real_imgs, local_label, transf_matrices, transf_matrices_inv), gpus) 84 | fake_features = nn.parallel.data_parallel(netD, (fake, local_label, transf_matrices, transf_matrices_inv), gpus) 85 | # real pairs 86 | inputs = (real_features, cond) 87 | real_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus) 88 | errD_real = criterion(real_logits, real_labels) 89 | # wrong pairs 90 | inputs = (real_features[:(batch_size-1)], cond[1:]) 91 | wrong_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus) 92 | errD_wrong = criterion(wrong_logits, fake_labels[1:]) 93 | # fake pairs 94 | inputs = (fake_features, cond) 95 | fake_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus) 96 | errD_fake = criterion(fake_logits, fake_labels) 97 | 98 | if netD.get_uncond_logits is not None: 99 | real_logits = nn.parallel.data_parallel(netD.get_uncond_logits, (real_features), gpus) 100 | fake_logits = nn.parallel.data_parallel(netD.get_uncond_logits, (fake_features), gpus) 101 | uncond_errD_real = criterion(real_logits, real_labels) 102 | uncond_errD_fake = criterion(fake_logits, fake_labels) 103 | # 104 | errD = ((errD_real + uncond_errD_real) / 2. + 105 | (errD_fake + errD_wrong + uncond_errD_fake) / 3.) 106 | errD_real = (errD_real + uncond_errD_real) / 2. 107 | errD_fake = (errD_fake + uncond_errD_fake) / 2. 108 | else: 109 | errD = errD_real + (errD_fake + errD_wrong) * 0.5 110 | return errD, errD_real.item(), errD_wrong.item(), errD_fake.item() 111 | 112 | 113 | def compute_generator_loss(netD, fake_imgs, real_labels, local_label, transf_matrices, transf_matrices_inv, conditions, gpus): 114 | criterion = nn.BCEWithLogitsLoss() 115 | cond = conditions.detach() 116 | fake_features = nn.parallel.data_parallel(netD, (fake_imgs, local_label, transf_matrices, transf_matrices_inv), gpus) 117 | # fake pairs 118 | inputs = (fake_features, cond) 119 | fake_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus) 120 | errD_fake = criterion(fake_logits, real_labels) 121 | if netD.get_uncond_logits is not None: 122 | fake_logits = nn.parallel.data_parallel(netD.get_uncond_logits, (fake_features), gpus) 123 | uncond_errD_fake = criterion(fake_logits, real_labels) 124 | errD_fake += uncond_errD_fake 125 | return errD_fake 126 | 127 | 128 | ############################# 129 | def weights_init(m): 130 | classname = m.__class__.__name__ 131 | if classname.find('Conv') != -1: 132 | m.weight.data.normal_(0.0, 0.02) 133 | elif classname.find('BatchNorm') != -1: 134 | m.weight.data.normal_(1.0, 0.02) 135 | m.bias.data.fill_(0) 136 | elif classname.find('Linear') != -1: 137 | m.weight.data.normal_(0.0, 0.02) 138 | if m.bias is not None: 139 | m.bias.data.fill_(0.0) 140 | 141 | 142 | ############################# 143 | def save_img_results(data_img, fake, epoch, image_dir): 144 | num = cfg.VIS_COUNT 145 | fake = fake[0:num] 146 | # data_img is changed to [0,1] 147 | if data_img is not None: 148 | data_img = data_img[0:num] 149 | vutils.save_image( 150 | data_img, '%s/real_samples.png' % image_dir, 151 | normalize=True) 152 | # fake.data is still [-1, 1] 153 | vutils.save_image( 154 | fake.data, '%s/fake_samples_epoch_%03d.png' % 155 | (image_dir, epoch), normalize=True) 156 | else: 157 | vutils.save_image( 158 | fake.data, '%s/lr_fake_samples_epoch_%03d.png' % 159 | (image_dir, epoch), normalize=True) 160 | 161 | 162 | def save_model(netG, netD, optimG, optimD, epoch, model_dir, saveD=False, saveOptim=False, max_to_keep=5): 163 | checkpoint = { 164 | 'epoch': epoch, 165 | 'netG': netG.state_dict(), 166 | 'optimG': optimG.state_dict() if saveOptim else {}, 167 | 'netD': netD.state_dict() if saveD else {}, 168 | 'optimD': optimD.state_dict() if saveOptim else {}} 169 | torch.save(checkpoint, "{}/checkpoint_{:04}.pth".format(model_dir, epoch)) 170 | print('Save G/D models') 171 | 172 | if max_to_keep is not None and max_to_keep > 0: 173 | checkpoint_list = sorted([ckpt for ckpt in glob.glob(model_dir + "/" + '*.pth')]) 174 | while len(checkpoint_list) > max_to_keep: 175 | os.remove(checkpoint_list[0]) 176 | checkpoint_list = checkpoint_list[1:] 177 | 178 | 179 | def mkdir_p(path): 180 | try: 181 | os.makedirs(path) 182 | except OSError as exc: # Python >2.5 183 | if exc.errno == errno.EEXIST and os.path.isdir(path): 184 | pass 185 | else: 186 | raise 187 | -------------------------------------------------------------------------------- /code/multi-mnist/cfg/mnist_eval.yml: -------------------------------------------------------------------------------- 1 | DATASET_NAME: 'multi-mnist' 2 | GPU_ID: '0' 3 | Z_DIM: 100 4 | NET_G: '../../models/model-multi-mnist-0019.pth' 5 | DATA_DIR: '../../data/Multi-MNIST' 6 | WORKERS: 4 7 | IMSIZE: 64 8 | USE_BBOX_LAYOUT: True 9 | 10 | TRAIN: 11 | FLAG: False 12 | BATCH_SIZE: 100 13 | 14 | GAN: 15 | CONDITION_DIM: 128 16 | DF_DIM: 64 17 | GF_DIM: 128 18 | R_NUM: 2 19 | 20 | -------------------------------------------------------------------------------- /code/multi-mnist/cfg/mnist_train.yml: -------------------------------------------------------------------------------- 1 | DATASET_NAME: 'multi-mnist' 2 | GPU_ID: '0' 3 | Z_DIM: 100 4 | DATA_DIR: '../../data/Multi-MNIST' 5 | IMSIZE: 64 6 | WORKERS: 4 7 | USE_BBOX_LAYOUT: True 8 | TRAIN: 9 | FLAG: True 10 | BATCH_SIZE: 128 11 | MAX_EPOCH: 20 12 | LR_DECAY_EPOCH: 10 13 | SNAPSHOT_INTERVAL: 5 14 | DISCRIMINATOR_LR: 0.0002 15 | GENERATOR_LR: 0.0002 16 | 17 | GAN: 18 | CONDITION_DIM: 128 19 | DF_DIM: 64 20 | GF_DIM: 128 21 | -------------------------------------------------------------------------------- /code/multi-mnist/main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch.backends.cudnn as cudnn 3 | import torch 4 | import torchvision.transforms as transforms 5 | import torch.backends.cudnn as cudnn 6 | 7 | import argparse 8 | import os 9 | import random 10 | import sys 11 | import pprint 12 | import datetime 13 | import dateutil 14 | import dateutil.tz 15 | from shutil import copyfile 16 | 17 | 18 | dir_path = (os.path.abspath(os.path.join(os.path.realpath(__file__), './.'))) 19 | sys.path.append(dir_path) 20 | 21 | from miscc.datasets import TextDataset 22 | from miscc.config import cfg, cfg_from_file 23 | from miscc.utils import mkdir_p 24 | from trainer import GANTrainer 25 | 26 | 27 | def parse_args(): 28 | parser = argparse.ArgumentParser(description='Train a GAN network') 29 | parser.add_argument('--cfg', dest='cfg_file', 30 | help='optional config file', 31 | default='birds_stage1.yml', type=str) 32 | parser.add_argument('--gpu', dest='gpu_id', type=str, default='0') 33 | parser.add_argument('--data_dir', dest='data_dir', type=str, default='') 34 | parser.add_argument('--manualSeed', type=int, help='manual seed') 35 | args = parser.parse_args() 36 | return args 37 | 38 | if __name__ == "__main__": 39 | args = parse_args() 40 | if args.cfg_file is not None: 41 | cfg_from_file(args.cfg_file) 42 | if args.gpu_id != -1: 43 | cfg.GPU_ID = args.gpu_id 44 | if args.data_dir != '': 45 | cfg.DATA_DIR = args.data_dir 46 | print('Using config:') 47 | pprint.pprint(cfg) 48 | if args.manualSeed is None: 49 | args.manualSeed = random.randint(1, 10000) 50 | random.seed(args.manualSeed) 51 | torch.manual_seed(args.manualSeed) 52 | if cfg.CUDA: 53 | torch.cuda.manual_seed_all(args.manualSeed) 54 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 55 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') 56 | output_dir = '../../output/%s_%s' % \ 57 | (cfg.DATASET_NAME, timestamp) 58 | 59 | cudnn.benchmark = True 60 | 61 | num_gpu = len(cfg.GPU_ID.split(',')) 62 | if cfg.TRAIN.FLAG: 63 | try: 64 | os.makedirs(output_dir) 65 | except OSError as exc: # Python >2.5 66 | if exc.errno == errno.EEXIST and os.path.isdir(path): 67 | pass 68 | else: 69 | raise 70 | 71 | copyfile(sys.argv[0], output_dir + "/" + sys.argv[0]) 72 | copyfile("trainer.py", output_dir + "/" + "trainer.py") 73 | copyfile("model.py", output_dir + "/" + "model.py") 74 | copyfile("miscc/utils.py", output_dir + "/" + "utils.py") 75 | copyfile("miscc/datasets.py", output_dir + "/" + "datasets.py") 76 | copyfile(args.cfg_file, output_dir + "/" + "cfg_file.yml") 77 | 78 | imsize=64 79 | 80 | img_transform = transforms.Compose([ 81 | transforms.ToTensor(), 82 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 83 | dataset = TextDataset(cfg.DATA_DIR, split="train", imsize=imsize, transform=img_transform, 84 | crop=True) 85 | assert dataset 86 | dataloader = torch.utils.data.DataLoader( 87 | dataset, batch_size=cfg.TRAIN.BATCH_SIZE, 88 | drop_last=True, shuffle=True, num_workers=int(cfg.WORKERS)) 89 | 90 | algo = GANTrainer(output_dir) 91 | algo.train(dataloader) 92 | else: 93 | datapath = os.path.join(cfg.DATA_DIR, "test") 94 | algo = GANTrainer(output_dir) 95 | algo.sample(datapath, num_samples=25, draw_bbox=True) 96 | -------------------------------------------------------------------------------- /code/multi-mnist/miscc/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | -------------------------------------------------------------------------------- /code/multi-mnist/miscc/config.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import os.path as osp 5 | import numpy as np 6 | from easydict import EasyDict as edict 7 | 8 | 9 | __C = edict() 10 | cfg = __C 11 | 12 | # Dataset name: flowers, birds 13 | __C.DATASET_NAME = 'multi-mnist' 14 | __C.CONFIG_NAME = '' 15 | __C.GPU_ID = '0' 16 | __C.CUDA = True 17 | __C.WORKERS = 4 18 | 19 | __C.NET_G = '' 20 | __C.NET_D = '' 21 | __C.DATA_DIR = '' 22 | __C.VIS_COUNT = 64 23 | 24 | __C.Z_DIM = 100 25 | __C.IMSIZE = 64 26 | 27 | __C.USE_LOCAL_PATHWAY = True 28 | __C.USE_BBOX_LAYOUT = True 29 | 30 | # Training options 31 | __C.TRAIN = edict() 32 | __C.TRAIN.FLAG = True 33 | __C.TRAIN.BATCH_SIZE = 64 34 | __C.TRAIN.MAX_EPOCH = 600 35 | __C.TRAIN.SNAPSHOT_INTERVAL = 50 36 | __C.TRAIN.PRETRAINED_MODEL = '' 37 | __C.TRAIN.PRETRAINED_EPOCH = 600 38 | __C.TRAIN.LR_DECAY_EPOCH = 600 39 | __C.TRAIN.DISCRIMINATOR_LR = 2e-4 40 | __C.TRAIN.GENERATOR_LR = 2e-4 41 | 42 | # Modal options 43 | __C.GAN = edict() 44 | __C.GAN.CONDITION_DIM = 128 45 | __C.GAN.DF_DIM = 64 46 | __C.GAN.GF_DIM = 128 47 | __C.GAN.R_NUM = 4 48 | 49 | 50 | def _merge_a_into_b(a, b): 51 | """Merge config dictionary a into config dictionary b, clobbering the 52 | options in b whenever they are also specified in a. 53 | """ 54 | if type(a) is not edict: 55 | return 56 | 57 | for k, v in a.iteritems(): 58 | # a must specify keys that are in b 59 | if not b.has_key(k): 60 | raise KeyError('{} is not a valid config key'.format(k)) 61 | 62 | # the types must match, too 63 | old_type = type(b[k]) 64 | if old_type is not type(v): 65 | if isinstance(b[k], np.ndarray): 66 | v = np.array(v, dtype=b[k].dtype) 67 | else: 68 | raise ValueError(('Type mismatch ({} vs. {}) ' 69 | 'for config key: {}').format(type(b[k]), 70 | type(v), k)) 71 | 72 | # recursively merge dicts 73 | if type(v) is edict: 74 | try: 75 | _merge_a_into_b(a[k], b[k]) 76 | except: 77 | print('Error under config key: {}'.format(k)) 78 | raise 79 | else: 80 | b[k] = v 81 | 82 | 83 | def cfg_from_file(filename): 84 | """Load a config file and merge it into the default options.""" 85 | import yaml 86 | with open(filename, 'r') as f: 87 | yaml_cfg = edict(yaml.load(f)) 88 | 89 | _merge_a_into_b(yaml_cfg, __C) 90 | -------------------------------------------------------------------------------- /code/multi-mnist/miscc/datasets.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | 7 | import torch.utils.data as data 8 | import PIL 9 | import os 10 | import os.path 11 | import random 12 | import numpy as np 13 | from PIL import Image 14 | import torchvision.transforms as transforms 15 | import torch 16 | import sys 17 | if sys.version_info[0] == 2: 18 | import cPickle as pickle 19 | else: 20 | import pickle 21 | 22 | from miscc.config import cfg 23 | 24 | 25 | class TextDataset(data.Dataset): 26 | def __init__(self, data_dir, imsize, split='train', transform=None, crop=False): 27 | 28 | self.transform = transform 29 | self.imsize = imsize 30 | self.crop = crop 31 | self.data = [] 32 | self.data_dir = data_dir 33 | self.split_dir = os.path.join(data_dir, split, "normal") 34 | self.img_dir = self.split_dir + "/imgs/" 35 | self.max_objects = 3 36 | 37 | self.filenames = self.load_filenames() 38 | self.bboxes = self.load_bboxes() 39 | self.labels = self.load_labels() 40 | 41 | def get_img(self, img_path): 42 | img = Image.open(img_path) 43 | 44 | if self.transform is not None: 45 | img = self.transform(img) 46 | 47 | return img 48 | 49 | def load_bboxes(self): 50 | bbox_path = os.path.join(self.split_dir, 'bboxes.pickle') 51 | with open(bbox_path, "rb") as f: 52 | bboxes = pickle.load(f) 53 | bboxes = np.array(bboxes, dtype=np.double) 54 | return bboxes 55 | 56 | def load_labels(self): 57 | label_path = os.path.join(self.split_dir, 'labels.pickle') 58 | with open(label_path, "rb") as f: 59 | labels = pickle.load(f) 60 | labels = np.array(labels) 61 | return labels 62 | 63 | def load_filenames(self): 64 | filepath = os.path.join(self.split_dir, 'filenames.pickle') 65 | with open(filepath, 'rb') as f: 66 | filenames = pickle.load(f) 67 | print('Load filenames from: %s (%d)' % (filepath, len(filenames))) 68 | return filenames 69 | 70 | def __getitem__(self, index): 71 | # load image 72 | key = self.filenames[index] 73 | key = key.split("/")[-1] 74 | img_name = self.split_dir + "/imgs/" + key 75 | img = self.get_img(img_name) 76 | 77 | # load bbox 78 | bbox = self.bboxes[index].astype(np.double) 79 | 80 | # load label 81 | label = self.labels[index] 82 | 83 | return img, bbox, label 84 | 85 | def __len__(self): 86 | return len(self.filenames) 87 | -------------------------------------------------------------------------------- /code/multi-mnist/miscc/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import errno 3 | import numpy as np 4 | import cPickle as pickle 5 | import glob 6 | 7 | from copy import deepcopy 8 | from miscc.config import cfg 9 | 10 | from torch.nn import init 11 | import torch 12 | import torch.nn as nn 13 | import torchvision.utils as vutils 14 | from torch.autograd import grad 15 | from torch.autograd import Variable 16 | 17 | 18 | def compute_transformation_matrix_inverse(bbox): 19 | x, y = bbox[:, 0], bbox[:, 1] 20 | w, h = bbox[:, 2], bbox[:, 3] 21 | 22 | scale_x = 1.0 / w 23 | scale_y = 1.0 / h 24 | 25 | t_x = 2 * scale_x * (0.5 - (x + 0.5 * w)) 26 | t_y = 2 * scale_y * (0.5 - (y + 0.5 * h)) 27 | 28 | zeros = torch.cuda.DoubleTensor(bbox.shape[0],1).fill_(0) 29 | 30 | transformation_matrix = torch.cat([scale_x.unsqueeze(-1), zeros, t_x.unsqueeze(-1), 31 | zeros, scale_y.unsqueeze(-1), t_y.unsqueeze(-1)], 1).view(-1, 2, 3) 32 | 33 | return transformation_matrix 34 | 35 | 36 | def compute_transformation_matrix(bbox): 37 | x, y = bbox[:, 0], bbox[:, 1] 38 | w, h = bbox[:, 2], bbox[:, 3] 39 | 40 | scale_x = w 41 | scale_y = h 42 | 43 | t_x = 2 * ((x + 0.5 * w) - 0.5) 44 | t_y = 2 * ((y + 0.5 * h) - 0.5) 45 | 46 | zeros = torch.cuda.DoubleTensor(bbox.shape[0],1).fill_(0) 47 | 48 | transformation_matrix = torch.cat([scale_x.unsqueeze(-1), zeros, t_x.unsqueeze(-1), 49 | zeros, scale_y.unsqueeze(-1), t_y.unsqueeze(-1)], 1).view(-1, 2, 3) 50 | 51 | return transformation_matrix 52 | 53 | 54 | def pad_imgs(img, pad=2): 55 | m = nn.ConstantPad2d((pad, pad, pad, pad), 0) 56 | return m(img) 57 | 58 | 59 | def load_validation_data(datapath): 60 | with open(os.path.join(datapath, "normal", "bboxes.pickle"), "rb") as f: 61 | bboxes = pickle.load(f) 62 | bboxes = np.array(bboxes) 63 | 64 | with open(os.path.join(datapath, "normal", "labels.pickle"), "rb") as f: 65 | labels = pickle.load(f) 66 | labels = np.array(labels) 67 | 68 | return torch.from_numpy(labels), torch.from_numpy(bboxes) 69 | 70 | 71 | def compute_discriminator_loss(netD, real_imgs, fake_imgs, 72 | real_labels, fake_labels, 73 | local_label, transf_matrices, transf_matrices_inv, gpus): 74 | criterion = nn.BCEWithLogitsLoss() 75 | batch_size = real_imgs.size(0) 76 | fake = fake_imgs.detach() 77 | local_label = local_label.detach() 78 | local_label_cond = local_label[:, 0, :] + local_label[:, 1, :] + local_label[:, 2, :] 79 | real_features = nn.parallel.data_parallel(netD, (real_imgs, local_label, transf_matrices, transf_matrices_inv), gpus) 80 | fake_features = nn.parallel.data_parallel(netD, (fake, local_label, transf_matrices, transf_matrices_inv), gpus) 81 | # real pairs 82 | inputs = (real_features, local_label_cond) 83 | real_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus) 84 | errD_real = criterion(real_logits, real_labels) 85 | # wrong pairs 86 | inputs = (real_features[:(batch_size-1)], local_label_cond[1:]) 87 | wrong_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus) 88 | errD_wrong = criterion(wrong_logits, fake_labels[1:]) 89 | # fake pairs 90 | inputs = (fake_features, local_label_cond) 91 | fake_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus) 92 | errD_fake = criterion(fake_logits, fake_labels) 93 | 94 | if netD.get_uncond_logits is not None: 95 | real_logits = nn.parallel.data_parallel(netD.get_uncond_logits, (real_features), gpus) 96 | fake_logits = nn.parallel.data_parallel(netD.get_uncond_logits, (fake_features), gpus) 97 | uncond_errD_real = criterion(real_logits, real_labels) 98 | uncond_errD_fake = criterion(fake_logits, fake_labels) 99 | # 100 | errD = ((errD_real + uncond_errD_real) / 2. + 101 | (errD_fake + errD_wrong + uncond_errD_fake) / 3.) 102 | errD_real = (errD_real + uncond_errD_real) / 2. 103 | errD_fake = (errD_fake + uncond_errD_fake) / 2. 104 | else: 105 | errD = errD_real + (errD_fake + errD_wrong) * 0.5 106 | return errD, errD_real.item(), errD_wrong.item(), errD_fake.item() 107 | 108 | 109 | def compute_generator_loss(netD, fake_imgs, real_labels, local_label, transf_matrices, transf_matrices_inv, gpus): 110 | criterion = nn.BCEWithLogitsLoss() 111 | local_label = local_label.detach() 112 | local_label_cond = local_label[:, 0, :] + local_label[:, 1, :] + local_label[:, 2, :] 113 | fake_features = nn.parallel.data_parallel(netD, (fake_imgs, local_label, transf_matrices, transf_matrices_inv), gpus) 114 | # fake pairs 115 | inputs = (fake_features, local_label_cond) 116 | fake_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus) 117 | errD_fake = criterion(fake_logits, real_labels) 118 | if netD.get_uncond_logits is not None: 119 | fake_logits = nn.parallel.data_parallel(netD.get_uncond_logits, (fake_features), gpus) 120 | # fake_logits = torch.clamp(fake_logits, 1e-8, 1-1e-8) 121 | uncond_errD_fake = criterion(fake_logits, real_labels) 122 | errD_fake += uncond_errD_fake 123 | return errD_fake 124 | 125 | 126 | ############################# 127 | def weights_init(m): 128 | classname = m.__class__.__name__ 129 | if classname.find('Conv') != -1: 130 | m.weight.data.normal_(0.0, 0.02) 131 | elif classname.find('BatchNorm') != -1: 132 | m.weight.data.normal_(1.0, 0.02) 133 | m.bias.data.fill_(0) 134 | elif classname.find('Linear') != -1: 135 | m.weight.data.normal_(0.0, 0.02) 136 | if m.bias is not None: 137 | m.bias.data.fill_(0.0) 138 | 139 | 140 | ############################# 141 | def save_img_results(data_img, fake, epoch, image_dir): 142 | num = cfg.VIS_COUNT 143 | fake = fake[0:num] 144 | # data_img is changed to [0,1] 145 | if data_img is not None: 146 | data_img = data_img[0:num] 147 | vutils.save_image( 148 | data_img, '%s/real_samples.png' % image_dir, 149 | normalize=True) 150 | # fake.data is still [-1, 1] 151 | vutils.save_image( 152 | fake.data, '%s/fake_samples_epoch_%03d.png' % 153 | (image_dir, epoch), normalize=True) 154 | else: 155 | vutils.save_image( 156 | fake.data, '%s/lr_fake_samples_epoch_%03d.png' % 157 | (image_dir, epoch), normalize=True) 158 | 159 | 160 | def save_model(netG, netD, optimG, optimD, epoch, model_dir, saveD=False, saveOptim=False, max_to_keep=5): 161 | checkpoint = { 162 | 'epoch': epoch, 163 | 'netG': netG.state_dict(), 164 | 'optimG': optimG.state_dict() if saveOptim else {}, 165 | 'netD': netD.state_dict() if saveD else {}, 166 | 'optimD': optimD.state_dict() if saveOptim else {}} 167 | torch.save(checkpoint, "{}/checkpoint_{:04}.pth".format(model_dir, epoch)) 168 | print('Save G/D models') 169 | 170 | if max_to_keep is not None and max_to_keep > 0: 171 | checkpoint_list = sorted([ckpt for ckpt in glob.glob(model_dir + "/" + '*.pth')]) 172 | while len(checkpoint_list) > max_to_keep: 173 | os.remove(checkpoint_list[0]) 174 | checkpoint_list = checkpoint_list[1:] 175 | 176 | 177 | def mkdir_p(path): 178 | try: 179 | os.makedirs(path) 180 | except OSError as exc: # Python >2.5 181 | if exc.errno == errno.EEXIST and os.path.isdir(path): 182 | pass 183 | else: 184 | raise 185 | -------------------------------------------------------------------------------- /code/multi-mnist/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.parallel 4 | from miscc.config import cfg 5 | from miscc.utils import compute_transformation_matrix, compute_transformation_matrix_inverse 6 | from torch.autograd import Variable 7 | 8 | 9 | def conv3x3(in_planes, out_planes, stride=1): 10 | "3x3 convolution with padding" 11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 12 | padding=1, bias=False) 13 | 14 | 15 | # Upsale the spatial size by a factor of 2 16 | def upBlock(in_planes, out_planes): 17 | block = nn.Sequential( 18 | nn.Upsample(scale_factor=2, mode='nearest'), 19 | conv3x3(in_planes, out_planes), 20 | nn.BatchNorm2d(out_planes), 21 | nn.ReLU(True)) 22 | return block 23 | 24 | 25 | class ResBlock(nn.Module): 26 | def __init__(self, channel_num): 27 | super(ResBlock, self).__init__() 28 | self.block = nn.Sequential( 29 | conv3x3(channel_num, channel_num), 30 | nn.BatchNorm2d(channel_num), 31 | nn.ReLU(True), 32 | conv3x3(channel_num, channel_num), 33 | nn.BatchNorm2d(channel_num)) 34 | self.relu = nn.ReLU(inplace=True) 35 | 36 | def forward(self, x): 37 | residual = x 38 | out = self.block(x) 39 | out += residual 40 | out = self.relu(out) 41 | return out 42 | 43 | class D_GET_LOGITS(nn.Module): 44 | def __init__(self, ndf, nef, bcondition=True): 45 | super(D_GET_LOGITS, self).__init__() 46 | self.df_dim = ndf 47 | self.ef_dim = nef 48 | self.bcondition = bcondition 49 | if bcondition: 50 | self.outlogits = nn.Sequential( 51 | conv3x3(ndf * 8 + nef, ndf * 8), 52 | nn.BatchNorm2d(ndf * 8), 53 | nn.LeakyReLU(0.2, inplace=True), 54 | nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4)) 55 | else: 56 | self.outlogits = nn.Sequential( 57 | nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4)) 58 | 59 | def forward(self, h_code, c_code=None): 60 | # conditioning output 61 | if self.bcondition and c_code is not None: 62 | c_code = c_code.view(c_code.shape[0], 10, 1, 1) 63 | c_code = c_code.repeat(1, 1, 4, 4) 64 | # state size (ngf+egf) x 4 x 4 65 | h_c_code = torch.cat((h_code, c_code), 1) 66 | else: 67 | h_c_code = h_code 68 | 69 | output = self.outlogits(h_c_code) 70 | return output.view(-1) 71 | 72 | 73 | def stn(image, transformation_matrix, size): 74 | grid = torch.nn.functional.affine_grid(transformation_matrix, torch.Size(size)) 75 | out_image = torch.nn.functional.grid_sample(image, grid) 76 | 77 | return out_image 78 | 79 | 80 | class BBOX_NET(nn.Module): 81 | def __init__(self): 82 | super(BBOX_NET, self).__init__() 83 | self.c_dim = 128 84 | self.encode = nn.Sequential( 85 | # 128 * 16 x 16 86 | conv3x3(10, self.c_dim // 2, stride=2), 87 | nn.LeakyReLU(0.2, inplace=True), 88 | # 64 x 8 x 8 89 | conv3x3(self.c_dim // 2, self.c_dim // 4, stride=2), 90 | nn.BatchNorm2d(self.c_dim // 4), 91 | nn.LeakyReLU(0.2, inplace=True), 92 | # 32 x 4 x 4 93 | conv3x3(self.c_dim // 4, self.c_dim // 8, stride=2), 94 | nn.BatchNorm2d(self.c_dim // 8), 95 | nn.LeakyReLU(0.2, inplace=True), 96 | # 16 x 2 x 2 97 | ) 98 | 99 | def forward(self, labels, transf_matr_inv, num_digits): 100 | label_layout = torch.cuda.FloatTensor(labels.shape[0], 10, 16, 16).fill_(0) 101 | for idx in range(num_digits): 102 | current_label = labels[:, idx] 103 | current_label = current_label.view(current_label.shape[0], current_label.shape[1], 1, 1) 104 | current_label = current_label.repeat(1, 1, 16, 16) 105 | current_label = stn(current_label, transf_matr_inv[:, idx], current_label.shape) 106 | label_layout += current_label 107 | 108 | layout_encoding = self.encode(label_layout).view(labels.shape[0], -1) 109 | 110 | return layout_encoding 111 | 112 | # ############# Networks for stageI GAN ############# 113 | class STAGE1_G(nn.Module): 114 | def __init__(self): 115 | super(STAGE1_G, self).__init__() 116 | self.gf_dim = cfg.GAN.GF_DIM * 8 117 | self.ef_dim = 10 118 | self.z_dim = cfg.Z_DIM 119 | self.define_module() 120 | 121 | def define_module(self): 122 | ninput = self.z_dim 123 | linput = self.ef_dim 124 | ngf = self.gf_dim 125 | 126 | if cfg.USE_BBOX_LAYOUT: 127 | self.bbox_net = BBOX_NET() 128 | ninput += 64 129 | 130 | # -> ngf x 4 x 4 131 | self.fc = nn.Sequential( 132 | nn.Linear(ninput, ngf * 4 * 4, bias=False), 133 | nn.BatchNorm1d(ngf * 4 * 4), 134 | nn.ReLU(True)) 135 | 136 | # local pathway 137 | self.label = nn.Sequential( 138 | nn.Linear(linput, self.ef_dim, bias=False), 139 | nn.BatchNorm1d(self.ef_dim), 140 | nn.ReLU(True)) 141 | self.local1 = upBlock(self.ef_dim, ngf // 2) 142 | self.local2 = upBlock(ngf // 2, ngf // 4) 143 | 144 | # global pathway 145 | # ngf x 4 x 4 -> ngf/2 x 8 x 8 146 | self.upsample1 = upBlock(ngf, ngf // 2) 147 | # -> ngf/4 x 16 x 16 148 | self.upsample2 = upBlock(ngf // 2, ngf // 4) 149 | # -> ngf/8 x 32 x 32 150 | self.upsample3 = upBlock(ngf // 2, ngf // 8) 151 | # -> ngf/16 x 64 x 64 152 | self.upsample4 = upBlock(ngf // 8, ngf // 16) 153 | # -> 3 x 64 x 64 154 | self.img = nn.Sequential( 155 | conv3x3(ngf // 16, 1), 156 | nn.Tanh()) 157 | 158 | def forward(self, noise, transf_matrices_inv, label_one_hot, num_digits_per_image=3): 159 | # local pathway 160 | h_code_locals = torch.cuda.FloatTensor(noise.shape[0], self.gf_dim // 4, 16, 16).fill_(0) 161 | 162 | for idx in range(num_digits_per_image): 163 | current_label = label_one_hot[:, idx] 164 | current_label = current_label.view(current_label.shape[0], self.ef_dim, 1, 1) 165 | current_label = current_label.repeat(1, 1, 4, 4) 166 | h_code_local = self.local1(current_label) 167 | h_code_local = self.local2(h_code_local) 168 | h_code_local = stn(h_code_local, transf_matrices_inv[:, idx], h_code_local.shape) 169 | h_code_locals += h_code_local 170 | 171 | # global pathway 172 | if cfg.USE_BBOX_LAYOUT: 173 | bbox_code = self.bbox_net(label_one_hot, transf_matrices_inv, num_digits_per_image) 174 | z_c_code = torch.cat((noise, bbox_code), 1) 175 | else: 176 | z_c_code = torch.cat((noise), 1) 177 | h_code = self.fc(z_c_code) 178 | h_code = h_code.view(-1, self.gf_dim, 4, 4) 179 | h_code = self.upsample1(h_code) 180 | h_code = self.upsample2(h_code) 181 | 182 | # combine local and global 183 | h_code = torch.cat((h_code, h_code_locals), 1) 184 | 185 | h_code = self.upsample3(h_code) 186 | h_code = self.upsample4(h_code) 187 | 188 | # state size 3 x 64 x 64 189 | fake_img = self.img(h_code) 190 | return None, fake_img 191 | 192 | 193 | class STAGE1_D(nn.Module): 194 | def __init__(self): 195 | super(STAGE1_D, self).__init__() 196 | self.df_dim = cfg.GAN.DF_DIM 197 | self.ef_dim = 10 198 | self.define_module() 199 | 200 | def define_module(self): 201 | ndf, nef = self.df_dim, self.ef_dim 202 | 203 | # local pathway 204 | self.local = nn.Sequential( 205 | nn.Conv2d(1 + 10, ndf * 2, 4, 1, 1, bias=False), 206 | nn.BatchNorm2d(ndf * 2), 207 | nn.LeakyReLU(0.2, inplace=True) 208 | ) 209 | 210 | self.act = nn.LeakyReLU(0.2, inplace=True) 211 | 212 | self.conv1 = nn.Conv2d(1, ndf, 4, 2, 1, bias=False) 213 | self.conv2 = nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False) 214 | self.bn2 = nn.BatchNorm2d(ndf * 2) 215 | self.conv3 = nn.Conv2d(ndf*4, ndf * 4, 4, 2, 1, bias=False) 216 | self.bn3 = nn.BatchNorm2d(ndf * 4) 217 | self.conv4 = nn.Conv2d(ndf*4, ndf * 8, 4, 2, 1, bias=False) 218 | self.bn4 = nn.BatchNorm2d(ndf * 8) 219 | 220 | self.get_cond_logits = D_GET_LOGITS(ndf, nef) 221 | self.get_uncond_logits = None 222 | 223 | def _encode_img(self, image, label, transf_matrices, transf_matrices_inv, num_digits_per_image=3): 224 | # local pathway 225 | h_code_locals = torch.cuda.FloatTensor(image.shape[0], self.df_dim * 2, 16, 16).fill_(0) 226 | 227 | for idx in range(num_digits_per_image): 228 | current_label = label[:, idx].view(label.shape[0], 10, 1, 1) 229 | current_label = current_label.repeat(1, 1, 16, 16) 230 | h_code_local = stn(image, transf_matrices[:, idx], (image.shape[0], image.shape[1], 16, 16)) 231 | h_code_local = torch.cat((h_code_local, current_label), 1) 232 | h_code_local = self.local(h_code_local) 233 | h_code_local = stn(h_code_local, transf_matrices_inv[:, idx], (h_code_local.shape[0], h_code_local.shape[1], 16, 16)) 234 | h_code_locals += h_code_local 235 | 236 | h_code = self.conv1(image) 237 | h_code = self.act(h_code) 238 | h_code = self.conv2(h_code) 239 | h_code = self.bn2(h_code) 240 | h_code = self.act(h_code) 241 | 242 | # combine local and global 243 | h_code = torch.cat((h_code, h_code_locals), 1) 244 | 245 | h_code = self.conv3(h_code) 246 | h_code = self.bn3(h_code) 247 | h_code = self.act(h_code) 248 | 249 | h_code = self.conv4(h_code) 250 | h_code = self.bn4(h_code) 251 | h_code = self.act(h_code) 252 | return h_code 253 | 254 | def forward(self, image, label, transf_matrices, transf_matrices_inv): 255 | img_embedding = self._encode_img(image, label, transf_matrices, transf_matrices_inv) 256 | 257 | return img_embedding 258 | -------------------------------------------------------------------------------- /code/multi-mnist/trainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from six.moves import range 3 | from PIL import Image 4 | 5 | import torch.backends.cudnn as cudnn 6 | import torch 7 | import torch.nn as nn 8 | from torch.autograd import Variable 9 | import torch.optim as optim 10 | import os 11 | import time 12 | 13 | import numpy as np 14 | import torchfile 15 | 16 | from miscc.config import cfg 17 | from miscc.utils import mkdir_p 18 | from miscc.utils import weights_init 19 | from miscc.utils import save_img_results, save_model 20 | from miscc.utils import compute_discriminator_loss, compute_generator_loss 21 | from miscc.utils import compute_transformation_matrix, compute_transformation_matrix_inverse 22 | from miscc.utils import load_validation_data, pad_imgs 23 | 24 | from tensorboard import summary 25 | from tensorboard import FileWriter 26 | 27 | class GANTrainer(object): 28 | def __init__(self, output_dir): 29 | if cfg.TRAIN.FLAG: 30 | self.model_dir = os.path.join(output_dir, 'Model') 31 | self.image_dir = os.path.join(output_dir, 'Image') 32 | self.log_dir = os.path.join(output_dir, 'Log') 33 | mkdir_p(self.model_dir) 34 | mkdir_p(self.image_dir) 35 | mkdir_p(self.log_dir) 36 | self.summary_writer = FileWriter(self.log_dir) 37 | 38 | self.max_epoch = cfg.TRAIN.MAX_EPOCH 39 | self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL 40 | self.max_objects = 3 41 | 42 | s_gpus = cfg.GPU_ID.split(',') 43 | self.gpus = [int(ix) for ix in s_gpus] 44 | self.num_gpus = len(self.gpus) 45 | self.batch_size = cfg.TRAIN.BATCH_SIZE 46 | torch.cuda.set_device(self.gpus[0]) 47 | cudnn.benchmark = True 48 | 49 | # ############# For training stageI GAN ############# 50 | def load_network_stageI(self): 51 | from model import STAGE1_G, STAGE1_D 52 | netG = STAGE1_G() 53 | netG.apply(weights_init) 54 | print(netG) 55 | netD = STAGE1_D() 56 | netD.apply(weights_init) 57 | print(netD) 58 | 59 | if cfg.NET_G != '': 60 | state_dict = \ 61 | torch.load(cfg.NET_G, map_location=lambda storage, loc: storage) 62 | netG.load_state_dict(state_dict["netG"]) 63 | print('Load from: ', cfg.NET_G) 64 | if cfg.NET_D != '': 65 | state_dict = \ 66 | torch.load(cfg.NET_D, map_location=lambda storage, loc: storage) 67 | netD.load_state_dict(state_dict) 68 | print('Load from: ', cfg.NET_D) 69 | if cfg.CUDA: 70 | netG.cuda() 71 | netD.cuda() 72 | return netG, netD 73 | 74 | 75 | def train(self, data_loader): 76 | netG, netD = self.load_network_stageI() 77 | 78 | nz = cfg.Z_DIM 79 | batch_size = self.batch_size 80 | noise = Variable(torch.FloatTensor(batch_size, nz)) 81 | 82 | fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1), requires_grad=False) 83 | real_labels = Variable(torch.FloatTensor(batch_size).fill_(1)) 84 | fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0)) 85 | 86 | if cfg.CUDA: 87 | noise, fixed_noise = noise.cuda(), fixed_noise.cuda() 88 | real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda() 89 | 90 | generator_lr = cfg.TRAIN.GENERATOR_LR 91 | discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR 92 | lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH 93 | 94 | netG_para = [] 95 | for p in netG.parameters(): 96 | if p.requires_grad: 97 | netG_para.append(p) 98 | optimizerD = optim.Adam(netD.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) 99 | optimizerG = optim.Adam(netG_para, lr=cfg.TRAIN.GENERATOR_LR, betas=(0.5, 0.999)) 100 | 101 | print("Starting training...") 102 | count = 0 103 | for epoch in range(self.max_epoch): 104 | start_t = time.time() 105 | if epoch % lr_decay_step == 0 and epoch > 0: 106 | generator_lr *= 0.5 107 | for param_group in optimizerG.param_groups: 108 | param_group['lr'] = generator_lr 109 | discriminator_lr *= 0.5 110 | for param_group in optimizerD.param_groups: 111 | param_group['lr'] = discriminator_lr 112 | 113 | for i, data in enumerate(data_loader, 0): 114 | ###################################################### 115 | # (1) Prepare training data 116 | ###################################################### 117 | real_img_cpu, bbox, label = data 118 | 119 | real_imgs = Variable(real_img_cpu) 120 | if cfg.CUDA: 121 | real_imgs = real_imgs.cuda() 122 | bbox = bbox.cuda() 123 | label_one_hot = label.cuda().float() 124 | 125 | bbox = bbox.view(-1, 4) 126 | transf_matrices_inv = compute_transformation_matrix_inverse(bbox).float() 127 | transf_matrices_inv = transf_matrices_inv.view(real_imgs.shape[0], self.max_objects, 2, 3) 128 | transf_matrices = compute_transformation_matrix(bbox).float() 129 | transf_matrices = transf_matrices.view(real_imgs.shape[0], self.max_objects, 2, 3) 130 | 131 | ####################################################### 132 | # (2) Generate fake images 133 | ###################################################### 134 | noise.data.normal_(0, 1) 135 | inputs = (noise, transf_matrices_inv, label_one_hot) 136 | # _, fake_imgs = nn.parallel.data_parallel(netG, inputs, self.gpus) 137 | _, fake_imgs = netG(noise, transf_matrices_inv, label_one_hot) 138 | 139 | ############################ 140 | # (3) Update D network 141 | ########################### 142 | netD.zero_grad() 143 | errD, errD_real, errD_wrong, errD_fake = \ 144 | compute_discriminator_loss(netD, real_imgs, fake_imgs, 145 | real_labels, fake_labels, 146 | label_one_hot, transf_matrices, transf_matrices_inv, self.gpus) 147 | errD.backward(retain_graph=True) 148 | optimizerD.step() 149 | ############################ 150 | # (2) Update G network 151 | ########################### 152 | netG.zero_grad() 153 | errG = compute_generator_loss(netD, fake_imgs, real_labels, label_one_hot, 154 | transf_matrices, transf_matrices_inv, self.gpus) 155 | errG_total = errG 156 | errG_total.backward() 157 | optimizerG.step() 158 | 159 | ############################ 160 | # (3) Log results 161 | ########################### 162 | count += 1 163 | if i % 500 == 0: 164 | summary_D = summary.scalar('D_loss', errD.item()) 165 | summary_D_r = summary.scalar('D_loss_real', errD_real) 166 | summary_D_w = summary.scalar('D_loss_wrong', errD_wrong) 167 | summary_D_f = summary.scalar('D_loss_fake', errD_fake) 168 | summary_G = summary.scalar('G_loss', errG.item()) 169 | 170 | self.summary_writer.add_summary(summary_D, count) 171 | self.summary_writer.add_summary(summary_D_r, count) 172 | self.summary_writer.add_summary(summary_D_w, count) 173 | self.summary_writer.add_summary(summary_D_f, count) 174 | self.summary_writer.add_summary(summary_G, count) 175 | 176 | # save the image result for each epoch 177 | with torch.no_grad(): 178 | inputs = (noise, transf_matrices_inv, label_one_hot) 179 | lr_fake, fake = nn.parallel.data_parallel(netG, inputs, self.gpus) 180 | real_img_cpu = pad_imgs(real_img_cpu) 181 | fake = pad_imgs(fake) 182 | save_img_results(real_img_cpu, fake, epoch, self.image_dir) 183 | if lr_fake is not None: 184 | save_img_results(None, lr_fake, epoch, self.image_dir) 185 | with torch.no_grad(): 186 | inputs = (noise, transf_matrices_inv, label_one_hot) 187 | lr_fake, fake = nn.parallel.data_parallel(netG, inputs, self.gpus) 188 | real_img_cpu = pad_imgs(real_img_cpu) 189 | fake = pad_imgs(fake) 190 | save_img_results(real_img_cpu, fake, epoch, self.image_dir) 191 | if lr_fake is not None: 192 | save_img_results(None, lr_fake, epoch, self.image_dir) 193 | end_t = time.time() 194 | print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f 195 | Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f 196 | Total Time: %.2fsec 197 | ''' 198 | % (epoch, self.max_epoch, i, len(data_loader), 199 | errD.item(), errG.item(), 200 | errD_real, errD_wrong, errD_fake, (end_t - start_t))) 201 | if epoch % self.snapshot_interval == 0: 202 | save_model(netG, netD, optimizerG, optimizerD, epoch, self.model_dir) 203 | # 204 | save_model(netG, netD, optimizerG, optimizerD, epoch, self.model_dir) 205 | # 206 | self.summary_writer.close() 207 | 208 | def sample(self, datapath, num_samples=25, draw_bbox=True, num_digits_per_img=3, change_bbox_size=False): 209 | from PIL import Image, ImageDraw, ImageFont 210 | import cPickle as pickle 211 | import torchvision 212 | import torchvision.utils as vutils 213 | img_dir = os.path.join(datapath, "normal", "imgs/") 214 | netG, _ = self.load_network_stageI() 215 | netG.eval() 216 | test_set_size = 10000 217 | 218 | label, bbox = load_validation_data(datapath) 219 | if num_digits_per_img < 3: 220 | label = label[:, :num_digits_per_img, :] 221 | bbox = bbox[:, :num_digits_per_img, ...] 222 | elif num_digits_per_img > 3: 223 | def get_one_hot(targets, nb_classes): 224 | res = np.eye(nb_classes)[np.array(targets).reshape(-1)] 225 | return res.reshape(list(targets.shape) + [nb_classes]) 226 | 227 | labels_sample = np.random.randint(0, 10, size=(bbox.shape[0], num_digits_per_img-3)) 228 | labels_sample = get_one_hot(labels_sample, 10) 229 | labels_new = np.zeros((label.shape[0], num_digits_per_img, 10)) 230 | labels_new[:, :3, :] = label 231 | labels_new[:, 3:, :] = labels_sample 232 | label = torch.from_numpy(labels_new) 233 | 234 | bboxes_x = np.random.random((bbox.shape[0], num_digits_per_img-3, 1)) 235 | bboxes_y = np.random.random((bbox.shape[0], num_digits_per_img-3, 1)) 236 | bboxes_w = np.random.randint(10, 20, size=(bbox.shape[0], num_digits_per_img-3, 1)) / 64.0 237 | bboxes_h = np.random.randint(16, 20, size=(bbox.shape[0], num_digits_per_img-3, 1)) / 64.0 238 | 239 | bbox_new_concat = np.concatenate((bboxes_x, bboxes_y, bboxes_w, bboxes_h), axis=2) 240 | bbox_new = np.zeros([bbox.shape[0], num_digits_per_img, 4]) 241 | bbox_new[:, :3, :] = bbox 242 | bbox_new[:, 3:, :] = bbox_new_concat 243 | bbox = torch.from_numpy(bbox_new) 244 | 245 | if change_bbox_size: 246 | bbox_idx = np.random.randint(0, bbox.shape[1]) 247 | scale_x = np.random.random(bbox.shape[0]) 248 | scale_x[scale_x < 0.5] = 0.5 249 | scale_y = np.random.random(bbox.shape[0]) 250 | scale_y[scale_y < 0.5] = 0.5 251 | 252 | bbox[:, bbox_idx, 2] *= torch.from_numpy(scale_x) 253 | bbox[:, bbox_idx, 3] *= torch.from_numpy(scale_y) 254 | 255 | filepath = os.path.join(datapath, "normal", 'filenames.pickle') 256 | with open(filepath, 'rb') as f: 257 | filenames = pickle.load(f) 258 | # path to save generated samples 259 | save_dir = cfg.NET_G[:cfg.NET_G.find('.pth')] + "_samples_" + str(num_digits_per_img) + "_digits" 260 | if change_bbox_size: 261 | save_dir += "_change_bbox_size" 262 | print("Saving {} to {}:".format(num_samples, save_dir)) 263 | mkdir_p(save_dir) 264 | if cfg.CUDA: 265 | bbox = bbox.cuda() 266 | label_one_hot = label.cuda().float() 267 | 268 | ####################################### 269 | bbox_ = bbox.clone() 270 | bbox = bbox.view(-1, 4) 271 | transf_matrices_inv = compute_transformation_matrix_inverse(bbox).float() 272 | transf_matrices_inv = transf_matrices_inv.view(test_set_size, num_digits_per_img, 2, 3) 273 | ####################################### 274 | 275 | nz = cfg.Z_DIM 276 | noise = Variable(torch.FloatTensor(9, nz)) 277 | if cfg.CUDA: 278 | noise = noise.cuda() 279 | 280 | imsize = 64 281 | 282 | for count in range(num_samples): 283 | index = int(np.random.randint(0, test_set_size, 1)) 284 | key = filenames[index].split("/")[-1] 285 | img_name = img_dir + key 286 | img = Image.open(img_name) 287 | val_image = torchvision.transforms.functional.to_tensor(img) 288 | val_image = val_image.view(1, 1, imsize, imsize) 289 | val_image = (val_image - 0.5) * 2 290 | 291 | transf_matrices_inv_batch = transf_matrices_inv[index] 292 | label_one_hot_batch = label_one_hot[index] 293 | 294 | transf_matrices_inv_batch = transf_matrices_inv_batch.view(1, num_digits_per_img, 2, 3).repeat(9, 1, 1, 1) 295 | label_one_hot_batch = label_one_hot_batch.view(1, num_digits_per_img, 10).repeat(9, 1, 1) 296 | 297 | if cfg.CUDA: 298 | label_one_hot_batch = label_one_hot_batch.cuda() 299 | 300 | ####################################################### 301 | # (2) Generate fake images 302 | ###################################################### 303 | noise.data.normal_(0, 1) 304 | inputs = (noise, transf_matrices_inv_batch, label_one_hot_batch, num_digits_per_img) 305 | _, fake_imgs = nn.parallel.data_parallel(netG, inputs, self.gpus) 306 | 307 | data_img = torch.FloatTensor(20, 1, imsize, imsize).fill_(0) 308 | data_img[0] = val_image 309 | data_img[1:10] = fake_imgs 310 | 311 | if draw_bbox: 312 | for idx in range(num_digits_per_img): 313 | x, y, w, h = tuple([int(imsize*x) for x in bbox_[index, idx]]) 314 | w = imsize-1 if w > imsize-1 else w 315 | h = imsize-1 if h > imsize-1 else h 316 | while x + w >= 64: 317 | x -= 1 318 | w -= 1 319 | while y + h >= 64: 320 | y -= 1 321 | h -= 1 322 | if x <= -1: 323 | break 324 | data_img[:10, :, y, x:x + w] = 1 325 | data_img[:10, :, y:y + h, x] = 1 326 | data_img[:10, :, y+h, x:x + w] = 1 327 | data_img[:10, :, y:y + h, x + w] = 1 328 | 329 | # write digit identities into image 330 | text_img = Image.new('L', (imsize*10, imsize), color = 'white') 331 | d = ImageDraw.Draw(text_img) 332 | label = label_one_hot_batch[0] 333 | label = label.cpu().numpy() 334 | label = np.argmax(label, axis=1) 335 | label = ", ".join([str(label[_]) for _ in range(num_digits_per_img)]) 336 | d.text((10,10), label) 337 | text_img = torchvision.transforms.functional.to_tensor(text_img) 338 | text_img = torch.chunk(text_img, 10, 2) 339 | text_img = torch.cat([text_img[i].view(1, 1, imsize, imsize) for i in range(10)], 0) 340 | data_img[10:] = text_img 341 | vutils.save_image(data_img, '{}/vis_{}.png'.format(save_dir, count), normalize=True, nrow=10) 342 | print("Saved {} files to {}".format(count+1, save_dir)) 343 | 344 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Data 2 | Check out the instructions in the *Data* section of the [main readme](https://github.com/tohinz/multiple-objects-gan#data). 3 | -------------------------------------------------------------------------------- /examples/clevr_cogent.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/multiple-objects-gan/6b0bd4d559f59897b0b9df61c2972ab142adcbe9/examples/clevr_cogent.png -------------------------------------------------------------------------------- /examples/clevr_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/multiple-objects-gan/6b0bd4d559f59897b0b9df61c2972ab142adcbe9/examples/clevr_example.png -------------------------------------------------------------------------------- /examples/clevr_generated.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/multiple-objects-gan/6b0bd4d559f59897b0b9df61c2972ab142adcbe9/examples/clevr_generated.png -------------------------------------------------------------------------------- /examples/clevr_real_examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/multiple-objects-gan/6b0bd4d559f59897b0b9df61c2972ab142adcbe9/examples/clevr_real_examples.png -------------------------------------------------------------------------------- /examples/coco_attngan_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/multiple-objects-gan/6b0bd4d559f59897b0b9df61c2972ab142adcbe9/examples/coco_attngan_example.png -------------------------------------------------------------------------------- /examples/coco_bbox_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/multiple-objects-gan/6b0bd4d559f59897b0b9df61c2972ab142adcbe9/examples/coco_bbox_example.png -------------------------------------------------------------------------------- /examples/coco_examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/multiple-objects-gan/6b0bd4d559f59897b0b9df61c2972ab142adcbe9/examples/coco_examples.png -------------------------------------------------------------------------------- /examples/coco_generated_examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/multiple-objects-gan/6b0bd4d559f59897b0b9df61c2972ab142adcbe9/examples/coco_generated_examples.png -------------------------------------------------------------------------------- /examples/coco_no_bbox.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/multiple-objects-gan/6b0bd4d559f59897b0b9df61c2972ab142adcbe9/examples/coco_no_bbox.png -------------------------------------------------------------------------------- /examples/coco_pathway.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/multiple-objects-gan/6b0bd4d559f59897b0b9df61c2972ab142adcbe9/examples/coco_pathway.png -------------------------------------------------------------------------------- /examples/coco_stackgan_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/multiple-objects-gan/6b0bd4d559f59897b0b9df61c2972ab142adcbe9/examples/coco_stackgan_example.png -------------------------------------------------------------------------------- /examples/coco_stuff_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/multiple-objects-gan/6b0bd4d559f59897b0b9df61c2972ab142adcbe9/examples/coco_stuff_example.png -------------------------------------------------------------------------------- /examples/d_final_pathway.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/multiple-objects-gan/6b0bd4d559f59897b0b9df61c2972ab142adcbe9/examples/d_final_pathway.png -------------------------------------------------------------------------------- /examples/d_global_pathway.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/multiple-objects-gan/6b0bd4d559f59897b0b9df61c2972ab142adcbe9/examples/d_global_pathway.png -------------------------------------------------------------------------------- /examples/d_object_pathway.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/multiple-objects-gan/6b0bd4d559f59897b0b9df61c2972ab142adcbe9/examples/d_object_pathway.png -------------------------------------------------------------------------------- /examples/datasets_examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/multiple-objects-gan/6b0bd4d559f59897b0b9df61c2972ab142adcbe9/examples/datasets_examples.png -------------------------------------------------------------------------------- /examples/g_final_pathway.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/multiple-objects-gan/6b0bd4d559f59897b0b9df61c2972ab142adcbe9/examples/g_final_pathway.png -------------------------------------------------------------------------------- /examples/g_global_pathway.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/multiple-objects-gan/6b0bd4d559f59897b0b9df61c2972ab142adcbe9/examples/g_global_pathway.png -------------------------------------------------------------------------------- /examples/g_object_pathway.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/multiple-objects-gan/6b0bd4d559f59897b0b9df61c2972ab142adcbe9/examples/g_object_pathway.png -------------------------------------------------------------------------------- /examples/gan_graphic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/multiple-objects-gan/6b0bd4d559f59897b0b9df61c2972ab142adcbe9/examples/gan_graphic.png -------------------------------------------------------------------------------- /examples/label_encoding.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/multiple-objects-gan/6b0bd4d559f59897b0b9df61c2972ab142adcbe9/examples/label_encoding.png -------------------------------------------------------------------------------- /examples/layout_encoding.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/multiple-objects-gan/6b0bd4d559f59897b0b9df61c2972ab142adcbe9/examples/layout_encoding.png -------------------------------------------------------------------------------- /examples/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/multiple-objects-gan/6b0bd4d559f59897b0b9df61c2972ab142adcbe9/examples/model.png -------------------------------------------------------------------------------- /examples/multi-mnist_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/multiple-objects-gan/6b0bd4d559f59897b0b9df61c2972ab142adcbe9/examples/multi-mnist_example.png -------------------------------------------------------------------------------- /examples/multi_mnist_digit_bottom_half.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/multiple-objects-gan/6b0bd4d559f59897b0b9df61c2972ab142adcbe9/examples/multi_mnist_digit_bottom_half.png -------------------------------------------------------------------------------- /examples/multi_mnist_digit_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/multiple-objects-gan/6b0bd4d559f59897b0b9df61c2972ab142adcbe9/examples/multi_mnist_digit_example.png -------------------------------------------------------------------------------- /examples/multi_mnist_digit_generalization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/multiple-objects-gan/6b0bd4d559f59897b0b9df61c2972ab142adcbe9/examples/multi_mnist_digit_generalization.png -------------------------------------------------------------------------------- /examples/multi_mnist_real.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/multiple-objects-gan/6b0bd4d559f59897b0b9df61c2972ab142adcbe9/examples/multi_mnist_real.png -------------------------------------------------------------------------------- /examples/multi_mnist_size_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/multiple-objects-gan/6b0bd4d559f59897b0b9df61c2972ab142adcbe9/examples/multi_mnist_size_example.png -------------------------------------------------------------------------------- /models/README.md: -------------------------------------------------------------------------------- 1 | Put the pretrained models here. 2 | -------------------------------------------------------------------------------- /poster/Generating Multiple Objects at Spatially Distinct Locations.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/multiple-objects-gan/6b0bd4d559f59897b0b9df61c2972ab142adcbe9/poster/Generating Multiple Objects at Spatially Distinct Locations.pdf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | backports.functools-lru-cache==1.5 2 | cloudpickle==0.6.1 3 | cycler==0.10.0 4 | dask==0.20.0 5 | decorator==4.3.0 6 | easydict==1.9 7 | funcsigs==1.0.2 8 | kiwisolver==1.0.1 9 | matplotlib==2.2.3 10 | mock==2.0.0 11 | networkx==2.2 12 | nltk==3.3 13 | numpy==1.15.4 14 | pbr==5.1.0 15 | Pillow==5.3.0 16 | pkg-resources==0.0.0 17 | protobuf==3.6.1 18 | pyparsing==2.3.0 19 | python-dateutil==2.7.5 20 | pytz==2018.7 21 | PyWavelets==1.0.1 22 | PyYAML==3.13 23 | scikit-image==0.14.1 24 | scipy==1.1.0 25 | six==1.11.0 26 | subprocess32==3.5.3 27 | tensorboard==1.0.0a4 28 | toolz==0.9.0 29 | torch==0.4.1 30 | torchfile==0.1.0 31 | torchvision==0.2.1 32 | Werkzeug==0.14.1 33 | -------------------------------------------------------------------------------- /sample.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | DATASET="$1" 3 | 4 | if [ "$DATASET" = "mnist" ]; then 5 | echo "Sampling from the Multi-MNIST data set." 6 | echo "Going to Multi-MNIST folder." 7 | cd code/multi-mnist 8 | python main.py --cfg cfg/mnist_eval.yml 9 | elif [ "$DATASET" = "clevr" ]; then 10 | echo "Sampling from the CLEVR data set." 11 | echo "Going to CLEVR folder." 12 | cd code/clevr 13 | python main.py --cfg cfg/clevr_eval.yml 14 | elif [ "$DATASET" = "coco-stackgan-1" ]; then 15 | echo "Starting training on the MS-COCO data set." 16 | echo "Going to MS-COCO folder." 17 | cd code/coco/stackgan 18 | python main.py --cfg cfg/coco_s1_eval.yml 19 | elif [ "$DATASET" = "coco-stackgan-2" ]; then 20 | echo "Starting training on the MS-COCO data set." 21 | echo "Going to MS-COCO folder." 22 | cd code/coco/stackgan 23 | python main.py --cfg cfg/coco_s2_eval.yml 24 | elif [ "$DATASET" = "coco-attngan" ]; then 25 | echo "Starting training on the MS-COCO data set." 26 | echo "Going to MS-COCO folder." 27 | cd code/coco/attngan 28 | python main.py --cfg cfg/coco_eval.yml 29 | else 30 | echo "Only one argument allowed. Must be either \"mnist\", \"clevr\", \"coco-stackgan-1\", \"coco-stackgan-2\", or \"coco-attngan\"." 31 | fi 32 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | DATASET="$1" 3 | GPU="$2" 4 | 5 | if [ "$DATASET" = "mnist" ]; then 6 | echo "Starting training on the Multi-MNIST data set." 7 | cd code/multi-mnist 8 | python main.py --cfg cfg/mnist_train.yml --gpu "$GPU" 9 | cd ../../ 10 | elif [ "$DATASET" = "clevr" ]; then 11 | echo "Starting training on the CLEVR data set." 12 | cd code/clevr 13 | python main.py --cfg cfg/clevr_train.yml --gpu "$GPU" 14 | cd ../../ 15 | elif [ "$DATASET" = "coco-stackgan-1" ]; then 16 | echo "Starting training on the MS-COCO data set." 17 | cd code/coco/stackgan 18 | python main.py --cfg cfg/coco_s1_train.yml --gpu "$GPU" 19 | cd ../../../ 20 | elif [ "$DATASET" = "coco-stackgan-2" ]; then 21 | echo "Starting training on the MS-COCO data set." 22 | cd code/coco/stackgan 23 | python main.py --cfg cfg/coco_s2_train.yml --gpu "$GPU" 24 | cd ../../../ 25 | elif [ "$DATASET" = "coco-attngan" ]; then 26 | echo "Starting training on the MS-COCO data set." 27 | cd code/coco/attngan 28 | python main.py --cfg cfg/coco_train.yml --gpu "$GPU" 29 | cd ../../../ 30 | else 31 | echo "Dataset argument must be either \"mnist\", \"clevr\", \"coco-stackgan-1\", \"coco-stackgan-2\", or \"coco-attngan\"." 32 | fi 33 | --------------------------------------------------------------------------------