├── .DS_Store ├── LICENSE ├── README.md ├── checkpoint ├── COCO_84.6.txt └── VOC2007_95.0.txt ├── data ├── __init__.py ├── coco.py └── voc.py ├── figs ├── .DS_Store ├── motivation.png └── vis.png ├── main.py ├── models ├── TDRG.py ├── __init__.py └── trans_utils │ ├── position_encoding.py │ └── transformer.py ├── trainer.py └── util.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iCVTEAM/TDRG/6ea0c9dedf84e38efbf24f0e481bce8c2e028323/.DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TDRG 2 | Pytorch implementation of [Transformer-based Dual Relation Graph for Multi-label Image Recognition. ICCV 2021](https://openaccess.thecvf.com/content/ICCV2021/html/Zhao_Transformer-Based_Dual_Relation_Graph_for_Multi-Label_Image_Recognition_ICCV_2021_paper.html) 3 | 4 | ![TDRG](https://github.com/iCVTEAM/TDRG/blob/master/figs/motivation.png) 5 | 6 | ## Prerequisites 7 | 8 | Python 3.6+ 9 | 10 | Pytorch 1.6 11 | 12 | CUDA 10.1 13 | 14 | Tesla V100 × 2 15 | 16 | ## Datasets 17 | 18 | - MS-COCO: [train](http://images.cocodataset.org/zips/train2014.zip) [val](http://images.cocodataset.org/zips/val2014.zip) [annotations](http://images.cocodataset.org/annotations/annotations_trainval2014.zip) 19 | - VOC 2007: [trainval](http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar) [test](http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar) [test_anno](http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtestnoimgs_06-Nov-2007.tar) 20 | 21 | ## Model 22 | 23 | - MS-COCO: the name of parameters in our original network is different from the public one, hence if you want to test the performance of TDRG on MS-COCO, please download the [checkpoint](https://drive.google.com/file/d/1roXvhXxivzVxjBsJ0_kqLbGnQfb7GSMp/view?usp=sharing) into `checkpoint/COCO2014` folder and replace the function `load_checkpoint` with `load_origin_checkpoint` in `trainer.py`. 24 | 25 | ## Train 26 | 27 | ``` 28 | CUDA_VISIBLE_DEVICES=0,1 python main.py --data COCO2014 --data_root_dir $DATA_PATH$ --save_dir $SAVE_PATH$ --i 448 --lr 0.03 -b 64 29 | ``` 30 | 31 | ## Test 32 | 33 | ``` 34 | python main.py --data COCO2014 --data_root_dir $DATA_PATH$ --save_dir $SAVE_PATH$ --i 448 --lr 0.03 -b 64 -e --resume checkpoint/COCO2014/checkpoint_COCO.pth 35 | ``` 36 | 37 | ## Visualization 38 | 39 | ![vis](https://github.com/iCVTEAM/TDRG/blob/master/figs/vis.png) 40 | 41 | ## Citation 42 | 43 | - If you find this work is helpful, please cite our paper 44 | 45 | ``` 46 | @InProceedings{Zhao2021TDRG, 47 | author = {Zhao, Jiawei and Yan, Ke and Zhao, Yifan and Guo, Xiaowei and Huang, Feiyue and Li, Jia}, 48 | title = {Transformer-Based Dual Relation Graph for Multi-Label Image Recognition}, 49 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 50 | month = {October}, 51 | year = {2021}, 52 | pages = {163-172} 53 | } 54 | ``` 55 | 56 | -------------------------------------------------------------------------------- /checkpoint/COCO_84.6.txt: -------------------------------------------------------------------------------- 1 | tensor([0.9735, 0.7160, 0.5518, 0.8753, 0.9544, 0.9616, 0.9802, 0.8900, 0.6906, 2 | 0.8170, 0.8371, 0.9134, 0.7074, 0.7631, 0.7623, 0.9328, 0.8966, 0.8491, 3 | 0.9007, 0.8151, 0.9683, 0.6998, 0.8079, 0.8437, 0.8495, 0.9331, 0.7825, 4 | 0.8046, 0.9119, 0.8679, 0.9864, 0.8605, 0.8032, 0.9489, 0.9969, 0.4524, 5 | 0.5779, 0.9446, 0.8051, 0.8920, 0.9764, 0.7145, 0.9105, 0.8244, 0.9464, 6 | 0.8987, 0.8250, 0.8888, 0.7207, 0.9919, 0.9512, 0.7157, 0.8285, 0.8222, 7 | 0.7730, 0.6551, 0.9651, 0.9231, 0.9764, 0.9533, 0.8742, 0.6747, 0.8878, 8 | 0.8215, 0.7654, 0.9672, 0.8802, 0.9890, 0.8609, 0.4081, 0.9784, 0.7645, 9 | 0.8751, 0.9673, 0.7718, 0.8953, 0.8746, 0.8005, 0.8089, 0.9945]) 10 | * Test 11 | Loss: 0.1566 mAP: 0.8456 Data_time: 0.0401 Batch_time: 1.6352 12 | OP: 0.866 OR: 0.763 OF1: 0.812 CP: 0.863 CR: 0.730 CF1: 0.791 13 | OP_3: 0.912 OR_3: 0.669 OF1_3: 0.772 CP_3: 0.901 CR_3: 0.643 CF1_3: 0.750 -------------------------------------------------------------------------------- /checkpoint/VOC2007_95.0.txt: -------------------------------------------------------------------------------- 1 | tensor([0.9990, 0.9887, 0.9836, 0.9870, 0.8191, 0.9584, 0.9778, 0.9803, 0.8514, 2 | 0.9549, 0.8949, 0.9878, 0.9861, 0.9708, 0.9907, 0.8627, 0.9775, 0.8723, 3 | 0.9912, 0.9527]) 4 | * Test 5 | Loss: 0.2230 mAP: 0.9495 Data_time: 0.3335 Batch_time: 2.9532 6 | OP: 0.871 OR: 0.916 OF1: 0.893 CP: 0.855 CR: 0.902 CF1: 0.878 7 | OP_3: 0.873 OR_3: 0.914 OF1_3: 0.893 CP_3: 0.857 CR_3: 0.899 CF1_3: 0.878 -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import os, sys, pdb 2 | from PIL import Image 3 | import random 4 | 5 | import torch 6 | from torch.utils.data import DataLoader 7 | import torchvision.transforms as transforms 8 | 9 | from .coco import COCO2014 10 | from .voc import VOC2007, VOC2012 11 | 12 | data_dict = {'COCO2014': COCO2014, 13 | 'VOC2007': VOC2007, 14 | 'VOC2012': VOC2012} 15 | 16 | def collate_fn(batch): 17 | ret_batch = dict() 18 | for k in batch[0].keys(): 19 | if k == 'image' or k == 'target': 20 | ret_batch[k] = torch.cat([b[k].unsqueeze(0) for b in batch]) 21 | else: 22 | ret_batch[k] = [b[k] for b in batch] 23 | return ret_batch 24 | 25 | class MultiScaleCrop(object): 26 | 27 | def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True): 28 | self.scales = scales if scales is not None else [1, 875, .75, .66] 29 | self.max_distort = max_distort 30 | self.fix_crop = fix_crop 31 | self.more_fix_crop = more_fix_crop 32 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size] 33 | self.interpolation = Image.BILINEAR 34 | 35 | def __call__(self, img): 36 | im_size = img.size 37 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) 38 | crop_img_group = img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) 39 | ret_img_group = crop_img_group.resize((self.input_size[0], self.input_size[1]), self.interpolation) 40 | return ret_img_group 41 | 42 | def _sample_crop_size(self, im_size): 43 | image_w, image_h = im_size[0], im_size[1] 44 | 45 | # find a crop size 46 | base_size = min(image_w, image_h) 47 | crop_sizes = [int(base_size * x) for x in self.scales] 48 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes] 49 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes] 50 | 51 | pairs = [] 52 | for i, h in enumerate(crop_h): 53 | for j, w in enumerate(crop_w): 54 | if abs(i - j) <= self.max_distort: 55 | pairs.append((w, h)) 56 | 57 | crop_pair = random.choice(pairs) 58 | if not self.fix_crop: 59 | w_offset = random.randint(0, image_w - crop_pair[0]) 60 | h_offset = random.randint(0, image_h - crop_pair[1]) 61 | else: 62 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1]) 63 | 64 | return crop_pair[0], crop_pair[1], w_offset, h_offset 65 | 66 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h): 67 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h) 68 | return random.choice(offsets) 69 | 70 | @staticmethod 71 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h): 72 | w_step = (image_w - crop_w) // 4 73 | h_step = (image_h - crop_h) // 4 74 | 75 | ret = list() 76 | ret.append((0, 0)) # upper left 77 | ret.append((4 * w_step, 0)) # upper right 78 | ret.append((0, 4 * h_step)) # lower left 79 | ret.append((4 * w_step, 4 * h_step)) # lower right 80 | ret.append((2 * w_step, 2 * h_step)) # center 81 | 82 | if more_fix_crop: 83 | ret.append((0, 2 * h_step)) # center left 84 | ret.append((4 * w_step, 2 * h_step)) # center right 85 | ret.append((2 * w_step, 4 * h_step)) # lower center 86 | ret.append((2 * w_step, 0 * h_step)) # upper center 87 | 88 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter 89 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter 90 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter 91 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter 92 | 93 | return ret 94 | 95 | def __str__(self): 96 | return self.__class__.__name__ 97 | 98 | 99 | def get_transform(args, is_train=True): 100 | # ImageNet 101 | mean = [0.485, 0.456, 0.406] 102 | std = [0.229, 0.224, 0.225] 103 | # COCO 104 | #mean = [0.471, 0.448, 0.408] 105 | #std = [0.234, 0.239, 0.242] 106 | # VOC 107 | #mean = [0.448, 0.425, 0.392] 108 | #std = [0.241, 0.236, 0.238] 109 | if is_train: 110 | transform = transforms.Compose([ 111 | # transforms.RandomResizedCrop(args.image_size, scale=(0.1, 1.5), ratio=(1.0, 1.0)), 112 | # transforms.RandomResizedCrop(args.image_size, scale=(0.1, 2.0), ratio=(1.0, 1.0)), 113 | transforms.Resize((args.image_size+64, args.image_size+64)), 114 | MultiScaleCrop(args.image_size, scales=(1.0, 0.875, 0.75, 0.66, 0.5), max_distort=2), 115 | #MultiScaleCrop(args.image_size, scales=(1.0, 0.875, 0.75), max_distort=2), 116 | transforms.RandomHorizontalFlip(), 117 | transforms.ToTensor(), 118 | transforms.Normalize(mean=mean, std=std) 119 | ]) 120 | else: 121 | transform = transforms.Compose([ 122 | transforms.Resize((args.image_size,args.image_size)), 123 | transforms.ToTensor(), 124 | transforms.Normalize(mean=mean, std=std) 125 | ]) 126 | return transform 127 | 128 | def make_data_loader(args, is_train=True): 129 | root_dir = os.path.join(args.data_root_dir, args.data) 130 | 131 | # Build val_loader 132 | transform = get_transform(args, is_train=False) 133 | if args.data == 'COCO2014': 134 | val_dataset = COCO2014(root_dir, phase='val', transform=transform) 135 | elif args.data in ('VOC2007', 'VOC2012'): 136 | val_dataset = data_dict['VOC2007'](root_dir, phase='test', transform=transform) 137 | else: 138 | raise NotImplementedError('Value error: No matched dataset!') 139 | 140 | num_classes = val_dataset[0]['target'].size(-1) 141 | val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, 142 | num_workers=args.num_workers, pin_memory=True, 143 | collate_fn=collate_fn, drop_last=False) 144 | 145 | if not is_train: 146 | return None, val_loader, num_classes 147 | 148 | # Build train_loader 149 | transform = get_transform(args, is_train=True) 150 | if args.data == 'COCO2014': 151 | train_dataset = COCO2014(root_dir, phase='train', transform=transform) 152 | elif args.data in ('VOC2007', 'VOC2012'): 153 | train_dataset = data_dict[args.data](root_dir, phase='trainval', transform=transform) 154 | else: 155 | raise NotImplementedError('Value error: No matched dataset!') 156 | 157 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, 158 | num_workers=args.num_workers, pin_memory=True, 159 | collate_fn=collate_fn, drop_last=True) 160 | 161 | 162 | return train_loader, val_loader, num_classes 163 | -------------------------------------------------------------------------------- /data/coco.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import subprocess 4 | from PIL import Image 5 | # import numpy as np 6 | import torch 7 | from torch.utils.data import Dataset 8 | import pickle 9 | 10 | urls = {'train_img':'http://images.cocodataset.org/zips/train2014.zip', 11 | 'val_img' : 'http://images.cocodataset.org/zips/val2014.zip', 12 | 'annotations':'http://images.cocodataset.org/annotations/annotations_trainval2014.zip'} 13 | 14 | def download_coco2014(root, phase): 15 | work_dir = os.getcwd() 16 | tmpdir = os.path.join(root, 'tmp/') 17 | if not os.path.exists(root): 18 | os.makedirs(root) 19 | if not os.path.exists(tmpdir): 20 | os.makedirs(tmpdir) 21 | if phase == 'train': 22 | filename = 'train2014.zip' 23 | elif phase == 'val': 24 | filename = 'val2014.zip' 25 | cached_file = os.path.join(tmpdir, filename) 26 | if not os.path.exists(cached_file): 27 | print('Downloading: "{}" to {}\n'.format(urls[phase + '_img'], cached_file)) 28 | os.chdir(tmpdir) 29 | subprocess.call('wget ' + urls[phase + '_img'], shell=True) 30 | os.chdir(root) 31 | # extract file 32 | img_data = os.path.join(root, filename.split('.')[0]) 33 | if not os.path.exists(img_data): 34 | print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root)) 35 | command = 'unzip {} -d {}'.format(cached_file,root) 36 | os.system(command) 37 | print('[dataset] Done!') 38 | 39 | # train/val images/annotations 40 | cached_file = os.path.join(tmpdir, 'annotations_trainval2014.zip') 41 | if not os.path.exists(cached_file): 42 | print('Downloading: "{}" to {}\n'.format(urls['annotations'], cached_file)) 43 | os.chdir(tmpdir) 44 | # subprocess.Popen('wget ' + urls['annotations'], shell=True) 45 | subprocess.call('wget ' + urls['annotations'], shell=True) 46 | os.chdir(root) 47 | annotations_data = os.path.join(root, 'annotations') 48 | if not os.path.exists(annotations_data): 49 | print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root)) 50 | command = 'unzip {} -d {}'.format(cached_file, root) 51 | os.system(command) 52 | print('[annotation] Done!') 53 | 54 | annotations_data = os.path.join(root, 'annotations') 55 | anno = os.path.join(root, '{}_anno.json'.format(phase)) 56 | img_id = {} 57 | annotations_id = {} 58 | if not os.path.exists(anno): 59 | annotations_file = json.load(open(os.path.join(annotations_data, 'instances_{}2014.json'.format(phase)))) 60 | annotations = annotations_file['annotations'] 61 | category = annotations_file['categories'] 62 | category_id = {} 63 | for cat in category: 64 | category_id[cat['id']] = cat['name'] 65 | cat2idx = categoty_to_idx(sorted(category_id.values())) 66 | images = annotations_file['images'] 67 | for annotation in annotations: 68 | if annotation['image_id'] not in annotations_id: 69 | annotations_id[annotation['image_id']] = set() 70 | annotations_id[annotation['image_id']].add(cat2idx[category_id[annotation['category_id']]]) 71 | for img in images: 72 | if img['id'] not in annotations_id: 73 | continue 74 | if img['id'] not in img_id: 75 | img_id[img['id']] = {} 76 | img_id[img['id']]['file_name'] = img['file_name'] 77 | img_id[img['id']]['labels'] = list(annotations_id[img['id']]) 78 | anno_list = [] 79 | for k, v in img_id.items(): 80 | anno_list.append(v) 81 | json.dump(anno_list, open(anno, 'w')) 82 | if not os.path.exists(os.path.join(root, 'category.json')): 83 | json.dump(cat2idx, open(os.path.join(root, 'category.json'), 'w')) 84 | del img_id 85 | del anno_list 86 | del images 87 | del annotations_id 88 | del annotations 89 | del category 90 | del category_id 91 | print('[json] Done!') 92 | os.chdir(work_dir) 93 | 94 | def categoty_to_idx(category): 95 | cat2idx = {} 96 | for cat in category: 97 | cat2idx[cat] = len(cat2idx) 98 | return cat2idx 99 | 100 | 101 | class COCO2014(Dataset): 102 | def __init__(self, root, transform=None, phase='train'): 103 | self.root = os.path.abspath(root) 104 | self.phase = phase 105 | self.img_list = [] 106 | self.transform = transform 107 | download_coco2014(self.root, phase) 108 | self.get_anno() 109 | self.num_classes = len(self.cat2idx) 110 | print('[dataset] COCO2014 classification phase={} number of classes={} number of images={}'.format(phase, self.num_classes, len(self.img_list))) 111 | 112 | def get_anno(self): 113 | list_path = os.path.join(self.root, '{}_anno.json'.format(self.phase)) 114 | self.img_list = json.load(open(list_path, 'r')) 115 | #self.img_list = self.img_list[:20000] 116 | self.cat2idx = json.load(open(os.path.join(self.root, 'category.json'), 'r')) 117 | 118 | def __len__(self): 119 | return len(self.img_list) 120 | 121 | def __getitem__(self, index): 122 | item = self.img_list[index] 123 | filename = item['file_name'] 124 | labels = sorted(item['labels']) 125 | img = Image.open(os.path.join(self.root, '{}2014'.format(self.phase), filename)).convert('RGB') 126 | if self.transform is not None: 127 | img = self.transform(img) 128 | # target = np.zeros(self.num_classes, np.float32) - 1 129 | target = torch.zeros(self.num_classes, dtype=torch.float32) - 1 130 | target[labels] = 1 131 | data = {'image':img, 'name': filename, 'target': target} 132 | return data 133 | # return image, target 134 | # return (img, filename), target 135 | -------------------------------------------------------------------------------- /data/voc.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | import tarfile 4 | from urllib.parse import urlparse 5 | from urllib.request import urlretrieve 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import Dataset 9 | from PIL import Image 10 | from tqdm import tqdm 11 | 12 | object_categories = ['aeroplane', 'bicycle', 'bird', 'boat', 13 | 'bottle', 'bus', 'car', 'cat', 'chair', 14 | 'cow', 'diningtable', 'dog', 'horse', 15 | 'motorbike', 'person', 'pottedplant', 16 | 'sheep', 'sofa', 'train', 'tvmonitor'] 17 | 18 | urls2007 = { 19 | 'devkit': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCdevkit_18-May-2011.tar', 20 | 'trainval_2007': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar', 21 | 'test_images_2007': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar', 22 | 'test_anno_2007': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtestnoimgs_06-Nov-2007.tar', 23 | } 24 | 25 | urls2012 = { 26 | 'devkit': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCdevkit_18-May-2011.tar', 27 | # 'trainval_2012': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_06-Nov-2012.tar', 28 | 'trainval_2012': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar', 29 | # 'test_images_2012': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtest_06-Nov-2012.tar', 30 | 'test_images_2012': 'http://pjreddie.com/media/files/VOC2012test.tar', 31 | # 'test_anno_2012': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtestnoimgs_06-Nov-2012.tar', 32 | } 33 | 34 | 35 | def download_url(url, destination=None, progress_bar=True): 36 | """Download a URL to a local file. 37 | 38 | Parameters 39 | ---------- 40 | url : str 41 | The URL to download. 42 | destination : str, None 43 | The destination of the file. If None is given the file is saved to a temporary directory. 44 | progress_bar : bool 45 | Whether to show a command-line progress bar while downloading. 46 | 47 | Returns 48 | ------- 49 | filename : str 50 | The location of the downloaded file. 51 | 52 | Notes 53 | ----- 54 | Progress bar use/example adapted from tqdm documentation: https://github.com/tqdm/tqdm 55 | """ 56 | 57 | def my_hook(t): 58 | last_b = [0] 59 | 60 | def inner(b=1, bsize=1, tsize=None): 61 | if tsize is not None: 62 | t.total = tsize 63 | if b > 0: 64 | t.update((b - last_b[0]) * bsize) 65 | last_b[0] = b 66 | 67 | return inner 68 | 69 | if progress_bar: 70 | with tqdm(unit='B', unit_scale=True, miniters=1, desc=url.split('/')[-1]) as t: 71 | filename, _ = urlretrieve(url, filename=destination, reporthook=my_hook(t)) 72 | else: 73 | filename, _ = urlretrieve(url, filename=destination) 74 | 75 | 76 | def read_image_label(file): 77 | print('[dataset] read ' + file) 78 | data = dict() 79 | with open(file, 'r') as f: 80 | for line in f: 81 | tmp = line.split(' ') 82 | name = tmp[0] 83 | label = int(tmp[-1]) 84 | data[name] = label 85 | return data 86 | 87 | 88 | def read_object_labels(root, dataset, phase): 89 | path_labels = os.path.join(root, 'VOCdevkit', dataset, 'ImageSets', 'Main') 90 | labeled_data = dict() 91 | num_classes = len(object_categories) 92 | 93 | for i in range(num_classes): 94 | file = os.path.join(path_labels, object_categories[i] + '_' + phase + '.txt') 95 | data = read_image_label(file) 96 | 97 | if i == 0: 98 | for (name, label) in data.items(): 99 | labels = np.zeros(num_classes) 100 | labels[i] = label 101 | labeled_data[name] = labels 102 | else: 103 | for (name, label) in data.items(): 104 | labeled_data[name][i] = label 105 | 106 | return labeled_data 107 | 108 | 109 | def write_object_labels_csv(file, labeled_data): 110 | # write a csv file 111 | print('[dataset] write file %s' % file) 112 | with open(file, 'w') as csvfile: 113 | fieldnames = ['name'] 114 | fieldnames.extend(object_categories) 115 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames) 116 | 117 | writer.writeheader() 118 | for (name, labels) in labeled_data.items(): 119 | example = {'name': name} 120 | for i in range(20): 121 | example[fieldnames[i + 1]] = int(labels[i]) 122 | writer.writerow(example) 123 | 124 | csvfile.close() 125 | 126 | 127 | def read_object_labels_csv(file, header=True): 128 | images = [] 129 | num_categories = 0 130 | print('[dataset] read', file) 131 | with open(file, 'r') as f: 132 | reader = csv.reader(f) 133 | rownum = 0 134 | for row in reader: 135 | if header and rownum == 0: 136 | header = row 137 | else: 138 | if num_categories == 0: 139 | num_categories = len(row) - 1 140 | name = row[0] 141 | labels = torch.from_numpy((np.asarray(row[1:num_categories + 1])).astype(np.float32)) 142 | item = (name, labels) 143 | images.append(item) 144 | rownum += 1 145 | return images 146 | 147 | 148 | # def find_images_classification(root, dataset, phase): 149 | # path_labels = os.path.join(root, 'VOCdevkit', dataset, 'ImageSets', 'Main') 150 | # images = [] 151 | # file = os.path.join(path_labels, phase + '.txt') 152 | # with open(file, 'r') as f: 153 | # for line in f: 154 | # images.append(line) 155 | # return images 156 | 157 | 158 | def download_voc2007(root): 159 | path_devkit = os.path.join(root, 'VOCdevkit') 160 | path_images = os.path.join(root, 'VOCdevkit', 'VOC2007', 'JPEGImages') 161 | tmpdir = os.path.join(root, 'tmp') 162 | 163 | # create directory 164 | if not os.path.exists(root): 165 | os.makedirs(root) 166 | 167 | if not os.path.exists(path_devkit): 168 | 169 | if not os.path.exists(tmpdir): 170 | os.makedirs(tmpdir) 171 | 172 | parts = urlparse(urls2007['devkit']) 173 | filename = os.path.basename(parts.path) 174 | cached_file = os.path.join(tmpdir, filename) 175 | 176 | if not os.path.exists(cached_file): 177 | print('Downloading: "{}" to {}\n'.format(urls2007['devkit'], cached_file)) 178 | download_url(urls2007['devkit'], cached_file) 179 | 180 | # extract file 181 | print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root)) 182 | cwd = os.getcwd() 183 | tar = tarfile.open(cached_file, "r") 184 | os.chdir(root) 185 | tar.extractall() 186 | tar.close() 187 | os.chdir(cwd) 188 | print('[dataset] Done!') 189 | 190 | # train/val images/annotations 191 | if not os.path.exists(path_images): 192 | 193 | # download train/val images/annotations 194 | parts = urlparse(urls2007['trainval_2007']) 195 | filename = os.path.basename(parts.path) 196 | cached_file = os.path.join(tmpdir, filename) 197 | 198 | if not os.path.exists(cached_file): 199 | print('Downloading: "{}" to {}\n'.format(urls2007['trainval_2007'], cached_file)) 200 | download_url(urls2007['trainval_2007'], cached_file) 201 | 202 | # extract file 203 | print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root)) 204 | cwd = os.getcwd() 205 | tar = tarfile.open(cached_file, "r") 206 | os.chdir(root) 207 | tar.extractall() 208 | tar.close() 209 | os.chdir(cwd) 210 | print('[dataset] Done!') 211 | 212 | # test images 213 | test_image = os.path.join(path_devkit, 'VOC2007/JPEGImages/000001.jpg') 214 | if not os.path.exists(test_image): 215 | 216 | # download test images 217 | parts = urlparse(urls2007['test_images_2007']) 218 | filename = os.path.basename(parts.path) 219 | cached_file = os.path.join(tmpdir, filename) 220 | 221 | if not os.path.exists(cached_file): 222 | print('Downloading: "{}" to {}\n'.format(urls2007['test_images_2007'], cached_file)) 223 | download_url(urls2007['test_images_2007'], cached_file) 224 | 225 | # extract file 226 | print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root)) 227 | cwd = os.getcwd() 228 | tar = tarfile.open(cached_file, "r") 229 | os.chdir(root) 230 | tar.extractall() 231 | tar.close() 232 | os.chdir(cwd) 233 | print('[dataset] Done!') 234 | 235 | # test annotations 236 | test_anno = os.path.join(path_devkit, 'VOC2007/ImageSets/Main/aeroplane_test.txt') 237 | if not os.path.exists(test_anno): 238 | 239 | # download test annotations 240 | parts = urlparse(urls2007['test_anno_2007']) 241 | filename = os.path.basename(parts.path) 242 | cached_file = os.path.join(tmpdir, filename) 243 | 244 | if not os.path.exists(cached_file): 245 | print('Downloading: "{}" to {}\n'.format(urls2007['test_anno_2007'], cached_file)) 246 | download_url(urls2007['test_anno_2007'], cached_file) 247 | 248 | # extract file 249 | print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root)) 250 | cwd = os.getcwd() 251 | tar = tarfile.open(cached_file, "r") 252 | os.chdir(root) 253 | tar.extractall() 254 | tar.close() 255 | os.chdir(cwd) 256 | print('[dataset] Done!') 257 | 258 | 259 | def download_voc2012(root): 260 | path_devkit = os.path.join(root, 'VOCdevkit') 261 | path_images = os.path.join(root, 'VOCdevkit', 'VOC2012', 'JPEGImages') 262 | tmpdir = os.path.join(root, 'tmp') 263 | 264 | # create directory 265 | if not os.path.exists(root): 266 | os.makedirs(root) 267 | 268 | if not os.path.exists(path_devkit): 269 | 270 | if not os.path.exists(tmpdir): 271 | os.makedirs(tmpdir) 272 | 273 | parts = urlparse(urls2012['devkit']) 274 | filename = os.path.basename(parts.path) 275 | cached_file = os.path.join(tmpdir, filename) 276 | 277 | if not os.path.exists(cached_file): 278 | print('Downloading: "{}" to {}\n'.format(urls2012['devkit'], cached_file)) 279 | download_url(urls2012['devkit'], cached_file) 280 | 281 | # extract file 282 | print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root)) 283 | cwd = os.getcwd() 284 | tar = tarfile.open(cached_file, "r") 285 | os.chdir(root) 286 | tar.extractall() 287 | tar.close() 288 | os.chdir(cwd) 289 | print('[dataset] Done!') 290 | 291 | # train/val images/annotations 292 | if not os.path.exists(path_images): 293 | 294 | # download train/val images/annotations 295 | parts = urlparse(urls2012['trainval_2012']) 296 | filename = os.path.basename(parts.path) 297 | cached_file = os.path.join(tmpdir, filename) 298 | 299 | if not os.path.exists(cached_file): 300 | print('Downloading: "{}" to {}\n'.format(urls2012['trainval_2012'], cached_file)) 301 | download_url(urls2012['trainval_2012'], cached_file) 302 | 303 | # extract file 304 | print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root)) 305 | cwd = os.getcwd() 306 | tar = tarfile.open(cached_file, "r") 307 | os.chdir(root) 308 | tar.extractall() 309 | tar.close() 310 | os.chdir(cwd) 311 | print('[dataset] Done!') 312 | 313 | # test images 314 | test_image = os.path.join(path_devkit, 'VOC2012/JPEGImages/2012_000001.jpg') 315 | if not os.path.exists(test_image): 316 | 317 | # download test images 318 | parts = urlparse(urls2012['test_images_2012']) 319 | filename = os.path.basename(parts.path) 320 | cached_file = os.path.join(tmpdir, filename) 321 | 322 | if not os.path.exists(cached_file): 323 | print('Downloading: "{}" to {}\n'.format(urls2012['test_images_2012'], cached_file)) 324 | download_url(urls2012['test_images_2012'], cached_file) 325 | 326 | # extract file 327 | print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root)) 328 | cwd = os.getcwd() 329 | tar = tarfile.open(cached_file, "r") 330 | os.chdir(root) 331 | tar.extractall() 332 | tar.close() 333 | os.chdir(cwd) 334 | print('[dataset] Done!') 335 | 336 | 337 | class VOC2007(Dataset): 338 | def __init__(self, root, phase, transform=None): 339 | self.root = os.path.abspath(root) 340 | self.path_devkit = os.path.join(self.root, 'VOCdevkit') 341 | self.path_images = os.path.join(self.root, 'VOCdevkit', 'VOC2007', 'JPEGImages') 342 | self.phase = phase 343 | self.transform = transform 344 | download_voc2007(self.root) 345 | 346 | # define path of csv file 347 | path_csv = os.path.join(self.root, 'files', 'VOC2007') 348 | # define filename of csv file 349 | file_csv = os.path.join(path_csv, 'classification_' + phase + '.csv') 350 | 351 | # create the csv file if necessary 352 | if not os.path.exists(file_csv): 353 | if not os.path.exists(path_csv): # create dir if necessary 354 | os.makedirs(path_csv) 355 | # generate csv file 356 | labeled_data = read_object_labels(self.root, 'VOC2007', self.phase) 357 | # write csv file 358 | write_object_labels_csv(file_csv, labeled_data) 359 | 360 | self.classes = object_categories 361 | self.images = read_object_labels_csv(file_csv) 362 | print('[dataset] VOC 2007 classification phase={} number of classes={} number of images={}'.format(phase, len(self.classes), len(self.images))) 363 | 364 | def __getitem__(self, index): 365 | filename, target = self.images[index] 366 | img = Image.open(os.path.join(self.path_images, filename + '.jpg')).convert('RGB') 367 | if self.transform is not None: 368 | img = self.transform(img) 369 | 370 | data = {'image':img, 'name': filename, 'target': target} 371 | return data 372 | # image = {'image': img, 'name': filename} 373 | # return image, target 374 | # return (img, filename), target 375 | 376 | def __len__(self): 377 | return len(self.images) 378 | 379 | def get_number_classes(self): 380 | return len(self.classes) 381 | 382 | 383 | class VOC2012(Dataset): 384 | def __init__(self, root, phase, transform=None): 385 | self.root = os.path.abspath(root) 386 | self.path_devkit = os.path.join(self.root, 'VOCdevkit') 387 | self.path_images = os.path.join(self.root, 'VOCdevkit', 'VOC2012', 'JPEGImages') 388 | self.phase = phase 389 | self.transform = transform 390 | download_voc2012(self.root) 391 | 392 | # define path of csv file 393 | path_csv = os.path.join(self.root, 'files', 'VOC2012') 394 | # define filename of csv file 395 | file_csv = os.path.join(path_csv, 'classification_' + phase + '.csv') 396 | 397 | # create the csv file if necessary 398 | if not os.path.exists(file_csv): 399 | if not os.path.exists(path_csv): # create dir if necessary 400 | os.makedirs(path_csv) 401 | # generate csv file 402 | labeled_data = read_object_labels(self.root, 'VOC2012', self.phase) 403 | # write csv file 404 | write_object_labels_csv(file_csv, labeled_data) 405 | 406 | self.classes = object_categories 407 | self.images = read_object_labels_csv(file_csv) 408 | print('[dataset] VOC 2012 classification phase={} number of classes={} number of images={}'.format(phase, len(self.classes), len(self.images))) 409 | 410 | def __getitem__(self, index): 411 | filename, target = self.images[index] 412 | img = Image.open(os.path.join(self.path_images, filename + '.jpg')).convert('RGB') 413 | if self.transform is not None: 414 | img = self.transform(img) 415 | 416 | data = {'image':img, 'name': filename, 'target': target} 417 | return data 418 | # image = {'image': img, 'name': filename} 419 | # return image, target 420 | # return (img, filename), target 421 | 422 | def __len__(self): 423 | return len(self.images) 424 | 425 | def get_number_classes(self): 426 | return len(self.classes) 427 | -------------------------------------------------------------------------------- /figs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iCVTEAM/TDRG/6ea0c9dedf84e38efbf24f0e481bce8c2e028323/figs/.DS_Store -------------------------------------------------------------------------------- /figs/motivation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iCVTEAM/TDRG/6ea0c9dedf84e38efbf24f0e481bce8c2e028323/figs/motivation.png -------------------------------------------------------------------------------- /figs/vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iCVTEAM/TDRG/6ea0c9dedf84e38efbf24f0e481bce8c2e028323/figs/vis.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os, sys, pdb 2 | import argparse 3 | from models import get_model 4 | from data import make_data_loader 5 | import warnings 6 | from trainer import Trainer 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | import random 10 | 11 | 12 | parser = argparse.ArgumentParser(description='PyTorch Training for Multi-label Image Classification') 13 | 14 | ''' Fixed in general ''' 15 | parser.add_argument('--data_root_dir', default='./datasets/', type=str, help='save path') 16 | parser.add_argument('--image-size', '-i', default=448, type=int) 17 | parser.add_argument('--epochs', default=50, type=int) 18 | parser.add_argument('--epoch_step', default=[40], type=int, nargs='+', help='number of epochs to change learning rate') 19 | # parser.add_argument('--device_ids', default=[0], type=int, nargs='+', help='number of epochs to change learning rate') 20 | parser.add_argument('-b', '--batch-size', default=32, type=int) 21 | parser.add_argument('-j', '--num_workers', default=8, type=int, metavar='INT', help='number of data loading workers (default: 4)') 22 | parser.add_argument('--display_interval', default=800, type=int, metavar='M', help='display_interval') 23 | parser.add_argument('--lr', '--learning-rate', default=0.03, type=float) 24 | parser.add_argument('--lrp', '--learning-rate-pretrained', default=0.1, type=float, metavar='LRP', help='learning rate for pre-trained layers') 25 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') 26 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)') 27 | parser.add_argument('--max_clip_grad_norm', default=10.0, type=float, metavar='M', help='max_clip_grad_norm') 28 | parser.add_argument('--seed', default=1, type=int, help='seed for initializing training. ') 29 | 30 | ''' Train setting ''' 31 | parser.add_argument('--data', metavar='NAME', help='dataset name (e.g. COCO2014') 32 | parser.add_argument('--model_name', type=str, default='TDRG') 33 | parser.add_argument('--save_dir', default='./checkpoint/VOC2012/', type=str, help='save path') 34 | 35 | ''' Val or Tese setting ''' 36 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set') 37 | parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') 38 | 39 | 40 | def main(args): 41 | 42 | if args.seed is not None: 43 | print ('* absolute seed: {}'.format(args.seed)) 44 | random.seed(args.seed) 45 | torch.manual_seed(args.seed) 46 | torch.cuda.manual_seed(args.seed) 47 | cudnn.deterministic = True 48 | warnings.warn('You have chosen to seed training. ' 49 | 'This will turn on the CUDNN deterministic setting, ' 50 | 'which can slow down your training considerably! ' 51 | 'You may see unexpected behavior when restarting ' 52 | 'from checkpoints.') 53 | 54 | is_train = True if not args.evaluate else False 55 | train_loader, val_loader, num_classes = make_data_loader(args, is_train=is_train) 56 | 57 | model = get_model(num_classes, args) 58 | 59 | criterion = torch.nn.MultiLabelSoftMarginLoss() 60 | 61 | trainer = Trainer(model, criterion, train_loader, val_loader, args) 62 | 63 | if is_train: 64 | trainer.train() 65 | else: 66 | trainer.validate() 67 | 68 | 69 | if __name__ == "__main__": 70 | args = parser.parse_args() 71 | main(args) 72 | -------------------------------------------------------------------------------- /models/TDRG.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .trans_utils.position_encoding import build_position_encoding 5 | from .trans_utils.transformer import build_transformer 6 | 7 | 8 | class TopKMaxPooling(nn.Module): 9 | def __init__(self, kmax=1.0): 10 | super(TopKMaxPooling, self).__init__() 11 | self.kmax = kmax 12 | 13 | @staticmethod 14 | def get_positive_k(k, n): 15 | if k <= 0: 16 | return 0 17 | elif k < 1: 18 | return round(k * n) 19 | elif k > n: 20 | return int(n) 21 | else: 22 | return int(k) 23 | 24 | def forward(self, input): 25 | batch_size = input.size(0) 26 | num_channels = input.size(1) 27 | h = input.size(2) 28 | w = input.size(3) 29 | n = h * w # number of regions 30 | kmax = self.get_positive_k(self.kmax, n) 31 | sorted, indices = torch.sort(input.view(batch_size, num_channels, n), dim=2, descending=True) 32 | region_max = sorted.narrow(2, 0, kmax) 33 | output = region_max.sum(2).div_(kmax) 34 | return output.view(batch_size, num_channels) 35 | 36 | def __repr__(self): 37 | return self.__class__.__name__ + ' (kmax=' + str(self.kmax) + ')' 38 | 39 | 40 | class GraphConvolution(nn.Module): 41 | def __init__(self, in_dim, out_dim): 42 | super(GraphConvolution, self).__init__() 43 | self.relu = nn.LeakyReLU(0.2) 44 | self.weight = nn.Conv1d(in_dim, out_dim, 1) 45 | 46 | def forward(self, adj, nodes): 47 | nodes = torch.matmul(nodes, adj) 48 | nodes = self.relu(nodes) 49 | nodes = self.weight(nodes) 50 | nodes = self.relu(nodes) 51 | return nodes 52 | 53 | 54 | class TDRG(nn.Module): 55 | def __init__(self, model, num_classes): 56 | super(TDRG, self).__init__() 57 | # backbone 58 | self.layer1 = nn.Sequential( 59 | model.conv1, 60 | model.bn1, 61 | model.relu, 62 | model.maxpool, 63 | model.layer1, 64 | # model.layer2, 65 | # model.layer3, 66 | # model.layer4, 67 | ) 68 | self.layer2 = model.layer2 69 | self.layer3 = model.layer3 70 | self.layer4 = model.layer4 71 | self.backbone = nn.ModuleList([self.layer1, self.layer2, self.layer3, self.layer4]) 72 | 73 | # hyper-parameters 74 | self.num_classes = num_classes 75 | self.in_planes = 2048 76 | self.transformer_dim = 512 77 | self.gcn_dim = 512 78 | self.num_queries = 1 79 | self.n_head = 4 80 | self.num_encoder_layers = 3 81 | self.num_decoder_layers = 0 82 | 83 | # transformer 84 | self.transform_14 = nn.Conv2d(self.in_planes, self.transformer_dim, 1) 85 | self.transform_28 = nn.Conv2d(self.in_planes // 2, self.transformer_dim, 1) 86 | self.transform_7 = nn.Conv2d(self.in_planes, self.transformer_dim, 3, stride=2) 87 | 88 | self.query_embed = nn.Embedding(self.num_queries, self.transformer_dim) 89 | self.positional_embedding = build_position_encoding(hidden_dim=self.transformer_dim, mode='learned') 90 | self.transformer = build_transformer(d_model=self.transformer_dim, nhead=self.n_head, 91 | num_encoder_layers=self.num_encoder_layers, 92 | num_decoder_layers=self.num_decoder_layers) 93 | 94 | self.kmp = TopKMaxPooling(kmax=0.05) 95 | self.GMP = nn.AdaptiveMaxPool2d(1) 96 | self.GAP = nn.AdaptiveAvgPool2d(1) 97 | self.GAP1d = nn.AdaptiveAvgPool1d(1) 98 | 99 | self.trans_classifier = nn.Linear(self.transformer_dim * 3, self.num_classes) 100 | 101 | # GCN 102 | self.constraint_classifier = nn.Conv2d(self.in_planes, num_classes, (1, 1), bias=False) 103 | 104 | self.guidance_transform = nn.Conv1d(self.transformer_dim, self.transformer_dim, 1) 105 | self.guidance_conv = nn.Conv1d(self.transformer_dim * 3, self.transformer_dim * 3, 1) 106 | self.guidance_bn = nn.BatchNorm1d(self.transformer_dim * 3) 107 | self.relu = nn.LeakyReLU(0.2) 108 | self.gcn_dim_transform = nn.Conv2d(self.in_planes, self.gcn_dim, (1, 1)) 109 | 110 | self.matrix_transform = nn.Conv1d(self.gcn_dim + self.transformer_dim * 4, self.num_classes, 1) 111 | 112 | self.forward_gcn = GraphConvolution(self.transformer_dim+self.gcn_dim, self.transformer_dim+self.gcn_dim) 113 | 114 | self.mask_mat = nn.Parameter(torch.eye(self.num_classes).float()) 115 | self.gcn_classifier = nn.Conv1d(self.transformer_dim + self.gcn_dim, self.num_classes, 1) 116 | 117 | def forward_backbone(self, x): 118 | x1 = self.layer1(x) 119 | x2 = self.layer2(x1) 120 | x3 = self.layer3(x2) 121 | x4 = self.layer4(x3) 122 | return x2, x3, x4 123 | 124 | @staticmethod 125 | def cross_scale_attention(x3, x4, x5): 126 | h3, h4, h5 = x3.shape[2], x4.shape[2], x5.shape[2] 127 | h_max = max(h3, h4, h5) 128 | x3 = F.interpolate(x3, size=(h_max, h_max), mode='bilinear', align_corners=True) 129 | x4 = F.interpolate(x4, size=(h_max, h_max), mode='bilinear', align_corners=True) 130 | x5 = F.interpolate(x5, size=(h_max, h_max), mode='bilinear', align_corners=True) 131 | 132 | mul = x3 * x4 * x5 133 | x3 = x3 + mul 134 | x4 = x4 + mul 135 | x5 = x5 + mul 136 | 137 | x3 = F.interpolate(x3, size=(h3, h3), mode='bilinear', align_corners=True) 138 | x4 = F.interpolate(x4, size=(h4, h4), mode='bilinear', align_corners=True) 139 | x5 = F.interpolate(x5, size=(h5, h5), mode='bilinear', align_corners=True) 140 | return x3, x4, x5 141 | 142 | def forward_transformer(self, x3, x4): 143 | # cross scale attention 144 | x5 = self.transform_7(x4) 145 | x4 = self.transform_14(x4) 146 | x3 = self.transform_28(x3) 147 | 148 | x3, x4, x5 = self.cross_scale_attention(x3, x4, x5) 149 | 150 | # transformer encoder 151 | mask3 = torch.zeros_like(x3[:, 0, :, :], dtype=torch.bool).cuda() 152 | mask4 = torch.zeros_like(x4[:, 0, :, :], dtype=torch.bool).cuda() 153 | mask5 = torch.zeros_like(x5[:, 0, :, :], dtype=torch.bool).cuda() 154 | 155 | pos3 = self.positional_embedding(x3) 156 | pos4 = self.positional_embedding(x4) 157 | pos5 = self.positional_embedding(x5) 158 | 159 | _, feat3 = self.transformer(x3, mask3, self.query_embed.weight, pos3) 160 | _, feat4 = self.transformer(x4, mask4, self.query_embed.weight, pos4) 161 | _, feat5 = self.transformer(x5, mask5, self.query_embed.weight, pos5) 162 | 163 | # f3 f4 f5: structural guidance 164 | f3 = feat3.view(feat3.shape[0], feat3.shape[1], -1).detach() 165 | f4 = feat4.view(feat4.shape[0], feat4.shape[1], -1).detach() 166 | f5 = feat5.view(feat5.shape[0], feat5.shape[1], -1).detach() 167 | 168 | feat3 = self.GMP(feat3).view(feat3.shape[0], -1) 169 | feat4 = self.GMP(feat4).view(feat4.shape[0], -1) 170 | feat5 = self.GMP(feat5).view(feat5.shape[0], -1) 171 | 172 | feat = torch.cat((feat3, feat4, feat5), dim=1) 173 | feat = self.trans_classifier(feat) 174 | 175 | return f3, f4, f5, feat 176 | 177 | def forward_constraint(self, x): 178 | activations = self.constraint_classifier(x) 179 | out = self.kmp(activations) 180 | return out 181 | 182 | def build_nodes(self, x, f4): 183 | mask = self.constraint_classifier(x) 184 | mask = mask.view(mask.size(0), mask.size(1), -1) 185 | mask = torch.sigmoid(mask) 186 | mask = mask.transpose(1, 2) 187 | 188 | x = self.gcn_dim_transform(x) 189 | x = x.view(x.size(0), x.size(1), -1) 190 | v_g = torch.matmul(x, mask) 191 | 192 | v_t = torch.matmul(f4, mask) 193 | v_t = v_t.detach() 194 | v_t = self.guidance_transform(v_t) 195 | nodes = torch.cat((v_g, v_t), dim=1) 196 | return nodes 197 | 198 | def build_joint_correlation_matrix(self, f3, f4, f5, x): 199 | f4 = self.GAP1d(f4) 200 | f3 = self.GAP1d(f3) 201 | f5 = self.GAP1d(f5) 202 | trans_guid = torch.cat((f3, f4, f5), dim=1) 203 | 204 | trans_guid = self.guidance_conv(trans_guid) 205 | trans_guid = self.guidance_bn(trans_guid) 206 | trans_guid = self.relu(trans_guid) 207 | trans_guid = trans_guid.expand(trans_guid.size(0), trans_guid.size(1), x.size(2)) 208 | 209 | x = torch.cat((trans_guid, x), dim=1) 210 | joint_correlation = self.matrix_transform(x) 211 | joint_correlation = torch.sigmoid(joint_correlation) 212 | return joint_correlation 213 | 214 | def forward(self, x): 215 | x2, x3, x4 = self.forward_backbone(x) 216 | 217 | # structural relation 218 | f3, f4, f5, out_trans = self.forward_transformer(x3, x4) 219 | 220 | # semantic relation 221 | # semantic-aware constraints 222 | out_sac = self.forward_constraint(x4) 223 | # graph nodes 224 | V = self.build_nodes(x4, f4) 225 | # print('V', V.shape) 226 | # joint correlation 227 | A_s = self.build_joint_correlation_matrix(f3, f4, f5, V) 228 | G = self.forward_gcn(A_s, V) + V 229 | out_gcn = self.gcn_classifier(G) 230 | mask_mat = self.mask_mat.detach() 231 | out_gcn = (out_gcn * mask_mat).sum(-1) 232 | 233 | return out_trans, out_gcn, out_sac 234 | 235 | def get_config_optim(self, lr, lrp): 236 | small_lr_layers = list(map(id, self.backbone.parameters())) 237 | large_lr_layers = filter(lambda p:id(p) not in small_lr_layers, self.parameters()) 238 | return [ 239 | {'params': self.backbone.parameters(), 'lr': lr * lrp}, 240 | {'params': large_lr_layers, 'lr': lr}, 241 | ] 242 | 243 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | from .TDRG import TDRG 3 | 4 | model_dict = {'TDRG': TDRG} 5 | 6 | def get_model(num_classes, args): 7 | res101 = torchvision.models.resnet101(pretrained=True) 8 | model = model_dict[args.model_name](res101, num_classes) 9 | return model 10 | -------------------------------------------------------------------------------- /models/trans_utils/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Various positional encodings for the trans_utils. 4 | """ 5 | import math 6 | import torch 7 | from torch import nn 8 | 9 | #from util.misc import NestedTensor 10 | 11 | 12 | class PositionEmbeddingSine(nn.Module): 13 | """ 14 | This is a more standard version of the position embedding, very similar to the one 15 | used by the Attention is all you need paper, generalized to work on images. 16 | """ 17 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 18 | super(PositionEmbeddingSine, self).__init__() 19 | self.num_pos_feats = num_pos_feats 20 | self.temperature = temperature 21 | self.normalize = normalize 22 | if scale is not None and normalize is False: 23 | raise ValueError("normalize should be True if scale is passed") 24 | if scale is None: 25 | scale = 2 * math.pi 26 | self.scale = scale 27 | 28 | def forward(self, x, mask): 29 | #x = tensor_list.tensors 30 | #mask = tensor_list.mask 31 | assert mask is not None 32 | not_mask = ~mask 33 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 34 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 35 | if self.normalize: 36 | eps = 1e-6 37 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 38 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 39 | 40 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 41 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 42 | 43 | pos_x = x_embed[:, :, :, None] / dim_t 44 | pos_y = y_embed[:, :, :, None] / dim_t 45 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 46 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 47 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 48 | return pos 49 | 50 | 51 | class PositionEmbeddingLearned(nn.Module): 52 | """ 53 | Absolute pos embedding, learned. 54 | """ 55 | def __init__(self, num_pos_feats=256): 56 | super(PositionEmbeddingLearned, self).__init__() 57 | self.row_embed = nn.Embedding(50, num_pos_feats) 58 | self.col_embed = nn.Embedding(50, num_pos_feats) 59 | self.reset_parameters() 60 | 61 | def reset_parameters(self): 62 | nn.init.uniform_(self.row_embed.weight) 63 | nn.init.uniform_(self.col_embed.weight) 64 | 65 | def forward(self, x): 66 | #x = tensor_list.tensors 67 | h, w = x.shape[-2:] 68 | i = torch.arange(w, device=x.device) 69 | j = torch.arange(h, device=x.device) 70 | x_emb = self.col_embed(i) 71 | y_emb = self.row_embed(j) 72 | pos = torch.cat([ 73 | x_emb.unsqueeze(0).repeat(h, 1, 1), 74 | y_emb.unsqueeze(1).repeat(1, w, 1), 75 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 76 | return pos 77 | 78 | 79 | def build_position_encoding(hidden_dim, mode = 'sine'): 80 | N_steps = hidden_dim // 2 81 | if mode in ('v2', 'sine'): 82 | # TODO find a better way of exposing other arguments 83 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True) 84 | elif mode in ('v3', 'learned'): 85 | position_embedding = PositionEmbeddingLearned(N_steps) 86 | else: 87 | raise ValueError("not supported {args.position_embedding}") 88 | 89 | return position_embedding 90 | -------------------------------------------------------------------------------- /models/trans_utils/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | DETR Transformer class. 4 | 5 | Copy-paste from torch.nn.Transformer with modifications: 6 | * positional encodings are passed in MHattention 7 | * extra LN at the end of encoder is removed 8 | * decoder returns a stack of activations from all decoding layers 9 | """ 10 | import copy 11 | from typing import Optional, List 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch import nn, Tensor 16 | 17 | 18 | class Transformer(nn.Module): 19 | 20 | def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, 21 | num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, 22 | activation="relu", normalize_before=False, 23 | return_intermediate_dec=False): 24 | super(Transformer, self).__init__() 25 | 26 | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, 27 | dropout, activation, normalize_before) 28 | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None 29 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) 30 | 31 | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, 32 | dropout, activation, normalize_before) 33 | decoder_norm = nn.LayerNorm(d_model) 34 | self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, 35 | return_intermediate=return_intermediate_dec) 36 | 37 | self._reset_parameters() 38 | 39 | self.d_model = d_model 40 | self.nhead = nhead 41 | 42 | def _reset_parameters(self): 43 | for p in self.parameters(): 44 | if p.dim() > 1: 45 | nn.init.xavier_uniform_(p) 46 | 47 | def forward(self, src, mask, query_embed, pos_embed): 48 | # flatten NxCxHxW to HWxNxC 49 | bs, c, h, w = src.shape 50 | src = src.flatten(2).permute(0, 2, 1)#(2, 0, 1) 51 | pos_embed = pos_embed.flatten(2).permute(0, 2, 1) 52 | query_embed = query_embed.unsqueeze(1).repeat(bs, 1, 1) 53 | mask = mask.flatten(1).permute(1,0) 54 | #print('mask', mask.shape) 55 | tgt = torch.zeros_like(query_embed) 56 | 57 | memory = self.encoder(src, src_key_padding_mask=mask, pos=None) 58 | hs = memory # for fast inference 59 | #hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, 60 | # pos=pos_embed, query_pos=query_embed) 61 | return hs.transpose(1, 2), memory.permute(0, 2, 1).view(bs, c, h, w) 62 | 63 | 64 | class TransformerEncoder(nn.Module): 65 | 66 | def __init__(self, encoder_layer, num_layers, norm=None): 67 | super(TransformerEncoder, self).__init__() 68 | self.layers = _get_clones(encoder_layer, num_layers) 69 | self.num_layers = num_layers 70 | self.norm = norm 71 | 72 | def forward(self, src, 73 | mask = None, 74 | src_key_padding_mask = None, 75 | pos= None): 76 | output = src 77 | 78 | for layer in self.layers: 79 | output = layer(output, src_mask=mask, 80 | src_key_padding_mask=src_key_padding_mask, pos=pos) 81 | 82 | if self.norm is not None: 83 | output = self.norm(output) 84 | 85 | return output 86 | 87 | 88 | class TransformerDecoder(nn.Module): 89 | 90 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): 91 | super(TransformerDecoder, self).__init__() 92 | self.layers = _get_clones(decoder_layer, num_layers) 93 | self.num_layers = num_layers 94 | self.norm = norm 95 | self.return_intermediate = return_intermediate 96 | 97 | def forward(self, tgt, memory, 98 | tgt_mask = None, 99 | memory_mask = None, 100 | tgt_key_padding_mask = None, 101 | memory_key_padding_mask= None, 102 | pos = None, 103 | query_pos = None): 104 | output = tgt 105 | 106 | intermediate = [] 107 | 108 | for layer in self.layers: 109 | output = layer(output, memory, tgt_mask=tgt_mask, 110 | memory_mask=memory_mask, 111 | tgt_key_padding_mask=tgt_key_padding_mask, 112 | memory_key_padding_mask=memory_key_padding_mask, 113 | pos=pos, query_pos=query_pos) 114 | if self.return_intermediate: 115 | intermediate.append(self.norm(output)) 116 | 117 | if self.norm is not None: 118 | output = self.norm(output) 119 | if self.return_intermediate: 120 | intermediate.pop() 121 | intermediate.append(output) 122 | 123 | if self.return_intermediate: 124 | return torch.stack(intermediate) 125 | 126 | return output.unsqueeze(0) 127 | 128 | 129 | class TransformerEncoderLayer(nn.Module): 130 | 131 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 132 | activation="relu", normalize_before=False): 133 | super(TransformerEncoderLayer, self).__init__() 134 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 135 | # Implementation of Feedforward model 136 | self.linear1 = nn.Linear(d_model, dim_feedforward) 137 | self.dropout = nn.Dropout(dropout) 138 | self.linear2 = nn.Linear(dim_feedforward, d_model) 139 | 140 | self.norm1 = nn.LayerNorm(d_model) 141 | self.norm2 = nn.LayerNorm(d_model) 142 | self.dropout1 = nn.Dropout(dropout) 143 | self.dropout2 = nn.Dropout(dropout) 144 | 145 | self.activation = _get_activation_fn(activation) 146 | self.normalize_before = normalize_before 147 | 148 | def with_pos_embed(self, tensor, pos=None): 149 | return tensor if pos is None else tensor + pos 150 | 151 | def without_pos_embed(self, tensor, pos=None): 152 | return tensor 153 | 154 | def forward_post(self, 155 | src, 156 | src_mask= None, 157 | src_key_padding_mask = None, 158 | pos= None): 159 | q = k = self.with_pos_embed(src, pos) 160 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, 161 | key_padding_mask=src_key_padding_mask)[0] 162 | src = src + self.dropout1(src2) 163 | src = self.norm1(src) 164 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 165 | src = src + self.dropout2(src2) 166 | src = self.norm2(src) 167 | return src 168 | 169 | def forward_pre(self, src, 170 | src_mask= None, 171 | src_key_padding_mask = None, 172 | pos = None): 173 | src2 = self.norm1(src) 174 | q = k = self.with_pos_embed(src2, pos) 175 | src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, 176 | key_padding_mask=src_key_padding_mask)[0] 177 | src = src + self.dropout1(src2) 178 | src2 = self.norm2(src) 179 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) 180 | src = src + self.dropout2(src2) 181 | return src 182 | 183 | def forward(self, src, 184 | src_mask = None, 185 | src_key_padding_mask = None, 186 | pos = None): 187 | if self.normalize_before: 188 | return self.forward_pre(src, src_mask, src_key_padding_mask, pos) 189 | return self.forward_post(src, src_mask, src_key_padding_mask, pos) 190 | 191 | 192 | class TransformerDecoderLayer(nn.Module): 193 | 194 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 195 | activation="relu", normalize_before=False): 196 | super(TransformerDecoderLayer, self).__init__() 197 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 198 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 199 | # Implementation of Feedforward model 200 | self.linear1 = nn.Linear(d_model, dim_feedforward) 201 | self.dropout = nn.Dropout(dropout) 202 | self.linear2 = nn.Linear(dim_feedforward, d_model) 203 | 204 | self.norm1 = nn.LayerNorm(d_model) 205 | self.norm2 = nn.LayerNorm(d_model) 206 | self.norm3 = nn.LayerNorm(d_model) 207 | self.dropout1 = nn.Dropout(dropout) 208 | self.dropout2 = nn.Dropout(dropout) 209 | self.dropout3 = nn.Dropout(dropout) 210 | 211 | self.activation = _get_activation_fn(activation) 212 | self.normalize_before = normalize_before 213 | 214 | def with_pos_embed(self, tensor, pos): 215 | return tensor if pos is None else tensor + pos 216 | 217 | def forward_post(self, tgt, memory, 218 | tgt_mask = None, 219 | memory_mask= None, 220 | tgt_key_padding_mask = None, 221 | memory_key_padding_mask = None, 222 | pos = None, 223 | query_pos = None): 224 | q = k = self.with_pos_embed(tgt, query_pos) 225 | tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, 226 | key_padding_mask=tgt_key_padding_mask)[0] 227 | tgt = tgt + self.dropout1(tgt2) 228 | tgt = self.norm1(tgt) 229 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), 230 | key=self.with_pos_embed(memory, pos), 231 | value=memory, attn_mask=memory_mask, 232 | key_padding_mask=memory_key_padding_mask)[0] 233 | tgt = tgt + self.dropout2(tgt2) 234 | tgt = self.norm2(tgt) 235 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 236 | tgt = tgt + self.dropout3(tgt2) 237 | tgt = self.norm3(tgt) 238 | return tgt 239 | 240 | def forward_pre(self, tgt, memory, 241 | tgt_mask = None, 242 | memory_mask = None, 243 | tgt_key_padding_mask = None, 244 | memory_key_padding_mask = None, 245 | pos = None, 246 | query_pos = None): 247 | tgt2 = self.norm1(tgt) 248 | q = k = self.with_pos_embed(tgt2, query_pos) 249 | tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, 250 | key_padding_mask=tgt_key_padding_mask)[0] 251 | tgt = tgt + self.dropout1(tgt2) 252 | tgt2 = self.norm2(tgt) 253 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), 254 | key=self.with_pos_embed(memory, pos), 255 | value=memory, attn_mask=memory_mask, 256 | key_padding_mask=memory_key_padding_mask)[0] 257 | tgt = tgt + self.dropout2(tgt2) 258 | tgt2 = self.norm3(tgt) 259 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 260 | tgt = tgt + self.dropout3(tgt2) 261 | return tgt 262 | 263 | def forward(self, tgt, memory, 264 | tgt_mask = None, 265 | memory_mask = None, 266 | tgt_key_padding_mask = None, 267 | memory_key_padding_mask = None, 268 | pos = None, 269 | query_pos = None): 270 | if self.normalize_before: 271 | return self.forward_pre(tgt, memory, tgt_mask, memory_mask, 272 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 273 | return self.forward_post(tgt, memory, tgt_mask, memory_mask, 274 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 275 | 276 | 277 | def _get_clones(module, N): 278 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 279 | 280 | 281 | def build_transformer(d_model, nhead, num_encoder_layers, num_decoder_layers): 282 | return Transformer( 283 | d_model=d_model, nhead=nhead, num_encoder_layers=num_encoder_layers, 284 | num_decoder_layers=num_decoder_layers, dim_feedforward=2048, dropout=0, 285 | activation="gelu", normalize_before=False, 286 | return_intermediate_dec=True) 287 | ''' 288 | return Transformer( 289 | d_model=cfg.hidden_dim, 290 | dropout=cfg.dropout, 291 | nhead=cfg.nheads, 292 | dim_feedforward=cfg.dim_feedforward, 293 | num_encoder_layers=cfg.enc_layers, 294 | num_decoder_layers=cfg.dec_layers, 295 | normalize_before=cfg.pre_norm, 296 | return_intermediate_dec=True, 297 | ) 298 | ''' 299 | 300 | def _get_activation_fn(activation): 301 | """Return an activation function given a string""" 302 | if activation == "relu": 303 | return F.relu 304 | if activation == "gelu": 305 | return F.gelu 306 | if activation == "glu": 307 | return F.glu 308 | raise RuntimeError("activation should be relu/gelu, not {activation}.") 309 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import time 4 | from datetime import datetime 5 | 6 | import numpy as np 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | from torch.autograd import Variable 10 | 11 | from util import AverageMeter, AveragePrecisionMeter 12 | 13 | 14 | class Trainer(object): 15 | def __init__(self, model, criterion, train_loader, val_loader, args): 16 | self.model = model 17 | self.criterion = criterion 18 | self.train_loader = train_loader 19 | self.val_loader = val_loader 20 | self.args = args 21 | # pprint (self.args) 22 | print('--------Args Items----------') 23 | for k, v in vars(self.args).items(): 24 | print('{}: {}'.format(k, v)) 25 | print('--------Args Items----------\n') 26 | 27 | def initialize_optimizer_and_scheduler(self): 28 | self.optimizer = torch.optim.SGD(self.model.get_config_optim(self.args.lr, self.args.lrp), 29 | lr=self.args.lr, 30 | momentum=self.args.momentum, 31 | weight_decay=self.args.weight_decay) 32 | # self.lr_scheduler = lr_scheduler.MultiStepLR(self.optimizer, self.args.epoch_step, gamma=0.1) 33 | 34 | def initialize_meters(self): 35 | self.meters = {} 36 | # meters 37 | self.meters['loss'] = AverageMeter('loss') 38 | self.meters['ap_meter'] = AveragePrecisionMeter() 39 | # time measure 40 | self.meters['batch_time'] = AverageMeter('batch_time') 41 | self.meters['data_time'] = AverageMeter('data_time') 42 | 43 | def initialization(self, is_train=False): 44 | """ initialize self.model and self.criterion here """ 45 | 46 | if is_train: 47 | self.start_epoch = 0 48 | self.epoch = 0 49 | self.end_epoch = self.args.epochs 50 | self.best_score = 0. 51 | self.lr_now = self.args.lr 52 | 53 | # initialize some settings 54 | self.initialize_optimizer_and_scheduler() 55 | 56 | self.initialize_meters() 57 | 58 | # load checkpoint if args.resume is a valid checkpint file. 59 | if os.path.isfile(self.args.resume) and self.args.resume.endswith('pth'): 60 | self.load_checkpoint() 61 | 62 | if torch.cuda.is_available(): 63 | cudnn.benchmark = True 64 | self.model = torch.nn.DataParallel(self.model).cuda() 65 | self.criterion = self.criterion.cuda() 66 | # self.train_loader.pin_memory = True 67 | # self.val_loader.pin_memory = True 68 | 69 | def reset_meters(self): 70 | for k, v in self.meters.items(): 71 | self.meters[k].reset() 72 | 73 | def on_start_epoch(self): 74 | self.reset_meters() 75 | 76 | def on_end_epoch(self, is_train=False): 77 | 78 | if is_train: 79 | # maybe you can do something like 'print the training results' here. 80 | return 81 | else: 82 | # map = self.meters['ap_meter'].value().mean() 83 | ap = self.meters['ap_meter'].value() 84 | print(ap) 85 | map = ap.mean() 86 | loss = self.meters['loss'].average() 87 | data_time = self.meters['data_time'].average() 88 | batch_time = self.meters['batch_time'].average() 89 | 90 | OP, OR, OF1, CP, CR, CF1 = self.meters['ap_meter'].overall() 91 | OP_k, OR_k, OF1_k, CP_k, CR_k, CF1_k = self.meters['ap_meter'].overall_topk(3) 92 | 93 | print('* Test\nLoss: {loss:.4f}\t mAP: {map:.4f}\t' 94 | 'Data_time: {data_time:.4f}\t Batch_time: {batch_time:.4f}'.format( 95 | loss=loss, map=map, data_time=data_time, batch_time=batch_time)) 96 | print('OP: {OP:.3f}\t OR: {OR:.3f}\t OF1: {OF1:.3f}\t' 97 | 'CP: {CP:.3f}\t CR: {CR:.3f}\t CF1: {CF1:.3f}'.format( 98 | OP=OP, OR=OR, OF1=OF1, CP=CP, CR=CR, CF1=CF1)) 99 | print('OP_3: {OP:.3f}\t OR_3: {OR:.3f}\t OF1_3: {OF1:.3f}\t' 100 | 'CP_3: {CP:.3f}\t CR_3: {CR:.3f}\t CF1_3: {CF1:.3f}'.format( 101 | OP=OP_k, OR=OR_k, OF1=OF1_k, CP=CP_k, CR=CR_k, CF1=CF1_k)) 102 | 103 | return map 104 | 105 | def on_forward(self, inputs, targets, is_train): 106 | inputs = Variable(inputs).float() 107 | targets = Variable(targets).float() 108 | 109 | if not is_train: 110 | with torch.no_grad(): 111 | out_trans, out_gcn, out_sac = self.model(inputs) 112 | else: 113 | out_trans, out_gcn, out_sac = self.model(inputs) 114 | outputs = (0.7 * out_trans + 0.3 * out_gcn) 115 | 116 | loss = self.criterion(outputs, targets) + \ 117 | self.criterion(out_trans, targets) + \ 118 | self.criterion(out_gcn, targets) + \ 119 | self.criterion(out_sac, targets) 120 | self.meters['loss'].update(loss.item(), inputs.size(0)) 121 | 122 | if is_train: 123 | self.optimizer.zero_grad() 124 | loss.backward() 125 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.args.max_clip_grad_norm) 126 | self.optimizer.step() 127 | 128 | return outputs 129 | 130 | def adjust_learning_rate(self): 131 | """ Sets learning rate if it is needed """ 132 | lr_list = [] 133 | decay = 0.1 if sum(self.epoch == np.array(self.args.epoch_step)) > 0 else 1.0 134 | for param_group in self.optimizer.param_groups: 135 | param_group['lr'] = param_group['lr'] * decay 136 | lr_list.append(param_group['lr']) 137 | 138 | return np.unique(lr_list) 139 | 140 | def train(self): 141 | self.initialization(is_train=True) 142 | 143 | for epoch in range(self.start_epoch, self.end_epoch): 144 | self.lr_now = self.adjust_learning_rate() 145 | print('Lr: {}'.format(self.lr_now)) 146 | 147 | self.epoch = epoch 148 | # train for one epoch 149 | self.run_iteration(self.train_loader, is_train=True) 150 | 151 | # evaluate on validation set 152 | score = self.run_iteration(self.val_loader, is_train=False) 153 | 154 | # record best score, save checkpoint and result 155 | is_best = score > self.best_score 156 | self.best_score = max(score, self.best_score) 157 | checkpoint = { 158 | 'epoch': epoch + 1, 159 | 'model_name': self.args.model_name, 160 | 'state_dict': self.model.module.state_dict() if torch.cuda.is_available() else self.model.state_dict(), 161 | 'best_score': self.best_score 162 | } 163 | model_dir = self.args.save_dir 164 | # assert os.path.exists(model_dir) == True 165 | self.save_checkpoint(checkpoint, model_dir, is_best) 166 | self.save_result(model_dir, is_best) 167 | 168 | print(' * best mAP={best:.4f}'.format(best=self.best_score)) 169 | 170 | return self.best_score 171 | 172 | def run_iteration(self, data_loader, is_train=True): 173 | self.on_start_epoch() 174 | 175 | if not is_train: 176 | # data_loader = tqdm(data_loader, desc='Validate') 177 | self.model.eval() 178 | else: 179 | self.model.train() 180 | 181 | st_time = time.time() 182 | for i, data in enumerate(data_loader): 183 | 184 | # measure data loading time 185 | data_time = time.time() - st_time 186 | self.meters['data_time'].update(data_time) 187 | 188 | # inputs, targets, targets_gt, filenames = self.on_start_batch(data) 189 | inputs = data['image'] 190 | targets = data['target'] 191 | 192 | # for voc 193 | labels = targets.clone() 194 | targets[targets == 0] = 1 195 | targets[targets == -1] = 0 196 | 197 | if torch.cuda.is_available(): 198 | inputs = inputs.cuda() 199 | targets = targets.cuda() 200 | 201 | outputs = self.on_forward(inputs, targets, is_train=is_train) 202 | 203 | # measure elapsed time 204 | batch_time = time.time() - st_time 205 | self.meters['batch_time'].update(batch_time) 206 | 207 | self.meters['ap_meter'].add(outputs.data, labels.data, data['name']) 208 | st_time = time.time() 209 | 210 | if is_train and i % self.args.display_interval == 0: 211 | print('{}, {} Epoch, {} Iter, Loss: {:.4f}, Data time: {:.4f}, Batch time: {:.4f}'.format( 212 | datetime.now().strftime('%Y-%m-%d %H:%M:%S'), self.epoch + 1, i, 213 | self.meters['loss'].value(), self.meters['data_time'].value(), 214 | self.meters['batch_time'].value())) 215 | 216 | return self.on_end_epoch(is_train=is_train) 217 | 218 | def validate(self): 219 | self.initialization(is_train=False) 220 | 221 | map = self.run_iteration(self.val_loader, is_train=False) 222 | 223 | model_dir = os.path.dirname(self.args.resume) 224 | assert os.path.exists(model_dir) == True 225 | self.save_result(model_dir, is_best=False) 226 | 227 | return map 228 | 229 | def load_checkpoint(self): 230 | print("* Loading checkpoint '{}'".format(self.args.resume)) 231 | checkpoint = torch.load(self.args.resume) 232 | self.start_epoch = checkpoint['epoch'] 233 | self.best_score = checkpoint['best_score'] 234 | model_dict = self.model.state_dict() 235 | for k, v in checkpoint['state_dict'].items(): 236 | if k in model_dict and v.shape == model_dict[k].shape: 237 | model_dict[k] = v 238 | else: 239 | print('\tMismatched layers: {}'.format(k)) 240 | self.model.load_state_dict(model_dict) 241 | 242 | # only for original pretrained model 243 | def load_origin_checkpoint(self): 244 | print("* Loading checkpoint '{}'".format(self.args.resume)) 245 | checkpoint = torch.load(self.args.resume) 246 | self.start_epoch = checkpoint['epoch'] 247 | self.best_score = checkpoint['best_score'] 248 | model_dict = self.model.state_dict() 249 | for k, v in checkpoint['state_dict'].items(): 250 | if 'features.' in k: 251 | model_dict[k.replace('features.', 'layer1.')] = v 252 | elif 'bottleneck.' in k or 'classifier_global.' in k or 'bn_position' in k or 'gcn.static_' in k: 253 | pass 254 | elif 'classifier.' in k: 255 | model_dict[k.replace('classifier.', 'trans_classifier.')] = v 256 | elif 'conv_position.' in k: 257 | model_dict[k.replace('conv_position.', 'guidance_transform.')] = v 258 | elif 'fc.' in k: 259 | model_dict[k.replace('fc.', 'constraint_classifier.')] = v 260 | elif 'conv_transform' in k: 261 | model_dict[k.replace('conv_transform', 'gcn_dim_transform')] = v 262 | elif 'gcn.conv_global' in k: 263 | model_dict[k.replace('gcn.conv_global', 'guidance_conv')] = v 264 | elif 'gcn.bn_global' in k: 265 | model_dict[k.replace('gcn.bn_global', 'guidance_bn')] = v 266 | elif 'gcn.conv_create_co_mat' in k: 267 | model_dict[k.replace('gcn.conv_create_co_mat', 'matrix_transform')] = v 268 | elif 'gcn.dynamic_weight' in k: 269 | model_dict[k.replace('gcn.dynamic_weight', 'forward_gcn.weight')] = v 270 | elif 'last_linear' in k: 271 | model_dict[k.replace('last_linear', 'gcn_classifier')] = v 272 | elif k in model_dict and v.shape == model_dict[k].shape: 273 | model_dict[k] = v 274 | else: 275 | print('\tMismatched layers: {}'.format(k)) 276 | self.model.load_state_dict(model_dict) 277 | 278 | 279 | def save_checkpoint(self, checkpoint, model_dir, is_best=False): 280 | if not os.path.exists(model_dir): 281 | os.makedirs(model_dir) 282 | 283 | # filename = 'Epoch-{}.pth'.format(self.epoch) 284 | filename = 'checkpoint.pth' 285 | res_path = os.path.join(model_dir, filename) 286 | print('Save checkpoint to {}'.format(res_path)) 287 | torch.save(checkpoint, res_path) 288 | if is_best: 289 | filename_best = 'checkpoint_best.pth' 290 | res_path_best = os.path.join(model_dir, filename_best) 291 | shutil.copyfile(res_path, res_path_best) 292 | 293 | def save_result(self, model_dir, is_best=False): 294 | if not os.path.exists(model_dir): 295 | os.makedirs(model_dir) 296 | 297 | # filename = 'results.csv' if not is_best else 'best_results.csv' 298 | filename = 'results.csv' 299 | res_path = os.path.join(model_dir, filename) 300 | print('Save results to {}'.format(res_path)) 301 | with open(res_path, 'w') as fid: 302 | for i in range(self.meters['ap_meter'].scores.shape[0]): 303 | fid.write('{},{},{}\n'.format(self.meters['ap_meter'].filenames[i], 304 | ','.join(map(str, self.meters['ap_meter'].scores[i].numpy())), 305 | ','.join(map(str, self.meters['ap_meter'].targets[i].numpy())))) 306 | 307 | if is_best: 308 | filename_best = 'output_best.csv' 309 | res_path_best = os.path.join(model_dir, filename_best) 310 | shutil.copyfile(res_path, res_path_best) 311 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import os, sys, pdb 2 | import math 3 | import torch 4 | from PIL import Image 5 | import numpy as np 6 | import random 7 | 8 | class AverageMeter(object): 9 | """Computes and stores the average and current value""" 10 | def __init__(self, name, fmt=':f'): 11 | self.name = name 12 | self.fmt = fmt 13 | self.reset() 14 | 15 | def reset(self): 16 | self.val = 0 17 | self.avg = 0 18 | self.sum = 0 19 | self.count = 0 20 | 21 | def update(self, val, n=1): 22 | self.val = val 23 | self.sum += val * n 24 | self.count += n 25 | self.avg = self.sum / self.count 26 | 27 | def average(self): 28 | return self.avg 29 | 30 | def value(self): 31 | return self.val 32 | 33 | def __str__(self): 34 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 35 | return fmtstr.format(**self.__dict__) 36 | 37 | class AveragePrecisionMeter(object): 38 | """ 39 | The APMeter measures the average precision per class. 40 | The APMeter is designed to operate on `NxK` Tensors `output` and 41 | `target`, and optionally a `Nx1` Tensor weight where (1) the `output` 42 | contains model output scores for `N` examples and `K` classes that ought to 43 | be higher when the model is more convinced that the example should be 44 | positively labeled, and smaller when the model believes the example should 45 | be negatively labeled (for instance, the output of a sigmoid function); (2) 46 | the `target` contains only values 0 (for negative examples) and 1 47 | (for positive examples); and (3) the `weight` ( > 0) represents weight for 48 | each sample. 49 | """ 50 | 51 | def __init__(self, difficult_examples=True): 52 | super(AveragePrecisionMeter, self).__init__() 53 | self.reset() 54 | self.difficult_examples = difficult_examples 55 | 56 | def reset(self): 57 | """Resets the meter with empty member variables""" 58 | self.scores = torch.FloatTensor(torch.FloatStorage()) 59 | self.targets = torch.LongTensor(torch.LongStorage()) 60 | self.filenames = [] 61 | 62 | def add(self, output, target, filename): 63 | """ 64 | Args: 65 | output (Tensor): NxK tensor that for each of the N examples 66 | indicates the probability of the example belonging to each of 67 | the K classes, according to the model. The probabilities should 68 | sum to one over all classes 69 | target (Tensor): binary NxK tensort that encodes which of the K 70 | classes are associated with the N-th input 71 | (eg: a row [0, 1, 0, 1] indicates that the example is 72 | associated with classes 2 and 4) 73 | weight (optional, Tensor): Nx1 tensor representing the weight for 74 | each example (each weight > 0) 75 | """ 76 | if not torch.is_tensor(output): 77 | output = torch.from_numpy(output) 78 | if not torch.is_tensor(target): 79 | target = torch.from_numpy(target) 80 | 81 | if output.dim() == 1: 82 | output = output.view(-1, 1) 83 | else: 84 | assert output.dim() == 2, \ 85 | 'wrong output size (should be 1D or 2D with one column \ 86 | per class)' 87 | if target.dim() == 1: 88 | target = target.view(-1, 1) 89 | else: 90 | assert target.dim() == 2, \ 91 | 'wrong target size (should be 1D or 2D with one column \ 92 | per class)' 93 | if self.scores.numel() > 0: 94 | assert target.size(1) == self.targets.size(1), \ 95 | 'dimensions for output should match previously added examples.' 96 | 97 | # make sure storage is of sufficient size 98 | if self.scores.storage().size() < self.scores.numel() + output.numel(): 99 | new_size = math.ceil(self.scores.storage().size() * 1.5) 100 | self.scores.storage().resize_(int(new_size + output.numel())) 101 | self.targets.storage().resize_(int(new_size + output.numel())) 102 | 103 | # store scores and targets 104 | offset = self.scores.size(0) if self.scores.dim() > 0 else 0 105 | self.scores.resize_(offset + output.size(0), output.size(1)) 106 | self.targets.resize_(offset + target.size(0), target.size(1)) 107 | self.scores.narrow(0, offset, output.size(0)).copy_(output) 108 | self.targets.narrow(0, offset, target.size(0)).copy_(target) 109 | 110 | self.filenames += filename # record filenames 111 | 112 | def value(self): 113 | """Returns the model's average precision for each class 114 | Return: 115 | ap (FloatTensor): 1xK tensor, with avg precision for each class k 116 | """ 117 | 118 | if self.scores.numel() == 0: 119 | return 0 120 | ap = torch.zeros(self.scores.size(1)) 121 | rg = torch.arange(1, self.scores.size(0)).float() 122 | # compute average precision for each class 123 | for k in range(self.scores.size(1)): 124 | # sort scores 125 | scores = self.scores[:, k] 126 | targets = self.targets[:, k] 127 | # compute average precision 128 | ap[k] = AveragePrecisionMeter.average_precision(scores, targets, self.difficult_examples) 129 | return ap 130 | 131 | @staticmethod 132 | def average_precision(output, target, difficult_examples=True): 133 | 134 | # sort examples 135 | sorted, indices = torch.sort(output, dim=0, descending=True) 136 | 137 | # Computes prec@i 138 | pos_count = 0. 139 | total_count = 0. 140 | precision_at_i = 0. 141 | for i in indices: 142 | label = target[i] 143 | if difficult_examples and label == 0: 144 | continue 145 | if label == 1: 146 | pos_count += 1 147 | total_count += 1 148 | if label == 1: 149 | precision_at_i += pos_count / total_count 150 | precision_at_i /= pos_count 151 | return precision_at_i 152 | 153 | def overall(self): 154 | if self.scores.numel() == 0: 155 | return 0 156 | scores = self.scores.cpu().numpy() 157 | targets = self.targets.clone().cpu().numpy() 158 | targets[targets == -1] = 0 159 | return self.evaluation(scores, targets) 160 | 161 | def overall_topk(self, k): 162 | targets = self.targets.clone().cpu().numpy() 163 | targets[targets == -1] = 0 164 | n, c = self.scores.size() 165 | scores = np.zeros((n, c)) - 1 166 | index = self.scores.topk(k, 1, True, True)[1].cpu().numpy() 167 | tmp = self.scores.cpu().numpy() 168 | for i in range(n): 169 | for ind in index[i]: 170 | scores[i, ind] = 1 if tmp[i, ind] >= 0 else -1 171 | return self.evaluation(scores, targets) 172 | 173 | def evaluation(self, scores_, targets_): 174 | n, n_class = scores_.shape 175 | Nc, Np, Ng = np.zeros(n_class), np.zeros(n_class), np.zeros(n_class) 176 | for k in range(n_class): 177 | scores = scores_[:, k] 178 | targets = targets_[:, k] 179 | targets[targets == -1] = 0 180 | Ng[k] = np.sum(targets == 1) 181 | Np[k] = np.sum(scores >= 0) 182 | Nc[k] = np.sum(targets * (scores >= 0)) 183 | Np[Np == 0] = 1 184 | OP = np.sum(Nc) / np.sum(Np) 185 | OR = np.sum(Nc) / np.sum(Ng) 186 | OF1 = (2 * OP * OR) / (OP + OR) 187 | 188 | CP = np.sum(Nc / Np) / n_class 189 | CR = np.sum(Nc / Ng) / n_class 190 | CF1 = (2 * CP * CR) / (CP + CR) 191 | return OP, OR, OF1, CP, CR, CF1 192 | --------------------------------------------------------------------------------