├── LICENSE ├── README.md ├── data_list ├── pa-100k │ ├── test.txt │ └── train_val.txt ├── peta │ ├── PETA_test_list.txt │ └── PETA_train_list.txt └── rap │ ├── test.txt │ └── train.txt ├── main.py ├── model ├── __init__.py └── inception_iccv.py └── utils └── datasets.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Improving Pedestrian Attribute Recognition With Weakly-Supervised Multi-Scale Attribute-Specific Localization 2 | 3 | Code for the paper "Improving Pedestrian Attribute Recognition With Weakly-Supervised Multi-Scale Attribute-Specific Localization", ICCV 2019, Seoul. 4 | 5 | [[Paper]](https://arxiv.org/abs/1910.04562) [[Poster]](https://chufengt.github.io/publication/pedestrian-attribute/iccv_poster_id2029.pdf) 6 | 7 | Contact: chufeng.t@foxmail.com or tcf18@mails.tsinghua.edu.cn 8 | 9 | ## Environment 10 | 11 | - Python 3.6+ 12 | - PyTorch 0.4+ 13 | 14 | ## Datasets 15 | 16 | - RAP: http://rap.idealtest.org/ 17 | - PETA: http://mmlab.ie.cuhk.edu.hk/projects/PETA.html 18 | - PA-100K: https://github.com/xh-liu/HydraPlus-Net 19 | 20 | The original datasets should be processed to match the DataLoader. 21 | 22 | We provide the label lists for training and testing. 23 | 24 | ## Training and Testing 25 | 26 | ``` 27 | python main.py --approach=inception_iccv --experiment=rap 28 | ``` 29 | 30 | ``` 31 | python main.py --approach=inception_iccv --experiment=rap -e --resume='model_path' 32 | ``` 33 | 34 | ## Pretrained Models 35 | 36 | We provide the pretrained models for reference, the results may slightly different with the values reported in our paper. 37 | 38 | | Dataset | mA | Link | 39 | | ------- | ----- | ------------------------------------------------------------ | 40 | | PETA | 86.34 | [Model](https://drive.google.com/file/d/1cvX43Qn_vydzT_jnmgwYUUe9hIA161PH/view?usp=sharing) | 41 | | RAP | 81.86 | [Model](https://drive.google.com/file/d/15paMK0-rKDsuzptDPK5kH2JuL8QO0HyS/view?usp=sharing) | 42 | | PA-100K | 80.45 | [Model](https://drive.google.com/file/d/1xIw3jpvE1pDC3U464kcFJ58iSKCRNQ63/view?usp=sharing) | 43 | 44 | ## Reference 45 | 46 | If this work is useful to your research, please cite: 47 | 48 | ``` 49 | @inproceedings{tang2019improving, 50 | title={Improving Pedestrian Attribute Recognition With Weakly-Supervised Multi-Scale Attribute-Specific Localization}, 51 | author={Tang, Chufeng and Sheng, Lu and Zhang, Zhaoxiang and Hu, Xiaolin}, 52 | booktitle={Proceedings of the IEEE International Conference on Computer Vision}, 53 | pages={4997--5006}, 54 | year={2019} 55 | } 56 | ``` 57 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import time 5 | import sys 6 | import numpy as np 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.parallel 11 | import torch.backends.cudnn as cudnn 12 | import torch.optim 13 | import model as models 14 | 15 | from utils.datasets import Get_Dataset 16 | 17 | parser = argparse.ArgumentParser(description='Pedestrian Attribute Framework') 18 | parser.add_argument('--experiment', default='rap', type=str, required=True, help='(default=%(default)s)') 19 | parser.add_argument('--approach', default='inception_iccv', type=str, required=True, help='(default=%(default)s)') 20 | parser.add_argument('--epochs', default=60, type=int, required=False, help='(default=%(default)d)') 21 | parser.add_argument('--batch_size', default=32, type=int, required=False, help='(default=%(default)d)') 22 | parser.add_argument('--lr', '--learning-rate', default=0.0001, type=float, required=False, help='(default=%(default)f)') 23 | parser.add_argument('--optimizer', default='adam', type=str, required=False, help='(default=%(default)s)') 24 | parser.add_argument('--momentum', default=0.9, type=float, required=False, help='(default=%(default)f)') 25 | parser.add_argument('--weight_decay', default=0.0005, type=float, required=False, help='(default=%(default)f)') 26 | parser.add_argument('--start-epoch', default=0, type=int, required=False, help='(default=%(default)d)') 27 | parser.add_argument('--print_freq', default=100, type=int, required=False, help='(default=%(default)d)') 28 | parser.add_argument('--save_freq', default=10, type=int, required=False, help='(default=%(default)d)') 29 | parser.add_argument('--resume', default='', type=str, required=False, help='(default=%(default)s)') 30 | parser.add_argument('--decay_epoch', default=(20,40), type=eval, required=False, help='(default=%(default)d)') 31 | parser.add_argument('--prefix', default='', type=str, required=False, help='(default=%(default)s)') 32 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', required=False, help='evaluate model on validation set') 33 | 34 | # Seed 35 | np.random.seed(1) 36 | torch.manual_seed(1) 37 | if torch.cuda.is_available(): torch.cuda.manual_seed(1) 38 | else: print('[CUDA unavailable]'); sys.exit() 39 | best_accu = 0 40 | EPS = 1e-12 41 | 42 | ##################################################################################################### 43 | 44 | 45 | def main(): 46 | global args, best_accu 47 | args = parser.parse_args() 48 | 49 | print('=' * 100) 50 | print('Arguments = ') 51 | for arg in vars(args): 52 | print('\t' + arg + ':', getattr(args, arg)) 53 | print('=' * 100) 54 | 55 | # Data loading code 56 | train_dataset, val_dataset, attr_num, description = Get_Dataset(args.experiment, args.approach) 57 | 58 | train_loader = torch.utils.data.DataLoader( 59 | train_dataset, 60 | batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True) 61 | 62 | val_loader = torch.utils.data.DataLoader( 63 | val_dataset, 64 | batch_size=32, shuffle=False, num_workers=4, pin_memory=True) 65 | 66 | # create model 67 | model = models.__dict__[args.approach](pretrained=True, num_classes=attr_num) 68 | 69 | # get the number of model parameters 70 | print('Number of model parameters: {}'.format( 71 | sum([p.data.nelement() for p in model.parameters()]))) 72 | print('') 73 | 74 | # for training on multiple GPUs. 75 | # Use CUDA_VISIBLE_DEVICES=0,1 to specify which GPUs to use 76 | model = torch.nn.DataParallel(model).cuda() 77 | 78 | # optionally resume from a checkpoint 79 | if args.resume: 80 | if os.path.isfile(args.resume): 81 | print("=> loading checkpoint '{}'".format(args.resume)) 82 | checkpoint = torch.load(args.resume) 83 | args.start_epoch = checkpoint['epoch'] 84 | best_accu = checkpoint['best_accu'] 85 | model.load_state_dict(checkpoint['state_dict']) 86 | print("=> loaded checkpoint '{}' (epoch {})" 87 | .format(args.resume, checkpoint['epoch'])) 88 | else: 89 | print("=> no checkpoint found at '{}'".format(args.resume)) 90 | 91 | cudnn.benchmark = False 92 | cudnn.deterministic = True 93 | 94 | # define loss function 95 | criterion = Weighted_BCELoss(args.experiment) 96 | 97 | if args.optimizer == 'adam': 98 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, 99 | betas=(0.9, 0.999), 100 | weight_decay=args.weight_decay) 101 | else: 102 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 103 | momentum=args.momentum, 104 | weight_decay=args.weight_decay) 105 | 106 | 107 | if args.evaluate: 108 | test(val_loader, model, attr_num, description) 109 | return 110 | 111 | for epoch in range(args.start_epoch, args.epochs): 112 | adjust_learning_rate(optimizer, epoch, args.decay_epoch) 113 | 114 | # train for one epoch 115 | train(train_loader, model, criterion, optimizer, epoch) 116 | 117 | # evaluate on validation set 118 | accu = validate(val_loader, model, criterion, epoch) 119 | 120 | test(val_loader, model, attr_num, description) 121 | 122 | # remember best Accu and save checkpoint 123 | is_best = accu > best_accu 124 | best_accu = max(accu, best_accu) 125 | 126 | if epoch in args.decay_epoch: 127 | save_checkpoint({ 128 | 'epoch': epoch + 1, 129 | 'state_dict': model.state_dict(), 130 | 'best_accu': best_accu, 131 | }, epoch+1, args.prefix) 132 | 133 | def train(train_loader, model, criterion, optimizer, epoch): 134 | """Train for one epoch on the training set""" 135 | batch_time = AverageMeter() 136 | losses = AverageMeter() 137 | top1 = AverageMeter() 138 | model.train() 139 | 140 | end = time.time() 141 | for i, _ in enumerate(train_loader): 142 | input, target = _ 143 | target = target.cuda(non_blocking=True) 144 | input = input.cuda(non_blocking=True) 145 | output = model(input) 146 | 147 | bs = target.size(0) 148 | 149 | if type(output) == type(()) or type(output) == type([]): 150 | loss_list = [] 151 | # deep supervision 152 | for k in range(len(output)): 153 | out = output[k] 154 | loss_list.append(criterion.forward(torch.sigmoid(out), target, epoch)) 155 | loss = sum(loss_list) 156 | # maximum voting 157 | output = torch.max(torch.max(torch.max(output[0],output[1]),output[2]),output[3]) 158 | else: 159 | loss = criterion.forward(torch.sigmoid(output), target, epoch) 160 | 161 | # measure accuracy and record loss 162 | accu = accuracy(output.data, target) 163 | losses.update(loss.data, bs) 164 | top1.update(accu, bs) 165 | 166 | # compute gradient and do SGD step 167 | optimizer.zero_grad() 168 | loss.backward() 169 | optimizer.step() 170 | 171 | # measure elapsed time 172 | batch_time.update(time.time() - end) 173 | end = time.time() 174 | 175 | if i % args.print_freq == 0: 176 | print('Epoch: [{0}][{1}/{2}]\t' 177 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 178 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 179 | 'Accu {top1.val:.3f} ({top1.avg:.3f})'.format( 180 | epoch, i, len(train_loader), batch_time=batch_time, 181 | loss=losses, top1=top1)) 182 | 183 | 184 | def validate(val_loader, model, criterion, epoch): 185 | """Perform validation on the validation set""" 186 | batch_time = AverageMeter() 187 | losses = AverageMeter() 188 | top1 = AverageMeter() 189 | model.eval() 190 | 191 | end = time.time() 192 | for i, _ in enumerate(val_loader): 193 | input, target = _ 194 | target = target.cuda(non_blocking=True) 195 | input = input.cuda(non_blocking=True) 196 | output = model(input) 197 | 198 | bs = target.size(0) 199 | 200 | if type(output) == type(()) or type(output) == type([]): 201 | loss_list = [] 202 | # deep supervision 203 | for k in range(len(output)): 204 | out = output[k] 205 | loss_list.append(criterion.forward(torch.sigmoid(out), target, epoch)) 206 | loss = sum(loss_list) 207 | # maximum voting 208 | output = torch.max(torch.max(torch.max(output[0],output[1]),output[2]),output[3]) 209 | else: 210 | loss = criterion.forward(torch.sigmoid(output), target, epoch) 211 | 212 | # measure accuracy and record loss 213 | accu = accuracy(output.data, target) 214 | losses.update(loss.data, bs) 215 | top1.update(accu, bs) 216 | 217 | # measure elapsed time 218 | batch_time.update(time.time() - end) 219 | end = time.time() 220 | 221 | if i % args.print_freq == 0: 222 | print('Test: [{0}/{1}]\t' 223 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 224 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 225 | 'Accu {top1.val:.3f} ({top1.avg:.3f})'.format( 226 | i, len(val_loader), batch_time=batch_time, loss=losses, 227 | top1=top1)) 228 | 229 | print(' * Accu {top1.avg:.3f}'.format(top1=top1)) 230 | return top1.avg 231 | 232 | 233 | def test(val_loader, model, attr_num, description): 234 | model.eval() 235 | 236 | pos_cnt = [] 237 | pos_tol = [] 238 | neg_cnt = [] 239 | neg_tol = [] 240 | 241 | accu = 0.0 242 | prec = 0.0 243 | recall = 0.0 244 | tol = 0 245 | 246 | for it in range(attr_num): 247 | pos_cnt.append(0) 248 | pos_tol.append(0) 249 | neg_cnt.append(0) 250 | neg_tol.append(0) 251 | 252 | for i, _ in enumerate(val_loader): 253 | input, target = _ 254 | target = target.cuda(non_blocking=True) 255 | input = input.cuda(non_blocking=True) 256 | output = model(input) 257 | bs = target.size(0) 258 | 259 | # maximum voting 260 | if type(output) == type(()) or type(output) == type([]): 261 | output = torch.max(torch.max(torch.max(output[0],output[1]),output[2]),output[3]) 262 | 263 | 264 | batch_size = target.size(0) 265 | tol = tol + batch_size 266 | output = torch.sigmoid(output.data).cpu().numpy() 267 | output = np.where(output > 0.5, 1, 0) 268 | target = target.cpu().numpy() 269 | 270 | for it in range(attr_num): 271 | for jt in range(batch_size): 272 | if target[jt][it] == 1: 273 | pos_tol[it] = pos_tol[it] + 1 274 | if output[jt][it] == 1: 275 | pos_cnt[it] = pos_cnt[it] + 1 276 | if target[jt][it] == 0: 277 | neg_tol[it] = neg_tol[it] + 1 278 | if output[jt][it] == 0: 279 | neg_cnt[it] = neg_cnt[it] + 1 280 | 281 | if attr_num == 1: 282 | continue 283 | for jt in range(batch_size): 284 | tp = 0 285 | fn = 0 286 | fp = 0 287 | for it in range(attr_num): 288 | if output[jt][it] == 1 and target[jt][it] == 1: 289 | tp = tp + 1 290 | elif output[jt][it] == 0 and target[jt][it] == 1: 291 | fn = fn + 1 292 | elif output[jt][it] == 1 and target[jt][it] == 0: 293 | fp = fp + 1 294 | if tp + fn + fp != 0: 295 | accu = accu + 1.0 * tp / (tp + fn + fp) 296 | if tp + fp != 0: 297 | prec = prec + 1.0 * tp / (tp + fp) 298 | if tp + fn != 0: 299 | recall = recall + 1.0 * tp / (tp + fn) 300 | 301 | print('=' * 100) 302 | print('\t Attr \tp_true/n_true\tp_tol/n_tol\tp_pred/n_pred\tcur_mA') 303 | mA = 0.0 304 | for it in range(attr_num): 305 | cur_mA = ((1.0*pos_cnt[it]/pos_tol[it]) + (1.0*neg_cnt[it]/neg_tol[it])) / 2.0 306 | mA = mA + cur_mA 307 | print('\t#{:2}: {:18}\t{:4}\{:4}\t{:4}\{:4}\t{:4}\{:4}\t{:.5f}'.format(it,description[it],pos_cnt[it],neg_cnt[it],pos_tol[it],neg_tol[it],(pos_cnt[it]+neg_tol[it]-neg_cnt[it]),(neg_cnt[it]+pos_tol[it]-pos_cnt[it]),cur_mA)) 308 | mA = mA / attr_num 309 | print('\t' + 'mA: '+str(mA)) 310 | 311 | if attr_num != 1: 312 | accu = accu / tol 313 | prec = prec / tol 314 | recall = recall / tol 315 | f1 = 2.0 * prec * recall / (prec + recall) 316 | print('\t' + 'Accuracy: '+str(accu)) 317 | print('\t' + 'Precision: '+str(prec)) 318 | print('\t' + 'Recall: '+str(recall)) 319 | print('\t' + 'F1_Score: '+str(f1)) 320 | print('=' * 100) 321 | 322 | 323 | def save_checkpoint(state, epoch, prefix, filename='.pth.tar'): 324 | """Saves checkpoint to disk""" 325 | directory = "your_path" + args.experiment + '/' + args.approach + '/' 326 | if not os.path.exists(directory): 327 | os.makedirs(directory) 328 | if prefix == '': 329 | filename = directory + str(epoch) + filename 330 | else: 331 | filename = directory + prefix + '_' + str(epoch) + filename 332 | torch.save(state, filename) 333 | 334 | class AverageMeter(object): 335 | """Computes and stores the average and current value""" 336 | def __init__(self): 337 | self.reset() 338 | 339 | def reset(self): 340 | self.val = 0 341 | self.avg = 0 342 | self.sum = 0 343 | self.count = 0 344 | 345 | def update(self, val, n=1): 346 | self.val = val 347 | self.sum += val * n 348 | self.count += n 349 | self.avg = self.sum / self.count 350 | 351 | def adjust_learning_rate(optimizer, epoch, decay_epoch): 352 | lr = args.lr 353 | for epc in decay_epoch: 354 | if epoch >= epc: 355 | lr = lr * 0.1 356 | else: 357 | break 358 | print() 359 | print('Learning Rate:', lr) 360 | print() 361 | for param_group in optimizer.param_groups: 362 | param_group['lr'] = lr 363 | 364 | 365 | def accuracy(output, target): 366 | batch_size = target.size(0) 367 | attr_num = target.size(1) 368 | 369 | output = torch.sigmoid(output).cpu().numpy() 370 | output = np.where(output > 0.5, 1, 0) 371 | pred = torch.from_numpy(output).long() 372 | target = target.cpu().long() 373 | correct = pred.eq(target) 374 | correct = correct.numpy() 375 | 376 | res = [] 377 | for k in range(attr_num): 378 | res.append(1.0*sum(correct[:,k]) / batch_size) 379 | return sum(res) / attr_num 380 | 381 | 382 | class Weighted_BCELoss(object): 383 | """ 384 | Weighted_BCELoss was proposed in "Multi-attribute learning for pedestrian attribute recognition in surveillance scenarios"[13]. 385 | """ 386 | def __init__(self, experiment): 387 | super(Weighted_BCELoss, self).__init__() 388 | self.weights = None 389 | if experiment == 'pa100k': 390 | self.weights = torch.Tensor([0.460444444444, 391 | 0.0134555555556, 392 | 0.924377777778, 393 | 0.0621666666667, 394 | 0.352666666667, 395 | 0.294622222222, 396 | 0.352711111111, 397 | 0.0435444444444, 398 | 0.179977777778, 399 | 0.185, 400 | 0.192733333333, 401 | 0.1601, 402 | 0.00952222222222, 403 | 0.5834, 404 | 0.4166, 405 | 0.0494777777778, 406 | 0.151044444444, 407 | 0.107755555556, 408 | 0.0419111111111, 409 | 0.00472222222222, 410 | 0.0168888888889, 411 | 0.0324111111111, 412 | 0.711711111111, 413 | 0.173444444444, 414 | 0.114844444444, 415 | 0.006]).cuda() 416 | elif experiment == 'rap': 417 | self.weights = torch.Tensor([0.311434, 418 | 0.009980, 419 | 0.430011, 420 | 0.560010, 421 | 0.144932, 422 | 0.742479, 423 | 0.097728, 424 | 0.946303, 425 | 0.048287, 426 | 0.004328, 427 | 0.189323, 428 | 0.944764, 429 | 0.016713, 430 | 0.072959, 431 | 0.010461, 432 | 0.221186, 433 | 0.123434, 434 | 0.057785, 435 | 0.228857, 436 | 0.172779, 437 | 0.315186, 438 | 0.022147, 439 | 0.030299, 440 | 0.017843, 441 | 0.560346, 442 | 0.000553, 443 | 0.027991, 444 | 0.036624, 445 | 0.268342, 446 | 0.133317, 447 | 0.302465, 448 | 0.270891, 449 | 0.124059, 450 | 0.012432, 451 | 0.157340, 452 | 0.018132, 453 | 0.064182, 454 | 0.028111, 455 | 0.042155, 456 | 0.027558, 457 | 0.012649, 458 | 0.024504, 459 | 0.294601, 460 | 0.034099, 461 | 0.032800, 462 | 0.091812, 463 | 0.024552, 464 | 0.010388, 465 | 0.017603, 466 | 0.023446, 467 | 0.128917]).cuda() 468 | elif experiment == 'peta': 469 | self.weights = torch.Tensor([0.5016, 470 | 0.3275, 471 | 0.1023, 472 | 0.0597, 473 | 0.1986, 474 | 0.2011, 475 | 0.8643, 476 | 0.8559, 477 | 0.1342, 478 | 0.1297, 479 | 0.1014, 480 | 0.0685, 481 | 0.314, 482 | 0.2932, 483 | 0.04, 484 | 0.2346, 485 | 0.5473, 486 | 0.2974, 487 | 0.0849, 488 | 0.7523, 489 | 0.2717, 490 | 0.0282, 491 | 0.0749, 492 | 0.0191, 493 | 0.3633, 494 | 0.0359, 495 | 0.1425, 496 | 0.0454, 497 | 0.2201, 498 | 0.0178, 499 | 0.0285, 500 | 0.5125, 501 | 0.0838, 502 | 0.4605, 503 | 0.0124]).cuda() 504 | #self.weights = None 505 | 506 | def forward(self, output, target, epoch): 507 | if self.weights is not None: 508 | cur_weights = torch.exp(target + (1 - target * 2) * self.weights) 509 | loss = cur_weights * (target * torch.log(output + EPS)) + ((1 - target) * torch.log(1 - output + EPS)) 510 | else: 511 | loss = target * torch.log(output + EPS) + (1 - target) * torch.log(1 - output + EPS) 512 | return torch.neg(torch.mean(loss)) 513 | 514 | if __name__ == '__main__': 515 | main() 516 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .inception_iccv import * -------------------------------------------------------------------------------- /model/inception_iccv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.tensor as tensor 4 | from torch.nn import functional as F 5 | 6 | __all__ = ['inception_iccv'] 7 | 8 | def inception_iccv(pretrained=True, debug=False, **kwargs): 9 | model = InceptionNet(**kwargs) 10 | """ 11 | Pretrained model: 'https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/bninception.py' 12 | Initializing with basedline models (trained BN-Inception) can obtain better results. 13 | """ 14 | if pretrained: 15 | pretrained_dict = torch.load('model/bn_inception-52deb4733.pth') 16 | model_dict = model.state_dict() 17 | new_dict = {} 18 | for k,_ in model_dict.items(): 19 | raw_name = k.replace('main_branch.', '') 20 | if raw_name in pretrained_dict: 21 | new_dict[k] = pretrained_dict[raw_name] 22 | model_dict.update(new_dict) 23 | model.load_state_dict(model_dict) 24 | return model 25 | 26 | 27 | class ChannelAttn(nn.Module): 28 | def __init__(self, in_channels, reduction_rate=16): 29 | super(ChannelAttn, self).__init__() 30 | assert in_channels%reduction_rate == 0 31 | self.conv1 = nn.Conv2d(in_channels, in_channels // reduction_rate, kernel_size=1, stride=1, padding=0) 32 | self.conv2 = nn.Conv2d(in_channels // reduction_rate, in_channels, kernel_size=1, stride=1, padding=0) 33 | 34 | def forward(self, x): 35 | # squeeze operation (global average pooling) 36 | x = F.avg_pool2d(x, x.size()[2:]) 37 | # excitation operation (2 conv layers) 38 | x = F.relu(self.conv1(x)) 39 | x = self.conv2(x) 40 | return torch.sigmoid(x) 41 | 42 | 43 | class SpatialTransformBlock(nn.Module): 44 | def __init__(self, num_classes, pooling_size, channels): 45 | super(SpatialTransformBlock, self).__init__() 46 | self.num_classes = num_classes 47 | self.spatial = pooling_size 48 | 49 | self.global_pool = nn.AvgPool2d((pooling_size, pooling_size//2), stride=1, padding=0, ceil_mode=True, count_include_pad=True) 50 | 51 | self.gap_list = nn.ModuleList() 52 | self.fc_list = nn.ModuleList() 53 | self.att_list = nn.ModuleList() 54 | self.stn_list = nn.ModuleList() 55 | for i in range(self.num_classes): 56 | self.gap_list.append(nn.AvgPool2d((pooling_size, pooling_size//2), stride=1, padding=0, ceil_mode=True, count_include_pad=True)) 57 | self.fc_list.append(nn.Linear(channels, 1)) 58 | self.att_list.append(ChannelAttn(channels)) 59 | self.stn_list.append(nn.Linear(channels, 4)) 60 | 61 | def stn(self, x, theta): 62 | grid = F.affine_grid(theta, x.size()) 63 | x = F.grid_sample(x, grid, padding_mode='border') 64 | return x.cuda() 65 | 66 | def transform_theta(self, theta_i, region_idx): 67 | theta = torch.zeros(theta_i.size(0), 2, 3) 68 | theta[:,0,0] = torch.sigmoid(theta_i[:,0]) 69 | theta[:,1,1] = torch.sigmoid(theta_i[:,1]) 70 | theta[:,0,2] = torch.tanh(theta_i[:,2]) 71 | theta[:,1,2] = torch.tanh(theta_i[:,3]) 72 | theta = theta.cuda() 73 | return theta 74 | 75 | def forward(self, features): 76 | pred_list = [] 77 | bs = features.size(0) 78 | for i in range(self.num_classes): 79 | stn_feature = features * self.att_list[i](features) + features 80 | 81 | theta_i = self.stn_list[i](F.avg_pool2d(stn_feature, stn_feature.size()[2:]).view(bs,-1)).view(-1,4) 82 | theta_i = self.transform_theta(theta_i, i) 83 | 84 | sub_feature = self.stn(stn_feature, theta_i) 85 | pred = self.gap_list[i](sub_feature).view(bs,-1) 86 | pred = self.fc_list[i](pred) 87 | pred_list.append(pred) 88 | pred = torch.cat(pred_list, 1) 89 | return pred 90 | 91 | 92 | class InceptionNet(nn.Module): 93 | def __init__(self, num_classes=51): 94 | super(InceptionNet, self).__init__() 95 | self.num_classes = num_classes 96 | self.main_branch = BNInception() 97 | self.global_pool = nn.AvgPool2d((8,4), stride=1, padding=0, ceil_mode=True, count_include_pad=True) 98 | self.finalfc = nn.Linear(1024, num_classes) 99 | 100 | self.st_3b = SpatialTransformBlock(num_classes, 32, 256*3) 101 | self.st_4d = SpatialTransformBlock(num_classes, 16, 256*2) 102 | self.st_5b = SpatialTransformBlock(num_classes, 8, 256) 103 | 104 | # Lateral layers 105 | self.latlayer_3b = nn.Conv2d(320, 256, kernel_size=1, stride=1, padding=0) 106 | self.latlayer_4d = nn.Conv2d(608, 256, kernel_size=1, stride=1, padding=0) 107 | self.latlayer_5b = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0) 108 | 109 | def _upsample_add(self, x, y): 110 | _,_,H,W = y.size() 111 | up_feat = F.interpolate(x, (H, W), mode='bilinear', align_corners=False) 112 | return torch.cat([up_feat,y], 1) 113 | 114 | def forward(self, input): 115 | bs = input.size(0) 116 | feat_3b, feat_4d, feat_5b = self.main_branch(input) 117 | main_feat = self.global_pool(feat_5b).view(bs,-1) 118 | main_pred = self.finalfc(main_feat) 119 | 120 | fusion_5b = self.latlayer_5b(feat_5b) 121 | fusion_4d = self._upsample_add(fusion_5b, self.latlayer_4d(feat_4d)) 122 | fusion_3b = self._upsample_add(fusion_4d, self.latlayer_3b(feat_3b)) 123 | 124 | pred_3b = self.st_3b(fusion_3b) 125 | pred_4d = self.st_4d(fusion_4d) 126 | pred_5b = self.st_5b(fusion_5b) 127 | 128 | return pred_3b, pred_4d, pred_5b, main_pred 129 | 130 | class BNInception(nn.Module): 131 | """ 132 | Copy from 'https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/bninception.py' 133 | """ 134 | def __init__(self): 135 | super(BNInception, self).__init__() 136 | inplace = True 137 | self.conv1_7x7_s2 = nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3)) 138 | self.conv1_7x7_s2_bn = nn.BatchNorm2d(64, affine=True) 139 | self.conv1_relu_7x7 = nn.ReLU(inplace) 140 | self.pool1_3x3_s2 = nn.MaxPool2d((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True) 141 | self.conv2_3x3_reduce = nn.Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1)) 142 | self.conv2_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True) 143 | self.conv2_relu_3x3_reduce = nn.ReLU(inplace) 144 | self.conv2_3x3 = nn.Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 145 | self.conv2_3x3_bn = nn.BatchNorm2d(192, affine=True) 146 | self.conv2_relu_3x3 = nn.ReLU(inplace) 147 | self.pool2_3x3_s2 = nn.MaxPool2d((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True) 148 | self.inception_3a_1x1 = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1)) 149 | self.inception_3a_1x1_bn = nn.BatchNorm2d(64, affine=True) 150 | self.inception_3a_relu_1x1 = nn.ReLU(inplace) 151 | self.inception_3a_3x3_reduce = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1)) 152 | self.inception_3a_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True) 153 | self.inception_3a_relu_3x3_reduce = nn.ReLU(inplace) 154 | self.inception_3a_3x3 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 155 | self.inception_3a_3x3_bn = nn.BatchNorm2d(64, affine=True) 156 | self.inception_3a_relu_3x3 = nn.ReLU(inplace) 157 | self.inception_3a_double_3x3_reduce = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1)) 158 | self.inception_3a_double_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True) 159 | self.inception_3a_relu_double_3x3_reduce = nn.ReLU(inplace) 160 | self.inception_3a_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 161 | self.inception_3a_double_3x3_1_bn = nn.BatchNorm2d(96, affine=True) 162 | self.inception_3a_relu_double_3x3_1 = nn.ReLU(inplace) 163 | self.inception_3a_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 164 | self.inception_3a_double_3x3_2_bn = nn.BatchNorm2d(96, affine=True) 165 | self.inception_3a_relu_double_3x3_2 = nn.ReLU(inplace) 166 | self.inception_3a_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) 167 | self.inception_3a_pool_proj = nn.Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1)) 168 | self.inception_3a_pool_proj_bn = nn.BatchNorm2d(32, affine=True) 169 | self.inception_3a_relu_pool_proj = nn.ReLU(inplace) 170 | self.inception_3b_1x1 = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) 171 | self.inception_3b_1x1_bn = nn.BatchNorm2d(64, affine=True) 172 | self.inception_3b_relu_1x1 = nn.ReLU(inplace) 173 | self.inception_3b_3x3_reduce = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) 174 | self.inception_3b_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True) 175 | self.inception_3b_relu_3x3_reduce = nn.ReLU(inplace) 176 | self.inception_3b_3x3 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 177 | self.inception_3b_3x3_bn = nn.BatchNorm2d(96, affine=True) 178 | self.inception_3b_relu_3x3 = nn.ReLU(inplace) 179 | self.inception_3b_double_3x3_reduce = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) 180 | self.inception_3b_double_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True) 181 | self.inception_3b_relu_double_3x3_reduce = nn.ReLU(inplace) 182 | self.inception_3b_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 183 | self.inception_3b_double_3x3_1_bn = nn.BatchNorm2d(96, affine=True) 184 | self.inception_3b_relu_double_3x3_1 = nn.ReLU(inplace) 185 | self.inception_3b_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 186 | self.inception_3b_double_3x3_2_bn = nn.BatchNorm2d(96, affine=True) 187 | self.inception_3b_relu_double_3x3_2 = nn.ReLU(inplace) 188 | self.inception_3b_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) 189 | self.inception_3b_pool_proj = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1)) 190 | self.inception_3b_pool_proj_bn = nn.BatchNorm2d(64, affine=True) 191 | self.inception_3b_relu_pool_proj = nn.ReLU(inplace) 192 | self.inception_3c_3x3_reduce = nn.Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1)) 193 | self.inception_3c_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True) 194 | self.inception_3c_relu_3x3_reduce = nn.ReLU(inplace) 195 | self.inception_3c_3x3 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 196 | self.inception_3c_3x3_bn = nn.BatchNorm2d(160, affine=True) 197 | self.inception_3c_relu_3x3 = nn.ReLU(inplace) 198 | self.inception_3c_double_3x3_reduce = nn.Conv2d(320, 64, kernel_size=(1, 1), stride=(1, 1)) 199 | self.inception_3c_double_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True) 200 | self.inception_3c_relu_double_3x3_reduce = nn.ReLU(inplace) 201 | self.inception_3c_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 202 | self.inception_3c_double_3x3_1_bn = nn.BatchNorm2d(96, affine=True) 203 | self.inception_3c_relu_double_3x3_1 = nn.ReLU(inplace) 204 | self.inception_3c_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 205 | self.inception_3c_double_3x3_2_bn = nn.BatchNorm2d(96, affine=True) 206 | self.inception_3c_relu_double_3x3_2 = nn.ReLU(inplace) 207 | self.inception_3c_pool = nn.MaxPool2d((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True) 208 | self.inception_4a_1x1 = nn.Conv2d(576, 224, kernel_size=(1, 1), stride=(1, 1)) 209 | self.inception_4a_1x1_bn = nn.BatchNorm2d(224, affine=True) 210 | self.inception_4a_relu_1x1 = nn.ReLU(inplace) 211 | self.inception_4a_3x3_reduce = nn.Conv2d(576, 64, kernel_size=(1, 1), stride=(1, 1)) 212 | self.inception_4a_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True) 213 | self.inception_4a_relu_3x3_reduce = nn.ReLU(inplace) 214 | self.inception_4a_3x3 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 215 | self.inception_4a_3x3_bn = nn.BatchNorm2d(96, affine=True) 216 | self.inception_4a_relu_3x3 = nn.ReLU(inplace) 217 | self.inception_4a_double_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1)) 218 | self.inception_4a_double_3x3_reduce_bn = nn.BatchNorm2d(96, affine=True) 219 | self.inception_4a_relu_double_3x3_reduce = nn.ReLU(inplace) 220 | self.inception_4a_double_3x3_1 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 221 | self.inception_4a_double_3x3_1_bn = nn.BatchNorm2d(128, affine=True) 222 | self.inception_4a_relu_double_3x3_1 = nn.ReLU(inplace) 223 | self.inception_4a_double_3x3_2 = nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 224 | self.inception_4a_double_3x3_2_bn = nn.BatchNorm2d(128, affine=True) 225 | self.inception_4a_relu_double_3x3_2 = nn.ReLU(inplace) 226 | self.inception_4a_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) 227 | self.inception_4a_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) 228 | self.inception_4a_pool_proj_bn = nn.BatchNorm2d(128, affine=True) 229 | self.inception_4a_relu_pool_proj = nn.ReLU(inplace) 230 | self.inception_4b_1x1 = nn.Conv2d(576, 192, kernel_size=(1, 1), stride=(1, 1)) 231 | self.inception_4b_1x1_bn = nn.BatchNorm2d(192, affine=True) 232 | self.inception_4b_relu_1x1 = nn.ReLU(inplace) 233 | self.inception_4b_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1)) 234 | self.inception_4b_3x3_reduce_bn = nn.BatchNorm2d(96, affine=True) 235 | self.inception_4b_relu_3x3_reduce = nn.ReLU(inplace) 236 | self.inception_4b_3x3 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 237 | self.inception_4b_3x3_bn = nn.BatchNorm2d(128, affine=True) 238 | self.inception_4b_relu_3x3 = nn.ReLU(inplace) 239 | self.inception_4b_double_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1)) 240 | self.inception_4b_double_3x3_reduce_bn = nn.BatchNorm2d(96, affine=True) 241 | self.inception_4b_relu_double_3x3_reduce = nn.ReLU(inplace) 242 | self.inception_4b_double_3x3_1 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 243 | self.inception_4b_double_3x3_1_bn = nn.BatchNorm2d(128, affine=True) 244 | self.inception_4b_relu_double_3x3_1 = nn.ReLU(inplace) 245 | self.inception_4b_double_3x3_2 = nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 246 | self.inception_4b_double_3x3_2_bn = nn.BatchNorm2d(128, affine=True) 247 | self.inception_4b_relu_double_3x3_2 = nn.ReLU(inplace) 248 | self.inception_4b_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) 249 | self.inception_4b_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) 250 | self.inception_4b_pool_proj_bn = nn.BatchNorm2d(128, affine=True) 251 | self.inception_4b_relu_pool_proj = nn.ReLU(inplace) 252 | self.inception_4c_1x1 = nn.Conv2d(576, 160, kernel_size=(1, 1), stride=(1, 1)) 253 | self.inception_4c_1x1_bn = nn.BatchNorm2d(160, affine=True) 254 | self.inception_4c_relu_1x1 = nn.ReLU(inplace) 255 | self.inception_4c_3x3_reduce = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) 256 | self.inception_4c_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True) 257 | self.inception_4c_relu_3x3_reduce = nn.ReLU(inplace) 258 | self.inception_4c_3x3 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 259 | self.inception_4c_3x3_bn = nn.BatchNorm2d(160, affine=True) 260 | self.inception_4c_relu_3x3 = nn.ReLU(inplace) 261 | self.inception_4c_double_3x3_reduce = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) 262 | self.inception_4c_double_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True) 263 | self.inception_4c_relu_double_3x3_reduce = nn.ReLU(inplace) 264 | self.inception_4c_double_3x3_1 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 265 | self.inception_4c_double_3x3_1_bn = nn.BatchNorm2d(160, affine=True) 266 | self.inception_4c_relu_double_3x3_1 = nn.ReLU(inplace) 267 | self.inception_4c_double_3x3_2 = nn.Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 268 | self.inception_4c_double_3x3_2_bn = nn.BatchNorm2d(160, affine=True) 269 | self.inception_4c_relu_double_3x3_2 = nn.ReLU(inplace) 270 | self.inception_4c_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) 271 | self.inception_4c_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1)) 272 | self.inception_4c_pool_proj_bn = nn.BatchNorm2d(128, affine=True) 273 | self.inception_4c_relu_pool_proj = nn.ReLU(inplace) 274 | self.inception_4d_1x1 = nn.Conv2d(608, 96, kernel_size=(1, 1), stride=(1, 1)) 275 | self.inception_4d_1x1_bn = nn.BatchNorm2d(96, affine=True) 276 | self.inception_4d_relu_1x1 = nn.ReLU(inplace) 277 | self.inception_4d_3x3_reduce = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1)) 278 | self.inception_4d_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True) 279 | self.inception_4d_relu_3x3_reduce = nn.ReLU(inplace) 280 | self.inception_4d_3x3 = nn.Conv2d(128, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 281 | self.inception_4d_3x3_bn = nn.BatchNorm2d(192, affine=True) 282 | self.inception_4d_relu_3x3 = nn.ReLU(inplace) 283 | self.inception_4d_double_3x3_reduce = nn.Conv2d(608, 160, kernel_size=(1, 1), stride=(1, 1)) 284 | self.inception_4d_double_3x3_reduce_bn = nn.BatchNorm2d(160, affine=True) 285 | self.inception_4d_relu_double_3x3_reduce = nn.ReLU(inplace) 286 | self.inception_4d_double_3x3_1 = nn.Conv2d(160, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 287 | self.inception_4d_double_3x3_1_bn = nn.BatchNorm2d(192, affine=True) 288 | self.inception_4d_relu_double_3x3_1 = nn.ReLU(inplace) 289 | self.inception_4d_double_3x3_2 = nn.Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 290 | self.inception_4d_double_3x3_2_bn = nn.BatchNorm2d(192, affine=True) 291 | self.inception_4d_relu_double_3x3_2 = nn.ReLU(inplace) 292 | self.inception_4d_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) 293 | self.inception_4d_pool_proj = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1)) 294 | self.inception_4d_pool_proj_bn = nn.BatchNorm2d(128, affine=True) 295 | self.inception_4d_relu_pool_proj = nn.ReLU(inplace) 296 | self.inception_4e_3x3_reduce = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1)) 297 | self.inception_4e_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True) 298 | self.inception_4e_relu_3x3_reduce = nn.ReLU(inplace) 299 | self.inception_4e_3x3 = nn.Conv2d(128, 192, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 300 | self.inception_4e_3x3_bn = nn.BatchNorm2d(192, affine=True) 301 | self.inception_4e_relu_3x3 = nn.ReLU(inplace) 302 | self.inception_4e_double_3x3_reduce = nn.Conv2d(608, 192, kernel_size=(1, 1), stride=(1, 1)) 303 | self.inception_4e_double_3x3_reduce_bn = nn.BatchNorm2d(192, affine=True) 304 | self.inception_4e_relu_double_3x3_reduce = nn.ReLU(inplace) 305 | self.inception_4e_double_3x3_1 = nn.Conv2d(192, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 306 | self.inception_4e_double_3x3_1_bn = nn.BatchNorm2d(256, affine=True) 307 | self.inception_4e_relu_double_3x3_1 = nn.ReLU(inplace) 308 | self.inception_4e_double_3x3_2 = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 309 | self.inception_4e_double_3x3_2_bn = nn.BatchNorm2d(256, affine=True) 310 | self.inception_4e_relu_double_3x3_2 = nn.ReLU(inplace) 311 | self.inception_4e_pool = nn.MaxPool2d((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True) 312 | self.inception_5a_1x1 = nn.Conv2d(1056, 352, kernel_size=(1, 1), stride=(1, 1)) 313 | self.inception_5a_1x1_bn = nn.BatchNorm2d(352, affine=True) 314 | self.inception_5a_relu_1x1 = nn.ReLU(inplace) 315 | self.inception_5a_3x3_reduce = nn.Conv2d(1056, 192, kernel_size=(1, 1), stride=(1, 1)) 316 | self.inception_5a_3x3_reduce_bn = nn.BatchNorm2d(192, affine=True) 317 | self.inception_5a_relu_3x3_reduce = nn.ReLU(inplace) 318 | self.inception_5a_3x3 = nn.Conv2d(192, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 319 | self.inception_5a_3x3_bn = nn.BatchNorm2d(320, affine=True) 320 | self.inception_5a_relu_3x3 = nn.ReLU(inplace) 321 | self.inception_5a_double_3x3_reduce = nn.Conv2d(1056, 160, kernel_size=(1, 1), stride=(1, 1)) 322 | self.inception_5a_double_3x3_reduce_bn = nn.BatchNorm2d(160, affine=True) 323 | self.inception_5a_relu_double_3x3_reduce = nn.ReLU(inplace) 324 | self.inception_5a_double_3x3_1 = nn.Conv2d(160, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 325 | self.inception_5a_double_3x3_1_bn = nn.BatchNorm2d(224, affine=True) 326 | self.inception_5a_relu_double_3x3_1 = nn.ReLU(inplace) 327 | self.inception_5a_double_3x3_2 = nn.Conv2d(224, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 328 | self.inception_5a_double_3x3_2_bn = nn.BatchNorm2d(224, affine=True) 329 | self.inception_5a_relu_double_3x3_2 = nn.ReLU(inplace) 330 | self.inception_5a_pool = nn.AvgPool2d(3, stride=1, padding=1, ceil_mode=True, count_include_pad=True) 331 | self.inception_5a_pool_proj = nn.Conv2d(1056, 128, kernel_size=(1, 1), stride=(1, 1)) 332 | self.inception_5a_pool_proj_bn = nn.BatchNorm2d(128, affine=True) 333 | self.inception_5a_relu_pool_proj = nn.ReLU(inplace) 334 | self.inception_5b_1x1 = nn.Conv2d(1024, 352, kernel_size=(1, 1), stride=(1, 1)) 335 | self.inception_5b_1x1_bn = nn.BatchNorm2d(352, affine=True) 336 | self.inception_5b_relu_1x1 = nn.ReLU(inplace) 337 | self.inception_5b_3x3_reduce = nn.Conv2d(1024, 192, kernel_size=(1, 1), stride=(1, 1)) 338 | self.inception_5b_3x3_reduce_bn = nn.BatchNorm2d(192, affine=True) 339 | self.inception_5b_relu_3x3_reduce = nn.ReLU(inplace) 340 | self.inception_5b_3x3 = nn.Conv2d(192, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 341 | self.inception_5b_3x3_bn = nn.BatchNorm2d(320, affine=True) 342 | self.inception_5b_relu_3x3 = nn.ReLU(inplace) 343 | self.inception_5b_double_3x3_reduce = nn.Conv2d(1024, 192, kernel_size=(1, 1), stride=(1, 1)) 344 | self.inception_5b_double_3x3_reduce_bn = nn.BatchNorm2d(192, affine=True) 345 | self.inception_5b_relu_double_3x3_reduce = nn.ReLU(inplace) 346 | self.inception_5b_double_3x3_1 = nn.Conv2d(192, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 347 | self.inception_5b_double_3x3_1_bn = nn.BatchNorm2d(224, affine=True) 348 | self.inception_5b_relu_double_3x3_1 = nn.ReLU(inplace) 349 | self.inception_5b_double_3x3_2 = nn.Conv2d(224, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 350 | self.inception_5b_double_3x3_2_bn = nn.BatchNorm2d(224, affine=True) 351 | self.inception_5b_relu_double_3x3_2 = nn.ReLU(inplace) 352 | self.inception_5b_pool = nn.MaxPool2d((3, 3), stride=(1, 1), padding=(1, 1), dilation=(1, 1), ceil_mode=True) 353 | self.inception_5b_pool_proj = nn.Conv2d(1024, 128, kernel_size=(1, 1), stride=(1, 1)) 354 | self.inception_5b_pool_proj_bn = nn.BatchNorm2d(128, affine=True) 355 | self.inception_5b_relu_pool_proj = nn.ReLU(inplace) 356 | 357 | def features(self, input): 358 | conv1_7x7_s2_out = self.conv1_7x7_s2(input) 359 | conv1_7x7_s2_bn_out = self.conv1_7x7_s2_bn(conv1_7x7_s2_out) 360 | conv1_relu_7x7_out = self.conv1_relu_7x7(conv1_7x7_s2_bn_out) 361 | pool1_3x3_s2_out = self.pool1_3x3_s2(conv1_relu_7x7_out) 362 | conv2_3x3_reduce_out = self.conv2_3x3_reduce(pool1_3x3_s2_out) 363 | conv2_3x3_reduce_bn_out = self.conv2_3x3_reduce_bn(conv2_3x3_reduce_out) 364 | conv2_relu_3x3_reduce_out = self.conv2_relu_3x3_reduce(conv2_3x3_reduce_bn_out) 365 | conv2_3x3_out = self.conv2_3x3(conv2_relu_3x3_reduce_out) 366 | conv2_3x3_bn_out = self.conv2_3x3_bn(conv2_3x3_out) 367 | conv2_relu_3x3_out = self.conv2_relu_3x3(conv2_3x3_bn_out) 368 | pool2_3x3_s2_out = self.pool2_3x3_s2(conv2_relu_3x3_out) 369 | inception_3a_1x1_out = self.inception_3a_1x1(pool2_3x3_s2_out) 370 | inception_3a_1x1_bn_out = self.inception_3a_1x1_bn(inception_3a_1x1_out) 371 | inception_3a_relu_1x1_out = self.inception_3a_relu_1x1(inception_3a_1x1_bn_out) 372 | inception_3a_3x3_reduce_out = self.inception_3a_3x3_reduce(pool2_3x3_s2_out) 373 | inception_3a_3x3_reduce_bn_out = self.inception_3a_3x3_reduce_bn(inception_3a_3x3_reduce_out) 374 | inception_3a_relu_3x3_reduce_out = self.inception_3a_relu_3x3_reduce(inception_3a_3x3_reduce_bn_out) 375 | inception_3a_3x3_out = self.inception_3a_3x3(inception_3a_relu_3x3_reduce_out) 376 | inception_3a_3x3_bn_out = self.inception_3a_3x3_bn(inception_3a_3x3_out) 377 | inception_3a_relu_3x3_out = self.inception_3a_relu_3x3(inception_3a_3x3_bn_out) 378 | inception_3a_double_3x3_reduce_out = self.inception_3a_double_3x3_reduce(pool2_3x3_s2_out) 379 | inception_3a_double_3x3_reduce_bn_out = self.inception_3a_double_3x3_reduce_bn(inception_3a_double_3x3_reduce_out) 380 | inception_3a_relu_double_3x3_reduce_out = self.inception_3a_relu_double_3x3_reduce(inception_3a_double_3x3_reduce_bn_out) 381 | inception_3a_double_3x3_1_out = self.inception_3a_double_3x3_1(inception_3a_relu_double_3x3_reduce_out) 382 | inception_3a_double_3x3_1_bn_out = self.inception_3a_double_3x3_1_bn(inception_3a_double_3x3_1_out) 383 | inception_3a_relu_double_3x3_1_out = self.inception_3a_relu_double_3x3_1(inception_3a_double_3x3_1_bn_out) 384 | inception_3a_double_3x3_2_out = self.inception_3a_double_3x3_2(inception_3a_relu_double_3x3_1_out) 385 | inception_3a_double_3x3_2_bn_out = self.inception_3a_double_3x3_2_bn(inception_3a_double_3x3_2_out) 386 | inception_3a_relu_double_3x3_2_out = self.inception_3a_relu_double_3x3_2(inception_3a_double_3x3_2_bn_out) 387 | inception_3a_pool_out = self.inception_3a_pool(pool2_3x3_s2_out) 388 | inception_3a_pool_proj_out = self.inception_3a_pool_proj(inception_3a_pool_out) 389 | inception_3a_pool_proj_bn_out = self.inception_3a_pool_proj_bn(inception_3a_pool_proj_out) 390 | inception_3a_relu_pool_proj_out = self.inception_3a_relu_pool_proj(inception_3a_pool_proj_bn_out) 391 | inception_3a_output_out = torch.cat([inception_3a_relu_1x1_out,inception_3a_relu_3x3_out,inception_3a_relu_double_3x3_2_out ,inception_3a_relu_pool_proj_out], 1) 392 | inception_3b_1x1_out = self.inception_3b_1x1(inception_3a_output_out) 393 | inception_3b_1x1_bn_out = self.inception_3b_1x1_bn(inception_3b_1x1_out) 394 | inception_3b_relu_1x1_out = self.inception_3b_relu_1x1(inception_3b_1x1_bn_out) 395 | inception_3b_3x3_reduce_out = self.inception_3b_3x3_reduce(inception_3a_output_out) 396 | inception_3b_3x3_reduce_bn_out = self.inception_3b_3x3_reduce_bn(inception_3b_3x3_reduce_out) 397 | inception_3b_relu_3x3_reduce_out = self.inception_3b_relu_3x3_reduce(inception_3b_3x3_reduce_bn_out) 398 | inception_3b_3x3_out = self.inception_3b_3x3(inception_3b_relu_3x3_reduce_out) 399 | inception_3b_3x3_bn_out = self.inception_3b_3x3_bn(inception_3b_3x3_out) 400 | inception_3b_relu_3x3_out = self.inception_3b_relu_3x3(inception_3b_3x3_bn_out) 401 | inception_3b_double_3x3_reduce_out = self.inception_3b_double_3x3_reduce(inception_3a_output_out) 402 | inception_3b_double_3x3_reduce_bn_out = self.inception_3b_double_3x3_reduce_bn(inception_3b_double_3x3_reduce_out) 403 | inception_3b_relu_double_3x3_reduce_out = self.inception_3b_relu_double_3x3_reduce(inception_3b_double_3x3_reduce_bn_out) 404 | inception_3b_double_3x3_1_out = self.inception_3b_double_3x3_1(inception_3b_relu_double_3x3_reduce_out) 405 | inception_3b_double_3x3_1_bn_out = self.inception_3b_double_3x3_1_bn(inception_3b_double_3x3_1_out) 406 | inception_3b_relu_double_3x3_1_out = self.inception_3b_relu_double_3x3_1(inception_3b_double_3x3_1_bn_out) 407 | inception_3b_double_3x3_2_out = self.inception_3b_double_3x3_2(inception_3b_relu_double_3x3_1_out) 408 | inception_3b_double_3x3_2_bn_out = self.inception_3b_double_3x3_2_bn(inception_3b_double_3x3_2_out) 409 | inception_3b_relu_double_3x3_2_out = self.inception_3b_relu_double_3x3_2(inception_3b_double_3x3_2_bn_out) 410 | inception_3b_pool_out = self.inception_3b_pool(inception_3a_output_out) 411 | inception_3b_pool_proj_out = self.inception_3b_pool_proj(inception_3b_pool_out) 412 | inception_3b_pool_proj_bn_out = self.inception_3b_pool_proj_bn(inception_3b_pool_proj_out) 413 | inception_3b_relu_pool_proj_out = self.inception_3b_relu_pool_proj(inception_3b_pool_proj_bn_out) 414 | inception_3b_output_out = torch.cat([inception_3b_relu_1x1_out,inception_3b_relu_3x3_out,inception_3b_relu_double_3x3_2_out,inception_3b_relu_pool_proj_out], 1) 415 | inception_3c_3x3_reduce_out = self.inception_3c_3x3_reduce(inception_3b_output_out) 416 | inception_3c_3x3_reduce_bn_out = self.inception_3c_3x3_reduce_bn(inception_3c_3x3_reduce_out) 417 | inception_3c_relu_3x3_reduce_out = self.inception_3c_relu_3x3_reduce(inception_3c_3x3_reduce_bn_out) 418 | inception_3c_3x3_out = self.inception_3c_3x3(inception_3c_relu_3x3_reduce_out) 419 | inception_3c_3x3_bn_out = self.inception_3c_3x3_bn(inception_3c_3x3_out) 420 | inception_3c_relu_3x3_out = self.inception_3c_relu_3x3(inception_3c_3x3_bn_out) 421 | inception_3c_double_3x3_reduce_out = self.inception_3c_double_3x3_reduce(inception_3b_output_out) 422 | inception_3c_double_3x3_reduce_bn_out = self.inception_3c_double_3x3_reduce_bn(inception_3c_double_3x3_reduce_out) 423 | inception_3c_relu_double_3x3_reduce_out = self.inception_3c_relu_double_3x3_reduce(inception_3c_double_3x3_reduce_bn_out) 424 | inception_3c_double_3x3_1_out = self.inception_3c_double_3x3_1(inception_3c_relu_double_3x3_reduce_out) 425 | inception_3c_double_3x3_1_bn_out = self.inception_3c_double_3x3_1_bn(inception_3c_double_3x3_1_out) 426 | inception_3c_relu_double_3x3_1_out = self.inception_3c_relu_double_3x3_1(inception_3c_double_3x3_1_bn_out) 427 | inception_3c_double_3x3_2_out = self.inception_3c_double_3x3_2(inception_3c_relu_double_3x3_1_out) 428 | inception_3c_double_3x3_2_bn_out = self.inception_3c_double_3x3_2_bn(inception_3c_double_3x3_2_out) 429 | inception_3c_relu_double_3x3_2_out = self.inception_3c_relu_double_3x3_2(inception_3c_double_3x3_2_bn_out) 430 | inception_3c_pool_out = self.inception_3c_pool(inception_3b_output_out) 431 | inception_3c_output_out = torch.cat([inception_3c_relu_3x3_out,inception_3c_relu_double_3x3_2_out,inception_3c_pool_out], 1) 432 | inception_4a_1x1_out = self.inception_4a_1x1(inception_3c_output_out) 433 | inception_4a_1x1_bn_out = self.inception_4a_1x1_bn(inception_4a_1x1_out) 434 | inception_4a_relu_1x1_out = self.inception_4a_relu_1x1(inception_4a_1x1_bn_out) 435 | inception_4a_3x3_reduce_out = self.inception_4a_3x3_reduce(inception_3c_output_out) 436 | inception_4a_3x3_reduce_bn_out = self.inception_4a_3x3_reduce_bn(inception_4a_3x3_reduce_out) 437 | inception_4a_relu_3x3_reduce_out = self.inception_4a_relu_3x3_reduce(inception_4a_3x3_reduce_bn_out) 438 | inception_4a_3x3_out = self.inception_4a_3x3(inception_4a_relu_3x3_reduce_out) 439 | inception_4a_3x3_bn_out = self.inception_4a_3x3_bn(inception_4a_3x3_out) 440 | inception_4a_relu_3x3_out = self.inception_4a_relu_3x3(inception_4a_3x3_bn_out) 441 | inception_4a_double_3x3_reduce_out = self.inception_4a_double_3x3_reduce(inception_3c_output_out) 442 | inception_4a_double_3x3_reduce_bn_out = self.inception_4a_double_3x3_reduce_bn(inception_4a_double_3x3_reduce_out) 443 | inception_4a_relu_double_3x3_reduce_out = self.inception_4a_relu_double_3x3_reduce(inception_4a_double_3x3_reduce_bn_out) 444 | inception_4a_double_3x3_1_out = self.inception_4a_double_3x3_1(inception_4a_relu_double_3x3_reduce_out) 445 | inception_4a_double_3x3_1_bn_out = self.inception_4a_double_3x3_1_bn(inception_4a_double_3x3_1_out) 446 | inception_4a_relu_double_3x3_1_out = self.inception_4a_relu_double_3x3_1(inception_4a_double_3x3_1_bn_out) 447 | inception_4a_double_3x3_2_out = self.inception_4a_double_3x3_2(inception_4a_relu_double_3x3_1_out) 448 | inception_4a_double_3x3_2_bn_out = self.inception_4a_double_3x3_2_bn(inception_4a_double_3x3_2_out) 449 | inception_4a_relu_double_3x3_2_out = self.inception_4a_relu_double_3x3_2(inception_4a_double_3x3_2_bn_out) 450 | inception_4a_pool_out = self.inception_4a_pool(inception_3c_output_out) 451 | inception_4a_pool_proj_out = self.inception_4a_pool_proj(inception_4a_pool_out) 452 | inception_4a_pool_proj_bn_out = self.inception_4a_pool_proj_bn(inception_4a_pool_proj_out) 453 | inception_4a_relu_pool_proj_out = self.inception_4a_relu_pool_proj(inception_4a_pool_proj_bn_out) 454 | inception_4a_output_out = torch.cat([inception_4a_relu_1x1_out,inception_4a_relu_3x3_out,inception_4a_relu_double_3x3_2_out,inception_4a_relu_pool_proj_out], 1) 455 | inception_4b_1x1_out = self.inception_4b_1x1(inception_4a_output_out) 456 | inception_4b_1x1_bn_out = self.inception_4b_1x1_bn(inception_4b_1x1_out) 457 | inception_4b_relu_1x1_out = self.inception_4b_relu_1x1(inception_4b_1x1_bn_out) 458 | inception_4b_3x3_reduce_out = self.inception_4b_3x3_reduce(inception_4a_output_out) 459 | inception_4b_3x3_reduce_bn_out = self.inception_4b_3x3_reduce_bn(inception_4b_3x3_reduce_out) 460 | inception_4b_relu_3x3_reduce_out = self.inception_4b_relu_3x3_reduce(inception_4b_3x3_reduce_bn_out) 461 | inception_4b_3x3_out = self.inception_4b_3x3(inception_4b_relu_3x3_reduce_out) 462 | inception_4b_3x3_bn_out = self.inception_4b_3x3_bn(inception_4b_3x3_out) 463 | inception_4b_relu_3x3_out = self.inception_4b_relu_3x3(inception_4b_3x3_bn_out) 464 | inception_4b_double_3x3_reduce_out = self.inception_4b_double_3x3_reduce(inception_4a_output_out) 465 | inception_4b_double_3x3_reduce_bn_out = self.inception_4b_double_3x3_reduce_bn(inception_4b_double_3x3_reduce_out) 466 | inception_4b_relu_double_3x3_reduce_out = self.inception_4b_relu_double_3x3_reduce(inception_4b_double_3x3_reduce_bn_out) 467 | inception_4b_double_3x3_1_out = self.inception_4b_double_3x3_1(inception_4b_relu_double_3x3_reduce_out) 468 | inception_4b_double_3x3_1_bn_out = self.inception_4b_double_3x3_1_bn(inception_4b_double_3x3_1_out) 469 | inception_4b_relu_double_3x3_1_out = self.inception_4b_relu_double_3x3_1(inception_4b_double_3x3_1_bn_out) 470 | inception_4b_double_3x3_2_out = self.inception_4b_double_3x3_2(inception_4b_relu_double_3x3_1_out) 471 | inception_4b_double_3x3_2_bn_out = self.inception_4b_double_3x3_2_bn(inception_4b_double_3x3_2_out) 472 | inception_4b_relu_double_3x3_2_out = self.inception_4b_relu_double_3x3_2(inception_4b_double_3x3_2_bn_out) 473 | inception_4b_pool_out = self.inception_4b_pool(inception_4a_output_out) 474 | inception_4b_pool_proj_out = self.inception_4b_pool_proj(inception_4b_pool_out) 475 | inception_4b_pool_proj_bn_out = self.inception_4b_pool_proj_bn(inception_4b_pool_proj_out) 476 | inception_4b_relu_pool_proj_out = self.inception_4b_relu_pool_proj(inception_4b_pool_proj_bn_out) 477 | inception_4b_output_out = torch.cat([inception_4b_relu_1x1_out,inception_4b_relu_3x3_out,inception_4b_relu_double_3x3_2_out,inception_4b_relu_pool_proj_out], 1) 478 | inception_4c_1x1_out = self.inception_4c_1x1(inception_4b_output_out) 479 | inception_4c_1x1_bn_out = self.inception_4c_1x1_bn(inception_4c_1x1_out) 480 | inception_4c_relu_1x1_out = self.inception_4c_relu_1x1(inception_4c_1x1_bn_out) 481 | inception_4c_3x3_reduce_out = self.inception_4c_3x3_reduce(inception_4b_output_out) 482 | inception_4c_3x3_reduce_bn_out = self.inception_4c_3x3_reduce_bn(inception_4c_3x3_reduce_out) 483 | inception_4c_relu_3x3_reduce_out = self.inception_4c_relu_3x3_reduce(inception_4c_3x3_reduce_bn_out) 484 | inception_4c_3x3_out = self.inception_4c_3x3(inception_4c_relu_3x3_reduce_out) 485 | inception_4c_3x3_bn_out = self.inception_4c_3x3_bn(inception_4c_3x3_out) 486 | inception_4c_relu_3x3_out = self.inception_4c_relu_3x3(inception_4c_3x3_bn_out) 487 | inception_4c_double_3x3_reduce_out = self.inception_4c_double_3x3_reduce(inception_4b_output_out) 488 | inception_4c_double_3x3_reduce_bn_out = self.inception_4c_double_3x3_reduce_bn(inception_4c_double_3x3_reduce_out) 489 | inception_4c_relu_double_3x3_reduce_out = self.inception_4c_relu_double_3x3_reduce(inception_4c_double_3x3_reduce_bn_out) 490 | inception_4c_double_3x3_1_out = self.inception_4c_double_3x3_1(inception_4c_relu_double_3x3_reduce_out) 491 | inception_4c_double_3x3_1_bn_out = self.inception_4c_double_3x3_1_bn(inception_4c_double_3x3_1_out) 492 | inception_4c_relu_double_3x3_1_out = self.inception_4c_relu_double_3x3_1(inception_4c_double_3x3_1_bn_out) 493 | inception_4c_double_3x3_2_out = self.inception_4c_double_3x3_2(inception_4c_relu_double_3x3_1_out) 494 | inception_4c_double_3x3_2_bn_out = self.inception_4c_double_3x3_2_bn(inception_4c_double_3x3_2_out) 495 | inception_4c_relu_double_3x3_2_out = self.inception_4c_relu_double_3x3_2(inception_4c_double_3x3_2_bn_out) 496 | inception_4c_pool_out = self.inception_4c_pool(inception_4b_output_out) 497 | inception_4c_pool_proj_out = self.inception_4c_pool_proj(inception_4c_pool_out) 498 | inception_4c_pool_proj_bn_out = self.inception_4c_pool_proj_bn(inception_4c_pool_proj_out) 499 | inception_4c_relu_pool_proj_out = self.inception_4c_relu_pool_proj(inception_4c_pool_proj_bn_out) 500 | inception_4c_output_out = torch.cat([inception_4c_relu_1x1_out,inception_4c_relu_3x3_out,inception_4c_relu_double_3x3_2_out,inception_4c_relu_pool_proj_out], 1) 501 | inception_4d_1x1_out = self.inception_4d_1x1(inception_4c_output_out) 502 | inception_4d_1x1_bn_out = self.inception_4d_1x1_bn(inception_4d_1x1_out) 503 | inception_4d_relu_1x1_out = self.inception_4d_relu_1x1(inception_4d_1x1_bn_out) 504 | inception_4d_3x3_reduce_out = self.inception_4d_3x3_reduce(inception_4c_output_out) 505 | inception_4d_3x3_reduce_bn_out = self.inception_4d_3x3_reduce_bn(inception_4d_3x3_reduce_out) 506 | inception_4d_relu_3x3_reduce_out = self.inception_4d_relu_3x3_reduce(inception_4d_3x3_reduce_bn_out) 507 | inception_4d_3x3_out = self.inception_4d_3x3(inception_4d_relu_3x3_reduce_out) 508 | inception_4d_3x3_bn_out = self.inception_4d_3x3_bn(inception_4d_3x3_out) 509 | inception_4d_relu_3x3_out = self.inception_4d_relu_3x3(inception_4d_3x3_bn_out) 510 | inception_4d_double_3x3_reduce_out = self.inception_4d_double_3x3_reduce(inception_4c_output_out) 511 | inception_4d_double_3x3_reduce_bn_out = self.inception_4d_double_3x3_reduce_bn(inception_4d_double_3x3_reduce_out) 512 | inception_4d_relu_double_3x3_reduce_out = self.inception_4d_relu_double_3x3_reduce(inception_4d_double_3x3_reduce_bn_out) 513 | inception_4d_double_3x3_1_out = self.inception_4d_double_3x3_1(inception_4d_relu_double_3x3_reduce_out) 514 | inception_4d_double_3x3_1_bn_out = self.inception_4d_double_3x3_1_bn(inception_4d_double_3x3_1_out) 515 | inception_4d_relu_double_3x3_1_out = self.inception_4d_relu_double_3x3_1(inception_4d_double_3x3_1_bn_out) 516 | inception_4d_double_3x3_2_out = self.inception_4d_double_3x3_2(inception_4d_relu_double_3x3_1_out) 517 | inception_4d_double_3x3_2_bn_out = self.inception_4d_double_3x3_2_bn(inception_4d_double_3x3_2_out) 518 | inception_4d_relu_double_3x3_2_out = self.inception_4d_relu_double_3x3_2(inception_4d_double_3x3_2_bn_out) 519 | inception_4d_pool_out = self.inception_4d_pool(inception_4c_output_out) 520 | inception_4d_pool_proj_out = self.inception_4d_pool_proj(inception_4d_pool_out) 521 | inception_4d_pool_proj_bn_out = self.inception_4d_pool_proj_bn(inception_4d_pool_proj_out) 522 | inception_4d_relu_pool_proj_out = self.inception_4d_relu_pool_proj(inception_4d_pool_proj_bn_out) 523 | inception_4d_output_out = torch.cat([inception_4d_relu_1x1_out,inception_4d_relu_3x3_out,inception_4d_relu_double_3x3_2_out,inception_4d_relu_pool_proj_out], 1) 524 | inception_4e_3x3_reduce_out = self.inception_4e_3x3_reduce(inception_4d_output_out) 525 | inception_4e_3x3_reduce_bn_out = self.inception_4e_3x3_reduce_bn(inception_4e_3x3_reduce_out) 526 | inception_4e_relu_3x3_reduce_out = self.inception_4e_relu_3x3_reduce(inception_4e_3x3_reduce_bn_out) 527 | inception_4e_3x3_out = self.inception_4e_3x3(inception_4e_relu_3x3_reduce_out) 528 | inception_4e_3x3_bn_out = self.inception_4e_3x3_bn(inception_4e_3x3_out) 529 | inception_4e_relu_3x3_out = self.inception_4e_relu_3x3(inception_4e_3x3_bn_out) 530 | inception_4e_double_3x3_reduce_out = self.inception_4e_double_3x3_reduce(inception_4d_output_out) 531 | inception_4e_double_3x3_reduce_bn_out = self.inception_4e_double_3x3_reduce_bn(inception_4e_double_3x3_reduce_out) 532 | inception_4e_relu_double_3x3_reduce_out = self.inception_4e_relu_double_3x3_reduce(inception_4e_double_3x3_reduce_bn_out) 533 | inception_4e_double_3x3_1_out = self.inception_4e_double_3x3_1(inception_4e_relu_double_3x3_reduce_out) 534 | inception_4e_double_3x3_1_bn_out = self.inception_4e_double_3x3_1_bn(inception_4e_double_3x3_1_out) 535 | inception_4e_relu_double_3x3_1_out = self.inception_4e_relu_double_3x3_1(inception_4e_double_3x3_1_bn_out) 536 | inception_4e_double_3x3_2_out = self.inception_4e_double_3x3_2(inception_4e_relu_double_3x3_1_out) 537 | inception_4e_double_3x3_2_bn_out = self.inception_4e_double_3x3_2_bn(inception_4e_double_3x3_2_out) 538 | inception_4e_relu_double_3x3_2_out = self.inception_4e_relu_double_3x3_2(inception_4e_double_3x3_2_bn_out) 539 | inception_4e_pool_out = self.inception_4e_pool(inception_4d_output_out) 540 | inception_4e_output_out = torch.cat([inception_4e_relu_3x3_out,inception_4e_relu_double_3x3_2_out,inception_4e_pool_out], 1) 541 | inception_5a_1x1_out = self.inception_5a_1x1(inception_4e_output_out) 542 | inception_5a_1x1_bn_out = self.inception_5a_1x1_bn(inception_5a_1x1_out) 543 | inception_5a_relu_1x1_out = self.inception_5a_relu_1x1(inception_5a_1x1_bn_out) 544 | inception_5a_3x3_reduce_out = self.inception_5a_3x3_reduce(inception_4e_output_out) 545 | inception_5a_3x3_reduce_bn_out = self.inception_5a_3x3_reduce_bn(inception_5a_3x3_reduce_out) 546 | inception_5a_relu_3x3_reduce_out = self.inception_5a_relu_3x3_reduce(inception_5a_3x3_reduce_bn_out) 547 | inception_5a_3x3_out = self.inception_5a_3x3(inception_5a_relu_3x3_reduce_out) 548 | inception_5a_3x3_bn_out = self.inception_5a_3x3_bn(inception_5a_3x3_out) 549 | inception_5a_relu_3x3_out = self.inception_5a_relu_3x3(inception_5a_3x3_bn_out) 550 | inception_5a_double_3x3_reduce_out = self.inception_5a_double_3x3_reduce(inception_4e_output_out) 551 | inception_5a_double_3x3_reduce_bn_out = self.inception_5a_double_3x3_reduce_bn(inception_5a_double_3x3_reduce_out) 552 | inception_5a_relu_double_3x3_reduce_out = self.inception_5a_relu_double_3x3_reduce(inception_5a_double_3x3_reduce_bn_out) 553 | inception_5a_double_3x3_1_out = self.inception_5a_double_3x3_1(inception_5a_relu_double_3x3_reduce_out) 554 | inception_5a_double_3x3_1_bn_out = self.inception_5a_double_3x3_1_bn(inception_5a_double_3x3_1_out) 555 | inception_5a_relu_double_3x3_1_out = self.inception_5a_relu_double_3x3_1(inception_5a_double_3x3_1_bn_out) 556 | inception_5a_double_3x3_2_out = self.inception_5a_double_3x3_2(inception_5a_relu_double_3x3_1_out) 557 | inception_5a_double_3x3_2_bn_out = self.inception_5a_double_3x3_2_bn(inception_5a_double_3x3_2_out) 558 | inception_5a_relu_double_3x3_2_out = self.inception_5a_relu_double_3x3_2(inception_5a_double_3x3_2_bn_out) 559 | inception_5a_pool_out = self.inception_5a_pool(inception_4e_output_out) 560 | inception_5a_pool_proj_out = self.inception_5a_pool_proj(inception_5a_pool_out) 561 | inception_5a_pool_proj_bn_out = self.inception_5a_pool_proj_bn(inception_5a_pool_proj_out) 562 | inception_5a_relu_pool_proj_out = self.inception_5a_relu_pool_proj(inception_5a_pool_proj_bn_out) 563 | inception_5a_output_out = torch.cat([inception_5a_relu_1x1_out,inception_5a_relu_3x3_out,inception_5a_relu_double_3x3_2_out,inception_5a_relu_pool_proj_out], 1) 564 | inception_5b_1x1_out = self.inception_5b_1x1(inception_5a_output_out) 565 | inception_5b_1x1_bn_out = self.inception_5b_1x1_bn(inception_5b_1x1_out) 566 | inception_5b_relu_1x1_out = self.inception_5b_relu_1x1(inception_5b_1x1_bn_out) 567 | inception_5b_3x3_reduce_out = self.inception_5b_3x3_reduce(inception_5a_output_out) 568 | inception_5b_3x3_reduce_bn_out = self.inception_5b_3x3_reduce_bn(inception_5b_3x3_reduce_out) 569 | inception_5b_relu_3x3_reduce_out = self.inception_5b_relu_3x3_reduce(inception_5b_3x3_reduce_bn_out) 570 | inception_5b_3x3_out = self.inception_5b_3x3(inception_5b_relu_3x3_reduce_out) 571 | inception_5b_3x3_bn_out = self.inception_5b_3x3_bn(inception_5b_3x3_out) 572 | inception_5b_relu_3x3_out = self.inception_5b_relu_3x3(inception_5b_3x3_bn_out) 573 | inception_5b_double_3x3_reduce_out = self.inception_5b_double_3x3_reduce(inception_5a_output_out) 574 | inception_5b_double_3x3_reduce_bn_out = self.inception_5b_double_3x3_reduce_bn(inception_5b_double_3x3_reduce_out) 575 | inception_5b_relu_double_3x3_reduce_out = self.inception_5b_relu_double_3x3_reduce(inception_5b_double_3x3_reduce_bn_out) 576 | inception_5b_double_3x3_1_out = self.inception_5b_double_3x3_1(inception_5b_relu_double_3x3_reduce_out) 577 | inception_5b_double_3x3_1_bn_out = self.inception_5b_double_3x3_1_bn(inception_5b_double_3x3_1_out) 578 | inception_5b_relu_double_3x3_1_out = self.inception_5b_relu_double_3x3_1(inception_5b_double_3x3_1_bn_out) 579 | inception_5b_double_3x3_2_out = self.inception_5b_double_3x3_2(inception_5b_relu_double_3x3_1_out) 580 | inception_5b_double_3x3_2_bn_out = self.inception_5b_double_3x3_2_bn(inception_5b_double_3x3_2_out) 581 | inception_5b_relu_double_3x3_2_out = self.inception_5b_relu_double_3x3_2(inception_5b_double_3x3_2_bn_out) 582 | inception_5b_pool_out = self.inception_5b_pool(inception_5a_output_out) 583 | inception_5b_pool_proj_out = self.inception_5b_pool_proj(inception_5b_pool_out) 584 | inception_5b_pool_proj_bn_out = self.inception_5b_pool_proj_bn(inception_5b_pool_proj_out) 585 | inception_5b_relu_pool_proj_out = self.inception_5b_relu_pool_proj(inception_5b_pool_proj_bn_out) 586 | inception_5b_output_out = torch.cat([inception_5b_relu_1x1_out,inception_5b_relu_3x3_out,inception_5b_relu_double_3x3_2_out,inception_5b_relu_pool_proj_out], 1) 587 | return inception_3b_output_out,inception_4d_output_out,inception_5b_output_out 588 | 589 | def forward(self, input): 590 | return self.features(input) 591 | -------------------------------------------------------------------------------- /utils/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from PIL import Image 4 | import torch 5 | import numpy as np 6 | import torch.utils.data as data 7 | import torchvision.transforms as transforms 8 | import torchvision.datasets as datasets 9 | 10 | def default_loader(path): 11 | return Image.open(path).convert('RGB') 12 | class MultiLabelDataset(data.Dataset): 13 | def __init__(self, root, label, transform = None, loader = default_loader): 14 | images = [] 15 | labels = open(label).readlines() 16 | for line in labels: 17 | items = line.split() 18 | img_name = items.pop(0) 19 | if os.path.isfile(os.path.join(root, img_name)): 20 | cur_label = tuple([int(v) for v in items]) 21 | images.append((img_name, cur_label)) 22 | else: 23 | print(os.path.join(root, img_name) + 'Not Found.') 24 | self.root = root 25 | self.images = images 26 | self.transform = transform 27 | self.loader = loader 28 | 29 | def __getitem__(self, index): 30 | img_name, label = self.images[index] 31 | img = self.loader(os.path.join(self.root, img_name)) 32 | raw_img = img.copy() 33 | if self.transform is not None: 34 | img = self.transform(img) 35 | return img, torch.Tensor(label) 36 | 37 | def __len__(self): 38 | return len(self.images) 39 | 40 | 41 | attr_nums = {} 42 | attr_nums['pa100k'] = 26 43 | attr_nums['rap'] = 51 44 | attr_nums['peta'] = 35 45 | 46 | description = {} 47 | description['pa100k'] = ['Female', 48 | 'AgeOver60', 49 | 'Age18-60', 50 | 'AgeLess18', 51 | 'Front', 52 | 'Side', 53 | 'Back', 54 | 'Hat', 55 | 'Glasses', 56 | 'HandBag', 57 | 'ShoulderBag', 58 | 'Backpack', 59 | 'HoldObjectsInFront', 60 | 'ShortSleeve', 61 | 'LongSleeve', 62 | 'UpperStride', 63 | 'UpperLogo', 64 | 'UpperPlaid', 65 | 'UpperSplice', 66 | 'LowerStripe', 67 | 'LowerPattern', 68 | 'LongCoat', 69 | 'Trousers', 70 | 'Shorts', 71 | 'Skirt&Dress', 72 | 'boots'] 73 | 74 | description['peta'] = ['Age16-30', 75 | 'Age31-45', 76 | 'Age46-60', 77 | 'AgeAbove61', 78 | 'Backpack', 79 | 'CarryingOther', 80 | 'Casual lower', 81 | 'Casual upper', 82 | 'Formal lower', 83 | 'Formal upper', 84 | 'Hat', 85 | 'Jacket', 86 | 'Jeans', 87 | 'Leather Shoes', 88 | 'Logo', 89 | 'Long hair', 90 | 'Male', 91 | 'Messenger Bag', 92 | 'Muffler', 93 | 'No accessory', 94 | 'No carrying', 95 | 'Plaid', 96 | 'PlasticBags', 97 | 'Sandals', 98 | 'Shoes', 99 | 'Shorts', 100 | 'Short Sleeve', 101 | 'Skirt', 102 | 'Sneaker', 103 | 'Stripes', 104 | 'Sunglasses', 105 | 'Trousers', 106 | 'Tshirt', 107 | 'UpperOther', 108 | 'V-Neck'] 109 | 110 | description['rap'] = ['Female', 111 | 'AgeLess16', 112 | 'Age17-30', 113 | 'Age31-45', 114 | 'BodyFat', 115 | 'BodyNormal', 116 | 'BodyThin', 117 | 'Customer', 118 | 'Clerk', 119 | 'BaldHead', 120 | 'LongHair', 121 | 'BlackHair', 122 | 'Hat', 123 | 'Glasses', 124 | 'Muffler', 125 | 'Shirt', 126 | 'Sweater', 127 | 'Vest', 128 | 'TShirt', 129 | 'Cotton', 130 | 'Jacket', 131 | 'Suit-Up', 132 | 'Tight', 133 | 'ShortSleeve', 134 | 'LongTrousers', 135 | 'Skirt', 136 | 'ShortSkirt', 137 | 'Dress', 138 | 'Jeans', 139 | 'TightTrousers', 140 | 'LeatherShoes', 141 | 'SportShoes', 142 | 'Boots', 143 | 'ClothShoes', 144 | 'CasualShoes', 145 | 'Backpack', 146 | 'SSBag', 147 | 'HandBag', 148 | 'Box', 149 | 'PlasticBag', 150 | 'PaperBag', 151 | 'HandTrunk', 152 | 'OtherAttchment', 153 | 'Calling', 154 | 'Talking', 155 | 'Gathering', 156 | 'Holding', 157 | 'Pusing', 158 | 'Pulling', 159 | 'CarryingbyArm', 160 | 'CarryingbyHand'] 161 | 162 | 163 | 164 | 165 | def Get_Dataset(experiment, approach): 166 | 167 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 168 | transform_train = transforms.Compose([ 169 | transforms.Resize(size=(256, 128)), 170 | transforms.RandomHorizontalFlip(), 171 | # transforms.ColorJitter(hue=.05, saturation=.05), 172 | # transforms.RandomRotation(20, resample=Image.BILINEAR), 173 | transforms.ToTensor(), 174 | normalize 175 | ]) 176 | transform_test = transforms.Compose([ 177 | transforms.Resize(size=(256, 128)), 178 | transforms.ToTensor(), 179 | normalize 180 | ]) 181 | 182 | if experiment == 'pa100k': 183 | train_dataset = MultiLabelDataset(root='data_path', 184 | label='train_list_path', transform=transform_train) 185 | val_dataset = MultiLabelDataset(root='data_path', 186 | label='val_list_path', transform=transform_test) 187 | return train_dataset, val_dataset, attr_nums['pa100k'], description['pa100k'] 188 | elif experiment == 'rap': 189 | train_dataset = MultiLabelDataset(root='data_path', 190 | label='train_list_path', transform=transform_train) 191 | val_dataset = MultiLabelDataset(root='data_path', 192 | label='val_list_path', transform=transform_test) 193 | return train_dataset, val_dataset, attr_nums['rap'], description['rap'] 194 | elif experiment == 'peta': 195 | train_dataset = MultiLabelDataset(root='data_path', 196 | label='train_list_path', transform=transform_train) 197 | val_dataset = MultiLabelDataset(root='data_path', 198 | label='val_list_path', transform=transform_test) 199 | return train_dataset, val_dataset, attr_nums['peta'], description['peta'] 200 | --------------------------------------------------------------------------------