├── README.md ├── checkpoint └── coco │ └── .gitkeep ├── coco.py ├── data └── coco │ └── .gitkeep ├── engine.py ├── gnn.py ├── logger.py ├── main.py ├── model ├── embedding │ ├── coco_bert_80x768.pkl │ ├── coco_bert_80x768_ec.pkl │ ├── coco_char2vec_80x300.pkl │ ├── coco_char2vec_80x300_ec.pkl │ ├── coco_fasttext_80x300.pkl │ ├── coco_fasttext_80x300_ec.pkl │ ├── coco_glove_word2vec_80x300.pkl │ ├── coco_glove_word2vec_80x300_ec.pkl │ ├── coco_roberta_80x768.pkl │ └── coco_roberta_80x768_ec.pkl └── topology │ ├── coco_adj.pkl │ ├── coco_bert_base_cosine_adj.pkl_emb │ ├── coco_char2vec_cosine_adj.pkl_emb │ ├── coco_glove_cosine_adj.pkl_emb │ ├── coco_glove_word2vec.pkl │ └── coco_mod.txt ├── models.py ├── neptune.txt ├── prepare.py ├── requirements.txt └── util.py /README.md: -------------------------------------------------------------------------------- 1 | # Modular Graph Transformer Networks (MGTN) 2 | This project implements the multi-learning based on Modular Graph Transformer Networks (MGTN). 3 | 4 | ### Requirements 5 | Please, install the following packages 6 | - numpy 7 | - pytorch (1.*) 8 | - torchnet 9 | - torchvision 10 | - tqdm 11 | - networkx 12 | 13 | ### Download best checkpoints 14 | checkpoint/coco/mgtn_final_86.9762.pth.tar ([Dropbox](https://www.dropbox.com/s/fr2286gwxsg80kq/mgtn_final_86.9762.pth.tar?dl=0)) 15 | 16 | ### Performance 17 | 18 | | Method | mAP | CP | CR | CF1 | OP | OR | OF1 | 19 | | ----------------------- | --------- | -------- | ---------- | --------- | -------- | --------- | --------- | 20 | | CNN-RNN | 61.2 | - | - | - | - | - | - | 21 | | SRN | 77.1 | 81.6 | 65.4 | 71.2 | 82.7 | 69.9 | 75.8 | 22 | | Baseline(ResNet101) | 77.3 | 80.2 | 66.7 | 72.8 | 83.9 | 70.8 | 76.8 | 23 | | Multi-Evidence | – | 80.4 | 70.2 | 74.9 | 85.2 | 72.5 | 78.4 | 24 | | ML-GCN (2019) | 82.4 | 84.4 | 71.4 | 77.4 | 85.8 | 74.5 | 79.8 | 25 | | ML-GCN (ResNeXt50 swsl) | 86.2 | 85.8 | 77.3 | 81.3 | 86.2 | 79.7 | 82.8 | 26 | | A-GCN | 83.1 | 84.7 | 72.3 | 78.0 | 85.6 | 75.5 | 80.3 | 27 | | KSSNet | 83.7 | 84.6 | 73.2 | 77.2 | 87.8 | 76.2 | 81.5 | 28 | | SGTN (Our**) | 86.6 | 77.2 | **82.2** | 79.6 | 76.0 | **82.6** | 79.2 | 29 | | **MGTN(Base)** | 86.9 | **89.4** | 74.5 | 81.3 | **90.9** | 76.3 | 83.0 | 30 | | **MGTN(Final}** | **87.0** | 86.1 | 77.9 | **81.8** | 87.7 | 79.4 | **83.4** | 31 | 32 | ** SGTN (Our): https://github.com/ReML-AI/sgtn 33 | 34 | ### TGCN on COCO 35 | 36 | ```sh 37 | python main.py data/coco --image-size 448 --workers 8 --batch-size 32 --lr 0.03 --learning-rate-decay 0.1 --epoch_step 20 30 --embedding model/embedding/coco_glove_word2vec_80x300_ec.pkl --adj-strong-threshold 0.4 --adj-weak-threshold 0.2 --device_ids 0 1 2 3 38 | ``` 39 | 40 | ### How to cite this work? 41 | ``` 42 | @inproceedings{Nguyen:AAAI:2021, 43 | author = {Nguyen, Hoang D. and Vu, Xuan-Son and Le, Duc-Trong}, 44 | title = {Modular Graph Transformer Networks for Multi-Label Image Classification}, 45 | booktitle = {Proceedings of the AAAI Conference on Artificial Intelligence}, 46 | series = {AAAI '21}, 47 | year = {2021}, 48 | publisher = {AAAI} 49 | } 50 | ``` 51 | 52 | 53 | 54 | ## Reference 55 | This project is based on the following implementations: 56 | 57 | - https://github.com/ReML-AI/sgtn 58 | - https://github.com/durandtibo/wildcat.pytorch 59 | - https://github.com/tkipf/pygcn 60 | - https://github.com/Megvii-Nanjing/ML_GCN/ 61 | - https://github.com/seongjunyun/Graph_Transformer_Networks 62 | 63 | 64 | -------------------------------------------------------------------------------- /checkpoint/coco/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReML-AI/MGTN/76bf9ea1f036eec2374576f1d7509f8a2c5dd065/checkpoint/coco/.gitkeep -------------------------------------------------------------------------------- /coco.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import json 3 | import os 4 | import subprocess 5 | from PIL import Image 6 | import numpy as np 7 | import torch 8 | import pickle 9 | from util import * 10 | 11 | urls = {'train_img':'http://images.cocodataset.org/zips/train2014.zip', 12 | 'val_img' : 'http://images.cocodataset.org/zips/val2014.zip', 13 | 'annotations':'http://images.cocodataset.org/annotations/annotations_trainval2014.zip'} 14 | 15 | def download_coco2014(root, phase): 16 | if not os.path.exists(root): 17 | os.makedirs(root) 18 | tmpdir = os.path.join(root, 'tmp/') 19 | data = os.path.join(root, 'data/') 20 | if not os.path.exists(data): 21 | os.makedirs(data) 22 | if not os.path.exists(tmpdir): 23 | os.makedirs(tmpdir) 24 | if phase == 'train': 25 | filename = 'train2014.zip' 26 | elif phase == 'val': 27 | filename = 'val2014.zip' 28 | cached_file = os.path.join(tmpdir, filename) 29 | if not os.path.exists(cached_file): 30 | print('Downloading: "{}" to {}\n'.format(urls[phase + '_img'], cached_file)) 31 | os.chdir(tmpdir) 32 | subprocess.call('wget ' + urls[phase + '_img'], shell=True) 33 | os.chdir(root) 34 | # extract file 35 | img_data = os.path.join(data, filename.split('.')[0]) 36 | if not os.path.exists(img_data): 37 | print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=data)) 38 | command = 'unzip {} -d {}'.format(cached_file,data) 39 | os.system(command) 40 | print('[dataset] Done!') 41 | 42 | # train/val images/annotations 43 | cached_file = os.path.join(tmpdir, 'annotations_trainval2014.zip') 44 | if not os.path.exists(cached_file): 45 | print('Downloading: "{}" to {}\n'.format(urls['annotations'], cached_file)) 46 | os.chdir(tmpdir) 47 | subprocess.Popen('wget ' + urls['annotations'], shell=True) 48 | os.chdir(root) 49 | annotations_data = os.path.join(data, 'annotations') 50 | if not os.path.exists(annotations_data): 51 | print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=data)) 52 | command = 'unzip {} -d {}'.format(cached_file, data) 53 | os.system(command) 54 | print('[annotation] Done!') 55 | 56 | anno = os.path.join(data, '{}_anno.json'.format(phase)) 57 | img_id = {} 58 | annotations_id = {} 59 | if not os.path.exists(anno): 60 | annotations_file = json.load(open(os.path.join(annotations_data, 'instances_{}2014.json'.format(phase)))) 61 | annotations = annotations_file['annotations'] 62 | category = annotations_file['categories'] 63 | category_id = {} 64 | for cat in category: 65 | category_id[cat['id']] = cat['name'] 66 | cat2idx = categoty_to_idx(sorted(category_id.values())) 67 | images = annotations_file['images'] 68 | for annotation in annotations: 69 | if annotation['image_id'] not in annotations_id: 70 | annotations_id[annotation['image_id']] = set() 71 | annotations_id[annotation['image_id']].add(cat2idx[category_id[annotation['category_id']]]) 72 | for img in images: 73 | if img['id'] not in annotations_id: 74 | continue 75 | if img['id'] not in img_id: 76 | img_id[img['id']] = {} 77 | img_id[img['id']]['file_name'] = img['file_name'] 78 | img_id[img['id']]['labels'] = list(annotations_id[img['id']]) 79 | anno_list = [] 80 | for k, v in img_id.items(): 81 | anno_list.append(v) 82 | json.dump(anno_list, open(anno, 'w')) 83 | if not os.path.exists(os.path.join(data, 'category.json')): 84 | json.dump(cat2idx, open(os.path.join(data, 'category.json'), 'w')) 85 | del img_id 86 | del anno_list 87 | del images 88 | del annotations_id 89 | del annotations 90 | del category 91 | del category_id 92 | print('[json] Done!') 93 | 94 | def categoty_to_idx(category): 95 | cat2idx = {} 96 | for cat in category: 97 | cat2idx[cat] = len(cat2idx) 98 | return cat2idx 99 | 100 | 101 | class COCO2014(data.Dataset): 102 | def __init__(self, root, transform=None, phase='train', emb_name=None): 103 | self.root = root 104 | self.phase = phase 105 | self.img_list = [] 106 | self.transform = transform 107 | download_coco2014(root, phase) 108 | self.get_anno() 109 | self.num_classes = len(self.cat2idx) 110 | 111 | with open(emb_name, 'rb') as f: 112 | self.emb = pickle.load(f) 113 | self.emb_name = emb_name 114 | 115 | def get_anno(self): 116 | list_path = os.path.join(self.root, 'data', '{}_anno.json'.format(self.phase)) 117 | self.img_list = json.load(open(list_path, 'r')) 118 | self.cat2idx = json.load(open(os.path.join(self.root, 'data', 'category.json'), 'r')) 119 | 120 | def __len__(self): 121 | return len(self.img_list) 122 | 123 | def __getitem__(self, index): 124 | item = self.img_list[index] 125 | return self.get(item) 126 | 127 | def get(self, item): 128 | filename = item['file_name'] 129 | labels = sorted(item['labels']) 130 | img_path = os.path.join(self.root, 'data', '{}2014'.format(self.phase), filename) 131 | img = Image.open(img_path).convert('RGB') 132 | if self.transform is not None: 133 | img = self.transform(img) 134 | target = np.zeros(self.num_classes, np.float32) - 1 135 | target[labels] = 1 136 | return (img, filename, self.emb), target 137 | -------------------------------------------------------------------------------- /data/coco/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReML-AI/MGTN/76bf9ea1f036eec2374576f1d7509f8a2c5dd065/data/coco/.gitkeep -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import time 4 | import torch.backends.cudnn as cudnn 5 | import torch.nn.parallel 6 | import torch.optim 7 | import torch.utils.data 8 | import torchnet as tnt 9 | import torchvision.transforms as transforms 10 | import torch.nn as nn 11 | from util import * 12 | from logger import * 13 | 14 | tqdm.monitor_interval = 0 15 | class MultiLabelEngine(object): 16 | def __init__(self, state={}): 17 | self.state = { 18 | 'use_gpu': torch.cuda.is_available(), 19 | 'image_size': 224, 20 | 'batch_size': 64, 21 | 'batch_size_test': 64, 22 | 'workers': 25, 23 | 'device_ids': None, 24 | 'evaluate': False, 25 | 'lr_decay': 0.1, 26 | 'start_epoch': 0, 27 | 'max_epochs': 300, 28 | 'epoch_step': [], 29 | 'difficult_examples': False, 30 | 'mlt': 0.999, 31 | 'use_pb': True, 32 | 'print_freq': 0, 33 | 'arch': '', 34 | 'resume': None, 35 | 'save_model_path': None, 36 | 'filename_previous_best': None, 37 | 'meter_loss': tnt.meter.AverageValueMeter(), 38 | 'batch_time': tnt.meter.AverageValueMeter(), 39 | 'data_time': tnt.meter.AverageValueMeter(), 40 | 'best_score': 0 41 | } 42 | self.state.update(state) 43 | self.state.setdefault('train_transform', transforms.Compose([ 44 | transforms.Resize((512, 512)), 45 | MultiScaleCrop(self.state['image_size'], scales=(1.0, 0.875, 0.75, 0.66, 0.5), max_distort=2), 46 | transforms.RandomHorizontalFlip(), 47 | transforms.ToTensor(), 48 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 49 | std=[0.229, 0.224, 0.225])])) 50 | self.state.setdefault('train_target_transform', None) 51 | self.state.setdefault('val_transform', transforms.Compose([ 52 | Warp(self.state['image_size']), 53 | transforms.ToTensor(), 54 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 55 | std=[0.229, 0.224, 0.225])])) 56 | self.state.setdefault('val_target_transform', None) 57 | self.state['ap_meter'] = AveragePrecisionMeter(self.state['difficult_examples']) 58 | 59 | def on_start_epoch(self, training, model, criterion, data_loader, optimizer=None, display=True): 60 | self.state['meter_loss'].reset() 61 | self.state['batch_time'].reset() 62 | self.state['data_time'].reset() 63 | self.state['ap_meter'].reset() 64 | 65 | def on_end_epoch(self, training, model, criterion, data_loader, optimizer=None, display=True): 66 | map = 100 * self.state['ap_meter'].value().mean() 67 | loss = self.state['meter_loss'].value()[0] 68 | OP, OR, OF1, CP, CR, CF1 = self.state['ap_meter'].overall() 69 | OP_k, OR_k, OF1_k, CP_k, CR_k, CF1_k = self.state['ap_meter'].overall_topk(3) 70 | if display: 71 | if training: 72 | print('Epoch: [{0}]\t Loss {loss:.4f}\t mAP {map:.3f}'.format(self.state['epoch'], loss=loss, map=map)) 73 | print('OP: {OP:.4f}\t OR: {OR:.4f}\t OF1: {OF1:.4f}\t CP: {CP:.4f}\t CR: {CR:.4f}\t CF1: {CF1:.4f}' 74 | .format(OP=OP, OR=OR, OF1=OF1, CP=CP, CR=CR, CF1=CF1)) 75 | send_log('epoch', self.state['epoch']) 76 | send_log('train_loss', loss) 77 | send_log('train_map', map) 78 | else: 79 | print('Test: \t Loss {loss:.4f}\t mAP {map:.3f}'.format(loss=loss, map=map)) 80 | print('OP: {OP:.4f}\t OR: {OR:.4f}\t OF1: {OF1:.4f}\t CP: {CP:.4f}\t CR: {CR:.4f}\t CF1: {CF1:.4f}' 81 | .format(OP=OP, OR=OR, OF1=OF1, CP=CP, CR=CR, CF1=CF1)) 82 | print('OP_3: {OP:.4f}\t OR_3: {OR:.4f}\t OF1_3: {OF1:.4f}\t CP_3: {CP:.4f}\t CR_3: {CR:.4f}\t CF1_3: {CF1:.4f}' 83 | .format(OP=OP_k, OR=OR_k, OF1=OF1_k, CP=CP_k, CR=CR_k, CF1=CF1_k)) 84 | send_log('test_loss', loss) 85 | send_log('test_map', map) 86 | return map 87 | 88 | def on_start_batch(self, training, model, criterion, data_loader, optimizer=None, display=True): 89 | self.state['target_gt'] = self.state['target'].clone() 90 | self.state['target'][self.state['target'] == 0] = 1 91 | self.state['target'][self.state['target'] == -1] = 0 92 | 93 | input = self.state['input'] 94 | self.state['img'] = input[0] 95 | self.state['name'] = input[1] 96 | 97 | 98 | def on_end_batch(self, training, model, criterion, data_loader, optimizer=None, display=True): 99 | # record loss 100 | self.state['loss_batch'] = self.state['loss'].item() 101 | self.state['meter_loss'].add(self.state['loss_batch']) 102 | 103 | # measure mAP 104 | self.state['ap_meter'].add(self.state['output'].data, self.state['target_gt']) 105 | 106 | if display and self.state['print_freq'] != 0 and self.state['iteration'] % self.state['print_freq'] == 0: 107 | loss = self.state['meter_loss'].value()[0] 108 | batch_time = self.state['batch_time'].value()[0] 109 | data_time = self.state['data_time'].value()[0] 110 | if training: 111 | print('Epoch: [{0}][{1}/{2}]\t' 112 | 'Time {batch_time_current:.3f} ({batch_time:.3f})\t' 113 | 'Data {data_time_current:.3f} ({data_time:.3f})\t' 114 | 'Loss {loss_current:.4f} ({loss:.4f})'.format( 115 | self.state['epoch'], self.state['iteration'], len(data_loader), 116 | batch_time_current=self.state['batch_time_current'], 117 | batch_time=batch_time, data_time_current=self.state['data_time_batch'], 118 | data_time=data_time, loss_current=self.state['loss_batch'], loss=loss)) 119 | else: 120 | print('Test: [{0}/{1}]\t' 121 | 'Time {batch_time_current:.3f} ({batch_time:.3f})\t' 122 | 'Data {data_time_current:.3f} ({data_time:.3f})\t' 123 | 'Loss {loss_current:.4f} ({loss:.4f})'.format( 124 | self.state['iteration'], len(data_loader), batch_time_current=self.state['batch_time_current'], 125 | batch_time=batch_time, data_time_current=self.state['data_time_batch'], 126 | data_time=data_time, loss_current=self.state['loss_batch'], loss=loss)) 127 | 128 | def on_forward(self, training, model, criterion, data_loader, optimizer=None, display=True): 129 | with torch.set_grad_enabled(training): 130 | img_var = torch.autograd.Variable(self.state['img']).float() 131 | target_var = torch.autograd.Variable(self.state['target']).float() 132 | # compute output 133 | self.state['output'] = model(img_var) 134 | self.state['loss'] = criterion(self.state['output'], target_var) 135 | if training: 136 | optimizer.zero_grad() 137 | self.state['loss'].backward() 138 | optimizer.step() 139 | else: 140 | torch.cuda.empty_cache() 141 | 142 | 143 | def learning(self, model, criterion, train_dataset, val_dataset, optimizer=None): 144 | 145 | # define train and val transform 146 | train_dataset.transform = self.state['train_transform'] 147 | train_dataset.target_transform = self.state['train_target_transform'] 148 | val_dataset.transform = self.state['val_transform'] 149 | val_dataset.target_transform = self.state['val_target_transform'] 150 | 151 | # data loading code 152 | train_loader = torch.utils.data.DataLoader(train_dataset, 153 | batch_size=self.state['batch_size'], shuffle=True, 154 | num_workers=self.state['workers']) 155 | 156 | val_loader = torch.utils.data.DataLoader(val_dataset, 157 | batch_size=self.state['batch_size_test'], shuffle=False, 158 | num_workers=self.state['workers']) 159 | 160 | # optionally resume from a checkpoint 161 | if self.state['resume'] is not None: 162 | if os.path.isfile(self.state['resume']): 163 | print("=> loading checkpoint '{}'".format(self.state['resume'])) 164 | checkpoint = torch.load(self.state['resume']) 165 | self.state['start_epoch'] = checkpoint['epoch'] 166 | self.state['best_score'] = checkpoint['best_score'] 167 | model.load_state_dict(checkpoint['state_dict']) 168 | print("=> loaded checkpoint '{}' (epoch {})" 169 | .format(self.state['evaluate'], checkpoint['epoch'])) 170 | else: 171 | print("=> no checkpoint found at '{}'".format(self.state['resume'])) 172 | 173 | 174 | if self.state['use_gpu']: 175 | train_loader.pin_memory = True 176 | val_loader.pin_memory = True 177 | cudnn.benchmark = False 178 | 179 | model = torch.nn.DataParallel(model, device_ids=self.state['device_ids']).cuda() 180 | 181 | criterion = criterion.cuda() 182 | 183 | if self.state['evaluate']: 184 | self.validate(val_loader, model, criterion) 185 | return 186 | 187 | for epoch in range(self.state['start_epoch'], self.state['max_epochs']): 188 | self.state['epoch'] = epoch 189 | lr = self.adjust_learning_rate(optimizer) 190 | print('lr:',lr, '|', 'step:' ,self.state['epoch_step'],'|', 'decay: ', self.state['lr_decay']) 191 | 192 | # train for one epoch 193 | self.train(train_loader, model, criterion, optimizer, epoch) 194 | # evaluate on validation set 195 | prec1 = self.validate(val_loader, model, criterion) 196 | 197 | # remember best prec@1 and save checkpoint 198 | is_best = prec1 > self.state['best_score'] 199 | self.state['best_score'] = max(prec1, self.state['best_score']) 200 | self.save_checkpoint({ 201 | 'epoch': epoch + 1, 202 | 'arch': self.state['arch'], 203 | 'state_dict': model.module.state_dict() if self.state['use_gpu'] else model.state_dict(), 204 | 'best_score': self.state['best_score'], 205 | }, is_best) 206 | 207 | print(' *** best={best:.3f}'.format(best=self.state['best_score'])) 208 | set_log_property('top', float(self.state['best_score'])) 209 | 210 | return self.state['best_score'] 211 | 212 | def run(self, training, data_loader, model, criterion, optimizer=None, epoch=None): 213 | if training: 214 | # switch to train mode 215 | model.train() 216 | else: 217 | # switch to evaluate mode 218 | model.eval() 219 | 220 | self.on_start_epoch(training, model, criterion, data_loader, optimizer) 221 | 222 | if self.state['use_pb']: 223 | data_loader = tqdm(data_loader, desc='Training' if training else 'Test') 224 | 225 | end = time.time() 226 | for i, (input, target) in enumerate(data_loader): 227 | # measure data loading time 228 | self.state['iteration'] = i 229 | self.state['data_time_batch'] = time.time() - end 230 | self.state['data_time'].add(self.state['data_time_batch']) 231 | 232 | self.state['input'] = input 233 | self.state['target'] = target 234 | 235 | self.on_start_batch(training, model, criterion, data_loader, optimizer) 236 | 237 | if self.state['use_gpu']: 238 | self.state['target'] = self.state['target'].cuda() 239 | 240 | self.on_forward(training, model, criterion, data_loader, optimizer) 241 | 242 | # measure elapsed time 243 | self.state['batch_time_current'] = time.time() - end 244 | self.state['batch_time'].add(self.state['batch_time_current']) 245 | end = time.time() 246 | # measure accuracy 247 | self.on_end_batch(training, model, criterion, data_loader, optimizer) 248 | 249 | return self.on_end_epoch(training, model, criterion, data_loader, optimizer) 250 | 251 | def train(self, data_loader, model, criterion, optimizer, epoch): 252 | return self.run(True, data_loader, model, criterion, optimizer, epoch) 253 | 254 | def validate(self, data_loader, model, criterion): 255 | return self.run(False, data_loader, model, criterion) 256 | 257 | def save_checkpoint(self, state, is_best, filename='checkpoint.pth.tar'): 258 | if self.state['save_model_path'] is not None: 259 | filename_ = filename 260 | filename = os.path.join(self.state['save_model_path'], filename_) 261 | if not os.path.exists(self.state['save_model_path']): 262 | os.makedirs(self.state['save_model_path']) 263 | print('save model {filename}'.format(filename=filename)) 264 | torch.save(state, filename) 265 | if is_best: 266 | filename_best = 'model_best.pth.tar' 267 | if self.state['save_model_path'] is not None: 268 | filename_best = os.path.join(self.state['save_model_path'], filename_best) 269 | shutil.copyfile(filename, filename_best) 270 | if self.state['save_model_path'] is not None: 271 | if self.state['filename_previous_best'] is not None: 272 | os.remove(self.state['filename_previous_best']) 273 | filename_best = os.path.join(self.state['save_model_path'], 'model_best_{score:.4f}.pth.tar'.format(score=state['best_score'])) 274 | shutil.copyfile(filename, filename_best) 275 | self.state['filename_previous_best'] = filename_best 276 | 277 | def adjust_learning_rate(self, optimizer): 278 | """Sets the learning rate to the initial LR decayed by a fraction every epoch steps""" 279 | lr_list = [] 280 | decay = self.state['lr_decay'] if sum(self.state['epoch'] == np.array(self.state['epoch_step'])) > 0 else 1.0 281 | for param_group in optimizer.param_groups: 282 | param_group['lr'] = param_group['lr'] * decay 283 | lr_list.append(param_group['lr']) 284 | return np.unique(lr_list) 285 | 286 | class GraphMultiLabelEngine(MultiLabelEngine): 287 | def on_forward(self, training, model, criterion, data_loader, optimizer=None, display=True): 288 | with torch.set_grad_enabled(training): 289 | img_var = torch.autograd.Variable(self.state['img']).float() 290 | target_var = torch.autograd.Variable(self.state['target']).float() 291 | emb_var = torch.autograd.Variable(self.state['emb']).float().detach() # one hot 292 | # compute output 293 | self.state['output'] = model(img_var, emb_var) 294 | self.state['loss'] = criterion(self.state['output'], target_var) 295 | if training: 296 | optimizer.zero_grad() 297 | self.state['loss'].backward() 298 | nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) 299 | optimizer.step() 300 | else: 301 | torch.cuda.empty_cache() 302 | 303 | def on_start_batch(self, training, model, criterion, data_loader, optimizer=None, display=True): 304 | MultiLabelEngine.on_start_batch(self, training, model, criterion, data_loader, optimizer, display) 305 | self.state['emb'] = self.state['input'][2] 306 | 307 | -------------------------------------------------------------------------------- /gnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import math 6 | 7 | class GConv(nn.Module): 8 | """ 9 | Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 10 | """ 11 | 12 | def __init__(self, in_features, out_features, bias=False): 13 | super(GConv, self).__init__() 14 | self.in_features = in_features 15 | self.out_features = out_features 16 | self.weight = nn.Parameter(torch.Tensor(in_features, out_features)) 17 | if bias: 18 | self.bias = nn.Parameter(torch.Tensor(1, 1, out_features)) 19 | else: 20 | self.register_parameter('bias', None) 21 | self.reset_parameters() 22 | 23 | def reset_parameters(self): 24 | stdv = 1. / math.sqrt(self.weight.size(1)) 25 | self.weight.data.uniform_(-stdv, stdv) 26 | if self.bias is not None: 27 | self.bias.data.uniform_(-stdv, stdv) 28 | 29 | def forward(self, input, adj): 30 | support = torch.matmul(input, self.weight) 31 | output = torch.matmul(adj, support) 32 | if self.bias is not None: 33 | return output + self.bias 34 | else: 35 | return output 36 | 37 | def __repr__(self): 38 | return self.__class__.__name__ + ' (' \ 39 | + str(self.in_features) + ' -> ' \ 40 | + str(self.out_features) + ')' 41 | 42 | class GTN(nn.Module): 43 | ''' 44 | The code is provided by Yun et al. 45 | Seongjun Yun, Minbyul Jeong, Raehyun Kim, Jaewoo Kang, Hyunwoo J. Kim, Graph Transformer Networks, 46 | In Advances in Neural Information Processing Systems (NeurIPS 2019). 47 | https://github.com/seongjunyun/Graph_Transformer_Networks 48 | ''' 49 | 50 | def __init__(self, num_edge, num_channels, w_in, w_out, num_class,num_layers,norm): 51 | super(GTN, self).__init__() 52 | self.num_edge = num_edge 53 | self.num_channels = num_channels 54 | self.w_in = w_in 55 | self.w_out = w_out 56 | self.num_class = num_class 57 | self.num_layers = num_layers 58 | self.is_norm = norm 59 | layers = [] 60 | for i in range(num_layers): 61 | if i == 0: 62 | layers.append(GTLayer(num_edge, num_channels, first=True)) 63 | else: 64 | layers.append(GTLayer(num_edge, num_channels, first=False)) 65 | self.layers = nn.ModuleList(layers) 66 | self.weight = nn.Parameter(torch.Tensor(w_in, w_out)) 67 | self.bias = nn.Parameter(torch.Tensor(w_out)) 68 | self.loss = nn.CrossEntropyLoss() 69 | self.linear1 = nn.Linear(self.w_out*self.num_channels, self.w_out) 70 | self.linear2 = nn.Linear(self.w_out, self.num_class) 71 | self.reset_parameters() 72 | 73 | def reset_parameters(self): 74 | nn.init.xavier_uniform_(self.weight) 75 | nn.init.zeros_(self.bias) 76 | 77 | def gcn_conv(self,X,H): 78 | X = torch.mm(X, self.weight) 79 | H = self.norm(H, add=True) 80 | return torch.mm(H.t(),X) 81 | 82 | def normalization(self, H): 83 | for i in range(self.num_channels): 84 | if i==0: 85 | H_ = self.norm(H[i,:,:]).unsqueeze(0) 86 | else: 87 | H_ = torch.cat((H_,self.norm(H[i,:,:]).unsqueeze(0)), dim=0) 88 | return H_ 89 | 90 | def norm(self, H, add=False): 91 | H = H.t() 92 | if add == False: 93 | H = H*((torch.eye(H.shape[0])==0).type(torch.FloatTensor)) 94 | else: 95 | H = H*((torch.eye(H.shape[0])==0).type(torch.FloatTensor)) + torch.eye(H.shape[0]).type(torch.FloatTensor) 96 | deg = torch.sum(H, dim=1) 97 | deg_inv = deg.pow(-1) 98 | deg_inv[deg_inv == float('inf')] = 0 99 | deg_inv = deg_inv*torch.eye(H.shape[0]).type(torch.FloatTensor) 100 | H = torch.mm(deg_inv,H) 101 | H = H.t() 102 | return H 103 | 104 | def forward(self, A, X, target_x, target): 105 | A = A.unsqueeze(0).permute(0,3,1,2) 106 | Ws = [] 107 | for i in range(self.num_layers): 108 | if i == 0: 109 | H, W = self.layers[i](A) 110 | else: 111 | H = self.normalization(H) 112 | H, W = self.layers[i](A, H) 113 | Ws.append(W) 114 | 115 | for i in range(self.num_channels): 116 | if i==0: 117 | X_ = F.relu(self.gcn_conv(X,H[i])) 118 | else: 119 | X_tmp = F.relu(self.gcn_conv(X,H[i])) 120 | X_ = torch.cat((X_,X_tmp), dim=1) 121 | X_ = self.linear1(X_) 122 | X_ = F.relu(X_) 123 | y = self.linear2(X_[target_x]) 124 | loss = self.loss(y, target) 125 | return loss, y, Ws 126 | 127 | class GTLayer(nn.Module): 128 | 129 | def __init__(self, in_channels, out_channels, first=True): 130 | super(GTLayer, self).__init__() 131 | self.in_channels = in_channels 132 | self.out_channels = out_channels 133 | self.first = first 134 | if self.first == True: 135 | self.conv1 = GTConv(in_channels, out_channels) 136 | self.conv2 = GTConv(in_channels, out_channels) 137 | else: 138 | self.conv1 = GTConv(in_channels, out_channels) 139 | 140 | def forward(self, A, H_=None): 141 | if self.first == True: 142 | a = self.conv1(A) 143 | b = self.conv2(A) 144 | H = torch.bmm(a,b) 145 | W = [(F.softmax(self.conv1.weight, dim=1)).detach(),(F.softmax(self.conv2.weight, dim=1)).detach()] 146 | else: 147 | a = self.conv1(A) 148 | H = torch.bmm(H_,a) 149 | W = [(F.softmax(self.conv1.weight, dim=1)).detach()] 150 | return H,W 151 | 152 | class GTConv(nn.Module): 153 | 154 | def __init__(self, in_channels, out_channels): 155 | super(GTConv, self).__init__() 156 | self.in_channels = in_channels 157 | self.out_channels = out_channels 158 | self.weight = nn.Parameter(torch.Tensor(out_channels,in_channels,1,1)) 159 | self.bias = None 160 | self.scale = nn.Parameter(torch.Tensor([0.1]), requires_grad=False) 161 | self.reset_parameters() 162 | def reset_parameters(self): 163 | n = self.in_channels 164 | nn.init.constant_(self.weight, 0.1) 165 | if self.bias is not None: 166 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) 167 | bound = 1 / math.sqrt(fan_in) 168 | nn.init.uniform_(self.bias, -bound, bound) 169 | 170 | def forward(self, A): 171 | A = torch.sum(A.cuda()*F.softmax(self.weight, dim=1).cuda(), dim=1) 172 | return A 173 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import neptune 2 | 3 | use_neptune = False 4 | 5 | def init_log(args): 6 | global use_neptune 7 | if args.neptune: 8 | with open(args.neptune_path, 'r') as f: 9 | nep = f.readlines() 10 | neptune.init(nep[0].strip(), api_token=nep[1].strip()) 11 | neptune.create_experiment(params=vars(args), upload_source_files=['*.py']) 12 | use_neptune = True 13 | 14 | 15 | def send_log(key, value): 16 | global use_neptune 17 | if use_neptune: 18 | try: 19 | neptune.send_metric(key, value) 20 | except: 21 | print("Log failed: ", key, value) 22 | 23 | def set_log_property(key, value): 24 | if use_neptune: 25 | try: 26 | neptune.set_property(key, value) 27 | except: 28 | print("Log property failed: ", key, value) 29 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from engine import * 3 | from models import * 4 | from coco import * 5 | from util import * 6 | from logger import * 7 | 8 | parser = argparse.ArgumentParser(description='Graph Multi-Label Classification Training') 9 | parser.add_argument('data', metavar='DIR', 10 | help='path to dataset (e.g. data/') 11 | parser.add_argument('--image-size', '-i', default=448, type=int, 12 | metavar='N', help='image size (default: 224)') 13 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 14 | help='number of data loading workers (default: 8)') 15 | parser.add_argument('--epochs', default=200, type=int, metavar='N', 16 | help='number of total epochs to run') 17 | parser.add_argument('--epoch_step', default=[30,40], type=int, nargs='+', 18 | help='number of epochs to change learning rate') 19 | parser.add_argument('--device_ids', default=[0], type=int, nargs='+', 20 | help='number of epochs to change learning rate') 21 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 22 | help='manual epoch number (useful on restarts)') 23 | parser.add_argument('-b', '--batch-size', default=16, type=int, 24 | metavar='N', help='mini-batch size (default: 16)') 25 | parser.add_argument('-bt', '--batch-size-test', default=None, type=int, 26 | metavar='N', help='mini-batch size for test (default: 16)') 27 | parser.add_argument('--lr', '--learning-rate', default=0.03, type=float, 28 | metavar='LR', help='initial learning rate') 29 | parser.add_argument('--lrd', '--learning-rate-decay', default=0.1, type=float, 30 | metavar='LRD', help='learning rate decay') 31 | parser.add_argument('--lrp', '--learning-rate-pretrained', default=0.1, type=float, 32 | metavar='LR', help='learning rate for pre-trained layers') 33 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 34 | help='momentum') 35 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 36 | metavar='W', help='weight decay (default: 1e-4)') 37 | parser.add_argument('--print-freq', '-p', default=0, type=int, 38 | metavar='N', help='print frequency (default: 10)') 39 | parser.add_argument('--embedding', default='model/embedding/coco_glove_word2vec_80x300.pkl', 40 | type=str, metavar='EMB', help='path to embedding (default: glove)') 41 | parser.add_argument('--embedding-length', default=300, type=int, metavar='EMB', 42 | help='embedding length (default: 300)') 43 | parser.add_argument('--adj-file', default='model/topology/coco_adj.pkl', type=str, metavar='ADJ', 44 | help='Adj file (default: model/topology/coco_adj.pkl') 45 | parser.add_argument('--adj-strong-threshold', default=0.4, type=float, metavar='ADJTS', 46 | help='Adj strong threshold (default: 0.4)') 47 | parser.add_argument('--adj-weak-threshold', default=0.2, type=float, metavar='ADJTW', 48 | help='Adj weak threshold (default: 0.2)') 49 | parser.add_argument('--mod-file', default='model/topology/coco_mod.txt', type=str, metavar='MOD', 50 | help='Adj file (default: model/topology/coco_mod.txt') 51 | parser.add_argument('--mlt', '--multi-learning-threshold', default=0.999, type=float, metavar='MLT', 52 | help='Multi-learning threshold (default: 0.999)') 53 | parser.add_argument('--exp-name', dest='exp_name', default='coco', type=str, metavar='COCO', 54 | help='Name of experiment to have different location to save checkpoints') 55 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 56 | help='path to latest checkpoint (default: none)') 57 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 58 | help='evaluate model on validation set') 59 | parser.add_argument('-n', '--neptune', dest='neptune', action='store_true', 60 | help='run with neptune') 61 | parser.add_argument('--neptune-path', default='neptune.txt', type=str, metavar='PATH', 62 | help='Neptune API keys (default: neptune.txt)') 63 | 64 | def main(): 65 | args = parser.parse_args() 66 | torch.backends.cudnn.benchmark = False 67 | init_log(args) 68 | 69 | train_dataset = COCO2014(args.data, phase="train", emb_name=args.embedding) 70 | val_dataset = COCO2014(args.data, phase="val", emb_name=args.embedding) 71 | num_classes = 80 72 | 73 | print('Embedding:', args.embedding, '(', args.embedding_length, ')') 74 | print('Adjacency file:', args.adj_file) 75 | print('Adjacency Strong Threshold:', args.adj_strong_threshold) 76 | print('Adjacency Weak Threshold:', args.adj_weak_threshold) 77 | print('Modularity file:', args.mod_file) 78 | 79 | if args.adj_strong_threshold < args.adj_weak_threshold: 80 | args.adj_weak_threshold = args.adj_strong_threshold 81 | 82 | model = mgtn_resnet(num_classes=num_classes, 83 | t1=args.adj_strong_threshold, 84 | t2=args.adj_weak_threshold, 85 | adj_file=args.adj_file, 86 | mod_file=args.mod_file, 87 | emb_features=args.embedding_length, 88 | ml_threshold=args.mlt) 89 | # define loss function (criterion) 90 | criterion = nn.MultiLabelSoftMarginLoss() 91 | 92 | # define optimizer 93 | optimizer = torch.optim.SGD(model.get_config_optim(args.lr, args.lrp), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 94 | 95 | model_path = "checkpoint/coco/%s" % args.exp_name 96 | if not os.path.exists(model_path): 97 | os.makedirs(model_path) 98 | 99 | state = { 100 | "batch_size": args.batch_size, 101 | "batch_size_test": args.batch_size if args.batch_size_test is None else args.batch_size_test, 102 | "image_size": args.image_size, 103 | "max_epochs": args.epochs, 104 | "evaluate": args.evaluate, 105 | "resume": args.resume, 106 | "num_classes": num_classes, 107 | "difficult_examples": True, 108 | "save_model_path": model_path, 109 | "workers": args.workers, 110 | "epoch_step": args.epoch_step, 111 | "lr": args.lr, 112 | "lr_decay": args.lrd, 113 | 'mlt': args.mlt, 114 | "device_ids": args.device_ids, 115 | "neptune": args.neptune, 116 | "evaluate": True if args.evaluate else False 117 | } 118 | 119 | engine = GraphMultiLabelEngine(state) 120 | engine.learning(model, criterion, train_dataset, val_dataset, optimizer) 121 | 122 | 123 | if __name__ == "__main__": 124 | main() 125 | -------------------------------------------------------------------------------- /model/embedding/coco_bert_80x768.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReML-AI/MGTN/76bf9ea1f036eec2374576f1d7509f8a2c5dd065/model/embedding/coco_bert_80x768.pkl -------------------------------------------------------------------------------- /model/embedding/coco_bert_80x768_ec.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReML-AI/MGTN/76bf9ea1f036eec2374576f1d7509f8a2c5dd065/model/embedding/coco_bert_80x768_ec.pkl -------------------------------------------------------------------------------- /model/embedding/coco_char2vec_80x300.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReML-AI/MGTN/76bf9ea1f036eec2374576f1d7509f8a2c5dd065/model/embedding/coco_char2vec_80x300.pkl -------------------------------------------------------------------------------- /model/embedding/coco_char2vec_80x300_ec.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReML-AI/MGTN/76bf9ea1f036eec2374576f1d7509f8a2c5dd065/model/embedding/coco_char2vec_80x300_ec.pkl -------------------------------------------------------------------------------- /model/embedding/coco_fasttext_80x300.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReML-AI/MGTN/76bf9ea1f036eec2374576f1d7509f8a2c5dd065/model/embedding/coco_fasttext_80x300.pkl -------------------------------------------------------------------------------- /model/embedding/coco_fasttext_80x300_ec.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReML-AI/MGTN/76bf9ea1f036eec2374576f1d7509f8a2c5dd065/model/embedding/coco_fasttext_80x300_ec.pkl -------------------------------------------------------------------------------- /model/embedding/coco_glove_word2vec_80x300.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReML-AI/MGTN/76bf9ea1f036eec2374576f1d7509f8a2c5dd065/model/embedding/coco_glove_word2vec_80x300.pkl -------------------------------------------------------------------------------- /model/embedding/coco_glove_word2vec_80x300_ec.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReML-AI/MGTN/76bf9ea1f036eec2374576f1d7509f8a2c5dd065/model/embedding/coco_glove_word2vec_80x300_ec.pkl -------------------------------------------------------------------------------- /model/embedding/coco_roberta_80x768.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReML-AI/MGTN/76bf9ea1f036eec2374576f1d7509f8a2c5dd065/model/embedding/coco_roberta_80x768.pkl -------------------------------------------------------------------------------- /model/embedding/coco_roberta_80x768_ec.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReML-AI/MGTN/76bf9ea1f036eec2374576f1d7509f8a2c5dd065/model/embedding/coco_roberta_80x768_ec.pkl -------------------------------------------------------------------------------- /model/topology/coco_adj.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReML-AI/MGTN/76bf9ea1f036eec2374576f1d7509f8a2c5dd065/model/topology/coco_adj.pkl -------------------------------------------------------------------------------- /model/topology/coco_bert_base_cosine_adj.pkl_emb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReML-AI/MGTN/76bf9ea1f036eec2374576f1d7509f8a2c5dd065/model/topology/coco_bert_base_cosine_adj.pkl_emb -------------------------------------------------------------------------------- /model/topology/coco_char2vec_cosine_adj.pkl_emb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReML-AI/MGTN/76bf9ea1f036eec2374576f1d7509f8a2c5dd065/model/topology/coco_char2vec_cosine_adj.pkl_emb -------------------------------------------------------------------------------- /model/topology/coco_glove_cosine_adj.pkl_emb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReML-AI/MGTN/76bf9ea1f036eec2374576f1d7509f8a2c5dd065/model/topology/coco_glove_cosine_adj.pkl_emb -------------------------------------------------------------------------------- /model/topology/coco_glove_word2vec.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReML-AI/MGTN/76bf9ea1f036eec2374576f1d7509f8a2c5dd065/model/topology/coco_glove_word2vec.pkl -------------------------------------------------------------------------------- /model/topology/coco_mod.txt: -------------------------------------------------------------------------------- 1 | 1 2 1 2 1 1 1 1 1 1 1 1 1 2 2 2 1 2 1 2 1 1 1 1 1 1 2 2 1 2 1 1 2 1 1 2 1 1 2 1 1 2 1 2 1 1 2 2 1 1 2 1 2 1 2 2 1 2 1 1 1 2 1 1 1 1 1 1 1 2 2 2 1 1 1 1 1 2 2 1 -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torchvision.models as models 2 | from util import * 3 | from gnn import * 4 | import torch 5 | import torch.nn as nn 6 | 7 | class MGTNResnet(nn.Module): 8 | def __init__(self, model_name, num_classes, emb_features=300, t1=0.0, t2=0.0, adj_file=None, mod_file=None, ml_threshold=0.999): 9 | super(MGTNResnet, self).__init__() 10 | 11 | _mods = np.loadtxt(mod_file, dtype=int) 12 | 13 | self.backbones = nn.ModuleList() 14 | # Create multiple backbones 15 | for i in range(int(max(_mods)) - int(min(_mods)) + 1): 16 | model = load_model(model_name) 17 | backbone = nn.Sequential( 18 | model.conv1, 19 | model.bn1, 20 | model.relu, 21 | model.maxpool, 22 | model.layer1, 23 | model.layer2, 24 | model.layer3, 25 | model.layer4, 26 | nn.MaxPool2d(14, 14), 27 | nn.Flatten(1) 28 | ) 29 | self.backbones.append(backbone) 30 | 31 | self.num_classes = num_classes 32 | 33 | # Graph Convolutions 34 | self.gc1 = GConv(emb_features, 2048) 35 | self.gc2 = GConv(2048, 4096) 36 | self.relu = nn.LeakyReLU(0.2) 37 | 38 | # Topology 39 | self.A = torch.stack([ 40 | torch.eye(num_classes).type(torch.FloatTensor), 41 | torch.from_numpy(AdjacencyHelper.gen_A(num_classes, 1.0, t1, adj_file)).type(torch.FloatTensor), 42 | torch.from_numpy(AdjacencyHelper.gen_A(num_classes, t1, t2, adj_file)).type(torch.FloatTensor) 43 | ]).unsqueeze(0) 44 | 45 | self.gtn = GTLayer(self.A.shape[1], 1, first=True) 46 | self.mods = nn.Parameter(torch.from_numpy(AdjacencyHelper.gen_M(_mods, dims=2048, t=ml_threshold)).float()) 47 | 48 | def forward(self, img, emb): 49 | fs = [] 50 | for i in range(len(self.backbones)): 51 | fs.append(self.backbones[i](img)) 52 | f = torch.cat(fs, 1) 53 | 54 | adj, _ = self.gtn.forward(self.A) 55 | adj = torch.squeeze(adj, 0) + torch.eye(self.num_classes).type(torch.FloatTensor).cuda() 56 | adj = AdjacencyHelper.gen_adj(adj) 57 | 58 | w = self.gc1(emb[0], adj) 59 | w = self.relu(w) 60 | w = self.gc2(w, adj) 61 | w = torch.mul(w, self.mods) 62 | 63 | w = w.transpose(0, 1) 64 | y = torch.matmul(f, w) 65 | return y 66 | 67 | def get_config_optim(self, lr, lrp): 68 | config_optim = [] 69 | for backbone in self.backbones: 70 | config_optim.append({'params': backbone.parameters(), 'lr': lr * lrp}) 71 | config_optim.append({'params': self.gc1.parameters(), 'lr': lr}) 72 | config_optim.append({'params': self.gc2.parameters(), 'lr': lr}) 73 | return config_optim 74 | 75 | def mgtn_resnet(num_classes, t1, t2, pretrained=True, adj_file=None, mod_file=None, emb_features=300, ml_threshold=0.999): 76 | return MGTNResnet('resnext50_32x4d_swsl', num_classes, t1=t1, t2=t2, adj_file=adj_file, mod_file=mod_file, emb_features=emb_features, ml_threshold=ml_threshold) 77 | -------------------------------------------------------------------------------- /neptune.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /prepare.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | import scipy as sp 4 | from scipy.sparse import linalg 5 | import argparse 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('-i', '--input-path', default='model/embedding/coco_glove_word2vec_80x300.pkl', 9 | type=str, help='Input Path') 10 | parser.add_argument('-a', '--adj-path', default='model/adjacency/coco_adj.pkl', 11 | type=str, help='Adjacency Path') 12 | parser.add_argument('-o', '--output-path', default='model/embedding/coco_glove_word2vec_80x300_ec.pkl', 13 | type=str, help='Output Path') 14 | parser.add_argument('-n', '--normalise', action='store_true', help='perform normalisation') 15 | parser.add_argument('-ec', '--eigenc', action='store_true', help='perform EC transformation') 16 | 17 | def load_adj(adj_file='data/coco/coco_adj.pkl'): 18 | result = pickle.load(open(adj_file, 'rb')) 19 | _adj = result['adj'] 20 | _nums = result['nums'] 21 | return (_adj, _nums) 22 | 23 | def eigenvector_centrality(adj): 24 | import networkx as nx 25 | graph = nx.from_numpy_matrix(adj) 26 | centrality = nx.eigenvector_centrality(graph) 27 | return np.array(tuple(centrality.values())) 28 | 29 | def rowmul(arr2d, arr1d): 30 | return np.array(arr2d) * np.array(arr1d)[:, None] 31 | 32 | def normalize(mx): 33 | """Row-normalize sparse matrix""" 34 | rowsum = np.array(mx.sum(1)) 35 | r_inv = np.power(rowsum, -1).flatten() 36 | r_inv[np.isinf(r_inv)] = 0. 37 | r_mat_inv = sp.diags(r_inv) 38 | mx = r_mat_inv.dot(mx) 39 | return mx 40 | 41 | def eigs(adj): 42 | eigenvalue, eigenvector = linalg.eigs(adj, k=1, which='LR') 43 | return (eigenvalue, eigenvector) 44 | 45 | def adjust(adj, t=0.4): 46 | _adj = np.array(adj) 47 | _nums = adj.shape[0] 48 | _nums = _nums[:, np.newaxis] 49 | _adj = _adj / _nums 50 | _adj[_adj < t] = 0 51 | _adj[_adj >= t] = 1 52 | _adj = _adj * 0.25 / (_adj.sum(0, keepdims=True) + 1e-6) 53 | _adj = _adj + np.identity(adj.shape[0], np.int) 54 | return _adj 55 | 56 | def main(): 57 | global args 58 | args = parser.parse_args() 59 | 60 | with open(args.input_path, 'rb') as finp: 61 | inp = pickle.load(finp) 62 | 63 | with open(args.adj_path, 'rb') as fadj: 64 | result = pickle.load(fadj) 65 | adj = result['adj'] 66 | 67 | if args.eigenc: 68 | print('Eigenvector Centrality Transformation') 69 | ec = eigenvector_centrality(adj) 70 | ec = ec * 10 # scale up by 10x 71 | out = rowmul(inp, ec) 72 | with open(args.output_path, 'wb') as fout: 73 | pickle.dump(out, fout, protocol=pickle.HIGHEST_PROTOCOL) 74 | print("Written to", args.output_path) 75 | 76 | 77 | if __name__ == '__main__': 78 | main() 79 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torchnet 2 | neptune_client 3 | psutil 4 | tqdm 5 | networkx 6 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import math 2 | from urllib.request import urlretrieve 3 | import torch 4 | from PIL import Image 5 | import multiprocessing as mp 6 | import itertools 7 | from tqdm import tqdm 8 | import numpy as np 9 | import random 10 | import torch.nn.functional as F 11 | import torchvision.models as models 12 | 13 | def download_url(url, destination=None, progress_bar=True): 14 | """Download a URL to a local file. 15 | 16 | Parameters 17 | ---------- 18 | url : str 19 | The URL to download. 20 | destination : str, None 21 | The destination of the file. If None is given the file is saved to a temporary directory. 22 | progress_bar : bool 23 | Whether to show a command-line progress bar while downloading. 24 | 25 | Returns 26 | ------- 27 | filename : str 28 | The location of the downloaded file. 29 | 30 | Notes 31 | ----- 32 | Progress bar use/example adapted from tqdm documentation: https://github.com/tqdm/tqdm 33 | """ 34 | 35 | def my_hook(t): 36 | last_b = [0] 37 | 38 | def inner(b=1, bsize=1, tsize=None): 39 | if tsize is not None: 40 | t.total = tsize 41 | if b > 0: 42 | t.update((b - last_b[0]) * bsize) 43 | last_b[0] = b 44 | 45 | return inner 46 | 47 | if progress_bar: 48 | with tqdm(unit='B', unit_scale=True, miniters=1, desc=url.split('/')[-1]) as t: 49 | filename, _ = urlretrieve(url, filename=destination, reporthook=my_hook(t)) 50 | else: 51 | filename, _ = urlretrieve(url, filename=destination) 52 | 53 | def load_model(model_name): 54 | if model_name == 'resnet18': 55 | return models.resnet18(pretrained=True) 56 | if model_name == 'resnet34': 57 | return models.resnet34(pretrained=True) 58 | if model_name == 'resnet50': 59 | return models.resnet50(pretrained=True) 60 | if model_name == 'resnet101': 61 | return models.resnet101(pretrained=True) 62 | if model_name == 'resnet152': 63 | return models.resnet152(pretrained=True) 64 | if model_name == 'resnext50_32x4d': 65 | return models.resnext50_32x4d(pretrained=True) 66 | if model_name == 'resnext101_32x8d': 67 | return models.resnext101_32x8d(pretrained=True) 68 | if model_name == 'wide_resnet50_2': 69 | return models.wide_resnet50_2(pretrained=True) 70 | if model_name == 'wide_resnet101_2': 71 | return models.wide_resnet101_2(pretrained=True) 72 | if model_name == 'resnet18_swsl': 73 | return torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models', 'resnet18_swsl') 74 | if model_name == 'resnet50_swsl': 75 | return torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models', 'resnet50_swsl') 76 | if model_name == 'resnext50_32x4d_swsl': 77 | return torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models', 'resnext50_32x4d_swsl') 78 | if model_name == 'resnext101_32x4d_swsl': 79 | return torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models', 'resnext101_32x4d_swsl') 80 | if model_name == 'resnext101_32x8d_swsl': 81 | return torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models', 'resnext101_32x8d_swsl') 82 | if model_name == 'resnext101_32x16d_swsl': 83 | return torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models', 'resnext101_32x16d_swsl') 84 | if model_name == 'resnet18_ssl': 85 | return torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models', 'resnet18_ssl') 86 | if model_name == 'resnet50_ssl': 87 | return torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models', 'resnet50_ssl') 88 | if model_name == 'resnext50_32x4d_ssl': 89 | return torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models', 'resnext50_32x4d_ssl') 90 | if model_name == 'resnext101_32x4d_ssl': 91 | return torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models', 'resnext101_32x4d_ssl') 92 | if model_name == 'resnext101_32x8d_ssl': 93 | return torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models', 'resnext101_32x8d_ssl') 94 | if model_name == 'resnext101_32x16d_ssl': 95 | return torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models', 'resnext101_32x16d_ssl') 96 | return None 97 | 98 | 99 | class Warp(object): 100 | def __init__(self, size, interpolation=Image.BILINEAR): 101 | self.size = int(size) 102 | self.interpolation = interpolation 103 | 104 | def __call__(self, img): 105 | return img.resize((self.size, self.size), self.interpolation) 106 | 107 | def __str__(self): 108 | return self.__class__.__name__ + ' (size={size}, interpolation={interpolation})'.format(size=self.size, 109 | interpolation=self.interpolation) 110 | class MultiScaleCrop(object): 111 | 112 | def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True): 113 | self.scales = scales if scales is not None else [1, 875, .75, .66] 114 | self.max_distort = max_distort 115 | self.fix_crop = fix_crop 116 | self.more_fix_crop = more_fix_crop 117 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size] 118 | self.interpolation = Image.BILINEAR 119 | 120 | def __call__(self, img): 121 | im_size = img.size 122 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) 123 | crop_img_group = img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) 124 | ret_img_group = crop_img_group.resize((self.input_size[0], self.input_size[1]), self.interpolation) 125 | return ret_img_group 126 | 127 | def _sample_crop_size(self, im_size): 128 | image_w, image_h = im_size[0], im_size[1] 129 | 130 | # find a crop size 131 | base_size = min(image_w, image_h) 132 | crop_sizes = [int(base_size * x) for x in self.scales] 133 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes] 134 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes] 135 | 136 | pairs = [] 137 | for i, h in enumerate(crop_h): 138 | for j, w in enumerate(crop_w): 139 | if abs(i - j) <= self.max_distort: 140 | pairs.append((w, h)) 141 | 142 | crop_pair = random.choice(pairs) 143 | if not self.fix_crop: 144 | w_offset = random.randint(0, image_w - crop_pair[0]) 145 | h_offset = random.randint(0, image_h - crop_pair[1]) 146 | else: 147 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1]) 148 | 149 | return crop_pair[0], crop_pair[1], w_offset, h_offset 150 | 151 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h): 152 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h) 153 | return random.choice(offsets) 154 | 155 | @staticmethod 156 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h): 157 | w_step = (image_w - crop_w) // 4 158 | h_step = (image_h - crop_h) // 4 159 | 160 | ret = list() 161 | ret.append((0, 0)) # upper left 162 | ret.append((4 * w_step, 0)) # upper right 163 | ret.append((0, 4 * h_step)) # lower left 164 | ret.append((4 * w_step, 4 * h_step)) # lower right 165 | ret.append((2 * w_step, 2 * h_step)) # center 166 | 167 | if more_fix_crop: 168 | ret.append((0, 2 * h_step)) # center left 169 | ret.append((4 * w_step, 2 * h_step)) # center right 170 | ret.append((2 * w_step, 4 * h_step)) # lower center 171 | ret.append((2 * w_step, 0 * h_step)) # upper center 172 | 173 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter 174 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter 175 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter 176 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter 177 | 178 | return ret 179 | 180 | 181 | def __str__(self): 182 | return self.__class__.__name__ 183 | 184 | class AveragePrecisionMeter(object): 185 | """ 186 | The APMeter measures the average precision per class. 187 | The APMeter is designed to operate on `NxK` Tensors `output` and 188 | `target`, and optionally a `Nx1` Tensor weight where (1) the `output` 189 | contains model output scores for `N` examples and `K` classes that ought to 190 | be higher when the model is more convinced that the example should be 191 | positively labeled, and smaller when the model believes the example should 192 | be negatively labeled (for instance, the output of a sigmoid function); (2) 193 | the `target` contains only values 0 (for negative examples) and 1 194 | (for positive examples); and (3) the `weight` ( > 0) represents weight for 195 | each sample. 196 | """ 197 | 198 | def __init__(self, difficult_examples=False, workers=8): 199 | super(AveragePrecisionMeter, self).__init__() 200 | self.reset() 201 | self.difficult_examples = difficult_examples 202 | self.workers = workers 203 | 204 | def reset(self): 205 | """Resets the meter with empty member variables""" 206 | self.scores = torch.FloatTensor(torch.FloatStorage()) 207 | self.targets = torch.LongTensor(torch.LongStorage()) 208 | 209 | def add(self, output, target): 210 | """ 211 | Args: 212 | output (Tensor): NxK tensor that for each of the N examples 213 | indicates the probability of the example belonging to each of 214 | the K classes, according to the model. The probabilities should 215 | sum to one over all classes 216 | target (Tensor): binary NxK tensort that encodes which of the K 217 | classes are associated with the N-th input 218 | (eg: a row [0, 1, 0, 1] indicates that the example is 219 | associated with classes 2 and 4) 220 | weight (optional, Tensor): Nx1 tensor representing the weight for 221 | each example (each weight > 0) 222 | """ 223 | if not torch.is_tensor(output): 224 | output = torch.from_numpy(output) 225 | if not torch.is_tensor(target): 226 | target = torch.from_numpy(target) 227 | 228 | if output.dim() == 1: 229 | output = output.view(-1, 1) 230 | else: 231 | assert output.dim() == 2, \ 232 | 'wrong output size (should be 1D or 2D with one column \ 233 | per class)' 234 | if target.dim() == 1: 235 | target = target.view(-1, 1) 236 | else: 237 | assert target.dim() == 2, \ 238 | 'wrong target size (should be 1D or 2D with one column \ 239 | per class)' 240 | if self.scores.numel() > 0: 241 | assert target.size(1) == self.targets.size(1), \ 242 | 'dimensions for output should match previously added examples.' 243 | 244 | # make sure storage is of sufficient size 245 | if self.scores.storage().size() < self.scores.numel() + output.numel(): 246 | new_size = math.ceil(self.scores.storage().size() * 1.5) 247 | self.scores.storage().resize_(int(new_size + output.numel())) 248 | self.targets.storage().resize_(int(new_size + output.numel())) 249 | 250 | # store scores and targets 251 | offset = self.scores.size(0) if self.scores.dim() > 0 else 0 252 | self.scores.resize_(offset + output.size(0), output.size(1)) 253 | self.targets.resize_(offset + target.size(0), target.size(1)) 254 | self.scores.narrow(0, offset, output.size(0)).copy_(output) 255 | self.targets.narrow(0, offset, target.size(0)).copy_(target) 256 | 257 | def value(self): 258 | """Returns the model's average precision for each class 259 | Return: 260 | ap (FloatTensor): 1xK tensor, with avg precision for each class k 261 | """ 262 | if self.scores.numel() == 0: 263 | return 0 264 | 265 | with mp.Pool(self.workers) as pool: 266 | results = pool.starmap(AveragePrecisionMeter.average_precision_process, list(zip(self.scores.T.tolist( 267 | ), self.targets.T.tolist(), itertools.repeat(self.difficult_examples)))) 268 | return torch.tensor(results) 269 | 270 | @staticmethod 271 | def average_precision_process(output, target, difficult_examples=True): 272 | output = torch.tensor(output) 273 | target = torch.tensor(target) 274 | ap = AveragePrecisionMeter.average_precision( 275 | output, target, difficult_examples) 276 | return ap 277 | 278 | @staticmethod 279 | def average_precision(output, target, difficult_examples=True): 280 | 281 | # sort examples 282 | sorted, indices = torch.sort(output, dim=0, descending=True) 283 | 284 | # Computes prec@i 285 | pos_count = 0. 286 | total_count = 0. 287 | precision_at_i = 0. 288 | for i in indices: 289 | label = target[i] 290 | if difficult_examples and label == 0: 291 | continue 292 | if label == 1: 293 | pos_count += 1 294 | total_count += 1 295 | if label == 1: 296 | precision_at_i += pos_count / total_count 297 | precision_at_i /= pos_count 298 | return precision_at_i 299 | 300 | def overall(self): 301 | if self.scores.numel() == 0: 302 | return 0 303 | scores = self.scores.cpu().numpy() 304 | targets = self.targets.cpu().numpy() 305 | targets[targets == -1] = 0 306 | return self.evaluation(scores, targets) 307 | 308 | def overall_topk(self, k): 309 | targets = self.targets.cpu().numpy() 310 | targets[targets == -1] = 0 311 | n, c = self.scores.size() 312 | scores = np.zeros((n, c)) - 1 313 | index = self.scores.topk(k, 1, True, True)[1].cpu().numpy() 314 | tmp = self.scores.cpu().numpy() 315 | for i in range(n): 316 | for ind in index[i]: 317 | scores[i, ind] = 1 if tmp[i, ind] >= 0 else -1 318 | return self.evaluation(scores, targets) 319 | 320 | @staticmethod 321 | def evaluation(scores_, targets_): 322 | n, n_class = scores_.shape 323 | Nc, Np, Ng = np.zeros(n_class), np.zeros(n_class), np.zeros(n_class) 324 | for k in range(n_class): 325 | scores = scores_[:, k] 326 | targets = targets_[:, k] 327 | targets[targets == -1] = 0 328 | Ng[k] = np.sum(targets == 1) 329 | Np[k] = np.sum(scores >= 0) 330 | Nc[k] = np.sum(targets * (scores >= 0)) 331 | Np[Np == 0] = 1 332 | OP = np.sum(Nc) / np.sum(Np) 333 | OR = np.sum(Nc) / np.sum(Ng) 334 | OF1 = (2 * OP * OR) / (OP + OR) 335 | 336 | CP = np.sum(Nc / Np) / n_class 337 | CR = np.sum(Nc / Ng) / n_class 338 | CF1 = (2 * CP * CR) / (CP + CR) 339 | return OP, OR, OF1, CP, CR, CF1 340 | 341 | class AdjacencyHelper: 342 | 343 | @staticmethod 344 | def gen_A(num_classes, t1, t2, adj_file): 345 | import pickle 346 | result = pickle.load(open(adj_file, 'rb')) 347 | _adj = result['adj'] 348 | _nums = result['nums'] 349 | _nums = _nums[:, np.newaxis] 350 | _adj = _adj / _nums 351 | _adj[_adj >= t1] = 0 352 | _adj[_adj < t2] = 0 353 | _adj[(_adj >= t2) & (_adj < t1)] = 1 354 | _adj = _adj * 0.25 / (_adj.sum(0, keepdims=True) + 1e-6) 355 | # _adj = _adj + np.identity(num_classes, np.int) 356 | return _adj 357 | 358 | @staticmethod 359 | def gen_M(mods, dims=2048, t=0.999): 360 | n = mods.shape[0] 361 | m = int(max(mods)) - int(min(mods)) + 1 362 | if m <= 1: 363 | return np.ones((n, dims)) 364 | x = [] 365 | for i in range(n): 366 | for j in range(m): 367 | x.append(np.repeat(t if (mods[i] - int(min(mods))) == j else ((1 - t) / (m - 1)), dims)) 368 | x = np.concatenate(x) 369 | return x.reshape((n, dims * m)) 370 | 371 | @staticmethod 372 | def gen_adj(A): 373 | D = torch.pow(A.sum(1).float(), -0.5) 374 | D = torch.diag(D) 375 | adj = torch.matmul(torch.matmul(A, D).t(), D) 376 | return adj 377 | 378 | --------------------------------------------------------------------------------