├── .gitignore ├── LICENSE ├── OP-GAN ├── code │ ├── GlobalAttention.py │ ├── cfg │ │ ├── cfg_file_eval.yml │ │ └── cfg_file_train.yml │ ├── datasets.py │ ├── main.py │ ├── miscc │ │ ├── __init__.py │ │ ├── config.py │ │ ├── losses.py │ │ └── utils.py │ ├── model.py │ └── trainer.py ├── data │ └── .gitignore ├── environment.yml ├── models │ └── .gitignore ├── output │ └── .gitignore ├── sample.sh └── train.sh ├── README.md └── SOA ├── README.md ├── calculate_soa.py ├── captions ├── label_00_person.pkl ├── label_01_bicycle.pkl ├── label_02_car.pkl ├── label_03_motorcycle.pkl ├── label_04_plane.pkl ├── label_05_bus.pkl ├── label_06_train.pkl ├── label_07_truck.pkl ├── label_08_boat.pkl ├── label_09_trafficlight.pkl ├── label_10_hydrant.pkl ├── label_11_stopsign.pkl ├── label_12_parkingmeter.pkl ├── label_13_bench.pkl ├── label_14_bird.pkl ├── label_15_cat.pkl ├── label_16_dog.pkl ├── label_17_horse.pkl ├── label_18_sheep.pkl ├── label_19_cow.pkl ├── label_20_elephant.pkl ├── label_21_bear.pkl ├── label_22_zebra.pkl ├── label_23_giraffe.pkl ├── label_24_backpack.pkl ├── label_25_umbrella.pkl ├── label_26_handbag.pkl ├── label_27_tie.pkl ├── label_28_suitcase.pkl ├── label_29_frisbee.pkl ├── label_30_skis.pkl ├── label_31_snowboard.pkl ├── label_32_ball.pkl ├── label_33_kite.pkl ├── label_34_baseballbat.pkl ├── label_35_baseballglove.pkl ├── label_36_skateboard.pkl ├── label_37_surfboard.pkl ├── label_38_racket.pkl ├── label_39_bottle.pkl ├── label_40_wineglass.pkl ├── label_41_cup.pkl ├── label_42_fork.pkl ├── label_43_knife.pkl ├── label_44_spoon.pkl ├── label_45_bowl.pkl ├── label_46_banana.pkl ├── label_47_apple.pkl ├── label_48_sandwich.pkl ├── label_49_oranges.pkl ├── label_50_broccoli.pkl ├── label_51_carrot.pkl ├── label_52_hotdog.pkl ├── label_53_pizza.pkl ├── label_54_donut.pkl ├── label_55_cake.pkl ├── label_56_chair.pkl ├── label_57_sofa.pkl ├── label_58_pottedplant.pkl ├── label_59_bed.pkl ├── label_60_table.pkl ├── label_61_toilet.pkl ├── label_62_monitor.pkl ├── label_63_laptop.pkl ├── label_64_computermouse.pkl ├── label_65_remote.pkl ├── label_66_keyboard.pkl ├── label_67_cellphone.pkl ├── label_68_microwave.pkl ├── label_69_oven.pkl ├── label_70_toaster.pkl ├── label_71_sink.pkl ├── label_72_refrigerator.pkl ├── label_73_book.pkl ├── label_74_clock.pkl ├── label_75_vase.pkl ├── label_76_scissor.pkl ├── label_77_teddybear.pkl ├── label_78_hairdrier.pkl └── label_79_toothbrush.pkl ├── cfg └── yolov3.cfg ├── darknet.py ├── data └── coco.names ├── dataset.py ├── requirements.txt └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | SOA/output/* 2 | SOA/yolov3.weights 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # Environments 88 | .env 89 | .venv 90 | env/ 91 | venv/ 92 | ENV/ 93 | env.bak/ 94 | venv.bak/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | 109 | .idea/ 110 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Tobias Hinz 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /OP-GAN/code/GlobalAttention.py: -------------------------------------------------------------------------------- 1 | """ 2 | Global attention takes a matrix and a query metrix. 3 | Based on each query vector q, it computes a parameterized convex combination of the matrix 4 | based. 5 | H_1 H_2 H_3 ... H_n 6 | q q q q 7 | | | | | 8 | \ | | / 9 | ..... 10 | \ | / 11 | a 12 | Constructs a unit mapping. 13 | $$(H_1 + H_n, q) => (a)$$ 14 | Where H is of `batch x n x dim` and q is of `batch x dim`. 15 | 16 | References: 17 | https://github.com/OpenNMT/OpenNMT-py/tree/fc23dfef1ba2f258858b2765d24565266526dc76/onmt/modules 18 | http://www.aclweb.org/anthology/D15-1166 19 | """ 20 | 21 | import torch 22 | import torch.nn as nn 23 | 24 | 25 | def conv1x1(in_planes, out_planes): 26 | "1x1 convolution with padding" 27 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, 28 | padding=0, bias=False) 29 | 30 | 31 | def func_attention(query, context, gamma1): 32 | """ 33 | query: batch x ndf x queryL 34 | context: batch x ndf x ih x iw (sourceL=ihxiw) 35 | mask: batch_size x sourceL 36 | """ 37 | batch_size, queryL = query.size(0), query.size(2) 38 | ih, iw = context.size(2), context.size(3) 39 | sourceL = ih * iw 40 | 41 | # --> batch x sourceL x ndf 42 | context = context.view(batch_size, -1, sourceL) 43 | contextT = torch.transpose(context, 1, 2).contiguous() 44 | 45 | # Get attention 46 | # (batch x sourceL x ndf)(batch x ndf x queryL) 47 | # -->batch x sourceL x queryL 48 | attn = torch.bmm(contextT, query) # Eq. (7) in AttnGAN paper 49 | # --> batch*sourceL x queryL 50 | attn = attn.view(batch_size*sourceL, queryL) 51 | attn = nn.Softmax(dim=1)(attn) # Eq. (8) 52 | 53 | # --> batch x sourceL x queryL 54 | attn = attn.view(batch_size, sourceL, queryL) 55 | # --> batch*queryL x sourceL 56 | attn = torch.transpose(attn, 1, 2).contiguous() 57 | attn = attn.view(batch_size*queryL, sourceL) 58 | # Eq. (9) 59 | attn = attn * gamma1 60 | attn = nn.Softmax(dim=1)(attn) 61 | attn = attn.view(batch_size, queryL, sourceL) 62 | # --> batch x sourceL x queryL 63 | attnT = torch.transpose(attn, 1, 2).contiguous() 64 | 65 | # (batch x ndf x sourceL)(batch x sourceL x queryL) 66 | # --> batch x ndf x queryL 67 | weightedContext = torch.bmm(context, attnT) 68 | 69 | return weightedContext, attn.view(batch_size, -1, ih, iw) 70 | 71 | 72 | class GlobalAttentionGeneral(nn.Module): 73 | def __init__(self, idf, cdf): 74 | super(GlobalAttentionGeneral, self).__init__() 75 | self.conv_context = conv1x1(cdf, idf) 76 | self.sm = nn.Softmax(dim=1) 77 | self.mask = None 78 | 79 | def applyMask(self, mask): 80 | self.mask = mask # batch x sourceL 81 | 82 | def forward(self, input, context): 83 | """ 84 | input: batch x idf x ih x iw (queryL=ihxiw) 85 | context: batch x cdf x sourceL 86 | """ 87 | ih, iw = input.size(2), input.size(3) 88 | queryL = ih * iw 89 | batch_size, sourceL = context.size(0), context.size(2) 90 | 91 | # --> batch x queryL x idf 92 | target = input.view(batch_size, -1, queryL) 93 | targetT = torch.transpose(target, 1, 2).contiguous() 94 | # batch x cdf x sourceL --> batch x cdf x sourceL x 1 95 | sourceT = context.unsqueeze(3) 96 | # --> batch x idf x sourceL 97 | sourceT = self.conv_context(sourceT).squeeze(3) 98 | 99 | # Get attention 100 | # (batch x queryL x idf)(batch x idf x sourceL) 101 | # -->batch x queryL x sourceL 102 | attn = torch.bmm(targetT, sourceT) 103 | # --> batch*queryL x sourceL 104 | attn = attn.view(batch_size*queryL, sourceL) 105 | if self.mask is not None: 106 | # batch_size x sourceL --> batch_size*queryL x sourceL 107 | mask = self.mask.repeat(queryL, 1) 108 | attn.data.masked_fill_(mask.data, -float('inf')) 109 | # print(attn.shape) 110 | # exit() 111 | attn = self.sm(attn) # Eq. (2) 112 | # --> batch x queryL x sourceL 113 | attn = attn.view(batch_size, queryL, sourceL) 114 | # --> batch x sourceL x queryL 115 | attn = torch.transpose(attn, 1, 2).contiguous() 116 | 117 | # (batch x idf x sourceL)(batch x sourceL x queryL) 118 | # --> batch x idf x queryL 119 | weightedContext = torch.bmm(sourceT, attn) 120 | weightedContext = weightedContext.view(batch_size, -1, ih, iw) 121 | attn = attn.view(batch_size, -1, ih, iw) 122 | 123 | return weightedContext, attn 124 | -------------------------------------------------------------------------------- /OP-GAN/code/cfg/cfg_file_eval.yml: -------------------------------------------------------------------------------- 1 | CONFIG_NAME: 'attn2' 2 | 3 | DATASET_NAME: 'coco' 4 | DATA_DIR: 'data' 5 | WORKERS: 1 6 | 7 | TREE: 8 | BRANCH_NUM: 3 9 | 10 | TRAIN: 11 | FLAG: False 12 | NET_G: 'models/op-gan.pth' 13 | B_NET_D: False 14 | BATCH_SIZE: [50] 15 | NET_E: 'models/coco/text_encoder100.pth' 16 | OPTIMIZE_DATA_LOADING: False 17 | GENERATED_BBOXES: True 18 | 19 | GAN: 20 | DISC_FEAT_DIM: 96 21 | GEN_FEAT_DIM: 48 22 | GLOBAL_Z_DIM: 100 23 | TEXT_CONDITION_DIM: 100 24 | RESIDUAL_NUM: 3 25 | 26 | TEXT: 27 | EMBEDDING_DIM: 256 28 | CAPTIONS_PER_IMAGE: 5 29 | WORDS_NUM: 20 30 | 31 | -------------------------------------------------------------------------------- /OP-GAN/code/cfg/cfg_file_train.yml: -------------------------------------------------------------------------------- 1 | CONFIG_NAME: 'glu-gan2' 2 | 3 | DATASET_NAME: 'coco' 4 | DATA_DIR: 'data' 5 | WORKERS: 8 6 | DEBUG: False 7 | 8 | TREE: 9 | BRANCH_NUM: 3 10 | 11 | TRAIN: 12 | FLAG: True 13 | NET_G: "" # 14 | B_NET_D: True 15 | BATCH_SIZE: [32, 28, 24, 24, 20, 16, 16, 16, 12, 12, 12] # batch sizes for the different batches containing [0, 1, ...] objects per image; [24] if OPTIMIZE_DATA_LOADING is False 16 | MAX_EPOCH: 120 17 | DISCRIMINATOR_LR: 0.0002 18 | GENERATOR_LR: 0.0002 19 | NET_E: 'models/coco/text_encoder100.pth' 20 | BBOX_LOSS: True 21 | OPTIMIZE_DATA_LOADING: True 22 | EMPTY_CACHE: True 23 | SMOOTH: 24 | GAMMA1: 4.0 # 1,2,5 good 4 best 10&100bad 25 | GAMMA2: 5.0 26 | GAMMA3: 10.0 # 10good 1&100bad 27 | LAMBDA: 50.0 28 | 29 | GAN: 30 | DISC_FEAT_DIM: 96 31 | GEN_FEAT_DIM: 48 32 | GLOBAL_Z_DIM: 100 33 | TEXT_CONDITION_DIM: 100 34 | RESIDUAL_NUM: 3 35 | 36 | TEXT: 37 | EMBEDDING_DIM: 256 38 | CAPTIONS_PER_IMAGE: 5 39 | WORDS_NUM: 12 40 | -------------------------------------------------------------------------------- /OP-GAN/code/datasets.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | 7 | from nltk.tokenize import RegexpTokenizer 8 | from collections import defaultdict 9 | 10 | import torch.utils.data as data 11 | from torch.autograd import Variable 12 | import torchvision.transforms as transforms 13 | 14 | import numpy.random as random 15 | import pickle 16 | 17 | from miscc.utils import * 18 | 19 | logger = logging.getLogger() 20 | 21 | 22 | def prepare_data(data, eval=False): 23 | if eval: 24 | imgs, captions, captions_lens, class_ids, keys, transformation_matrices, label, bbox = data 25 | else: 26 | imgs, captions, captions_lens, class_ids, keys, transformation_matrices, label = data 27 | 28 | # sort data by the length in a decreasing order 29 | sorted_cap_lens, sorted_cap_indices = torch.sort(captions_lens, 0, True) 30 | 31 | real_imgs = [] 32 | for i in range(len(imgs)): 33 | imgs[i] = imgs[i][sorted_cap_indices] 34 | real_imgs.append(Variable(imgs[i]).to(cfg.DEVICE).detach()) 35 | 36 | captions = captions[sorted_cap_indices].squeeze() 37 | class_ids = class_ids[sorted_cap_indices].numpy() 38 | transformation_matrices[0] = transformation_matrices[0][sorted_cap_indices] 39 | transformation_matrices[1] = transformation_matrices[1][sorted_cap_indices] 40 | label = label[sorted_cap_indices] 41 | keys = [keys[i] for i in sorted_cap_indices.numpy()] 42 | if cfg.CUDA: 43 | captions = Variable(captions).cuda().detach() 44 | sorted_cap_lens = Variable(sorted_cap_lens).cuda().detach() 45 | transformation_matrices[0] = transformation_matrices[0].cuda().detach() 46 | transformation_matrices[1] = transformation_matrices[1].cuda().detach() 47 | label = label.cuda().detach() 48 | else: 49 | captions = Variable(captions).detach() 50 | sorted_cap_lens = Variable(sorted_cap_lens).detach() 51 | 52 | if eval: 53 | return [real_imgs, captions, sorted_cap_lens, class_ids, keys, transformation_matrices, label, bbox] 54 | else: 55 | return [real_imgs, captions, sorted_cap_lens, class_ids, keys, transformation_matrices, label] 56 | 57 | 58 | def get_imgs(img_path, imsize, max_objects, bbox=None, transform=None, normalize=None): 59 | img = Image.open(img_path).convert('RGB') 60 | if transform is not None: 61 | img = transform(img) 62 | 63 | img, bbox_scaled = crop_imgs(img, bbox, max_objects=max_objects) 64 | 65 | ret = [] 66 | if cfg.GAN.B_DCGAN: 67 | ret = [normalize(img)] 68 | else: 69 | for i in range(cfg.TREE.BRANCH_NUM): 70 | re_img = transforms.ToPILImage()(img) 71 | re_img = transforms.Resize((imsize[i], imsize[i]))(re_img) 72 | ret.append(normalize(re_img)) 73 | 74 | return ret, bbox_scaled 75 | 76 | 77 | def crop_imgs(image, bbox, max_objects): 78 | ori_size = 268 79 | imsize = 256 80 | 81 | flip_img = random.random() < 0.5 82 | img_crop = ori_size - imsize 83 | h1 = int(np.floor((img_crop) * np.random.random())) 84 | w1 = int(np.floor((img_crop) * np.random.random())) 85 | 86 | bbox_scaled = np.zeros_like(bbox) 87 | bbox_scaled[...] = -1.0 88 | 89 | for idx in range(max_objects): 90 | bbox_tmp = bbox[idx] 91 | if bbox_tmp[0] == -1: 92 | break 93 | 94 | x_new = max(bbox_tmp[0] * float(ori_size) - h1, 0) / float(imsize) 95 | y_new = max(bbox_tmp[1] * float(ori_size) - w1, 0) / float(imsize) 96 | 97 | width_new = min((float(ori_size)/imsize) * bbox_tmp[2], 1.0) 98 | if x_new + width_new > 0.999: 99 | width_new = 1.0 - x_new - 0.001 100 | 101 | height_new = min((float(ori_size)/imsize) * bbox_tmp[3], 1.0) 102 | if y_new + height_new > 0.999: 103 | height_new = 1.0 - y_new - 0.001 104 | 105 | if flip_img: 106 | x_new = 1.0-x_new-width_new 107 | 108 | bbox_scaled[idx] = [x_new, y_new, width_new, height_new] 109 | 110 | cropped_image = image[:, w1: w1 + imsize, h1: h1 + imsize] 111 | 112 | if flip_img: 113 | idx = [i for i in reversed(range(cropped_image.shape[2]))] 114 | idx = torch.LongTensor(idx) 115 | transformed_image = torch.index_select(cropped_image, 2, idx) 116 | else: 117 | transformed_image = cropped_image 118 | 119 | return transformed_image, bbox_scaled 120 | 121 | 122 | class TextDataset(data.Dataset): 123 | def __init__(self, data_dir, img_dir, split='train', base_size=64, 124 | transform=None, target_transform=None, eval=False, use_generated_bboxes=False): 125 | self.transform = transform 126 | self.norm = transforms.Compose([ 127 | transforms.ToTensor(), 128 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 129 | self.target_transform = target_transform 130 | self.embeddings_num = cfg.TEXT.CAPTIONS_PER_IMAGE 131 | self.img_dir = os.path.join(data_dir, img_dir) 132 | self.split_dir = os.path.join(data_dir, split) 133 | self.eval = eval 134 | self.use_generated_bboxes = use_generated_bboxes 135 | 136 | self.imsize = [] 137 | for i in range(cfg.TREE.BRANCH_NUM): 138 | self.imsize.append(base_size) 139 | base_size = base_size * 2 140 | 141 | self.data = [] 142 | self.data_dir = data_dir 143 | self.bbox = self.load_bbox() 144 | self.labels = self.load_labels() 145 | self.split_dir = os.path.join(data_dir, split) 146 | self.max_objects = 3 147 | if cfg.TRAIN.OPTIMIZE_DATA_LOADING or self.use_generated_bboxes: 148 | self.max_objects = 10 149 | 150 | self.filenames, self.captions, self.ixtoword, \ 151 | self.wordtoix, self.n_words = self.load_text_data(data_dir, split) 152 | 153 | self.class_id = self.load_class_id(self.split_dir, len(self.filenames)) 154 | self.number_example = len(self.filenames) 155 | 156 | def load_bbox(self): 157 | bbox_path = os.path.join(self.split_dir, 'bboxes.pickle') 158 | if self.use_generated_bboxes: 159 | bbox_path = os.path.join(self.split_dir, 'bboxes_generated.pickle') 160 | elif cfg.TRAIN.OPTIMIZE_DATA_LOADING: 161 | bbox_path = os.path.join(self.split_dir, 'bboxes_large.pickle') 162 | with open(bbox_path, "rb") as f: 163 | bboxes = pickle.load(f, encoding='latin1') 164 | bboxes = np.array(bboxes) 165 | logger.info("Load bounding boxes: %s", bboxes.shape) 166 | return bboxes 167 | 168 | def load_labels(self): 169 | label_path = os.path.join(self.split_dir, 'labels.pickle') 170 | if self.use_generated_bboxes: 171 | label_path = os.path.join(self.split_dir, 'labels_generated.pickle') 172 | elif cfg.TRAIN.OPTIMIZE_DATA_LOADING: 173 | label_path = os.path.join(self.split_dir, 'labels_large.pickle') 174 | with open(label_path, "rb") as f: 175 | labels = pickle.load(f, encoding='latin1') 176 | labels = np.array(labels) 177 | logger.info("Load Labels: %s", labels.shape) 178 | return labels 179 | 180 | def load_captions(self, data_dir, filenames): 181 | all_captions = [] 182 | for i in range(len(filenames)): 183 | cap_path = '%s/text/%s.txt' % (data_dir, filenames[i]) 184 | with open(cap_path, "r") as f: 185 | captions = f.read().decode('utf8').split('\n') 186 | cnt = 0 187 | for cap in captions: 188 | if len(cap) == 0: 189 | continue 190 | cap = cap.replace("\ufffd\ufffd", " ") 191 | # picks out sequences of alphanumeric characters as tokens 192 | # and drops everything else 193 | tokenizer = RegexpTokenizer(r'\w+') 194 | tokens = tokenizer.tokenize(cap.lower()) 195 | # logger.info('tokens: %s', tokens) 196 | if len(tokens) == 0: 197 | logger.info('cap: %s', cap) 198 | continue 199 | 200 | tokens_new = [] 201 | for t in tokens: 202 | t = t.encode('ascii', 'ignore').decode('ascii') 203 | if len(t) > 0: 204 | tokens_new.append(t) 205 | all_captions.append(tokens_new) 206 | cnt += 1 207 | if cnt == self.embeddings_num: 208 | break 209 | if cnt < self.embeddings_num: 210 | logger.error('ERROR: the captions for %s less than %d' 211 | % (filenames[i], cnt)) 212 | return all_captions 213 | 214 | def build_dictionary(self, train_captions, test_captions): 215 | word_counts = defaultdict(float) 216 | captions = train_captions + test_captions 217 | for sent in captions: 218 | for word in sent: 219 | word_counts[word] += 1 220 | 221 | vocab = [w for w in word_counts if word_counts[w] >= 0] 222 | 223 | ixtoword = {} 224 | ixtoword[0] = '' 225 | wordtoix = {} 226 | wordtoix[''] = 0 227 | ix = 1 228 | for w in vocab: 229 | wordtoix[w] = ix 230 | ixtoword[ix] = w 231 | ix += 1 232 | 233 | train_captions_new = [] 234 | for t in train_captions: 235 | rev = [] 236 | for w in t: 237 | if w in wordtoix: 238 | rev.append(wordtoix[w]) 239 | # rev.append(0) # do not need '' token 240 | train_captions_new.append(rev) 241 | 242 | test_captions_new = [] 243 | for t in test_captions: 244 | rev = [] 245 | for w in t: 246 | if w in wordtoix: 247 | rev.append(wordtoix[w]) 248 | # rev.append(0) # do not need '' token 249 | test_captions_new.append(rev) 250 | 251 | return [train_captions_new, test_captions_new, 252 | ixtoword, wordtoix, len(ixtoword)] 253 | 254 | def load_text_data(self, data_dir, split): 255 | filepath = os.path.join(data_dir, 'captions.pickle') 256 | train_names = self.load_filenames(data_dir, 'train') 257 | test_names = self.load_filenames(data_dir, 'test') 258 | if not os.path.isfile(filepath): 259 | train_captions = self.load_captions(data_dir, train_names) 260 | test_captions = self.load_captions(data_dir, test_names) 261 | 262 | train_captions, test_captions, ixtoword, wordtoix, n_words = \ 263 | self.build_dictionary(train_captions, test_captions) 264 | with open(filepath, 'wb') as f: 265 | pickle.dump([train_captions, test_captions, 266 | ixtoword, wordtoix], f, protocol=2) 267 | logger.info('Save captions to: %s', filepath) 268 | else: 269 | with open(filepath, 'rb') as f: 270 | x = pickle.load(f, encoding='latin1') 271 | train_captions, test_captions = x[0], x[1] 272 | ixtoword, wordtoix = x[2], x[3] 273 | del x 274 | n_words = len(ixtoword) 275 | logger.info('Load captions from: %s', filepath) 276 | if split == 'train': 277 | # a list of list: each list contains 278 | # the indices of words in a sentence 279 | captions = train_captions 280 | filenames = train_names 281 | else: # split=='test' 282 | captions = test_captions 283 | filenames = test_names 284 | logger.info("Captions: %s", len(captions)) 285 | return filenames, captions, ixtoword, wordtoix, n_words 286 | 287 | def load_class_id(self, data_dir, total_num): 288 | if os.path.isfile(data_dir + '/class_info.pickle'): 289 | with open(data_dir + '/class_info.pickle', 'rb') as f: 290 | class_id = pickle.load(f, encoding='latin1') 291 | else: 292 | class_id = np.arange(total_num) 293 | return class_id 294 | 295 | def load_filenames(self, data_dir, split): 296 | filepath = '%s/%s/filenames.pickle' % (data_dir, split) 297 | if os.path.isfile(filepath): 298 | with open(filepath, 'rb') as f: 299 | filenames = pickle.load(f, encoding='latin1') 300 | logger.info('Load filenames from: %s (%d)' % (filepath, len(filenames))) 301 | else: 302 | filenames = [] 303 | return filenames 304 | 305 | def get_caption(self, sent_ix): 306 | # a list of indices for a sentence 307 | sent_caption = np.asarray(self.captions[sent_ix]).astype('int64') 308 | if (sent_caption == 0).sum() > 0: 309 | logger.error('ERROR: do not need END (0) token', sent_caption) 310 | num_words = len(sent_caption) 311 | # pad with 0s (i.e., '') 312 | x = np.zeros((cfg.TEXT.WORDS_NUM, 1), dtype='int64') 313 | x_len = num_words 314 | if num_words <= cfg.TEXT.WORDS_NUM: 315 | x[:num_words, 0] = sent_caption 316 | else: 317 | ix = list(np.arange(num_words)) # 1, 2, 3,..., maxNum 318 | np.random.shuffle(ix) 319 | ix = ix[:cfg.TEXT.WORDS_NUM] 320 | ix = np.sort(ix) 321 | x[:, 0] = sent_caption[ix] 322 | x_len = cfg.TEXT.WORDS_NUM 323 | return x, x_len 324 | 325 | def get_transformation_matrices(self, bbox): 326 | bbox = torch.from_numpy(bbox) 327 | bbox = bbox.view(-1, 4) 328 | transf_matrices_inv = compute_transformation_matrix_inverse(bbox) 329 | transf_matrices_inv = transf_matrices_inv.view(self.max_objects, 2, 3) 330 | transf_matrices = compute_transformation_matrix(bbox) 331 | transf_matrices = transf_matrices.view(self.max_objects, 2, 3) 332 | 333 | return transf_matrices, transf_matrices_inv 334 | 335 | def get_one_hot_labels(self, label): 336 | labels = torch.from_numpy(label) 337 | labels = labels.long() 338 | # remove -1 to enable one-hot converting 339 | labels[labels < 0] = cfg.TEXT.CLASSES_NUM - 1 340 | label_one_hot = torch.FloatTensor(labels.shape[0], cfg.TEXT.CLASSES_NUM).fill_(0) 341 | label_one_hot = label_one_hot.scatter_(1, labels, 1).float() 342 | 343 | return label_one_hot 344 | 345 | def __getitem__(self, index): 346 | # 347 | key = self.filenames[index] 348 | cls_id = self.class_id[index] 349 | # 350 | if self.bbox is not None: 351 | if self.use_generated_bboxes: 352 | rand_num = np.random.randint(0, 5, 1) 353 | bbox = self.bbox[index, rand_num].squeeze() 354 | else: 355 | bbox = self.bbox[index] 356 | 357 | img_name = '%s/%s.jpg' % (self.img_dir, key) 358 | imgs, bbox_scaled = get_imgs(img_name, self.imsize, self.max_objects, 359 | bbox, self.transform, normalize=self.norm) 360 | 361 | transformation_matrices = self.get_transformation_matrices(bbox_scaled) 362 | 363 | # load label 364 | if self.use_generated_bboxes: 365 | label = np.expand_dims(self.labels[index, rand_num].squeeze(), 1) 366 | else: 367 | label = self.labels[index] 368 | 369 | label = self.get_one_hot_labels(label) 370 | 371 | # random select a sentence 372 | sent_ix = random.randint(0, self.embeddings_num) 373 | new_sent_ix = index * self.embeddings_num + sent_ix 374 | caps, cap_len = self.get_caption(new_sent_ix) 375 | 376 | if self.eval: 377 | return imgs, caps, cap_len, cls_id, key, transformation_matrices, label, bbox_scaled 378 | return imgs, caps, cap_len, cls_id, key, transformation_matrices, label 379 | 380 | def __len__(self): 381 | return len(self.filenames) 382 | -------------------------------------------------------------------------------- /OP-GAN/code/main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import logging 4 | 5 | from miscc.config import cfg, cfg_from_file 6 | from datasets import TextDataset 7 | from miscc.utils import initialize_logging, mkdir_p 8 | from trainer import condGANTrainer as trainer 9 | 10 | import os 11 | import sys 12 | import time 13 | import random 14 | import pprint 15 | import datetime 16 | import dateutil.tz 17 | import argparse 18 | import numpy as np 19 | from shutil import copyfile 20 | import pickle 21 | 22 | import torch 23 | import torchvision.transforms as transforms 24 | 25 | dir_path = (os.path.abspath(os.path.join(os.path.realpath(__file__), './.'))) 26 | sys.path.append(dir_path) 27 | 28 | logger = logging.getLogger() 29 | 30 | 31 | def parse_args(): 32 | parser = argparse.ArgumentParser(description='Train a AttnGAN network') 33 | parser.add_argument('--cfg', dest='cfg_file', help='config file', type=str) 34 | parser.add_argument('--resume', dest='resume', type=str, default='') 35 | parser.add_argument('--data_dir', dest='data_dir', type=str, default='') 36 | parser.add_argument('--net_g', dest='net_g', type=str, default='') 37 | parser.add_argument('--max_objects', type=int, default=10) 38 | args = parser.parse_args() 39 | return args 40 | 41 | 42 | def get_dataset_indices(split="train", num_max_objects=10): 43 | if cfg.TRAIN.OPTIMIZE_DATA_LOADING: 44 | label_path = os.path.join(os.path.join(cfg.DATA_DIR, split), 'labels_large.pickle') 45 | with open(label_path, "rb") as f: 46 | labels = pickle.load(f, encoding='latin1') 47 | labels = np.array(labels) 48 | dataset_indices = [] 49 | 50 | for _i in range(num_max_objects+1): 51 | dataset_indices.append([]) 52 | 53 | for index, label in enumerate(labels): 54 | for idx, l in enumerate(label): 55 | if l == -1: 56 | dataset_indices[idx].append(index) 57 | break 58 | else: 59 | dataset_indices[-1].append(index) 60 | 61 | return dataset_indices 62 | 63 | 64 | if __name__ == "__main__": 65 | args = parse_args() 66 | if args.cfg_file is not None: 67 | cfg_from_file(args.cfg_file) 68 | 69 | if cfg.SEED == -1: 70 | cfg.SEED = random.randint(1, 10000) 71 | random.seed(cfg.SEED) 72 | np.random.seed(cfg.SEED) 73 | torch.manual_seed(cfg.SEED) 74 | if cfg.CUDA: 75 | torch.cuda.manual_seed_all(cfg.SEED) 76 | 77 | if args.resume == "": 78 | resume = False 79 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 80 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') 81 | output_dir = os.path.join(cfg.OUTPUT_DIR, '%s_%s_%s_%s' 82 | % (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp, cfg.SEED)) 83 | mkdir_p(output_dir) 84 | else: 85 | assert os.path.isdir(args.resume) 86 | resume = True 87 | output_dir = args.resume 88 | initialize_logging(output_dir, to_file=True) 89 | logger.info("Using output dir: %s" % output_dir) 90 | logger.info("Using seed {}".format(cfg.SEED)) 91 | 92 | if not (torch.cuda.is_available() and cfg.CUDA): 93 | cfg.CUDA = False 94 | cfg.DEVICE = torch.device('cpu') 95 | else: 96 | cfg.CUDA = True 97 | cfg.DEVICE = torch.device('cuda:0') 98 | logger.info('USING DEVICE %s' % cfg.DEVICE) 99 | 100 | if args.data_dir != '': 101 | cfg.DATA_DIR = args.data_dir 102 | if args.net_g != "": 103 | cfg.TRAIN.NET_G = args.net_g 104 | logger.info('Using config: ') 105 | pprint.pprint(cfg) 106 | 107 | split_dir, bshuffle = 'train', True 108 | eval = False 109 | img_dir = "train/train2014" 110 | if not cfg.TRAIN.FLAG: 111 | split_dir = 'test' 112 | img_dir = "test/val2014" 113 | eval = True 114 | 115 | # Get data loader 116 | imsize = cfg.TREE.BASE_SIZE * (2 ** (cfg.TREE.BRANCH_NUM - 1)) 117 | image_transform = transforms.Compose([ 118 | transforms.Resize((268, 268)), 119 | transforms.ToTensor()]) 120 | 121 | if cfg.TRAIN.OPTIMIZE_DATA_LOADING: 122 | num_max_objects = 10 123 | dataset_indices = get_dataset_indices(num_max_objects=num_max_objects, split="train"\ 124 | if cfg.TRAIN.FLAG else "test") 125 | dataset = TextDataset(cfg.DATA_DIR, img_dir, split_dir, base_size=cfg.TREE.BASE_SIZE, 126 | transform=image_transform, eval=eval, use_generated_bboxes=cfg.TRAIN.GENERATED_BBOXES) 127 | assert dataset 128 | dataset_subsets = [] 129 | dataloaders = [] 130 | for max_objects in range(num_max_objects+1): 131 | subset = torch.utils.data.Subset(dataset, dataset_indices[max_objects]) 132 | subset_to_load = ( 133 | torch.utils.data.Subset(subset, list(range(cfg.DEBUG_NUM_DATAPOINTS // num_max_objects))) 134 | if cfg.DEBUG else subset 135 | ) 136 | dataset_subsets.append(subset) 137 | dataloader = torch.utils.data.DataLoader(subset_to_load, batch_size=cfg.TRAIN.BATCH_SIZE[max_objects], 138 | drop_last=True, shuffle=bshuffle, num_workers=int(cfg.WORKERS)) 139 | dataloaders.append(dataloader) 140 | 141 | algo = trainer(output_dir, dataloaders, dataset.n_words, dataset.ixtoword, resume) 142 | 143 | else: 144 | dataset = TextDataset(cfg.DATA_DIR, img_dir, split_dir, base_size=cfg.TREE.BASE_SIZE, 145 | transform=image_transform, eval=eval, use_generated_bboxes=cfg.TRAIN.GENERATED_BBOXES) 146 | assert dataset 147 | dataset_to_load = ( 148 | torch.utils.data.Subset(dataset, list(range(cfg.DEBUG_NUM_DATAPOINTS))) if cfg.DEBUG else dataset 149 | ) 150 | dataloader = torch.utils.data.DataLoader(dataset_to_load, batch_size=cfg.TRAIN.BATCH_SIZE[0], 151 | drop_last=True, shuffle=bshuffle, num_workers=int(cfg.WORKERS)) 152 | 153 | algo = trainer(output_dir, dataloader, dataset.n_words, dataset.ixtoword, resume) 154 | 155 | start_t = time.time() 156 | if cfg.TRAIN.FLAG: 157 | if not resume: 158 | copyfile("code/main.py", os.path.join(output_dir, "main.py")) 159 | copyfile("code/trainer.py", os.path.join(output_dir, "trainer.py")) 160 | copyfile("code/model.py", os.path.join(output_dir, "model.py")) 161 | copyfile("code/miscc/utils.py", os.path.join(output_dir, "utils.py")) 162 | copyfile("code/miscc/losses.py", os.path.join(output_dir, "losses.py")) 163 | copyfile("code/GlobalAttention.py", os.path.join(output_dir, "GlobalAttention.py")) 164 | copyfile("code/datasets.py", os.path.join(output_dir, "datasets.py")) 165 | copyfile(args.cfg_file, os.path.join(output_dir, "cfg_file_train.yml")) 166 | algo.train() 167 | end_t = time.time() 168 | logger.info('Total time for training: %s', end_t - start_t) 169 | else: 170 | '''generate images from pre-extracted embeddings''' 171 | assert not cfg.TRAIN.OPTIMIZE_DATA_LOADING, "\"cfg.TRAIN.OPTIMIZE_DATA_LOADING\" " \ 172 | "not valid for sampling since we use" \ 173 | "generated bounding boxes at test time." 174 | use_generated_bboxes = cfg.TRAIN.GENERATED_BBOXES 175 | algo.sampling(split_dir, num_samples=500) # generate images 176 | -------------------------------------------------------------------------------- /OP-GAN/code/miscc/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | -------------------------------------------------------------------------------- /OP-GAN/code/miscc/config.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import logging 5 | 6 | import numpy as np 7 | from easydict import EasyDict as edict 8 | 9 | logger = logging.getLogger() 10 | 11 | __C = edict() 12 | cfg = __C 13 | 14 | __C.DATASET_NAME = 'coco' 15 | __C.CONFIG_NAME = '' 16 | __C.DATA_DIR = 'data' 17 | __C.OUTPUT_DIR = 'output' 18 | __C.CUDA = True 19 | __C.WORKERS = 6 20 | __C.SEED = -1 21 | __C.DEBUG = False 22 | __C.DEBUG_NUM_DATAPOINTS = 100 23 | 24 | __C.RNN_TYPE = 'LSTM' # 'GRU' 25 | 26 | __C.TREE = edict() 27 | __C.TREE.BRANCH_NUM = 3 28 | __C.TREE.BASE_SIZE = 64 29 | 30 | # Training options 31 | __C.TRAIN = edict() 32 | __C.TRAIN.BATCH_SIZE = [24] 33 | __C.TRAIN.MAX_EPOCH = 120 34 | __C.TRAIN.DISCRIMINATOR_LR = 2e-4 35 | __C.TRAIN.GENERATOR_LR = 2e-4 36 | __C.TRAIN.ENCODER_LR = 2e-4 37 | __C.TRAIN.RNN_GRAD_CLIP = 0.25 38 | __C.TRAIN.FLAG = True 39 | __C.TRAIN.NET_E = '' 40 | __C.TRAIN.NET_G = '' 41 | __C.TRAIN.BBOX_LOSS = True 42 | __C.TRAIN.B_NET_D = True 43 | __C.TRAIN.OPTIMIZE_DATA_LOADING = True 44 | __C.TRAIN.EMPTY_CACHE = False 45 | __C.TRAIN.GENERATED_BBOXES = False 46 | 47 | __C.TRAIN.SMOOTH = edict() 48 | __C.TRAIN.SMOOTH.GAMMA1 = 5.0 49 | __C.TRAIN.SMOOTH.GAMMA3 = 10.0 50 | __C.TRAIN.SMOOTH.GAMMA2 = 5.0 51 | __C.TRAIN.SMOOTH.LAMBDA = 1.0 52 | 53 | # Modal options 54 | __C.GAN = edict() 55 | __C.GAN.DISC_FEAT_DIM = 96 56 | __C.GAN.GEN_FEAT_DIM = 48 57 | __C.GAN.GLOBAL_Z_DIM = 100 58 | __C.GAN.LOCAL_Z_DIM = 32 59 | __C.GAN.TEXT_CONDITION_DIM = 100 60 | __C.GAN.INIT_LABEL_DIM = 100 61 | __C.GAN.NEXT_LABEL_DIM = 256 // 2 62 | __C.GAN.RESIDUAL_NUM = 3 63 | __C.GAN.B_ATTENTION = True 64 | __C.GAN.B_DCGAN = False 65 | __C.GAN.LAYOUT_SPATIAL_DIM = 16 66 | 67 | __C.TEXT = edict() 68 | __C.TEXT.CAPTIONS_PER_IMAGE = 10 69 | __C.TEXT.EMBEDDING_DIM = 256 70 | __C.TEXT.WORDS_NUM = 12 71 | __C.TEXT.CLASSES_NUM = 81 72 | 73 | 74 | def _merge_a_into_b(a, b): 75 | """Merge config dictionary a into config dictionary b, clobbering the 76 | options in b whenever they are also specified in a. 77 | """ 78 | if type(a) is not edict: 79 | return 80 | 81 | for k, v in a.items(): 82 | # a must specify keys that are in b 83 | if not k in b: 84 | raise KeyError('{} is not a valid config key'.format(k)) 85 | 86 | # the types must match, too 87 | old_type = type(b[k]) 88 | if old_type is not type(v): 89 | if isinstance(b[k], np.ndarray): 90 | v = np.array(v, dtype=b[k].dtype) 91 | else: 92 | raise ValueError(('Type mismatch ({} vs. {}) ' 93 | 'for config key: {}').format(type(b[k]), 94 | type(v), k)) 95 | 96 | # recursively merge dicts 97 | if type(v) is edict: 98 | try: 99 | _merge_a_into_b(a[k], b[k]) 100 | except: 101 | logger.info('Error under config key: {}'.format(k)) 102 | raise 103 | else: 104 | b[k] = v 105 | 106 | 107 | def cfg_from_file(filename): 108 | """Load a config file and merge it into the default options.""" 109 | import yaml 110 | with open(filename, 'r') as f: 111 | yaml_cfg = edict(yaml.load(f, Loader=yaml.FullLoader)) 112 | 113 | _merge_a_into_b(yaml_cfg, __C) 114 | -------------------------------------------------------------------------------- /OP-GAN/code/miscc/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import numpy as np 5 | from miscc.config import cfg 6 | 7 | from GlobalAttention import func_attention 8 | 9 | 10 | # ##################Loss for matching text-image################### 11 | def cosine_similarity(x1, x2, dim=1, eps=1e-8): 12 | """Returns cosine similarity between x1 and x2, computed along dim. 13 | """ 14 | w12 = torch.sum(x1 * x2, dim) 15 | w1 = torch.norm(x1, 2, dim) 16 | w2 = torch.norm(x2, 2, dim) 17 | return (w12 / (w1 * w2).clamp(min=eps)).squeeze() 18 | 19 | 20 | def sent_loss(cnn_code, rnn_code, labels, class_ids, 21 | batch_size, eps=1e-8): 22 | # ### Mask mis-match samples ### 23 | # that come from the same class as the real sample ### 24 | masks = [] 25 | if class_ids is not None: 26 | for i in range(batch_size): 27 | mask = (class_ids == class_ids[i]).astype(np.uint8) 28 | mask[i] = 0 29 | masks.append(mask.reshape((1, -1))) 30 | masks = np.concatenate(masks, 0) 31 | # masks: batch_size x batch_size 32 | masks = torch.BoolTensor(masks).to(cfg.DEVICE) 33 | 34 | # --> seq_len x batch_size x nef 35 | if cnn_code.dim() == 2: 36 | cnn_code = cnn_code.unsqueeze(0) 37 | rnn_code = rnn_code.unsqueeze(0) 38 | 39 | # cnn_code_norm / rnn_code_norm: seq_len x batch_size x 1 40 | cnn_code_norm = torch.norm(cnn_code, 2, dim=2, keepdim=True) 41 | rnn_code_norm = torch.norm(rnn_code, 2, dim=2, keepdim=True) 42 | # scores* / norm*: seq_len x batch_size x batch_size 43 | scores0 = torch.bmm(cnn_code, rnn_code.transpose(1, 2)) 44 | norm0 = torch.bmm(cnn_code_norm, rnn_code_norm.transpose(1, 2)) 45 | scores0 = scores0 / norm0.clamp(min=eps) * cfg.TRAIN.SMOOTH.GAMMA3 46 | 47 | # --> batch_size x batch_size 48 | scores0 = scores0.squeeze() 49 | if class_ids is not None: 50 | scores0.data.masked_fill_(masks, -float('inf')) 51 | scores1 = scores0.transpose(0, 1) 52 | if labels is not None: 53 | loss0 = nn.CrossEntropyLoss()(scores0, labels) 54 | loss1 = nn.CrossEntropyLoss()(scores1, labels) 55 | else: 56 | loss0, loss1 = None, None 57 | return loss0, loss1 58 | 59 | 60 | def words_loss(img_features, words_emb, labels, 61 | cap_lens, class_ids, batch_size): 62 | """ 63 | words_emb(query): batch x nef x seq_len 64 | img_features(context): batch x nef x 17 x 17 65 | """ 66 | masks = [] 67 | att_maps = [] 68 | similarities = [] 69 | cap_lens = cap_lens.data.tolist() 70 | for i in range(batch_size): 71 | if class_ids is not None: 72 | mask = (class_ids == class_ids[i]).astype(np.uint8) 73 | mask[i] = 0 74 | masks.append(mask.reshape((1, -1))) 75 | # Get the i-th text description 76 | words_num = cap_lens[i] 77 | # -> 1 x nef x words_num 78 | word = words_emb[i, :, :words_num].unsqueeze(0).contiguous() 79 | # -> batch_size x nef x words_num 80 | word = word.repeat(batch_size, 1, 1) 81 | # batch x nef x 17*17 82 | context = img_features 83 | """ 84 | word(query): batch x nef x words_num 85 | context: batch x nef x 17 x 17 86 | weiContext: batch x nef x words_num 87 | attn: batch x words_num x 17 x 17 88 | """ 89 | weiContext, attn = func_attention(word, context, cfg.TRAIN.SMOOTH.GAMMA1) 90 | att_maps.append(attn[i].unsqueeze(0).contiguous()) 91 | # --> batch_size x words_num x nef 92 | word = word.transpose(1, 2).contiguous() 93 | weiContext = weiContext.transpose(1, 2).contiguous() 94 | # --> batch_size*words_num x nef 95 | word = word.view(batch_size * words_num, -1) 96 | weiContext = weiContext.view(batch_size * words_num, -1) 97 | # 98 | # -->batch_size*words_num 99 | row_sim = cosine_similarity(word, weiContext) 100 | # --> batch_size x words_num 101 | row_sim = row_sim.view(batch_size, words_num) 102 | 103 | # Eq. (10) 104 | row_sim.mul_(cfg.TRAIN.SMOOTH.GAMMA2).exp_() 105 | row_sim = row_sim.sum(dim=1, keepdim=True) 106 | row_sim = torch.log(row_sim) 107 | 108 | # --> 1 x batch_size 109 | # similarities(i, j): the similarity between the i-th image and the j-th text description 110 | similarities.append(row_sim) 111 | 112 | # batch_size x batch_size 113 | similarities = torch.cat(similarities, 1) 114 | if class_ids is not None: 115 | masks = np.concatenate(masks, 0) 116 | # masks: batch_size x batch_size 117 | masks = torch.BoolTensor(masks).to(cfg.DEVICE) 118 | 119 | similarities = similarities * cfg.TRAIN.SMOOTH.GAMMA3 120 | if class_ids is not None: 121 | similarities.data.masked_fill_(masks, -float('inf')) 122 | similarities1 = similarities.transpose(0, 1) 123 | if labels is not None: 124 | loss0 = nn.CrossEntropyLoss()(similarities, labels) 125 | loss1 = nn.CrossEntropyLoss()(similarities1, labels) 126 | else: 127 | loss0, loss1 = None, None 128 | return loss0, loss1, att_maps 129 | 130 | 131 | # ##################Loss for G and Ds############################## 132 | def discriminator_loss(netD, real_imgs, fake_imgs, conditions, 133 | real_labels, fake_labels, local_labels=None, 134 | transf_matrices=None, transf_matrices_inv=None, cfg=None, max_objects=None): 135 | # Forward 136 | criterion = nn.BCELoss() 137 | if local_labels is not None: 138 | inputs = (real_imgs, local_labels, transf_matrices, transf_matrices_inv, max_objects) 139 | else: 140 | inputs = (real_imgs) 141 | real_features = netD(*inputs) 142 | if local_labels is not None: 143 | inputs = (fake_imgs.detach(), local_labels, transf_matrices, transf_matrices_inv, max_objects) 144 | else: 145 | inputs = (fake_imgs.detach()) 146 | fake_features = netD(*inputs) 147 | 148 | if cfg.TRAIN.BBOX_LOSS: 149 | if local_labels is not None: 150 | inputs = (fake_imgs.detach(), local_labels, torch.flip(transf_matrices, [0]), 151 | torch.flip(transf_matrices_inv, [0]), max_objects) 152 | else: 153 | inputs = (fake_imgs.detach()) 154 | real_features_wrong_bbox = netD(*inputs) 155 | # loss 156 | cond_real_logits = netD.COND_DNET(real_features, conditions) 157 | cond_real_errD = criterion(cond_real_logits, real_labels) 158 | cond_fake_logits = netD.COND_DNET(fake_features, conditions) 159 | cond_fake_errD = criterion(cond_fake_logits, fake_labels) 160 | 161 | batch_size = real_features.size(0) 162 | cond_wrong_logits = netD.COND_DNET(real_features[:(batch_size - 1)], 163 | conditions[1:batch_size]) 164 | cond_wrong_errD = criterion(cond_wrong_logits, fake_labels[1:batch_size]) 165 | 166 | if cfg.TRAIN.BBOX_LOSS: 167 | cond_wrong_bbox = netD.COND_DNET(real_features_wrong_bbox, 168 | conditions) 169 | cond_wrong_bbox_errD = criterion(cond_wrong_bbox, fake_labels) 170 | 171 | if netD.UNCOND_DNET is not None: 172 | real_logits = netD.UNCOND_DNET(real_features) 173 | fake_logits = netD.UNCOND_DNET(fake_features) 174 | real_errD = criterion(real_logits, real_labels) 175 | fake_errD = criterion(fake_logits, fake_labels) 176 | if cfg.TRAIN.BBOX_LOSS: 177 | wrong_bbox_logits = netD.UNCOND_DNET(real_features_wrong_bbox) 178 | wrong_bbox_errD = criterion(wrong_bbox_logits, fake_labels) 179 | errD = ((real_errD + cond_real_errD) / 2. + 180 | (fake_errD + cond_fake_errD + cond_wrong_errD + cond_wrong_bbox_errD + wrong_bbox_errD) / 5.) 181 | 182 | else: 183 | errD = ((real_errD + cond_real_errD) / 2. + 184 | (fake_errD + cond_fake_errD + cond_wrong_errD) / 3.) 185 | else: 186 | if cfg.TRAIN.BBOX_LOSS: 187 | errD = cond_real_errD + (cond_fake_errD + cond_wrong_errD + cond_wrong_bbox_errD) / 3. 188 | else: 189 | errD = cond_real_errD + (cond_fake_errD + cond_wrong_errD) / 2. 190 | return errD 191 | 192 | 193 | def generator_loss(netsD, image_encoder, fake_imgs, real_labels, 194 | words_embs, sent_emb, match_labels, 195 | cap_lens, class_ids, local_labels=None, 196 | transf_matrices=None, transf_matrices_inv=None, max_objects=None): 197 | numDs = len(netsD) 198 | batch_size = real_labels.size(0) 199 | criterion = nn.BCELoss() 200 | # Forward 201 | errG_total = 0 202 | for i in range(numDs): 203 | inputs = (fake_imgs[i], local_labels, transf_matrices, transf_matrices_inv, max_objects) 204 | features = netsD[i](*inputs) 205 | cond_logits = netsD[i].COND_DNET(features, sent_emb) 206 | cond_errG = criterion(cond_logits, real_labels) 207 | if netsD[i].UNCOND_DNET is not None: 208 | logits = netsD[i].UNCOND_DNET(features) 209 | errG = criterion(logits, real_labels) 210 | g_loss = errG + cond_errG 211 | else: 212 | g_loss = cond_errG 213 | errG_total += g_loss 214 | 215 | # Ranking loss 216 | if i == (numDs - 1): 217 | # words_features: batch_size x nef x 17 x 17 218 | # sent_code: batch_size x nef 219 | region_features, cnn_code = image_encoder(fake_imgs[i]) 220 | w_loss0, w_loss1, _ = words_loss(region_features, words_embs, 221 | match_labels, cap_lens, 222 | class_ids, batch_size) 223 | w_loss = (w_loss0 + w_loss1) * \ 224 | cfg.TRAIN.SMOOTH.LAMBDA 225 | # err_words = err_words + w_loss.data[0] 226 | 227 | s_loss0, s_loss1 = sent_loss(cnn_code, sent_emb, 228 | match_labels, class_ids, batch_size) 229 | s_loss = (s_loss0 + s_loss1) * \ 230 | cfg.TRAIN.SMOOTH.LAMBDA 231 | # err_sent = err_sent + s_loss.data[0] 232 | 233 | errG_total += w_loss + s_loss 234 | return errG_total 235 | 236 | 237 | ################################################################## 238 | def KL_loss(mu, logvar): 239 | # -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) 240 | KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) 241 | KLD = torch.mean(KLD_element).mul_(-0.5) 242 | return KLD 243 | -------------------------------------------------------------------------------- /OP-GAN/code/miscc/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import errno 4 | import sys 5 | 6 | import numpy as np 7 | from torch.nn import init 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | from PIL import Image, ImageDraw, ImageFont 13 | from copy import deepcopy 14 | import skimage.transform 15 | 16 | from miscc.config import cfg 17 | 18 | 19 | def compute_transformation_matrix_inverse(bbox): 20 | x, y = bbox[:, 0], bbox[:, 1] 21 | w, h = bbox[:, 2], bbox[:, 3] 22 | 23 | scale_x = 1.0 / w 24 | scale_y = 1.0 / h 25 | 26 | t_x = 2 * scale_x * (0.5 - (x + 0.5 * w)) 27 | t_y = 2 * scale_y * (0.5 - (y + 0.5 * h)) 28 | 29 | zeros = torch.FloatTensor(bbox.shape[0],1).fill_(0) 30 | 31 | transformation_matrix = torch.cat([scale_x.unsqueeze(-1), zeros, t_x.unsqueeze(-1), 32 | zeros, scale_y.unsqueeze(-1), t_y.unsqueeze(-1)], 1).view(-1, 2, 3) 33 | 34 | return transformation_matrix 35 | 36 | 37 | def compute_transformation_matrix(bbox): 38 | x, y = bbox[:, 0], bbox[:, 1] 39 | w, h = bbox[:, 2], bbox[:, 3] 40 | 41 | scale_x = w 42 | scale_y = h 43 | 44 | t_x = 2 * ((x + 0.5 * w) - 0.5) 45 | t_y = 2 * ((y + 0.5 * h) - 0.5) 46 | 47 | zeros = torch.FloatTensor(bbox.shape[0],1).fill_(0) 48 | 49 | transformation_matrix = torch.cat([scale_x.unsqueeze(-1), zeros, t_x.unsqueeze(-1), 50 | zeros, scale_y.unsqueeze(-1), t_y.unsqueeze(-1)], 1).view(-1, 2, 3) 51 | 52 | return transformation_matrix 53 | 54 | # For visualization ################################################ 55 | COLOR_DIC = {0:[128,64,128], 1:[244, 35,232], 56 | 2:[70, 70, 70], 3:[102,102,156], 57 | 4:[190,153,153], 5:[153,153,153], 58 | 6:[250,170, 30], 7:[220, 220, 0], 59 | 8:[107,142, 35], 9:[152,251,152], 60 | 10:[70,130,180], 11:[220,20, 60], 61 | 12:[255, 0, 0], 13:[0, 0, 142], 62 | 14:[119,11, 32], 15:[0, 60,100], 63 | 16:[0, 80, 100], 17:[0, 0, 230], 64 | 18:[0, 0, 70], 19:[0, 0, 0]} 65 | FONT_MAX = 50 66 | 67 | 68 | def drawCaption(convas, captions, ixtoword, vis_size, off1=2, off2=2): 69 | num = captions.size(0) 70 | img_txt = Image.fromarray(convas) 71 | # get a font 72 | # fnt = ImageFont.truetype('Pillow/Tests/fonts/FreeMono.ttf', 50) 73 | fnt = ImageFont.load_default() 74 | # get a drawing context 75 | d = ImageDraw.Draw(img_txt) 76 | sentence_list = [] 77 | for i in range(num): 78 | cap = captions[i].data.cpu().numpy() 79 | sentence = [] 80 | for j in range(len(cap)): 81 | if cap[j] == 0: 82 | break 83 | word = ixtoword[cap[j]].encode('ascii', 'ignore').decode('ascii') 84 | d.text(((j + off1) * (vis_size + off2), i * FONT_MAX), '%d:%s' % (j, word[:6]), 85 | font=fnt, fill=(255, 255, 255, 255)) 86 | sentence.append(word) 87 | sentence_list.append(sentence) 88 | return img_txt, sentence_list 89 | 90 | 91 | def build_super_images(real_imgs, captions, ixtoword, 92 | attn_maps, att_sze, lr_imgs=None, 93 | batch_size=cfg.TRAIN.BATCH_SIZE[0], 94 | max_word_num=cfg.TEXT.WORDS_NUM): 95 | nvis = min(8, len(attn_maps)) 96 | real_imgs = real_imgs[:nvis] 97 | if lr_imgs is not None: 98 | lr_imgs = lr_imgs[:nvis] 99 | if att_sze == 17: 100 | vis_size = att_sze * 16 101 | else: 102 | vis_size = real_imgs.size(2) 103 | 104 | text_convas = \ 105 | np.ones([batch_size * FONT_MAX, 106 | (max_word_num + 2) * (vis_size + 2), 3], 107 | dtype=np.uint8) 108 | 109 | for i in range(max_word_num): 110 | istart = (i + 2) * (vis_size + 2) 111 | iend = (i + 3) * (vis_size + 2) 112 | text_convas[:, istart:iend, :] = COLOR_DIC[i] 113 | 114 | 115 | real_imgs = nn.Upsample(size=(vis_size, vis_size), mode='bilinear', align_corners=False)(real_imgs) 116 | # [-1, 1] --> [0, 1] 117 | real_imgs.add_(1).div_(2).mul_(255) 118 | real_imgs = real_imgs.data.numpy() 119 | # b x c x h x w --> b x h x w x c 120 | real_imgs = np.transpose(real_imgs, (0, 2, 3, 1)) 121 | pad_sze = real_imgs.shape 122 | middle_pad = np.zeros([pad_sze[2], 2, 3]) 123 | post_pad = np.zeros([pad_sze[1], pad_sze[2], 3]) 124 | if lr_imgs is not None: 125 | lr_imgs = nn.Upsample(size=(vis_size, vis_size), mode='bilinear', align_corners=False)(lr_imgs) 126 | # [-1, 1] --> [0, 1] 127 | lr_imgs.add_(1).div_(2).mul_(255) 128 | lr_imgs = lr_imgs.data.numpy() 129 | # b x c x h x w --> b x h x w x c 130 | lr_imgs = np.transpose(lr_imgs, (0, 2, 3, 1)) 131 | 132 | # batch x seq_len x 17 x 17 --> batch x 1 x 17 x 17 133 | seq_len = max_word_num 134 | img_set = [] 135 | num = nvis # len(attn_maps) 136 | 137 | text_map, sentences = \ 138 | drawCaption(text_convas, captions, ixtoword, vis_size) 139 | text_map = np.asarray(text_map).astype(np.uint8) 140 | 141 | bUpdate = 1 142 | for i in range(num): 143 | attn = attn_maps[i].cpu().view(1, -1, att_sze, att_sze) 144 | # --> 1 x 1 x 17 x 17 145 | attn_max = attn.max(dim=1, keepdim=True) 146 | attn = torch.cat([attn_max[0], attn], 1) 147 | # 148 | attn = attn.view(-1, 1, att_sze, att_sze) 149 | attn = attn.repeat(1, 3, 1, 1).data.numpy() 150 | # n x c x h x w --> n x h x w x c 151 | attn = np.transpose(attn, (0, 2, 3, 1)) 152 | num_attn = attn.shape[0] 153 | # 154 | img = real_imgs[i] 155 | if lr_imgs is None: 156 | lrI = img 157 | else: 158 | lrI = lr_imgs[i] 159 | row = [lrI, middle_pad] 160 | row_merge = [img, middle_pad] 161 | row_beforeNorm = [] 162 | minVglobal, maxVglobal = 1, 0 163 | for j in range(num_attn): 164 | one_map = attn[j] 165 | if (vis_size // att_sze) > 1: 166 | one_map = \ 167 | skimage.transform.pyramid_expand(one_map, sigma=20, 168 | upscale=vis_size // att_sze, 169 | multichannel=True) 170 | row_beforeNorm.append(one_map) 171 | minV = one_map.min() 172 | maxV = one_map.max() 173 | if minVglobal > minV: 174 | minVglobal = minV 175 | if maxVglobal < maxV: 176 | maxVglobal = maxV 177 | for j in range(seq_len + 1): 178 | if j < num_attn: 179 | one_map = row_beforeNorm[j] 180 | one_map = (one_map - minVglobal) / (maxVglobal - minVglobal) 181 | one_map *= 255 182 | # 183 | PIL_im = Image.fromarray(np.uint8(img)) 184 | PIL_att = Image.fromarray(np.uint8(one_map)) 185 | merged = \ 186 | Image.new('RGBA', (vis_size, vis_size), (0, 0, 0, 0)) 187 | mask = Image.new('L', (vis_size, vis_size), (210)) 188 | merged.paste(PIL_im, (0, 0)) 189 | merged.paste(PIL_att, (0, 0), mask) 190 | merged = np.array(merged)[:, :, :3] 191 | else: 192 | one_map = post_pad 193 | merged = post_pad 194 | row.append(one_map) 195 | row.append(middle_pad) 196 | # 197 | row_merge.append(merged) 198 | row_merge.append(middle_pad) 199 | row = np.concatenate(row, 1) 200 | row_merge = np.concatenate(row_merge, 1) 201 | txt = text_map[i * FONT_MAX: (i + 1) * FONT_MAX] 202 | if txt.shape[1] != row.shape[1]: 203 | print('txt', txt.shape, 'row', row.shape) 204 | bUpdate = 0 205 | break 206 | row = np.concatenate([txt, row, row_merge], 0) 207 | img_set.append(row) 208 | if bUpdate: 209 | img_set = np.concatenate(img_set, 0) 210 | img_set = img_set.astype(np.uint8) 211 | return img_set, sentences 212 | else: 213 | return None 214 | 215 | 216 | def build_super_images2(real_imgs, captions, cap_lens, ixtoword, 217 | attn_maps, att_sze, vis_size=256, topK=5): 218 | batch_size = real_imgs.size(0) 219 | max_word_num = np.max(cap_lens) 220 | text_convas = np.ones([batch_size * FONT_MAX, 221 | max_word_num * (vis_size + 2), 3], 222 | dtype=np.uint8) 223 | 224 | real_imgs = \ 225 | nn.Upsample(size=(vis_size, vis_size), mode='bilinear', align_corners=False)(real_imgs) 226 | # [-1, 1] --> [0, 1] 227 | real_imgs.add_(1).div_(2).mul_(255) 228 | real_imgs = real_imgs.data.numpy() 229 | # b x c x h x w --> b x h x w x c 230 | real_imgs = np.transpose(real_imgs, (0, 2, 3, 1)) 231 | pad_sze = real_imgs.shape 232 | middle_pad = np.zeros([pad_sze[2], 2, 3]) 233 | 234 | # batch x seq_len x 17 x 17 --> batch x 1 x 17 x 17 235 | img_set = [] 236 | num = len(attn_maps) 237 | 238 | text_map, sentences = \ 239 | drawCaption(text_convas, captions, ixtoword, vis_size, off1=0) 240 | text_map = np.asarray(text_map).astype(np.uint8) 241 | 242 | bUpdate = 1 243 | for i in range(num): 244 | attn = attn_maps[i].cpu().view(1, -1, att_sze, att_sze) 245 | # 246 | attn = attn.view(-1, 1, att_sze, att_sze) 247 | attn = attn.repeat(1, 3, 1, 1).data.numpy() 248 | # n x c x h x w --> n x h x w x c 249 | attn = np.transpose(attn, (0, 2, 3, 1)) 250 | num_attn = cap_lens[i] 251 | thresh = 2./float(num_attn) 252 | # 253 | img = real_imgs[i] 254 | row = [] 255 | row_merge = [] 256 | row_txt = [] 257 | row_beforeNorm = [] 258 | conf_score = [] 259 | for j in range(num_attn): 260 | one_map = attn[j] 261 | mask0 = one_map > (2. * thresh) 262 | conf_score.append(np.sum(one_map * mask0)) 263 | mask = one_map > thresh 264 | one_map = one_map * mask 265 | if (vis_size // att_sze) > 1: 266 | one_map = \ 267 | skimage.transform.pyramid_expand(one_map, sigma=20, 268 | upscale=vis_size // att_sze) 269 | minV = one_map.min() 270 | maxV = one_map.max() 271 | one_map = (one_map - minV) / (maxV - minV) 272 | row_beforeNorm.append(one_map) 273 | sorted_indices = np.argsort(conf_score)[::-1] 274 | 275 | for j in range(num_attn): 276 | one_map = row_beforeNorm[j] 277 | one_map *= 255 278 | # 279 | PIL_im = Image.fromarray(np.uint8(img)) 280 | PIL_att = Image.fromarray(np.uint8(one_map)) 281 | merged = \ 282 | Image.new('RGBA', (vis_size, vis_size), (0, 0, 0, 0)) 283 | mask = Image.new('L', (vis_size, vis_size), (180)) # (210) 284 | merged.paste(PIL_im, (0, 0)) 285 | merged.paste(PIL_att, (0, 0), mask) 286 | merged = np.array(merged)[:, :, :3] 287 | 288 | row.append(np.concatenate([one_map, middle_pad], 1)) 289 | # 290 | row_merge.append(np.concatenate([merged, middle_pad], 1)) 291 | # 292 | txt = text_map[i * FONT_MAX:(i + 1) * FONT_MAX, 293 | j * (vis_size + 2):(j + 1) * (vis_size + 2), :] 294 | row_txt.append(txt) 295 | # reorder 296 | row_new = [] 297 | row_merge_new = [] 298 | txt_new = [] 299 | for j in range(num_attn): 300 | idx = sorted_indices[j] 301 | row_new.append(row[idx]) 302 | row_merge_new.append(row_merge[idx]) 303 | txt_new.append(row_txt[idx]) 304 | row = np.concatenate(row_new[:topK], 1) 305 | row_merge = np.concatenate(row_merge_new[:topK], 1) 306 | txt = np.concatenate(txt_new[:topK], 1) 307 | if txt.shape[1] != row.shape[1]: 308 | print('Warnings: txt', txt.shape, 'row', row.shape, 309 | 'row_merge_new', row_merge_new.shape) 310 | bUpdate = 0 311 | break 312 | row = np.concatenate([txt, row_merge], 0) 313 | img_set.append(row) 314 | if bUpdate: 315 | img_set = np.concatenate(img_set, 0) 316 | img_set = img_set.astype(np.uint8) 317 | return img_set, sentences 318 | else: 319 | return None 320 | 321 | 322 | #################################################################### 323 | def weights_init(m): 324 | classname = m.__class__.__name__ 325 | if classname.find('Conv') != -1: 326 | nn.init.orthogonal_(m.weight.data, 1.0) 327 | elif classname.find('BatchNorm') != -1: 328 | m.weight.data.normal_(1.0, 0.02) 329 | m.bias.data.fill_(0) 330 | elif classname.find('Linear') != -1: 331 | nn.init.orthogonal_(m.weight.data, 1.0) 332 | if m.bias is not None: 333 | m.bias.data.fill_(0.0) 334 | 335 | 336 | def load_params(model, new_param): 337 | for p, new_p in zip(model.parameters(), new_param): 338 | p.data.copy_(new_p) 339 | 340 | 341 | def copy_G_params(model): 342 | flatten = deepcopy(list(p.data for p in model.parameters())) 343 | return flatten 344 | 345 | 346 | def count_learnable_params(model): 347 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 348 | 349 | 350 | def mkdir_p(path): 351 | try: 352 | os.makedirs(path) 353 | except OSError as exc: # Python >2.5 354 | if exc.errno == errno.EEXIST and os.path.isdir(path): 355 | pass 356 | else: 357 | raise 358 | 359 | 360 | class DataParallelPassThrough(nn.parallel.DataParallel): 361 | """ 362 | Use this so the following still works. 363 | >>> net = SomeModule(10, 20) 364 | >>> print(net.some_sub_module) 365 | >>> net = DistributedDataParallelPassthrough(net) 366 | >>> print(net.some_sub_module) 367 | While otherwise, with `nn.parallel.DataParallel`, this would give a ModuleAttributeError. 368 | https://github.com/pytorch/pytorch/issues/16885 369 | """ 370 | def __getattr__(self, name): 371 | try: 372 | return super().__getattr__(name) 373 | except AttributeError: 374 | return getattr(self.module, name) 375 | 376 | 377 | def initialize_logging(output_dir, to_file=True): 378 | logger = logging.getLogger() 379 | logger.setLevel(logging.INFO) 380 | 381 | formatter = logging.Formatter(fmt='%(asctime)s - %(levelname)s - %(message)s', 382 | datefmt='%d-%m-%Y %H:%M:%S') 383 | ch = logging.StreamHandler(stream=sys.stdout) 384 | ch.setLevel(logging.INFO) 385 | ch.setFormatter(formatter) 386 | logger.addHandler(ch) 387 | 388 | if to_file: 389 | fh = logging.FileHandler(os.path.join(output_dir, 'output.log')) 390 | fh.setLevel(logging.INFO) 391 | fh.setFormatter(formatter) 392 | logger.addHandler(fh) 393 | -------------------------------------------------------------------------------- /OP-GAN/code/model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.parallel 6 | from torch.autograd import Variable 7 | from torchvision import models 8 | import torch.nn.functional as F 9 | 10 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 11 | 12 | from miscc.config import cfg 13 | from GlobalAttention import GlobalAttentionGeneral as ATT_NET 14 | 15 | logger = logging.getLogger() 16 | 17 | 18 | def stn(image, transformation_matrix, size): 19 | grid = torch.nn.functional.affine_grid(transformation_matrix, torch.Size(size), align_corners=False) 20 | out_image = torch.nn.functional.grid_sample(image, grid, align_corners=False) 21 | 22 | return out_image 23 | 24 | 25 | class GLU(nn.Module): 26 | def __init__(self): 27 | super(GLU, self).__init__() 28 | 29 | def forward(self, x): 30 | nc = x.size(1) 31 | assert nc % 2 == 0, 'channels dont divide 2!' 32 | nc = int(nc/2) 33 | return x[:, :nc] * torch.sigmoid(x[:, nc:]) 34 | 35 | 36 | def conv1x1(in_planes, out_planes, bias=False): 37 | "1x1 convolution with padding" 38 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, 39 | padding=0, bias=bias) 40 | 41 | 42 | def conv3x3(in_planes, out_planes, stride=1): 43 | "3x3 convolution with padding" 44 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 45 | padding=1, bias=False) 46 | 47 | 48 | # Upsale the spatial size by a factor of 2 49 | def upBlock(in_planes, out_planes): 50 | block = nn.Sequential( 51 | # nn.functional.interpolate(scale_factor=2, mode="nearest"), 52 | nn.Upsample(scale_factor=2, mode='nearest'), 53 | conv3x3(in_planes, out_planes * 2), 54 | nn.BatchNorm2d(out_planes * 2), 55 | GLU()) 56 | return block 57 | 58 | 59 | # Keep the spatial size 60 | def Block3x3_relu(in_planes, out_planes): 61 | block = nn.Sequential( 62 | conv3x3(in_planes, out_planes * 2), 63 | nn.BatchNorm2d(out_planes * 2), 64 | GLU()) 65 | return block 66 | 67 | 68 | class ResBlock(nn.Module): 69 | def __init__(self, channel_num): 70 | super(ResBlock, self).__init__() 71 | self.block = nn.Sequential( 72 | conv3x3(channel_num, channel_num * 2), 73 | nn.BatchNorm2d(channel_num * 2), 74 | GLU(), 75 | conv3x3(channel_num, channel_num), 76 | nn.BatchNorm2d(channel_num)) 77 | 78 | def forward(self, x): 79 | residual = x 80 | out = self.block(x) 81 | out += residual 82 | return out 83 | 84 | 85 | def channel_pool(input, kernel_size): 86 | b, c, h, w = input.size() 87 | input = input.view(b, c, h * w).permute(0, 2, 1) 88 | stride = c 89 | pooled = torch.nn.functional.max_pool1d(input, kernel_size, stride) 90 | pooled = pooled.permute(0, 2, 1).view(b, -1, h, w) 91 | assert pooled.shape[1] == 1 92 | return pooled.squeeze() 93 | 94 | 95 | def merge_tensors(source, new_features, idx): 96 | """This method deals with the fact that some bboxes overlap each other. 97 | To deal with this we use the simple heuristic that the smaller bbox contains the object in the foreground. 98 | As such features in smaller bboxes replace the features of larger bboxes in overlapping areas""" 99 | if idx == 0: 100 | return new_features 101 | else: 102 | nz = torch.nonzero(new_features) 103 | source[nz[:, 0], nz[:, 1], nz[:, 2], nz[:, 3]] = new_features[nz[:, 0], nz[:, 1], nz[:, 2], nz[:, 3]] 104 | return source 105 | 106 | 107 | class BBOX_NET(nn.Module): 108 | # some code is modified from vae examples 109 | # (https://github.com/pytorch/examples/blob/master/vae/main.py) 110 | def __init__(self): 111 | super(BBOX_NET, self).__init__() 112 | self.c_dim = cfg.GAN.INIT_LABEL_DIM 113 | self.input_dim = cfg.GAN.LAYOUT_SPATIAL_DIM 114 | self.encode = nn.Sequential( 115 | # 128 * 16 x 16 116 | conv3x3(self.c_dim, self.c_dim // 2, stride=2), 117 | nn.LeakyReLU(0.2, inplace=True), 118 | # 64 x 8 x 8 119 | conv3x3(self.c_dim // 2, self.c_dim // 4, stride=2), 120 | nn.BatchNorm2d(self.c_dim // 4), 121 | nn.LeakyReLU(0.2, inplace=True), 122 | # 32 x 4 x 4 123 | conv3x3(self.c_dim // 4, self.c_dim // 8, stride=2), 124 | nn.BatchNorm2d(self.c_dim // 8), 125 | nn.LeakyReLU(0.2, inplace=True), 126 | # 16 x 2 x 2 127 | ) 128 | 129 | def forward(self, labels, transf_matr_inv, max_objects): 130 | label_layout = labels.new_zeros(labels.shape[0], self.c_dim, self.input_dim, self.input_dim) 131 | for idx in range(max_objects): 132 | current_label = labels[:, idx] 133 | current_label = current_label.view(current_label.shape[0], current_label.shape[1], 1, 1) 134 | current_label = current_label.repeat(1, 1, self.input_dim, self.input_dim) 135 | current_label = stn(current_label, transf_matr_inv[:, idx], current_label.shape) 136 | label_layout += current_label 137 | 138 | layout_encoding = self.encode(label_layout).view(labels.shape[0], -1) 139 | 140 | return layout_encoding 141 | 142 | 143 | # ############## Text2Image Encoder-Decoder ####### 144 | class RNN_ENCODER(nn.Module): 145 | def __init__(self, ntoken, ninput=300, drop_prob=0.5, 146 | nhidden=128, nlayers=1, bidirectional=True): 147 | super(RNN_ENCODER, self).__init__() 148 | self.n_steps = cfg.TEXT.WORDS_NUM 149 | self.ntoken = ntoken # size of the dictionary 150 | self.ninput = ninput # size of each embedding vector 151 | self.drop_prob = drop_prob # probability of an element to be zeroed 152 | self.nlayers = nlayers # Number of recurrent layers 153 | self.bidirectional = bidirectional 154 | self.rnn_type = cfg.RNN_TYPE 155 | if bidirectional: 156 | self.num_directions = 2 157 | else: 158 | self.num_directions = 1 159 | # number of features in the hidden state 160 | self.nhidden = nhidden // self.num_directions 161 | 162 | self.define_module() 163 | self.init_weights() 164 | 165 | def define_module(self): 166 | """ 167 | nn.LSTM and nn.GRU will give a warning if nlayers=1 and dropout>0, saying dropout is only used 168 | when nlayers>1. That's okay. 169 | """ 170 | self.encoder = nn.Embedding(self.ntoken, self.ninput) 171 | self.drop = nn.Dropout(self.drop_prob) 172 | if self.rnn_type == 'LSTM': 173 | # dropout: If non-zero, introduces a dropout layer on 174 | # the outputs of each RNN layer except the last layer 175 | self.rnn = nn.LSTM(self.ninput, self.nhidden, 176 | self.nlayers, batch_first=True, 177 | dropout=self.drop_prob, 178 | bidirectional=self.bidirectional) 179 | elif self.rnn_type == 'GRU': 180 | self.rnn = nn.GRU(self.ninput, self.nhidden, 181 | self.nlayers, batch_first=True, 182 | dropout=self.drop_prob, 183 | bidirectional=self.bidirectional) 184 | else: 185 | raise NotImplementedError 186 | 187 | def init_weights(self): 188 | initrange = 0.1 189 | self.encoder.weight.data.uniform_(-initrange, initrange) 190 | # Do not need to initialize RNN parameters, which have been initialized 191 | # http://pytorch.org/docs/master/_modules/torch/nn/modules/rnn.html#LSTM 192 | # self.decoder.weight.data.uniform_(-initrange, initrange) 193 | # self.decoder.bias.data.fill_(0) 194 | 195 | def init_hidden(self, bsz): 196 | weight = next(self.parameters()).data 197 | if self.rnn_type == 'LSTM': 198 | return (Variable(weight.new(self.nlayers * self.num_directions, 199 | bsz, self.nhidden).zero_()), 200 | Variable(weight.new(self.nlayers * self.num_directions, 201 | bsz, self.nhidden).zero_())) 202 | else: 203 | return Variable(weight.new(self.nlayers * self.num_directions, 204 | bsz, self.nhidden).zero_()) 205 | 206 | def forward(self, captions, cap_lens, hidden, mask=None): 207 | # input: torch.LongTensor of size batch x n_steps 208 | # --> emb: batch x n_steps x ninput 209 | emb = self.drop(self.encoder(captions)) 210 | # 211 | # Returns: a PackedSequence object 212 | cap_lens = cap_lens.data.tolist() 213 | emb = pack_padded_sequence(emb, cap_lens, batch_first=True) 214 | # #hidden and memory (num_layers * num_directions, batch, hidden_size): 215 | # tensor containing the initial hidden state for each element in batch. 216 | # #output (batch, seq_len, hidden_size * num_directions) 217 | # #or a PackedSequence object: 218 | # tensor containing output features (h_t) from the last layer of RNN 219 | output, hidden = self.rnn(emb, hidden) 220 | # PackedSequence object 221 | # --> (batch, seq_len, hidden_size * num_directions) 222 | output = pad_packed_sequence(output, batch_first=True)[0] 223 | # output = self.drop(output) 224 | # --> batch x hidden_size*num_directions x seq_len 225 | words_emb = output.transpose(1, 2) 226 | # --> batch x num_directions*hidden_size 227 | if self.rnn_type == 'LSTM': 228 | sent_emb = hidden[0].transpose(0, 1).contiguous() 229 | else: 230 | sent_emb = hidden.transpose(0, 1).contiguous() 231 | sent_emb = sent_emb.view(-1, self.nhidden * self.num_directions) 232 | return words_emb, sent_emb 233 | 234 | 235 | class CNN_ENCODER(nn.Module): 236 | def __init__(self, feat_dim): 237 | super(CNN_ENCODER, self).__init__() 238 | if cfg.TRAIN.FLAG: 239 | self.feat_dim = feat_dim 240 | else: 241 | self.feat_dim = cfg.TEXT.EMBEDDING_DIM 242 | 243 | model = models.inception_v3(init_weights=False) 244 | url = 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth' 245 | state_dict = torch.hub.load_state_dict_from_url(url, model_dir='models/hub') 246 | model.load_state_dict(state_dict) 247 | for param in model.parameters(): 248 | param.requires_grad = False 249 | logger.info('Load pretrained model from %s', url) 250 | # logger.info(model) 251 | 252 | self.define_module(model) 253 | self.init_trainable_weights() 254 | 255 | def define_module(self, model): 256 | self.Conv2d_1a_3x3 = model.Conv2d_1a_3x3 257 | self.Conv2d_2a_3x3 = model.Conv2d_2a_3x3 258 | self.Conv2d_2b_3x3 = model.Conv2d_2b_3x3 259 | self.Conv2d_3b_1x1 = model.Conv2d_3b_1x1 260 | self.Conv2d_4a_3x3 = model.Conv2d_4a_3x3 261 | self.Mixed_5b = model.Mixed_5b 262 | self.Mixed_5c = model.Mixed_5c 263 | self.Mixed_5d = model.Mixed_5d 264 | self.Mixed_6a = model.Mixed_6a 265 | self.Mixed_6b = model.Mixed_6b 266 | self.Mixed_6c = model.Mixed_6c 267 | self.Mixed_6d = model.Mixed_6d 268 | self.Mixed_6e = model.Mixed_6e 269 | self.Mixed_7a = model.Mixed_7a 270 | self.Mixed_7b = model.Mixed_7b 271 | self.Mixed_7c = model.Mixed_7c 272 | 273 | self.emb_features = conv1x1(768, self.feat_dim) 274 | self.emb_cnn_code = nn.Linear(2048, self.feat_dim) 275 | 276 | def init_trainable_weights(self): 277 | initrange = 0.1 278 | self.emb_features.weight.data.uniform_(-initrange, initrange) 279 | self.emb_cnn_code.weight.data.uniform_(-initrange, initrange) 280 | 281 | def forward(self, x): 282 | features = None 283 | # --> fixed-size input: batch x 3 x 299 x 299 284 | # x = nn.functional.interpolate(x, size=(299, 299), mode='bilinear') 285 | x = nn.Upsample(size=(299, 299), mode='bilinear', align_corners=False)(x) 286 | # 299 x 299 x 3 287 | x = self.Conv2d_1a_3x3(x) 288 | # 149 x 149 x 32 289 | x = self.Conv2d_2a_3x3(x) 290 | # 147 x 147 x 32 291 | x = self.Conv2d_2b_3x3(x) 292 | # 147 x 147 x 64 293 | x = F.max_pool2d(x, kernel_size=3, stride=2) 294 | # 73 x 73 x 64 295 | x = self.Conv2d_3b_1x1(x) 296 | # 73 x 73 x 80 297 | x = self.Conv2d_4a_3x3(x) 298 | # 71 x 71 x 192 299 | 300 | x = F.max_pool2d(x, kernel_size=3, stride=2) 301 | # 35 x 35 x 192 302 | x = self.Mixed_5b(x) 303 | # 35 x 35 x 256 304 | x = self.Mixed_5c(x) 305 | # 35 x 35 x 288 306 | x = self.Mixed_5d(x) 307 | # 35 x 35 x 288 308 | 309 | x = self.Mixed_6a(x) 310 | # 17 x 17 x 768 311 | x = self.Mixed_6b(x) 312 | # 17 x 17 x 768 313 | x = self.Mixed_6c(x) 314 | # 17 x 17 x 768 315 | x = self.Mixed_6d(x) 316 | # 17 x 17 x 768 317 | x = self.Mixed_6e(x) 318 | # 17 x 17 x 768 319 | 320 | # image region features 321 | features = x 322 | # 17 x 17 x 768 323 | 324 | x = self.Mixed_7a(x) 325 | # 8 x 8 x 1280 326 | x = self.Mixed_7b(x) 327 | # 8 x 8 x 2048 328 | x = self.Mixed_7c(x) 329 | # 8 x 8 x 2048 330 | x = F.avg_pool2d(x, kernel_size=8) 331 | # 1 x 1 x 2048 332 | # x = F.dropout(x, training=self.training) 333 | # 1 x 1 x 2048 334 | x = x.view(x.size(0), -1) 335 | # 2048 336 | 337 | # global image features 338 | cnn_code = self.emb_cnn_code(x) 339 | # 512 340 | if features is not None: 341 | features = self.emb_features(features) 342 | return features, cnn_code 343 | 344 | 345 | # ############## G networks ################### 346 | class CA_NET(nn.Module): 347 | # some code is modified from vae examples 348 | # (https://github.com/pytorch/examples/blob/master/vae/main.py) 349 | def __init__(self): 350 | super(CA_NET, self).__init__() 351 | self.t_dim = cfg.TEXT.EMBEDDING_DIM 352 | self.c_dim = cfg.GAN.TEXT_CONDITION_DIM 353 | self.fc = nn.Linear(self.t_dim, self.c_dim * 4, bias=True) 354 | self.relu = GLU() 355 | 356 | def encode(self, text_embedding): 357 | x = self.relu(self.fc(text_embedding)) 358 | mu = x[:, :self.c_dim] 359 | logvar = x[:, self.c_dim:] 360 | return mu, logvar 361 | 362 | def reparametrize(self, mu, logvar): 363 | std = logvar.mul(0.5).exp() 364 | eps = torch.zeros_like(std).normal_() 365 | eps = Variable(eps) 366 | mult = eps.mul(std) 367 | added = mult.add(mu) 368 | return added 369 | 370 | def forward(self, text_embedding): 371 | mu, logvar = self.encode(text_embedding) 372 | c_code = self.reparametrize(mu, logvar) 373 | return c_code, mu, logvar 374 | 375 | 376 | class INIT_STAGE_G(nn.Module): 377 | def __init__(self): 378 | super(INIT_STAGE_G, self).__init__() 379 | self.gen_feat_dim = cfg.GAN.GEN_FEAT_DIM * 16 380 | self.define_module() 381 | 382 | def define_module(self): 383 | self.bbox_net = BBOX_NET() 384 | 385 | layout_cond_dim = cfg.GAN.INIT_LABEL_DIM // 8 * (cfg.GAN.LAYOUT_SPATIAL_DIM // 8)**2 386 | cond_inp_dim = cfg.GAN.TEXT_CONDITION_DIM + cfg.GAN.GLOBAL_Z_DIM + layout_cond_dim 387 | self.fc = nn.Sequential( 388 | nn.Linear(cond_inp_dim, self.gen_feat_dim * 4 * 4 * 2, bias=False), 389 | nn.BatchNorm1d(self.gen_feat_dim * 4 * 4 * 2), 390 | GLU() 391 | ) 392 | 393 | # local pathway 394 | label_inp_dim = cfg.GAN.TEXT_CONDITION_DIM + cfg.TEXT.CLASSES_NUM + cfg.GAN.LOCAL_Z_DIM 395 | self.label = nn.Sequential( 396 | nn.Linear(label_inp_dim, cfg.GAN.INIT_LABEL_DIM, bias=False), 397 | nn.BatchNorm1d(cfg.GAN.INIT_LABEL_DIM), 398 | nn.ReLU(inplace=True) 399 | ) 400 | self.local1 = upBlock(cfg.GAN.INIT_LABEL_DIM, self.gen_feat_dim // 2) 401 | self.local2 = upBlock(self.gen_feat_dim // 2, self.gen_feat_dim // 4) 402 | 403 | self.upsample1 = upBlock(self.gen_feat_dim, self.gen_feat_dim // 2) 404 | self.upsample2 = upBlock(self.gen_feat_dim // 2, self.gen_feat_dim // 4) 405 | self.upsample3 = upBlock(self.gen_feat_dim // 2, self.gen_feat_dim // 8) 406 | self.upsample4 = upBlock(self.gen_feat_dim // 8, self.gen_feat_dim // 16) 407 | 408 | def forward(self, z_code, local_noise, c_code, transf_matrices_inv, label_one_hot, max_objects, op=True): 409 | """ 410 | :param z_code: batch x cfg.GAN.Z_DIM 411 | :param c_code: batch x cfg.TEXT.EMBEDDING_DIM 412 | :return: batch x ngf/16 x 64 x 64 413 | """ 414 | local_labels = z_code.new_zeros(z_code.shape[0], max_objects, cfg.GAN.INIT_LABEL_DIM) 415 | 416 | # object pathway 417 | h_code_locals = z_code.new_zeros(z_code.shape[0], self.gen_feat_dim // 4, 16, 16) 418 | 419 | if op: 420 | for idx in range(max_objects): 421 | current_label = self.label(torch.cat((c_code, label_one_hot[:, idx], local_noise), 1)) 422 | local_labels[:, idx] = current_label 423 | current_label = current_label.view(current_label.shape[0], cfg.GAN.INIT_LABEL_DIM, 1, 1) 424 | current_label = current_label.repeat(1, 1, 4, 4) 425 | h_code_local = self.local1(current_label) 426 | h_code_local = self.local2(h_code_local) 427 | h_code_local = stn(h_code_local, transf_matrices_inv[:, idx], h_code_local.shape) 428 | h_code_locals = merge_tensors(h_code_locals, h_code_local, idx) 429 | 430 | bbox_code = self.bbox_net(local_labels, transf_matrices_inv, max_objects) 431 | c_z_code = torch.cat((c_code, z_code, bbox_code), 1) 432 | # state size ngf x 4 x 4 433 | out_code = self.fc(c_z_code) 434 | out_code = out_code.view(-1, self.gen_feat_dim, 4, 4) 435 | # state size ngf/3 x 8 x 8 436 | out_code = self.upsample1(out_code) 437 | # state size ngf/4 x 16 x 16 438 | out_code = self.upsample2(out_code) 439 | 440 | # combine local and global pathways 441 | out_code = torch.cat((out_code, h_code_locals), 1) 442 | 443 | # state size ngf/8 x 32 x 32 444 | out_code32 = self.upsample3(out_code) 445 | # state size ngf/16 x 64 x 64 446 | out_code64 = self.upsample4(out_code32) 447 | 448 | return out_code64 449 | 450 | 451 | class NEXT_STAGE_G(nn.Module): 452 | def __init__(self): 453 | super(NEXT_STAGE_G, self).__init__() 454 | self.gen_feat_dim = cfg.GAN.GEN_FEAT_DIM 455 | self.text_emb_dim = cfg.TEXT.EMBEDDING_DIM 456 | self.label_dim = cfg.GAN.NEXT_LABEL_DIM 457 | self.num_residual = cfg.GAN.RESIDUAL_NUM 458 | self.define_module() 459 | 460 | def _make_layer(self, block, channel_num): 461 | layers = [] 462 | for i in range(cfg.GAN.RESIDUAL_NUM): 463 | layers.append(block(channel_num)) 464 | return nn.Sequential(*layers) 465 | 466 | def define_module(self): 467 | self.att = ATT_NET(self.gen_feat_dim, self.text_emb_dim) 468 | self.residual = self._make_layer(ResBlock, self.gen_feat_dim * 2) 469 | self.upsample = upBlock(self.gen_feat_dim * 3, self.gen_feat_dim) 470 | 471 | # local pathway 472 | label_input_dim = cfg.GAN.TEXT_CONDITION_DIM + cfg.TEXT.CLASSES_NUM # no noise anymore 473 | self.label = nn.Sequential( 474 | nn.Linear(label_input_dim, self.label_dim, bias=False), 475 | nn.BatchNorm1d(self.label_dim), 476 | nn.ReLU(True) 477 | ) 478 | 479 | self.local1 = upBlock(self.label_dim + self.gen_feat_dim, self.gen_feat_dim * 2) 480 | self.local2 = upBlock(self.gen_feat_dim * 2, self.gen_feat_dim) 481 | 482 | def forward(self, h_code, c_code, word_embs, mask, transf_matrices, transf_matrices_inv, label_one_hot, 483 | max_objects, op=True): 484 | """ 485 | h_code1(query): batch x idf x ih x iw (queryL=ihxiw) 486 | word_embs(context): batch x cdf x sourceL (sourceL=seq_len) 487 | c_code1: batch x idf x queryL 488 | att1: batch x sourceL x queryL 489 | """ 490 | _hw = h_code.shape[2] 491 | self.att.applyMask(mask) 492 | c_code_att, att = self.att(h_code, word_embs) 493 | h_c_code = torch.cat((h_code, c_code_att), 1) 494 | out_code = self.residual(h_c_code) 495 | 496 | # object pathways 497 | h_code_locals = h_code.new_zeros(h_code.shape[0], self.gen_feat_dim, _hw, _hw) 498 | if op: 499 | for idx in range(max_objects): 500 | current_label = self.label(torch.cat((c_code, label_one_hot[:, idx]), 1)) 501 | current_label = current_label.view(h_code.shape[0], self.label_dim, 1, 1) 502 | current_label = current_label.repeat(1, 1, _hw//4, _hw//4) 503 | current_patch = stn(h_code, transf_matrices[:, idx], (h_code.shape[0], h_code.shape[1], _hw//4, _hw//4)) 504 | # logger.info(current_label.shape) 505 | # logger.info(current_patch.shape) 506 | current_input = torch.cat((current_patch, current_label), 1) 507 | # logger.info(current_input.shape) 508 | h_code_local = self.local1(current_input) 509 | h_code_local = self.local2(h_code_local) 510 | h_code_local = stn(h_code_local, transf_matrices_inv[:, idx], h_code_locals.shape) 511 | h_code_locals = merge_tensors(h_code_locals, h_code_local, idx) 512 | 513 | out_code = torch.cat((out_code, h_code_locals), 1) 514 | 515 | # state size ngf/2 x 2in_size x 2in_size 516 | out_code = self.upsample(out_code) 517 | 518 | return out_code, att 519 | 520 | 521 | class GET_IMAGE_G(nn.Module): 522 | def __init__(self): 523 | super(GET_IMAGE_G, self).__init__() 524 | self.img = nn.Sequential( 525 | conv3x3(cfg.GAN.GEN_FEAT_DIM, 3), 526 | nn.Tanh() 527 | ) 528 | 529 | def forward(self, h_code): 530 | out_img = self.img(h_code) 531 | return out_img 532 | 533 | 534 | class G_NET(nn.Module): 535 | def __init__(self): 536 | super(G_NET, self).__init__() 537 | self.ca_net = CA_NET() 538 | 539 | if cfg.TREE.BRANCH_NUM > 0: 540 | self.h_net1 = INIT_STAGE_G() 541 | self.img_net1 = GET_IMAGE_G() 542 | # gf x 64 x 64 543 | if cfg.TREE.BRANCH_NUM > 1: 544 | self.h_net2 = NEXT_STAGE_G() 545 | self.img_net2 = GET_IMAGE_G() 546 | if cfg.TREE.BRANCH_NUM > 2: 547 | self.h_net3 = NEXT_STAGE_G() 548 | self.img_net3 = GET_IMAGE_G() 549 | 550 | def forward(self, z_code, local_noise, sent_emb, word_embs, mask, transf_matrices, transf_matrices_inv, 551 | label_one_hot, max_objects, op=[True, True, True]): 552 | """ 553 | :param z_code: batch x cfg.GAN.Z_DIM 554 | :param sent_emb: batch x cfg.TEXT.EMBEDDING_DIM 555 | :param word_embs: batch x cdf x seq_len 556 | :param mask: batch x seq_len 557 | :return: 558 | """ 559 | fake_imgs = [] 560 | att_maps = [] 561 | c_code, mu, logvar = self.ca_net(sent_emb) 562 | 563 | if cfg.TREE.BRANCH_NUM > 0: 564 | h_code1 = self.h_net1(z_code, local_noise, c_code, transf_matrices_inv, label_one_hot, max_objects, op[0]) 565 | fake_img1 = self.img_net1(h_code1) 566 | fake_imgs.append(fake_img1) 567 | if cfg.TREE.BRANCH_NUM > 1: 568 | h_code2, att1 = self.h_net2(h_code1, c_code, word_embs, mask, transf_matrices, 569 | transf_matrices_inv, label_one_hot, max_objects, op[1]) 570 | fake_img2 = self.img_net2(h_code2) 571 | fake_imgs.append(fake_img2) 572 | if att1 is not None: 573 | att_maps.append(att1) 574 | if cfg.TREE.BRANCH_NUM > 2: 575 | h_code3, att2 = self.h_net3(h_code2, c_code, word_embs, mask, transf_matrices, 576 | transf_matrices_inv, label_one_hot, max_objects, op[2]) 577 | fake_img3 = self.img_net3(h_code3) 578 | fake_imgs.append(fake_img3) 579 | if att2 is not None: 580 | att_maps.append(att2) 581 | 582 | return fake_imgs, att_maps, mu, logvar 583 | 584 | 585 | # ############## D networks ########################## 586 | def Block3x3_leakRelu(in_planes, out_planes): 587 | block = nn.Sequential( 588 | conv3x3(in_planes, out_planes), 589 | nn.BatchNorm2d(out_planes), 590 | nn.LeakyReLU(0.2, inplace=True) 591 | ) 592 | return block 593 | 594 | 595 | # Downsale the spatial size by a factor of 2 596 | def downBlock(in_planes, out_planes): 597 | block = nn.Sequential( 598 | nn.Conv2d(in_planes, out_planes, 4, 2, 1, bias=False), 599 | nn.BatchNorm2d(out_planes), 600 | nn.LeakyReLU(0.2, inplace=True) 601 | ) 602 | return block 603 | 604 | 605 | # Downsale the spatial size by a factor of 16 606 | def encode_image_by_16times(ndf): 607 | encode_img = nn.Sequential( 608 | # --> state size. ndf x in_size/2 x in_size/2 609 | nn.Conv2d(3, ndf, 4, 2, 1, bias=False), 610 | nn.LeakyReLU(0.2, inplace=True), 611 | # --> state size 2ndf x x in_size/4 x in_size/4 612 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), 613 | nn.BatchNorm2d(ndf * 2), 614 | nn.LeakyReLU(0.2, inplace=True), 615 | # --> state size 4ndf x in_size/8 x in_size/8 616 | nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), 617 | nn.BatchNorm2d(ndf * 4), 618 | nn.LeakyReLU(0.2, inplace=True), 619 | # --> state size 8ndf x in_size/16 x in_size/16 620 | nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), 621 | nn.BatchNorm2d(ndf * 8), 622 | nn.LeakyReLU(0.2, inplace=True) 623 | ) 624 | return encode_img 625 | 626 | 627 | class D_GET_LOGITS(nn.Module): 628 | def __init__(self, ndf, nef, bcondition=False): 629 | super(D_GET_LOGITS, self).__init__() 630 | self.df_dim = ndf 631 | self.ef_dim = nef 632 | self.bcondition = bcondition 633 | if self.bcondition: 634 | self.jointConv = Block3x3_leakRelu(ndf * 8 + nef, ndf * 8) 635 | 636 | self.outlogits = nn.Sequential( 637 | nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4), 638 | nn.Sigmoid()) 639 | 640 | def forward(self, h_code, c_code=None): 641 | if self.bcondition and c_code is not None: 642 | # conditioning output 643 | c_code = c_code.view(-1, self.ef_dim, 1, 1) 644 | c_code = c_code.repeat(1, 1, 4, 4) 645 | # state size (ngf+egf) x 4 x 4 646 | h_c_code = torch.cat((h_code, c_code), 1) 647 | # state size ngf x in_size x in_size 648 | h_c_code = self.jointConv(h_c_code) 649 | else: 650 | h_c_code = h_code 651 | 652 | output = self.outlogits(h_c_code) 653 | return output.view(-1) 654 | 655 | 656 | # For 64 x 64 images 657 | class D_NET64(nn.Module): 658 | def __init__(self, b_jcu=True): 659 | super(D_NET64, self).__init__() 660 | if b_jcu: 661 | self.UNCOND_DNET = D_GET_LOGITS(cfg.GAN.DISC_FEAT_DIM, cfg.TEXT.EMBEDDING_DIM, bcondition=False) 662 | else: 663 | self.UNCOND_DNET = None 664 | self.COND_DNET = D_GET_LOGITS(cfg.GAN.DISC_FEAT_DIM, cfg.TEXT.EMBEDDING_DIM, bcondition=True) 665 | self.define_module() 666 | 667 | def define_module(self): 668 | self.act = nn.LeakyReLU(0.2, inplace=True) 669 | 670 | # global pathway 671 | # --> state size. ndf x in_size/2 x in_size/2 672 | self.conv1 = nn.Conv2d(3, cfg.GAN.DISC_FEAT_DIM, 4, 2, 1, bias=False) 673 | # --> state size 2ndf x x in_size/4 x in_size/4 674 | self.conv2 = nn.Conv2d(cfg.GAN.DISC_FEAT_DIM, cfg.GAN.DISC_FEAT_DIM * 2, 4, 2, 1, bias=False) 675 | self.bn2 = nn.BatchNorm2d(cfg.GAN.DISC_FEAT_DIM * 2) 676 | # --> state size 4ndf x in_size/8 x in_size/8 677 | self.conv3 = nn.Conv2d(cfg.GAN.DISC_FEAT_DIM * 4, cfg.GAN.DISC_FEAT_DIM * 4, 4, 2, 1, bias=False) 678 | self.bn3 = nn.BatchNorm2d(cfg.GAN.DISC_FEAT_DIM * 4) 679 | # --> state size 8ndf x in_size/16 x in_size/16 680 | self.conv4 = nn.Conv2d(cfg.GAN.DISC_FEAT_DIM * 4, cfg.GAN.DISC_FEAT_DIM * 8, 4, 2, 1, bias=False) 681 | self.bn4 = nn.BatchNorm2d(cfg.GAN.DISC_FEAT_DIM * 8) 682 | 683 | # object pathway 684 | self.local = nn.Sequential( 685 | nn.Conv2d(3 + cfg.TEXT.CLASSES_NUM, cfg.GAN.DISC_FEAT_DIM * 2, 4, 1, 1, bias=False), 686 | nn.BatchNorm2d(cfg.GAN.DISC_FEAT_DIM * 2), 687 | nn.LeakyReLU(0.2, inplace=True) 688 | ) 689 | 690 | def forward(self, image, label, transf_matrices, transf_matrices_inv, max_objects): 691 | # object pathway 692 | h_code_locals = image.new_zeros(image.shape[0], cfg.GAN.DISC_FEAT_DIM * 2, 16, 16, dtype=torch.float) 693 | 694 | for idx in range(max_objects): 695 | current_label = label[:, idx].view(label.shape[0], cfg.TEXT.CLASSES_NUM, 1, 1) 696 | current_label = current_label.repeat(1, 1, 16, 16) 697 | h_code_local = stn(image, transf_matrices[:, idx], (image.shape[0], image.shape[1], 16, 16)) 698 | h_code_local = torch.cat((h_code_local, current_label), 1) 699 | h_code_local = self.local(h_code_local) 700 | h_code_local = stn(h_code_local, transf_matrices_inv[:, idx], 701 | (h_code_local.shape[0], h_code_local.shape[1], 16, 16)) 702 | h_code_locals = merge_tensors(h_code_locals, h_code_local, idx) 703 | 704 | h_code = self.conv1(image) 705 | h_code = self.act(h_code) 706 | h_code = self.conv2(h_code) 707 | h_code = self.bn2(h_code) 708 | h_code = self.act(h_code) 709 | 710 | h_code = torch.cat((h_code, h_code_locals), 1) 711 | 712 | h_code = self.conv3(h_code) 713 | h_code = self.bn3(h_code) 714 | h_code = self.act(h_code) 715 | 716 | h_code = self.conv4(h_code) 717 | h_code = self.bn4(h_code) 718 | x_code4 = self.act(h_code) 719 | 720 | return x_code4 721 | 722 | 723 | # For 128 x 128 images 724 | class D_NET128(nn.Module): 725 | def __init__(self, b_jcu=True): 726 | super(D_NET128, self).__init__() 727 | self.img_code_s32 = downBlock(cfg.GAN.DISC_FEAT_DIM * 8, cfg.GAN.DISC_FEAT_DIM * 16) 728 | self.img_code_s32_1 = Block3x3_leakRelu(cfg.GAN.DISC_FEAT_DIM * 16, cfg.GAN.DISC_FEAT_DIM * 8) 729 | self.encode_img = nn.Sequential( 730 | # --> state size. ndf x in_size/2 x in_size/2 731 | nn.Conv2d(3, cfg.GAN.DISC_FEAT_DIM, 4, 2, 1, bias=False), 732 | nn.LeakyReLU(0.2, inplace=True), 733 | # --> state size 2ndf x x in_size/4 x in_size/4 734 | nn.Conv2d(cfg.GAN.DISC_FEAT_DIM, cfg.GAN.DISC_FEAT_DIM * 2, 4, 2, 1, bias=False), 735 | nn.BatchNorm2d(cfg.GAN.DISC_FEAT_DIM * 2), 736 | nn.LeakyReLU(0.2, inplace=True), 737 | ) 738 | self.encode_final = nn.Sequential( 739 | nn.Conv2d(cfg.GAN.DISC_FEAT_DIM * 4, cfg.GAN.DISC_FEAT_DIM * 4, 4, 2, 1, bias=False), 740 | nn.BatchNorm2d(cfg.GAN.DISC_FEAT_DIM * 4), 741 | nn.LeakyReLU(0.2, inplace=True), 742 | # --> state size 8ndf x in_size/16 x in_size/16 743 | nn.Conv2d(cfg.GAN.DISC_FEAT_DIM * 4, cfg.GAN.DISC_FEAT_DIM * 8, 4, 2, 1, bias=False), 744 | nn.BatchNorm2d(cfg.GAN.DISC_FEAT_DIM * 8), 745 | nn.LeakyReLU(0.2, inplace=True) 746 | ) 747 | # 748 | if b_jcu: 749 | self.UNCOND_DNET = D_GET_LOGITS(cfg.GAN.DISC_FEAT_DIM, cfg.TEXT.EMBEDDING_DIM, bcondition=False) 750 | else: 751 | self.UNCOND_DNET = None 752 | self.COND_DNET = D_GET_LOGITS(cfg.GAN.DISC_FEAT_DIM, cfg.TEXT.EMBEDDING_DIM, bcondition=True) 753 | 754 | self.local = nn.Sequential( 755 | nn.Conv2d(3 + cfg.TEXT.CLASSES_NUM, cfg.GAN.DISC_FEAT_DIM, 4, 1, 1, bias=False), 756 | nn.BatchNorm2d(cfg.GAN.DISC_FEAT_DIM), 757 | nn.LeakyReLU(0.2, inplace=True), 758 | nn.Conv2d(cfg.GAN.DISC_FEAT_DIM, cfg.GAN.DISC_FEAT_DIM * 2, 4, 1, 1, bias=False), 759 | nn.BatchNorm2d(cfg.GAN.DISC_FEAT_DIM * 2), 760 | nn.LeakyReLU(0.2, inplace=True), 761 | ) 762 | 763 | def forward(self, image, label, transf_matrices, transf_matrices_inv, max_objects): 764 | # object pathway 765 | h_code_locals = image.new_zeros(image.shape[0], cfg.GAN.DISC_FEAT_DIM * 2, 32, 32, dtype=torch.float) 766 | 767 | for idx in range(max_objects): 768 | current_label = label[:, idx].view(label.shape[0], cfg.TEXT.CLASSES_NUM, 1, 1) 769 | current_label = current_label.repeat(1, 1, 32, 32) 770 | h_code_local = stn(image, transf_matrices[:, idx], (image.shape[0], image.shape[1], 32, 32)) 771 | h_code_local = torch.cat((h_code_local, current_label), 1) 772 | h_code_local = self.local(h_code_local) 773 | h_code_local = stn(h_code_local, transf_matrices_inv[:, idx], 774 | (h_code_local.shape[0], h_code_local.shape[1], 32, 32)) 775 | h_code_locals = merge_tensors(h_code_locals, h_code_local, idx) 776 | 777 | x_code_32 = self.encode_img(image) # 32 x 32 x df*2 778 | x_code_32 = torch.cat((x_code_32, h_code_locals), 1) # 32 x 32 x df*4 779 | 780 | x_code8 = self.encode_final(x_code_32) # 8 x 8 x 8df 781 | x_code4 = self.img_code_s32(x_code8) # 4 x 4 x 16df 782 | x_code4 = self.img_code_s32_1(x_code4) # 4 x 4 x 8df 783 | return x_code4 784 | 785 | 786 | # For 256 x 256 images 787 | class D_NET256(nn.Module): 788 | def __init__(self, b_jcu=True): 789 | super(D_NET256, self).__init__() 790 | self.img_code_s16 = encode_image_by_16times(cfg.GAN.DISC_FEAT_DIM) 791 | self.encode_img = nn.Sequential( 792 | # --> state size. ndf x in_size/2 x in_size/2 793 | nn.Conv2d(3, cfg.GAN.DISC_FEAT_DIM, 4, 2, 1, bias=False), 794 | nn.LeakyReLU(0.2, inplace=True), 795 | # --> state size 2ndf x x in_size/4 x in_size/4 796 | nn.Conv2d(cfg.GAN.DISC_FEAT_DIM, cfg.GAN.DISC_FEAT_DIM * 2, 4, 2, 1, bias=False), 797 | nn.BatchNorm2d(cfg.GAN.DISC_FEAT_DIM * 2), 798 | nn.LeakyReLU(0.2, inplace=True), 799 | ) 800 | self.encode_final = nn.Sequential( 801 | nn.Conv2d(cfg.GAN.DISC_FEAT_DIM * 4, cfg.GAN.DISC_FEAT_DIM * 4, 4, 2, 1, bias=False), 802 | nn.BatchNorm2d(cfg.GAN.DISC_FEAT_DIM * 4), 803 | nn.LeakyReLU(0.2, inplace=True), 804 | # --> state size 8ndf x in_size/16 x in_size/16 805 | nn.Conv2d(cfg.GAN.DISC_FEAT_DIM * 4, cfg.GAN.DISC_FEAT_DIM * 8, 4, 2, 1, bias=False), 806 | nn.BatchNorm2d(cfg.GAN.DISC_FEAT_DIM * 8), 807 | nn.LeakyReLU(0.2, inplace=True) 808 | ) 809 | self.img_code_s32 = downBlock(cfg.GAN.DISC_FEAT_DIM * 8, cfg.GAN.DISC_FEAT_DIM * 16) 810 | self.img_code_s64 = downBlock(cfg.GAN.DISC_FEAT_DIM * 16, cfg.GAN.DISC_FEAT_DIM * 32) 811 | self.img_code_s64_1 = Block3x3_leakRelu(cfg.GAN.DISC_FEAT_DIM * 32, cfg.GAN.DISC_FEAT_DIM * 16) 812 | self.img_code_s64_2 = Block3x3_leakRelu(cfg.GAN.DISC_FEAT_DIM * 16, cfg.GAN.DISC_FEAT_DIM * 8) 813 | if b_jcu: 814 | self.UNCOND_DNET = D_GET_LOGITS(cfg.GAN.DISC_FEAT_DIM, cfg.TEXT.EMBEDDING_DIM, bcondition=False) 815 | else: 816 | self.UNCOND_DNET = None 817 | self.COND_DNET = D_GET_LOGITS(cfg.GAN.DISC_FEAT_DIM, cfg.TEXT.EMBEDDING_DIM, bcondition=True) 818 | 819 | self.local = nn.Sequential( 820 | nn.Conv2d(3 + cfg.TEXT.CLASSES_NUM, cfg.GAN.DISC_FEAT_DIM, 4, 1, 1, bias=False), 821 | nn.BatchNorm2d(cfg.GAN.DISC_FEAT_DIM), 822 | nn.LeakyReLU(0.2, inplace=True), 823 | nn.Conv2d(cfg.GAN.DISC_FEAT_DIM, cfg.GAN.DISC_FEAT_DIM * 2, 4, 1, 1, bias=False), 824 | nn.BatchNorm2d(cfg.GAN.DISC_FEAT_DIM * 2), 825 | nn.LeakyReLU(0.2, inplace=True), 826 | ) 827 | 828 | def forward(self, image, label, transf_matrices, transf_matrices_inv, max_objects): 829 | # object pathway 830 | h_code_locals = image.new_zeros(image.shape[0], cfg.GAN.DISC_FEAT_DIM * 2, 64, 64, dtype=torch.float) 831 | 832 | for idx in range(max_objects): 833 | current_label = label[:, idx].view(label.shape[0], cfg.TEXT.CLASSES_NUM, 1, 1) 834 | current_label = current_label.repeat(1, 1, 64, 64) 835 | h_code_local = stn(image, transf_matrices[:, idx], (image.shape[0], image.shape[1], 64, 64)) 836 | h_code_local = torch.cat((h_code_local, current_label), 1) 837 | h_code_local = self.local(h_code_local) 838 | h_code_local = stn(h_code_local, transf_matrices_inv[:, idx], 839 | (h_code_local.shape[0], h_code_local.shape[1], 64, 64)) 840 | h_code_locals = merge_tensors(h_code_locals, h_code_local, idx) 841 | 842 | x_code_64 = self.encode_img(image) 843 | x_code_64 = torch.cat((x_code_64, h_code_locals), 1) 844 | 845 | x_code16 = self.encode_final(x_code_64) 846 | x_code8 = self.img_code_s32(x_code16) 847 | x_code4 = self.img_code_s64(x_code8) 848 | x_code4 = self.img_code_s64_1(x_code4) 849 | x_code4 = self.img_code_s64_2(x_code4) 850 | 851 | return x_code4 852 | -------------------------------------------------------------------------------- /OP-GAN/code/trainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import logging 4 | 5 | from six.moves import range 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | from torch.autograd import Variable 11 | import torch.backends.cudnn as cudnn 12 | 13 | from PIL import Image 14 | from tqdm import tqdm 15 | import os 16 | import numpy as np 17 | import glob 18 | 19 | from miscc.config import cfg 20 | from miscc.utils import mkdir_p 21 | from miscc.utils import build_super_images, count_learnable_params, DataParallelPassThrough 22 | from miscc.utils import weights_init, load_params, copy_G_params 23 | from model import G_NET 24 | from datasets import prepare_data 25 | from model import RNN_ENCODER, CNN_ENCODER 26 | 27 | from miscc.losses import words_loss 28 | from miscc.losses import discriminator_loss, generator_loss, KL_loss 29 | 30 | logger = logging.getLogger() 31 | 32 | 33 | # ################# Text to image task############################ # 34 | class condGANTrainer(object): 35 | def __init__(self, output_dir, data_loader, n_words, ixtoword, resume): 36 | if cfg.TRAIN.FLAG: 37 | self.model_dir = os.path.join(output_dir, 'Model') 38 | self.image_dir = os.path.join(output_dir, 'Image') 39 | mkdir_p(self.model_dir) 40 | mkdir_p(self.image_dir) 41 | 42 | self.batch_size = cfg.TRAIN.BATCH_SIZE 43 | self.max_epoch = cfg.TRAIN.MAX_EPOCH 44 | self.resume = resume 45 | 46 | self.n_gpu = torch.cuda.device_count() 47 | 48 | self.n_words = n_words 49 | self.ixtoword = ixtoword 50 | 51 | if cfg.TRAIN.OPTIMIZE_DATA_LOADING: 52 | self.data_loader = data_loader 53 | self.num_batches = 0 54 | self.subset_lengths = [] 55 | for _idx in range(len(self.data_loader)): 56 | self.num_batches += len(self.data_loader[_idx]) 57 | self.subset_lengths.append(len(self.data_loader[_idx])) 58 | else: 59 | self.data_loader = data_loader 60 | self.num_batches = len(self.data_loader) 61 | 62 | if cfg.CUDA: 63 | torch.cuda.set_device(cfg.DEVICE) 64 | cudnn.benchmark = True 65 | 66 | def build_models(self): 67 | # ###################encoders######################################## # 68 | if cfg.TRAIN.NET_E == '': 69 | raise Exception('Error: no pretrained text encoder') 70 | 71 | image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) 72 | img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') 73 | state_dict = torch.load(img_encoder_path, map_location=lambda storage, loc: storage) 74 | image_encoder.load_state_dict(state_dict) 75 | for p in image_encoder.parameters(): 76 | p.requires_grad = False 77 | logger.info('Load image encoder from: %s', img_encoder_path) 78 | image_encoder.eval() 79 | 80 | text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) 81 | state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) 82 | text_encoder.load_state_dict(state_dict) 83 | for p in text_encoder.parameters(): 84 | p.requires_grad = False 85 | logger.info('Load text encoder from: %s', cfg.TRAIN.NET_E) 86 | text_encoder.eval() 87 | 88 | # #######################generator and discriminators############## # 89 | netsD = [] 90 | from model import D_NET64, D_NET128, D_NET256 91 | netG = G_NET() 92 | if cfg.TREE.BRANCH_NUM > 0: 93 | netsD.append(D_NET64()) 94 | if cfg.TREE.BRANCH_NUM > 1: 95 | netsD.append(D_NET128()) 96 | if cfg.TREE.BRANCH_NUM > 2: 97 | netsD.append(D_NET256()) 98 | 99 | netG.apply(weights_init) 100 | for i in range(len(netsD)): 101 | netsD[i].apply(weights_init) 102 | logger.info('# of params in netG: %s' % count_learnable_params(netG)) 103 | logger.info('# of netsD: %s', len(netsD)) 104 | logger.info('# of params in netsD: %s' % [count_learnable_params(netD) for netD in netsD]) 105 | epoch = 0 106 | 107 | if self.resume: 108 | checkpoint_list = sorted([ckpt for ckpt in glob.glob(self.model_dir + "/" + '*.pth')]) 109 | latest_checkpoint = checkpoint_list[-1] 110 | state_dict = torch.load(latest_checkpoint, map_location=lambda storage, loc: storage) 111 | 112 | netG.load_state_dict(state_dict["netG"]) 113 | for i in range(len(netsD)): 114 | netsD[i].load_state_dict(state_dict["netD"][i]) 115 | epoch = int(latest_checkpoint[-8:-4]) + 1 116 | logger.info("Resuming training from checkpoint {} at epoch {}.".format(latest_checkpoint, epoch)) 117 | 118 | # 119 | if cfg.TRAIN.NET_G != '': 120 | state_dict = torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) 121 | netG.load_state_dict(state_dict) 122 | logger.info('Load G from: %s', cfg.TRAIN.NET_G) 123 | istart = cfg.TRAIN.NET_G.rfind('_') + 1 124 | iend = cfg.TRAIN.NET_G.rfind('.') 125 | epoch = cfg.TRAIN.NET_G[istart:iend] 126 | epoch = int(epoch) + 1 127 | if cfg.TRAIN.B_NET_D: 128 | Gname = cfg.TRAIN.NET_G 129 | for i in range(len(netsD)): 130 | s_tmp = Gname[:Gname.rfind('/')] 131 | Dname = '%s/netD%d.pth' % (s_tmp, i) 132 | logger.info('Load D from: %s', Dname) 133 | state_dict = torch.load(Dname, map_location=lambda storage, loc: storage) 134 | netsD[i].load_state_dict(state_dict) 135 | # ########################################################### # 136 | if cfg.CUDA: 137 | text_encoder.to(cfg.DEVICE) 138 | image_encoder.to(cfg.DEVICE) 139 | netG.to(cfg.DEVICE) 140 | if self.n_gpu > 1: 141 | netG = DataParallelPassThrough(netG, ) 142 | for i in range(len(netsD)): 143 | netsD[i].to(cfg.DEVICE) 144 | if self.n_gpu > 1: 145 | netsD[i] = DataParallelPassThrough(netsD[i], ) 146 | return [text_encoder, image_encoder, netG, netsD, epoch] 147 | 148 | def define_optimizers(self, netG, netsD): 149 | optimizersD = [] 150 | num_Ds = len(netsD) 151 | for i in range(num_Ds): 152 | opt = optim.Adam(netsD[i].parameters(), 153 | lr=cfg.TRAIN.DISCRIMINATOR_LR, 154 | betas=(0.5, 0.999)) 155 | optimizersD.append(opt) 156 | 157 | optimizerG = optim.Adam(netG.parameters(), 158 | lr=cfg.TRAIN.GENERATOR_LR, 159 | betas=(0.5, 0.999)) 160 | 161 | if self.resume: 162 | checkpoint_list = sorted([ckpt for ckpt in glob.glob(self.model_dir + "/" + '*.pth')]) 163 | latest_checkpoint = checkpoint_list[-1] 164 | state_dict = torch.load(latest_checkpoint, map_location=lambda storage, loc: storage) 165 | optimizerG.load_state_dict(state_dict["optimG"]) 166 | 167 | for i in range(len(netsD)): 168 | optimizersD[i].load_state_dict(state_dict["optimD"][i]) 169 | 170 | return optimizerG, optimizersD 171 | 172 | def prepare_labels(self): 173 | if cfg.TRAIN.OPTIMIZE_DATA_LOADING: 174 | batch_sizes = self.batch_size 175 | real_labels, fake_labels, match_labels = [], [], [] 176 | for batch_size in batch_sizes: 177 | real_labels.append(Variable(torch.FloatTensor(batch_size).fill_(1).to(cfg.DEVICE).detach())) 178 | fake_labels.append(Variable(torch.FloatTensor(batch_size).fill_(0).to(cfg.DEVICE).detach())) 179 | match_labels.append(Variable(torch.LongTensor(range(batch_size)).to(cfg.DEVICE).detach())) 180 | else: 181 | batch_size = self.batch_size[0] 182 | real_labels = Variable(torch.FloatTensor(batch_size).fill_(1).to(cfg.DEVICE).detach()) 183 | fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0).to(cfg.DEVICE).detach()) 184 | match_labels = Variable(torch.LongTensor(range(batch_size)).to(cfg.DEVICE).detach()) 185 | 186 | return real_labels, fake_labels, match_labels 187 | 188 | def save_model(self, netG, avg_param_G, netsD, optimG, optimsD, epoch, max_to_keep=5, interval=5): 189 | netDs_state_dicts = [] 190 | optimDs_state_dicts = [] 191 | for i in range(len(netsD)): 192 | netD = netsD[i] 193 | optimD = optimsD[i] 194 | netDs_state_dicts.append(netD.state_dict()) 195 | optimDs_state_dicts.append(optimD.state_dict()) 196 | 197 | backup_para = copy_G_params(netG) 198 | load_params(netG, avg_param_G) 199 | checkpoint = { 200 | 'epoch': epoch, 201 | 'netG': netG.state_dict(), 202 | 'optimG': optimG.state_dict(), 203 | 'netD': netDs_state_dicts, 204 | 'optimD': optimDs_state_dicts} 205 | torch.save(checkpoint, "{}/checkpoint_{:04}.pth".format(self.model_dir, epoch)) 206 | logger.info('Save G/D models') 207 | 208 | load_params(netG, backup_para) 209 | 210 | if max_to_keep is not None and max_to_keep > 0: 211 | checkpoint_list_all = sorted([ckpt for ckpt in glob.glob(self.model_dir + "/" + '*.pth')]) 212 | checkpoint_list = [] 213 | checkpoint_list_tmp = [] 214 | 215 | for ckpt in checkpoint_list_all: 216 | ckpt_epoch = int(ckpt[-8:-4]) 217 | if ckpt_epoch % interval == 0: 218 | checkpoint_list.append(ckpt) 219 | else: 220 | checkpoint_list_tmp.append(ckpt) 221 | 222 | while len(checkpoint_list) > max_to_keep: 223 | os.remove(checkpoint_list[0]) 224 | checkpoint_list = checkpoint_list[1:] 225 | 226 | ckpt_tmp = len(checkpoint_list_tmp) 227 | for idx in range(ckpt_tmp-1): 228 | os.remove(checkpoint_list_tmp[idx]) 229 | 230 | def set_requires_grad_value(self, models_list, brequires): 231 | for i in range(len(models_list)): 232 | for p in models_list[i].parameters(): 233 | p.requires_grad = brequires 234 | 235 | def save_img_results(self, netG, noise, sent_emb, words_embs, mask, 236 | image_encoder, captions, cap_lens, 237 | gen_iterations, transf_matrices_inv, label_one_hot, local_noise, 238 | transf_matrices, max_objects, subset_idx, name='current'): 239 | # Save images 240 | inputs = (noise, local_noise, sent_emb, words_embs, mask, transf_matrices, transf_matrices_inv, 241 | label_one_hot, max_objects) 242 | fake_imgs, attention_maps, _, _ = netG(*inputs) 243 | for i in range(len(attention_maps)): 244 | if len(fake_imgs) > 1: 245 | img = fake_imgs[i + 1].detach().cpu() 246 | lr_img = fake_imgs[i].detach().cpu() 247 | else: 248 | img = fake_imgs[0].detach().cpu() 249 | lr_img = None 250 | attn_maps = attention_maps[i] 251 | att_sze = attn_maps.size(2) 252 | img_set, _ = build_super_images(img, captions, self.ixtoword, attn_maps, att_sze, lr_imgs=lr_img, 253 | batch_size=self.batch_size[0]) 254 | if img_set is not None: 255 | im = Image.fromarray(img_set) 256 | fullpath = '%s/G_%s_%d_%d.png' % (self.image_dir, name, gen_iterations, i) 257 | im.save(fullpath) 258 | 259 | # for i in range(len(netsD)): 260 | i = -1 261 | img = fake_imgs[i].detach() 262 | region_features, _ = image_encoder(img) 263 | att_sze = region_features.size(2) 264 | if cfg.TRAIN.OPTIMIZE_DATA_LOADING: 265 | _, _, att_maps = words_loss(region_features.detach(), words_embs.detach(), 266 | None, cap_lens, None, self.batch_size[subset_idx]) 267 | else: 268 | _, _, att_maps = words_loss(region_features.detach(), words_embs.detach(), 269 | None, cap_lens, None, self.batch_size[0]) 270 | img_set, _ = build_super_images(fake_imgs[i].detach().cpu(), captions, self.ixtoword, att_maps, att_sze) 271 | if img_set is not None: 272 | im = Image.fromarray(img_set) 273 | fullpath = '%s/D_%s_%d.png' % (self.image_dir, name, gen_iterations) 274 | im.save(fullpath) 275 | 276 | def train(self): 277 | torch.autograd.set_detect_anomaly(True) 278 | 279 | text_encoder, image_encoder, netG, netsD, start_epoch = self.build_models() 280 | avg_param_G = copy_G_params(netG) 281 | optimizerG, optimizersD = self.define_optimizers(netG, netsD) 282 | real_labels, fake_labels, match_labels = self.prepare_labels() 283 | 284 | if cfg.TRAIN.OPTIMIZE_DATA_LOADING: 285 | batch_sizes = self.batch_size 286 | noise, local_noise, fixed_noise = [], [], [] 287 | for batch_size in batch_sizes: 288 | noise.append(Variable(torch.FloatTensor(batch_size, cfg.GAN.GLOBAL_Z_DIM)).to(cfg.DEVICE)) 289 | local_noise.append(Variable(torch.FloatTensor(batch_size, cfg.GAN.LOCAL_Z_DIM)).to(cfg.DEVICE)) 290 | fixed_noise.append(Variable(torch.FloatTensor(batch_size, cfg.GAN.GLOBAL_Z_DIM).normal_(0, 1)).to(cfg.DEVICE)) 291 | else: 292 | batch_size = self.batch_size[0] 293 | noise = Variable(torch.FloatTensor(batch_size, cfg.GAN.GLOBAL_Z_DIM)).to(cfg.DEVICE) 294 | local_noise = Variable(torch.FloatTensor(batch_size, cfg.GAN.LOCAL_Z_DIM)).to(cfg.DEVICE) 295 | fixed_noise = Variable(torch.FloatTensor(batch_size, cfg.GAN.GLOBAL_Z_DIM).normal_(0, 1)).to(cfg.DEVICE) 296 | 297 | for epoch in range(start_epoch, self.max_epoch): 298 | logger.info("Epoch nb: %s" % epoch) 299 | gen_iterations = 0 300 | if cfg.TRAIN.OPTIMIZE_DATA_LOADING: 301 | data_iter = [] 302 | for _idx in range(len(self.data_loader)): 303 | data_iter.append(iter(self.data_loader[_idx])) 304 | total_batches_left = sum([len(self.data_loader[i]) for i in range(len(self.data_loader))]) 305 | current_probability = [len(self.data_loader[i]) for i in range(len(self.data_loader))] 306 | current_probability_percent = [current_probability[i] / float(total_batches_left) for i in 307 | range(len(current_probability))] 308 | else: 309 | data_iter = iter(self.data_loader) 310 | 311 | _dataset = tqdm(range(self.num_batches)) 312 | for step in _dataset: 313 | ###################################################### 314 | # (1) Prepare training data and Compute text embeddings 315 | ###################################################### 316 | if cfg.TRAIN.OPTIMIZE_DATA_LOADING: 317 | subset_idx = np.random.choice(range(len(self.data_loader)), size=None, 318 | p=current_probability_percent) 319 | total_batches_left -= 1 320 | if total_batches_left > 0: 321 | current_probability[subset_idx] -= 1 322 | current_probability_percent = [current_probability[i] / float(total_batches_left) for i in 323 | range(len(current_probability))] 324 | 325 | max_objects = subset_idx 326 | data = data_iter[subset_idx].next() 327 | else: 328 | data = data_iter.next() 329 | max_objects = 3 330 | _dataset.set_description('Obj-{}'.format(max_objects)) 331 | 332 | imgs, captions, cap_lens, class_ids, keys, transformation_matrices, label_one_hot = prepare_data(data) 333 | transf_matrices = transformation_matrices[0] 334 | transf_matrices_inv = transformation_matrices[1] 335 | 336 | with torch.no_grad(): 337 | if cfg.TRAIN.OPTIMIZE_DATA_LOADING: 338 | hidden = text_encoder.init_hidden(batch_sizes[subset_idx]) 339 | else: 340 | hidden = text_encoder.init_hidden(batch_size) 341 | # words_embs: batch_size x nef x seq_len 342 | # sent_emb: batch_size x nef 343 | words_embs, sent_emb = text_encoder(captions, cap_lens, hidden) 344 | words_embs, sent_emb = words_embs.detach(), sent_emb.detach() 345 | mask = (captions == 0).bool() 346 | num_words = words_embs.size(2) 347 | if mask.size(1) > num_words: 348 | mask = mask[:, :num_words] 349 | 350 | ####################################################### 351 | # (2) Generate fake images 352 | ###################################################### 353 | if cfg.TRAIN.OPTIMIZE_DATA_LOADING: 354 | noise[subset_idx].data.normal_(0, 1) 355 | local_noise[subset_idx].data.normal_(0, 1) 356 | inputs = (noise[subset_idx], local_noise[subset_idx], sent_emb, words_embs, mask, transf_matrices, 357 | transf_matrices_inv, label_one_hot, max_objects) 358 | else: 359 | noise.data.normal_(0, 1) 360 | local_noise.data.normal_(0, 1) 361 | inputs = (noise, local_noise, sent_emb, words_embs, mask, transf_matrices, transf_matrices_inv, 362 | label_one_hot, max_objects) 363 | 364 | inputs = tuple((inp.to(cfg.DEVICE) if isinstance(inp, torch.Tensor) else inp) for inp in inputs) 365 | fake_imgs, _, mu, logvar = netG(*inputs) 366 | 367 | ####################################################### 368 | # (3) Update D network 369 | ###################################################### 370 | # errD_total = 0 371 | D_logs = '' 372 | for i in range(len(netsD)): 373 | netsD[i].zero_grad() 374 | if cfg.TRAIN.OPTIMIZE_DATA_LOADING: 375 | errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i], 376 | sent_emb, real_labels[subset_idx], fake_labels[subset_idx], 377 | local_labels=label_one_hot, transf_matrices=transf_matrices, 378 | transf_matrices_inv=transf_matrices_inv, cfg=cfg, 379 | max_objects=max_objects) 380 | else: 381 | errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i], 382 | sent_emb, real_labels, fake_labels, 383 | local_labels=label_one_hot, transf_matrices=transf_matrices, 384 | transf_matrices_inv=transf_matrices_inv, cfg=cfg, 385 | max_objects=max_objects) 386 | 387 | # backward and update parameters 388 | errD.backward() 389 | optimizersD[i].step() 390 | D_logs += 'errD%d: %.2f ' % (i, errD.item()) 391 | 392 | ####################################################### 393 | # (4) Update G network: maximize log(D(G(z))) 394 | ###################################################### 395 | # compute total loss for training G 396 | # step += 1 397 | gen_iterations += 1 398 | 399 | # do not need to compute gradient for Ds 400 | netG.zero_grad() 401 | if cfg.TRAIN.OPTIMIZE_DATA_LOADING: 402 | errG_total = \ 403 | generator_loss(netsD, image_encoder, fake_imgs, real_labels[subset_idx], 404 | words_embs, sent_emb, match_labels[subset_idx], cap_lens, class_ids, 405 | local_labels=label_one_hot, transf_matrices=transf_matrices, 406 | transf_matrices_inv=transf_matrices_inv, max_objects=max_objects) 407 | else: 408 | errG_total = \ 409 | generator_loss(netsD, image_encoder, fake_imgs, real_labels, 410 | words_embs, sent_emb, match_labels, cap_lens, class_ids, 411 | local_labels=label_one_hot, transf_matrices=transf_matrices, 412 | transf_matrices_inv=transf_matrices_inv, max_objects=max_objects) 413 | kl_loss = KL_loss(mu, logvar) 414 | errG_total += kl_loss 415 | # backward and update parameters 416 | errG_total.backward() 417 | optimizerG.step() 418 | for p, avg_p in zip(netG.parameters(), avg_param_G): 419 | avg_p.mul_(0.999).add_(p.data, alpha=0.001) 420 | 421 | if cfg.TRAIN.EMPTY_CACHE: 422 | torch.cuda.empty_cache() 423 | 424 | # save images 425 | if ( 426 | 2 * gen_iterations == self.num_batches 427 | or 2 * gen_iterations + 1 == self.num_batches 428 | or gen_iterations + 1 == self.num_batches 429 | ): 430 | logger.info('Saving images...') 431 | backup_para = copy_G_params(netG) 432 | load_params(netG, avg_param_G) 433 | if cfg.TRAIN.OPTIMIZE_DATA_LOADING: 434 | self.save_img_results(netG, fixed_noise[subset_idx], sent_emb, 435 | words_embs, mask, image_encoder, 436 | captions, cap_lens, epoch, transf_matrices_inv, 437 | label_one_hot, local_noise[subset_idx], transf_matrices, 438 | max_objects, subset_idx, name='average') 439 | else: 440 | self.save_img_results(netG, fixed_noise, sent_emb, 441 | words_embs, mask, image_encoder, 442 | captions, cap_lens, epoch, transf_matrices_inv, 443 | label_one_hot, local_noise, transf_matrices, 444 | max_objects, None, name='average') 445 | load_params(netG, backup_para) 446 | 447 | self.save_model(netG, avg_param_G, netsD, optimizerG, optimizersD, epoch) 448 | self.save_model(netG, avg_param_G, netsD, optimizerG, optimizersD, epoch) 449 | 450 | def sampling(self, split_dir, num_samples=30000): 451 | if cfg.TRAIN.NET_G == '': 452 | logger.error('Error: the path for morels is not found!') 453 | else: 454 | if split_dir == 'test': 455 | split_dir = 'valid' 456 | # Build and load the generator 457 | if cfg.GAN.B_DCGAN: 458 | netG = G_DCGAN() 459 | else: 460 | netG = G_NET() 461 | netG.apply(weights_init) 462 | netG.to(cfg.DEVICE) 463 | netG.eval() 464 | # 465 | text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) 466 | state_dict = torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) 467 | text_encoder.load_state_dict(state_dict) 468 | text_encoder = text_encoder.to(cfg.DEVICE) 469 | text_encoder.eval() 470 | logger.info('Loaded text encoder from: %s', cfg.TRAIN.NET_E) 471 | 472 | batch_size = self.batch_size[0] 473 | nz = cfg.GAN.GLOBAL_Z_DIM 474 | noise = Variable(torch.FloatTensor(batch_size, nz)).to(cfg.DEVICE) 475 | local_noise = Variable(torch.FloatTensor(batch_size, cfg.GAN.LOCAL_Z_DIM)).to(cfg.DEVICE) 476 | 477 | model_dir = cfg.TRAIN.NET_G 478 | state_dict = torch.load(model_dir, map_location=lambda storage, loc: storage) 479 | netG.load_state_dict(state_dict["netG"]) 480 | max_objects = 10 481 | logger.info('Load G from: %s', model_dir) 482 | 483 | # the path to save generated images 484 | s_tmp = model_dir[:model_dir.rfind('.pth')].split("/")[-1] 485 | save_dir = '%s/%s/%s' % ("../output", s_tmp, split_dir) 486 | mkdir_p(save_dir) 487 | logger.info("Saving images to: {}".format(save_dir)) 488 | 489 | number_batches = num_samples // batch_size 490 | if number_batches < 1: 491 | number_batches = 1 492 | 493 | data_iter = iter(self.data_loader) 494 | 495 | for step in tqdm(range(number_batches)): 496 | data = data_iter.next() 497 | 498 | imgs, captions, cap_lens, class_ids, keys, transformation_matrices, label_one_hot, _ = prepare_data( 499 | data, eval=True) 500 | 501 | transf_matrices = transformation_matrices[0] 502 | transf_matrices_inv = transformation_matrices[1] 503 | 504 | hidden = text_encoder.init_hidden(batch_size) 505 | # words_embs: batch_size x nef x seq_len 506 | # sent_emb: batch_size x nef 507 | words_embs, sent_emb = text_encoder(captions, cap_lens, hidden) 508 | words_embs, sent_emb = words_embs.detach(), sent_emb.detach() 509 | mask = (captions == 0) 510 | num_words = words_embs.size(2) 511 | if mask.size(1) > num_words: 512 | mask = mask[:, :num_words] 513 | 514 | ####################################################### 515 | # (2) Generate fake images 516 | ###################################################### 517 | noise.data.normal_(0, 1) 518 | local_noise.data.normal_(0, 1) 519 | inputs = (noise, local_noise, sent_emb, words_embs, mask, transf_matrices, transf_matrices_inv, label_one_hot, max_objects) 520 | inputs = tuple((inp.to(cfg.DEVICE) if isinstance(inp, torch.Tensor) else inp) for inp in inputs) 521 | 522 | with torch.no_grad(): 523 | fake_imgs, _, mu, logvar = netG(*inputs) 524 | for batch_idx, j in enumerate(range(batch_size)): 525 | s_tmp = '%s/%s' % (save_dir, keys[j]) 526 | folder = s_tmp[:s_tmp.rfind('/')] 527 | if not os.path.isdir(folder): 528 | logger.info('Make a new folder: %s', folder) 529 | mkdir_p(folder) 530 | k = -1 531 | # for k in range(len(fake_imgs)): 532 | im = fake_imgs[k][j].data.cpu().numpy() 533 | # [-1, 1] --> [0, 255] 534 | im = (im + 1.0) * 127.5 535 | im = im.astype(np.uint8) 536 | im = np.transpose(im, (1, 2, 0)) 537 | im = Image.fromarray(im) 538 | fullpath = '%s_s%d.png' % (s_tmp, step*batch_size+batch_idx) 539 | im.save(fullpath) 540 | -------------------------------------------------------------------------------- /OP-GAN/data/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /OP-GAN/environment.yml: -------------------------------------------------------------------------------- 1 | name: tsenv 2 | channels: 3 | - defaults 4 | - conda-forge 5 | - pytorch 6 | dependencies: 7 | - cudatoolkit=11.0 8 | - python=3.8.5 9 | - pytorch 10 | - torchvision 11 | - tqdm 12 | - easydict 13 | - nltk 14 | - scikit-image 15 | - python-dateutil 16 | - pip: -------------------------------------------------------------------------------- /OP-GAN/models/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /OP-GAN/output/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /OP-GAN/sample.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | GPU=$1 3 | export CUDA_VISIBLE_DEVICES=${GPU} 4 | export PYTHONUNBUFFERED=1 5 | if [ -z "$GPU" ] 6 | then 7 | echo "Starting training on CPU." 8 | else 9 | echo "Starting training on GPU ${GPU}." 10 | fi 11 | python3 -u code/main.py --cfg code/cfg/cfg_file_eval.yml 12 | echo "Done." 13 | -------------------------------------------------------------------------------- /OP-GAN/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | GPU=$1 3 | export CUDA_VISIBLE_DEVICES=${GPU} 4 | export PYTHONUNBUFFERED=1 5 | if [ -z "$GPU" ] 6 | then 7 | echo "Starting training on CPU." 8 | else 9 | echo "Starting training on GPU ${GPU}." 10 | fi 11 | python3 -u code/main.py --cfg code/cfg/cfg_file_train.yml 12 | echo "Done." 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Semantic Object Accuracy for Generative Text-to-Image Synthesis 2 | Code for our paper [Semantic Object Accuracy for Generative Text-to-Image Synthesis](https://ieeexplore.ieee.org/document/9184960) ([Arxiv Version](https://arxiv.org/abs/1910.13321)) published in TPAMI 2020. 3 | 4 | Summary in our [blog post](https://www.tobiashinz.com/2019/10/30/semantic-object-accuracy-for-generative-text-to-image-synthesis). 5 | 6 | Semantic Object Accuracy (SOA) is a score we introduce to evaluate the quality of generative text-to-image models. For this, we provide captions from the MS-COCO data set from which the evaluated model should generate images. We then use a pre-trained object detector to check whether the generated images contain the object that was specified in the caption. 7 | E.g. when an image is generated from the caption `a car is driving down the street` we check if the generated image actually contains a car. For more details check section 4 of our [paper](https://arxiv.org/abs/1910.13321). 8 | 9 | We also perform a user study in which humans rate the images generated by several state-of-the art models trained on the MS-COCO dataset. 10 | We then compare the ranking obtained through our user study with the rankings obtained by different quantitative evaluation metrics. We show that popular metrics, such as e.g. the Inception Score, do not correlate with how humans rate the generated images, whereas SOA strongly correlates with human judgement. 11 | 12 | Contents: 13 | * [Calculate SOA Scores](#calculate-soa-scores-semantic-object-accuracy) 14 | * [Use Our Model (OP-GAN)](#use-our-model-op-gan) 15 | 16 | ## Calculate SOA Scores (Semantic Object Accuracy) 17 | 18 | How to calculate the SOA scores for a model: 19 | 20 | 1. Go to ``SOA``. The captions are in ``SOA/captions`` 21 | 1. each file is named ``label_XX_XX.pkl`` describing for which labels the captions in the file are 22 | 2. load the file with pickle 23 | * ```python 24 | import pickle 25 | with open(label_XX_XX.pkl, "rb") as f: 26 | captions = pickle.load(f) 27 | ``` 28 | 3. each file is a list and each entry in the list is a dictionary containing information about the caption: 29 | * ```python 30 | [{'image_id': XX, 'id': XX, 'idx': [XX, XX], 'caption': u'XX'}, ...] 31 | ``` 32 | * where ``'idx': [XX, XX]`` gives the indices for the validation captions in the commonly used captions file from [AttnGAN](https://github.com/taoxugit/AttnGAN) 33 | 2. Use your model to generate images from the specified captions 34 | 35 | 1. each caption file contains the relevant captions for the given label 36 | 2. create a new and empty folder 37 | 3. use each caption file to generate images for each caption and save the images in a folder within the previously created empty folder, i.e. for each of the labels (0-79) there should be a new folder in the previously created folder and the folder structure should look like this 38 | * images 39 | * label_00 -> folder contains images generated from captions for label 0 40 | * label_01 -> folder contains images generated from captions for label 1 41 | * ... 42 | * label_79 -> folder contains images generated from captions for label 79 43 | 4. each new folder (that contains generated images) should contain the string "label_XX" somewhere in its name (make sure that integers are formated to two digits, e.g. "0", "02", ...) -> ideally give the folders the same name as the label files 44 | 5. generate **three images for each caption** in each file 45 | * exception: for label "00" (person) randomly sample 30,000 captions and generate one image each for a total of 30,000 images 46 | 6. in the end you should have 80 folders in the folder created in the step (2.ii), each folder should have the string "label_XX" in it for identification, and each folder should contain the generated images for this label 47 | 48 | 3. Once you have generated images for each label you can calculate the SOA scores: 49 | 1. Install requirements from ``SOA/requirements.txt`` (we use Python 3.5.2) 50 | 2. [download](https://www2.informatik.uni-hamburg.de/wtm/software/semantic-object-accuracy/yolov3.weights.tar.gz) the YOLOv3 weights file and save it as ``SOA/yolov3.weights`` 51 | 3. run ``python calculate_soa.py --images path/to/folder/created-in-step-2ii --output path/to/folder/where-results-are-saved --gpu 0`` 52 | 53 | 4. If you also want to calculate IoU values check the detailed instructions [here](SOA/README.md) 54 | 5. Calculating the SOA scores takes about 30-45 minutes (tested with a NVIDIA GTX 1080TI) depending on your hardware (not including the time it takes to generate the images) 55 | 6. More detailed information (if needed) [here](SOA/README.md) 56 | 57 | ## Use Our Model (OP-GAN) 58 | #### Dependencies 59 | - python 3.8.5 60 | - pytorch 1.7.1 61 | 62 | Go to ``OP-GAN``. 63 | Please add the project folder to PYTHONPATH and install the required dependencies: 64 | 65 | ``` 66 | conda env create -f environment.yml 67 | ``` 68 | 69 | #### Data 70 | - MS-COCO: 71 | - [download](https://www2.informatik.uni-hamburg.de/wtm/software/semantic-object-accuracy/data.tar.gz) our preprocessed data (bounding boxes, bounding box labels, preprocessed captions), save to `data/` and extract 72 | - the preprocessed captions are obtained from and are the same as in the [AttnGAN implementation](https://github.com/taoxugit/AttnGAN) 73 | - the generateod bounding boxes for evaluating at test time were generated with code from the [Obj-GAN](https://github.com/jamesli1618/Obj-GAN) 74 | - obtain the train and validation images from the 2014 split [here](http://cocodataset.org/#download), extract and save them in `data/train/` and `data/test/` 75 | - download the pre-trained DAMSM for COCO model from [here](https://github.com/taoxugit/AttnGAN), put it into `models/` and extract 76 | 77 | #### Training 78 | - to start training run `sh train.sh gpu-ids` where you choose which gpus to train on 79 | - e.g. `sh train.sh 0,1,2,3` 80 | - training parameters can be adapted via `code/cfg/dataset_train.yml`, if you train on more/fewer GPUs or have more VRAM adjust the batch sizes as needed 81 | - make sure the DATA_DIR in the respective `code/cfg/cfg_file_train.yml` points to the correct path 82 | - results are stored in `output/` 83 | 84 | #### Evaluating 85 | - update the eval cfg file in `code/cfg/dataset_eval.yml` and adapt the path of `NET_G` to point to the model you want to use (default path is to the pretrained model linked below) 86 | - run `sh sample.sh gpu-ids` to generate images using the specified model 87 | - e.g. `sh sample.sh 0` 88 | 89 | #### Pretrained Models 90 | - OP-GAN: [download](https://www2.informatik.uni-hamburg.de/wtm/software/semantic-object-accuracy/op-gan.pth) and save to `models` 91 | 92 | 93 | ## Acknowledgement 94 | - Code and preprocessed metadata for the experiments on MS-COCO are adapted from [AttnGAN](https://github.com/taoxugit/AttnGAN) and [AttnGAN+OP](https://github.com/tohinz/multiple-objects-gan). 95 | - Code to generate bounding boxes for evaluation at test time is from the [Obj-GAN](https://github.com/jamesli1618/Obj-GAN) implementation. 96 | - Code for using YOLOv3 is adapted from [here](https://pjreddie.com/darknet/), [here](https://github.com/eriklindernoren/PyTorch-YOLOv3), and [here](https://github.com/ayooshkathuria/pytorch-yolo-v3). 97 | 98 | ## Citing 99 | If you find our model useful in your research please consider citing: 100 | 101 | ``` 102 | @article{hinz2019semantic, 103 | title = {Semantic Object Accuracy for Generative Text-to-Image Synthesis}, 104 | author = {Tobias Hinz and Stefan Heinrich and Stefan Wermter}, 105 | journal = {arXiv preprint arXiv:1910.13321}, 106 | year = {2019}, 107 | } 108 | ``` 109 | -------------------------------------------------------------------------------- /SOA/README.md: -------------------------------------------------------------------------------- 1 | # Details for calculating the SOA Scores 2 | 3 | ## Work with the caption files 4 | To load the captions: load a caption file, get the captions, and generate images: 5 | ```python 6 | import pickle 7 | import my_model 8 | 9 | # load the caption file 10 | with open(label_01_bicycle.pkl, "rb") as f: 11 | captions = pickle.load(f) 12 | 13 | # iterate over the captions and generate three images each 14 | for caption in captions: 15 | current_caption = caption["caption"] 16 | for idx in range(3): 17 | my_generated_image = my_model(current_caption) 18 | save("images/label_01_bicycle/my_generated_image_{}.png".format(idx)) 19 | ``` 20 | 21 | Alternatively, if you're working with the ``captions.pickle`` file from the [AttnGAN](https://github.com/taoxugit/AttnGAN) and their dataloader you can use the provided ``idx`` to load the file directly from the file: 22 | ```python 23 | import pickle 24 | import my_model 25 | 26 | # load the AttnGAN captions file 27 | with open(captions.pickle, "rb") as f: 28 | attngan_captions = pickle.load(f) 29 | test_captions = attngan_captions[1] 30 | 31 | # load the caption file 32 | with open(label_01_bicycle.pkl, "rb") as f: 33 | captions = pickle.load(f) 34 | 35 | # iterate over the captions and generate three images each 36 | for caption in captions: 37 | current_caption_idx = caption["idx"] 38 | # new_ix is the index for the filenames 39 | new_ix = [current_caption_idx[0]] 40 | # new_sent_ix is the index to the exact caption, e.g. use it for 41 | # caps, cap_len = get_caption(new_sent_ix) 42 | new_sent_ix = [current_caption_idx[0]*5+current_caption_idx[i][1]] 43 | for idx in range(3): 44 | ... 45 | ``` 46 | 47 | For the file ``label_00_person.pkl`` we randomly sample 30,000 captions and generate one image each: 48 | ```python 49 | import pickle 50 | import random 51 | import my_model 52 | 53 | # load the caption file 54 | with open(label_00_person.pkl, "rb") as f: 55 | captions = pickle.load(f) 56 | 57 | caption_subset = random.sample(captions, 30000) 58 | 59 | # iterate over the captions and generate three images each 60 | for caption in caption_subset: 61 | current_caption = caption["caption"] 62 | my_generated_image = my_model(current_caption) 63 | save("images/label_00_person/my_generated_image.png) 64 | ``` 65 | 66 | ## Calculating IoU Scores 67 | For the IoU scores it is important that you use the same label mappings as we (we use the standard mapping). Our labels can be found in ``data/coco.names`` where each label is mapped to the line it is on, i.e. ``person=0, bicycle=1, ...`` 68 | 69 | In order to calculate the IoU scores you need to save the "ground truth" information, i.e. the bounding boxes you give your model as input, so we can compare them with the bounding boxes from the detection network. 70 | We expect the information about the bounding boxes as a pickle file which is a dictionary of the form 71 | ```python 72 | output_dict = {"name_of_the_generated_image": [[], [label_int], [bbox]], 73 | ...} 74 | # for example: 75 | output_dict = {"my_generated_image": [[], [1, 1], [[0.1, 0.1, 0.3, 0.5], [0.6, 0.2, 0.2, 0.4]]]} 76 | ``` 77 | Here, ``label_int`` is a list of the integer labels you use as conditioning (e.g. ``person=0, bicycle=1, ...``) and ``bbox`` 78 | is a list of the bounding boxes ``[x, y, width, height]`` where the values are normalized to be between ``[0,1]`` and the coordinate system starts at the top left corner of the image, i.e. a bounding box of ``[0, 0, 0.5, 0.5]`` covers the top left quarter of the image. 79 | The ``output_dict`` should be saved in the same folder as the images for which it was created. 80 | 81 | ```python 82 | import pickle 83 | import my_model 84 | 85 | # load the caption file 86 | with open(label_01_bicycle.pkl, "rb") as f: 87 | captions = pickle.load(f) 88 | 89 | # this is the dictionary we use to save the bounding boxes 90 | output_dict = {} 91 | 92 | # iterate over the captions and generate three images each 93 | for caption in captions: 94 | current_caption = caption["caption"] 95 | for idx in range(3): 96 | my_generated_image = my_model(current_caption) 97 | save("images/label_01_bicycle/my_generated_image_{}.png".format(idx)) 98 | # label_int is a list of the integer values for labels you used as input to the network 99 | # bbox is a list with the corresponding bounding boxes [x, y, width, height] 100 | # e.g. label_int = [1, 1] 101 | # bbox = [[0.1, 0.1, 0.3, 0.5], [0.6, 0.2, 0.2, 0.4]] 102 | output_dict["my_generated_image_{}.png".format(idx)] = [[], label_int, bbox] 103 | 104 | with open("images/label_01_bicycle/ground_truth_label_01_bicycle.pkl", "wb") as f: 105 | pickle.dump(output_dict, f) 106 | ``` 107 | 108 | Finally, you should have the 80 folders with images as before, but now each folder should also contain a ``.pkl`` file with the ground truth information of the given layout. 109 | Run the same command as before but with the ``--iou`` flag: ``python calculate_soa.py --images path/to/folder/created-in-first-step --output path/to/folder/where-results-are-saved --gpu 0 --iou`` 110 | -------------------------------------------------------------------------------- /SOA/calculate_soa.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import torch.nn as nn 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | import numpy as np 7 | import argparse 8 | import os 9 | import pickle as pkl 10 | from tqdm import tqdm 11 | import glob 12 | import shutil 13 | 14 | from darknet import Darknet 15 | from dataset import YoloDataset 16 | from util import * 17 | 18 | 19 | def arg_parse(): 20 | """ 21 | Parse arguements to the detect module 22 | 23 | """ 24 | parser = argparse.ArgumentParser(description='YOLO v3 Detection Module') 25 | 26 | parser.add_argument("--images", dest='images', help="Image / Directory containing images to perform detection upon", 27 | type=str) 28 | parser.add_argument("--output", dest='output', help="Image / Directory to store detections to", 29 | default="output", type=str) 30 | parser.add_argument("--bs", dest="bs", help="Batch size", default=50) 31 | parser.add_argument("--confidence", dest="confidence", help="Object Confidence to filter predictions", default=0.5) 32 | parser.add_argument("--nms_thresh", dest="nms_thresh", help="NMS Threshhold", default=0.4) 33 | parser.add_argument("--cfg", dest='cfgfile', help="Config file", default="cfg/yolov3.cfg", type=str) 34 | parser.add_argument("--weights", dest='weightsfile', help="weightsfile", default="yolov3.weights", type=str) 35 | parser.add_argument("--resolution", dest='resolution', default="256", type=str, 36 | help="Input resolution of the network. Increase to increase accuracy. Decrease to increase speed") 37 | parser.add_argument("--image_size", dest='image_size', help="Size of evaluated images", default=256, type=int) 38 | parser.add_argument('--iou', dest='iou', action='store_true') 39 | parser.add_argument('--gpu', dest='gpu', type=str, default="0") 40 | 41 | return parser.parse_args() 42 | 43 | 44 | def run_yolo(args): 45 | images = args.images 46 | batch_size = int(args.bs) 47 | confidence = float(args.confidence) 48 | nms_thresh = float(args.nms_thresh) 49 | img_size = args.image_size 50 | 51 | # check that the given folder contains exactly 80 folders 52 | _all_dirs = os.listdir(images) 53 | _num_folders = 0 54 | for _dir in _all_dirs: 55 | if os.path.isdir(os.path.join(images, _dir)): 56 | _num_folders += 1 57 | if _num_folders != 80: 58 | print("") 59 | print("****************************************************************************") 60 | print("\tWARNING") 61 | print("\tDid not find exactly 80 folders ({} folders found) in {}.".format(_num_folders, images)) 62 | print("\tFor the final calculation please make sure the folder {} contains one subfolder for each of the labels.".format(images)) 63 | print("\tCalculating scores on {}/80 labels now, but results will not be conclusive.".format(_num_folders)) 64 | print("****************************************************************************") 65 | 66 | if not os.path.exists(args.output): 67 | os.makedirs(args.output) 68 | 69 | CUDA = torch.cuda.is_available() 70 | 71 | classes = load_classes('data/coco.names') 72 | 73 | # Set up the neural network 74 | print("Loading network.....") 75 | model = Darknet(args.cfgfile) 76 | model.load_weights(args.weightsfile) 77 | print("Network successfully loaded") 78 | 79 | model.net_info["height"] = args.resolution 80 | inp_dim = int(model.net_info["height"]) 81 | assert inp_dim % 32 == 0 82 | assert inp_dim > 32 83 | 84 | # If there's a GPU available, put the model on GPU 85 | if CUDA: 86 | _gpu = int(args.gpu) 87 | torch.cuda.set_device(_gpu) 88 | model.cuda() 89 | print("Using GPU: {}".format(_gpu)) 90 | 91 | # Set the model in evaluation mode 92 | model.eval() 93 | print("saving to {}".format(args.output)) 94 | 95 | # go through all folders of generated images 96 | for dir in tqdm(os.listdir(images)): 97 | full_dir = os.path.join(images, dir) 98 | 99 | # check if there exists a ground truth file (which would contain bboxes etc to calculate IoU) 100 | ground_truth_file = [_file for _file in os.listdir(full_dir) if _file.endswith(".pkl")] 101 | if len(ground_truth_file) > 0: 102 | shutil.copyfile(os.path.join(full_dir, ground_truth_file[0]), 103 | os.path.join(args.output, "ground_truth_{}.pkl".format(dir))) 104 | 105 | # check if detection was already run for this label 106 | if os.path.isfile(os.path.join(args.output, "detected_{}.pkl".format(dir))): 107 | print("Detection already run for {}. Continuing with next label.".format(dir)) 108 | continue 109 | 110 | # create dataset from images in the current folder 111 | image_transform = transforms.Compose([ 112 | transforms.Resize((img_size, img_size)), 113 | transforms.ToTensor(), 114 | transforms.Normalize((0., 0., 0.), (1, 1, 1))]) 115 | dataset = YoloDataset(full_dir, transform=image_transform) 116 | assert dataset 117 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, 118 | drop_last=False, shuffle=False, num_workers=4) 119 | 120 | num_batches = len(dataloader) 121 | dataloader = iter(dataloader) 122 | output_dict = {} 123 | 124 | # get YOLO predictions for images in current folder 125 | for idx in tqdm(range(num_batches)): 126 | data = dataloader.next() 127 | imgs, filenames = data 128 | if CUDA: 129 | imgs = imgs.cuda() 130 | 131 | with torch.no_grad(): 132 | predictions = model(imgs, CUDA) 133 | predictions = non_max_suppression(predictions, confidence, nms_thresh) 134 | 135 | for img, preds in zip(filenames, predictions): 136 | img_preds_name = [] 137 | img_preds_id = [] 138 | img_bboxs = [] 139 | if preds is not None and len(preds) > 0: 140 | for pred in preds: 141 | pred_id = int(pred[-1]) 142 | pred_name = classes[pred_id] 143 | 144 | bbox_x = pred[0] / img_size 145 | bbox_y = pred[1] / img_size 146 | bbox_width = (pred[2] - pred[0]) / img_size 147 | bbox_height = (pred[3] - pred[1]) / img_size 148 | 149 | img_preds_id.append(pred_id) 150 | img_preds_name.append(pred_name) 151 | img_bboxs.append([bbox_x.cpu().numpy(), bbox_y.cpu().numpy(), 152 | bbox_width.cpu().numpy(), bbox_height.cpu().numpy()]) 153 | output_dict[img.split("/")[-1]] = [img_preds_name, img_preds_id, img_bboxs] 154 | 155 | with open(os.path.join(args.output, "detected_{}.pkl".format(dir)), "wb") as f: 156 | pkl.dump(output_dict, f) 157 | 158 | 159 | def calc_recall(predicted_bbox, label): 160 | """Calculate how often a given object (label) was detected in the images""" 161 | correctly_recognized = 0 162 | num_images_total = len(predicted_bbox.keys()) 163 | for key in predicted_bbox.keys(): 164 | predictions = predicted_bbox[key] 165 | for recognized_label in predictions[1]: 166 | if recognized_label == label: 167 | correctly_recognized += 1 168 | break 169 | if num_images_total == 0: 170 | return 0, 0, 0 171 | accuracy = float(correctly_recognized) / num_images_total 172 | return accuracy, correctly_recognized, num_images_total 173 | 174 | 175 | def calc_iou(predicted_bbox, gt_bbox, label): 176 | """Calculate max IoU between correctly detected objects and provided ground truths for each image""" 177 | ious = [] 178 | 179 | # iterate over the predictions for all images 180 | for key in predicted_bbox.keys(): 181 | predicted_bboxes = [] 182 | # get predictions for the image 183 | predictions = predicted_bbox[key] 184 | # check if it recognized an object of the given label and if yes get its predicted bounding box 185 | # for all detected objects of the given label 186 | for recognized_label, pred_bbox in zip(predictions[1], predictions[2]): 187 | if recognized_label == label: 188 | predicted_bboxes.append(pred_bbox) 189 | 190 | gt_bboxes = [] 191 | # get the ground truth information of the current image 192 | gts = gt_bbox[key] 193 | 194 | if gts[1] is None or len(gts[1]) == 0: 195 | continue 196 | else: 197 | # gts should e.g. be [[], 198 | # [7, 1], -> integer values for the object labels 199 | # [[0.1, 0.2, 0.3, 0.5], [0.4, 0.3, 0.3, 0.5]] -> bounding boxes 200 | assert type(gts[1]) is list and type(gts[2]) is list,\ 201 | "Expected lists as entries of the ground truth bounding box file" 202 | for real_label, real_bbox in zip(gts[1], gts[2]): 203 | if real_label == label: 204 | assert all([_val >= 0 and _val <= 1 for _val in real_bbox]), \ 205 | "Bounding box entries should be between 0 and 1 but are: {}.".format(real_bbox) 206 | gt_bboxes.append(real_bbox) 207 | 208 | # calculate all IoUs between ground truth bounding boxes of the given label 209 | # and predicted bounding boxes of the given label 210 | all_current_ious = [] 211 | for current_predicted_bbox in predicted_bboxes: 212 | for current_gt_bbox in gt_bboxes: 213 | current_iou = get_iou(current_predicted_bbox, current_gt_bbox) 214 | all_current_ious.append(current_iou) 215 | # choose the maximum value as the IoU for this image 216 | if len(all_current_ious) > 0: 217 | ious.append(max(all_current_ious)) 218 | if len(ious) == 0: 219 | return 0.0 220 | avg_iou = sum(ious) / float(len(ious)) 221 | return avg_iou 222 | 223 | 224 | def calc_overall_class_average_accuracy(dict): 225 | """Calculate SOA-C""" 226 | accuracy = 0 227 | for label in dict.keys(): 228 | accuracy += dict[label]["accuracy"] 229 | overall_accuracy = accuracy / len(dict.keys()) 230 | return overall_accuracy 231 | 232 | 233 | def calc_image_weighted_average_accuracy(dict): 234 | """Calculate SOA-I""" 235 | accuracy = 0 236 | total_images = 0 237 | for label in dict.keys(): 238 | num_images = dict[label]["images_total"] 239 | accuracy += num_images * dict[label]["accuracy"] 240 | total_images += num_images 241 | overall_accuracy = accuracy / total_images 242 | return overall_accuracy 243 | 244 | 245 | def calc_split_class_average_accuracy(dict): 246 | """Calculate SOA-C-Top/Bot-40""" 247 | num_img_list = [] 248 | for label in dict.keys(): 249 | num_img_list.append([label, dict[label]["images_total"]]) 250 | num_img_list.sort(key=lambda x: x[1]) 251 | sorted_label_list = [x[0] for x in num_img_list] 252 | 253 | bottom_40_accuracy = 0 254 | top_40_accuracy = 0 255 | for label in dict.keys(): 256 | if sorted_label_list.index(label) < 40: 257 | bottom_40_accuracy += dict[label]["accuracy"] 258 | else: 259 | top_40_accuracy += dict[label]["accuracy"] 260 | bottom_40_accuracy /= 0.5*len(dict.keys()) 261 | top_40_accuracy /= 0.5*len(dict.keys()) 262 | 263 | return top_40_accuracy, bottom_40_accuracy 264 | 265 | 266 | def calc_overall_class_average_iou(dict): 267 | """Calculate SOA-C-IoU""" 268 | iou = 0 269 | for label in dict.keys(): 270 | if dict[label]["iou"] is not None and dict[label]["iou"] >= 0: 271 | iou += dict[label]["iou"] 272 | overall_iou = iou / len(dict.keys()) 273 | return overall_iou 274 | 275 | 276 | def calc_image_weighted_average_iou(dict): 277 | """Calculate SOA-I-IoU""" 278 | iou = 0 279 | total_images = 0 280 | for label in dict.keys(): 281 | num_images = dict[label]["images_total"] 282 | if dict[label]["iou"] is not None and dict[label]["iou"] >= 0: 283 | iou += num_images * dict[label]["iou"] 284 | total_images += num_images 285 | overall_iou = iou / total_images 286 | return overall_iou 287 | 288 | 289 | def calc_split_class_average_iou(dict): 290 | """Calculate SOA-C-IoU-Top/Bot-40""" 291 | num_img_list = [] 292 | for label in dict.keys(): 293 | num_img_list.append([label, dict[label]["images_total"]]) 294 | num_img_list.sort(key=lambda x: x[1]) 295 | sorted_label_list = [x[0] for x in num_img_list] 296 | 297 | bottom_40_iou = 0 298 | top_40_iou = 0 299 | for label in dict.keys(): 300 | if sorted_label_list.index(label) < 40: 301 | if dict[label]["iou"] is not None and dict[label]["iou"] >= 0: 302 | bottom_40_iou += dict[label]["iou"] 303 | else: 304 | if dict[label]["iou"] is not None and dict[label]["iou"] >= 0: 305 | top_40_iou += dict[label]["iou"] 306 | bottom_40_iou /= 0.5*len(dict.keys()) 307 | top_40_iou /= 0.5*len(dict.keys()) 308 | 309 | return top_40_iou, bottom_40_iou 310 | 311 | 312 | def calc_soa(args): 313 | """Calculate SOA scores""" 314 | results_dict = {} 315 | 316 | # find detection results 317 | yolo_detected_files = [os.path.join(args.output, _file) for _file in os.listdir(args.output) 318 | if _file.endswith(".pkl") and _file.startswith("detected_")] 319 | 320 | # go through yolo detection and check how often it detected the desired object (based on the label) 321 | for yolo_file in yolo_detected_files: 322 | yolo = load_file(yolo_file) 323 | label = get_label(yolo_file) 324 | acc, correctly_recog, num_imgs_total = calc_recall(yolo, label) 325 | 326 | results_dict[label] = {} 327 | results_dict[label]["accuracy"] = acc 328 | results_dict[label]["images_recognized"] = correctly_recog 329 | results_dict[label]["images_total"] = num_imgs_total 330 | 331 | # calculate SOA-C and SOA-I 332 | print("") 333 | class_average_acc = calc_overall_class_average_accuracy(results_dict) 334 | print("Class average accuracy for all classes (SOA-C) is: {:6.4f}".format(class_average_acc)) 335 | 336 | image_average_acc = calc_image_weighted_average_accuracy(results_dict) 337 | print("Image weighted average accuracy (SOA-I) is: {:6.4f}".format(image_average_acc)) 338 | 339 | top_40_class_average_acc, bottom_40_class_average_acc = calc_split_class_average_accuracy(results_dict) 340 | print("Top (SOA-C-Top40) and Bottom (SOA-C-Bot40) 40 class average accuracy is: {:6.4f} and {:6.4f}". 341 | format(top_40_class_average_acc, bottom_40_class_average_acc)) 342 | 343 | # if IoU is true calculate the IoU scores, too 344 | if args.iou: 345 | ground_truth_files = [os.path.join(args.output, _file) for _file in os.listdir(args.output) 346 | if _file.endswith(".pkl") and _file.startswith("ground_truth_")] 347 | 348 | yolo_detected_files = sorted(yolo_detected_files) 349 | ground_truth_files = sorted(ground_truth_files) 350 | 351 | for yolo_file, gt_file in zip(yolo_detected_files, ground_truth_files): 352 | yolo = load_file(yolo_file) 353 | gt = load_file(gt_file) 354 | label = get_label(yolo_file) 355 | iou = calc_iou(yolo, gt, label) 356 | 357 | results_dict[label]["iou"] = iou 358 | 359 | print("") 360 | class_average_iou = calc_overall_class_average_iou(results_dict) 361 | print("Class average IoU for all classes (SOA-C-IoU) is: {:6.4f}".format(class_average_iou)) 362 | 363 | image_average_iou = calc_image_weighted_average_iou(results_dict) 364 | print("Image weighted average IoU (SOA-I-IoU) is: {:6.4f}".format(image_average_iou)) 365 | 366 | top_40_class_average_iou, bottom_40_class_average_iou = calc_split_class_average_iou(results_dict) 367 | print("Top (SOA-C-Top40-IoU) and Bottom (SOA-C-Bot40-IoU) 40 class average IoU is: {:6.4f} and {:6.4f}". 368 | format(top_40_class_average_iou, bottom_40_class_average_iou)) 369 | 370 | # store results 371 | with open(os.path.join(args.output, "result_file.pkl"), "wb") as f: 372 | pkl.dump(results_dict, f) 373 | 374 | 375 | if __name__ == '__main__': 376 | args = arg_parse() 377 | 378 | # use YOLOv3 on all images 379 | print("Using YOLOv3 Network on Generated Images...") 380 | run_yolo(args) 381 | 382 | # calculate score 383 | print("Calculating SOA Score...") 384 | calc_soa(args) 385 | 386 | -------------------------------------------------------------------------------- /SOA/captions/label_00_person.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_00_person.pkl -------------------------------------------------------------------------------- /SOA/captions/label_01_bicycle.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_01_bicycle.pkl -------------------------------------------------------------------------------- /SOA/captions/label_02_car.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_02_car.pkl -------------------------------------------------------------------------------- /SOA/captions/label_03_motorcycle.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_03_motorcycle.pkl -------------------------------------------------------------------------------- /SOA/captions/label_04_plane.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_04_plane.pkl -------------------------------------------------------------------------------- /SOA/captions/label_05_bus.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_05_bus.pkl -------------------------------------------------------------------------------- /SOA/captions/label_06_train.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_06_train.pkl -------------------------------------------------------------------------------- /SOA/captions/label_07_truck.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_07_truck.pkl -------------------------------------------------------------------------------- /SOA/captions/label_08_boat.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_08_boat.pkl -------------------------------------------------------------------------------- /SOA/captions/label_09_trafficlight.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_09_trafficlight.pkl -------------------------------------------------------------------------------- /SOA/captions/label_10_hydrant.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_10_hydrant.pkl -------------------------------------------------------------------------------- /SOA/captions/label_11_stopsign.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_11_stopsign.pkl -------------------------------------------------------------------------------- /SOA/captions/label_12_parkingmeter.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_12_parkingmeter.pkl -------------------------------------------------------------------------------- /SOA/captions/label_13_bench.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_13_bench.pkl -------------------------------------------------------------------------------- /SOA/captions/label_14_bird.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_14_bird.pkl -------------------------------------------------------------------------------- /SOA/captions/label_15_cat.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_15_cat.pkl -------------------------------------------------------------------------------- /SOA/captions/label_16_dog.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_16_dog.pkl -------------------------------------------------------------------------------- /SOA/captions/label_17_horse.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_17_horse.pkl -------------------------------------------------------------------------------- /SOA/captions/label_18_sheep.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_18_sheep.pkl -------------------------------------------------------------------------------- /SOA/captions/label_19_cow.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_19_cow.pkl -------------------------------------------------------------------------------- /SOA/captions/label_20_elephant.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_20_elephant.pkl -------------------------------------------------------------------------------- /SOA/captions/label_21_bear.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_21_bear.pkl -------------------------------------------------------------------------------- /SOA/captions/label_22_zebra.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_22_zebra.pkl -------------------------------------------------------------------------------- /SOA/captions/label_23_giraffe.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_23_giraffe.pkl -------------------------------------------------------------------------------- /SOA/captions/label_24_backpack.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_24_backpack.pkl -------------------------------------------------------------------------------- /SOA/captions/label_25_umbrella.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_25_umbrella.pkl -------------------------------------------------------------------------------- /SOA/captions/label_26_handbag.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_26_handbag.pkl -------------------------------------------------------------------------------- /SOA/captions/label_27_tie.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_27_tie.pkl -------------------------------------------------------------------------------- /SOA/captions/label_28_suitcase.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_28_suitcase.pkl -------------------------------------------------------------------------------- /SOA/captions/label_29_frisbee.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_29_frisbee.pkl -------------------------------------------------------------------------------- /SOA/captions/label_30_skis.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_30_skis.pkl -------------------------------------------------------------------------------- /SOA/captions/label_31_snowboard.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_31_snowboard.pkl -------------------------------------------------------------------------------- /SOA/captions/label_32_ball.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_32_ball.pkl -------------------------------------------------------------------------------- /SOA/captions/label_33_kite.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_33_kite.pkl -------------------------------------------------------------------------------- /SOA/captions/label_34_baseballbat.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_34_baseballbat.pkl -------------------------------------------------------------------------------- /SOA/captions/label_35_baseballglove.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_35_baseballglove.pkl -------------------------------------------------------------------------------- /SOA/captions/label_36_skateboard.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_36_skateboard.pkl -------------------------------------------------------------------------------- /SOA/captions/label_37_surfboard.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_37_surfboard.pkl -------------------------------------------------------------------------------- /SOA/captions/label_38_racket.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_38_racket.pkl -------------------------------------------------------------------------------- /SOA/captions/label_39_bottle.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_39_bottle.pkl -------------------------------------------------------------------------------- /SOA/captions/label_40_wineglass.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_40_wineglass.pkl -------------------------------------------------------------------------------- /SOA/captions/label_41_cup.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_41_cup.pkl -------------------------------------------------------------------------------- /SOA/captions/label_42_fork.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_42_fork.pkl -------------------------------------------------------------------------------- /SOA/captions/label_43_knife.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_43_knife.pkl -------------------------------------------------------------------------------- /SOA/captions/label_44_spoon.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_44_spoon.pkl -------------------------------------------------------------------------------- /SOA/captions/label_45_bowl.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_45_bowl.pkl -------------------------------------------------------------------------------- /SOA/captions/label_46_banana.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_46_banana.pkl -------------------------------------------------------------------------------- /SOA/captions/label_47_apple.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_47_apple.pkl -------------------------------------------------------------------------------- /SOA/captions/label_48_sandwich.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_48_sandwich.pkl -------------------------------------------------------------------------------- /SOA/captions/label_49_oranges.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_49_oranges.pkl -------------------------------------------------------------------------------- /SOA/captions/label_50_broccoli.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_50_broccoli.pkl -------------------------------------------------------------------------------- /SOA/captions/label_51_carrot.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_51_carrot.pkl -------------------------------------------------------------------------------- /SOA/captions/label_52_hotdog.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_52_hotdog.pkl -------------------------------------------------------------------------------- /SOA/captions/label_53_pizza.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_53_pizza.pkl -------------------------------------------------------------------------------- /SOA/captions/label_54_donut.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_54_donut.pkl -------------------------------------------------------------------------------- /SOA/captions/label_55_cake.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_55_cake.pkl -------------------------------------------------------------------------------- /SOA/captions/label_56_chair.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_56_chair.pkl -------------------------------------------------------------------------------- /SOA/captions/label_57_sofa.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_57_sofa.pkl -------------------------------------------------------------------------------- /SOA/captions/label_58_pottedplant.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_58_pottedplant.pkl -------------------------------------------------------------------------------- /SOA/captions/label_59_bed.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_59_bed.pkl -------------------------------------------------------------------------------- /SOA/captions/label_60_table.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_60_table.pkl -------------------------------------------------------------------------------- /SOA/captions/label_61_toilet.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_61_toilet.pkl -------------------------------------------------------------------------------- /SOA/captions/label_62_monitor.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_62_monitor.pkl -------------------------------------------------------------------------------- /SOA/captions/label_63_laptop.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_63_laptop.pkl -------------------------------------------------------------------------------- /SOA/captions/label_64_computermouse.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_64_computermouse.pkl -------------------------------------------------------------------------------- /SOA/captions/label_65_remote.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_65_remote.pkl -------------------------------------------------------------------------------- /SOA/captions/label_66_keyboard.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_66_keyboard.pkl -------------------------------------------------------------------------------- /SOA/captions/label_67_cellphone.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_67_cellphone.pkl -------------------------------------------------------------------------------- /SOA/captions/label_68_microwave.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_68_microwave.pkl -------------------------------------------------------------------------------- /SOA/captions/label_69_oven.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_69_oven.pkl -------------------------------------------------------------------------------- /SOA/captions/label_70_toaster.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_70_toaster.pkl -------------------------------------------------------------------------------- /SOA/captions/label_71_sink.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_71_sink.pkl -------------------------------------------------------------------------------- /SOA/captions/label_72_refrigerator.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_72_refrigerator.pkl -------------------------------------------------------------------------------- /SOA/captions/label_73_book.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_73_book.pkl -------------------------------------------------------------------------------- /SOA/captions/label_74_clock.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_74_clock.pkl -------------------------------------------------------------------------------- /SOA/captions/label_75_vase.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_75_vase.pkl -------------------------------------------------------------------------------- /SOA/captions/label_76_scissor.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_76_scissor.pkl -------------------------------------------------------------------------------- /SOA/captions/label_77_teddybear.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_77_teddybear.pkl -------------------------------------------------------------------------------- /SOA/captions/label_78_hairdrier.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_78_hairdrier.pkl -------------------------------------------------------------------------------- /SOA/captions/label_79_toothbrush.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tohinz/semantic-object-accuracy-for-generative-text-to-image-synthesis/bd00ed59fe6e3621b6baa977b74838c6b09d7d62/SOA/captions/label_79_toothbrush.pkl -------------------------------------------------------------------------------- /SOA/cfg/yolov3.cfg: -------------------------------------------------------------------------------- 1 | [net] 2 | # Testing 3 | batch=1 4 | subdivisions=1 5 | # Training 6 | # batch=64 7 | # subdivisions=16 8 | width= 320 9 | height = 320 10 | channels=3 11 | momentum=0.9 12 | decay=0.0005 13 | angle=0 14 | saturation = 1.5 15 | exposure = 1.5 16 | hue=.1 17 | 18 | learning_rate=0.001 19 | burn_in=1000 20 | max_batches = 500200 21 | policy=steps 22 | steps=400000,450000 23 | scales=.1,.1 24 | 25 | [convolutional] 26 | batch_normalize=1 27 | filters=32 28 | size=3 29 | stride=1 30 | pad=1 31 | activation=leaky 32 | 33 | # Downsample 34 | 35 | [convolutional] 36 | batch_normalize=1 37 | filters=64 38 | size=3 39 | stride=2 40 | pad=1 41 | activation=leaky 42 | 43 | [convolutional] 44 | batch_normalize=1 45 | filters=32 46 | size=1 47 | stride=1 48 | pad=1 49 | activation=leaky 50 | 51 | [convolutional] 52 | batch_normalize=1 53 | filters=64 54 | size=3 55 | stride=1 56 | pad=1 57 | activation=leaky 58 | 59 | [shortcut] 60 | from=-3 61 | activation=linear 62 | 63 | # Downsample 64 | 65 | [convolutional] 66 | batch_normalize=1 67 | filters=128 68 | size=3 69 | stride=2 70 | pad=1 71 | activation=leaky 72 | 73 | [convolutional] 74 | batch_normalize=1 75 | filters=64 76 | size=1 77 | stride=1 78 | pad=1 79 | activation=leaky 80 | 81 | [convolutional] 82 | batch_normalize=1 83 | filters=128 84 | size=3 85 | stride=1 86 | pad=1 87 | activation=leaky 88 | 89 | [shortcut] 90 | from=-3 91 | activation=linear 92 | 93 | [convolutional] 94 | batch_normalize=1 95 | filters=64 96 | size=1 97 | stride=1 98 | pad=1 99 | activation=leaky 100 | 101 | [convolutional] 102 | batch_normalize=1 103 | filters=128 104 | size=3 105 | stride=1 106 | pad=1 107 | activation=leaky 108 | 109 | [shortcut] 110 | from=-3 111 | activation=linear 112 | 113 | # Downsample 114 | 115 | [convolutional] 116 | batch_normalize=1 117 | filters=256 118 | size=3 119 | stride=2 120 | pad=1 121 | activation=leaky 122 | 123 | [convolutional] 124 | batch_normalize=1 125 | filters=128 126 | size=1 127 | stride=1 128 | pad=1 129 | activation=leaky 130 | 131 | [convolutional] 132 | batch_normalize=1 133 | filters=256 134 | size=3 135 | stride=1 136 | pad=1 137 | activation=leaky 138 | 139 | [shortcut] 140 | from=-3 141 | activation=linear 142 | 143 | [convolutional] 144 | batch_normalize=1 145 | filters=128 146 | size=1 147 | stride=1 148 | pad=1 149 | activation=leaky 150 | 151 | [convolutional] 152 | batch_normalize=1 153 | filters=256 154 | size=3 155 | stride=1 156 | pad=1 157 | activation=leaky 158 | 159 | [shortcut] 160 | from=-3 161 | activation=linear 162 | 163 | [convolutional] 164 | batch_normalize=1 165 | filters=128 166 | size=1 167 | stride=1 168 | pad=1 169 | activation=leaky 170 | 171 | [convolutional] 172 | batch_normalize=1 173 | filters=256 174 | size=3 175 | stride=1 176 | pad=1 177 | activation=leaky 178 | 179 | [shortcut] 180 | from=-3 181 | activation=linear 182 | 183 | [convolutional] 184 | batch_normalize=1 185 | filters=128 186 | size=1 187 | stride=1 188 | pad=1 189 | activation=leaky 190 | 191 | [convolutional] 192 | batch_normalize=1 193 | filters=256 194 | size=3 195 | stride=1 196 | pad=1 197 | activation=leaky 198 | 199 | [shortcut] 200 | from=-3 201 | activation=linear 202 | 203 | 204 | [convolutional] 205 | batch_normalize=1 206 | filters=128 207 | size=1 208 | stride=1 209 | pad=1 210 | activation=leaky 211 | 212 | [convolutional] 213 | batch_normalize=1 214 | filters=256 215 | size=3 216 | stride=1 217 | pad=1 218 | activation=leaky 219 | 220 | [shortcut] 221 | from=-3 222 | activation=linear 223 | 224 | [convolutional] 225 | batch_normalize=1 226 | filters=128 227 | size=1 228 | stride=1 229 | pad=1 230 | activation=leaky 231 | 232 | [convolutional] 233 | batch_normalize=1 234 | filters=256 235 | size=3 236 | stride=1 237 | pad=1 238 | activation=leaky 239 | 240 | [shortcut] 241 | from=-3 242 | activation=linear 243 | 244 | [convolutional] 245 | batch_normalize=1 246 | filters=128 247 | size=1 248 | stride=1 249 | pad=1 250 | activation=leaky 251 | 252 | [convolutional] 253 | batch_normalize=1 254 | filters=256 255 | size=3 256 | stride=1 257 | pad=1 258 | activation=leaky 259 | 260 | [shortcut] 261 | from=-3 262 | activation=linear 263 | 264 | [convolutional] 265 | batch_normalize=1 266 | filters=128 267 | size=1 268 | stride=1 269 | pad=1 270 | activation=leaky 271 | 272 | [convolutional] 273 | batch_normalize=1 274 | filters=256 275 | size=3 276 | stride=1 277 | pad=1 278 | activation=leaky 279 | 280 | [shortcut] 281 | from=-3 282 | activation=linear 283 | 284 | # Downsample 285 | 286 | [convolutional] 287 | batch_normalize=1 288 | filters=512 289 | size=3 290 | stride=2 291 | pad=1 292 | activation=leaky 293 | 294 | [convolutional] 295 | batch_normalize=1 296 | filters=256 297 | size=1 298 | stride=1 299 | pad=1 300 | activation=leaky 301 | 302 | [convolutional] 303 | batch_normalize=1 304 | filters=512 305 | size=3 306 | stride=1 307 | pad=1 308 | activation=leaky 309 | 310 | [shortcut] 311 | from=-3 312 | activation=linear 313 | 314 | 315 | [convolutional] 316 | batch_normalize=1 317 | filters=256 318 | size=1 319 | stride=1 320 | pad=1 321 | activation=leaky 322 | 323 | [convolutional] 324 | batch_normalize=1 325 | filters=512 326 | size=3 327 | stride=1 328 | pad=1 329 | activation=leaky 330 | 331 | [shortcut] 332 | from=-3 333 | activation=linear 334 | 335 | 336 | [convolutional] 337 | batch_normalize=1 338 | filters=256 339 | size=1 340 | stride=1 341 | pad=1 342 | activation=leaky 343 | 344 | [convolutional] 345 | batch_normalize=1 346 | filters=512 347 | size=3 348 | stride=1 349 | pad=1 350 | activation=leaky 351 | 352 | [shortcut] 353 | from=-3 354 | activation=linear 355 | 356 | 357 | [convolutional] 358 | batch_normalize=1 359 | filters=256 360 | size=1 361 | stride=1 362 | pad=1 363 | activation=leaky 364 | 365 | [convolutional] 366 | batch_normalize=1 367 | filters=512 368 | size=3 369 | stride=1 370 | pad=1 371 | activation=leaky 372 | 373 | [shortcut] 374 | from=-3 375 | activation=linear 376 | 377 | [convolutional] 378 | batch_normalize=1 379 | filters=256 380 | size=1 381 | stride=1 382 | pad=1 383 | activation=leaky 384 | 385 | [convolutional] 386 | batch_normalize=1 387 | filters=512 388 | size=3 389 | stride=1 390 | pad=1 391 | activation=leaky 392 | 393 | [shortcut] 394 | from=-3 395 | activation=linear 396 | 397 | 398 | [convolutional] 399 | batch_normalize=1 400 | filters=256 401 | size=1 402 | stride=1 403 | pad=1 404 | activation=leaky 405 | 406 | [convolutional] 407 | batch_normalize=1 408 | filters=512 409 | size=3 410 | stride=1 411 | pad=1 412 | activation=leaky 413 | 414 | [shortcut] 415 | from=-3 416 | activation=linear 417 | 418 | 419 | [convolutional] 420 | batch_normalize=1 421 | filters=256 422 | size=1 423 | stride=1 424 | pad=1 425 | activation=leaky 426 | 427 | [convolutional] 428 | batch_normalize=1 429 | filters=512 430 | size=3 431 | stride=1 432 | pad=1 433 | activation=leaky 434 | 435 | [shortcut] 436 | from=-3 437 | activation=linear 438 | 439 | [convolutional] 440 | batch_normalize=1 441 | filters=256 442 | size=1 443 | stride=1 444 | pad=1 445 | activation=leaky 446 | 447 | [convolutional] 448 | batch_normalize=1 449 | filters=512 450 | size=3 451 | stride=1 452 | pad=1 453 | activation=leaky 454 | 455 | [shortcut] 456 | from=-3 457 | activation=linear 458 | 459 | # Downsample 460 | 461 | [convolutional] 462 | batch_normalize=1 463 | filters=1024 464 | size=3 465 | stride=2 466 | pad=1 467 | activation=leaky 468 | 469 | [convolutional] 470 | batch_normalize=1 471 | filters=512 472 | size=1 473 | stride=1 474 | pad=1 475 | activation=leaky 476 | 477 | [convolutional] 478 | batch_normalize=1 479 | filters=1024 480 | size=3 481 | stride=1 482 | pad=1 483 | activation=leaky 484 | 485 | [shortcut] 486 | from=-3 487 | activation=linear 488 | 489 | [convolutional] 490 | batch_normalize=1 491 | filters=512 492 | size=1 493 | stride=1 494 | pad=1 495 | activation=leaky 496 | 497 | [convolutional] 498 | batch_normalize=1 499 | filters=1024 500 | size=3 501 | stride=1 502 | pad=1 503 | activation=leaky 504 | 505 | [shortcut] 506 | from=-3 507 | activation=linear 508 | 509 | [convolutional] 510 | batch_normalize=1 511 | filters=512 512 | size=1 513 | stride=1 514 | pad=1 515 | activation=leaky 516 | 517 | [convolutional] 518 | batch_normalize=1 519 | filters=1024 520 | size=3 521 | stride=1 522 | pad=1 523 | activation=leaky 524 | 525 | [shortcut] 526 | from=-3 527 | activation=linear 528 | 529 | [convolutional] 530 | batch_normalize=1 531 | filters=512 532 | size=1 533 | stride=1 534 | pad=1 535 | activation=leaky 536 | 537 | [convolutional] 538 | batch_normalize=1 539 | filters=1024 540 | size=3 541 | stride=1 542 | pad=1 543 | activation=leaky 544 | 545 | [shortcut] 546 | from=-3 547 | activation=linear 548 | 549 | ###################### 550 | 551 | [convolutional] 552 | batch_normalize=1 553 | filters=512 554 | size=1 555 | stride=1 556 | pad=1 557 | activation=leaky 558 | 559 | [convolutional] 560 | batch_normalize=1 561 | size=3 562 | stride=1 563 | pad=1 564 | filters=1024 565 | activation=leaky 566 | 567 | [convolutional] 568 | batch_normalize=1 569 | filters=512 570 | size=1 571 | stride=1 572 | pad=1 573 | activation=leaky 574 | 575 | [convolutional] 576 | batch_normalize=1 577 | size=3 578 | stride=1 579 | pad=1 580 | filters=1024 581 | activation=leaky 582 | 583 | [convolutional] 584 | batch_normalize=1 585 | filters=512 586 | size=1 587 | stride=1 588 | pad=1 589 | activation=leaky 590 | 591 | [convolutional] 592 | batch_normalize=1 593 | size=3 594 | stride=1 595 | pad=1 596 | filters=1024 597 | activation=leaky 598 | 599 | [convolutional] 600 | size=1 601 | stride=1 602 | pad=1 603 | filters=255 604 | activation=linear 605 | 606 | 607 | [yolo] 608 | mask = 6,7,8 609 | anchors = 10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326 610 | classes=80 611 | num=9 612 | jitter=.3 613 | ignore_thresh = .5 614 | truth_thresh = 1 615 | random=1 616 | 617 | 618 | [route] 619 | layers = -4 620 | 621 | [convolutional] 622 | batch_normalize=1 623 | filters=256 624 | size=1 625 | stride=1 626 | pad=1 627 | activation=leaky 628 | 629 | [upsample] 630 | stride=2 631 | 632 | [route] 633 | layers = -1, 61 634 | 635 | 636 | 637 | [convolutional] 638 | batch_normalize=1 639 | filters=256 640 | size=1 641 | stride=1 642 | pad=1 643 | activation=leaky 644 | 645 | [convolutional] 646 | batch_normalize=1 647 | size=3 648 | stride=1 649 | pad=1 650 | filters=512 651 | activation=leaky 652 | 653 | [convolutional] 654 | batch_normalize=1 655 | filters=256 656 | size=1 657 | stride=1 658 | pad=1 659 | activation=leaky 660 | 661 | [convolutional] 662 | batch_normalize=1 663 | size=3 664 | stride=1 665 | pad=1 666 | filters=512 667 | activation=leaky 668 | 669 | [convolutional] 670 | batch_normalize=1 671 | filters=256 672 | size=1 673 | stride=1 674 | pad=1 675 | activation=leaky 676 | 677 | [convolutional] 678 | batch_normalize=1 679 | size=3 680 | stride=1 681 | pad=1 682 | filters=512 683 | activation=leaky 684 | 685 | [convolutional] 686 | size=1 687 | stride=1 688 | pad=1 689 | filters=255 690 | activation=linear 691 | 692 | 693 | [yolo] 694 | mask = 3,4,5 695 | anchors = 10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326 696 | classes=80 697 | num=9 698 | jitter=.3 699 | ignore_thresh = .5 700 | truth_thresh = 1 701 | random=1 702 | 703 | 704 | 705 | [route] 706 | layers = -4 707 | 708 | [convolutional] 709 | batch_normalize=1 710 | filters=128 711 | size=1 712 | stride=1 713 | pad=1 714 | activation=leaky 715 | 716 | [upsample] 717 | stride=2 718 | 719 | [route] 720 | layers = -1, 36 721 | 722 | 723 | 724 | [convolutional] 725 | batch_normalize=1 726 | filters=128 727 | size=1 728 | stride=1 729 | pad=1 730 | activation=leaky 731 | 732 | [convolutional] 733 | batch_normalize=1 734 | size=3 735 | stride=1 736 | pad=1 737 | filters=256 738 | activation=leaky 739 | 740 | [convolutional] 741 | batch_normalize=1 742 | filters=128 743 | size=1 744 | stride=1 745 | pad=1 746 | activation=leaky 747 | 748 | [convolutional] 749 | batch_normalize=1 750 | size=3 751 | stride=1 752 | pad=1 753 | filters=256 754 | activation=leaky 755 | 756 | [convolutional] 757 | batch_normalize=1 758 | filters=128 759 | size=1 760 | stride=1 761 | pad=1 762 | activation=leaky 763 | 764 | [convolutional] 765 | batch_normalize=1 766 | size=3 767 | stride=1 768 | pad=1 769 | filters=256 770 | activation=leaky 771 | 772 | [convolutional] 773 | size=1 774 | stride=1 775 | pad=1 776 | filters=255 777 | activation=linear 778 | 779 | 780 | [yolo] 781 | mask = 0,1,2 782 | anchors = 10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326 783 | classes=80 784 | num=9 785 | jitter=.3 786 | ignore_thresh = .5 787 | truth_thresh = 1 788 | random=1 789 | 790 | -------------------------------------------------------------------------------- /SOA/darknet.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | import numpy as np 8 | import pickle as pkl 9 | 10 | from util import predict_transform 11 | 12 | 13 | def parse_cfg(cfgfile): 14 | """ 15 | Takes a configuration file 16 | 17 | Returns a list of blocks. Each blocks describes a block in the neural 18 | network to be built. Block is represented as a dictionary in the list 19 | 20 | """ 21 | file = open(cfgfile, 'r') 22 | lines = file.read().split('\n') #store the lines in a list 23 | lines = [x for x in lines if len(x) > 0] #get read of the empty lines 24 | lines = [x for x in lines if x[0] != '#'] 25 | lines = [x.rstrip().lstrip() for x in lines] 26 | 27 | block = {} 28 | blocks = [] 29 | 30 | for line in lines: 31 | if line[0] == "[": #This marks the start of a new block 32 | if len(block) != 0: 33 | blocks.append(block) 34 | block = {} 35 | block["type"] = line[1:-1].rstrip() 36 | else: 37 | key,value = line.split("=") 38 | block[key.rstrip()] = value.lstrip() 39 | blocks.append(block) 40 | 41 | return blocks 42 | 43 | 44 | class MaxPoolStride1(nn.Module): 45 | def __init__(self, kernel_size): 46 | super(MaxPoolStride1, self).__init__() 47 | self.kernel_size = kernel_size 48 | self.pad = kernel_size - 1 49 | 50 | def forward(self, x): 51 | padded_x = F.pad(x, (0,self.pad,0,self.pad), mode="replicate") 52 | pooled_x = nn.MaxPool2d(self.kernel_size, self.pad)(padded_x) 53 | return pooled_x 54 | 55 | 56 | class EmptyLayer(nn.Module): 57 | def __init__(self): 58 | super(EmptyLayer, self).__init__() 59 | 60 | 61 | class DetectionLayer(nn.Module): 62 | def __init__(self, anchors): 63 | super(DetectionLayer, self).__init__() 64 | self.anchors = anchors 65 | 66 | def forward(self, x, inp_dim, num_classes, confidence): 67 | x = x.data 68 | global CUDA 69 | prediction = x 70 | prediction = predict_transform(prediction, inp_dim, self.anchors, num_classes, confidence, CUDA) 71 | return prediction 72 | 73 | 74 | class Upsample(nn.Module): 75 | def __init__(self, stride=2): 76 | super(Upsample, self).__init__() 77 | self.stride = stride 78 | 79 | def forward(self, x): 80 | stride = self.stride 81 | assert(x.data.dim() == 4) 82 | B = x.data.size(0) 83 | C = x.data.size(1) 84 | H = x.data.size(2) 85 | W = x.data.size(3) 86 | ws = stride 87 | hs = stride 88 | x = x.view(B, C, H, 1, W, 1).expand(B, C, H, stride, W, stride).contiguous().view(B, C, H*stride, W*stride) 89 | return x 90 | 91 | 92 | class ReOrgLayer(nn.Module): 93 | def __init__(self, stride = 2): 94 | super(ReOrgLayer, self).__init__() 95 | self.stride= stride 96 | 97 | def forward(self,x): 98 | assert(x.data.dim() == 4) 99 | B,C,H,W = x.data.shape 100 | hs = self.stride 101 | ws = self.stride 102 | assert(H % hs == 0), "The stride " + str(self.stride) + " is not a proper divisor of height " + str(H) 103 | assert(W % ws == 0), "The stride " + str(self.stride) + " is not a proper divisor of height " + str(W) 104 | x = x.view(B,C, H // hs, hs, W // ws, ws).transpose(-2,-3).contiguous() 105 | x = x.view(B,C, H // hs * W // ws, hs, ws) 106 | x = x.view(B,C, H // hs * W // ws, hs*ws).transpose(-1,-2).contiguous() 107 | x = x.view(B, C, ws*hs, H // ws, W // ws).transpose(1,2).contiguous() 108 | x = x.view(B, C*ws*hs, H // ws, W // ws) 109 | return x 110 | 111 | 112 | def create_modules(blocks): 113 | net_info = blocks[0] #Captures the information about the input and pre-processing 114 | 115 | module_list = nn.ModuleList() 116 | 117 | index = 0 #indexing blocks helps with implementing route layers (skip connections) 118 | 119 | prev_filters = 3 120 | 121 | output_filters = [] 122 | 123 | for x in blocks: 124 | module = nn.Sequential() 125 | 126 | if (x["type"] == "net"): 127 | continue 128 | 129 | #If it's a convolutional layer 130 | if (x["type"] == "convolutional"): 131 | #Get the info about the layer 132 | activation = x["activation"] 133 | try: 134 | batch_normalize = int(x["batch_normalize"]) 135 | bias = False 136 | except: 137 | batch_normalize = 0 138 | bias = True 139 | 140 | filters= int(x["filters"]) 141 | padding = int(x["pad"]) 142 | kernel_size = int(x["size"]) 143 | stride = int(x["stride"]) 144 | 145 | if padding: 146 | pad = (kernel_size - 1) // 2 147 | else: 148 | pad = 0 149 | 150 | #Add the convolutional layer 151 | conv = nn.Conv2d(prev_filters, filters, kernel_size, stride, pad, bias = bias) 152 | module.add_module("conv_{0}".format(index), conv) 153 | 154 | #Add the Batch Norm Layer 155 | if batch_normalize: 156 | bn = nn.BatchNorm2d(filters) 157 | module.add_module("batch_norm_{0}".format(index), bn) 158 | 159 | #Check the activation. 160 | #It is either Linear or a Leaky ReLU for YOLO 161 | if activation == "leaky": 162 | activn = nn.LeakyReLU(0.1, inplace = True) 163 | module.add_module("leaky_{0}".format(index), activn) 164 | 165 | #If it's an upsampling layer 166 | #We use Bilinear2dUpsampling 167 | 168 | elif (x["type"] == "upsample"): 169 | stride = int(x["stride"]) 170 | # upsample = Upsample(stride) 171 | upsample = nn.Upsample(scale_factor = 2, mode = "nearest") 172 | module.add_module("upsample_{}".format(index), upsample) 173 | 174 | #If it is a route layer 175 | elif (x["type"] == "route"): 176 | x["layers"] = x["layers"].split(',') 177 | 178 | #Start of a route 179 | start = int(x["layers"][0]) 180 | 181 | #end, if there exists one. 182 | try: 183 | end = int(x["layers"][1]) 184 | except: 185 | end = 0 186 | 187 | #Positive anotation 188 | if start > 0: 189 | start = start - index 190 | 191 | if end > 0: 192 | end = end - index 193 | 194 | route = EmptyLayer() 195 | module.add_module("route_{0}".format(index), route) 196 | 197 | if end < 0: 198 | filters = output_filters[index + start] + output_filters[index + end] 199 | else: 200 | filters= output_filters[index + start] 201 | 202 | #shortcut corresponds to skip connection 203 | elif x["type"] == "shortcut": 204 | from_ = int(x["from"]) 205 | shortcut = EmptyLayer() 206 | module.add_module("shortcut_{}".format(index), shortcut) 207 | 208 | elif x["type"] == "maxpool": 209 | stride = int(x["stride"]) 210 | size = int(x["size"]) 211 | if stride != 1: 212 | maxpool = nn.MaxPool2d(size, stride) 213 | else: 214 | maxpool = MaxPoolStride1(size) 215 | 216 | module.add_module("maxpool_{}".format(index), maxpool) 217 | 218 | #Yolo is the detection layer 219 | elif x["type"] == "yolo": 220 | mask = x["mask"].split(",") 221 | mask = [int(x) for x in mask] 222 | 223 | anchors = x["anchors"].split(",") 224 | anchors = [int(a) for a in anchors] 225 | anchors = [(anchors[i], anchors[i+1]) for i in range(0, len(anchors),2)] 226 | anchors = [anchors[i] for i in mask] 227 | 228 | detection = DetectionLayer(anchors) 229 | module.add_module("Detection_{}".format(index), detection) 230 | else: 231 | print("Something I dunno") 232 | assert False 233 | 234 | module_list.append(module) 235 | prev_filters = filters 236 | output_filters.append(filters) 237 | index += 1 238 | 239 | return (net_info, module_list) 240 | 241 | 242 | class Darknet(nn.Module): 243 | def __init__(self, cfgfile): 244 | super(Darknet, self).__init__() 245 | self.blocks = parse_cfg(cfgfile) 246 | self.net_info, self.module_list = create_modules(self.blocks) 247 | self.header = torch.IntTensor([0,0,0,0]) 248 | self.seen = 0 249 | 250 | def get_blocks(self): 251 | return self.blocks 252 | 253 | def get_module_list(self): 254 | return self.module_list 255 | 256 | def forward(self, x, CUDA): 257 | detections = [] 258 | modules = self.blocks[1:] 259 | outputs = {} #We cache the outputs for the route layer 260 | 261 | write = 0 262 | for i in range(len(modules)): 263 | 264 | module_type = (modules[i]["type"]) 265 | if module_type == "convolutional" or module_type == "upsample" or module_type == "maxpool": 266 | 267 | x = self.module_list[i](x) 268 | outputs[i] = x 269 | 270 | elif module_type == "route": 271 | layers = modules[i]["layers"] 272 | layers = [int(a) for a in layers] 273 | 274 | if (layers[0]) > 0: 275 | layers[0] = layers[0] - i 276 | 277 | if len(layers) == 1: 278 | x = outputs[i + (layers[0])] 279 | 280 | else: 281 | if (layers[1]) > 0: 282 | layers[1] = layers[1] - i 283 | 284 | map1 = outputs[i + layers[0]] 285 | map2 = outputs[i + layers[1]] 286 | 287 | x = torch.cat((map1, map2), 1) 288 | outputs[i] = x 289 | elif module_type == "shortcut": 290 | from_ = int(modules[i]["from"]) 291 | x = outputs[i-1] + outputs[i+from_] 292 | outputs[i] = x 293 | elif module_type == 'yolo': 294 | 295 | anchors = self.module_list[i][0].anchors 296 | #Get the input dimensions 297 | inp_dim = int (self.net_info["height"]) 298 | 299 | #Get the number of classes 300 | num_classes = int (modules[i]["classes"]) 301 | 302 | #Output the result 303 | x = x.data 304 | x = predict_transform(x, inp_dim, anchors, num_classes, CUDA) 305 | 306 | if type(x) == int: 307 | continue 308 | 309 | if not write: 310 | detections = x 311 | write = 1 312 | 313 | else: 314 | detections = torch.cat((detections, x), 1) 315 | 316 | outputs[i] = outputs[i-1] 317 | 318 | try: 319 | return detections 320 | except: 321 | return 0 322 | 323 | def load_weights(self, weightfile): 324 | 325 | #Open the weights file 326 | fp = open(weightfile, "rb") 327 | 328 | #The first 4 values are header information 329 | # 1. Major version number 330 | # 2. Minor Version Number 331 | # 3. Subversion number 332 | # 4. IMages seen 333 | header = np.fromfile(fp, dtype = np.int32, count=5) 334 | self.header = torch.from_numpy(header) 335 | self.seen = self.header[3] 336 | 337 | #The rest of the values are the weights 338 | # Let's load them up 339 | weights = np.fromfile(fp, dtype = np.float32) 340 | 341 | ptr = 0 342 | for i in range(len(self.module_list)): 343 | module_type = self.blocks[i + 1]["type"] 344 | 345 | if module_type == "convolutional": 346 | model = self.module_list[i] 347 | try: 348 | batch_normalize = int(self.blocks[i+1]["batch_normalize"]) 349 | except: 350 | batch_normalize = 0 351 | 352 | conv = model[0] 353 | 354 | if (batch_normalize): 355 | bn = model[1] 356 | 357 | #Get the number of weights of Batch Norm Layer 358 | num_bn_biases = bn.bias.numel() 359 | 360 | #Load the weights 361 | bn_biases = torch.from_numpy(weights[ptr:ptr + num_bn_biases]) 362 | ptr += num_bn_biases 363 | 364 | bn_weights = torch.from_numpy(weights[ptr: ptr + num_bn_biases]) 365 | ptr += num_bn_biases 366 | 367 | bn_running_mean = torch.from_numpy(weights[ptr: ptr + num_bn_biases]) 368 | ptr += num_bn_biases 369 | 370 | bn_running_var = torch.from_numpy(weights[ptr: ptr + num_bn_biases]) 371 | ptr += num_bn_biases 372 | 373 | #Cast the loaded weights into dims of model weights. 374 | bn_biases = bn_biases.view_as(bn.bias.data) 375 | bn_weights = bn_weights.view_as(bn.weight.data) 376 | bn_running_mean = bn_running_mean.view_as(bn.running_mean) 377 | bn_running_var = bn_running_var.view_as(bn.running_var) 378 | 379 | #Copy the data to model 380 | bn.bias.data.copy_(bn_biases) 381 | bn.weight.data.copy_(bn_weights) 382 | bn.running_mean.copy_(bn_running_mean) 383 | bn.running_var.copy_(bn_running_var) 384 | 385 | else: 386 | #Number of biases 387 | num_biases = conv.bias.numel() 388 | 389 | #Load the weights 390 | conv_biases = torch.from_numpy(weights[ptr: ptr + num_biases]) 391 | ptr = ptr + num_biases 392 | 393 | #reshape the loaded weights according to the dims of the model weights 394 | conv_biases = conv_biases.view_as(conv.bias.data) 395 | 396 | #Finally copy the data 397 | conv.bias.data.copy_(conv_biases) 398 | 399 | 400 | #Let us load the weights for the Convolutional layers 401 | num_weights = conv.weight.numel() 402 | 403 | #Do the same as above for weights 404 | conv_weights = torch.from_numpy(weights[ptr:ptr+num_weights]) 405 | ptr = ptr + num_weights 406 | 407 | conv_weights = conv_weights.view_as(conv.weight.data) 408 | conv.weight.data.copy_(conv_weights) 409 | -------------------------------------------------------------------------------- /SOA/data/coco.names: -------------------------------------------------------------------------------- 1 | person 2 | bicycle 3 | car 4 | motorbike 5 | aeroplane 6 | bus 7 | train 8 | truck 9 | boat 10 | traffic light 11 | fire hydrant 12 | stop sign 13 | parking meter 14 | bench 15 | bird 16 | cat 17 | dog 18 | horse 19 | sheep 20 | cow 21 | elephant 22 | bear 23 | zebra 24 | giraffe 25 | backpack 26 | umbrella 27 | handbag 28 | tie 29 | suitcase 30 | frisbee 31 | skis 32 | snowboard 33 | sports ball 34 | kite 35 | baseball bat 36 | baseball glove 37 | skateboard 38 | surfboard 39 | tennis racket 40 | bottle 41 | wine glass 42 | cup 43 | fork 44 | knife 45 | spoon 46 | bowl 47 | banana 48 | apple 49 | sandwich 50 | orange 51 | broccoli 52 | carrot 53 | hot dog 54 | pizza 55 | donut 56 | cake 57 | chair 58 | sofa 59 | pottedplant 60 | bed 61 | diningtable 62 | toilet 63 | tvmonitor 64 | laptop 65 | mouse 66 | remote 67 | keyboard 68 | cell phone 69 | microwave 70 | oven 71 | toaster 72 | sink 73 | refrigerator 74 | book 75 | clock 76 | vase 77 | scissors 78 | teddy bear 79 | hair drier 80 | toothbrush 81 | -------------------------------------------------------------------------------- /SOA/dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import torch 7 | import torch.utils.data as data 8 | from torch.autograd import Variable 9 | import torchvision.transforms as transforms 10 | 11 | import os 12 | import os.path as osp 13 | from PIL import Image 14 | 15 | 16 | class YoloDataset(data.Dataset): 17 | def __init__(self, data_dir, transform=None, imsize=256.): 18 | self.transform = transform 19 | self.data_dir = data_dir 20 | self.imsize = imsize 21 | 22 | self.filenames = self.load_filenames(data_dir) 23 | 24 | def load_filenames(self, data_dir): 25 | filenames = [osp.join(data_dir, img) for img in os.listdir(data_dir) if os.path.splitext(img)[1] == '.png' 26 | or os.path.splitext(img)[1] == '.jpeg' or os.path.splitext(img)[1] == '.jpg'] 27 | return filenames 28 | 29 | def load_img(self, img_name): 30 | img = Image.open(img_name).convert('RGB') 31 | if self.transform is not None: 32 | img = self.transform(img) 33 | 34 | return img 35 | 36 | def __getitem__(self, index): 37 | filename = self.filenames[index] 38 | img = self.load_img(filename) 39 | 40 | return img, filename 41 | 42 | def __len__(self): 43 | return len(self.filenames) 44 | -------------------------------------------------------------------------------- /SOA/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.17.3 2 | Pillow==6.2.1 3 | pkg-resources==0.0.0 4 | six==1.12.0 5 | torch==1.3.0 6 | torchvision==0.4.1 7 | tqdm==4.27.0 8 | -------------------------------------------------------------------------------- /SOA/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | import pickle 7 | 8 | 9 | def load_classes(namesfile): 10 | fp = open(namesfile, "r") 11 | names = fp.read().split("\n")[:-1] 12 | return names 13 | 14 | 15 | def load_file(path): 16 | with open(path, "rb") as f: 17 | _file = pickle.load(f) 18 | return _file 19 | 20 | 21 | def get_label(path): 22 | idx = path.find("label_") 23 | try: 24 | label = int(path[idx+6:idx+8]) 25 | except: 26 | label = int(path[idx + 6:idx + 7]) 27 | return label 28 | 29 | 30 | def predict_transform(prediction, inp_dim, anchors, num_classes, CUDA=True): 31 | batch_size = prediction.size(0) 32 | stride = inp_dim // prediction.size(2) 33 | grid_size = inp_dim // stride 34 | bbox_attrs = 5 + num_classes 35 | num_anchors = len(anchors) 36 | 37 | anchors = [(a[0] / stride, a[1] / stride) for a in anchors] 38 | 39 | prediction = prediction.view(batch_size, bbox_attrs * num_anchors, grid_size * grid_size) 40 | prediction = prediction.transpose(1, 2).contiguous() 41 | prediction = prediction.view(batch_size, grid_size * grid_size * num_anchors, bbox_attrs) 42 | 43 | # Sigmoid the centre_X, centre_Y. and object confidencce 44 | prediction[:, :, 0] = torch.sigmoid(prediction[:, :, 0]) 45 | prediction[:, :, 1] = torch.sigmoid(prediction[:, :, 1]) 46 | prediction[:, :, 4] = torch.sigmoid(prediction[:, :, 4]) 47 | 48 | # Add the center offsets 49 | grid_len = np.arange(grid_size) 50 | a, b = np.meshgrid(grid_len, grid_len) 51 | 52 | x_offset = torch.FloatTensor(a).view(-1, 1) 53 | y_offset = torch.FloatTensor(b).view(-1, 1) 54 | 55 | if CUDA: 56 | x_offset = x_offset.cuda() 57 | y_offset = y_offset.cuda() 58 | 59 | x_y_offset = torch.cat((x_offset, y_offset), 1).repeat(1, num_anchors).view(-1, 2).unsqueeze(0) 60 | 61 | prediction[:, :, :2] += x_y_offset 62 | 63 | # log space transform height and the width 64 | anchors = torch.FloatTensor(anchors) 65 | 66 | if CUDA: 67 | anchors = anchors.cuda() 68 | 69 | anchors = anchors.repeat(grid_size * grid_size, 1).unsqueeze(0) 70 | prediction[:, :, 2:4] = torch.exp(prediction[:, :, 2:4]) * anchors 71 | 72 | # Softmax the class scores 73 | prediction[:, :, 5: 5 + num_classes] = torch.sigmoid((prediction[:, :, 5: 5 + num_classes])) 74 | 75 | prediction[:, :, :4] *= stride 76 | 77 | return prediction 78 | 79 | 80 | def xywh2xyxy(x): 81 | y = x.new(x.shape) 82 | y[..., 0] = x[..., 0] - x[..., 2] / 2 83 | y[..., 1] = x[..., 1] - x[..., 3] / 2 84 | y[..., 2] = x[..., 0] + x[..., 2] / 2 85 | y[..., 3] = x[..., 1] + x[..., 3] / 2 86 | return y 87 | 88 | 89 | def bbox_wh_iou(wh1, wh2): 90 | wh2 = wh2.t() 91 | w1, h1 = wh1[0], wh1[1] 92 | w2, h2 = wh2[0], wh2[1] 93 | inter_area = torch.min(w1, w2) * torch.min(h1, h2) 94 | union_area = (w1 * h1 + 1e-16) + w2 * h2 - inter_area 95 | return inter_area / union_area 96 | 97 | 98 | def bbox_iou(box1, box2, x1y1x2y2=True): 99 | """ 100 | Returns the IoU of two bounding boxes 101 | """ 102 | if not x1y1x2y2: 103 | # Transform from center and width to exact coordinates 104 | b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2 105 | b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2 106 | b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2 107 | b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2 108 | else: 109 | # Get the coordinates of bounding boxes 110 | b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3] 111 | b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3] 112 | 113 | # get the corrdinates of the intersection rectangle 114 | inter_rect_x1 = torch.max(b1_x1, b2_x1) 115 | inter_rect_y1 = torch.max(b1_y1, b2_y1) 116 | inter_rect_x2 = torch.min(b1_x2, b2_x2) 117 | inter_rect_y2 = torch.min(b1_y2, b2_y2) 118 | # Intersection area 119 | inter_area = torch.clamp(inter_rect_x2 - inter_rect_x1 + 1, min=0) * torch.clamp( 120 | inter_rect_y2 - inter_rect_y1 + 1, min=0 121 | ) 122 | # Union Area 123 | b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1) 124 | b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1) 125 | 126 | iou = inter_area / (b1_area + b2_area - inter_area + 1e-16) 127 | 128 | return iou 129 | 130 | 131 | def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4): 132 | """ 133 | Removes detections with lower object confidence score than 'conf_thres' and performs 134 | Non-Maximum Suppression to further filter detections. 135 | Returns detections with shape: 136 | (x1, y1, x2, y2, object_conf, class_score, class_pred) 137 | """ 138 | 139 | # From (center x, center y, width, height) to (x1, y1, x2, y2) 140 | prediction[..., :4] = xywh2xyxy(prediction[..., :4]) 141 | output = [None for _ in range(len(prediction))] 142 | for image_i, image_pred in enumerate(prediction): 143 | # Filter out confidence scores below threshold 144 | image_pred = image_pred[image_pred[:, 4] >= conf_thres] 145 | # If none are remaining => process next image 146 | if not image_pred.size(0): 147 | continue 148 | # Object confidence times class confidence 149 | score = image_pred[:, 4] * image_pred[:, 5:].max(1)[0] 150 | # Sort by it 151 | image_pred = image_pred[(-score).argsort()] 152 | class_confs, class_preds = image_pred[:, 5:].max(1, keepdim=True) 153 | detections = torch.cat((image_pred[:, :5], class_confs.float(), class_preds.float()), 1) 154 | # Perform non-maximum suppression 155 | keep_boxes = [] 156 | while detections.size(0): 157 | large_overlap = bbox_iou(detections[0, :4].unsqueeze(0), detections[:, :4]) > nms_thres 158 | label_match = detections[0, -1] == detections[:, -1] 159 | # Indices of boxes with lower confidence scores, large IOUs and matching labels 160 | invalid = large_overlap & label_match 161 | weights = detections[invalid, 4:5] 162 | # Merge overlapping bboxes by order of confidence 163 | detections[0, :4] = (weights * detections[invalid, :4]).sum(0) / weights.sum() 164 | keep_boxes += [detections[0]] 165 | detections = detections[~invalid] 166 | if keep_boxes: 167 | output[image_i] = torch.stack(keep_boxes) 168 | 169 | return output 170 | 171 | 172 | def get_iou(bb1, bb2): 173 | """ 174 | Calculate the Intersection over Union (IoU) of two bounding boxes. 175 | 176 | Parameters 177 | ---------- 178 | bb1 : dict 179 | Keys: {'x1', 'x2', 'y1', 'y2'} 180 | The (x1, y1) position is at the top left corner, 181 | the (x2, y2) position is at the bottom right corner 182 | bb2 : dict 183 | Keys: {'x1', 'x2', 'y1', 'y2'} 184 | The (x, y) position is at the top left corner, 185 | the (x2, y2) position is at the bottom right corner 186 | 187 | Returns 188 | ------- 189 | float 190 | in [0, 1] 191 | """ 192 | bb1_x1 = int(bb1[0]*256) 193 | bb1_y1 = int(bb1[1]*256) 194 | bb1_x2 = bb1_x1+int(bb1[2]*256) 195 | bb1_y2 = bb1_y1+int(bb1[3]*256) 196 | 197 | bb1_x1 = 0 if bb1_x1 < 0 else bb1_x1 198 | bb1_y1 = 0 if bb1_y1 < 0 else bb1_y1 199 | bb1_x2 = 0 if bb1_x2 < 0 else bb1_x2 200 | bb1_y2 = 0 if bb1_y2 < 0 else bb1_y2 201 | 202 | 203 | bb2_x1 = int(bb2[0]*256) 204 | bb2_y1 = int(bb2[1]*256) 205 | bb2_x2 = bb2_x1 + int(bb2[2]*256) 206 | bb2_y2 = bb2_y1 + int(bb2[3]*256) 207 | 208 | bb2_x1 = 0 if bb2_x1 < 0 else bb2_x1 209 | bb2_y1 = 0 if bb2_y1 < 0 else bb2_y1 210 | bb2_x2 = 0 if bb2_x2 < 0 else bb2_x2 211 | bb2_y2 = 0 if bb2_y2 < 0 else bb2_y2 212 | 213 | if not bb1_x1 < bb1_x2: 214 | bb1_x2 = bb1_x1 + 1 215 | if not bb1_y1 < bb1_y2: 216 | bb1_y2 = bb1_y1 + 1 217 | if not bb2_x1 < bb2_x2: 218 | bb2_x2 = bb2_x1 +1 219 | if not bb2_y1 < bb2_y2: 220 | bb2_y2 = bb2_y1 + 1 221 | 222 | # determine the coordinates of the intersection rectangle 223 | x_left = max(bb1_x1, bb2_x1) 224 | y_top = max(bb1_y1, bb2_y1) 225 | x_right = min(bb1_x2, bb2_x2) 226 | y_bottom = min(bb1_y2, bb2_y2) 227 | 228 | if x_right < x_left or y_bottom < y_top: 229 | return 0.0 230 | 231 | # The intersection of two axis-aligned bounding boxes is always an 232 | # axis-aligned bounding box 233 | intersection_area = (x_right - x_left) * (y_bottom - y_top) 234 | 235 | # compute the area of both AABBs 236 | bb1_area = (bb1_x2 - bb1_x1) * (bb1_y2 - bb1_y1) 237 | bb2_area = (bb2_x2 - bb2_x1) * (bb2_y2 - bb2_y1) 238 | 239 | # compute the intersection over union by taking the intersection 240 | # area and dividing it by the sum of prediction + ground-truth 241 | # areas - the interesection area 242 | iou = intersection_area / float(bb1_area + bb2_area - intersection_area) 243 | assert iou >= 0.0 244 | assert iou <= 1.0 245 | return iou --------------------------------------------------------------------------------