├── files
├── approach.png
├── teaser.png
├── qual_res1.png
├── qual_res2.png
├── qual_res3.png
└── finegan_demo.gif
├── code
├── miscc
│ ├── __init__.py
│ ├── utils.py
│ └── config.py
├── cfg
│ ├── eval.yml
│ └── train.yml
├── main.py
├── datasets.py
├── inception.py
├── model.py
└── trainer.py
├── models
└── README.md
├── data
└── README.md
├── LICENSE
└── README.md
/files/approach.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kkanshul/finegan/HEAD/files/approach.png
--------------------------------------------------------------------------------
/files/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kkanshul/finegan/HEAD/files/teaser.png
--------------------------------------------------------------------------------
/files/qual_res1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kkanshul/finegan/HEAD/files/qual_res1.png
--------------------------------------------------------------------------------
/files/qual_res2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kkanshul/finegan/HEAD/files/qual_res2.png
--------------------------------------------------------------------------------
/files/qual_res3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kkanshul/finegan/HEAD/files/qual_res3.png
--------------------------------------------------------------------------------
/files/finegan_demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kkanshul/finegan/HEAD/files/finegan_demo.gif
--------------------------------------------------------------------------------
/code/miscc/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 |
--------------------------------------------------------------------------------
/code/miscc/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import errno
3 |
4 | def mkdir_p(path):
5 | try:
6 | os.makedirs(path)
7 | except OSError as exc: # Python >2.5
8 | if exc.errno == errno.EEXIST and os.path.isdir(path):
9 | pass
10 | else:
11 | raise
12 |
--------------------------------------------------------------------------------
/models/README.md:
--------------------------------------------------------------------------------
1 | ## Download pretrained models
2 | Pretrained generator models for CUB, Stanford Dogs are available at this [link](https://drive.google.com/file/d/1cKJAXRDQ-_a76bHWRqcIdmPXqpN8a1lR/view?usp=sharing). Download and extract them in the `models` directory.
3 | ```bash
4 | cd models
5 | unzip netG.zip
6 | cd ..
7 | ```
8 |
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | ## Data
2 | **Note**: You only need to download the data if you wish to train your own model.
3 |
4 | Download the formatted CUB data from this [link](https://drive.google.com/file/d/1ardy8L7Cb-Vn1ynQigaXpX_JHl0dhh2M/view?usp=sharing) and extract it inside the `data` directory
5 | ```bash
6 | cd data
7 | unzip birds.zip
8 | cd ..
9 | ```
10 |
--------------------------------------------------------------------------------
/code/cfg/eval.yml:
--------------------------------------------------------------------------------
1 | DATASET_NAME: 'birds'
2 | SAVE_DIR: '../output/'
3 | GPU_ID: '3'
4 | WORKERS: 1 # 4
5 |
6 | SUPER_CATEGORIES: 20 # For CUB
7 | FINE_GRAINED_CATEGORIES: 200 # For CUB
8 | TEST_CHILD_CLASS: 125 # specify any value [0, FINE_GRAINED_CATEGORIES - 1]
9 | TEST_PARENT_CLASS: 0 # specify any value [0, SUPER_CATEGORIES - 1]
10 | TEST_BACKGROUND_CLASS: 0 # specify any value [0, FINE_GRAINED_CATEGORIES - 1]
11 | TIED_CODES: True
12 |
13 | TRAIN:
14 | FLAG: False
15 | NET_G: '../models/netG/netG_birds.pth'
16 | BATCH_SIZE: 1
17 |
18 |
19 | GAN:
20 | DF_DIM: 64
21 | GF_DIM: 64
22 | Z_DIM: 100
23 | R_NUM: 2
24 |
--------------------------------------------------------------------------------
/code/cfg/train.yml:
--------------------------------------------------------------------------------
1 | CONFIG_NAME: '3stages'
2 | DATASET_NAME: 'birds'
3 | DATA_DIR: '../data/birds'
4 | SAVE_DIR: '../output/vis'
5 | GPU_ID: '0'
6 | WORKERS: 4
7 |
8 | SUPER_CATEGORIES: 20 # For CUB
9 | FINE_GRAINED_CATEGORIES: 200 # For CUB
10 | TIED_CODES: True # Do NOT change this to False during training.
11 |
12 | TREE:
13 | BRANCH_NUM: 3
14 |
15 | TRAIN:
16 | FLAG: True
17 | NET_G: '' # Specify the generator path to resume training
18 | NET_D: '' # Specify the discriminator path to resume training
19 | BATCH_SIZE: 16
20 | MAX_EPOCH: 600
21 | HARDNEG_MAX_ITER: 1500
22 | SNAPSHOT_INTERVAL: 4000
23 | SNAPSHOT_INTERVAL_HARDNEG: 500
24 | DISCRIMINATOR_LR: 0.0002
25 | GENERATOR_LR: 0.0002
26 |
27 | GAN:
28 | DF_DIM: 64
29 | GF_DIM: 64
30 | Z_DIM: 100
31 | R_NUM: 2
32 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2019, Krishna Kumar Singh, Utkarsh Ojha, Yong Jae Lee
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are met:
6 |
7 | * Redistributions of source code must retain the above copyright notice, this
8 | list of conditions and the following disclaimer.
9 |
10 | * Redistributions in binary form must reproduce the above copyright notice,
11 | this list of conditions and the following disclaimer in the documentation
12 | and/or other materials provided with the distribution.
13 |
14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24 |
--------------------------------------------------------------------------------
/code/miscc/config.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 |
4 | import os.path as osp
5 | import numpy as np
6 | from easydict import EasyDict as edict
7 |
8 |
9 | __C = edict()
10 | cfg = __C
11 |
12 | # Dataset name: flowers, birds
13 | __C.DATASET_NAME = 'birds'
14 | __C.CONFIG_NAME = ''
15 | __C.DATA_DIR = ''
16 | __C.SAVE_DIR = ''
17 | __C.GPU_ID = '0'
18 | __C.CUDA = True
19 |
20 | __C.WORKERS = 6
21 |
22 | __C.TREE = edict()
23 | __C.TREE.BRANCH_NUM = 3
24 | __C.TREE.BASE_SIZE = 64
25 | __C.SUPER_CATEGORIES = 20
26 | __C.FINE_GRAINED_CATEGORIES = 200
27 | __C.TEST_CHILD_CLASS = 0
28 | __C.TEST_PARENT_CLASS = 0
29 | __C.TEST_BACKGROUND_CLASS = 0
30 | __C.TIED_CODES = True
31 |
32 | # Test options
33 | __C.TEST = edict()
34 |
35 | # Training options
36 | __C.TRAIN = edict()
37 | __C.TRAIN.BATCH_SIZE = 64
38 | __C.TRAIN.BG_LOSS_WT = 10
39 | __C.TRAIN.VIS_COUNT = 64
40 | __C.TRAIN.MAX_EPOCH = 600
41 | __C.TRAIN.HARDNEG_MAX_ITER = 1500
42 | __C.TRAIN.SNAPSHOT_INTERVAL = 2000
43 | __C.TRAIN.SNAPSHOT_INTERVAL_HARDNEG = 500
44 | __C.TRAIN.DISCRIMINATOR_LR = 2e-4
45 | __C.TRAIN.GENERATOR_LR = 2e-4
46 | __C.TRAIN.FLAG = True
47 | __C.TRAIN.NET_G = ''
48 | __C.TRAIN.NET_D = ''
49 |
50 |
51 | # Modal options
52 | __C.GAN = edict()
53 | __C.GAN.DF_DIM = 64
54 | __C.GAN.GF_DIM = 64
55 | __C.GAN.Z_DIM = 100
56 | __C.GAN.NETWORK_TYPE = 'default'
57 | __C.GAN.R_NUM = 2
58 |
59 |
60 |
61 |
62 | def _merge_a_into_b(a, b):
63 | """Merge config dictionary a into config dictionary b, clobbering the
64 | options in b whenever they are also specified in a.
65 | """
66 | if type(a) is not edict:
67 | return
68 |
69 | for k, v in a.iteritems():
70 | # a must specify keys that are in b
71 | if not b.has_key(k):
72 | raise KeyError('{} is not a valid config key'.format(k))
73 |
74 | # the types must match, too
75 | old_type = type(b[k])
76 | if old_type is not type(v):
77 | if isinstance(b[k], np.ndarray):
78 | v = np.array(v, dtype=b[k].dtype)
79 | else:
80 | raise ValueError(('Type mismatch ({} vs. {}) '
81 | 'for config key: {}').format(type(b[k]),
82 | type(v), k))
83 |
84 | # recursively merge dicts
85 | if type(v) is edict:
86 | try:
87 | _merge_a_into_b(a[k], b[k])
88 | except:
89 | print('Error under config key: {}'.format(k))
90 | raise
91 | else:
92 | b[k] = v
93 |
94 |
95 | def cfg_from_file(filename):
96 | """Load a config file and merge it into the default options."""
97 | import yaml
98 | with open(filename, 'r') as f:
99 | yaml_cfg = edict(yaml.load(f))
100 |
101 | _merge_a_into_b(yaml_cfg, __C)
102 |
--------------------------------------------------------------------------------
/code/main.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch
3 | import torchvision.transforms as transforms
4 |
5 | import argparse
6 | import os
7 | import random
8 | import sys
9 | import pprint
10 | import datetime
11 | import dateutil.tz
12 | import time
13 | import pickle
14 |
15 | dir_path = (os.path.abspath(os.path.join(os.path.realpath(__file__), './.')))
16 | sys.path.append(dir_path)
17 |
18 |
19 | from miscc.config import cfg, cfg_from_file
20 |
21 |
22 | def parse_args():
23 | parser = argparse.ArgumentParser(description='Train a GAN network')
24 | parser.add_argument('--cfg', dest='cfg_file',
25 | help='optional config file',
26 | default='cfg/birds_proGAN.yml', type=str)
27 | parser.add_argument('--gpu', dest='gpu_id', type=str, default='-1')
28 | parser.add_argument('--data_dir', dest='data_dir', type=str, default='')
29 | parser.add_argument('--manualSeed', type=int, help='manual seed')
30 | #parser.add_argument('--config_key',dest='config_key', type=str, help='configuration name', default = 'finegan_birds')
31 | args = parser.parse_args()
32 | return args
33 |
34 |
35 | if __name__ == "__main__":
36 | args = parse_args()
37 | if args.cfg_file is not None:
38 | cfg_from_file(args.cfg_file)
39 |
40 | if args.gpu_id != '-1':
41 | cfg.GPU_ID = args.gpu_id
42 | else:
43 | cfg.CUDA = False
44 |
45 | if args.data_dir != '':
46 | cfg.DATA_DIR = args.data_dir
47 | if cfg.TRAIN.FLAG:
48 | print('Using config:')
49 | pprint.pprint(cfg)
50 |
51 | if not cfg.TRAIN.FLAG:
52 | args.manualSeed = 45 # Change this to have different random seed during evaluation
53 |
54 | elif args.manualSeed is None:
55 | args.manualSeed = random.randint(1, 10000)
56 | random.seed(args.manualSeed)
57 | torch.manual_seed(args.manualSeed)
58 | if cfg.CUDA:
59 | torch.cuda.manual_seed_all(args.manualSeed)
60 |
61 | # Evaluation part
62 | if not cfg.TRAIN.FLAG:
63 | from trainer import FineGAN_evaluator as evaluator
64 | algo = evaluator()
65 | algo.evaluate_finegan()
66 |
67 | # Training part
68 | else:
69 | now = datetime.datetime.now(dateutil.tz.tzlocal())
70 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
71 | output_dir = '../output/%s_%s' % \
72 | (cfg.DATASET_NAME, timestamp)
73 | pkl_filename = 'cfg.pickle'
74 |
75 | if not os.path.exists(output_dir):
76 | os.makedirs(output_dir)
77 |
78 | with open(os.path.join(output_dir, pkl_filename), 'wb') as pk:
79 | pickle.dump(cfg, pk, protocol=pickle.HIGHEST_PROTOCOL)
80 |
81 | bshuffle = True
82 |
83 | # Get data loader
84 | imsize = cfg.TREE.BASE_SIZE * (2 ** (cfg.TREE.BRANCH_NUM-1))
85 | image_transform = transforms.Compose([
86 | transforms.Scale(int(imsize * 76 / 64)),
87 | transforms.RandomCrop(imsize),
88 | transforms.RandomHorizontalFlip()])
89 |
90 |
91 | from datasets import Dataset
92 | dataset = Dataset(cfg.DATA_DIR,
93 | base_size=cfg.TREE.BASE_SIZE,
94 | transform=image_transform)
95 | assert dataset
96 | num_gpu = len(cfg.GPU_ID.split(','))
97 | dataloader = torch.utils.data.DataLoader(
98 | dataset, batch_size=cfg.TRAIN.BATCH_SIZE * num_gpu,
99 | drop_last=True, shuffle=bshuffle, num_workers=int(cfg.WORKERS))
100 |
101 |
102 | from trainer import FineGAN_trainer as trainer
103 | algo = trainer(output_dir, dataloader, imsize)
104 |
105 | start_t = time.time()
106 | algo.train()
107 | end_t = time.time()
108 | print('Total time for training:', end_t - start_t)
109 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FineGAN
2 | Pytorch implementation for learning to synthesize images in a hierarchical, stagewise manner by disentangling background, object shape and object appearance.
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 | ### FineGAN: Unsupervised Hierarchical Disentanglement for Fine-grained Object Generation and Discovery
12 | [Krishna Kumar Singh*](http://krsingh.cs.ucdavis.edu), [Utkarsh Ojha*](https://utkarshojha.github.io/), [Yong Jae Lee](http://web.cs.ucdavis.edu/~yjlee/)
13 |
14 | [project](http://krsingh.cs.ucdavis.edu/krishna_files/papers/finegan/index.html) |
15 | [arxiv](https://arxiv.org/abs/1811.11155) | [demo video](https://www.youtube.com/watch?v=tkk0SeWGu-8) | [talk video](https://www.youtube.com/watch?v=8qkrPSjONhA&t=51m40s)
16 |
17 | **[CVPR 2019 (Oral Presentation)](http://cvpr2019.thecvf.com/)**
18 | ## Architecture
19 |
20 |
21 |
22 | ## Requirements
23 | - Linux
24 | - Python 2.7
25 | - Pytorch 0.4.1
26 | - TensorboardX 1.2
27 | - NVIDIA GPU + CUDA CuDNN
28 |
29 | ## Getting started
30 | ### Clone the repository
31 | ```bash
32 | git clone https://github.com/kkanshul/finegan
33 | cd finegan
34 | ```
35 | ### Setting up the data
36 | **Note**: You only need to download the data if you wish to train your own model.
37 |
38 | Download the formatted CUB data from this [link](https://drive.google.com/file/d/1ardy8L7Cb-Vn1ynQigaXpX_JHl0dhh2M/view?usp=sharing) and extract it inside the `data` directory
39 | ```bash
40 | cd data
41 | unzip birds.zip
42 | cd ..
43 | ```
44 | ### Downloading pretrained models
45 |
46 | Pretrained generator models for CUB, Stanford Dogs are available at this [link](https://drive.google.com/file/d/1cKJAXRDQ-_a76bHWRqcIdmPXqpN8a1lR/view?usp=sharing). Download and extract them in the `models` directory.
47 | ```bash
48 | cd models
49 | unzip netG.zip
50 | cd ../code/
51 | ```
52 | ## Evaluating the model
53 | In `cfg/eval.yml`:
54 | - Specify the model path in `TRAIN.NET_G`.
55 | - Specify the output directory to save the generated images in `SAVE_DIR`.
56 | - Specify the number of super and fine-grained categories in `SUPER_CATEGORIES` and `FINE_GRAINED_CATEGORIES` according to our [paper](https://arxiv.org/abs/1811.11155).
57 | - Specify the option for using 'tied' latent codes in `TIED_CODES`:
58 | - if `True`, specify the child code in `TEST_CHILD_CLASS`. The background and parent codes are derived through the child code in this case.
59 | - if `False`, i.e. no relationship between parent, child or background code, specify each of them in `TEST_PARENT_CLASS`, `TEST_CHILD_CLASS` and `TEST_BACKGROUND_CLASS` respectively.
60 | - Run `python main.py --cfg cfg/eval.yml --gpu 0`
61 |
62 | ## Training your own model
63 | In `cfg/train.yml`:
64 | - Specify the dataset location in `DATA_DIR`.
65 | - **NOTE**: If you wish to train this on your own (different) dataset, please make sure it is formatted in a way similar to the CUB dataset that we've provided.
66 | - Specify the number of super and fine-grained categories that you wish for FineGAN to discover, in `SUPER_CATEGORIES` and `FINE_GRAINED_CATEGORIES`.
67 | - Specify the training hyperparameters in `TRAIN`.
68 | - Run `python main.py --cfg cfg/train.yml --gpu 0`
69 |
70 | ## Sample generation results of FineGAN
71 | ### 1. Stage wise image generation results
72 |
73 |
74 | ### 2. Grouping among the generated images (child).
75 |
76 |
77 |
78 | ## Citation
79 | If you find this code useful in your research, consider citing our work:
80 | ```
81 | @inproceedings{singh-cvpr2019,
82 | title = {FineGAN: Unsupervised Hierarchical Disentanglement for Fine-Grained Object Generation and Discovery},
83 | author = {Krishna Kumar Singh and Utkarsh Ojha and Yong Jae Lee},
84 | booktitle = {CVPR},
85 | year = {2019}
86 | }
87 | ```
88 | ## Acknowledgement
89 | We thank the authors of [StackGAN++: Realistic Image Synthesis with Stacked Generative Adversarial Networks](https://arxiv.org/abs/1710.10916) for releasing their source code.
90 | ## Contact
91 | For any questions regarding our paper or code, contact [Krishna Kumar Singh](mailto:krsingh@ucdavis.edu) and [Utkarsh Ojha](uojha@ucdavis.edu).
92 |
--------------------------------------------------------------------------------
/code/datasets.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 | from __future__ import unicode_literals
5 | import sys
6 |
7 | import torch.utils.data as data
8 | import torchvision.transforms as transforms
9 | from PIL import Image
10 | import PIL
11 | import os
12 | import os.path
13 | import pickle
14 | import random
15 | import numpy as np
16 | import pandas as pd
17 | from miscc.config import cfg
18 |
19 | import torch.utils.data as data
20 | from PIL import Image
21 | import os
22 | import os.path
23 | import six
24 | import string
25 | import sys
26 | import torch
27 | from copy import deepcopy
28 | if sys.version_info[0] == 2:
29 | import cPickle as pickle
30 | else:
31 | import pickle
32 |
33 | IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG',
34 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP']
35 |
36 |
37 | def is_image_file(filename):
38 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
39 |
40 |
41 | def get_imgs(img_path, imsize, bbox=None,
42 | transform=None, normalize=None):
43 | img = Image.open(img_path).convert('RGB')
44 | width, height = img.size
45 | if bbox is not None:
46 | r = int(np.maximum(bbox[2], bbox[3]) * 0.75)
47 | center_x = int((2 * bbox[0] + bbox[2]) / 2)
48 | center_y = int((2 * bbox[1] + bbox[3]) / 2)
49 | y1 = np.maximum(0, center_y - r)
50 | y2 = np.minimum(height, center_y + r)
51 | x1 = np.maximum(0, center_x - r)
52 | x2 = np.minimum(width, center_x + r)
53 | fimg = deepcopy(img)
54 | fimg_arr = np.array(fimg)
55 | fimg = Image.fromarray(fimg_arr)
56 | cimg = img.crop([x1, y1, x2, y2])
57 |
58 | if transform is not None:
59 | cimg = transform(cimg)
60 |
61 |
62 | retf = []
63 | retc = []
64 | re_cimg = transforms.Scale(imsize[1])(cimg)
65 | retc.append(normalize(re_cimg))
66 |
67 | # We use full image to get background patches
68 |
69 | # We resize the full image to be 126 X 126 (instead of 128 X 128) for the full coverage of the input (full) image by
70 | # the receptive fields of the final convolution layer of background discriminator
71 |
72 | my_crop_width = 126
73 | re_fimg = transforms.Scale(int(my_crop_width * 76 / 64))(fimg)
74 | re_width, re_height = re_fimg.size
75 |
76 | # random cropping
77 | x_crop_range = re_width-my_crop_width
78 | y_crop_range = re_height-my_crop_width
79 |
80 | crop_start_x = np.random.randint(x_crop_range)
81 | crop_start_y = np.random.randint(y_crop_range)
82 |
83 | crop_re_fimg = re_fimg.crop([crop_start_x, crop_start_y, crop_start_x + my_crop_width, crop_start_y + my_crop_width])
84 | warped_x1 = bbox[0] * re_width / width
85 | warped_y1 = bbox[1] * re_height / height
86 | warped_x2 = warped_x1 + (bbox[2] * re_width / width)
87 | warped_y2 = warped_y1 + (bbox[3] * re_height / height)
88 |
89 | warped_x1 =min(max(0, warped_x1 - crop_start_x), my_crop_width)
90 | warped_y1 =min(max(0, warped_y1 - crop_start_y), my_crop_width)
91 | warped_x2 =max(min(my_crop_width, warped_x2 - crop_start_x),0)
92 | warped_y2 =max(min(my_crop_width, warped_y2 - crop_start_y),0)
93 |
94 | # random flipping
95 | random_flag=np.random.randint(2)
96 | if(random_flag == 0):
97 | crop_re_fimg = crop_re_fimg.transpose(Image.FLIP_LEFT_RIGHT)
98 | flipped_x1 = my_crop_width - warped_x2
99 | flipped_x2 = my_crop_width - warped_x1
100 | warped_x1 = flipped_x1
101 | warped_x2 = flipped_x2
102 |
103 | retf.append(normalize(crop_re_fimg))
104 |
105 | warped_bbox = []
106 | warped_bbox.append(warped_y1)
107 | warped_bbox.append(warped_x1)
108 | warped_bbox.append(warped_y2)
109 | warped_bbox.append(warped_x2)
110 |
111 | return retf, retc, warped_bbox
112 |
113 |
114 |
115 | class Dataset(data.Dataset):
116 | def __init__(self, data_dir, base_size=64, transform = None):
117 |
118 | self.transform = transform
119 | self.norm = transforms.Compose([
120 | transforms.ToTensor(),
121 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
122 |
123 | self.imsize = []
124 | for i in range(cfg.TREE.BRANCH_NUM):
125 | self.imsize.append(base_size)
126 | base_size = base_size * 2
127 |
128 | self.data = []
129 | self.data_dir = data_dir
130 | self.bbox = self.load_bbox()
131 | self.filenames = self.load_filenames(data_dir)
132 | if cfg.TRAIN.FLAG:
133 | self.iterator = self.prepair_training_pairs
134 | else:
135 | self.iterator = self.prepair_test_pairs
136 |
137 |
138 | # only used in background stage
139 | def load_bbox(self):
140 | # Returns a dictionary with image filename as 'key' and its bounding box coordinates as 'value'
141 |
142 | data_dir = self.data_dir
143 | bbox_path = os.path.join(data_dir, 'bounding_boxes.txt')
144 | df_bounding_boxes = pd.read_csv(bbox_path,
145 | delim_whitespace=True,
146 | header=None).astype(int)
147 | filepath = os.path.join(data_dir, 'images.txt')
148 | df_filenames = \
149 | pd.read_csv(filepath, delim_whitespace=True, header=None)
150 | filenames = df_filenames[1].tolist()
151 | print('Total filenames: ', len(filenames), filenames[0])
152 | filename_bbox = {img_file[:-4]: [] for img_file in filenames}
153 | numImgs = len(filenames)
154 | for i in xrange(0, numImgs):
155 | bbox = df_bounding_boxes.iloc[i][1:].tolist()
156 | key = filenames[i][:-4]
157 | filename_bbox[key] = bbox
158 | return filename_bbox
159 |
160 |
161 | def load_filenames(self, data_dir):
162 | filepath = os.path.join(data_dir, 'images.txt')
163 | df_filenames = \
164 | pd.read_csv(filepath, delim_whitespace=True, header=None)
165 | filenames = df_filenames[1].tolist()
166 | filenames = [fname[:-4] for fname in filenames];
167 | print('Load filenames from: %s (%d)' % (filepath, len(filenames)))
168 | return filenames
169 |
170 |
171 | def prepair_training_pairs(self, index):
172 | key = self.filenames[index]
173 | if self.bbox is not None:
174 | bbox = self.bbox[key]
175 | else:
176 | bbox = None
177 | data_dir = self.data_dir
178 | img_name = '%s/images/%s.jpg' % (data_dir, key)
179 | fimgs, cimgs, warped_bbox = get_imgs(img_name, self.imsize,
180 | bbox, self.transform, normalize=self.norm)
181 |
182 | rand_class= random.sample(range(cfg.FINE_GRAINED_CATEGORIES),1); # Randomly generating child code during training
183 | c_code = torch.zeros([cfg.FINE_GRAINED_CATEGORIES,])
184 | c_code[rand_class] = 1
185 |
186 | return fimgs, cimgs, c_code, key, warped_bbox
187 |
188 | def prepair_test_pairs(self, index):
189 | key = self.filenames[index]
190 | if self.bbox is not None:
191 | bbox = self.bbox[key]
192 | else:
193 | bbox = None
194 | data_dir = self.data_dir
195 | c_code = self.c_code[index, :, :]
196 | img_name = '%s/images/%s.jpg' % (data_dir, key)
197 | _, imgs, _ = get_imgs(img_name, self.imsize,
198 | bbox, self.transform, normalize=self.norm)
199 |
200 | return imgs, c_code, key
201 |
202 | def __getitem__(self, index):
203 | return self.iterator(index)
204 |
205 | def __len__(self):
206 | return len(self.filenames)
207 |
--------------------------------------------------------------------------------
/code/inception.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.utils.model_zoo as model_zoo
5 | from miscc.config import cfg
6 |
7 |
8 | __all__ = ['Inception3', 'inception_v3']
9 |
10 |
11 | model_urls = {
12 | # Inception v3 ported from TensorFlow
13 | 'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
14 | }
15 |
16 |
17 |
18 | def inception_v3(pretrained=True, **kwargs):
19 | r"""Inception v3 model architecture from
20 | `"Rethinking the Inception Architecture for Computer Vision" `_.
21 | Args:
22 | pretrained (bool): If True, returns a model pre-trained on ImageNet
23 | """
24 | if pretrained:
25 | if 'transform_input' not in kwargs:
26 | kwargs['transform_input'] = True
27 | model = Inception3(**kwargs)
28 | pretrained_dict = model_zoo.load_url(model_urls['inception_v3_google'])
29 | model_dict = model.state_dict()
30 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
31 | model_dict.update(pretrained_dict)
32 | model.load_state_dict(model_dict)
33 | print ("Inception pretrained on IMAGENET loaded")
34 | return model
35 |
36 | return Inception3(**kwargs)
37 |
38 |
39 | class Inception3(nn.Module):
40 |
41 | def __init__(self, num_classes=200, aux_logits=True, transform_input=False):
42 | super(Inception3, self).__init__()
43 | self.aux_logits = aux_logits
44 | self.transform_input = transform_input
45 | self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2)
46 | self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3)
47 | self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
48 | self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1)
49 | self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3)
50 | self.Mixed_5b = InceptionA(192, pool_features=32)
51 | self.Mixed_5c = InceptionA(256, pool_features=64)
52 | self.Mixed_5d = InceptionA(288, pool_features=64)
53 | self.Mixed_6a = InceptionB(288)
54 | self.Mixed_6b = InceptionC(768, channels_7x7=128)
55 | self.Mixed_6c = InceptionC(768, channels_7x7=160)
56 | self.Mixed_6d = InceptionC(768, channels_7x7=160)
57 | self.Mixed_6e = InceptionC(768, channels_7x7=192)
58 | if aux_logits:
59 | self.AuxLogits = InceptionAux(768, num_classes)
60 | self.Mixed_7a = InceptionD(768)
61 | self.Mixed_7b = InceptionE(1280)
62 | self.Mixed_7c = InceptionE(2048)
63 | self.fc_new = nn.Linear(2048, num_classes)
64 |
65 | for m in self.modules():
66 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
67 | import scipy.stats as stats
68 | stddev = m.stddev if hasattr(m, 'stddev') else 0.1
69 | X = stats.truncnorm(-2, 2, scale=stddev)
70 | values = torch.Tensor(X.rvs(m.weight.numel()))
71 | values = values.view(m.weight.size())
72 | m.weight.data.copy_(values)
73 | elif isinstance(m, nn.BatchNorm2d):
74 | nn.init.constant_(m.weight, 1)
75 | nn.init.constant_(m.bias, 0)
76 |
77 | def forward(self, x):
78 |
79 | #No preprocessing being done right now
80 |
81 | # 299 x 299 x 3
82 | x = self.Conv2d_1a_3x3(x)
83 | # 149 x 149 x 32
84 | x = self.Conv2d_2a_3x3(x)
85 | # 147 x 147 x 32
86 | x = self.Conv2d_2b_3x3(x)
87 | # 147 x 147 x 64
88 | x = F.max_pool2d(x, kernel_size=3, stride=2)
89 | # 73 x 73 x 64
90 | x = self.Conv2d_3b_1x1(x)
91 | # 73 x 73 x 80
92 | x = self.Conv2d_4a_3x3(x)
93 | # 71 x 71 x 192
94 | x = F.max_pool2d(x, kernel_size=3, stride=2)
95 | # 35 x 35 x 192
96 | x = self.Mixed_5b(x)
97 | # 35 x 35 x 256
98 | x = self.Mixed_5c(x)
99 | # 35 x 35 x 288
100 | x = self.Mixed_5d(x)
101 | # 35 x 35 x 288
102 | x = self.Mixed_6a(x)
103 | # 17 x 17 x 768
104 | x = self.Mixed_6b(x)
105 | # 17 x 17 x 768
106 | x = self.Mixed_6c(x)
107 | # 17 x 17 x 768
108 | x = self.Mixed_6d(x)
109 | # 17 x 17 x 768
110 | x = self.Mixed_6e(x)
111 | # 17 x 17 x 768
112 | if self.training and self.aux_logits:
113 | aux = self.AuxLogits(x)
114 | # 17 x 17 x 768
115 | x = self.Mixed_7a(x)
116 | # 8 x 8 x 1280
117 | x = self.Mixed_7b(x)
118 | # 8 x 8 x 2048
119 | x = self.Mixed_7c(x)
120 | # 8 x 8 x 2048
121 | x = F.avg_pool2d(x, kernel_size=8)
122 | # 1 x 1 x 2048
123 | x = F.dropout(x, training=self.training)
124 | # 1 x 1 x 2048
125 | x = x.view(x.size(0), -1)
126 | # 2048
127 | x = self.fc_new(x)
128 | # 1000 (num_classes)
129 | if self.training and self.aux_logits:
130 | return x, aux
131 | return x
132 |
133 |
134 | class InceptionA(nn.Module):
135 |
136 | def __init__(self, in_channels, pool_features):
137 | super(InceptionA, self).__init__()
138 | self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1)
139 |
140 | self.branch5x5_1 = BasicConv2d(in_channels, 48, kernel_size=1)
141 | self.branch5x5_2 = BasicConv2d(48, 64, kernel_size=5, padding=2)
142 |
143 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1)
144 | self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1)
145 | self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, padding=1)
146 |
147 | self.branch_pool = BasicConv2d(in_channels, pool_features, kernel_size=1)
148 |
149 | def forward(self, x):
150 | branch1x1 = self.branch1x1(x)
151 |
152 | branch5x5 = self.branch5x5_1(x)
153 | branch5x5 = self.branch5x5_2(branch5x5)
154 |
155 | branch3x3dbl = self.branch3x3dbl_1(x)
156 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
157 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
158 |
159 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
160 | branch_pool = self.branch_pool(branch_pool)
161 |
162 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
163 | return torch.cat(outputs, 1)
164 |
165 |
166 | class InceptionB(nn.Module):
167 |
168 | def __init__(self, in_channels):
169 | super(InceptionB, self).__init__()
170 | self.branch3x3 = BasicConv2d(in_channels, 384, kernel_size=3, stride=2)
171 |
172 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1)
173 | self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1)
174 | self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, stride=2)
175 |
176 | def forward(self, x):
177 | branch3x3 = self.branch3x3(x)
178 |
179 | branch3x3dbl = self.branch3x3dbl_1(x)
180 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
181 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
182 |
183 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
184 |
185 | outputs = [branch3x3, branch3x3dbl, branch_pool]
186 | return torch.cat(outputs, 1)
187 |
188 |
189 | class InceptionC(nn.Module):
190 |
191 | def __init__(self, in_channels, channels_7x7):
192 | super(InceptionC, self).__init__()
193 | self.branch1x1 = BasicConv2d(in_channels, 192, kernel_size=1)
194 |
195 | c7 = channels_7x7
196 | self.branch7x7_1 = BasicConv2d(in_channels, c7, kernel_size=1)
197 | self.branch7x7_2 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3))
198 | self.branch7x7_3 = BasicConv2d(c7, 192, kernel_size=(7, 1), padding=(3, 0))
199 |
200 | self.branch7x7dbl_1 = BasicConv2d(in_channels, c7, kernel_size=1)
201 | self.branch7x7dbl_2 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0))
202 | self.branch7x7dbl_3 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3))
203 | self.branch7x7dbl_4 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0))
204 | self.branch7x7dbl_5 = BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3))
205 |
206 | self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1)
207 |
208 | def forward(self, x):
209 | branch1x1 = self.branch1x1(x)
210 |
211 | branch7x7 = self.branch7x7_1(x)
212 | branch7x7 = self.branch7x7_2(branch7x7)
213 | branch7x7 = self.branch7x7_3(branch7x7)
214 |
215 | branch7x7dbl = self.branch7x7dbl_1(x)
216 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
217 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
218 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
219 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
220 |
221 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
222 | branch_pool = self.branch_pool(branch_pool)
223 |
224 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
225 | return torch.cat(outputs, 1)
226 |
227 |
228 | class InceptionD(nn.Module):
229 |
230 | def __init__(self, in_channels):
231 | super(InceptionD, self).__init__()
232 | self.branch3x3_1 = BasicConv2d(in_channels, 192, kernel_size=1)
233 | self.branch3x3_2 = BasicConv2d(192, 320, kernel_size=3, stride=2)
234 |
235 | self.branch7x7x3_1 = BasicConv2d(in_channels, 192, kernel_size=1)
236 | self.branch7x7x3_2 = BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3))
237 | self.branch7x7x3_3 = BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0))
238 | self.branch7x7x3_4 = BasicConv2d(192, 192, kernel_size=3, stride=2)
239 |
240 | def forward(self, x):
241 | branch3x3 = self.branch3x3_1(x)
242 | branch3x3 = self.branch3x3_2(branch3x3)
243 |
244 | branch7x7x3 = self.branch7x7x3_1(x)
245 | branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
246 | branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
247 | branch7x7x3 = self.branch7x7x3_4(branch7x7x3)
248 |
249 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
250 | outputs = [branch3x3, branch7x7x3, branch_pool]
251 | return torch.cat(outputs, 1)
252 |
253 |
254 | class InceptionE(nn.Module):
255 |
256 | def __init__(self, in_channels):
257 | super(InceptionE, self).__init__()
258 | self.branch1x1 = BasicConv2d(in_channels, 320, kernel_size=1)
259 |
260 | self.branch3x3_1 = BasicConv2d(in_channels, 384, kernel_size=1)
261 | self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))
262 | self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))
263 |
264 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 448, kernel_size=1)
265 | self.branch3x3dbl_2 = BasicConv2d(448, 384, kernel_size=3, padding=1)
266 | self.branch3x3dbl_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))
267 | self.branch3x3dbl_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))
268 |
269 | self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1)
270 |
271 | def forward(self, x):
272 | branch1x1 = self.branch1x1(x)
273 |
274 | branch3x3 = self.branch3x3_1(x)
275 | branch3x3 = [
276 | self.branch3x3_2a(branch3x3),
277 | self.branch3x3_2b(branch3x3),
278 | ]
279 | branch3x3 = torch.cat(branch3x3, 1)
280 |
281 | branch3x3dbl = self.branch3x3dbl_1(x)
282 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
283 | branch3x3dbl = [
284 | self.branch3x3dbl_3a(branch3x3dbl),
285 | self.branch3x3dbl_3b(branch3x3dbl),
286 | ]
287 | branch3x3dbl = torch.cat(branch3x3dbl, 1)
288 |
289 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
290 | branch_pool = self.branch_pool(branch_pool)
291 |
292 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
293 | return torch.cat(outputs, 1)
294 |
295 |
296 | class InceptionAux(nn.Module):
297 |
298 | def __init__(self, in_channels, num_classes):
299 | super(InceptionAux, self).__init__()
300 | self.conv0 = BasicConv2d(in_channels, 128, kernel_size=1)
301 | self.conv1 = BasicConv2d(128, 768, kernel_size=5)
302 | self.conv1.stddev = 0.01
303 | self.fc_new = nn.Linear(768, num_classes)
304 | self.fc_new.stddev = 0.001
305 |
306 | def forward(self, x):
307 | # 17 x 17 x 768
308 | x = F.avg_pool2d(x, kernel_size=5, stride=3)
309 | # 5 x 5 x 768
310 | x = self.conv0(x)
311 | # 5 x 5 x 128
312 | x = self.conv1(x)
313 | # 1 x 1 x 768
314 | x = x.view(x.size(0), -1)
315 | # 768
316 | x = self.fc_new(x)
317 | # 1000
318 | return x
319 |
320 |
321 | class BasicConv2d(nn.Module):
322 |
323 | def __init__(self, in_channels, out_channels, **kwargs):
324 | super(BasicConv2d, self).__init__()
325 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
326 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
327 |
328 | def forward(self, x):
329 | x = self.conv(x)
330 | x = self.bn(x)
331 | return F.relu(x, inplace=True)
332 |
333 |
334 |
--------------------------------------------------------------------------------
/code/model.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.parallel
5 | from miscc.config import cfg
6 | from torch.autograd import Variable
7 | import torch.nn.functional as F
8 | from torch.nn import Upsample
9 |
10 |
11 | class GLU(nn.Module):
12 | def __init__(self):
13 | super(GLU, self).__init__()
14 |
15 | def forward(self, x):
16 | nc = x.size(1)
17 | assert nc % 2 == 0, 'channels dont divide 2!'
18 | nc = int(nc/2)
19 | return x[:, :nc] * F.sigmoid(x[:, nc:])
20 |
21 |
22 | def conv3x3(in_planes, out_planes):
23 | "3x3 convolution with padding"
24 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1,
25 | padding=1, bias=False)
26 |
27 |
28 | def convlxl(in_planes, out_planes):
29 | "3x3 convolution with padding"
30 | return nn.Conv2d(in_planes, out_planes, kernel_size=13, stride=1,
31 | padding=1, bias=False)
32 |
33 |
34 | def child_to_parent(child_c_code, classes_child, classes_parent):
35 |
36 | ratio = classes_child / classes_parent
37 | arg_parent = torch.argmax(child_c_code, dim = 1) / ratio
38 | parent_c_code = torch.zeros([child_c_code.size(0), classes_parent]).cuda()
39 | for i in range(child_c_code.size(0)):
40 | parent_c_code[i][arg_parent[i]] = 1
41 | return parent_c_code
42 |
43 |
44 | # ############## G networks ################################################
45 | # Upsale the spatial size by a factor of 2
46 | def upBlock(in_planes, out_planes):
47 | block = nn.Sequential(
48 | nn.Upsample(scale_factor=2, mode='nearest'),
49 | conv3x3(in_planes, out_planes * 2),
50 | nn.BatchNorm2d(out_planes * 2),
51 | GLU()
52 | )
53 | return block
54 |
55 | def sameBlock(in_planes, out_planes):
56 | block = nn.Sequential(
57 | conv3x3(in_planes, out_planes * 2),
58 | nn.BatchNorm2d(out_planes * 2),
59 | GLU()
60 | )
61 | return block
62 |
63 | # Keep the spatial size
64 | def Block3x3_relu(in_planes, out_planes):
65 | block = nn.Sequential(
66 | conv3x3(in_planes, out_planes * 2),
67 | nn.BatchNorm2d(out_planes * 2),
68 | GLU()
69 | )
70 | return block
71 |
72 |
73 | class ResBlock(nn.Module):
74 | def __init__(self, channel_num):
75 | super(ResBlock, self).__init__()
76 | self.block = nn.Sequential(
77 | conv3x3(channel_num, channel_num * 2),
78 | nn.BatchNorm2d(channel_num * 2),
79 | GLU(),
80 | conv3x3(channel_num, channel_num),
81 | nn.BatchNorm2d(channel_num)
82 | )
83 |
84 |
85 | def forward(self, x):
86 | residual = x
87 | out = self.block(x)
88 | out += residual
89 | return out
90 |
91 |
92 | class INIT_STAGE_G(nn.Module):
93 | def __init__(self, ngf, c_flag):
94 | super(INIT_STAGE_G, self).__init__()
95 | self.gf_dim = ngf
96 | self.c_flag= c_flag
97 |
98 | if self.c_flag==1 :
99 | self.in_dim = cfg.GAN.Z_DIM + cfg.SUPER_CATEGORIES
100 | elif self.c_flag==2:
101 | self.in_dim = cfg.GAN.Z_DIM + cfg.FINE_GRAINED_CATEGORIES
102 |
103 | self.define_module()
104 |
105 | def define_module(self):
106 | in_dim = self.in_dim
107 | ngf = self.gf_dim
108 | self.fc = nn.Sequential(
109 | nn.Linear(in_dim, ngf * 4 * 4 * 2, bias=False),
110 | nn.BatchNorm1d(ngf * 4 * 4 * 2),
111 | GLU())
112 |
113 | self.upsample1 = upBlock(ngf, ngf // 2)
114 | self.upsample2 = upBlock(ngf // 2, ngf // 4)
115 | self.upsample3 = upBlock(ngf // 4, ngf // 8)
116 | self.upsample4 = upBlock(ngf // 8, ngf // 16)
117 | self.upsample5 = upBlock(ngf // 16, ngf // 16)
118 |
119 |
120 | def forward(self, z_code, code):
121 |
122 | in_code = torch.cat((code, z_code), 1)
123 | out_code = self.fc(in_code)
124 | out_code = out_code.view(-1, self.gf_dim, 4, 4)
125 | out_code = self.upsample1(out_code)
126 | out_code = self.upsample2(out_code)
127 | out_code = self.upsample3(out_code)
128 | out_code = self.upsample4(out_code)
129 | out_code = self.upsample5(out_code)
130 |
131 | return out_code
132 |
133 |
134 | class NEXT_STAGE_G(nn.Module):
135 | def __init__(self, ngf, use_hrc = 1, num_residual=cfg.GAN.R_NUM):
136 | super(NEXT_STAGE_G, self).__init__()
137 | self.gf_dim = ngf
138 | if use_hrc == 1: # For parent stage
139 | self.ef_dim = cfg.SUPER_CATEGORIES
140 |
141 | else: # For child stage
142 | self.ef_dim = cfg.FINE_GRAINED_CATEGORIES
143 |
144 | self.num_residual = num_residual
145 | self.define_module()
146 |
147 | def _make_layer(self, block, channel_num):
148 | layers = []
149 | for i in range(self.num_residual):
150 | layers.append(block(channel_num))
151 | return nn.Sequential(*layers)
152 |
153 | def define_module(self):
154 | ngf = self.gf_dim
155 | efg = self.ef_dim
156 | self.jointConv = Block3x3_relu(ngf + efg, ngf)
157 | self.residual = self._make_layer(ResBlock, ngf)
158 | self.samesample = sameBlock(ngf, ngf // 2)
159 |
160 | def forward(self, h_code, code):
161 | s_size = h_code.size(2)
162 | code = code.view(-1, self.ef_dim, 1, 1)
163 | code = code.repeat(1, 1, s_size, s_size)
164 | h_c_code = torch.cat((code, h_code), 1)
165 | out_code = self.jointConv(h_c_code)
166 | out_code = self.residual(out_code)
167 | out_code = self.samesample(out_code)
168 | return out_code
169 |
170 |
171 | class GET_IMAGE_G(nn.Module):
172 | def __init__(self, ngf):
173 | super(GET_IMAGE_G, self).__init__()
174 | self.gf_dim = ngf
175 | self.img = nn.Sequential(
176 | conv3x3(ngf, 3),
177 | nn.Tanh()
178 | )
179 |
180 | def forward(self, h_code):
181 | out_img = self.img(h_code)
182 | return out_img
183 |
184 |
185 |
186 | class GET_MASK_G(nn.Module):
187 | def __init__(self, ngf):
188 | super(GET_MASK_G, self).__init__()
189 | self.gf_dim = ngf
190 | self.img = nn.Sequential(
191 | conv3x3(ngf, 1),
192 | nn.Sigmoid()
193 | )
194 |
195 | def forward(self, h_code):
196 | out_img = self.img(h_code)
197 | return out_img
198 |
199 |
200 | class G_NET(nn.Module):
201 | def __init__(self):
202 | super(G_NET, self).__init__()
203 | self.gf_dim = cfg.GAN.GF_DIM
204 | self.define_module()
205 | self.upsampling = Upsample(scale_factor = 2, mode = 'bilinear')
206 | self.scale_fimg = nn.UpsamplingBilinear2d(size = [126, 126])
207 |
208 | def define_module(self):
209 |
210 | #Background stage
211 | self.h_net1_bg = INIT_STAGE_G(self.gf_dim * 16, 2)
212 | self.img_net1_bg = GET_IMAGE_G(self.gf_dim) # Background generation network
213 |
214 | # Parent stage networks
215 | self.h_net1 = INIT_STAGE_G(self.gf_dim * 16, 1)
216 | self.h_net2 = NEXT_STAGE_G(self.gf_dim, use_hrc = 1)
217 | self.img_net2 = GET_IMAGE_G(self.gf_dim // 2) # Parent foreground generation network
218 | self.img_net2_mask= GET_MASK_G(self.gf_dim // 2) # Parent mask generation network
219 |
220 | # Child stage networks
221 | self.h_net3 = NEXT_STAGE_G(self.gf_dim // 2, use_hrc = 0)
222 | self.img_net3 = GET_IMAGE_G(self.gf_dim // 4) # Child foreground generation network
223 | self.img_net3_mask = GET_MASK_G(self.gf_dim // 4) # Child mask generation network
224 |
225 | def forward(self, z_code, c_code, p_code = None, bg_code = None):
226 |
227 | fake_imgs = [] # Will contain [background image, parent image, child image]
228 | fg_imgs = [] # Will contain [parent foreground, child foreground]
229 | mk_imgs = [] # Will contain [parent mask, child mask]
230 | fg_mk = [] # Will contain [masked parent foreground, masked child foreground]
231 |
232 | if cfg.TIED_CODES:
233 | p_code = child_to_parent(c_code, cfg.FINE_GRAINED_CATEGORIES, cfg.SUPER_CATEGORIES) # Obtaining the parent code from child code
234 | bg_code = c_code
235 |
236 | #Background stage
237 | h_code1_bg = self.h_net1_bg(z_code, bg_code)
238 | fake_img1 = self.img_net1_bg(h_code1_bg) # Background image
239 | fake_img1_126 = self.scale_fimg(fake_img1) # Resizing fake background image from 128x128 to the resolution which background discriminator expects: 126 x 126.
240 | fake_imgs.append(fake_img1_126)
241 |
242 | #Parent stage
243 | h_code1 = self.h_net1(z_code, p_code)
244 | h_code2 = self.h_net2(h_code1, p_code)
245 | fake_img2_foreground = self.img_net2(h_code2) # Parent foreground
246 | fake_img2_mask = self.img_net2_mask(h_code2) # Parent mask
247 | ones_mask_p = torch.ones_like(fake_img2_mask)
248 | opp_mask_p = ones_mask_p - fake_img2_mask
249 | fg_masked2 = torch.mul(fake_img2_foreground, fake_img2_mask)
250 | fg_mk.append(fg_masked2)
251 | bg_masked2 = torch.mul(fake_img1, opp_mask_p)
252 | fake_img2_final = fg_masked2 + bg_masked2 # Parent image
253 | fake_imgs.append(fake_img2_final)
254 | fg_imgs.append(fake_img2_foreground)
255 | mk_imgs.append(fake_img2_mask)
256 |
257 | #Child stage
258 | h_code3 = self.h_net3(h_code2, c_code)
259 | fake_img3_foreground = self.img_net3(h_code3) # Child foreground
260 | fake_img3_mask = self.img_net3_mask(h_code3) # Child mask
261 | ones_mask_c = torch.ones_like(fake_img3_mask)
262 | opp_mask_c = ones_mask_c - fake_img3_mask
263 | fg_masked3 = torch.mul(fake_img3_foreground, fake_img3_mask)
264 | fg_mk.append(fg_masked3)
265 | bg_masked3 = torch.mul(fake_img2_final, opp_mask_c)
266 | fake_img3_final = fg_masked3 + bg_masked3 # Child image
267 | fake_imgs.append(fake_img3_final)
268 | fg_imgs.append(fake_img3_foreground)
269 | mk_imgs.append(fake_img3_mask)
270 |
271 | return fake_imgs, fg_imgs, mk_imgs, fg_mk
272 |
273 |
274 | # ############## D networks ################################################
275 | def Block3x3_leakRelu(in_planes, out_planes):
276 | block = nn.Sequential(
277 | conv3x3(in_planes, out_planes),
278 | nn.BatchNorm2d(out_planes),
279 | nn.LeakyReLU(0.2, inplace=True)
280 | )
281 | return block
282 |
283 |
284 | # Downsale the spatial size by a factor of 2
285 | def downBlock(in_planes, out_planes):
286 | block = nn.Sequential(
287 | nn.Conv2d(in_planes, out_planes, 4, 2, 1, bias=False),
288 | nn.BatchNorm2d(out_planes),
289 | nn.LeakyReLU(0.2, inplace=True)
290 | )
291 | return block
292 |
293 |
294 |
295 | def encode_parent_and_child_img(ndf): # Defines the encoder network used for parent and child image
296 | encode_img = nn.Sequential(
297 | nn.Conv2d(3, ndf, 4, 2, 1, bias=False),
298 | nn.LeakyReLU(0.2, inplace=True),
299 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
300 | nn.BatchNorm2d(ndf * 2),
301 | nn.LeakyReLU(0.2, inplace=True),
302 | nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
303 | nn.BatchNorm2d(ndf * 4),
304 | nn.LeakyReLU(0.2, inplace=True),
305 | nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
306 | nn.BatchNorm2d(ndf * 8),
307 | nn.LeakyReLU(0.2, inplace=True)
308 | )
309 | return encode_img
310 |
311 |
312 | def encode_background_img(ndf): # Defines the encoder network used for background image
313 | encode_img = nn.Sequential(
314 | nn.Conv2d(3, ndf, 4, 2, 0, bias=False),
315 | nn.LeakyReLU(0.2, inplace=True),
316 | nn.Conv2d(ndf, ndf * 2, 4, 2, 0, bias=False),
317 | nn.LeakyReLU(0.2, inplace=True),
318 | nn.Conv2d(ndf * 2, ndf * 4, 4, 1, 0, bias=False),
319 | nn.LeakyReLU(0.2, inplace=True),
320 | )
321 | return encode_img
322 |
323 |
324 | class D_NET(nn.Module):
325 | def __init__(self, stg_no):
326 | super(D_NET, self).__init__()
327 | self.df_dim = cfg.GAN.DF_DIM
328 | self.stg_no = stg_no
329 |
330 | if self.stg_no == 0:
331 | self.ef_dim = 1
332 | elif self.stg_no == 1:
333 | self.ef_dim = cfg.SUPER_CATEGORIES
334 | elif self.stg_no == 2:
335 | self.ef_dim = cfg.FINE_GRAINED_CATEGORIES
336 | else:
337 | print ("Invalid stage number. Set stage number as follows:")
338 | print ("0 - for background stage")
339 | print ("1 - for parent stage")
340 | print ("2 - for child stage")
341 | print ("...Exiting now")
342 | sys.exit(0)
343 | self.define_module()
344 |
345 | def define_module(self):
346 | ndf = self.df_dim
347 | efg = self.ef_dim
348 |
349 | if self.stg_no == 0:
350 |
351 | self.patchgan_img_code_s16 = encode_background_img(ndf)
352 | self.uncond_logits1 = nn.Sequential(
353 | nn.Conv2d(ndf * 4, 1, kernel_size=4, stride=1),
354 | nn.Sigmoid())
355 | self.uncond_logits2 = nn.Sequential(
356 | nn.Conv2d(ndf * 4, 1, kernel_size=4, stride=1),
357 | nn.Sigmoid())
358 |
359 | else:
360 | self.img_code_s16 = encode_parent_and_child_img(ndf)
361 | self.img_code_s32 = downBlock(ndf * 8, ndf * 16)
362 | self.img_code_s32_1 = Block3x3_leakRelu(ndf * 16, ndf * 8)
363 |
364 | self.logits = nn.Sequential(
365 | nn.Conv2d(ndf * 8, efg, kernel_size=4, stride=4))
366 |
367 | self.jointConv = Block3x3_leakRelu(ndf * 8, ndf * 8)
368 | self.uncond_logits = nn.Sequential(
369 | nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4),
370 | nn.Sigmoid())
371 |
372 |
373 | def forward(self, x_var):
374 |
375 | if self.stg_no == 0:
376 | x_code = self.patchgan_img_code_s16(x_var)
377 | classi_score = self.uncond_logits1(x_code) # Background vs Foreground classification score (0 - background and 1 - foreground)
378 | rf_score = self.uncond_logits2(x_code) # Real/Fake score for the background image
379 | return [classi_score, rf_score]
380 |
381 | elif self.stg_no > 0:
382 | x_code = self.img_code_s16(x_var)
383 | x_code = self.img_code_s32(x_code)
384 | x_code = self.img_code_s32_1(x_code)
385 | h_c_code = self.jointConv(x_code)
386 | code_pred = self.logits(h_c_code) # Predicts the parent code and child code in parent and child stage respectively
387 | rf_score = self.uncond_logits(x_code) # This score is not used in parent stage while training
388 | return [code_pred.view(-1, self.ef_dim), rf_score.view(-1)]
389 |
390 |
391 |
392 |
--------------------------------------------------------------------------------
/code/trainer.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from six.moves import range
3 | import sys
4 | import numpy as np
5 | import os
6 | import random
7 | import time
8 | from PIL import Image
9 | from copy import deepcopy
10 |
11 | import torch.backends.cudnn as cudnn
12 | import torch
13 | import torch.nn as nn
14 | from torch.autograd import Variable
15 | import torch.optim as optim
16 | import torchvision.utils as vutils
17 | from torch.nn.functional import softmax, log_softmax
18 | from torch.nn.functional import cosine_similarity
19 | from tensorboardX import summary
20 | from tensorboardX import FileWriter
21 |
22 | from miscc.config import cfg
23 | from miscc.utils import mkdir_p
24 |
25 | from model import G_NET, D_NET
26 |
27 |
28 | # ################## Shared functions ###################
29 |
30 | def child_to_parent(child_c_code, classes_child, classes_parent):
31 |
32 | ratio = classes_child / classes_parent
33 | arg_parent = torch.argmax(child_c_code, dim = 1) / ratio
34 | parent_c_code = torch.zeros([child_c_code.size(0), classes_parent]).cuda()
35 | for i in range(child_c_code.size(0)):
36 | parent_c_code[i][arg_parent[i]] = 1
37 | return parent_c_code
38 |
39 |
40 | def weights_init(m):
41 | classname = m.__class__.__name__
42 | if classname.find('Conv') != -1:
43 | nn.init.orthogonal(m.weight.data, 1.0)
44 | elif classname.find('BatchNorm') != -1:
45 | m.weight.data.normal_(1.0, 0.02)
46 | m.bias.data.fill_(0)
47 | elif classname.find('Linear') != -1:
48 | nn.init.orthogonal(m.weight.data, 1.0)
49 | if m.bias is not None:
50 | m.bias.data.fill_(0.0)
51 |
52 |
53 | def load_params(model, new_param):
54 | for p, new_p in zip(model.parameters(), new_param):
55 | p.data.copy_(new_p)
56 |
57 |
58 | def copy_G_params(model):
59 | flatten = deepcopy(list(p.data for p in model.parameters()))
60 | return flatten
61 |
62 | def load_network(gpus):
63 | netG = G_NET()
64 | netG.apply(weights_init)
65 | netG = torch.nn.DataParallel(netG, device_ids=gpus)
66 | print(netG)
67 |
68 | netsD = []
69 | for i in range(3): # 3 discriminators for background, parent and child stage
70 | netsD.append(D_NET(i))
71 |
72 | for i in range(len(netsD)):
73 | netsD[i].apply(weights_init)
74 | netsD[i] = torch.nn.DataParallel(netsD[i], device_ids=gpus)
75 |
76 | count = 0
77 |
78 | if cfg.TRAIN.NET_G != '':
79 | state_dict = torch.load(cfg.TRAIN.NET_G)
80 | netG.load_state_dict(state_dict)
81 | print('Load ', cfg.TRAIN.NET_G)
82 |
83 | istart = cfg.TRAIN.NET_G.rfind('_') + 1
84 | iend = cfg.TRAIN.NET_G.rfind('.')
85 | count = cfg.TRAIN.NET_G[istart:iend]
86 | count = int(count) + 1
87 |
88 | if cfg.TRAIN.NET_D != '':
89 | for i in range(len(netsD)):
90 | print('Load %s_%d.pth' % (cfg.TRAIN.NET_D, i))
91 | state_dict = torch.load('%s_%d.pth' % (cfg.TRAIN.NET_D, i))
92 | netsD[i].load_state_dict(state_dict)
93 |
94 | if cfg.CUDA:
95 | netG.cuda()
96 | for i in range(len(netsD)):
97 | netsD[i].cuda()
98 |
99 | return netG, netsD, len(netsD), count
100 |
101 |
102 | def define_optimizers(netG, netsD):
103 | optimizersD = []
104 | num_Ds = len(netsD)
105 | for i in range(num_Ds):
106 | opt = optim.Adam(netsD[i].parameters(),
107 | lr=cfg.TRAIN.DISCRIMINATOR_LR,
108 | betas=(0.5, 0.999))
109 | optimizersD.append(opt)
110 |
111 | optimizerG = []
112 | optimizerG.append(optim.Adam(netG.parameters(),
113 | lr=cfg.TRAIN.GENERATOR_LR,
114 | betas=(0.5, 0.999)))
115 |
116 | for i in range(num_Ds):
117 | if i==1:
118 | opt = optim.Adam(netsD[i].parameters(),
119 | lr=cfg.TRAIN.GENERATOR_LR,
120 | betas=(0.5, 0.999))
121 | optimizerG.append(opt)
122 | elif i==2:
123 | opt = optim.Adam([{'params':netsD[i].module.jointConv.parameters()},{'params':netsD[i].module.logits.parameters()}],
124 | lr=cfg.TRAIN.GENERATOR_LR,
125 | betas=(0.5, 0.999))
126 | optimizerG.append(opt)
127 |
128 | return optimizerG, optimizersD
129 |
130 |
131 | def save_model(netG, avg_param_G, netsD, epoch, model_dir):
132 | load_params(netG, avg_param_G)
133 | torch.save(
134 | netG.state_dict(),
135 | '%s/netG_%d.pth' % (model_dir, epoch))
136 | for i in range(len(netsD)):
137 | netD = netsD[i]
138 | torch.save(
139 | netD.state_dict(),
140 | '%s/netD%d.pth' % (model_dir, i))
141 | print('Save G/Ds models.')
142 |
143 |
144 | def save_img_results(imgs_tcpu, fake_imgs, num_imgs,
145 | count, image_dir, summary_writer):
146 | num = cfg.TRAIN.VIS_COUNT
147 |
148 | real_img = imgs_tcpu[-1][0:num]
149 | vutils.save_image(
150 | real_img, '%s/real_samples%09d.png' % (image_dir,count),
151 | normalize=True)
152 | real_img_set = vutils.make_grid(real_img).numpy()
153 | real_img_set = np.transpose(real_img_set, (1, 2, 0))
154 | real_img_set = real_img_set * 255
155 | real_img_set = real_img_set.astype(np.uint8)
156 |
157 | for i in range(len(fake_imgs)):
158 | fake_img = fake_imgs[i][0:num]
159 |
160 | vutils.save_image(
161 | fake_img.data, '%s/count_%09d_fake_samples%d.png' %
162 | (image_dir, count, i), normalize=True)
163 |
164 | fake_img_set = vutils.make_grid(fake_img.data).cpu().numpy()
165 |
166 | fake_img_set = np.transpose(fake_img_set, (1, 2, 0))
167 | fake_img_set = (fake_img_set + 1) * 255 / 2
168 | fake_img_set = fake_img_set.astype(np.uint8)
169 | summary_writer.flush()
170 |
171 |
172 |
173 | class FineGAN_trainer(object):
174 | def __init__(self, output_dir, data_loader, imsize):
175 | if cfg.TRAIN.FLAG:
176 | self.model_dir = os.path.join(output_dir, 'Model')
177 | self.image_dir = os.path.join(output_dir, 'Image')
178 | self.log_dir = os.path.join(output_dir, 'Log')
179 | mkdir_p(self.model_dir)
180 | mkdir_p(self.image_dir)
181 | mkdir_p(self.log_dir)
182 | self.summary_writer = FileWriter(self.log_dir)
183 |
184 | s_gpus = cfg.GPU_ID.split(',')
185 | self.gpus = [int(ix) for ix in s_gpus]
186 | self.num_gpus = len(self.gpus)
187 | torch.cuda.set_device(self.gpus[0])
188 | cudnn.benchmark = True
189 |
190 | self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus
191 | self.max_epoch = cfg.TRAIN.MAX_EPOCH
192 | self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL
193 |
194 | self.data_loader = data_loader
195 | self.num_batches = len(self.data_loader)
196 |
197 |
198 |
199 | def prepare_data(self, data):
200 | fimgs, cimgs, c_code, _, warped_bbox = data
201 |
202 | real_vfimgs, real_vcimgs = [], []
203 | if cfg.CUDA:
204 | vc_code = Variable(c_code).cuda()
205 | for i in range(len(warped_bbox)):
206 | warped_bbox[i] = Variable(warped_bbox[i]).float().cuda()
207 |
208 | else:
209 | vc_code = Variable(c_code)
210 | for i in range(len(warped_bbox)):
211 | warped_bbox[i] = Variable(warped_bbox[i])
212 |
213 | if cfg.CUDA:
214 | real_vfimgs.append(Variable(fimgs[0]).cuda())
215 | real_vcimgs.append(Variable(cimgs[0]).cuda())
216 | else:
217 | real_vfimgs.append(Variable(fimgs[0]))
218 | real_vcimgs.append(Variable(cimgs[0]))
219 |
220 | return fimgs, real_vfimgs, real_vcimgs, vc_code, warped_bbox
221 |
222 | def train_Dnet(self, idx, count):
223 | if idx == 0 or idx == 2: # Discriminator is only trained in background and child stage. (NOT in parent stage)
224 | flag = count % 100
225 | batch_size = self.real_fimgs[0].size(0)
226 | criterion, criterion_one = self.criterion, self.criterion_one
227 |
228 | netD, optD = self.netsD[idx], self.optimizersD[idx]
229 | if idx == 0:
230 | real_imgs = self.real_fimgs[0]
231 |
232 | elif idx == 2:
233 | real_imgs = self.real_cimgs[0]
234 |
235 | fake_imgs = self.fake_imgs[idx]
236 | netD.zero_grad()
237 | real_logits = netD(real_imgs)
238 |
239 | if idx == 2:
240 | fake_labels = torch.zeros_like(real_logits[1])
241 | real_labels = torch.ones_like(real_logits[1])
242 | elif idx == 0:
243 |
244 | fake_labels = torch.zeros_like(real_logits[1])
245 | ext, output = real_logits
246 | weights_real = torch.ones_like(output)
247 | real_labels = torch.ones_like(output)
248 |
249 | for i in range(batch_size):
250 | x1 = self.warped_bbox[0][i]
251 | x2 = self.warped_bbox[2][i]
252 | y1 = self.warped_bbox[1][i]
253 | y2 = self.warped_bbox[3][i]
254 |
255 | a1 = max(torch.tensor(0).float().cuda(), torch.ceil((x1 - self.recp_field)/self.patch_stride))
256 | a2 = min(torch.tensor(self.n_out - 1).float().cuda(), torch.floor((self.n_out - 1) - ((126 - self.recp_field) - x2)/self.patch_stride)) + 1
257 | b1 = max(torch.tensor(0).float().cuda(), torch.ceil((y1 - self.recp_field)/self.patch_stride))
258 | b2 = min(torch.tensor(self.n_out - 1).float().cuda(), torch.floor((self.n_out - 1) - ((126 - self.recp_field) - y2)/self.patch_stride)) + 1
259 |
260 | if (x1 != x2 and y1 != y2):
261 | weights_real[i, :, a1.type(torch.int) : a2.type(torch.int) , b1.type(torch.int) : b2.type(torch.int)] = 0.0
262 |
263 | norm_fact_real = weights_real.sum()
264 | norm_fact_fake = weights_real.shape[0]*weights_real.shape[1]*weights_real.shape[2]*weights_real.shape[3]
265 | real_logits = ext, output
266 |
267 | fake_logits = netD(fake_imgs.detach())
268 |
269 |
270 |
271 | if idx == 0: # Background stage
272 |
273 | errD_real_uncond = criterion(real_logits[1], real_labels) # Real/Fake loss for 'real background' (on patch level)
274 | errD_real_uncond = torch.mul(errD_real_uncond, weights_real) # Masking output units which correspond to receptive fields which lie within the boundin box
275 | errD_real_uncond = errD_real_uncond.mean()
276 |
277 | errD_real_uncond_classi = criterion(real_logits[0], weights_real) # Background/foreground classification loss
278 | errD_real_uncond_classi = errD_real_uncond_classi.mean()
279 |
280 | errD_fake_uncond = criterion(fake_logits[1], fake_labels) # Real/Fake loss for 'fake background' (on patch level)
281 | errD_fake_uncond = errD_fake_uncond.mean()
282 |
283 | if (norm_fact_real > 0): # Normalizing the real/fake loss for background after accounting the number of masked members in the output.
284 | errD_real = errD_real_uncond * ((norm_fact_fake * 1.0) /(norm_fact_real * 1.0))
285 | else:
286 | errD_real = errD_real_uncond
287 |
288 | errD_fake = errD_fake_uncond
289 | errD = ((errD_real + errD_fake) * cfg.TRAIN.BG_LOSS_WT) + errD_real_uncond_classi
290 |
291 | if idx == 2:
292 |
293 | errD_real = criterion_one(real_logits[1], real_labels) # Real/Fake loss for the real image
294 | errD_fake = criterion_one(fake_logits[1], fake_labels) # Real/Fake loss for the fake image
295 | errD = errD_real + errD_fake
296 |
297 | if (idx == 0 or idx == 2):
298 | errD.backward()
299 | optD.step()
300 |
301 | if (flag == 0):
302 | summary_D = summary.scalar('D_loss%d' % idx, errD.data[0])
303 | self.summary_writer.add_summary(summary_D, count)
304 | summary_D_real = summary.scalar('D_loss_real_%d' % idx, errD_real.data[0])
305 | self.summary_writer.add_summary(summary_D_real, count)
306 | summary_D_fake = summary.scalar('D_loss_fake_%d' % idx, errD_fake.data[0])
307 | self.summary_writer.add_summary(summary_D_fake, count)
308 |
309 | return errD
310 |
311 | def train_Gnet(self, count):
312 | self.netG.zero_grad()
313 | for myit in range(len(self.netsD)):
314 | self.netsD[myit].zero_grad()
315 |
316 | errG_total = 0
317 | flag = count % 100
318 | batch_size = self.real_fimgs[0].size(0)
319 | criterion_one, criterion_class, c_code, p_code = self.criterion_one, self.criterion_class, self.c_code, self.p_code
320 |
321 | for i in range(self.num_Ds):
322 |
323 | outputs = self.netsD[i](self.fake_imgs[i])
324 |
325 | if i == 0 or i == 2: # real/fake loss for background (0) and child (2) stage
326 | real_labels = torch.ones_like(outputs[1])
327 | errG = criterion_one(outputs[1], real_labels)
328 | if i==0:
329 | errG = errG * cfg.TRAIN.BG_LOSS_WT
330 | errG_classi = criterion_one(outputs[0], real_labels) # Background/Foreground classification loss for the fake background image (on patch level)
331 | errG = errG + errG_classi
332 | errG_total = errG_total + errG
333 |
334 | if i == 1: # Mutual information loss for the parent stage (1)
335 | pred_p = self.netsD[i](self.fg_mk[i-1])
336 | errG_info = criterion_class(pred_p[0], torch.nonzero(p_code.long())[:,1])
337 | elif i == 2: # Mutual information loss for the child stage (2)
338 | pred_c = self.netsD[i](self.fg_mk[i-1])
339 | errG_info = criterion_class(pred_c[0], torch.nonzero(c_code.long())[:,1])
340 |
341 | if(i>0):
342 | errG_total = errG_total + errG_info
343 |
344 | if flag == 0:
345 | if i>0:
346 | summary_D_class = summary.scalar('Information_loss_%d' % i, errG_info.data[0])
347 | self.summary_writer.add_summary(summary_D_class, count)
348 |
349 | if i == 0 or i == 2:
350 | summary_D = summary.scalar('G_loss%d' % i, errG.data[0])
351 | self.summary_writer.add_summary(summary_D, count)
352 |
353 | errG_total.backward()
354 | for myit in range(len(self.netsD)):
355 | self.optimizerG[myit].step()
356 | return errG_total
357 |
358 | def train(self):
359 | self.netG, self.netsD, self.num_Ds, start_count = load_network(self.gpus)
360 | avg_param_G = copy_G_params(self.netG)
361 |
362 | self.optimizerG, self.optimizersD = \
363 | define_optimizers(self.netG, self.netsD)
364 |
365 | self.criterion = nn.BCELoss(reduce=False)
366 | self.criterion_one = nn.BCELoss()
367 | self.criterion_class = nn.CrossEntropyLoss()
368 |
369 | self.real_labels = \
370 | Variable(torch.FloatTensor(self.batch_size).fill_(1))
371 | self.fake_labels = \
372 | Variable(torch.FloatTensor(self.batch_size).fill_(0))
373 |
374 | nz = cfg.GAN.Z_DIM
375 | noise = Variable(torch.FloatTensor(self.batch_size, nz))
376 | fixed_noise = \
377 | Variable(torch.FloatTensor(self.batch_size, nz).normal_(0, 1))
378 | hard_noise = \
379 | Variable(torch.FloatTensor(self.batch_size, nz).normal_(0, 1)).cuda()
380 |
381 | self.patch_stride = float(4) # Receptive field stride given the current discriminator architecture for background stage
382 | self.n_out = 24 # Output size of the discriminator at the background stage; N X N where N = 24
383 | self.recp_field = 34 # Receptive field of each of the member of N X N
384 |
385 |
386 | if cfg.CUDA:
387 | self.criterion.cuda()
388 | self.criterion_one.cuda()
389 | self.criterion_class.cuda()
390 | self.real_labels = self.real_labels.cuda()
391 | self.fake_labels = self.fake_labels.cuda()
392 | noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
393 |
394 | print ("Starting normal FineGAN training..")
395 | count = start_count
396 | start_epoch = start_count // (self.num_batches)
397 |
398 | for epoch in range(start_epoch, self.max_epoch):
399 | start_t = time.time()
400 |
401 | for step, data in enumerate(self.data_loader, 0):
402 |
403 | self.imgs_tcpu, self.real_fimgs, self.real_cimgs, \
404 | self.c_code, self.warped_bbox = self.prepare_data(data)
405 |
406 | # Feedforward through Generator. Obtain stagewise fake images
407 | noise.data.normal_(0, 1)
408 | self.fake_imgs, self.fg_imgs, self.mk_imgs, self.fg_mk = \
409 | self.netG(noise, self.c_code)
410 |
411 | # Obtain the parent code given the child code
412 | self.p_code = child_to_parent(self.c_code, cfg.FINE_GRAINED_CATEGORIES, cfg.SUPER_CATEGORIES)
413 |
414 | # Update Discriminator networks
415 | errD_total = 0
416 | for i in range(self.num_Ds):
417 | if i == 0 or i == 2: # only at parent and child stage
418 | errD = self.train_Dnet(i, count)
419 | errD_total += errD
420 |
421 | # Update the Generator networks
422 | errG_total = self.train_Gnet(count)
423 | for p, avg_p in zip(self.netG.parameters(), avg_param_G):
424 | avg_p.mul_(0.999).add_(0.001, p.data)
425 |
426 | count = count + 1
427 |
428 | if count % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:
429 | backup_para = copy_G_params(self.netG)
430 | save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir)
431 | # Save images
432 | load_params(self.netG, avg_param_G)
433 | self.netG.eval()
434 | with torch.set_grad_enabled(False):
435 | self.fake_imgs, self.fg_imgs, self.mk_imgs, self.fg_mk = \
436 | self.netG(fixed_noise, self.c_code)
437 | save_img_results(self.imgs_tcpu, (self.fake_imgs + self.fg_imgs + self.mk_imgs + self.fg_mk), self.num_Ds,
438 | count, self.image_dir, self.summary_writer)
439 | self.netG.train()
440 | load_params(self.netG, backup_para)
441 |
442 | end_t = time.time()
443 | print('''[%d/%d][%d]
444 | Loss_D: %.2f Loss_G: %.2f Time: %.2fs
445 | '''
446 | % (epoch, self.max_epoch, self.num_batches,
447 | errD_total.data[0], errG_total.data[0],
448 | end_t - start_t))
449 |
450 | save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir)
451 |
452 | print ("Done with the normal training. Now performing hard negative training..")
453 | count = 0
454 | start_t = time.time()
455 | for step, data in enumerate(self.data_loader, 0):
456 |
457 | self.imgs_tcpu, self.real_fimgs, self.real_cimgs, \
458 | self.c_code, self.warped_bbox = self.prepare_data(data)
459 |
460 | if (count % 2) == 0: # Train on normal batch of images
461 |
462 | # Feedforward through Generator. Obtain stagewise fake images
463 | noise.data.normal_(0, 1)
464 | self.fake_imgs, self.fg_imgs, self.mk_imgs, self.fg_mk = \
465 | self.netG(noise, self.c_code)
466 |
467 | self.p_code = child_to_parent(self.c_code, cfg.FINE_GRAINED_CATEGORIES, cfg.SUPER_CATEGORIES)
468 |
469 | # Update discriminator networks
470 | errD_total = 0
471 | for i in range(self.num_Ds):
472 | if i == 0 or i == 2:
473 | errD = self.train_Dnet(i, count)
474 | errD_total += errD
475 |
476 |
477 | # Update the generator network
478 | errG_total = self.train_Gnet(count)
479 |
480 | else: # Train on degenerate images
481 | repeat_times=10
482 | all_hard_z = Variable(torch.zeros(self.batch_size * repeat_times, nz)).cuda()
483 | all_hard_class = Variable(torch.zeros(self.batch_size * repeat_times, cfg.FINE_GRAINED_CATEGORIES)).cuda()
484 | all_logits = Variable(torch.zeros(self.batch_size * repeat_times,)).cuda()
485 |
486 | for hard_it in range(repeat_times):
487 | hard_noise = hard_noise.data.normal_(0,1)
488 | hard_class = Variable(torch.zeros([self.batch_size, cfg.FINE_GRAINED_CATEGORIES])).cuda()
489 | my_rand_id=[]
490 |
491 | for c_it in range(self.batch_size):
492 | rand_class = random.sample(range(cfg.FINE_GRAINED_CATEGORIES),1);
493 | hard_class[c_it][rand_class] = 1
494 | my_rand_id.append(rand_class)
495 |
496 | all_hard_z[self.batch_size * hard_it : self.batch_size * (hard_it + 1)] = hard_noise.data
497 | all_hard_class[self.batch_size * hard_it : self.batch_size * (hard_it + 1)] = hard_class.data
498 | self.fake_imgs, self.fg_imgs, self.mk_imgs, self.fg_mk = self.netG(hard_noise.detach(), hard_class.detach())
499 |
500 | fake_logits = self.netsD[2](self.fg_mk[1].detach())
501 | smax_class = softmax(fake_logits[0], dim = 1)
502 |
503 | for b_it in range(self.batch_size):
504 | all_logits[(self.batch_size * hard_it) + b_it] = smax_class[b_it][my_rand_id[b_it]]
505 |
506 | sorted_val, indices_hard = torch.sort(all_logits)
507 | noise = all_hard_z[indices_hard[0 : self.batch_size]]
508 | self.c_code = all_hard_class[indices_hard[0 : self.batch_size]]
509 |
510 | self.fake_imgs, self.fg_imgs, self.mk_imgs, self.fg_mk = \
511 | self.netG(noise, self.c_code)
512 |
513 | self.p_code = child_to_parent(self.c_code, cfg.FINE_GRAINED_CATEGORIES, cfg.SUPER_CATEGORIES)
514 |
515 | # Update Discriminator networks
516 | errD_total = 0
517 | for i in range(self.num_Ds):
518 | if i == 0 or i == 2:
519 | errD = self.train_Dnet(i, count)
520 | errD_total += errD
521 |
522 | # Update generator network
523 | errG_total = self.train_Gnet(count)
524 |
525 | for p, avg_p in zip(self.netG.parameters(), avg_param_G):
526 | avg_p.mul_(0.999).add_(0.001, p.data)
527 | count = count + 1
528 |
529 | if count % cfg.TRAIN.SNAPSHOT_INTERVAL_HARDNEG == 0:
530 | backup_para = copy_G_params(self.netG)
531 | save_model(self.netG, avg_param_G, self.netsD, count+500000, self.model_dir)
532 | load_params(self.netG, avg_param_G)
533 | self.netG.eval()
534 | with torch.set_grad_enabled(False):
535 | self.fake_imgs, self.fg_imgs, self.mk_imgs, self.fg_mk = \
536 | self.netG(fixed_noise, self.c_code)
537 | save_img_results(self.imgs_tcpu, (self.fake_imgs + self.fg_imgs + self.mk_imgs + self.fg_mk), self.num_Ds,
538 | count, self.image_dir, self.summary_writer)
539 | self.netG.train()
540 | load_params(self.netG, backup_para)
541 |
542 | end_t = time.time()
543 |
544 | if (count % 100) == 0:
545 | print('''[%d/%d][%d]
546 | Loss_D: %.2f Loss_G: %.2f Time: %.2fs
547 | '''
548 | % (count, cfg.TRAIN.HARDNEG_MAX_ITER, self.num_batches,
549 | errD_total.data[0], errG_total.data[0],
550 | end_t - start_t))
551 |
552 | if (count == cfg.TRAIN.HARDNEG_MAX_ITER): # Hard negative training complete
553 | break
554 |
555 | save_model(self.netG, avg_param_G, self.netsD, count, self.model_dir)
556 | self.summary_writer.close()
557 |
558 |
559 |
560 | class FineGAN_evaluator(object):
561 |
562 | def __init__(self):
563 |
564 | self.save_dir = os.path.join(cfg.SAVE_DIR, 'images')
565 | mkdir_p(self.save_dir)
566 | s_gpus = cfg.GPU_ID.split(',')
567 | self.gpus = [int(ix) for ix in s_gpus]
568 | self.num_gpus = len(self.gpus)
569 | torch.cuda.set_device(self.gpus[0])
570 | cudnn.benchmark = True
571 | self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus
572 |
573 |
574 | def evaluate_finegan(self):
575 | if cfg.TRAIN.NET_G == '':
576 | print('Error: the path for model not found!')
577 | else:
578 | # Build and load the generator
579 | netG = G_NET()
580 | netG.apply(weights_init)
581 | netG = torch.nn.DataParallel(netG, device_ids=self.gpus)
582 | model_dict = netG.state_dict()
583 |
584 | state_dict = \
585 | torch.load(cfg.TRAIN.NET_G,
586 | map_location=lambda storage, loc: storage)
587 |
588 | state_dict = {k: v for k, v in state_dict.items() if k in model_dict}
589 |
590 | model_dict.update(state_dict)
591 | netG.load_state_dict(model_dict)
592 | print('Load ', cfg.TRAIN.NET_G)
593 |
594 | # Uncomment this to print Generator layers
595 | # print(netG)
596 |
597 | nz = cfg.GAN.Z_DIM
598 | noise = torch.FloatTensor(self.batch_size, nz)
599 | noise.data.normal_(0, 1)
600 |
601 | if cfg.CUDA:
602 | netG.cuda()
603 | noise = noise.cuda()
604 |
605 | netG.eval()
606 |
607 | background_class = cfg.TEST_BACKGROUND_CLASS
608 | parent_class = cfg.TEST_PARENT_CLASS
609 | child_class = cfg.TEST_CHILD_CLASS
610 | bg_code = torch.zeros([self.batch_size, cfg.FINE_GRAINED_CATEGORIES])
611 | p_code = torch.zeros([self.batch_size, cfg.SUPER_CATEGORIES])
612 | c_code = torch.zeros([self.batch_size, cfg.FINE_GRAINED_CATEGORIES])
613 |
614 | for j in range(self.batch_size):
615 | bg_code[j][background_class] = 1
616 | p_code[j][parent_class] = 1
617 | c_code[j][child_class] = 1
618 |
619 | fake_imgs, fg_imgs, mk_imgs, fgmk_imgs = netG(noise, c_code, p_code, bg_code) # Forward pass through the generator
620 |
621 | self.save_image(fake_imgs[0][0], self.save_dir, 'background')
622 | self.save_image(fake_imgs[1][0], self.save_dir, 'parent_final')
623 | self.save_image(fake_imgs[2][0], self.save_dir, 'child_final')
624 | self.save_image(fg_imgs[0][0], self.save_dir, 'parent_foreground')
625 | self.save_image(fg_imgs[1][0], self.save_dir, 'child_foreground')
626 | self.save_image(mk_imgs[0][0], self.save_dir, 'parent_mask')
627 | self.save_image(mk_imgs[1][0], self.save_dir, 'child_mask')
628 | self.save_image(fgmk_imgs[0][0], self.save_dir, 'parent_foreground_masked')
629 | self.save_image(fgmk_imgs[1][0], self.save_dir, 'child_foreground_masked')
630 |
631 |
632 | def save_image(self, images, save_dir, iname):
633 |
634 | img_name = '%s.png' % (iname)
635 | full_path = os.path.join(save_dir, img_name)
636 |
637 | if (iname.find('mask') == -1) or (iname.find('foreground') != -1):
638 | img = images.add(1).div(2).mul(255).clamp(0, 255).byte()
639 | ndarr = img.permute(1, 2, 0).data.cpu().numpy()
640 | im = Image.fromarray(ndarr)
641 | im.save(full_path)
642 |
643 | else:
644 | img = images.mul(255).clamp(0, 255).byte()
645 | ndarr = img.data.cpu().numpy()
646 | ndarr = np.reshape(ndarr, (ndarr.shape[-1], ndarr.shape[-1], 1))
647 | ndarr = np.repeat(ndarr, 3, axis=2)
648 | im = Image.fromarray(ndarr)
649 | im.save(full_path)
650 |
651 |
652 |
--------------------------------------------------------------------------------