├── files ├── approach.png ├── teaser.png ├── qual_res1.png ├── qual_res2.png ├── qual_res3.png └── finegan_demo.gif ├── code ├── miscc │ ├── __init__.py │ ├── utils.py │ └── config.py ├── cfg │ ├── eval.yml │ └── train.yml ├── main.py ├── datasets.py ├── inception.py ├── model.py └── trainer.py ├── models └── README.md ├── data └── README.md ├── LICENSE └── README.md /files/approach.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkanshul/finegan/HEAD/files/approach.png -------------------------------------------------------------------------------- /files/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkanshul/finegan/HEAD/files/teaser.png -------------------------------------------------------------------------------- /files/qual_res1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkanshul/finegan/HEAD/files/qual_res1.png -------------------------------------------------------------------------------- /files/qual_res2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkanshul/finegan/HEAD/files/qual_res2.png -------------------------------------------------------------------------------- /files/qual_res3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkanshul/finegan/HEAD/files/qual_res3.png -------------------------------------------------------------------------------- /files/finegan_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkanshul/finegan/HEAD/files/finegan_demo.gif -------------------------------------------------------------------------------- /code/miscc/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | -------------------------------------------------------------------------------- /code/miscc/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import errno 3 | 4 | def mkdir_p(path): 5 | try: 6 | os.makedirs(path) 7 | except OSError as exc: # Python >2.5 8 | if exc.errno == errno.EEXIST and os.path.isdir(path): 9 | pass 10 | else: 11 | raise 12 | -------------------------------------------------------------------------------- /models/README.md: -------------------------------------------------------------------------------- 1 | ## Download pretrained models 2 | Pretrained generator models for CUB, Stanford Dogs are available at this [link](https://drive.google.com/file/d/1cKJAXRDQ-_a76bHWRqcIdmPXqpN8a1lR/view?usp=sharing). Download and extract them in the `models` directory. 3 | ```bash 4 | cd models 5 | unzip netG.zip 6 | cd .. 7 | ``` 8 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | ## Data 2 | **Note**: You only need to download the data if you wish to train your own model. 3 | 4 | Download the formatted CUB data from this [link](https://drive.google.com/file/d/1ardy8L7Cb-Vn1ynQigaXpX_JHl0dhh2M/view?usp=sharing) and extract it inside the `data` directory 5 | ```bash 6 | cd data 7 | unzip birds.zip 8 | cd .. 9 | ``` 10 | -------------------------------------------------------------------------------- /code/cfg/eval.yml: -------------------------------------------------------------------------------- 1 | DATASET_NAME: 'birds' 2 | SAVE_DIR: '../output/' 3 | GPU_ID: '3' 4 | WORKERS: 1 # 4 5 | 6 | SUPER_CATEGORIES: 20 # For CUB 7 | FINE_GRAINED_CATEGORIES: 200 # For CUB 8 | TEST_CHILD_CLASS: 125 # specify any value [0, FINE_GRAINED_CATEGORIES - 1] 9 | TEST_PARENT_CLASS: 0 # specify any value [0, SUPER_CATEGORIES - 1] 10 | TEST_BACKGROUND_CLASS: 0 # specify any value [0, FINE_GRAINED_CATEGORIES - 1] 11 | TIED_CODES: True 12 | 13 | TRAIN: 14 | FLAG: False 15 | NET_G: '../models/netG/netG_birds.pth' 16 | BATCH_SIZE: 1 17 | 18 | 19 | GAN: 20 | DF_DIM: 64 21 | GF_DIM: 64 22 | Z_DIM: 100 23 | R_NUM: 2 24 | -------------------------------------------------------------------------------- /code/cfg/train.yml: -------------------------------------------------------------------------------- 1 | CONFIG_NAME: '3stages' 2 | DATASET_NAME: 'birds' 3 | DATA_DIR: '../data/birds' 4 | SAVE_DIR: '../output/vis' 5 | GPU_ID: '0' 6 | WORKERS: 4 7 | 8 | SUPER_CATEGORIES: 20 # For CUB 9 | FINE_GRAINED_CATEGORIES: 200 # For CUB 10 | TIED_CODES: True # Do NOT change this to False during training. 11 | 12 | TREE: 13 | BRANCH_NUM: 3 14 | 15 | TRAIN: 16 | FLAG: True 17 | NET_G: '' # Specify the generator path to resume training 18 | NET_D: '' # Specify the discriminator path to resume training 19 | BATCH_SIZE: 16 20 | MAX_EPOCH: 600 21 | HARDNEG_MAX_ITER: 1500 22 | SNAPSHOT_INTERVAL: 4000 23 | SNAPSHOT_INTERVAL_HARDNEG: 500 24 | DISCRIMINATOR_LR: 0.0002 25 | GENERATOR_LR: 0.0002 26 | 27 | GAN: 28 | DF_DIM: 64 29 | GF_DIM: 64 30 | Z_DIM: 100 31 | R_NUM: 2 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019, Krishna Kumar Singh, Utkarsh Ojha, Yong Jae Lee 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | -------------------------------------------------------------------------------- /code/miscc/config.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import os.path as osp 5 | import numpy as np 6 | from easydict import EasyDict as edict 7 | 8 | 9 | __C = edict() 10 | cfg = __C 11 | 12 | # Dataset name: flowers, birds 13 | __C.DATASET_NAME = 'birds' 14 | __C.CONFIG_NAME = '' 15 | __C.DATA_DIR = '' 16 | __C.SAVE_DIR = '' 17 | __C.GPU_ID = '0' 18 | __C.CUDA = True 19 | 20 | __C.WORKERS = 6 21 | 22 | __C.TREE = edict() 23 | __C.TREE.BRANCH_NUM = 3 24 | __C.TREE.BASE_SIZE = 64 25 | __C.SUPER_CATEGORIES = 20 26 | __C.FINE_GRAINED_CATEGORIES = 200 27 | __C.TEST_CHILD_CLASS = 0 28 | __C.TEST_PARENT_CLASS = 0 29 | __C.TEST_BACKGROUND_CLASS = 0 30 | __C.TIED_CODES = True 31 | 32 | # Test options 33 | __C.TEST = edict() 34 | 35 | # Training options 36 | __C.TRAIN = edict() 37 | __C.TRAIN.BATCH_SIZE = 64 38 | __C.TRAIN.BG_LOSS_WT = 10 39 | __C.TRAIN.VIS_COUNT = 64 40 | __C.TRAIN.MAX_EPOCH = 600 41 | __C.TRAIN.HARDNEG_MAX_ITER = 1500 42 | __C.TRAIN.SNAPSHOT_INTERVAL = 2000 43 | __C.TRAIN.SNAPSHOT_INTERVAL_HARDNEG = 500 44 | __C.TRAIN.DISCRIMINATOR_LR = 2e-4 45 | __C.TRAIN.GENERATOR_LR = 2e-4 46 | __C.TRAIN.FLAG = True 47 | __C.TRAIN.NET_G = '' 48 | __C.TRAIN.NET_D = '' 49 | 50 | 51 | # Modal options 52 | __C.GAN = edict() 53 | __C.GAN.DF_DIM = 64 54 | __C.GAN.GF_DIM = 64 55 | __C.GAN.Z_DIM = 100 56 | __C.GAN.NETWORK_TYPE = 'default' 57 | __C.GAN.R_NUM = 2 58 | 59 | 60 | 61 | 62 | def _merge_a_into_b(a, b): 63 | """Merge config dictionary a into config dictionary b, clobbering the 64 | options in b whenever they are also specified in a. 65 | """ 66 | if type(a) is not edict: 67 | return 68 | 69 | for k, v in a.iteritems(): 70 | # a must specify keys that are in b 71 | if not b.has_key(k): 72 | raise KeyError('{} is not a valid config key'.format(k)) 73 | 74 | # the types must match, too 75 | old_type = type(b[k]) 76 | if old_type is not type(v): 77 | if isinstance(b[k], np.ndarray): 78 | v = np.array(v, dtype=b[k].dtype) 79 | else: 80 | raise ValueError(('Type mismatch ({} vs. {}) ' 81 | 'for config key: {}').format(type(b[k]), 82 | type(v), k)) 83 | 84 | # recursively merge dicts 85 | if type(v) is edict: 86 | try: 87 | _merge_a_into_b(a[k], b[k]) 88 | except: 89 | print('Error under config key: {}'.format(k)) 90 | raise 91 | else: 92 | b[k] = v 93 | 94 | 95 | def cfg_from_file(filename): 96 | """Load a config file and merge it into the default options.""" 97 | import yaml 98 | with open(filename, 'r') as f: 99 | yaml_cfg = edict(yaml.load(f)) 100 | 101 | _merge_a_into_b(yaml_cfg, __C) 102 | -------------------------------------------------------------------------------- /code/main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torchvision.transforms as transforms 4 | 5 | import argparse 6 | import os 7 | import random 8 | import sys 9 | import pprint 10 | import datetime 11 | import dateutil.tz 12 | import time 13 | import pickle 14 | 15 | dir_path = (os.path.abspath(os.path.join(os.path.realpath(__file__), './.'))) 16 | sys.path.append(dir_path) 17 | 18 | 19 | from miscc.config import cfg, cfg_from_file 20 | 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser(description='Train a GAN network') 24 | parser.add_argument('--cfg', dest='cfg_file', 25 | help='optional config file', 26 | default='cfg/birds_proGAN.yml', type=str) 27 | parser.add_argument('--gpu', dest='gpu_id', type=str, default='-1') 28 | parser.add_argument('--data_dir', dest='data_dir', type=str, default='') 29 | parser.add_argument('--manualSeed', type=int, help='manual seed') 30 | #parser.add_argument('--config_key',dest='config_key', type=str, help='configuration name', default = 'finegan_birds') 31 | args = parser.parse_args() 32 | return args 33 | 34 | 35 | if __name__ == "__main__": 36 | args = parse_args() 37 | if args.cfg_file is not None: 38 | cfg_from_file(args.cfg_file) 39 | 40 | if args.gpu_id != '-1': 41 | cfg.GPU_ID = args.gpu_id 42 | else: 43 | cfg.CUDA = False 44 | 45 | if args.data_dir != '': 46 | cfg.DATA_DIR = args.data_dir 47 | if cfg.TRAIN.FLAG: 48 | print('Using config:') 49 | pprint.pprint(cfg) 50 | 51 | if not cfg.TRAIN.FLAG: 52 | args.manualSeed = 45 # Change this to have different random seed during evaluation 53 | 54 | elif args.manualSeed is None: 55 | args.manualSeed = random.randint(1, 10000) 56 | random.seed(args.manualSeed) 57 | torch.manual_seed(args.manualSeed) 58 | if cfg.CUDA: 59 | torch.cuda.manual_seed_all(args.manualSeed) 60 | 61 | # Evaluation part 62 | if not cfg.TRAIN.FLAG: 63 | from trainer import FineGAN_evaluator as evaluator 64 | algo = evaluator() 65 | algo.evaluate_finegan() 66 | 67 | # Training part 68 | else: 69 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 70 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') 71 | output_dir = '../output/%s_%s' % \ 72 | (cfg.DATASET_NAME, timestamp) 73 | pkl_filename = 'cfg.pickle' 74 | 75 | if not os.path.exists(output_dir): 76 | os.makedirs(output_dir) 77 | 78 | with open(os.path.join(output_dir, pkl_filename), 'wb') as pk: 79 | pickle.dump(cfg, pk, protocol=pickle.HIGHEST_PROTOCOL) 80 | 81 | bshuffle = True 82 | 83 | # Get data loader 84 | imsize = cfg.TREE.BASE_SIZE * (2 ** (cfg.TREE.BRANCH_NUM-1)) 85 | image_transform = transforms.Compose([ 86 | transforms.Scale(int(imsize * 76 / 64)), 87 | transforms.RandomCrop(imsize), 88 | transforms.RandomHorizontalFlip()]) 89 | 90 | 91 | from datasets import Dataset 92 | dataset = Dataset(cfg.DATA_DIR, 93 | base_size=cfg.TREE.BASE_SIZE, 94 | transform=image_transform) 95 | assert dataset 96 | num_gpu = len(cfg.GPU_ID.split(',')) 97 | dataloader = torch.utils.data.DataLoader( 98 | dataset, batch_size=cfg.TRAIN.BATCH_SIZE * num_gpu, 99 | drop_last=True, shuffle=bshuffle, num_workers=int(cfg.WORKERS)) 100 | 101 | 102 | from trainer import FineGAN_trainer as trainer 103 | algo = trainer(output_dir, dataloader, imsize) 104 | 105 | start_t = time.time() 106 | algo.train() 107 | end_t = time.time() 108 | print('Total time for training:', end_t - start_t) 109 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FineGAN 2 | Pytorch implementation for learning to synthesize images in a hierarchical, stagewise manner by disentangling background, object shape and object appearance. 3 |
4 |
5 | 6 |
7 |
8 | 9 | 10 | 11 | ### FineGAN: Unsupervised Hierarchical Disentanglement for Fine-grained Object Generation and Discovery 12 | [Krishna Kumar Singh*](http://krsingh.cs.ucdavis.edu), [Utkarsh Ojha*](https://utkarshojha.github.io/), [Yong Jae Lee](http://web.cs.ucdavis.edu/~yjlee/) 13 |
14 | [project](http://krsingh.cs.ucdavis.edu/krishna_files/papers/finegan/index.html) | 15 | [arxiv](https://arxiv.org/abs/1811.11155) | [demo video](https://www.youtube.com/watch?v=tkk0SeWGu-8) | [talk video](https://www.youtube.com/watch?v=8qkrPSjONhA&t=51m40s) 16 |
17 | **[CVPR 2019 (Oral Presentation)](http://cvpr2019.thecvf.com/)** 18 | ## Architecture 19 |
20 | 21 | 22 | ## Requirements 23 | - Linux 24 | - Python 2.7 25 | - Pytorch 0.4.1 26 | - TensorboardX 1.2 27 | - NVIDIA GPU + CUDA CuDNN 28 | 29 | ## Getting started 30 | ### Clone the repository 31 | ```bash 32 | git clone https://github.com/kkanshul/finegan 33 | cd finegan 34 | ``` 35 | ### Setting up the data 36 | **Note**: You only need to download the data if you wish to train your own model. 37 | 38 | Download the formatted CUB data from this [link](https://drive.google.com/file/d/1ardy8L7Cb-Vn1ynQigaXpX_JHl0dhh2M/view?usp=sharing) and extract it inside the `data` directory 39 | ```bash 40 | cd data 41 | unzip birds.zip 42 | cd .. 43 | ``` 44 | ### Downloading pretrained models 45 | 46 | Pretrained generator models for CUB, Stanford Dogs are available at this [link](https://drive.google.com/file/d/1cKJAXRDQ-_a76bHWRqcIdmPXqpN8a1lR/view?usp=sharing). Download and extract them in the `models` directory. 47 | ```bash 48 | cd models 49 | unzip netG.zip 50 | cd ../code/ 51 | ``` 52 | ## Evaluating the model 53 | In `cfg/eval.yml`: 54 | - Specify the model path in `TRAIN.NET_G`. 55 | - Specify the output directory to save the generated images in `SAVE_DIR`. 56 | - Specify the number of super and fine-grained categories in `SUPER_CATEGORIES` and `FINE_GRAINED_CATEGORIES` according to our [paper](https://arxiv.org/abs/1811.11155). 57 | - Specify the option for using 'tied' latent codes in `TIED_CODES`: 58 | - if `True`, specify the child code in `TEST_CHILD_CLASS`. The background and parent codes are derived through the child code in this case. 59 | - if `False`, i.e. no relationship between parent, child or background code, specify each of them in `TEST_PARENT_CLASS`, `TEST_CHILD_CLASS` and `TEST_BACKGROUND_CLASS` respectively. 60 | - Run `python main.py --cfg cfg/eval.yml --gpu 0` 61 | 62 | ## Training your own model 63 | In `cfg/train.yml`: 64 | - Specify the dataset location in `DATA_DIR`. 65 | - **NOTE**: If you wish to train this on your own (different) dataset, please make sure it is formatted in a way similar to the CUB dataset that we've provided. 66 | - Specify the number of super and fine-grained categories that you wish for FineGAN to discover, in `SUPER_CATEGORIES` and `FINE_GRAINED_CATEGORIES`. 67 | - Specify the training hyperparameters in `TRAIN`. 68 | - Run `python main.py --cfg cfg/train.yml --gpu 0` 69 | 70 | ## Sample generation results of FineGAN 71 | ### 1. Stage wise image generation results 72 | 73 | 74 | ### 2. Grouping among the generated images (child). 75 | 76 | 77 | 78 | ## Citation 79 | If you find this code useful in your research, consider citing our work: 80 | ``` 81 | @inproceedings{singh-cvpr2019, 82 | title = {FineGAN: Unsupervised Hierarchical Disentanglement for Fine-Grained Object Generation and Discovery}, 83 | author = {Krishna Kumar Singh and Utkarsh Ojha and Yong Jae Lee}, 84 | booktitle = {CVPR}, 85 | year = {2019} 86 | } 87 | ``` 88 | ## Acknowledgement 89 | We thank the authors of [StackGAN++: Realistic Image Synthesis with Stacked Generative Adversarial Networks](https://arxiv.org/abs/1710.10916) for releasing their source code. 90 | ## Contact 91 | For any questions regarding our paper or code, contact [Krishna Kumar Singh](mailto:krsingh@ucdavis.edu) and [Utkarsh Ojha](uojha@ucdavis.edu). 92 | -------------------------------------------------------------------------------- /code/datasets.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | import sys 6 | 7 | import torch.utils.data as data 8 | import torchvision.transforms as transforms 9 | from PIL import Image 10 | import PIL 11 | import os 12 | import os.path 13 | import pickle 14 | import random 15 | import numpy as np 16 | import pandas as pd 17 | from miscc.config import cfg 18 | 19 | import torch.utils.data as data 20 | from PIL import Image 21 | import os 22 | import os.path 23 | import six 24 | import string 25 | import sys 26 | import torch 27 | from copy import deepcopy 28 | if sys.version_info[0] == 2: 29 | import cPickle as pickle 30 | else: 31 | import pickle 32 | 33 | IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', 34 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP'] 35 | 36 | 37 | def is_image_file(filename): 38 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 39 | 40 | 41 | def get_imgs(img_path, imsize, bbox=None, 42 | transform=None, normalize=None): 43 | img = Image.open(img_path).convert('RGB') 44 | width, height = img.size 45 | if bbox is not None: 46 | r = int(np.maximum(bbox[2], bbox[3]) * 0.75) 47 | center_x = int((2 * bbox[0] + bbox[2]) / 2) 48 | center_y = int((2 * bbox[1] + bbox[3]) / 2) 49 | y1 = np.maximum(0, center_y - r) 50 | y2 = np.minimum(height, center_y + r) 51 | x1 = np.maximum(0, center_x - r) 52 | x2 = np.minimum(width, center_x + r) 53 | fimg = deepcopy(img) 54 | fimg_arr = np.array(fimg) 55 | fimg = Image.fromarray(fimg_arr) 56 | cimg = img.crop([x1, y1, x2, y2]) 57 | 58 | if transform is not None: 59 | cimg = transform(cimg) 60 | 61 | 62 | retf = [] 63 | retc = [] 64 | re_cimg = transforms.Scale(imsize[1])(cimg) 65 | retc.append(normalize(re_cimg)) 66 | 67 | # We use full image to get background patches 68 | 69 | # We resize the full image to be 126 X 126 (instead of 128 X 128) for the full coverage of the input (full) image by 70 | # the receptive fields of the final convolution layer of background discriminator 71 | 72 | my_crop_width = 126 73 | re_fimg = transforms.Scale(int(my_crop_width * 76 / 64))(fimg) 74 | re_width, re_height = re_fimg.size 75 | 76 | # random cropping 77 | x_crop_range = re_width-my_crop_width 78 | y_crop_range = re_height-my_crop_width 79 | 80 | crop_start_x = np.random.randint(x_crop_range) 81 | crop_start_y = np.random.randint(y_crop_range) 82 | 83 | crop_re_fimg = re_fimg.crop([crop_start_x, crop_start_y, crop_start_x + my_crop_width, crop_start_y + my_crop_width]) 84 | warped_x1 = bbox[0] * re_width / width 85 | warped_y1 = bbox[1] * re_height / height 86 | warped_x2 = warped_x1 + (bbox[2] * re_width / width) 87 | warped_y2 = warped_y1 + (bbox[3] * re_height / height) 88 | 89 | warped_x1 =min(max(0, warped_x1 - crop_start_x), my_crop_width) 90 | warped_y1 =min(max(0, warped_y1 - crop_start_y), my_crop_width) 91 | warped_x2 =max(min(my_crop_width, warped_x2 - crop_start_x),0) 92 | warped_y2 =max(min(my_crop_width, warped_y2 - crop_start_y),0) 93 | 94 | # random flipping 95 | random_flag=np.random.randint(2) 96 | if(random_flag == 0): 97 | crop_re_fimg = crop_re_fimg.transpose(Image.FLIP_LEFT_RIGHT) 98 | flipped_x1 = my_crop_width - warped_x2 99 | flipped_x2 = my_crop_width - warped_x1 100 | warped_x1 = flipped_x1 101 | warped_x2 = flipped_x2 102 | 103 | retf.append(normalize(crop_re_fimg)) 104 | 105 | warped_bbox = [] 106 | warped_bbox.append(warped_y1) 107 | warped_bbox.append(warped_x1) 108 | warped_bbox.append(warped_y2) 109 | warped_bbox.append(warped_x2) 110 | 111 | return retf, retc, warped_bbox 112 | 113 | 114 | 115 | class Dataset(data.Dataset): 116 | def __init__(self, data_dir, base_size=64, transform = None): 117 | 118 | self.transform = transform 119 | self.norm = transforms.Compose([ 120 | transforms.ToTensor(), 121 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 122 | 123 | self.imsize = [] 124 | for i in range(cfg.TREE.BRANCH_NUM): 125 | self.imsize.append(base_size) 126 | base_size = base_size * 2 127 | 128 | self.data = [] 129 | self.data_dir = data_dir 130 | self.bbox = self.load_bbox() 131 | self.filenames = self.load_filenames(data_dir) 132 | if cfg.TRAIN.FLAG: 133 | self.iterator = self.prepair_training_pairs 134 | else: 135 | self.iterator = self.prepair_test_pairs 136 | 137 | 138 | # only used in background stage 139 | def load_bbox(self): 140 | # Returns a dictionary with image filename as 'key' and its bounding box coordinates as 'value' 141 | 142 | data_dir = self.data_dir 143 | bbox_path = os.path.join(data_dir, 'bounding_boxes.txt') 144 | df_bounding_boxes = pd.read_csv(bbox_path, 145 | delim_whitespace=True, 146 | header=None).astype(int) 147 | filepath = os.path.join(data_dir, 'images.txt') 148 | df_filenames = \ 149 | pd.read_csv(filepath, delim_whitespace=True, header=None) 150 | filenames = df_filenames[1].tolist() 151 | print('Total filenames: ', len(filenames), filenames[0]) 152 | filename_bbox = {img_file[:-4]: [] for img_file in filenames} 153 | numImgs = len(filenames) 154 | for i in xrange(0, numImgs): 155 | bbox = df_bounding_boxes.iloc[i][1:].tolist() 156 | key = filenames[i][:-4] 157 | filename_bbox[key] = bbox 158 | return filename_bbox 159 | 160 | 161 | def load_filenames(self, data_dir): 162 | filepath = os.path.join(data_dir, 'images.txt') 163 | df_filenames = \ 164 | pd.read_csv(filepath, delim_whitespace=True, header=None) 165 | filenames = df_filenames[1].tolist() 166 | filenames = [fname[:-4] for fname in filenames]; 167 | print('Load filenames from: %s (%d)' % (filepath, len(filenames))) 168 | return filenames 169 | 170 | 171 | def prepair_training_pairs(self, index): 172 | key = self.filenames[index] 173 | if self.bbox is not None: 174 | bbox = self.bbox[key] 175 | else: 176 | bbox = None 177 | data_dir = self.data_dir 178 | img_name = '%s/images/%s.jpg' % (data_dir, key) 179 | fimgs, cimgs, warped_bbox = get_imgs(img_name, self.imsize, 180 | bbox, self.transform, normalize=self.norm) 181 | 182 | rand_class= random.sample(range(cfg.FINE_GRAINED_CATEGORIES),1); # Randomly generating child code during training 183 | c_code = torch.zeros([cfg.FINE_GRAINED_CATEGORIES,]) 184 | c_code[rand_class] = 1 185 | 186 | return fimgs, cimgs, c_code, key, warped_bbox 187 | 188 | def prepair_test_pairs(self, index): 189 | key = self.filenames[index] 190 | if self.bbox is not None: 191 | bbox = self.bbox[key] 192 | else: 193 | bbox = None 194 | data_dir = self.data_dir 195 | c_code = self.c_code[index, :, :] 196 | img_name = '%s/images/%s.jpg' % (data_dir, key) 197 | _, imgs, _ = get_imgs(img_name, self.imsize, 198 | bbox, self.transform, normalize=self.norm) 199 | 200 | return imgs, c_code, key 201 | 202 | def __getitem__(self, index): 203 | return self.iterator(index) 204 | 205 | def __len__(self): 206 | return len(self.filenames) 207 | -------------------------------------------------------------------------------- /code/inception.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.utils.model_zoo as model_zoo 5 | from miscc.config import cfg 6 | 7 | 8 | __all__ = ['Inception3', 'inception_v3'] 9 | 10 | 11 | model_urls = { 12 | # Inception v3 ported from TensorFlow 13 | 'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth', 14 | } 15 | 16 | 17 | 18 | def inception_v3(pretrained=True, **kwargs): 19 | r"""Inception v3 model architecture from 20 | `"Rethinking the Inception Architecture for Computer Vision" `_. 21 | Args: 22 | pretrained (bool): If True, returns a model pre-trained on ImageNet 23 | """ 24 | if pretrained: 25 | if 'transform_input' not in kwargs: 26 | kwargs['transform_input'] = True 27 | model = Inception3(**kwargs) 28 | pretrained_dict = model_zoo.load_url(model_urls['inception_v3_google']) 29 | model_dict = model.state_dict() 30 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 31 | model_dict.update(pretrained_dict) 32 | model.load_state_dict(model_dict) 33 | print ("Inception pretrained on IMAGENET loaded") 34 | return model 35 | 36 | return Inception3(**kwargs) 37 | 38 | 39 | class Inception3(nn.Module): 40 | 41 | def __init__(self, num_classes=200, aux_logits=True, transform_input=False): 42 | super(Inception3, self).__init__() 43 | self.aux_logits = aux_logits 44 | self.transform_input = transform_input 45 | self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2) 46 | self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3) 47 | self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) 48 | self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1) 49 | self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3) 50 | self.Mixed_5b = InceptionA(192, pool_features=32) 51 | self.Mixed_5c = InceptionA(256, pool_features=64) 52 | self.Mixed_5d = InceptionA(288, pool_features=64) 53 | self.Mixed_6a = InceptionB(288) 54 | self.Mixed_6b = InceptionC(768, channels_7x7=128) 55 | self.Mixed_6c = InceptionC(768, channels_7x7=160) 56 | self.Mixed_6d = InceptionC(768, channels_7x7=160) 57 | self.Mixed_6e = InceptionC(768, channels_7x7=192) 58 | if aux_logits: 59 | self.AuxLogits = InceptionAux(768, num_classes) 60 | self.Mixed_7a = InceptionD(768) 61 | self.Mixed_7b = InceptionE(1280) 62 | self.Mixed_7c = InceptionE(2048) 63 | self.fc_new = nn.Linear(2048, num_classes) 64 | 65 | for m in self.modules(): 66 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 67 | import scipy.stats as stats 68 | stddev = m.stddev if hasattr(m, 'stddev') else 0.1 69 | X = stats.truncnorm(-2, 2, scale=stddev) 70 | values = torch.Tensor(X.rvs(m.weight.numel())) 71 | values = values.view(m.weight.size()) 72 | m.weight.data.copy_(values) 73 | elif isinstance(m, nn.BatchNorm2d): 74 | nn.init.constant_(m.weight, 1) 75 | nn.init.constant_(m.bias, 0) 76 | 77 | def forward(self, x): 78 | 79 | #No preprocessing being done right now 80 | 81 | # 299 x 299 x 3 82 | x = self.Conv2d_1a_3x3(x) 83 | # 149 x 149 x 32 84 | x = self.Conv2d_2a_3x3(x) 85 | # 147 x 147 x 32 86 | x = self.Conv2d_2b_3x3(x) 87 | # 147 x 147 x 64 88 | x = F.max_pool2d(x, kernel_size=3, stride=2) 89 | # 73 x 73 x 64 90 | x = self.Conv2d_3b_1x1(x) 91 | # 73 x 73 x 80 92 | x = self.Conv2d_4a_3x3(x) 93 | # 71 x 71 x 192 94 | x = F.max_pool2d(x, kernel_size=3, stride=2) 95 | # 35 x 35 x 192 96 | x = self.Mixed_5b(x) 97 | # 35 x 35 x 256 98 | x = self.Mixed_5c(x) 99 | # 35 x 35 x 288 100 | x = self.Mixed_5d(x) 101 | # 35 x 35 x 288 102 | x = self.Mixed_6a(x) 103 | # 17 x 17 x 768 104 | x = self.Mixed_6b(x) 105 | # 17 x 17 x 768 106 | x = self.Mixed_6c(x) 107 | # 17 x 17 x 768 108 | x = self.Mixed_6d(x) 109 | # 17 x 17 x 768 110 | x = self.Mixed_6e(x) 111 | # 17 x 17 x 768 112 | if self.training and self.aux_logits: 113 | aux = self.AuxLogits(x) 114 | # 17 x 17 x 768 115 | x = self.Mixed_7a(x) 116 | # 8 x 8 x 1280 117 | x = self.Mixed_7b(x) 118 | # 8 x 8 x 2048 119 | x = self.Mixed_7c(x) 120 | # 8 x 8 x 2048 121 | x = F.avg_pool2d(x, kernel_size=8) 122 | # 1 x 1 x 2048 123 | x = F.dropout(x, training=self.training) 124 | # 1 x 1 x 2048 125 | x = x.view(x.size(0), -1) 126 | # 2048 127 | x = self.fc_new(x) 128 | # 1000 (num_classes) 129 | if self.training and self.aux_logits: 130 | return x, aux 131 | return x 132 | 133 | 134 | class InceptionA(nn.Module): 135 | 136 | def __init__(self, in_channels, pool_features): 137 | super(InceptionA, self).__init__() 138 | self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1) 139 | 140 | self.branch5x5_1 = BasicConv2d(in_channels, 48, kernel_size=1) 141 | self.branch5x5_2 = BasicConv2d(48, 64, kernel_size=5, padding=2) 142 | 143 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1) 144 | self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1) 145 | self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, padding=1) 146 | 147 | self.branch_pool = BasicConv2d(in_channels, pool_features, kernel_size=1) 148 | 149 | def forward(self, x): 150 | branch1x1 = self.branch1x1(x) 151 | 152 | branch5x5 = self.branch5x5_1(x) 153 | branch5x5 = self.branch5x5_2(branch5x5) 154 | 155 | branch3x3dbl = self.branch3x3dbl_1(x) 156 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 157 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 158 | 159 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 160 | branch_pool = self.branch_pool(branch_pool) 161 | 162 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 163 | return torch.cat(outputs, 1) 164 | 165 | 166 | class InceptionB(nn.Module): 167 | 168 | def __init__(self, in_channels): 169 | super(InceptionB, self).__init__() 170 | self.branch3x3 = BasicConv2d(in_channels, 384, kernel_size=3, stride=2) 171 | 172 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1) 173 | self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1) 174 | self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, stride=2) 175 | 176 | def forward(self, x): 177 | branch3x3 = self.branch3x3(x) 178 | 179 | branch3x3dbl = self.branch3x3dbl_1(x) 180 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 181 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 182 | 183 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) 184 | 185 | outputs = [branch3x3, branch3x3dbl, branch_pool] 186 | return torch.cat(outputs, 1) 187 | 188 | 189 | class InceptionC(nn.Module): 190 | 191 | def __init__(self, in_channels, channels_7x7): 192 | super(InceptionC, self).__init__() 193 | self.branch1x1 = BasicConv2d(in_channels, 192, kernel_size=1) 194 | 195 | c7 = channels_7x7 196 | self.branch7x7_1 = BasicConv2d(in_channels, c7, kernel_size=1) 197 | self.branch7x7_2 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)) 198 | self.branch7x7_3 = BasicConv2d(c7, 192, kernel_size=(7, 1), padding=(3, 0)) 199 | 200 | self.branch7x7dbl_1 = BasicConv2d(in_channels, c7, kernel_size=1) 201 | self.branch7x7dbl_2 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)) 202 | self.branch7x7dbl_3 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)) 203 | self.branch7x7dbl_4 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)) 204 | self.branch7x7dbl_5 = BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3)) 205 | 206 | self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1) 207 | 208 | def forward(self, x): 209 | branch1x1 = self.branch1x1(x) 210 | 211 | branch7x7 = self.branch7x7_1(x) 212 | branch7x7 = self.branch7x7_2(branch7x7) 213 | branch7x7 = self.branch7x7_3(branch7x7) 214 | 215 | branch7x7dbl = self.branch7x7dbl_1(x) 216 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 217 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 218 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 219 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 220 | 221 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 222 | branch_pool = self.branch_pool(branch_pool) 223 | 224 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 225 | return torch.cat(outputs, 1) 226 | 227 | 228 | class InceptionD(nn.Module): 229 | 230 | def __init__(self, in_channels): 231 | super(InceptionD, self).__init__() 232 | self.branch3x3_1 = BasicConv2d(in_channels, 192, kernel_size=1) 233 | self.branch3x3_2 = BasicConv2d(192, 320, kernel_size=3, stride=2) 234 | 235 | self.branch7x7x3_1 = BasicConv2d(in_channels, 192, kernel_size=1) 236 | self.branch7x7x3_2 = BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3)) 237 | self.branch7x7x3_3 = BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0)) 238 | self.branch7x7x3_4 = BasicConv2d(192, 192, kernel_size=3, stride=2) 239 | 240 | def forward(self, x): 241 | branch3x3 = self.branch3x3_1(x) 242 | branch3x3 = self.branch3x3_2(branch3x3) 243 | 244 | branch7x7x3 = self.branch7x7x3_1(x) 245 | branch7x7x3 = self.branch7x7x3_2(branch7x7x3) 246 | branch7x7x3 = self.branch7x7x3_3(branch7x7x3) 247 | branch7x7x3 = self.branch7x7x3_4(branch7x7x3) 248 | 249 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) 250 | outputs = [branch3x3, branch7x7x3, branch_pool] 251 | return torch.cat(outputs, 1) 252 | 253 | 254 | class InceptionE(nn.Module): 255 | 256 | def __init__(self, in_channels): 257 | super(InceptionE, self).__init__() 258 | self.branch1x1 = BasicConv2d(in_channels, 320, kernel_size=1) 259 | 260 | self.branch3x3_1 = BasicConv2d(in_channels, 384, kernel_size=1) 261 | self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) 262 | self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) 263 | 264 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 448, kernel_size=1) 265 | self.branch3x3dbl_2 = BasicConv2d(448, 384, kernel_size=3, padding=1) 266 | self.branch3x3dbl_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) 267 | self.branch3x3dbl_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) 268 | 269 | self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1) 270 | 271 | def forward(self, x): 272 | branch1x1 = self.branch1x1(x) 273 | 274 | branch3x3 = self.branch3x3_1(x) 275 | branch3x3 = [ 276 | self.branch3x3_2a(branch3x3), 277 | self.branch3x3_2b(branch3x3), 278 | ] 279 | branch3x3 = torch.cat(branch3x3, 1) 280 | 281 | branch3x3dbl = self.branch3x3dbl_1(x) 282 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 283 | branch3x3dbl = [ 284 | self.branch3x3dbl_3a(branch3x3dbl), 285 | self.branch3x3dbl_3b(branch3x3dbl), 286 | ] 287 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 288 | 289 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 290 | branch_pool = self.branch_pool(branch_pool) 291 | 292 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 293 | return torch.cat(outputs, 1) 294 | 295 | 296 | class InceptionAux(nn.Module): 297 | 298 | def __init__(self, in_channels, num_classes): 299 | super(InceptionAux, self).__init__() 300 | self.conv0 = BasicConv2d(in_channels, 128, kernel_size=1) 301 | self.conv1 = BasicConv2d(128, 768, kernel_size=5) 302 | self.conv1.stddev = 0.01 303 | self.fc_new = nn.Linear(768, num_classes) 304 | self.fc_new.stddev = 0.001 305 | 306 | def forward(self, x): 307 | # 17 x 17 x 768 308 | x = F.avg_pool2d(x, kernel_size=5, stride=3) 309 | # 5 x 5 x 768 310 | x = self.conv0(x) 311 | # 5 x 5 x 128 312 | x = self.conv1(x) 313 | # 1 x 1 x 768 314 | x = x.view(x.size(0), -1) 315 | # 768 316 | x = self.fc_new(x) 317 | # 1000 318 | return x 319 | 320 | 321 | class BasicConv2d(nn.Module): 322 | 323 | def __init__(self, in_channels, out_channels, **kwargs): 324 | super(BasicConv2d, self).__init__() 325 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 326 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 327 | 328 | def forward(self, x): 329 | x = self.conv(x) 330 | x = self.bn(x) 331 | return F.relu(x, inplace=True) 332 | 333 | 334 | -------------------------------------------------------------------------------- /code/model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.parallel 5 | from miscc.config import cfg 6 | from torch.autograd import Variable 7 | import torch.nn.functional as F 8 | from torch.nn import Upsample 9 | 10 | 11 | class GLU(nn.Module): 12 | def __init__(self): 13 | super(GLU, self).__init__() 14 | 15 | def forward(self, x): 16 | nc = x.size(1) 17 | assert nc % 2 == 0, 'channels dont divide 2!' 18 | nc = int(nc/2) 19 | return x[:, :nc] * F.sigmoid(x[:, nc:]) 20 | 21 | 22 | def conv3x3(in_planes, out_planes): 23 | "3x3 convolution with padding" 24 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, 25 | padding=1, bias=False) 26 | 27 | 28 | def convlxl(in_planes, out_planes): 29 | "3x3 convolution with padding" 30 | return nn.Conv2d(in_planes, out_planes, kernel_size=13, stride=1, 31 | padding=1, bias=False) 32 | 33 | 34 | def child_to_parent(child_c_code, classes_child, classes_parent): 35 | 36 | ratio = classes_child / classes_parent 37 | arg_parent = torch.argmax(child_c_code, dim = 1) / ratio 38 | parent_c_code = torch.zeros([child_c_code.size(0), classes_parent]).cuda() 39 | for i in range(child_c_code.size(0)): 40 | parent_c_code[i][arg_parent[i]] = 1 41 | return parent_c_code 42 | 43 | 44 | # ############## G networks ################################################ 45 | # Upsale the spatial size by a factor of 2 46 | def upBlock(in_planes, out_planes): 47 | block = nn.Sequential( 48 | nn.Upsample(scale_factor=2, mode='nearest'), 49 | conv3x3(in_planes, out_planes * 2), 50 | nn.BatchNorm2d(out_planes * 2), 51 | GLU() 52 | ) 53 | return block 54 | 55 | def sameBlock(in_planes, out_planes): 56 | block = nn.Sequential( 57 | conv3x3(in_planes, out_planes * 2), 58 | nn.BatchNorm2d(out_planes * 2), 59 | GLU() 60 | ) 61 | return block 62 | 63 | # Keep the spatial size 64 | def Block3x3_relu(in_planes, out_planes): 65 | block = nn.Sequential( 66 | conv3x3(in_planes, out_planes * 2), 67 | nn.BatchNorm2d(out_planes * 2), 68 | GLU() 69 | ) 70 | return block 71 | 72 | 73 | class ResBlock(nn.Module): 74 | def __init__(self, channel_num): 75 | super(ResBlock, self).__init__() 76 | self.block = nn.Sequential( 77 | conv3x3(channel_num, channel_num * 2), 78 | nn.BatchNorm2d(channel_num * 2), 79 | GLU(), 80 | conv3x3(channel_num, channel_num), 81 | nn.BatchNorm2d(channel_num) 82 | ) 83 | 84 | 85 | def forward(self, x): 86 | residual = x 87 | out = self.block(x) 88 | out += residual 89 | return out 90 | 91 | 92 | class INIT_STAGE_G(nn.Module): 93 | def __init__(self, ngf, c_flag): 94 | super(INIT_STAGE_G, self).__init__() 95 | self.gf_dim = ngf 96 | self.c_flag= c_flag 97 | 98 | if self.c_flag==1 : 99 | self.in_dim = cfg.GAN.Z_DIM + cfg.SUPER_CATEGORIES 100 | elif self.c_flag==2: 101 | self.in_dim = cfg.GAN.Z_DIM + cfg.FINE_GRAINED_CATEGORIES 102 | 103 | self.define_module() 104 | 105 | def define_module(self): 106 | in_dim = self.in_dim 107 | ngf = self.gf_dim 108 | self.fc = nn.Sequential( 109 | nn.Linear(in_dim, ngf * 4 * 4 * 2, bias=False), 110 | nn.BatchNorm1d(ngf * 4 * 4 * 2), 111 | GLU()) 112 | 113 | self.upsample1 = upBlock(ngf, ngf // 2) 114 | self.upsample2 = upBlock(ngf // 2, ngf // 4) 115 | self.upsample3 = upBlock(ngf // 4, ngf // 8) 116 | self.upsample4 = upBlock(ngf // 8, ngf // 16) 117 | self.upsample5 = upBlock(ngf // 16, ngf // 16) 118 | 119 | 120 | def forward(self, z_code, code): 121 | 122 | in_code = torch.cat((code, z_code), 1) 123 | out_code = self.fc(in_code) 124 | out_code = out_code.view(-1, self.gf_dim, 4, 4) 125 | out_code = self.upsample1(out_code) 126 | out_code = self.upsample2(out_code) 127 | out_code = self.upsample3(out_code) 128 | out_code = self.upsample4(out_code) 129 | out_code = self.upsample5(out_code) 130 | 131 | return out_code 132 | 133 | 134 | class NEXT_STAGE_G(nn.Module): 135 | def __init__(self, ngf, use_hrc = 1, num_residual=cfg.GAN.R_NUM): 136 | super(NEXT_STAGE_G, self).__init__() 137 | self.gf_dim = ngf 138 | if use_hrc == 1: # For parent stage 139 | self.ef_dim = cfg.SUPER_CATEGORIES 140 | 141 | else: # For child stage 142 | self.ef_dim = cfg.FINE_GRAINED_CATEGORIES 143 | 144 | self.num_residual = num_residual 145 | self.define_module() 146 | 147 | def _make_layer(self, block, channel_num): 148 | layers = [] 149 | for i in range(self.num_residual): 150 | layers.append(block(channel_num)) 151 | return nn.Sequential(*layers) 152 | 153 | def define_module(self): 154 | ngf = self.gf_dim 155 | efg = self.ef_dim 156 | self.jointConv = Block3x3_relu(ngf + efg, ngf) 157 | self.residual = self._make_layer(ResBlock, ngf) 158 | self.samesample = sameBlock(ngf, ngf // 2) 159 | 160 | def forward(self, h_code, code): 161 | s_size = h_code.size(2) 162 | code = code.view(-1, self.ef_dim, 1, 1) 163 | code = code.repeat(1, 1, s_size, s_size) 164 | h_c_code = torch.cat((code, h_code), 1) 165 | out_code = self.jointConv(h_c_code) 166 | out_code = self.residual(out_code) 167 | out_code = self.samesample(out_code) 168 | return out_code 169 | 170 | 171 | class GET_IMAGE_G(nn.Module): 172 | def __init__(self, ngf): 173 | super(GET_IMAGE_G, self).__init__() 174 | self.gf_dim = ngf 175 | self.img = nn.Sequential( 176 | conv3x3(ngf, 3), 177 | nn.Tanh() 178 | ) 179 | 180 | def forward(self, h_code): 181 | out_img = self.img(h_code) 182 | return out_img 183 | 184 | 185 | 186 | class GET_MASK_G(nn.Module): 187 | def __init__(self, ngf): 188 | super(GET_MASK_G, self).__init__() 189 | self.gf_dim = ngf 190 | self.img = nn.Sequential( 191 | conv3x3(ngf, 1), 192 | nn.Sigmoid() 193 | ) 194 | 195 | def forward(self, h_code): 196 | out_img = self.img(h_code) 197 | return out_img 198 | 199 | 200 | class G_NET(nn.Module): 201 | def __init__(self): 202 | super(G_NET, self).__init__() 203 | self.gf_dim = cfg.GAN.GF_DIM 204 | self.define_module() 205 | self.upsampling = Upsample(scale_factor = 2, mode = 'bilinear') 206 | self.scale_fimg = nn.UpsamplingBilinear2d(size = [126, 126]) 207 | 208 | def define_module(self): 209 | 210 | #Background stage 211 | self.h_net1_bg = INIT_STAGE_G(self.gf_dim * 16, 2) 212 | self.img_net1_bg = GET_IMAGE_G(self.gf_dim) # Background generation network 213 | 214 | # Parent stage networks 215 | self.h_net1 = INIT_STAGE_G(self.gf_dim * 16, 1) 216 | self.h_net2 = NEXT_STAGE_G(self.gf_dim, use_hrc = 1) 217 | self.img_net2 = GET_IMAGE_G(self.gf_dim // 2) # Parent foreground generation network 218 | self.img_net2_mask= GET_MASK_G(self.gf_dim // 2) # Parent mask generation network 219 | 220 | # Child stage networks 221 | self.h_net3 = NEXT_STAGE_G(self.gf_dim // 2, use_hrc = 0) 222 | self.img_net3 = GET_IMAGE_G(self.gf_dim // 4) # Child foreground generation network 223 | self.img_net3_mask = GET_MASK_G(self.gf_dim // 4) # Child mask generation network 224 | 225 | def forward(self, z_code, c_code, p_code = None, bg_code = None): 226 | 227 | fake_imgs = [] # Will contain [background image, parent image, child image] 228 | fg_imgs = [] # Will contain [parent foreground, child foreground] 229 | mk_imgs = [] # Will contain [parent mask, child mask] 230 | fg_mk = [] # Will contain [masked parent foreground, masked child foreground] 231 | 232 | if cfg.TIED_CODES: 233 | p_code = child_to_parent(c_code, cfg.FINE_GRAINED_CATEGORIES, cfg.SUPER_CATEGORIES) # Obtaining the parent code from child code 234 | bg_code = c_code 235 | 236 | #Background stage 237 | h_code1_bg = self.h_net1_bg(z_code, bg_code) 238 | fake_img1 = self.img_net1_bg(h_code1_bg) # Background image 239 | fake_img1_126 = self.scale_fimg(fake_img1) # Resizing fake background image from 128x128 to the resolution which background discriminator expects: 126 x 126. 240 | fake_imgs.append(fake_img1_126) 241 | 242 | #Parent stage 243 | h_code1 = self.h_net1(z_code, p_code) 244 | h_code2 = self.h_net2(h_code1, p_code) 245 | fake_img2_foreground = self.img_net2(h_code2) # Parent foreground 246 | fake_img2_mask = self.img_net2_mask(h_code2) # Parent mask 247 | ones_mask_p = torch.ones_like(fake_img2_mask) 248 | opp_mask_p = ones_mask_p - fake_img2_mask 249 | fg_masked2 = torch.mul(fake_img2_foreground, fake_img2_mask) 250 | fg_mk.append(fg_masked2) 251 | bg_masked2 = torch.mul(fake_img1, opp_mask_p) 252 | fake_img2_final = fg_masked2 + bg_masked2 # Parent image 253 | fake_imgs.append(fake_img2_final) 254 | fg_imgs.append(fake_img2_foreground) 255 | mk_imgs.append(fake_img2_mask) 256 | 257 | #Child stage 258 | h_code3 = self.h_net3(h_code2, c_code) 259 | fake_img3_foreground = self.img_net3(h_code3) # Child foreground 260 | fake_img3_mask = self.img_net3_mask(h_code3) # Child mask 261 | ones_mask_c = torch.ones_like(fake_img3_mask) 262 | opp_mask_c = ones_mask_c - fake_img3_mask 263 | fg_masked3 = torch.mul(fake_img3_foreground, fake_img3_mask) 264 | fg_mk.append(fg_masked3) 265 | bg_masked3 = torch.mul(fake_img2_final, opp_mask_c) 266 | fake_img3_final = fg_masked3 + bg_masked3 # Child image 267 | fake_imgs.append(fake_img3_final) 268 | fg_imgs.append(fake_img3_foreground) 269 | mk_imgs.append(fake_img3_mask) 270 | 271 | return fake_imgs, fg_imgs, mk_imgs, fg_mk 272 | 273 | 274 | # ############## D networks ################################################ 275 | def Block3x3_leakRelu(in_planes, out_planes): 276 | block = nn.Sequential( 277 | conv3x3(in_planes, out_planes), 278 | nn.BatchNorm2d(out_planes), 279 | nn.LeakyReLU(0.2, inplace=True) 280 | ) 281 | return block 282 | 283 | 284 | # Downsale the spatial size by a factor of 2 285 | def downBlock(in_planes, out_planes): 286 | block = nn.Sequential( 287 | nn.Conv2d(in_planes, out_planes, 4, 2, 1, bias=False), 288 | nn.BatchNorm2d(out_planes), 289 | nn.LeakyReLU(0.2, inplace=True) 290 | ) 291 | return block 292 | 293 | 294 | 295 | def encode_parent_and_child_img(ndf): # Defines the encoder network used for parent and child image 296 | encode_img = nn.Sequential( 297 | nn.Conv2d(3, ndf, 4, 2, 1, bias=False), 298 | nn.LeakyReLU(0.2, inplace=True), 299 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), 300 | nn.BatchNorm2d(ndf * 2), 301 | nn.LeakyReLU(0.2, inplace=True), 302 | nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), 303 | nn.BatchNorm2d(ndf * 4), 304 | nn.LeakyReLU(0.2, inplace=True), 305 | nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), 306 | nn.BatchNorm2d(ndf * 8), 307 | nn.LeakyReLU(0.2, inplace=True) 308 | ) 309 | return encode_img 310 | 311 | 312 | def encode_background_img(ndf): # Defines the encoder network used for background image 313 | encode_img = nn.Sequential( 314 | nn.Conv2d(3, ndf, 4, 2, 0, bias=False), 315 | nn.LeakyReLU(0.2, inplace=True), 316 | nn.Conv2d(ndf, ndf * 2, 4, 2, 0, bias=False), 317 | nn.LeakyReLU(0.2, inplace=True), 318 | nn.Conv2d(ndf * 2, ndf * 4, 4, 1, 0, bias=False), 319 | nn.LeakyReLU(0.2, inplace=True), 320 | ) 321 | return encode_img 322 | 323 | 324 | class D_NET(nn.Module): 325 | def __init__(self, stg_no): 326 | super(D_NET, self).__init__() 327 | self.df_dim = cfg.GAN.DF_DIM 328 | self.stg_no = stg_no 329 | 330 | if self.stg_no == 0: 331 | self.ef_dim = 1 332 | elif self.stg_no == 1: 333 | self.ef_dim = cfg.SUPER_CATEGORIES 334 | elif self.stg_no == 2: 335 | self.ef_dim = cfg.FINE_GRAINED_CATEGORIES 336 | else: 337 | print ("Invalid stage number. Set stage number as follows:") 338 | print ("0 - for background stage") 339 | print ("1 - for parent stage") 340 | print ("2 - for child stage") 341 | print ("...Exiting now") 342 | sys.exit(0) 343 | self.define_module() 344 | 345 | def define_module(self): 346 | ndf = self.df_dim 347 | efg = self.ef_dim 348 | 349 | if self.stg_no == 0: 350 | 351 | self.patchgan_img_code_s16 = encode_background_img(ndf) 352 | self.uncond_logits1 = nn.Sequential( 353 | nn.Conv2d(ndf * 4, 1, kernel_size=4, stride=1), 354 | nn.Sigmoid()) 355 | self.uncond_logits2 = nn.Sequential( 356 | nn.Conv2d(ndf * 4, 1, kernel_size=4, stride=1), 357 | nn.Sigmoid()) 358 | 359 | else: 360 | self.img_code_s16 = encode_parent_and_child_img(ndf) 361 | self.img_code_s32 = downBlock(ndf * 8, ndf * 16) 362 | self.img_code_s32_1 = Block3x3_leakRelu(ndf * 16, ndf * 8) 363 | 364 | self.logits = nn.Sequential( 365 | nn.Conv2d(ndf * 8, efg, kernel_size=4, stride=4)) 366 | 367 | self.jointConv = Block3x3_leakRelu(ndf * 8, ndf * 8) 368 | self.uncond_logits = nn.Sequential( 369 | nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4), 370 | nn.Sigmoid()) 371 | 372 | 373 | def forward(self, x_var): 374 | 375 | if self.stg_no == 0: 376 | x_code = self.patchgan_img_code_s16(x_var) 377 | classi_score = self.uncond_logits1(x_code) # Background vs Foreground classification score (0 - background and 1 - foreground) 378 | rf_score = self.uncond_logits2(x_code) # Real/Fake score for the background image 379 | return [classi_score, rf_score] 380 | 381 | elif self.stg_no > 0: 382 | x_code = self.img_code_s16(x_var) 383 | x_code = self.img_code_s32(x_code) 384 | x_code = self.img_code_s32_1(x_code) 385 | h_c_code = self.jointConv(x_code) 386 | code_pred = self.logits(h_c_code) # Predicts the parent code and child code in parent and child stage respectively 387 | rf_score = self.uncond_logits(x_code) # This score is not used in parent stage while training 388 | return [code_pred.view(-1, self.ef_dim), rf_score.view(-1)] 389 | 390 | 391 | 392 | -------------------------------------------------------------------------------- /code/trainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from six.moves import range 3 | import sys 4 | import numpy as np 5 | import os 6 | import random 7 | import time 8 | from PIL import Image 9 | from copy import deepcopy 10 | 11 | import torch.backends.cudnn as cudnn 12 | import torch 13 | import torch.nn as nn 14 | from torch.autograd import Variable 15 | import torch.optim as optim 16 | import torchvision.utils as vutils 17 | from torch.nn.functional import softmax, log_softmax 18 | from torch.nn.functional import cosine_similarity 19 | from tensorboardX import summary 20 | from tensorboardX import FileWriter 21 | 22 | from miscc.config import cfg 23 | from miscc.utils import mkdir_p 24 | 25 | from model import G_NET, D_NET 26 | 27 | 28 | # ################## Shared functions ################### 29 | 30 | def child_to_parent(child_c_code, classes_child, classes_parent): 31 | 32 | ratio = classes_child / classes_parent 33 | arg_parent = torch.argmax(child_c_code, dim = 1) / ratio 34 | parent_c_code = torch.zeros([child_c_code.size(0), classes_parent]).cuda() 35 | for i in range(child_c_code.size(0)): 36 | parent_c_code[i][arg_parent[i]] = 1 37 | return parent_c_code 38 | 39 | 40 | def weights_init(m): 41 | classname = m.__class__.__name__ 42 | if classname.find('Conv') != -1: 43 | nn.init.orthogonal(m.weight.data, 1.0) 44 | elif classname.find('BatchNorm') != -1: 45 | m.weight.data.normal_(1.0, 0.02) 46 | m.bias.data.fill_(0) 47 | elif classname.find('Linear') != -1: 48 | nn.init.orthogonal(m.weight.data, 1.0) 49 | if m.bias is not None: 50 | m.bias.data.fill_(0.0) 51 | 52 | 53 | def load_params(model, new_param): 54 | for p, new_p in zip(model.parameters(), new_param): 55 | p.data.copy_(new_p) 56 | 57 | 58 | def copy_G_params(model): 59 | flatten = deepcopy(list(p.data for p in model.parameters())) 60 | return flatten 61 | 62 | def load_network(gpus): 63 | netG = G_NET() 64 | netG.apply(weights_init) 65 | netG = torch.nn.DataParallel(netG, device_ids=gpus) 66 | print(netG) 67 | 68 | netsD = [] 69 | for i in range(3): # 3 discriminators for background, parent and child stage 70 | netsD.append(D_NET(i)) 71 | 72 | for i in range(len(netsD)): 73 | netsD[i].apply(weights_init) 74 | netsD[i] = torch.nn.DataParallel(netsD[i], device_ids=gpus) 75 | 76 | count = 0 77 | 78 | if cfg.TRAIN.NET_G != '': 79 | state_dict = torch.load(cfg.TRAIN.NET_G) 80 | netG.load_state_dict(state_dict) 81 | print('Load ', cfg.TRAIN.NET_G) 82 | 83 | istart = cfg.TRAIN.NET_G.rfind('_') + 1 84 | iend = cfg.TRAIN.NET_G.rfind('.') 85 | count = cfg.TRAIN.NET_G[istart:iend] 86 | count = int(count) + 1 87 | 88 | if cfg.TRAIN.NET_D != '': 89 | for i in range(len(netsD)): 90 | print('Load %s_%d.pth' % (cfg.TRAIN.NET_D, i)) 91 | state_dict = torch.load('%s_%d.pth' % (cfg.TRAIN.NET_D, i)) 92 | netsD[i].load_state_dict(state_dict) 93 | 94 | if cfg.CUDA: 95 | netG.cuda() 96 | for i in range(len(netsD)): 97 | netsD[i].cuda() 98 | 99 | return netG, netsD, len(netsD), count 100 | 101 | 102 | def define_optimizers(netG, netsD): 103 | optimizersD = [] 104 | num_Ds = len(netsD) 105 | for i in range(num_Ds): 106 | opt = optim.Adam(netsD[i].parameters(), 107 | lr=cfg.TRAIN.DISCRIMINATOR_LR, 108 | betas=(0.5, 0.999)) 109 | optimizersD.append(opt) 110 | 111 | optimizerG = [] 112 | optimizerG.append(optim.Adam(netG.parameters(), 113 | lr=cfg.TRAIN.GENERATOR_LR, 114 | betas=(0.5, 0.999))) 115 | 116 | for i in range(num_Ds): 117 | if i==1: 118 | opt = optim.Adam(netsD[i].parameters(), 119 | lr=cfg.TRAIN.GENERATOR_LR, 120 | betas=(0.5, 0.999)) 121 | optimizerG.append(opt) 122 | elif i==2: 123 | opt = optim.Adam([{'params':netsD[i].module.jointConv.parameters()},{'params':netsD[i].module.logits.parameters()}], 124 | lr=cfg.TRAIN.GENERATOR_LR, 125 | betas=(0.5, 0.999)) 126 | optimizerG.append(opt) 127 | 128 | return optimizerG, optimizersD 129 | 130 | 131 | def save_model(netG, avg_param_G, netsD, epoch, model_dir): 132 | load_params(netG, avg_param_G) 133 | torch.save( 134 | netG.state_dict(), 135 | '%s/netG_%d.pth' % (model_dir, epoch)) 136 | for i in range(len(netsD)): 137 | netD = netsD[i] 138 | torch.save( 139 | netD.state_dict(), 140 | '%s/netD%d.pth' % (model_dir, i)) 141 | print('Save G/Ds models.') 142 | 143 | 144 | def save_img_results(imgs_tcpu, fake_imgs, num_imgs, 145 | count, image_dir, summary_writer): 146 | num = cfg.TRAIN.VIS_COUNT 147 | 148 | real_img = imgs_tcpu[-1][0:num] 149 | vutils.save_image( 150 | real_img, '%s/real_samples%09d.png' % (image_dir,count), 151 | normalize=True) 152 | real_img_set = vutils.make_grid(real_img).numpy() 153 | real_img_set = np.transpose(real_img_set, (1, 2, 0)) 154 | real_img_set = real_img_set * 255 155 | real_img_set = real_img_set.astype(np.uint8) 156 | 157 | for i in range(len(fake_imgs)): 158 | fake_img = fake_imgs[i][0:num] 159 | 160 | vutils.save_image( 161 | fake_img.data, '%s/count_%09d_fake_samples%d.png' % 162 | (image_dir, count, i), normalize=True) 163 | 164 | fake_img_set = vutils.make_grid(fake_img.data).cpu().numpy() 165 | 166 | fake_img_set = np.transpose(fake_img_set, (1, 2, 0)) 167 | fake_img_set = (fake_img_set + 1) * 255 / 2 168 | fake_img_set = fake_img_set.astype(np.uint8) 169 | summary_writer.flush() 170 | 171 | 172 | 173 | class FineGAN_trainer(object): 174 | def __init__(self, output_dir, data_loader, imsize): 175 | if cfg.TRAIN.FLAG: 176 | self.model_dir = os.path.join(output_dir, 'Model') 177 | self.image_dir = os.path.join(output_dir, 'Image') 178 | self.log_dir = os.path.join(output_dir, 'Log') 179 | mkdir_p(self.model_dir) 180 | mkdir_p(self.image_dir) 181 | mkdir_p(self.log_dir) 182 | self.summary_writer = FileWriter(self.log_dir) 183 | 184 | s_gpus = cfg.GPU_ID.split(',') 185 | self.gpus = [int(ix) for ix in s_gpus] 186 | self.num_gpus = len(self.gpus) 187 | torch.cuda.set_device(self.gpus[0]) 188 | cudnn.benchmark = True 189 | 190 | self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus 191 | self.max_epoch = cfg.TRAIN.MAX_EPOCH 192 | self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL 193 | 194 | self.data_loader = data_loader 195 | self.num_batches = len(self.data_loader) 196 | 197 | 198 | 199 | def prepare_data(self, data): 200 | fimgs, cimgs, c_code, _, warped_bbox = data 201 | 202 | real_vfimgs, real_vcimgs = [], [] 203 | if cfg.CUDA: 204 | vc_code = Variable(c_code).cuda() 205 | for i in range(len(warped_bbox)): 206 | warped_bbox[i] = Variable(warped_bbox[i]).float().cuda() 207 | 208 | else: 209 | vc_code = Variable(c_code) 210 | for i in range(len(warped_bbox)): 211 | warped_bbox[i] = Variable(warped_bbox[i]) 212 | 213 | if cfg.CUDA: 214 | real_vfimgs.append(Variable(fimgs[0]).cuda()) 215 | real_vcimgs.append(Variable(cimgs[0]).cuda()) 216 | else: 217 | real_vfimgs.append(Variable(fimgs[0])) 218 | real_vcimgs.append(Variable(cimgs[0])) 219 | 220 | return fimgs, real_vfimgs, real_vcimgs, vc_code, warped_bbox 221 | 222 | def train_Dnet(self, idx, count): 223 | if idx == 0 or idx == 2: # Discriminator is only trained in background and child stage. (NOT in parent stage) 224 | flag = count % 100 225 | batch_size = self.real_fimgs[0].size(0) 226 | criterion, criterion_one = self.criterion, self.criterion_one 227 | 228 | netD, optD = self.netsD[idx], self.optimizersD[idx] 229 | if idx == 0: 230 | real_imgs = self.real_fimgs[0] 231 | 232 | elif idx == 2: 233 | real_imgs = self.real_cimgs[0] 234 | 235 | fake_imgs = self.fake_imgs[idx] 236 | netD.zero_grad() 237 | real_logits = netD(real_imgs) 238 | 239 | if idx == 2: 240 | fake_labels = torch.zeros_like(real_logits[1]) 241 | real_labels = torch.ones_like(real_logits[1]) 242 | elif idx == 0: 243 | 244 | fake_labels = torch.zeros_like(real_logits[1]) 245 | ext, output = real_logits 246 | weights_real = torch.ones_like(output) 247 | real_labels = torch.ones_like(output) 248 | 249 | for i in range(batch_size): 250 | x1 = self.warped_bbox[0][i] 251 | x2 = self.warped_bbox[2][i] 252 | y1 = self.warped_bbox[1][i] 253 | y2 = self.warped_bbox[3][i] 254 | 255 | a1 = max(torch.tensor(0).float().cuda(), torch.ceil((x1 - self.recp_field)/self.patch_stride)) 256 | a2 = min(torch.tensor(self.n_out - 1).float().cuda(), torch.floor((self.n_out - 1) - ((126 - self.recp_field) - x2)/self.patch_stride)) + 1 257 | b1 = max(torch.tensor(0).float().cuda(), torch.ceil((y1 - self.recp_field)/self.patch_stride)) 258 | b2 = min(torch.tensor(self.n_out - 1).float().cuda(), torch.floor((self.n_out - 1) - ((126 - self.recp_field) - y2)/self.patch_stride)) + 1 259 | 260 | if (x1 != x2 and y1 != y2): 261 | weights_real[i, :, a1.type(torch.int) : a2.type(torch.int) , b1.type(torch.int) : b2.type(torch.int)] = 0.0 262 | 263 | norm_fact_real = weights_real.sum() 264 | norm_fact_fake = weights_real.shape[0]*weights_real.shape[1]*weights_real.shape[2]*weights_real.shape[3] 265 | real_logits = ext, output 266 | 267 | fake_logits = netD(fake_imgs.detach()) 268 | 269 | 270 | 271 | if idx == 0: # Background stage 272 | 273 | errD_real_uncond = criterion(real_logits[1], real_labels) # Real/Fake loss for 'real background' (on patch level) 274 | errD_real_uncond = torch.mul(errD_real_uncond, weights_real) # Masking output units which correspond to receptive fields which lie within the boundin box 275 | errD_real_uncond = errD_real_uncond.mean() 276 | 277 | errD_real_uncond_classi = criterion(real_logits[0], weights_real) # Background/foreground classification loss 278 | errD_real_uncond_classi = errD_real_uncond_classi.mean() 279 | 280 | errD_fake_uncond = criterion(fake_logits[1], fake_labels) # Real/Fake loss for 'fake background' (on patch level) 281 | errD_fake_uncond = errD_fake_uncond.mean() 282 | 283 | if (norm_fact_real > 0): # Normalizing the real/fake loss for background after accounting the number of masked members in the output. 284 | errD_real = errD_real_uncond * ((norm_fact_fake * 1.0) /(norm_fact_real * 1.0)) 285 | else: 286 | errD_real = errD_real_uncond 287 | 288 | errD_fake = errD_fake_uncond 289 | errD = ((errD_real + errD_fake) * cfg.TRAIN.BG_LOSS_WT) + errD_real_uncond_classi 290 | 291 | if idx == 2: 292 | 293 | errD_real = criterion_one(real_logits[1], real_labels) # Real/Fake loss for the real image 294 | errD_fake = criterion_one(fake_logits[1], fake_labels) # Real/Fake loss for the fake image 295 | errD = errD_real + errD_fake 296 | 297 | if (idx == 0 or idx == 2): 298 | errD.backward() 299 | optD.step() 300 | 301 | if (flag == 0): 302 | summary_D = summary.scalar('D_loss%d' % idx, errD.data[0]) 303 | self.summary_writer.add_summary(summary_D, count) 304 | summary_D_real = summary.scalar('D_loss_real_%d' % idx, errD_real.data[0]) 305 | self.summary_writer.add_summary(summary_D_real, count) 306 | summary_D_fake = summary.scalar('D_loss_fake_%d' % idx, errD_fake.data[0]) 307 | self.summary_writer.add_summary(summary_D_fake, count) 308 | 309 | return errD 310 | 311 | def train_Gnet(self, count): 312 | self.netG.zero_grad() 313 | for myit in range(len(self.netsD)): 314 | self.netsD[myit].zero_grad() 315 | 316 | errG_total = 0 317 | flag = count % 100 318 | batch_size = self.real_fimgs[0].size(0) 319 | criterion_one, criterion_class, c_code, p_code = self.criterion_one, self.criterion_class, self.c_code, self.p_code 320 | 321 | for i in range(self.num_Ds): 322 | 323 | outputs = self.netsD[i](self.fake_imgs[i]) 324 | 325 | if i == 0 or i == 2: # real/fake loss for background (0) and child (2) stage 326 | real_labels = torch.ones_like(outputs[1]) 327 | errG = criterion_one(outputs[1], real_labels) 328 | if i==0: 329 | errG = errG * cfg.TRAIN.BG_LOSS_WT 330 | errG_classi = criterion_one(outputs[0], real_labels) # Background/Foreground classification loss for the fake background image (on patch level) 331 | errG = errG + errG_classi 332 | errG_total = errG_total + errG 333 | 334 | if i == 1: # Mutual information loss for the parent stage (1) 335 | pred_p = self.netsD[i](self.fg_mk[i-1]) 336 | errG_info = criterion_class(pred_p[0], torch.nonzero(p_code.long())[:,1]) 337 | elif i == 2: # Mutual information loss for the child stage (2) 338 | pred_c = self.netsD[i](self.fg_mk[i-1]) 339 | errG_info = criterion_class(pred_c[0], torch.nonzero(c_code.long())[:,1]) 340 | 341 | if(i>0): 342 | errG_total = errG_total + errG_info 343 | 344 | if flag == 0: 345 | if i>0: 346 | summary_D_class = summary.scalar('Information_loss_%d' % i, errG_info.data[0]) 347 | self.summary_writer.add_summary(summary_D_class, count) 348 | 349 | if i == 0 or i == 2: 350 | summary_D = summary.scalar('G_loss%d' % i, errG.data[0]) 351 | self.summary_writer.add_summary(summary_D, count) 352 | 353 | errG_total.backward() 354 | for myit in range(len(self.netsD)): 355 | self.optimizerG[myit].step() 356 | return errG_total 357 | 358 | def train(self): 359 | self.netG, self.netsD, self.num_Ds, start_count = load_network(self.gpus) 360 | avg_param_G = copy_G_params(self.netG) 361 | 362 | self.optimizerG, self.optimizersD = \ 363 | define_optimizers(self.netG, self.netsD) 364 | 365 | self.criterion = nn.BCELoss(reduce=False) 366 | self.criterion_one = nn.BCELoss() 367 | self.criterion_class = nn.CrossEntropyLoss() 368 | 369 | self.real_labels = \ 370 | Variable(torch.FloatTensor(self.batch_size).fill_(1)) 371 | self.fake_labels = \ 372 | Variable(torch.FloatTensor(self.batch_size).fill_(0)) 373 | 374 | nz = cfg.GAN.Z_DIM 375 | noise = Variable(torch.FloatTensor(self.batch_size, nz)) 376 | fixed_noise = \ 377 | Variable(torch.FloatTensor(self.batch_size, nz).normal_(0, 1)) 378 | hard_noise = \ 379 | Variable(torch.FloatTensor(self.batch_size, nz).normal_(0, 1)).cuda() 380 | 381 | self.patch_stride = float(4) # Receptive field stride given the current discriminator architecture for background stage 382 | self.n_out = 24 # Output size of the discriminator at the background stage; N X N where N = 24 383 | self.recp_field = 34 # Receptive field of each of the member of N X N 384 | 385 | 386 | if cfg.CUDA: 387 | self.criterion.cuda() 388 | self.criterion_one.cuda() 389 | self.criterion_class.cuda() 390 | self.real_labels = self.real_labels.cuda() 391 | self.fake_labels = self.fake_labels.cuda() 392 | noise, fixed_noise = noise.cuda(), fixed_noise.cuda() 393 | 394 | print ("Starting normal FineGAN training..") 395 | count = start_count 396 | start_epoch = start_count // (self.num_batches) 397 | 398 | for epoch in range(start_epoch, self.max_epoch): 399 | start_t = time.time() 400 | 401 | for step, data in enumerate(self.data_loader, 0): 402 | 403 | self.imgs_tcpu, self.real_fimgs, self.real_cimgs, \ 404 | self.c_code, self.warped_bbox = self.prepare_data(data) 405 | 406 | # Feedforward through Generator. Obtain stagewise fake images 407 | noise.data.normal_(0, 1) 408 | self.fake_imgs, self.fg_imgs, self.mk_imgs, self.fg_mk = \ 409 | self.netG(noise, self.c_code) 410 | 411 | # Obtain the parent code given the child code 412 | self.p_code = child_to_parent(self.c_code, cfg.FINE_GRAINED_CATEGORIES, cfg.SUPER_CATEGORIES) 413 | 414 | # Update Discriminator networks 415 | errD_total = 0 416 | for i in range(self.num_Ds): 417 | if i == 0 or i == 2: # only at parent and child stage 418 | errD = self.train_Dnet(i, count) 419 | errD_total += errD 420 | 421 | # Update the Generator networks 422 | errG_total = self.train_Gnet(count) 423 | for p, avg_p in zip(self.netG.parameters(), avg_param_G): 424 | avg_p.mul_(0.999).add_(0.001, p.data) 425 | 426 | count = count + 1 427 | 428 | if count % cfg.TRAIN.SNAPSHOT_INTERVAL == 0: 429 | backup_para = copy_G_params(self.netG) 430 | save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir) 431 | # Save images 432 | load_params(self.netG, avg_param_G) 433 | self.netG.eval() 434 | with torch.set_grad_enabled(False): 435 | self.fake_imgs, self.fg_imgs, self.mk_imgs, self.fg_mk = \ 436 | self.netG(fixed_noise, self.c_code) 437 | save_img_results(self.imgs_tcpu, (self.fake_imgs + self.fg_imgs + self.mk_imgs + self.fg_mk), self.num_Ds, 438 | count, self.image_dir, self.summary_writer) 439 | self.netG.train() 440 | load_params(self.netG, backup_para) 441 | 442 | end_t = time.time() 443 | print('''[%d/%d][%d] 444 | Loss_D: %.2f Loss_G: %.2f Time: %.2fs 445 | ''' 446 | % (epoch, self.max_epoch, self.num_batches, 447 | errD_total.data[0], errG_total.data[0], 448 | end_t - start_t)) 449 | 450 | save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir) 451 | 452 | print ("Done with the normal training. Now performing hard negative training..") 453 | count = 0 454 | start_t = time.time() 455 | for step, data in enumerate(self.data_loader, 0): 456 | 457 | self.imgs_tcpu, self.real_fimgs, self.real_cimgs, \ 458 | self.c_code, self.warped_bbox = self.prepare_data(data) 459 | 460 | if (count % 2) == 0: # Train on normal batch of images 461 | 462 | # Feedforward through Generator. Obtain stagewise fake images 463 | noise.data.normal_(0, 1) 464 | self.fake_imgs, self.fg_imgs, self.mk_imgs, self.fg_mk = \ 465 | self.netG(noise, self.c_code) 466 | 467 | self.p_code = child_to_parent(self.c_code, cfg.FINE_GRAINED_CATEGORIES, cfg.SUPER_CATEGORIES) 468 | 469 | # Update discriminator networks 470 | errD_total = 0 471 | for i in range(self.num_Ds): 472 | if i == 0 or i == 2: 473 | errD = self.train_Dnet(i, count) 474 | errD_total += errD 475 | 476 | 477 | # Update the generator network 478 | errG_total = self.train_Gnet(count) 479 | 480 | else: # Train on degenerate images 481 | repeat_times=10 482 | all_hard_z = Variable(torch.zeros(self.batch_size * repeat_times, nz)).cuda() 483 | all_hard_class = Variable(torch.zeros(self.batch_size * repeat_times, cfg.FINE_GRAINED_CATEGORIES)).cuda() 484 | all_logits = Variable(torch.zeros(self.batch_size * repeat_times,)).cuda() 485 | 486 | for hard_it in range(repeat_times): 487 | hard_noise = hard_noise.data.normal_(0,1) 488 | hard_class = Variable(torch.zeros([self.batch_size, cfg.FINE_GRAINED_CATEGORIES])).cuda() 489 | my_rand_id=[] 490 | 491 | for c_it in range(self.batch_size): 492 | rand_class = random.sample(range(cfg.FINE_GRAINED_CATEGORIES),1); 493 | hard_class[c_it][rand_class] = 1 494 | my_rand_id.append(rand_class) 495 | 496 | all_hard_z[self.batch_size * hard_it : self.batch_size * (hard_it + 1)] = hard_noise.data 497 | all_hard_class[self.batch_size * hard_it : self.batch_size * (hard_it + 1)] = hard_class.data 498 | self.fake_imgs, self.fg_imgs, self.mk_imgs, self.fg_mk = self.netG(hard_noise.detach(), hard_class.detach()) 499 | 500 | fake_logits = self.netsD[2](self.fg_mk[1].detach()) 501 | smax_class = softmax(fake_logits[0], dim = 1) 502 | 503 | for b_it in range(self.batch_size): 504 | all_logits[(self.batch_size * hard_it) + b_it] = smax_class[b_it][my_rand_id[b_it]] 505 | 506 | sorted_val, indices_hard = torch.sort(all_logits) 507 | noise = all_hard_z[indices_hard[0 : self.batch_size]] 508 | self.c_code = all_hard_class[indices_hard[0 : self.batch_size]] 509 | 510 | self.fake_imgs, self.fg_imgs, self.mk_imgs, self.fg_mk = \ 511 | self.netG(noise, self.c_code) 512 | 513 | self.p_code = child_to_parent(self.c_code, cfg.FINE_GRAINED_CATEGORIES, cfg.SUPER_CATEGORIES) 514 | 515 | # Update Discriminator networks 516 | errD_total = 0 517 | for i in range(self.num_Ds): 518 | if i == 0 or i == 2: 519 | errD = self.train_Dnet(i, count) 520 | errD_total += errD 521 | 522 | # Update generator network 523 | errG_total = self.train_Gnet(count) 524 | 525 | for p, avg_p in zip(self.netG.parameters(), avg_param_G): 526 | avg_p.mul_(0.999).add_(0.001, p.data) 527 | count = count + 1 528 | 529 | if count % cfg.TRAIN.SNAPSHOT_INTERVAL_HARDNEG == 0: 530 | backup_para = copy_G_params(self.netG) 531 | save_model(self.netG, avg_param_G, self.netsD, count+500000, self.model_dir) 532 | load_params(self.netG, avg_param_G) 533 | self.netG.eval() 534 | with torch.set_grad_enabled(False): 535 | self.fake_imgs, self.fg_imgs, self.mk_imgs, self.fg_mk = \ 536 | self.netG(fixed_noise, self.c_code) 537 | save_img_results(self.imgs_tcpu, (self.fake_imgs + self.fg_imgs + self.mk_imgs + self.fg_mk), self.num_Ds, 538 | count, self.image_dir, self.summary_writer) 539 | self.netG.train() 540 | load_params(self.netG, backup_para) 541 | 542 | end_t = time.time() 543 | 544 | if (count % 100) == 0: 545 | print('''[%d/%d][%d] 546 | Loss_D: %.2f Loss_G: %.2f Time: %.2fs 547 | ''' 548 | % (count, cfg.TRAIN.HARDNEG_MAX_ITER, self.num_batches, 549 | errD_total.data[0], errG_total.data[0], 550 | end_t - start_t)) 551 | 552 | if (count == cfg.TRAIN.HARDNEG_MAX_ITER): # Hard negative training complete 553 | break 554 | 555 | save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir) 556 | self.summary_writer.close() 557 | 558 | 559 | 560 | class FineGAN_evaluator(object): 561 | 562 | def __init__(self): 563 | 564 | self.save_dir = os.path.join(cfg.SAVE_DIR, 'images') 565 | mkdir_p(self.save_dir) 566 | s_gpus = cfg.GPU_ID.split(',') 567 | self.gpus = [int(ix) for ix in s_gpus] 568 | self.num_gpus = len(self.gpus) 569 | torch.cuda.set_device(self.gpus[0]) 570 | cudnn.benchmark = True 571 | self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus 572 | 573 | 574 | def evaluate_finegan(self): 575 | if cfg.TRAIN.NET_G == '': 576 | print('Error: the path for model not found!') 577 | else: 578 | # Build and load the generator 579 | netG = G_NET() 580 | netG.apply(weights_init) 581 | netG = torch.nn.DataParallel(netG, device_ids=self.gpus) 582 | model_dict = netG.state_dict() 583 | 584 | state_dict = \ 585 | torch.load(cfg.TRAIN.NET_G, 586 | map_location=lambda storage, loc: storage) 587 | 588 | state_dict = {k: v for k, v in state_dict.items() if k in model_dict} 589 | 590 | model_dict.update(state_dict) 591 | netG.load_state_dict(model_dict) 592 | print('Load ', cfg.TRAIN.NET_G) 593 | 594 | # Uncomment this to print Generator layers 595 | # print(netG) 596 | 597 | nz = cfg.GAN.Z_DIM 598 | noise = torch.FloatTensor(self.batch_size, nz) 599 | noise.data.normal_(0, 1) 600 | 601 | if cfg.CUDA: 602 | netG.cuda() 603 | noise = noise.cuda() 604 | 605 | netG.eval() 606 | 607 | background_class = cfg.TEST_BACKGROUND_CLASS 608 | parent_class = cfg.TEST_PARENT_CLASS 609 | child_class = cfg.TEST_CHILD_CLASS 610 | bg_code = torch.zeros([self.batch_size, cfg.FINE_GRAINED_CATEGORIES]) 611 | p_code = torch.zeros([self.batch_size, cfg.SUPER_CATEGORIES]) 612 | c_code = torch.zeros([self.batch_size, cfg.FINE_GRAINED_CATEGORIES]) 613 | 614 | for j in range(self.batch_size): 615 | bg_code[j][background_class] = 1 616 | p_code[j][parent_class] = 1 617 | c_code[j][child_class] = 1 618 | 619 | fake_imgs, fg_imgs, mk_imgs, fgmk_imgs = netG(noise, c_code, p_code, bg_code) # Forward pass through the generator 620 | 621 | self.save_image(fake_imgs[0][0], self.save_dir, 'background') 622 | self.save_image(fake_imgs[1][0], self.save_dir, 'parent_final') 623 | self.save_image(fake_imgs[2][0], self.save_dir, 'child_final') 624 | self.save_image(fg_imgs[0][0], self.save_dir, 'parent_foreground') 625 | self.save_image(fg_imgs[1][0], self.save_dir, 'child_foreground') 626 | self.save_image(mk_imgs[0][0], self.save_dir, 'parent_mask') 627 | self.save_image(mk_imgs[1][0], self.save_dir, 'child_mask') 628 | self.save_image(fgmk_imgs[0][0], self.save_dir, 'parent_foreground_masked') 629 | self.save_image(fgmk_imgs[1][0], self.save_dir, 'child_foreground_masked') 630 | 631 | 632 | def save_image(self, images, save_dir, iname): 633 | 634 | img_name = '%s.png' % (iname) 635 | full_path = os.path.join(save_dir, img_name) 636 | 637 | if (iname.find('mask') == -1) or (iname.find('foreground') != -1): 638 | img = images.add(1).div(2).mul(255).clamp(0, 255).byte() 639 | ndarr = img.permute(1, 2, 0).data.cpu().numpy() 640 | im = Image.fromarray(ndarr) 641 | im.save(full_path) 642 | 643 | else: 644 | img = images.mul(255).clamp(0, 255).byte() 645 | ndarr = img.data.cpu().numpy() 646 | ndarr = np.reshape(ndarr, (ndarr.shape[-1], ndarr.shape[-1], 1)) 647 | ndarr = np.repeat(ndarr, 3, axis=2) 648 | im = Image.fromarray(ndarr) 649 | im.save(full_path) 650 | 651 | 652 | --------------------------------------------------------------------------------