├── .gitignore ├── README.md ├── configs ├── coco_resnet101.yaml ├── coco_ssgrl.yaml ├── vg500_resnet101.yaml ├── vg500_ssgrl.yaml ├── voc2012_resnet101.yaml └── voc2012_ssgrl.yaml ├── evaluate.py ├── lib ├── data_loader.py ├── dataset.py ├── metrics.py └── util.py ├── models ├── __init__.py ├── resnet101.py ├── ssgrl.py ├── ssgrl_backbone.py └── ssgrl_utils.py ├── scripts ├── coco.py ├── label_analysis.py ├── label_count.py ├── preprocessing_ssgrl.py ├── vg500.py └── voc2012.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | */__pycache__ 3 | data 4 | initmodels 5 | logs 6 | checkpoints 7 | checkpoints_backup 8 | temp 9 | tmp 10 | 11 | train_coloss.py 12 | infer.py 13 | 14 | scripts/ssgrl_analysis.ipynb 15 | scripts/.ipynb_checkpoints -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MultiLabelClassification 2 | 3 | This is a multi label classification codebase in PyTorch. Currently, it supports ResNet101, SSGRL (a implement of paper "Learning Semantic-Specific Graph Representation for Multi-Label Image Recognition" based on official repository [HCPLab-SYSU/SSGRL](https://github.com/HCPLab-SYSU/SSGRL)) and training on Pascal Voc 2012, COCO and Visual Genome. 4 | 5 | ## Requirements 6 | - Python 3.6 7 | - PyTorch 1.1 8 | - TorchVision 0.3 9 | 10 | ## Data preparation 11 | Download datasets and symlink the paths to them as follows: 12 | ```bash 13 | mkdir data 14 | ln -s /path/to/mscoco data/coco 15 | ln -s /path/to/VisualGenome1.4 data/VisualGenome1.4 16 | ln -s /path/to/VOC2012 data/VOC2012 17 | 18 | mkdir tmp 19 | ln -s /path/to/glove.840B.300d.txt tmp/glove.840B.300d.txt 20 | ``` 21 | 22 | Running following scripts to preprocess datasets and generate desired data for SSGRL model. 23 | ``` 24 | python scripts/voc2012.py 25 | python scripts/coco.py 26 | python scripts/vg500.py 27 | 28 | python scripts/preprocessing_ssgrl.py --data [voc2012, coco, vg500] 29 | ``` 30 | 31 | ## Training 32 | 33 | ```bash 34 | python train.py --config $cfg_file_path 35 | ``` 36 | For example, with default optimizer(Adam) and loss(BCElogitloss), training resnet101 model on different dataset: 37 | ```bash 38 | python train.py --config configs/coco_resnet101.yaml 39 | python train.py --config configs/voc2012_resnet101.yaml 40 | ``` 41 | training ssgrl model on different dataset: 42 | ```bash 43 | python train.py --config configs/coco_ssgrl.yaml 44 | python train.py --config configs/voc2012_ssgrl.yaml 45 | ``` 46 | 47 | To resume training, you can run `train.py` with argument `--resume`. 48 | 49 | ## Pretrained models 50 | Pretrained models are provided on [google drive](https://drive.google.com/open?id=10Ex1hEWCZw8Gop0DN-kvnPVlVfuzTbll). 51 | 52 | ## Evaluation 53 | 54 | ```bash 55 | python evaluate.py --config $cfg_file_path 56 | ``` 57 | For example: 58 | ```bash 59 | python evaluate.py --config configs/vg500_resnet101.yaml 60 | python evaluate.py --config configs/vg500_ssgrl.yaml 61 | ``` 62 | 63 | ## Results 64 | Typically, The performances of pretrained multi label classification models are evaluated with mean average precision (mAP) and reported as follows: 65 | 66 | | models | VOC2012 | COCO | VG500 | 67 | | ----- | ------- | ----- | ---------| 68 | | ResNet101 | 0.901 | 0.802 | 0.293 | 69 | | SSGRL | 0.923 | 0.837 | 0.334 | 70 | 71 | ## Acknowledgements 72 | Thanks the official implement [SSGRL](https://github.com/HCPLab-SYSU/SSGRL) and awesome PyTorch team. 73 | -------------------------------------------------------------------------------- /configs/coco_resnet101.yaml: -------------------------------------------------------------------------------- 1 | data: &data coco 2 | model: &model resnet101 3 | num_classes: 80 4 | train_path: !join ['temp', *data, 'train.txt'] 5 | val_path: !join ['temp', *data, 'val.txt'] 6 | label_path: !join ['temp', *data, 'label.txt'] 7 | 8 | loss: BCElogitloss 9 | optimizer: Adam 10 | initmodel: ./initmodels/resnet101-5d3b4d8f.pth 11 | 12 | batch_size: 32 13 | scale_size: 640 14 | crop_size: 576 15 | 16 | num_workers: 8 17 | max_epoch: 200 18 | lr: 0.00001 19 | 20 | topk: 3 21 | threshold: 0.5 22 | 23 | name: &name !cat [*data, *model] 24 | log_dir: !join ['logs', *name] 25 | ckpt_dir: &ckpt_dir !join ['checkpoints', *name] 26 | ckpt_latest_path: !join [*ckpt_dir, 'latest_model.pth'] 27 | ckpt_best_path: !join [*ckpt_dir, 'best_model.pth'] 28 | 29 | output_dir: !join ['tmp', *name] -------------------------------------------------------------------------------- /configs/coco_ssgrl.yaml: -------------------------------------------------------------------------------- 1 | data: &data coco 2 | model: &model ssgrl 3 | num_classes: 80 4 | train_path: !join ['temp', *data, 'train.txt'] 5 | val_path: !join ['temp', *data, 'val.txt'] 6 | label_path: !join ['temp', *data, 'label.txt'] 7 | embedding_path: !join ['temp', *data, 'embeddings.npy'] 8 | graph_path: !join ['temp', *data, 'graph.npy'] 9 | 10 | loss: BCElogitloss 11 | optimizer: Adam 12 | initmodel: ./initmodels/resnet101-5d3b4d8f.pth 13 | 14 | batch_size: 8 15 | scale_size: 640 16 | crop_size: 576 17 | 18 | num_workers: 8 19 | max_epoch: 200 20 | lr: 0.00001 21 | 22 | topk: 3 23 | threshold: 0.5 24 | 25 | name: &name !cat [*data, *model] 26 | log_dir: !join ['logs', *name] 27 | ckpt_dir: &ckpt_dir !join ['checkpoints', *name] 28 | ckpt_latest_path: !join [*ckpt_dir, 'latest_model.pth'] 29 | ckpt_best_path: !join [*ckpt_dir, 'best_model.pth'] 30 | 31 | output_dir: !join ['tmp', *name] 32 | -------------------------------------------------------------------------------- /configs/vg500_resnet101.yaml: -------------------------------------------------------------------------------- 1 | data: &data vg500 2 | model: &model resnet101 3 | num_classes: 500 4 | train_path: !join ['temp', *data, 'train.txt'] 5 | val_path: !join ['temp', *data, 'val.txt'] 6 | label_path: !join ['temp', *data, 'label.txt'] 7 | 8 | loss: BCElogitloss 9 | optimizer: Adam 10 | initmodel: ./initmodels/resnet101-5d3b4d8f.pth 11 | 12 | batch_size: 32 13 | scale_size: 640 14 | crop_size: 576 15 | 16 | num_workers: 8 17 | max_epoch: 200 18 | lr: 0.00001 19 | 20 | topk: 3 21 | threshold: 0.5 22 | 23 | name: &name !cat [*data, *model] 24 | log_dir: !join ['logs', *name] 25 | ckpt_dir: &ckpt_dir !join ['checkpoints', *name] 26 | ckpt_latest_path: !join [*ckpt_dir, 'latest_model.pth'] 27 | ckpt_best_path: !join [*ckpt_dir, 'best_model.pth'] 28 | 29 | output_dir: !join ['tmp', *name] -------------------------------------------------------------------------------- /configs/vg500_ssgrl.yaml: -------------------------------------------------------------------------------- 1 | data: &data vg500 2 | model: &model ssgrl 3 | num_classes: 500 4 | train_path: !join ['temp', *data, 'train.txt'] 5 | val_path: !join ['temp', *data, 'val.txt'] 6 | label_path: !join ['temp', *data, 'label.txt'] 7 | embedding_path: !join ['temp', *data, 'embeddings.npy'] 8 | graph_path: !join ['temp', *data, 'graph.npy'] 9 | 10 | loss: BCElogitloss 11 | optimizer: Adam 12 | initmodel: ./initmodels/resnet101-5d3b4d8f.pth 13 | 14 | batch_size: 8 15 | scale_size: 640 16 | crop_size: 576 17 | 18 | num_workers: 8 19 | max_epoch: 200 20 | lr: 0.00001 21 | 22 | topk: 3 23 | threshold: 0.5 24 | 25 | name: &name !cat [*data, *model] 26 | log_dir: !join ['logs', *name] 27 | ckpt_dir: &ckpt_dir !join ['checkpoints', *name] 28 | ckpt_latest_path: !join [*ckpt_dir, 'latest_model.pth'] 29 | ckpt_best_path: !join [*ckpt_dir, 'best_model.pth'] 30 | 31 | output_dir: !join ['tmp', *name] 32 | -------------------------------------------------------------------------------- /configs/voc2012_resnet101.yaml: -------------------------------------------------------------------------------- 1 | data: &data voc2012 2 | model: &model resnet101 3 | num_classes: 20 4 | train_path: !join ['temp', *data, 'train.txt'] 5 | val_path: !join ['temp', *data, 'val.txt'] 6 | label_path: !join ['temp', *data, 'label.txt'] 7 | 8 | loss: BCElogitloss 9 | optimizer: Adam 10 | initmodel: ./initmodels/resnet101-5d3b4d8f.pth 11 | 12 | batch_size: 32 13 | scale_size: 640 14 | crop_size: 576 15 | 16 | num_workers: 8 17 | max_epoch: 200 18 | lr: 0.00001 19 | 20 | topk: 3 21 | threshold: 0.5 22 | 23 | name: &name !cat [*data, *model] 24 | log_dir: !join ['logs', *name] 25 | ckpt_dir: &ckpt_dir !join ['checkpoints', *name] 26 | ckpt_latest_path: !join [*ckpt_dir, 'latest_model.pth'] 27 | ckpt_best_path: !join [*ckpt_dir, 'best_model.pth'] 28 | 29 | output_dir: !join ['tmp', *name] -------------------------------------------------------------------------------- /configs/voc2012_ssgrl.yaml: -------------------------------------------------------------------------------- 1 | data: &data voc2012 2 | model: &model ssgrl 3 | num_classes: 20 4 | train_path: !join ['temp', *data, 'train.txt'] 5 | val_path: !join ['temp', *data, 'val.txt'] 6 | label_path: !join ['temp', *data, 'label.txt'] 7 | embedding_path: !join ['temp', *data, 'embeddings.npy'] 8 | graph_path: !join ['temp', *data, 'graph.npy'] 9 | 10 | loss: BCElogitloss 11 | optimizer: Adam 12 | initmodel: ./initmodels/resnet101-5d3b4d8f.pth 13 | 14 | batch_size: 8 15 | scale_size: 640 16 | crop_size: 576 17 | 18 | num_workers: 8 19 | max_epoch: 200 20 | lr: 0.00001 21 | 22 | topk: 3 23 | threshold: 0.5 24 | 25 | name: &name !cat [*data, *model] 26 | log_dir: !join ['logs', *name] 27 | ckpt_dir: &ckpt_dir !join ['checkpoints', *name] 28 | ckpt_latest_path: !join [*ckpt_dir, 'latest_model.pth'] 29 | ckpt_best_path: !join [*ckpt_dir, 'best_model.pth'] 30 | 31 | output_dir: !join ['tmp', *name] 32 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | # Created by: jasonseu 3 | # Created on: 2021-1-19 4 | # Email: zhuxuelin23@gmail.com 5 | # 6 | # Copyright © 2021 - CPSS Group 7 | # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 8 | import os 9 | import yaml 10 | import json 11 | import argparse 12 | from argparse import Namespace 13 | from tqdm import tqdm 14 | 15 | import torch 16 | from torch import nn 17 | from torchvision import transforms 18 | from torch.utils.data import DataLoader 19 | 20 | from models import model_factory 21 | from lib.util import * 22 | from lib.metrics import * 23 | from lib.dataset import MLDataset 24 | 25 | torch.backends.cudnn.benchmark = True 26 | 27 | class Evaluator(object): 28 | def __init__(self, args): 29 | super(Evaluator, self).__init__() 30 | self.args = args 31 | 32 | test_transform = transforms.Compose([ 33 | transforms.Resize((args.scale_size, args.scale_size)), 34 | transforms.CenterCrop(args.crop_size), 35 | transforms.ToTensor(), 36 | transforms.Normalize( 37 | mean=[0.485, 0.456, 0.406], 38 | std=[0.229, 0.224, 0.225] 39 | ) 40 | ]) 41 | ml_dataset = MLDataset(args.val_path, args.label_path, test_transform) 42 | self.data = ml_dataset.data 43 | self.labels = ml_dataset.labels 44 | self.data_loader = DataLoader( 45 | dataset=ml_dataset, 46 | batch_size=args.batch_size, 47 | shuffle=False, 48 | num_workers=args.num_workers, 49 | pin_memory=True 50 | ) 51 | 52 | self.model = model_factory[args.model](args, args.num_classes) 53 | self.model.cuda() 54 | 55 | if args.loss == 'BCElogitloss': 56 | self.criterion = nn.BCEWithLogitsLoss() 57 | elif args.loss == 'tencentloss': 58 | self.criterion = TencentLoss(args.num_classes) 59 | elif args.loss == 'focalloss': 60 | self.criterion = FocalLoss() 61 | 62 | self.args = args 63 | self.voc12_mAP = VOC12mAP(args.num_classes) 64 | self.average_topk_meter = TopkAverageMeter(args.num_classes, topk=args.topk) 65 | self.average_threshold_meter = ThresholdAverageMeter(args.num_classes, threshold=args.threshold) 66 | 67 | def run(self): 68 | model_dict = torch.load(self.args.ckpt_best_path) 69 | self.model.load_state_dict(model_dict) 70 | print(f'loading best checkpoint success') 71 | 72 | fw = open(os.path.join(self.args.output_dir, 'prediction.txt'), 'w') 73 | self.model.eval() 74 | self.voc12_mAP.reset() 75 | self.average_topk_meter.reset() 76 | self.average_threshold_meter.reset() 77 | desc = "EVALUATION - loss: {:.4f}" 78 | pbar = tqdm(total=len(self.data_loader), leave=False, desc=desc.format(0)) 79 | with torch.no_grad(): 80 | for i, batch in enumerate(self.data_loader): 81 | x, y = batch[0].cuda(), batch[1].cuda() 82 | pred_y = self.model(x) 83 | loss = self.criterion(pred_y, y) 84 | loss = loss.cpu().numpy() 85 | 86 | y = y.cpu().numpy() 87 | confidence = torch.sigmoid(pred_y) 88 | confidence = confidence.cpu().numpy() 89 | self.voc12_mAP.update(confidence, y) 90 | self.average_topk_meter.update(confidence, y) 91 | self.average_threshold_meter.update(confidence, y) 92 | 93 | topk_inds = np.argsort(-confidence)[:, :self.args.topk] 94 | i *= self.args.batch_size 95 | for j in range(x.size(0)): 96 | img_name = os.path.basename(self.data[i+j][0]) 97 | pred_labels = [self.labels[ind] for ind in topk_inds[j]] 98 | fw.write('{}\t{}\n'.format(img_name, ' '.join(pred_labels))) 99 | pbar.desc = desc.format(loss) 100 | pbar.update(1) 101 | pbar.close() 102 | fw.close() 103 | 104 | ap_list, mAP = self.voc12_mAP.compute() 105 | self.average_topk_meter.compute() 106 | self.average_threshold_meter.compute() 107 | 108 | res = { 109 | 'mAP': mAP, 110 | 'ap_list': ap_list, 111 | 'topk_cp': self.average_topk_meter.cp, 112 | 'topk_cr': self.average_topk_meter.cr, 113 | 'topk_cf1': self.average_topk_meter.cf1, 114 | 'topk_op': self.average_topk_meter.op, 115 | 'topk_or': self.average_topk_meter.or_, 116 | 'topk_of1': self.average_topk_meter.of1, 117 | 'threshold_cp': self.average_threshold_meter.cp, 118 | 'threshold_cr': self.average_threshold_meter.cr, 119 | 'threshold_cf1': self.average_threshold_meter.cf1, 120 | 'threshold_op': self.average_threshold_meter.op, 121 | 'threshold_or': self.average_threshold_meter.or_, 122 | 'threshold_of1': self.average_threshold_meter.of1, 123 | } 124 | with open(os.path.join(self.args.output_dir, 'result.json'), 'w') as fw: 125 | json.dump(res, fw) 126 | 127 | print('model {} data {} mAP: {}'.format(self.args.model, self.args.data, mAP)) 128 | 129 | 130 | if __name__ == "__main__": 131 | parser = argparse.ArgumentParser() 132 | parser.add_argument('--config', type=str, default='configs/voc2012_resnet101.yaml') 133 | args = parser.parse_args() 134 | cfg = load_cfg(args.config) 135 | args = Namespace(**cfg) 136 | print(args) 137 | 138 | if not os.path.exists(args.output_dir): 139 | os.makedirs(args.output_dir) 140 | evaluator = Evaluator(args) 141 | evaluator.run() 142 | -------------------------------------------------------------------------------- /lib/data_loader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from torchvision import transforms 3 | 4 | def get_transform(args, split): 5 | scale_size = args.scale_size 6 | crop_size = args.crop_size 7 | 8 | if split == 'train': 9 | transform = transforms.Compose([ 10 | transforms.Resize((scale_size, scale_size)), 11 | transforms.RandomChoice([ 12 | transforms.RandomCrop(640), 13 | transforms.RandomCrop(576), 14 | transforms.RandomCrop(512), 15 | transforms.RandomCrop(384), 16 | transforms.RandomCrop(320) 17 | ]), 18 | transforms.Resize((crop_size, crop_size)), 19 | transforms.ToTensor(), 20 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 21 | std=[0.229, 0.224, 0.225]) 22 | ]) 23 | else: 24 | transform = transforms.Compose([ 25 | transforms.Resize((scale_size, scale_size)), 26 | transforms.CenterCrop(crop_size), 27 | transforms.ToTensor(), 28 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 29 | std=[0.229, 0.224, 0.225]) 30 | ]) 31 | 32 | return transform 33 | 34 | def get_loader(dataset, args, split): 35 | transform = get_transform(args, split) 36 | dataset.transform = transform 37 | 38 | shuffle = True if split == 'train' else False 39 | data_loader = DataLoader( 40 | dataset=dataset, 41 | num_workers=args.num_workers, 42 | batch_size=args.batch_size, 43 | shuffle=shuffle 44 | ) 45 | 46 | return data_loader -------------------------------------------------------------------------------- /lib/dataset.py: -------------------------------------------------------------------------------- 1 | # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | # Created by: jasonseu 3 | # Created on: 2021-1-19 4 | # Email: zhuxuelin23@gmail.com 5 | # 6 | # Copyright © 2021 - CPSS Group 7 | # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 8 | import numpy as np 9 | from PIL import Image 10 | from torch.utils.data import Dataset 11 | 12 | 13 | class MLDataset(Dataset): 14 | def __init__(self, data_path, label_path, transform): 15 | super(MLDataset, self).__init__() 16 | 17 | self.labels = [line.strip() for line in open(label_path)] 18 | self.num_classes = len(self.labels) 19 | self.label2id = {label:i for i, label in enumerate(self.labels)} 20 | 21 | self.data = [] 22 | with open(data_path, 'r') as fr: 23 | for line in fr.readlines(): 24 | image_path, image_label = line.strip().split('\t') 25 | image_label = [self.label2id[l] for l in image_label.split(',')] 26 | self.data.append([image_path, image_label]) 27 | self.transform = transform 28 | 29 | def __getitem__(self, index): 30 | image_path, image_label = self.data[index] 31 | image_data = Image.open(image_path).convert('RGB') 32 | x = self.transform(image_data) 33 | 34 | # one-hot encoding for label 35 | y = np.zeros(self.num_classes).astype(np.float32) 36 | y[image_label] = 1.0 37 | return x, y 38 | 39 | def __len__(self): 40 | return len(self.data) 41 | -------------------------------------------------------------------------------- /lib/metrics.py: -------------------------------------------------------------------------------- 1 | # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | # Created by: jasonseu 3 | # Created on: 2021-3-9 4 | # Email: zhuxuelin23@gmail.com 5 | # 6 | # Copyright © 2021 - CPSS Group 7 | # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 8 | import abc 9 | import numpy as np 10 | 11 | 12 | class VOC12mAP(object): 13 | def __init__(self, num_classes): 14 | super(VOC12mAP, self).__init__() 15 | self._num_classes = num_classes 16 | 17 | def reset(self): 18 | self._predicted = np.array([], dtype=np.float32).reshape(0, self._num_classes) 19 | self._gt_label = np.array([], dtype=np.float32).reshape(0, self._num_classes) 20 | 21 | def update(self, predicted, gt_label): 22 | self._predicted = np.vstack((self._predicted, predicted)) 23 | self._gt_label = np.vstack((self._gt_label, gt_label)) 24 | 25 | def compute(self): 26 | return self._voc12_mAP() 27 | 28 | def _voc12_mAP(self): 29 | sample_num, num_classes = self._gt_label.shape 30 | ap_list = [] 31 | 32 | for class_id in range(num_classes): 33 | confidence = self._predicted[:, class_id] 34 | sorted_ind = np.argsort(-confidence) 35 | sorted_label = self._gt_label[sorted_ind, class_id] 36 | 37 | tp = (sorted_label == 1).astype(np.int64) # true positive 38 | fp = (sorted_label == 0).astype(np.int64) # false positive 39 | tp_num = max(sum(tp), np.finfo(np.float64).eps) 40 | tp = np.cumsum(tp) 41 | fp = np.cumsum(fp) 42 | recall = tp / float(tp_num) 43 | precision = tp / np.arange(1, sample_num + 1, dtype=np.float64) 44 | 45 | ap = self._voc_AP(recall, precision, tp_num) # average precision 46 | ap_list.append(ap) 47 | 48 | mAP = np.mean(ap_list) # mean average precision 49 | return ap_list, mAP 50 | 51 | def _voc_AP(self, recall, precision, tp_num): 52 | mrec = np.concatenate(([0.], recall, [1.])) 53 | mpre = np.concatenate(([0.], precision, [0.])) 54 | for i in range(mpre.size - 1, 0, -1): 55 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 56 | i = np.where(mrec[1:] != mrec[:-1])[0] 57 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 58 | return ap 59 | 60 | 61 | class AverageLoss(object): 62 | def __init__(self): 63 | super(AverageLoss, self).__init__() 64 | 65 | def reset(self): 66 | self._sum = 0 67 | self._counter = 0 68 | 69 | def update(self, loss, n=0): 70 | self._sum += loss * n 71 | self._counter += n 72 | 73 | def compute(self): 74 | return self._sum / self._counter 75 | 76 | 77 | class AverageMeter(object): 78 | def __init__(self, num_classes): 79 | super(AverageMeter, self).__init__() 80 | self.num_classes = num_classes 81 | 82 | def reset(self): 83 | self._right_pred_counter = np.zeros(self.num_classes) # right predicted image per-class counter 84 | self._pred_counter = np.zeros(self.num_classes) # predicted image per-class counter 85 | self._gt_counter = np.zeros(self.num_classes) # ground-truth image per-class counter 86 | 87 | def update(self, confidence, gt_label): 88 | self._count(confidence, gt_label) 89 | 90 | def compute(self): 91 | self._op = sum(self._right_pred_counter) / sum(self._pred_counter) 92 | self._or = sum(self._right_pred_counter) / sum(self._gt_counter) 93 | self._of1 = 2 * self._op * self._or / (self._op + self._or) 94 | self._right_pred_counter = np.maximum(self._right_pred_counter, np.finfo(np.float64).eps) 95 | self._pred_counter = np.maximum(self._pred_counter, np.finfo(np.float64).eps) 96 | self._gt_counter = np.maximum(self._gt_counter, np.finfo(np.float64).eps) 97 | self._cp = np.mean(self._right_pred_counter / self._pred_counter) 98 | self._cr = np.mean(self._right_pred_counter / self._gt_counter) 99 | self._cf1 = 2 * self._cp * self._cr / (self._cp + self._cr) 100 | 101 | @abc.abstractmethod 102 | def _count(self, confidence, gt_label): 103 | pass 104 | 105 | @property 106 | def op(self): # overall precision 107 | return self._op 108 | 109 | @property # overall recall 110 | def or_(self): 111 | return self._or 112 | 113 | @property # overall F1 114 | def of1(self): 115 | return self._of1 116 | 117 | @property # per-class precision 118 | def cp(self): 119 | return self._cp 120 | 121 | @property # per-class recall 122 | def cr(self): 123 | return self._cr 124 | 125 | @property # per-class F1 126 | def cf1(self): 127 | return self._cf1 128 | 129 | 130 | class TopkAverageMeter(AverageMeter): 131 | def __init__(self, num_classes, topk=3): 132 | super(TopkAverageMeter, self).__init__(num_classes) 133 | self.topk = topk 134 | 135 | def _count(self, confidence, gt_label): 136 | sample_num = confidence.shape[0] 137 | sorted_inds = np.argsort(-confidence, axis=-1) 138 | for i in range(sample_num): 139 | sample_gt_label = gt_label[i] 140 | topk_inds = sorted_inds[i][:self.topk] 141 | self._gt_counter[sample_gt_label == 1] += 1 142 | self._pred_counter[topk_inds] += 1 143 | correct_inds = topk_inds[sample_gt_label[topk_inds] == 1] 144 | self._right_pred_counter[correct_inds] += 1 145 | 146 | 147 | class ThresholdAverageMeter(AverageMeter): 148 | def __init__(self, num_classes, threshold=0.5): 149 | super(ThresholdAverageMeter, self).__init__(num_classes) 150 | self.threshold = threshold 151 | 152 | def _count(self, confidence, gt_label): 153 | sample_num = confidence.shape[0] 154 | for i in range(sample_num): 155 | sample_gt_label = gt_label[i] 156 | self._gt_counter[sample_gt_label == 1] += 1 157 | inds = np.argwhere(confidence[i] > self.threshold).squeeze(-1) 158 | self._pred_counter[inds] += 1 159 | correct_inds = inds[sample_gt_label[inds] == 1] 160 | self._right_pred_counter[correct_inds] += 1 161 | -------------------------------------------------------------------------------- /lib/util.py: -------------------------------------------------------------------------------- 1 | # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | # Created by: jasonseu 3 | # Created on: 2021-1-19 4 | # Email: zhuxuelin23@gmail.com 5 | # 6 | # Copyright © 2021 - CPSS Group 7 | # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 8 | import os 9 | import math 10 | import yaml 11 | 12 | import torch 13 | from torch import nn 14 | import torch.nn.functional as F 15 | import numpy as np 16 | 17 | 18 | class EarlyStopping(object): 19 | def __init__(self, patience): 20 | super(EarlyStopping, self).__init__() 21 | self.patience = patience 22 | self.counter = 0 23 | self.best_score = None 24 | 25 | def __call__(self, score): 26 | is_save, is_terminate = True, False 27 | if self.best_score is None: 28 | self.best_score = score 29 | elif self.best_score >= score: 30 | self.counter += 1 31 | if self.counter >= self.patience: 32 | is_terminate = True 33 | is_save = False 34 | else: 35 | self.best_score = score 36 | self.counter = 0 37 | return is_save, is_terminate 38 | 39 | 40 | class TencentLoss(object): 41 | def __init__(self, class_num, pos_weight=12.0): 42 | super(TencentLoss, self).__init__() 43 | self.pos_weight = torch.FloatTensor(class_num).fill_(pos_weight).cuda() 44 | self.pre_status = torch.IntTensor(class_num).fill_(-1).cuda() 45 | self.t = None 46 | 47 | def __call__(self, input, target): 48 | r = self._get_adaptive_weight(target) 49 | output = F.binary_cross_entropy_with_logits(input, target, weight=r, pos_weight=self.pos_weight) 50 | return output 51 | 52 | def _get_adaptive_weight(self, target): 53 | class_status = torch.sum(target, dim=0) 54 | cur_status = class_status > torch.tensor(0.0).cuda() 55 | cur_status = cur_status.type_as(self.pre_status) 56 | if torch.all(torch.eq(self.pre_status, cur_status)): 57 | self.t += 1 58 | else: 59 | self.t = 1 60 | self.pre_status = cur_status 61 | 62 | pos_r = max(0.01, math.log10(10/(0.01+self.t))) 63 | neg_r = max(0.01, math.log10(10/(8+self.t))) 64 | pos_r = target.clone().fill_(pos_r) 65 | neg_r = target.clone().fill_(neg_r) 66 | 67 | r = torch.where(target == 1, pos_r, neg_r) 68 | return r 69 | 70 | class FocalLoss(object): 71 | def __init__(self, alpha=0.5, gamma=2.0): 72 | super(FocalLoss, self).__init__() 73 | self.alpha = alpha 74 | self.gamma = gamma 75 | 76 | def __call__(self, input, target): 77 | input_prob = torch.sigmoid(input) 78 | hard_easy_weight = (1 - input_prob) * target + input_prob * (1 - target) 79 | posi_nega_weight = self.alpha * target + (1 - self.alpha) * (1 - target) 80 | focal_weight = (posi_nega_weight * torch.pow(hard_easy_weight, self.gamma)).detach() 81 | focal_loss = F.binary_cross_entropy_with_logits(input, target, weight=focal_weight) 82 | return focal_loss 83 | 84 | class CoocurrenceLoss(object): 85 | def __init__(self, label_comatrix_path) -> None: 86 | super().__init__() 87 | label_comatrix = np.load(label_comatrix_path).astype(np.float32) 88 | self.gt = torch.flatten(torch.tensor(label_comatrix)).cuda() 89 | self.kl_loss = nn.KLDivLoss() 90 | 91 | def __call__(self, input): 92 | temp = torch.bmm(input.unsqueeze(2), input.unsqueeze(1)) 93 | batch_size = temp.size(0) 94 | temp = torch.mean(temp.view(batch_size, -1), dim=0) 95 | coloss = self.kl_loss(temp, self.gt) 96 | return coloss 97 | 98 | def load_cfg(cfg_path): 99 | yaml.add_constructor('!cat', lambda loader, node: '_'.join(loader.construct_sequence(node))) 100 | yaml.add_constructor('!join', lambda loader, node: os.path.join(*loader.construct_sequence(node))) 101 | with open(cfg_path, 'r') as fr: 102 | cfg = yaml.load(fr) 103 | return cfg -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.resnet101 import ResNet101 2 | from models.ssgrl import SSGRL 3 | 4 | model_factory = { 5 | 'resnet101': ResNet101, 6 | 'ssgrl': SSGRL 7 | } -------------------------------------------------------------------------------- /models/resnet101.py: -------------------------------------------------------------------------------- 1 | # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | # Created by: jasonseu 3 | # Created on: 2021-1-19 4 | # Email: zhuxuelin23@gmail.com 5 | # 6 | # Copyright © 2021 - CPSS Group 7 | # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 8 | import torch 9 | from torch import nn 10 | from torchvision import models 11 | 12 | class ResNet101(nn.Module): 13 | 14 | def __init__(self, args, num_labels): 15 | super(ResNet101, self).__init__() 16 | self.network = models.resnet101(pretrained=False, num_classes=num_labels) 17 | 18 | print('loading pretrained model from imagenet') 19 | model_dict = self.network.state_dict() 20 | resnet_pretrained = torch.load(args.initmodel) 21 | pretrain_dict = {k:v for k, v in resnet_pretrained.items() if not k.startswith('fc')} 22 | model_dict.update(pretrain_dict) 23 | self.network.load_state_dict(model_dict) 24 | 25 | for param in self.network.parameters(): 26 | param.requires_grad = False 27 | for param in self.network.layer4.parameters(): 28 | param.requires_grad = True 29 | self.network.fc.requires_grad = True 30 | 31 | def forward(self, x): 32 | x = self.network(x) 33 | return x 34 | -------------------------------------------------------------------------------- /models/ssgrl.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | import torch.nn.functional as F 7 | 8 | from models.ssgrl_backbone import resnet101 9 | from models.ssgrl_utils import Semantic, GGNN, Element_Wise_Layer 10 | 11 | class SSGRL(nn.Module): 12 | def __init__(self, args, num_classes=80): 13 | super(SSGRL, self).__init__() 14 | self.args = args 15 | 16 | self.num_classes = num_classes 17 | self.image_feature_dim = 2048 18 | self.output_dim = 2048 19 | self.word_feature_dim = 300 20 | self.word_file = self.args.embedding_path 21 | self.graph_file = self.args.graph_path 22 | self.time_step = 3 23 | 24 | self._word_features = self._load_features() 25 | self._in_matrix, self._out_matrix = self.load_matrix() 26 | 27 | self.word_semantic = Semantic( 28 | num_classes=self.num_classes, 29 | image_feature_dim=self.image_feature_dim, 30 | word_feature_dim=self.word_feature_dim 31 | ) 32 | 33 | self.graph_net = GGNN( 34 | input_dim=self.image_feature_dim, 35 | time_step=self.time_step, 36 | in_matrix=self._in_matrix, 37 | out_matrix=self._out_matrix 38 | ) 39 | 40 | self.fc_output = nn.Linear(2 * self.image_feature_dim, self.output_dim) 41 | self.classifiers = Element_Wise_Layer(self.num_classes, self.output_dim) 42 | 43 | self.resnet_101 = resnet101() 44 | self._load_pretrain_model() 45 | for param in self.resnet_101.parameters(): 46 | param.requires_grad = False 47 | for param in self.resnet_101.layer4.parameters(): 48 | param.requires_grad = True 49 | 50 | def forward(self, x): 51 | batch_size = x.size()[0] 52 | img_feature_map = self.resnet_101(x) 53 | graph_net_input = self.word_semantic(batch_size, img_feature_map, self._word_features) 54 | graph_net_feature = self.graph_net(graph_net_input) 55 | 56 | output = torch.cat((graph_net_feature.view(batch_size*self.num_classes,-1), graph_net_input.view(-1, self.image_feature_dim)), 1) 57 | output = self.fc_output(output) 58 | output = torch.tanh(output) 59 | output = output.contiguous().view(batch_size, self.num_classes, self.output_dim) 60 | result = self.classifiers(output) 61 | return result 62 | 63 | def _load_pretrain_model(self): 64 | model_dict = self.resnet_101.state_dict() 65 | print('loading pretrained model from imagenet') 66 | resnet_pretrained = torch.load(self.args.initmodel) 67 | pretrain_dict = {k:v for k, v in resnet_pretrained.items() if not k.startswith('fc')} 68 | model_dict.update(pretrain_dict) 69 | self.resnet_101.load_state_dict(model_dict) 70 | 71 | def _load_features(self): 72 | return torch.from_numpy(np.load(self.word_file).astype(np.float32)).cuda() 73 | 74 | def load_matrix(self): 75 | mat = np.load(self.graph_file) 76 | _in_matrix, _out_matrix = mat.astype(np.float32), mat.T.astype(np.float32) 77 | _in_matrix = torch.from_numpy(_in_matrix).cuda() 78 | _out_matrix = torch.from_numpy(_out_matrix).cuda() 79 | return _in_matrix, _out_matrix 80 | -------------------------------------------------------------------------------- /models/ssgrl_backbone.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | import torch.nn.functional as F 5 | 6 | 7 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 8 | 'resnet152'] 9 | 10 | 11 | model_urls = { 12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 17 | } 18 | 19 | 20 | def conv3x3(in_planes, out_planes, stride=1): 21 | """3x3 convolution with padding""" 22 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 23 | padding=1, bias=False) 24 | 25 | 26 | class BasicBlock(nn.Module): 27 | expansion = 1 28 | 29 | def __init__(self, inplanes, planes, stride=1, downsample=None): 30 | super(BasicBlock, self).__init__() 31 | self.conv1 = conv3x3(inplanes, planes, stride) 32 | self.bn1 = nn.BatchNorm2d(planes) 33 | self.relu = nn.ReLU(inplace=True) 34 | self.conv2 = conv3x3(planes, planes) 35 | self.bn2 = nn.BatchNorm2d(planes) 36 | self.downsample = downsample 37 | self.stride = stride 38 | 39 | def forward(self, x): 40 | residual = x 41 | 42 | out = self.conv1(x) 43 | out = self.bn1(out) 44 | out = self.relu(out) 45 | 46 | out = self.conv2(out) 47 | out = self.bn2(out) 48 | 49 | if self.downsample is not None: 50 | residual = self.downsample(x) 51 | 52 | out += residual 53 | out = self.relu(out) 54 | 55 | return out 56 | 57 | 58 | class Bottleneck(nn.Module): 59 | expansion = 4 60 | 61 | def __init__(self, inplanes, planes, stride=1, downsample=None): 62 | super(Bottleneck, self).__init__() 63 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 64 | self.bn1 = nn.BatchNorm2d(planes) 65 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 66 | padding=1, bias=False) 67 | self.bn2 = nn.BatchNorm2d(planes) 68 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 69 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 70 | self.relu = nn.ReLU(inplace=True) 71 | self.downsample = downsample 72 | self.stride = stride 73 | 74 | def forward(self, x): 75 | residual = x 76 | 77 | out = self.conv1(x) 78 | out = self.bn1(out) 79 | out = self.relu(out) 80 | 81 | out = self.conv2(out) 82 | out = self.bn2(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv3(out) 86 | out = self.bn3(out) 87 | 88 | if self.downsample is not None: 89 | residual = self.downsample(x) 90 | 91 | out += residual 92 | out = self.relu(out) 93 | 94 | return out 95 | 96 | 97 | class ResNet(nn.Module): 98 | 99 | def __init__(self, block, layers, num_classes=80, avg_pool_kernel_size = 7): 100 | self.inplanes = 64 101 | super(ResNet, self).__init__() 102 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 103 | bias=False) 104 | self.bn1 = nn.BatchNorm2d(64) 105 | self.relu = nn.ReLU(inplace=True) 106 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 107 | self.layer1 = self._make_layer(block, 64, layers[0]) 108 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 109 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 110 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 111 | 112 | self.avgpool1 = nn.AvgPool2d(2 , stride=2) 113 | #self.avgpool2 = nn.AvgPool2d(avg_pool_kernel_size, stride=1) 114 | #self.fc = nn.Linear(8192, num_classes) 115 | #self.fc = nn.Linear(512 * block.expansion, num_classes) 116 | 117 | for m in self.modules(): 118 | if isinstance(m, nn.Conv2d): 119 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 120 | elif isinstance(m, nn.BatchNorm2d): 121 | nn.init.constant_(m.weight, 1) 122 | nn.init.constant_(m.bias, 0) 123 | 124 | def _make_layer(self, block, planes, blocks, stride=1): 125 | downsample = None 126 | if stride != 1 or self.inplanes != planes * block.expansion: 127 | downsample = nn.Sequential( 128 | nn.Conv2d(self.inplanes, planes * block.expansion, 129 | kernel_size=1, stride=stride, bias=False), 130 | nn.BatchNorm2d(planes * block.expansion), 131 | ) 132 | 133 | layers = [] 134 | layers.append(block(self.inplanes, planes, stride, downsample)) 135 | self.inplanes = planes * block.expansion 136 | for i in range(1, blocks): 137 | layers.append(block(self.inplanes, planes)) 138 | 139 | return nn.Sequential(*layers) 140 | 141 | def forward(self, x): 142 | x = self.conv1(x) 143 | x = self.bn1(x) 144 | x = self.relu(x) 145 | x = self.maxpool(x) 146 | 147 | x = self.layer1(x) 148 | x = self.layer2(x) 149 | x = self.layer3(x) 150 | x = self.layer4(x) 151 | x = self.avgpool1(x) 152 | #print(x.shape) 153 | #x = self.avgpool2(x) 154 | #x = x.view(x.size(0), -1) 155 | #x = self.fc(x) 156 | 157 | return x 158 | 159 | 160 | def resnet18(pretrained=False, **kwargs): 161 | """Constructs a ResNet-18 model. 162 | 163 | Args: 164 | pretrained (bool): If True, returns a model pre-trained on ImageNet 165 | """ 166 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 167 | if pretrained: 168 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 169 | return model 170 | 171 | 172 | def resnet34(pretrained=False, **kwargs): 173 | """Constructs a ResNet-34 model. 174 | 175 | Args: 176 | pretrained (bool): If True, returns a model pre-trained on ImageNet 177 | """ 178 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 179 | if pretrained: 180 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 181 | return model 182 | 183 | 184 | def resnet50(pretrained=False, **kwargs): 185 | """Constructs a ResNet-50 model. 186 | 187 | Args: 188 | pretrained (bool): If True, returns a model pre-trained on ImageNet 189 | """ 190 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 191 | if pretrained: 192 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 193 | return model 194 | 195 | 196 | def resnet101(pretrained=False, **kwargs): 197 | """Constructs a ResNet-101 model. 198 | 199 | Args: 200 | pretrained (bool): If True, returns a model pre-trained on ImageNet 201 | """ 202 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 203 | #if pretrained: 204 | #model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 205 | return model 206 | 207 | 208 | def resnet152(pretrained=False, **kwargs): 209 | """Constructs a ResNet-152 model. 210 | 211 | Args: 212 | pretrained (bool): If True, returns a model pre-trained on ImageNet 213 | """ 214 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 215 | if pretrained: 216 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 217 | return model 218 | -------------------------------------------------------------------------------- /models/ssgrl_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | from torch.nn.parameter import Parameter 9 | 10 | 11 | class Semantic(nn.Module): 12 | def __init__(self, num_classes, image_feature_dim, word_feature_dim, intermediary_dim=1024): 13 | super(Semantic, self).__init__() 14 | self.num_classes = num_classes 15 | self.image_feature_dim = image_feature_dim 16 | self.word_feature_dim = word_feature_dim 17 | self.intermediary_dim = intermediary_dim 18 | self.fc_1 = nn.Linear(self.image_feature_dim, self.intermediary_dim, bias=False) 19 | self.fc_2 = nn.Linear(self.word_feature_dim, self.intermediary_dim, bias=False) 20 | self.fc_3 = nn.Linear(self.intermediary_dim, self.intermediary_dim) 21 | self.fc_a = nn.Linear(self.intermediary_dim, 1) 22 | 23 | def forward(self, batch_size, img_feature_map, word_features): 24 | # temp, t = self.ex(batch_size, img_feature_map, word_features) 25 | # img_feature_map: 2 x 2048 x 3 x 3 26 | convsize = img_feature_map.size(3) 27 | 28 | f_wh_feature = img_feature_map.permute((0, 2, 3, 1)).contiguous().view(batch_size * convsize * convsize, -1) 29 | f_wh_feature = self.fc_1(f_wh_feature) # 18 x 1024 30 | f_wd_feature = self.fc_2(word_features).view(self.num_classes, 1, self.intermediary_dim) # 80 x 1024 - > 80 x 1 x 1024 31 | lb_feature = self.fc_3(torch.tanh(f_wd_feature * f_wh_feature).view(-1, self.intermediary_dim)) # 80 x 18 x 1024 -> 1440 x 1024 32 | 33 | coefficient = self.fc_a(lb_feature) # 1440 x 1 34 | coefficient = coefficient.view(self.num_classes, batch_size, -1).transpose(0, 1) # 80 x 2 x 9 -> 2 x 80 x 9 35 | # b = torch.all(torch.eq(coefficient, t)) 36 | 37 | coefficient = F.softmax(coefficient, dim=2) # 2 x 80 x 9 38 | img_feature_map = img_feature_map.permute(0, 2, 3, 1).view(batch_size, convsize * convsize, -1) # 2 x 3 x 3 x 2048 -> 2 x 9 x 2048 39 | 40 | graph_net_input = torch.bmm(coefficient, img_feature_map) # 2 x 80 x 2048 41 | return graph_net_input 42 | 43 | # def ex(self, batch_size, img_feature_map, word_features): 44 | # convsize = img_feature_map.size()[3] 45 | 46 | # img_feature_map = torch.transpose(torch.transpose(img_feature_map, 1, 2),2,3) 47 | # f_wh_feature = img_feature_map.contiguous().view(batch_size*convsize*convsize, -1) # 18 x 2048 48 | # f_wh_feature = self.fc_1(f_wh_feature).view(batch_size*convsize*convsize, 1, -1).repeat(1, self.num_classes, 1) # 18 x 80 x 1024 49 | 50 | # f_wd_feature = self.fc_2(word_features).view(1, self.num_classes, 1024).repeat(batch_size*convsize*convsize,1,1) # 18 x 80 x 1024 51 | # lb_feature = self.fc_3(torch.tanh(f_wh_feature*f_wd_feature).view(-1,1024)) # 18 x 80 x 1024 -> 1440 x 1024 52 | # coefficient = self.fc_a(lb_feature) # 1440 x 1 53 | 54 | # t = self.fc_a(self.fc_3(torch.tanh(f_wh_feature*f_wd_feature).transpose(0, 1).contiguous().view(-1,1024))) # 55 | 56 | # 1440 x 1 -> 2 x 3 x 3 x 80 -> 2 x 80 x 3 x 3 -> 2 x 80 x 9 57 | # coefficient = torch.transpose(torch.transpose(coefficient.view(batch_size, convsize, convsize, self.num_classes),2,3),1,2).view(batch_size, self.num_classes, -1) 58 | # t = coefficient 59 | # coefficient = F.softmax(coefficient, dim=2) 60 | # coefficient = coefficient.view(batch_size, self.num_classes, convsize, convsize) # 2 x 80 x 3 x 3 61 | # coefficient = torch.transpose(torch.transpose(coefficient,1,2),2,3) # 2 x 3 x 3 x 80 62 | # coefficient = coefficient.view(batch_size, convsize, convsize, self.num_classes, 1).repeat(1,1,1,1,self.image_feature_dim) # 2 x 3 x 3 x 80 x 2048 63 | # img_feature_map = img_feature_map.view(batch_size, convsize, convsize, 1, self.image_feature_dim).repeat(1, 1, 1, self.num_classes, 1)* coefficient # 2 x 3 x 3 x 80 x 2048 64 | # graph_net_input = torch.sum(torch.sum(img_feature_map,1) ,1) # 2 x 80 x 2048 65 | # return graph_net_input, t 66 | 67 | 68 | class GGNN(nn.Module): 69 | def __init__(self, input_dim, time_step, in_matrix, out_matrix): 70 | super(GGNN, self).__init__() 71 | self.input_dim = input_dim 72 | self.time_step = time_step 73 | self._in_matrix = in_matrix 74 | self._out_matrix = out_matrix 75 | 76 | self.fc_eq3_w = nn.Linear(2*input_dim, input_dim) 77 | self.fc_eq3_u = nn.Linear(input_dim, input_dim) 78 | self.fc_eq4_w = nn.Linear(2*input_dim, input_dim) 79 | self.fc_eq4_u = nn.Linear(input_dim, input_dim) 80 | self.fc_eq5_w = nn.Linear(2*input_dim, input_dim) 81 | self.fc_eq5_u = nn.Linear(input_dim, input_dim) 82 | 83 | def forward(self, input): 84 | batch_size = input.size(0) 85 | input = input.view(-1, self.input_dim) 86 | node_num = self._in_matrix.size(0) 87 | batch_aog_nodes = input.view(batch_size, node_num, self.input_dim) 88 | batch_in_matrix = self._in_matrix.repeat(batch_size, 1).view(batch_size, node_num, -1) 89 | batch_out_matrix = self._out_matrix.repeat(batch_size, 1).view(batch_size, node_num, -1) 90 | for t in range(self.time_step): 91 | # eq(2) 92 | av = torch.cat((torch.bmm(batch_in_matrix, batch_aog_nodes), torch.bmm(batch_out_matrix, batch_aog_nodes)), 2) 93 | av = av.view(batch_size * node_num, -1) 94 | flatten_aog_nodes = batch_aog_nodes.view(batch_size * node_num, -1) 95 | # eq(3) 96 | zv = torch.sigmoid(self.fc_eq3_w(av) + self.fc_eq3_u(flatten_aog_nodes)) 97 | # eq(4) 98 | rv = torch.sigmoid(self.fc_eq4_w(av) + self.fc_eq3_u(flatten_aog_nodes)) 99 | #eq(5) 100 | hv = torch.tanh(self.fc_eq5_w(av) + self.fc_eq5_u(rv * flatten_aog_nodes)) 101 | 102 | flatten_aog_nodes = (1 - zv) * flatten_aog_nodes + zv * hv 103 | batch_aog_nodes = flatten_aog_nodes.view(batch_size, node_num, -1) 104 | return batch_aog_nodes 105 | 106 | class Element_Wise_Layer(nn.Module): 107 | def __init__(self, in_features, out_features, bias=True): 108 | super(Element_Wise_Layer, self).__init__() 109 | self.in_features = in_features 110 | self.out_features = out_features 111 | self.weight = Parameter(torch.Tensor(in_features, out_features)) 112 | if bias: 113 | self.bias = Parameter(torch.Tensor(in_features)) 114 | else: 115 | self.register_parameter('bias', None) 116 | self.reset_parameters() 117 | 118 | def reset_parameters(self): 119 | stdv = 1. / math.sqrt(self.weight.size(1)) 120 | for i in range(self.in_features): 121 | self.weight[i].data.uniform_(-stdv, stdv) 122 | if self.bias is not None: 123 | for i in range(self.in_features): 124 | self.bias[i].data.uniform_(-stdv, stdv) 125 | 126 | def forward(self, input): 127 | x = input * self.weight 128 | x = torch.sum(x,2) 129 | if self.bias is not None: 130 | x = x + self.bias 131 | return x 132 | 133 | def extra_repr(self): 134 | return 'in_features={}, out_features={}, bias={}'.format( 135 | self.in_features, self.out_features, self.bias is not None) 136 | 137 | -------------------------------------------------------------------------------- /scripts/coco.py: -------------------------------------------------------------------------------- 1 | # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | # Created by: jasonseu 3 | # Created on: 2021-1-19 4 | # Email: zhuxuelin23@gmail.com 5 | # 6 | # Copyright © 2021 - CPSS Group 7 | # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 8 | import os 9 | import json 10 | from collections import defaultdict 11 | 12 | coco_dir = 'data/coco' 13 | save_dir = 'temp/coco' 14 | if not os.path.exists(save_dir): 15 | os.makedirs(save_dir) 16 | 17 | data = json.load(open(os.path.join(coco_dir, 'annotations/instances_train2014.json'))) 18 | categories = [] 19 | catId2catName = {} 20 | for line in data['categories']: 21 | catId2catName[line['id']] = line['name'] 22 | categories.append(line['name']) 23 | imgId2imgName = {} 24 | for line in data['images']: 25 | imgId2imgName[line['id']] = line['file_name'] 26 | 27 | imgName2catName = defaultdict(list) 28 | for line in data['annotations']: 29 | cat_id = line['category_id'] 30 | cat_name = catId2catName[cat_id] 31 | img_id = line['image_id'] 32 | img_name = imgId2imgName[img_id] 33 | img_path = os.path.join(coco_dir, 'train2014', img_name) 34 | if cat_name not in imgName2catName[img_path]: 35 | imgName2catName[img_path].append(cat_name) 36 | 37 | train_data = ['{}\t{}\n'.format(k, ','.join(v)) for k, v in imgName2catName.items()] 38 | print(f"total training data number: {len(train_data)}") 39 | 40 | with open(os.path.join(save_dir, 'label.txt'), 'w') as fw: 41 | fw.writelines(['{}\n'.format(x) for x in categories]) 42 | with open(os.path.join(save_dir, 'train.txt'), 'w') as fw: 43 | fw.writelines(train_data) 44 | 45 | 46 | data = json.load(open(os.path.join(coco_dir, 'annotations/instances_val2014.json'))) 47 | catId2catName = {} 48 | for line in data['categories']: 49 | catId2catName[line['id']] = line['name'] 50 | imgId2imgName = {} 51 | for line in data['images']: 52 | imgId2imgName[line['id']] = line['file_name'] 53 | 54 | imgName2catName = defaultdict(list) 55 | for line in data['annotations']: 56 | cat_id = line['category_id'] 57 | cat_name = catId2catName[cat_id] 58 | img_id = line['image_id'] 59 | img_name = imgId2imgName[img_id] 60 | img_path = os.path.join(coco_dir, 'val2014', img_name) 61 | if cat_name not in imgName2catName[img_path]: 62 | imgName2catName[img_path].append(cat_name) 63 | 64 | imgName2catName = ['{}\t{}\n'.format(k, ','.join(v)) for k, v in imgName2catName.items()] 65 | print(f"total test data number: {len(imgName2catName)}") 66 | 67 | with open(os.path.join(save_dir, 'val.txt'), 'w') as fw: 68 | fw.writelines(imgName2catName) 69 | -------------------------------------------------------------------------------- /scripts/label_analysis.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import numpy as np 5 | from matplotlib import pyplot as plt 6 | from matplotlib import cm 7 | from matplotlib import axes 8 | from matplotlib.font_manager import FontProperties 9 | 10 | 11 | def draw_heatmap(matrix, num_labels, name): 12 | xLabel = list(range(num_labels)) 13 | yLabel = list(range(num_labels)) 14 | fig = plt.figure() 15 | ax = fig.add_subplot(111) 16 | # ax.set_yticks(range(num_labels)) 17 | # ax.set_yticklabels(yLabel) 18 | # ax.set_xticks(range(num_labels)) 19 | # ax.set_xticklabels(xLabel) 20 | im = ax.imshow(matrix, cmap=plt.cm.hot_r) 21 | plt.colorbar(im) 22 | plt.savefig('tmp/{}_heatmap.jpg'.format(name)) 23 | 24 | def main(name): 25 | data_path = os.path.join('temp', name, 'train.txt') 26 | label_path = os.path.join('temp', name, 'label.txt') 27 | label_list = [line.strip() for line in open(label_path)] 28 | label2id = {l:i for i, l in enumerate(label_list)} 29 | num_labels = len(label_list) 30 | coocurrence_matrix = np.zeros((num_labels, num_labels)) 31 | for line in open(data_path): 32 | temp = line.strip().split('\t')[-1].split(',') 33 | labelid_list = [label2id[t] for t in temp] 34 | for i in range(len(labelid_list)): 35 | for j in range(i+1, len(labelid_list)): 36 | x = labelid_list[i] 37 | y = labelid_list[j] 38 | coocurrence_matrix[x, y] += 1 39 | coocurrence_matrix[y, x] += 1 40 | temp = coocurrence_matrix / coocurrence_matrix.sum() 41 | draw_heatmap(coocurrence_matrix, num_labels, name) 42 | np.save(os.path.join('temp', name, 'label_coocurrence.npy'), coocurrence_matrix) 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument('--data', type=str, default='coco') 47 | args = parser.parse_args() 48 | main(args.data) -------------------------------------------------------------------------------- /scripts/label_count.py: -------------------------------------------------------------------------------- 1 | label_num = {} 2 | with open('temp/voc2012/train.txt', 'r') as fr: 3 | for line in fr: 4 | temp = line.strip().split('\t')[-1].split(',') 5 | for t in temp: 6 | if t in label_num.keys(): 7 | label_num[t] += 1 8 | else: 9 | label_num[t] = 1 10 | 11 | print('voc2012 average label number:', label_num) 12 | 13 | 14 | # sample_num = 0 15 | # label_num = 0 16 | # with open('temp/coco/train.txt', 'r') as fr: 17 | # for line in fr: 18 | # temp = line.strip().split('\t')[-1].split(',') 19 | # label_num += len(temp) 20 | # sample_num += 1 21 | 22 | # print('coco average label number:', label_num / sample_num) 23 | 24 | 25 | # sample_num = 0 26 | # label_num = 0 27 | # with open('temp/visual_genome/vg500_train.txt', 'r') as fr: 28 | # for line in fr: 29 | # temp = line.strip().split('\t')[-1].split(',') 30 | # label_num += len(temp) 31 | # sample_num += 1 32 | 33 | # print('visual_genome average label number:', label_num / sample_num) 34 | -------------------------------------------------------------------------------- /scripts/preprocessing_ssgrl.py: -------------------------------------------------------------------------------- 1 | # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | # Created by: jasonseu 3 | # Created on: 2021-3-9 4 | # Email: zhuxuelin23@gmail.com 5 | # 6 | # Copyright © 2021 - CPSS Group 7 | # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 8 | import os 9 | import argparse 10 | import numpy as np 11 | 12 | 13 | glove_path = 'tmp/glove.840B.300d.txt' 14 | 15 | # generate adjacency matrix 16 | def preprocessing_for_ssgrl(data): 17 | dir_name = os.path.join('temp', data) 18 | 19 | label_path = os.path.join(dir_name, 'label.txt') 20 | train_path = os.path.join(dir_name, 'train.txt') 21 | val_path = os.path.join(dir_name, 'val.txt') 22 | graph_path = os.path.join(dir_name, 'graph.npy') 23 | embed_path = os.path.join(dir_name, 'embeddings.npy') 24 | 25 | categories = [line.strip() for line in open(label_path).readlines()] 26 | cate2id = {cat:i for i, cat in enumerate(categories)} 27 | adjacency_matrix = np.zeros((len(categories), len(categories))) 28 | 29 | with open(train_path, 'r') as fr: 30 | data = [line.strip().split('\t')[1].split(',') for line in fr.readlines()] 31 | with open(val_path, 'r') as fr: 32 | data.extend([line.strip().split('\t')[1].split(',') for line in fr.readlines()]) 33 | 34 | for temp in data: 35 | for i in temp: 36 | for j in temp: 37 | adjacency_matrix[cate2id[i], cate2id[j]] += 1 38 | 39 | for i in range(adjacency_matrix.shape[0]): 40 | adjacency_matrix[i] = adjacency_matrix[i] / adjacency_matrix[i, i] 41 | adjacency_matrix[i, i] = 0.0 42 | 43 | np.save(graph_path, adjacency_matrix) 44 | 45 | # generate coco category embeddings 46 | with open(glove_path, 'r') as fr: 47 | embeddings = dict([line.split(' ', 1) for line in fr.readlines()]) 48 | 49 | data_embeddings = [] 50 | for cat in categories: 51 | if cat == 'diningtable': # pretrained glove missing the label diningtable in voc2012 52 | cat = 'dining table' 53 | if cat == 'tvmonitor': 54 | cat = 'tv monitor' 55 | if cat == 'pottedplant': 56 | cat = 'potted plant' 57 | # category (eg: traffic light) with two or more words should split and average in each word embedding 58 | temp = np.array([list(map(lambda x: float(x), embeddings[t].split())) for t in cat.split()]) 59 | if temp.shape[0] > 1: 60 | temp = temp.mean(axis=0, keepdims=True) 61 | data_embeddings.append(temp[0]) 62 | 63 | data_embeddings = np.array(data_embeddings) 64 | np.save(embed_path, data_embeddings) 65 | 66 | if __name__ == "__main__": 67 | parser = argparse.ArgumentParser() 68 | parser.add_argument('--data', type=str, required=True, choices=['coco', 'voc2012', 'vg500']) 69 | args = parser.parse_args() 70 | 71 | preprocessing_for_ssgrl(args.data) -------------------------------------------------------------------------------- /scripts/vg500.py: -------------------------------------------------------------------------------- 1 | # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | # Created by: jasonseu 3 | # Created on: 2021-1-19 4 | # Email: zhuxuelin23@gmail.com 5 | # 6 | # Copyright © 2021 - CPSS Group 7 | # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 8 | import os 9 | import json 10 | import random 11 | from random import shuffle 12 | 13 | from collections import Counter 14 | 15 | random.seed(202) 16 | 17 | target_dir = 'temp/vg500' 18 | image_dir1 = 'data/VisualGenome1.4/VG_100K' 19 | image_dir2 = 'data/VisualGenome1.4/VG_100K_2' 20 | if not os.path.exists(target_dir): 21 | os.makedirs(target_dir) 22 | 23 | data = json.load(open('data/VisualGenome1.4/objects.json')) 24 | print('total image number: {}'.format(len(data))) 25 | 26 | tags = [temp['names'][0] for line in data for temp in line['objects']] 27 | counter = Counter() 28 | counter.update(tags) 29 | print('total tag number: {}'.format(len(tags))) 30 | print('total unique tag number: {}'.format(len(counter))) 31 | 32 | tags500 = [k for k, _ in counter.most_common()[:500]] 33 | 34 | vg500_dict = [] 35 | for line in data: 36 | temp = {} 37 | temp['image_id'] = line['image_id'] 38 | temp['objects'] = [] 39 | for obj in line['objects']: 40 | if obj['names'][0] in tags500 and obj['names'][0] not in temp['objects']: 41 | temp['objects'].append(obj['names'][0]) 42 | if len(temp['objects']) > 0: 43 | vg500_dict.append(temp) 44 | 45 | vg500 = [] 46 | for item in vg500_dict: 47 | labels = ','.join(item['objects']) 48 | img_name = '{}.jpg'.format(item['image_id']) 49 | img_path = os.path.join(image_dir1, img_name) 50 | if not os.path.exists(img_path): 51 | img_path = os.path.join(image_dir2, img_name) 52 | if not os.path.exists(img_path): 53 | raise Exception('file {} not found!'.format(img_path)) 54 | vg500.append('{}\t{}\n'.format(img_path, labels)) 55 | 56 | shuffle(vg500) 57 | train_num = int(len(vg500) * 0.8) 58 | train_split = vg500[:train_num] 59 | test_split = vg500[train_num:] 60 | 61 | print('total number of train dataset: {}'.format(len(train_split))) 62 | # # print('total number of validation dataset: {}'.format(len(validation_split))) 63 | print('total number of test dataset: {}'.format(len(test_split))) 64 | 65 | with open(os.path.join(target_dir, 'train.txt'), 'w') as fw: 66 | fw.writelines(train_split) 67 | # with open('data/vg500/val.txt', 'w') as fw: 68 | # fw.writelines(validation_split) 69 | with open(os.path.join(target_dir, 'val.txt'), 'w') as fw: 70 | fw.writelines(test_split) 71 | with open(os.path.join(target_dir, 'label.txt'), 'w') as fw: 72 | fw.writelines(['{}\n'.format(t) for t in tags500]) 73 | -------------------------------------------------------------------------------- /scripts/voc2012.py: -------------------------------------------------------------------------------- 1 | # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | # Created by: jasonseu 3 | # Created on: 2021-1-20 4 | # Email: zhuxuelin23@gmail.com 5 | # 6 | # Copyright © 2021 - CPSS Group 7 | # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 8 | import os 9 | from os.path import join 10 | from xml.dom.minidom import parse 11 | 12 | 13 | data_dir = 'data/VOC2012' 14 | anno_dir = os.path.join(data_dir, 'Annotations') 15 | image_dir = os.path.join(data_dir, 'JPEGImages') 16 | target_dir = 'temp/voc2012' 17 | if not os.path.exists(target_dir): 18 | os.makedirs(target_dir) 19 | 20 | train_txt = os.path.join(data_dir, 'ImageSets/Main/train.txt') 21 | train_imgIds = [t.strip() for t in open(train_txt)] 22 | 23 | label_set = set() 24 | train_data = [] 25 | for img_id in train_imgIds: 26 | xml_path = os.path.join(anno_dir, '{}.xml'.format(img_id)) 27 | dom_tree = parse(xml_path) 28 | root = dom_tree.documentElement 29 | objects = root.getElementsByTagName('object') 30 | labels = set() 31 | for obj in objects: 32 | if (obj.getElementsByTagName('difficult')[0].firstChild.data) == '1': 33 | continue 34 | tag = obj.getElementsByTagName('name')[0].firstChild.data.lower() 35 | labels.add(tag) 36 | label_set.add(tag) 37 | image_path = os.path.join(image_dir, '{}.jpg'.format(img_id)) 38 | if not os.path.exists(image_path): 39 | raise Exception('file {} not found!'.format(image_path)) 40 | train_data.append('{}\t{}\n'.format(image_path, ','.join(list(labels)))) 41 | 42 | with open(os.path.join(target_dir, 'train.txt'), 'w') as fw: 43 | fw.writelines(train_data) 44 | label_set = sorted(list(label_set)) 45 | with open(os.path.join(target_dir, 'label.txt'), 'w') as fw: 46 | for line in label_set: 47 | fw.write(line+'\n') 48 | 49 | 50 | val_txt = os.path.join(data_dir, 'ImageSets/Main/val.txt') 51 | val_imgIds = [t.strip() for t in open(val_txt)] 52 | 53 | val_data = [] 54 | for img_id in val_imgIds: 55 | xml_path = os.path.join(anno_dir, '{}.xml'.format(img_id)) 56 | dom_tree = parse(xml_path) 57 | root = dom_tree.documentElement 58 | objects = root.getElementsByTagName('object') 59 | labels = set() 60 | for obj in objects: 61 | if (obj.getElementsByTagName('difficult')[0].firstChild.data) == '1': 62 | continue 63 | tag = obj.getElementsByTagName('name')[0].firstChild.data.lower() 64 | labels.add(tag) 65 | image_path = os.path.join(image_dir, '{}.jpg'.format(img_id)) 66 | if not os.path.exists(image_path): 67 | raise Exception('file {} not found!'.format(image_path)) 68 | val_data.append('{}\t{}\n'.format(image_path, ','.join(list(labels)))) 69 | 70 | with open(os.path.join(target_dir, 'val.txt'), 'w') as fw: 71 | fw.writelines(val_data) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | # Created by: jasonseu 3 | # Created on: 2021-1-19 4 | # Email: zhuxuelin23@gmail.com 5 | # 6 | # Copyright © 2021 - CPSS Group 7 | # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 8 | import os 9 | import argparse 10 | from argparse import Namespace 11 | 12 | import torch 13 | from torch import nn 14 | from torch.optim import Adam, SGD, lr_scheduler 15 | from torchvision import transforms 16 | from torch.utils.data import DataLoader 17 | from torch.utils.tensorboard import SummaryWriter 18 | 19 | from models import model_factory 20 | from lib.util import * 21 | from lib.metrics import * 22 | from lib.dataset import MLDataset 23 | 24 | torch.backends.cudnn.benchmark = True 25 | 26 | class Trainer(object): 27 | def __init__(self, args): 28 | super(Trainer, self).__init__() 29 | train_transform = transforms.Compose([ 30 | transforms.Resize((args.scale_size, args.scale_size)), 31 | transforms.RandomChoice([ 32 | transforms.RandomCrop(640), 33 | transforms.RandomCrop(576), 34 | transforms.RandomCrop(512), 35 | transforms.RandomCrop(384), 36 | transforms.RandomCrop(320) 37 | ]), 38 | transforms.Resize((args.crop_size, args.crop_size)), 39 | transforms.ToTensor(), 40 | transforms.Normalize( 41 | mean=[0.485, 0.456, 0.406], 42 | std=[0.229, 0.224, 0.225] 43 | ) 44 | ]) 45 | train_dataset = MLDataset(args.train_path, args.label_path, train_transform) 46 | self.train_loader = DataLoader( 47 | dataset=train_dataset, 48 | batch_size=args.batch_size, 49 | shuffle=True, 50 | num_workers=args.num_workers, 51 | pin_memory=True, 52 | drop_last=True 53 | ) 54 | val_transform = transforms.Compose([ 55 | transforms.Resize((args.scale_size, args.scale_size)), 56 | transforms.CenterCrop(args.crop_size), 57 | transforms.ToTensor(), 58 | transforms.Normalize( 59 | mean=[0.485, 0.456, 0.406], 60 | std=[0.229, 0.224, 0.225] 61 | ) 62 | ]) 63 | val_dataset = MLDataset(args.val_path, args.label_path, val_transform) 64 | self.val_loader = DataLoader( 65 | dataset=val_dataset, 66 | batch_size=args.batch_size, 67 | shuffle=False, 68 | num_workers=args.num_workers, 69 | pin_memory=True 70 | ) 71 | 72 | self.model = model_factory[args.model](args, args.num_classes) 73 | self.model.cuda() 74 | 75 | trainable_parameters = filter(lambda param: param.requires_grad, self.model.parameters()) 76 | if args.optimizer == 'Adam': 77 | self.optimizer = Adam(trainable_parameters, lr=args.lr) 78 | elif args.optimizer == 'SGD': 79 | self.optimizer = SGD(trainable_parameters, lr=args.lr) 80 | 81 | self.lr_scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='max', patience=2, verbose=True) 82 | if args.loss == 'BCElogitloss': 83 | self.criterion = nn.BCEWithLogitsLoss() 84 | elif args.loss == 'tencentloss': 85 | self.criterion = TencentLoss(args.num_classes) 86 | elif args.loss == 'focalloss': 87 | self.criterion = FocalLoss() 88 | self.early_stopping = EarlyStopping(patience=5) 89 | 90 | self.voc12_mAP = VOC12mAP(args.num_classes) 91 | self.average_loss = AverageLoss() 92 | self.average_topk_meter = TopkAverageMeter(args.num_classes, topk=args.topk) 93 | self.average_threshold_meter = ThresholdAverageMeter(args.num_classes, threshold=args.threshold) 94 | 95 | self.args = args 96 | self.global_step = 0 97 | self.writer = SummaryWriter(log_dir=args.log_dir) 98 | 99 | def run(self): 100 | s_epoch = 0 101 | if self.args.resume: 102 | checkpoint = torch.load(self.args.ckpt_latest_path) 103 | s_epoch = checkpoint['epoch'] 104 | self.global_step = checkpoint['global_step'] 105 | self.model.load_state_dict(checkpoint['model_state_dict']) 106 | self.optimizer.load_state_dict(checkpoint['optim_state_dict']) 107 | self.early_stopping.best_score = checkpoint['best_score'] 108 | print('loading checkpoint success (epoch {})'.format(s_epoch)) 109 | 110 | for epoch in range(s_epoch, self.args.max_epoch): 111 | self.train(epoch) 112 | save_dict = { 113 | 'epoch': epoch + 1, 114 | 'global_step': self.global_step, 115 | 'model_state_dict': self.model.state_dict(), 116 | 'optim_state_dict': self.optimizer.state_dict(), 117 | 'best_score': self.early_stopping.best_score 118 | } 119 | torch.save(save_dict, self.args.ckpt_latest_path) 120 | 121 | mAP = self.validation(epoch) 122 | self.lr_scheduler.step(mAP) 123 | is_save, is_terminate = self.early_stopping(mAP) 124 | if is_terminate: 125 | break 126 | if is_save: 127 | torch.save(self.model.state_dict(), self.args.ckpt_best_path) 128 | 129 | def train(self, epoch): 130 | self.model.train() 131 | if self.args.model == 'ssgrl': 132 | self.model.resnet_101.eval() 133 | self.model.resnet_101.layer4.train() 134 | for _, batch in enumerate(self.train_loader): 135 | x, y = batch[0].cuda(), batch[1].cuda() 136 | pred_y = self.model(x) 137 | loss = self.criterion(pred_y, y) 138 | self.optimizer.zero_grad() 139 | loss.backward() 140 | self.optimizer.step() 141 | 142 | if self.global_step % 400 == 0: 143 | self.writer.add_scalar('Loss/train', loss, self.global_step) 144 | print('TRAIN [epoch {}] loss: {:4f}'.format(epoch, loss)) 145 | 146 | self.global_step += 1 147 | 148 | def validation(self, epoch): 149 | self.model.eval() 150 | self.voc12_mAP.reset() 151 | self.average_loss.reset() 152 | self.average_topk_meter.reset() 153 | self.average_threshold_meter.reset() 154 | with torch.no_grad(): 155 | for _, batch in enumerate(self.val_loader): 156 | x, y = batch[0].cuda(), batch[1].cuda() 157 | pred_y = self.model(x) 158 | loss = self.criterion(pred_y, y) 159 | 160 | y = y.cpu().numpy() 161 | pred_y = pred_y.cpu().numpy() 162 | loss = loss.cpu().numpy() 163 | self.voc12_mAP.update(pred_y, y) 164 | self.average_loss.update(loss, x.size(0)) 165 | self.average_topk_meter.update(pred_y, y) 166 | self.average_threshold_meter.update(pred_y, y) 167 | 168 | _, mAP = self.voc12_mAP.compute() 169 | mLoss = self.average_loss.compute() 170 | self.average_topk_meter.compute() 171 | self.average_threshold_meter.compute() 172 | self.writer.add_scalar('Loss/val', mLoss, self.global_step) 173 | self.writer.add_scalar('mAP/val', mAP, self.global_step) 174 | 175 | print("Validation [epoch {}] mAP: {:.4f} loss: {:.4f}".format(epoch, mAP, mLoss)) 176 | return mAP 177 | 178 | 179 | if __name__ == "__main__": 180 | parser = argparse.ArgumentParser() 181 | parser.add_argument('--config', type=str, default='configs/coco_resnet101.yaml') 182 | parser.add_argument('--resume', action='store_true', default=False) 183 | args = parser.parse_args() 184 | cfg = load_cfg(args.config) 185 | cfg['resume'] = args.resume 186 | args = Namespace(**cfg) 187 | print(args) 188 | if not os.path.exists(args.log_dir): 189 | os.makedirs(args.log_dir) 190 | if not os.path.exists(args.ckpt_dir): 191 | os.makedirs(args.ckpt_dir) 192 | 193 | trainer = Trainer(args) 194 | trainer.run() 195 | --------------------------------------------------------------------------------