├── torchreid ├── utils │ ├── __init__.py │ ├── avgmeter.py │ ├── iotools.py │ ├── loggers.py │ ├── torchtools.py │ └── reidtools.py ├── eval_cylib │ ├── __init__.py │ ├── Makefile │ ├── setup.py │ ├── test_cython.py │ └── eval_metrics_cy.pyx ├── __init__.py ├── losses │ ├── ring_loss.py │ ├── __init__.py │ ├── nll_loss.py │ ├── cross_entropy_loss.py │ ├── bce_loss.py │ ├── hard_mine_triplet_loss.py │ ├── center_loss.py │ └── bce_focal_loss.py ├── reid_dataset │ ├── __init__.py │ ├── import_CUHK01.py │ ├── gdrive_downloader.py │ ├── import_MarketDuke.py │ ├── import_MarketDuke_nodistractors.py │ ├── marketduke_to_hdf5.py │ ├── import_Market1501.py │ ├── import_CUHK03.py │ ├── cuhk03_to_image.py │ ├── reiddataset_downloader.py │ ├── import_PETA.py │ ├── pytorch_prepare.py │ └── import_Market1501Attribute.py ├── optimizers.py ├── models │ ├── __init__.py │ ├── mobilenetv2.py │ ├── shufflenet.py │ ├── mudeep.py │ ├── mlfn.py │ ├── squeezenet.py │ ├── resnext.py │ ├── pcb.py │ ├── xception.py │ ├── resnetmid.py │ ├── densenet.py │ └── resnet.py ├── transforms.py ├── samplers.py ├── eval_metrics.py └── data_manager.py ├── imgs ├── setting.png ├── full-model.PNG ├── ranked_results.png └── deep-person-reid-logo.png ├── .gitattributes ├── requirements.txt ├── scripts ├── market.sh ├── peta.sh └── market_alt.sh ├── LICENSE ├── .gitignore ├── README.md ├── models.py └── DATASETS.md /torchreid/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torchreid/eval_cylib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imgs/setting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ycao5602/SAL/HEAD/imgs/setting.png -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /imgs/full-model.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ycao5602/SAL/HEAD/imgs/full-model.PNG -------------------------------------------------------------------------------- /imgs/ranked_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ycao5602/SAL/HEAD/imgs/ranked_results.png -------------------------------------------------------------------------------- /imgs/deep-person-reid-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ycao5602/SAL/HEAD/imgs/deep-person-reid-logo.png -------------------------------------------------------------------------------- /torchreid/eval_cylib/Makefile: -------------------------------------------------------------------------------- 1 | all: 2 | python setup.py build_ext --inplace 3 | rm -rf build 4 | 5 | clean: 6 | rm -rf build 7 | rm -f eval_metrics_cy.c *.so -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | h5py==2.8.0 2 | scipy==1.1.0 3 | requests==2.20.1 4 | torchvision==0.2.1 5 | Cython==0.29.12 6 | numpy==1.13.3 7 | torch==0.4.1.post2 8 | Pillow>=8.2.0 -------------------------------------------------------------------------------- /torchreid/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | deep-person-reid 3 | == 4 | 5 | Description: PyTorch implementation of deep person re-identification models. 6 | 7 | Github page: https://github.com/KaiyangZhou/deep-person-reid 8 | """ 9 | 10 | __author__ = 'Kaiyang Zhou' 11 | __email__ = 'k.zhou@qmul.ac.uk' -------------------------------------------------------------------------------- /scripts/market.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # abort entire script on error 4 | set -e 5 | 6 | # train model 7 | cd .. 8 | python train.py \ 9 | -s Market-1501 \ 10 | --root dataset \ 11 | --optim adam \ 12 | --label-smooth \ 13 | --max-epoch-pt 100 \ 14 | --max-epoch-jt 100 \ 15 | --max-epoch-al 60 \ 16 | --stepsize 20 40 60 80 \ 17 | --stepsize-sal 20 40 \ 18 | --train-batch-size 128 \ 19 | --test-batch-size 100 \ 20 | -a resnet50 \ 21 | --save-dir log/market-results \ 22 | --eval-freq 10 \ 23 | --save-pt 20 \ 24 | --gpu-devices 0,1 25 | -------------------------------------------------------------------------------- /scripts/peta.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # abort entire script on error 4 | set -e 5 | 6 | # train model 7 | cd .. 8 | python train.py \ 9 | -s PETA \ 10 | --root dataset \ 11 | --optim adam \ 12 | --label-smooth \ 13 | --max-epoch-jt 200 \ 14 | --max-epoch-pt 200 \ 15 | --max-epoch-al 60 \ 16 | --stepsize 20 40 60 80 100 120 140 160 180 \ 17 | --stepsize-sal 20 30 40 50 \ 18 | --train-batch-size 128 \ 19 | --test-batch-size 100 \ 20 | -a resnet50 \ 21 | --save-dir log/peta-results \ 22 | --eval-freq 10 \ 23 | --save-pt 20 \ 24 | --gpu-devices 0,1 25 | -------------------------------------------------------------------------------- /torchreid/eval_cylib/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from distutils.extension import Extension 3 | from Cython.Build import cythonize 4 | import numpy as np 5 | 6 | 7 | def numpy_include(): 8 | try: 9 | numpy_include = np.get_include() 10 | except AttributeError: 11 | numpy_include = np.get_numpy_include() 12 | return numpy_include 13 | 14 | ext_modules = [ 15 | Extension('eval_metrics_cy', 16 | ['eval_metrics_cy.pyx'], 17 | include_dirs=[numpy_include()], 18 | ) 19 | ] 20 | 21 | setup( 22 | name='Cython-based reid evaluation code', 23 | ext_modules=cythonize(ext_modules) 24 | ) -------------------------------------------------------------------------------- /torchreid/utils/avgmeter.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | 5 | class AverageMeter(object): 6 | """Computes and stores the average and current value. 7 | 8 | Code imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 9 | """ 10 | def __init__(self): 11 | self.reset() 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /torchreid/losses/ring_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import warnings 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class RingLoss(nn.Module): 11 | """Ring loss. 12 | 13 | Reference: 14 | Zheng et al. Ring loss: Convex Feature Normalization for Face Recognition. CVPR 2018. 15 | """ 16 | def __init__(self): 17 | super(RingLoss, self).__init__() 18 | warnings.warn("This method is deprecated") 19 | self.radius = nn.Parameter(torch.ones(1, dtype=torch.float)) 20 | 21 | def forward(self, x): 22 | loss = ((x.norm(p=2, dim=1) - self.radius)**2).mean() 23 | return loss -------------------------------------------------------------------------------- /scripts/market_alt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # abort entire script on error 4 | set -e 5 | 6 | # train model 7 | cd .. 8 | python train.py \ 9 | -s Market-1501 \ 10 | --root dataset \ 11 | --optim adam \ 12 | --label-smooth \ 13 | --lr-al 5e-5 \ 14 | --lr-gf 5e-4 \ 15 | --lr-df 5e-4 \ 16 | --max-epoch-pt 100 \ 17 | --max-epoch-jt 100 \ 18 | --max-epoch-al 60 \ 19 | --stepsize 20 40 60 80 \ 20 | --stepsize-sal 20 40 \ 21 | --train-batch-size 128 \ 22 | --test-batch-size 100 \ 23 | -a resnet50 \ 24 | --save-dir log/market-results \ 25 | --eval-freq 10 \ 26 | --save-pt 20 \ 27 | --gpu-devices 0,1 28 | -------------------------------------------------------------------------------- /torchreid/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .cross_entropy_loss import CrossEntropyLoss 6 | from .hard_mine_triplet_loss import TripletLoss 7 | from .bce_loss import BCELoss 8 | from .center_loss import CenterLoss 9 | from .ring_loss import RingLoss 10 | from .nll_loss import NegativeLogLoss 11 | from .bce_focal_loss import FocalLoss 12 | 13 | 14 | def DeepSupervision(criterion, xs, y): 15 | """ 16 | Args: 17 | - criterion: loss function 18 | - xs: tuple of inputs 19 | - y: ground truth 20 | """ 21 | loss = 0. 22 | for x in xs: 23 | loss += criterion(x, y) 24 | loss /= len(xs) 25 | return loss -------------------------------------------------------------------------------- /torchreid/reid_dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .reiddataset_downloader import reiddataset_downloader 2 | from .reiddataset_downloader import reiddataset_downloader_all 3 | from .import_CUHK01 import import_CUHK01 4 | from .import_CUHK03 import import_CUHK03 5 | from .import_Market1501 import import_Market1501 6 | from .import_Market1501Attribute import import_Market1501Attribute 7 | from .import_Market1501Attribute import import_Market1501Attribute_binary 8 | from .import_MarketDuke import import_MarketDuke 9 | from .import_MarketDuke_nodistractors import import_MarketDuke_nodistractors 10 | from .pytorch_prepare import pytorch_prepare 11 | from .pytorch_prepare import pytorch_prepare_all 12 | from .marketduke_to_hdf5 import marketduke_to_hdf5 13 | from .cuhk03_to_image import cuhk03_to_image 14 | from .import_PETA import import_peta 15 | -------------------------------------------------------------------------------- /torchreid/reid_dataset/import_CUHK01.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .reiddataset_downloader import * 3 | def import_CUHK01(dataset_dir): 4 | cuhk01_dir = os.path.join(dataset_dir,'CUHK01') 5 | 6 | if not os.path.exists(cuhk01_dir): 7 | print('Please Download the CUHK01 Dataset') 8 | 9 | file_list=os.listdir(cuhk01_dir) 10 | name_dict={} 11 | for name in file_list: 12 | if name[-3:]=='png': 13 | id = name[:4] 14 | if id not in name_dict: 15 | name_dict[id]=[] 16 | name_dict[id].append([]) 17 | name_dict[id].append([]) 18 | if int(name[-7:-4])<3: 19 | name_dict[id][0].append(os.path.join(cuhk01_dir,name)) 20 | else: 21 | name_dict[id][1].append(os.path.join(cuhk01_dir,name)) 22 | return name_dict -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Kaiyang Zhou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /torchreid/utils/iotools.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import os 4 | import os.path as osp 5 | import errno 6 | import json 7 | import shutil 8 | 9 | import torch 10 | 11 | 12 | def mkdir_if_missing(directory): 13 | if not osp.exists(directory): 14 | try: 15 | os.makedirs(directory) 16 | except OSError as e: 17 | if e.errno != errno.EEXIST: 18 | raise 19 | 20 | 21 | def check_isfile(path): 22 | isfile = osp.isfile(path) 23 | if not isfile: 24 | print("=> Warning: no file found at '{}' (ignored)".format(path)) 25 | return isfile 26 | 27 | 28 | def read_json(fpath): 29 | with open(fpath, 'r') as f: 30 | obj = json.load(f) 31 | return obj 32 | 33 | 34 | def write_json(obj, fpath): 35 | mkdir_if_missing(osp.dirname(fpath)) 36 | with open(fpath, 'w') as f: 37 | json.dump(obj, f, indent=4, separators=(',', ': ')) 38 | 39 | 40 | def save_checkpoint(state, is_best=False, fpath='checkpoint.pth.tar'): 41 | if len(osp.dirname(fpath)) != 0: 42 | mkdir_if_missing(osp.dirname(fpath)) 43 | torch.save(state, fpath) 44 | if is_best: 45 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'best_model.pth.tar')) -------------------------------------------------------------------------------- /torchreid/reid_dataset/gdrive_downloader.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | def gdrive_downloader(destination, id): 4 | URL = "https://docs.google.com/uc?export=download" 5 | 6 | session = requests.Session() 7 | 8 | response = session.get(URL, params = { 'id' : id }, stream = True) 9 | token = get_confirm_token(response) 10 | 11 | if token: 12 | params = { 'id' : id, 'confirm' : token } 13 | response = session.get(URL, params = params, stream = True) 14 | 15 | save_response_content(response, destination) 16 | 17 | def get_confirm_token(response): 18 | for key, value in response.cookies.items(): 19 | if key.startswith('download_warning'): 20 | return value 21 | 22 | return None 23 | 24 | def save_response_content(response, destination): 25 | CHUNK_SIZE = 32768 26 | 27 | with open(destination, "wb") as f: 28 | for chunk in response.iter_content(CHUNK_SIZE): 29 | if chunk: # filter out keep-alive new chunks 30 | f.write(chunk) 31 | 32 | if __name__ == "__main__": 33 | var = raw_input("Please enter public file id : ") 34 | file_id = str(var) 35 | name = raw_input("Please enter name with extension : ") 36 | destination = str(name) 37 | gdrive_downloader(file_id, destination) 38 | -------------------------------------------------------------------------------- /torchreid/reid_dataset/import_MarketDuke.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .reiddataset_downloader import * 3 | 4 | 5 | def import_MarketDuke(data_dir, dataset_name): 6 | dataset_dir = os.path.join(data_dir,dataset_name) 7 | 8 | if not os.path.exists(dataset_dir): 9 | print('Please Download '+dataset_name+ ' Dataset') 10 | 11 | dataset_dir = os.path.join(data_dir,dataset_name) 12 | data_group = ['train','query','gallery'] 13 | for group in data_group: 14 | if group == 'train': 15 | name_dir = os.path.join(dataset_dir , 'bounding_box_train') 16 | elif group == 'query': 17 | name_dir = os.path.join(dataset_dir, 'query') 18 | else: 19 | name_dir = os.path.join(dataset_dir, 'bounding_box_test') 20 | file_list=sorted(os.listdir(name_dir)) 21 | globals()[group]={} 22 | globals()[group]['data']=[] 23 | globals()[group]['ids'] = [] 24 | for name in file_list: 25 | if name[-3:]=='jpg': 26 | id = name.split('_')[0] 27 | cam = int(name.split('_')[1][1]) 28 | images = os.path.join(name_dir,name) 29 | if id not in globals()[group]['ids']: 30 | globals()[group]['ids'].append(id) 31 | globals()[group]['data'].append([images,globals()[group]['ids'].index(id),id,cam,name.split('.')[0]]) 32 | return train,query,gallery -------------------------------------------------------------------------------- /torchreid/reid_dataset/import_MarketDuke_nodistractors.py: -------------------------------------------------------------------------------- 1 | import os 2 | #from .reiddataset_downloader import * 3 | 4 | 5 | def import_MarketDuke_nodistractors(data_dir, dataset_name): 6 | dataset_dir = os.path.join(data_dir,dataset_name) 7 | 8 | if not os.path.exists(dataset_dir): 9 | print('Please Download '+dataset_name+ ' Dataset') 10 | 11 | dataset_dir = os.path.join(data_dir,dataset_name) 12 | data_group = ['train','query','gallery'] 13 | for group in data_group: 14 | if group == 'train': 15 | name_dir = os.path.join(dataset_dir , 'bounding_box_train') 16 | elif group == 'query': 17 | name_dir = os.path.join(dataset_dir, 'query') 18 | else: 19 | name_dir = os.path.join(dataset_dir, 'bounding_box_test') 20 | file_list=sorted(os.listdir(name_dir)) 21 | globals()[group]={} 22 | globals()[group]['data']=[] 23 | globals()[group]['ids'] = [] 24 | for name in file_list: 25 | if name[-3:]=='jpg': 26 | id = name.split('_')[0] 27 | cam = name.split('_')[1][1] 28 | images = os.path.join(name_dir,name) 29 | if (id!='0000' and id !='-1'): 30 | if id not in globals()[group]['ids']: 31 | globals()[group]['ids'].append(id) 32 | globals()[group]['data'].append([images,globals()[group]['ids'].index(id),id,cam,name.split('.')[0]]) 33 | return train, query, gallery -------------------------------------------------------------------------------- /torchreid/reid_dataset/marketduke_to_hdf5.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings('ignore','.*conversion.*') 3 | 4 | import os 5 | import h5py 6 | import numpy as np 7 | from PIL import Image 8 | from .import_MarketDuke import import_MarketDuke 9 | 10 | def marketduke_to_hdf5(data_dir,dataset_name,save_dir=os.getcwd()): 11 | phase_list = ['train','query','gallery'] 12 | dataset = import_MarketDuke(data_dir,dataset_name) 13 | dt = h5py.special_dtype(vlen=str) 14 | 15 | f = h5py.File(os.path.join(save_dir,dataset_name+'.hdf5'),'w') 16 | for phase in phase_list: 17 | grp = f.create_group(phase) 18 | phase_dataset = dataset[phase_list.index(phase)] 19 | for i in range(len(phase_dataset['data'])): 20 | name = phase_dataset['data'][i][0].split('/')[-1].split('.')[0] 21 | temp = grp.create_group(name) 22 | temp.create_dataset('img',data=Image.open(phase_dataset['data'][i][0])) 23 | temp.create_dataset('index',data=int(phase_dataset['data'][i][1])) 24 | temp.create_dataset('id',data=phase_dataset['data'][i][2], dtype=dt) 25 | temp.create_dataset('cam',data=int(phase_dataset['data'][i][3])) 26 | 27 | ids = f.create_group('ids') 28 | ids.create_dataset('train',data=np.array(dataset[0]['ids'],'S4'),dtype=dt) 29 | ids.create_dataset('query',data=np.array(dataset[1]['ids'],'S4'),dtype=dt) 30 | ids.create_dataset('gallery',data=np.array(dataset[2]['ids'],'S4'),dtype=dt) 31 | 32 | f.close() -------------------------------------------------------------------------------- /torchreid/optimizers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | 6 | def init_optimizer(params, 7 | optim='adam', 8 | lr=0.003, 9 | weight_decay=5e-4, 10 | momentum=0.9, # momentum factor for sgd and rmsprop 11 | sgd_dampening=0, # sgd's dampening for momentum 12 | sgd_nesterov=False, # whether to enable sgd's Nesterov momentum 13 | rmsprop_alpha=0.99, # rmsprop's smoothing constant 14 | adam_beta1=0.9, # exponential decay rate for adam's first moment 15 | adam_beta2=0.999 # # exponential decay rate for adam's second moment 16 | ): 17 | if optim == 'adam': 18 | return torch.optim.Adam(params, lr=lr, weight_decay=weight_decay, 19 | betas=(adam_beta1, adam_beta2)) 20 | 21 | elif optim == 'amsgrad': 22 | return torch.optim.Adam(params, lr=lr, weight_decay=weight_decay, 23 | betas=(adam_beta1, adam_beta2), amsgrad=True) 24 | 25 | elif optim == 'sgd': 26 | return torch.optim.SGD(params, lr=lr, momentum=momentum, weight_decay=weight_decay, 27 | dampening=sgd_dampening, nesterov=sgd_nesterov) 28 | 29 | elif optim == 'rmsprop': 30 | return torch.optim.RMSprop(params, lr=lr, momentum=momentum, weight_decay=weight_decay, 31 | alpha=rmsprop_alpha) 32 | 33 | else: 34 | raise ValueError("Unsupported optimizer: {}".format(optim)) -------------------------------------------------------------------------------- /torchreid/losses/nll_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class NegativeLogLoss(nn.Module): 9 | """Cross entropy loss with label smoothing regularizer. 10 | 11 | Reference: 12 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 13 | 14 | Equation: y = (1 - epsilon) * y + epsilon / K. 15 | 16 | Args: 17 | - num_classes (int): number of classes 18 | - epsilon (float): weight 19 | - use_gpu (bool): whether to use gpu devices 20 | - label_smooth (bool): whether to apply label smoothing, if False, epsilon = 0 21 | """ 22 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True, label_smooth=True): 23 | super(NegativeLogLoss, self).__init__() 24 | self.num_classes = num_classes 25 | self.epsilon = epsilon if label_smooth else 0 26 | self.use_gpu = use_gpu 27 | self.softmax = nn.Softmax(dim=1) 28 | 29 | def forward(self, inputs, targets): 30 | """ 31 | Args: 32 | - inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 33 | - targets: ground truth labels with shape (num_classes) 34 | """ 35 | inputs = self.softmax(inputs) 36 | targets = torch.zeros(inputs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1) 37 | if self.use_gpu: targets = targets.cuda() 38 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 39 | loss = inputs - targets * torch.log(inputs + 1e-8) 40 | return loss.mean(0).sum() -------------------------------------------------------------------------------- /torchreid/reid_dataset/import_Market1501.py: -------------------------------------------------------------------------------- 1 | import os 2 | #from .reiddataset_downloader import * 3 | def import_Market1501(dataset_dir): 4 | market1501_dir = os.path.join(dataset_dir,'Market-1501') 5 | if not os.path.exists(market1501_dir): 6 | print('Please Download Market1501 Dataset') 7 | data_group = ['train','query','gallery'] 8 | for group in data_group: 9 | if group == 'train': 10 | name_dir = os.path.join(market1501_dir , 'bounding_box_train') 11 | elif group == 'query': 12 | name_dir = os.path.join(market1501_dir, 'query') 13 | else: 14 | name_dir = os.path.join(market1501_dir, 'bounding_box_test') 15 | file_list=os.listdir(name_dir) 16 | globals()[group]={} 17 | for name in file_list: 18 | if name[-3:]=='jpg': 19 | ''' 20 | store the directories of images by camera number. 21 | ''' 22 | id = name.split('_')[0] 23 | if id not in globals()[group]: 24 | ''' 25 | create a global dictionary 26 | globals()[group][id]=camera no. 27 | ''' 28 | globals()[group][id]=[] 29 | globals()[group][id].append([]) 30 | globals()[group][id].append([]) 31 | globals()[group][id].append([]) 32 | globals()[group][id].append([]) 33 | globals()[group][id].append([]) 34 | globals()[group][id].append([]) 35 | cam_n = int(name.split('_')[1][1])-1 36 | globals()[group][id][cam_n].append(os.path.join(name_dir,name)) 37 | return train,query,gallery -------------------------------------------------------------------------------- /torchreid/reid_dataset/import_CUHK03.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .reiddataset_downloader import * 3 | def import_CUHK03(dataset_dir, detected = False): 4 | 5 | cuhk03_dir = os.path.join(dataset_dir,'CUHK03') 6 | 7 | if not os.path.exists(cuhk03_dir): 8 | Print('Please Download the CUHK03 Dataset') 9 | 10 | if not detected: 11 | cuhk03_dir = os.path.join(cuhk03_dir , 'labeled') 12 | else: 13 | cuhk03_dir = os.path.join(cuhk03_dir , 'detected') 14 | 15 | campair_list = os.listdir(cuhk03_dir) 16 | #campair_list = ['P1','P2','P3'] 17 | name_dict={} 18 | for campair in campair_list: 19 | cam1_list = [] 20 | cam1_list=os.listdir(os.path.join(cuhk03_dir,campair,'cam1')) 21 | cam2_list=os.listdir(os.path.join(cuhk03_dir,campair,'cam2')) 22 | for file in cam1_list: 23 | id = campair[1:]+'-'+file.split('-')[0] 24 | if id not in name_dict: 25 | name_dict[id]=[] 26 | name_dict[id].append([]) 27 | name_dict[id].append([]) 28 | name_dict[id][0].append(os.path.join(cuhk03_dir,campair,'cam1',file)) 29 | for file in cam2_list: 30 | id = campair[1:]+'-'+file.split('-')[0] 31 | if id not in name_dict: 32 | name_dict[id]=[] 33 | name_dict[id].append([]) 34 | name_dict[id].append([]) 35 | name_dict[id][1].append(os.path.join(cuhk03_dir,campair,'cam2',file)) 36 | return name_dict 37 | 38 | def cuhk03_test(data_dir): 39 | CUHK03_dir = os.path.join(data_dir , 'CUHK03') 40 | f = h5py.File(os.path.join(CUHK03_dir,'cuhk-03.mat')) 41 | test = [] 42 | for i in range(20): 43 | test_set = (np.array(f[f['testsets'][0][i]],dtype='int').T).tolist() 44 | test.append(test_set) 45 | 46 | return test -------------------------------------------------------------------------------- /torchreid/reid_dataset/cuhk03_to_image.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings('ignore','.*conversion.*') 3 | 4 | import os 5 | import zipfile 6 | import shutil 7 | import requests 8 | import h5py 9 | import numpy as np 10 | from PIL import Image 11 | import argparse 12 | 13 | def cuhk03_to_image(CUHK03_dir): 14 | 15 | f = h5py.File(os.path.join(CUHK03_dir,'cuhk-03.mat')) 16 | 17 | detected_labeled = ['detected','labeled'] 18 | print('converting') 19 | for data_type in detected_labeled: 20 | 21 | datatype_dir = os.path.join(CUHK03_dir, data_type) 22 | if not os.path.exists(datatype_dir): 23 | os.makedirs(datatype_dir) 24 | 25 | for campair in range(len(f[data_type][0])): 26 | campair_dir = os.path.join(datatype_dir,'P%d'%(campair+1)) 27 | cam1_dir = os.path.join(campair_dir,'cam1') 28 | cam2_dir = os.path.join(campair_dir,'cam2') 29 | 30 | if not os.path.exists(campair_dir): 31 | os.makedirs(campair_dir) 32 | if not os.path.exists(cam1_dir): 33 | os.makedirs(cam1_dir) 34 | if not os.path.exists(cam2_dir): 35 | os.makedirs(cam2_dir) 36 | 37 | for img_no in range(f[f[data_type][0][campair]].shape[0]): 38 | if img_no < 5: 39 | cam_dir = 'cam1' 40 | else: 41 | cam_dir = 'cam2' 42 | for person_id in range(f[f[data_type][0][campair]].shape[1]): 43 | img = np.array(f[f[f[data_type][0][campair]][img_no][person_id]]) 44 | if img.shape[0] !=2: 45 | img = np.transpose(img, (2,1,0)) 46 | im = Image.fromarray(img) 47 | im.save(os.path.join(campair_dir, cam_dir, "%d-%d.jpg"%(person_id+1,img_no+1))) -------------------------------------------------------------------------------- /torchreid/losses/cross_entropy_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class CrossEntropyLoss(nn.Module): 9 | """Cross entropy loss with label smoothing regularizer. 10 | 11 | Reference: 12 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 13 | 14 | Equation: y = (1 - epsilon) * y + epsilon / K. 15 | 16 | Args: 17 | - num_classes (int): number of classes 18 | - epsilon (float): weight 19 | - use_gpu (bool): whether to use gpu devices 20 | - label_smooth (bool): whether to apply label smoothing, if False, epsilon = 0 21 | """ 22 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True, label_smooth=True): 23 | super(CrossEntropyLoss, self).__init__() 24 | self.num_classes = num_classes 25 | self.epsilon = epsilon if label_smooth else 0 26 | self.use_gpu = use_gpu 27 | self.logsoftmax = nn.LogSoftmax(dim=1) 28 | 29 | def forward(self, inputs, targets): 30 | """ 31 | Args: 32 | - inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 33 | - targets: ground truth labels with shape (batch_size) 34 | """ 35 | # print('inputs size: ',inputs.size()) 36 | log_probs = self.logsoftmax(inputs) 37 | # print('target unsqueeze size: ',targets.unsqueeze(1).data.cpu().size()) 38 | # print('log_probs: ',log_probs) 39 | # print('targets: ',targets) 40 | 41 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1) 42 | 43 | if self.use_gpu: targets = targets.cuda() 44 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 45 | loss = (- targets * log_probs).mean(0).sum() 46 | return loss -------------------------------------------------------------------------------- /torchreid/losses/bce_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class BCELoss(nn.Module): 9 | """Cross entropy loss with label smoothing regularizer. 10 | 11 | Reference: 12 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 13 | 14 | Equation: y = (1 - epsilon) * y + epsilon / K. 15 | 16 | Args: 17 | - num_classes (int): number of classes 18 | - epsilon (float): weight 19 | - use_gpu (bool): whether to use gpu devices 20 | - label_smooth (bool): whether to apply label smoothing, if False, epsilon = 0 21 | """ 22 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True, label_smooth=True): 23 | super(BCELoss, self).__init__() 24 | self.num_classes = num_classes 25 | self.epsilon = epsilon if label_smooth else 0 26 | self.use_gpu = use_gpu 27 | self.sigmoid = nn.Sigmoid() 28 | 29 | def forward(self, inputs, targets): 30 | """ 31 | Args: 32 | - inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 33 | - targets: ground truth labels with shape (batch_size, num_classes) 34 | """ 35 | if self.use_gpu: 36 | targets = targets.cuda() 37 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 38 | 39 | ################################### 40 | ##### pytorch bce with logits###### 41 | ################################### 42 | 43 | max_val = (-inputs).clamp(min=0) 44 | loss = inputs - inputs * targets + max_val + ((-max_val).exp() + (-inputs - max_val).exp()).log() 45 | 46 | # loss = torch.max(inputs, torch.zeros_like(inputs)) - inputs * targets \ 47 | # + (torch.ones_like(inputs)+(-torch.abs(inputs)).exp()).log() 48 | 49 | return loss.mean(0).sum() -------------------------------------------------------------------------------- /torchreid/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .resnet import * 4 | from .resnetmid import * 5 | from .resnext import * 6 | from .senet import * 7 | from .densenet import * 8 | from .inceptionresnetv2 import * 9 | from .inceptionv4 import * 10 | from .xception import * 11 | 12 | from .nasnet import * 13 | from .mobilenetv2 import * 14 | from .shufflenet import * 15 | from .squeezenet import * 16 | 17 | from .mudeep import * 18 | from .hacnn import * 19 | from .pcb import * 20 | from .mlfn import * 21 | 22 | 23 | __model_factory = { 24 | # image classification models 25 | 'resnet50': resnet50, 26 | 'resnet50_fc512': resnet50_fc512, 27 | 'resnext50_32x4d': resnext50_32x4d, 28 | 'resnext50_32x4d_fc512': resnext50_32x4d_fc512, 29 | 'se_resnet50': se_resnet50, 30 | 'se_resnet50_fc512': se_resnet50_fc512, 31 | 'se_resnet101': se_resnet101, 32 | 'se_resnext50_32x4d': se_resnext50_32x4d, 33 | 'se_resnext101_32x4d': se_resnext101_32x4d, 34 | 'densenet121': densenet121, 35 | 'densenet121_fc512': densenet121_fc512, 36 | 'inceptionresnetv2': InceptionResNetV2, 37 | 'inceptionv4': inceptionv4, 38 | 'xception': xception, 39 | # lightweight models 40 | 'nasnsetmobile': nasnetamobile, 41 | 'mobilenetv2': MobileNetV2, 42 | 'shufflenet': ShuffleNet, 43 | 'squeezenet1_0': squeezenet1_0, 44 | 'squeezenet1_0_fc512': squeezenet1_0_fc512, 45 | 'squeezenet1_1': squeezenet1_1, 46 | # reid-specific models 47 | 'mudeep': MuDeep, 48 | 'resnet50mid': resnet50mid, 49 | 'hacnn': HACNN, 50 | 'pcb_p6': pcb_p6, 51 | 'pcb_p4': pcb_p4, 52 | 'mlfn': mlfn, 53 | } 54 | 55 | 56 | def get_names(): 57 | return list(__model_factory.keys()) 58 | 59 | 60 | def init_model(name, *args, **kwargs): 61 | if name not in list(__model_factory.keys()): 62 | raise KeyError("Unknown model: {}".format(name)) 63 | return __model_factory[name](*args, **kwargs) -------------------------------------------------------------------------------- /torchreid/losses/hard_mine_triplet_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class TripletLoss(nn.Module): 9 | """Triplet loss with hard positive/negative mining. 10 | 11 | Reference: 12 | Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737. 13 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py. 14 | 15 | Args: 16 | - margin (float): margin for triplet. 17 | """ 18 | def __init__(self, margin=0.3): 19 | super(TripletLoss, self).__init__() 20 | self.margin = margin 21 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 22 | 23 | def forward(self, inputs, targets): 24 | """ 25 | Args: 26 | - inputs: feature matrix with shape (batch_size, feat_dim) 27 | - targets: ground truth labels with shape (num_classes) 28 | """ 29 | n = inputs.size(0) 30 | 31 | # Compute pairwise distance, replace by the official when merged 32 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 33 | dist = dist + dist.t() 34 | dist.addmm_(1, -2, inputs, inputs.t()) 35 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 36 | 37 | # For each anchor, find the hardest positive and negative 38 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 39 | dist_ap, dist_an = [], [] 40 | for i in range(n): 41 | dist_ap.append(dist[i][mask[i]].max().unsqueeze(0)) 42 | dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0)) 43 | dist_ap = torch.cat(dist_ap) 44 | dist_an = torch.cat(dist_an) 45 | 46 | # Compute ranking hinge loss 47 | y = torch.ones_like(dist_an) 48 | loss = self.ranking_loss(dist_an, dist_ap, y) 49 | return loss -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom 2 | data/ 3 | log/ 4 | saved-models/ 5 | debug.py 6 | .idea/ 7 | 8 | # cython eval code 9 | torchreid/eval_cylib/eval_metrics_cy.c 10 | torchreid/eval_cylib/*.html 11 | 12 | # OS X 13 | .DS_Store 14 | .Spotlight-V100 15 | .Trashes 16 | ._* 17 | 18 | # Byte-compiled / optimized / DLL files 19 | __pycache__/ 20 | *.py[cod] 21 | *$py.class 22 | 23 | # C extensions 24 | *.so 25 | 26 | # Distribution / packaging 27 | .Python 28 | build/ 29 | develop-eggs/ 30 | dist/ 31 | downloads/ 32 | eggs/ 33 | .eggs/ 34 | lib/ 35 | lib64/ 36 | parts/ 37 | sdist/ 38 | var/ 39 | wheels/ 40 | *.egg-info/ 41 | .installed.cfg 42 | *.egg 43 | MANIFEST 44 | 45 | # PyInstaller 46 | # Usually these files are written by a python script from a template 47 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 48 | *.manifest 49 | *.spec 50 | 51 | # Installer logs 52 | pip-log.txt 53 | pip-delete-this-directory.txt 54 | 55 | # Unit test / coverage reports 56 | htmlcov/ 57 | .tox/ 58 | .coverage 59 | .coverage.* 60 | .cache 61 | nosetests.xml 62 | coverage.xml 63 | *.cover 64 | .hypothesis/ 65 | .pytest_cache/ 66 | 67 | # Translations 68 | *.mo 69 | *.pot 70 | 71 | # Django stuff: 72 | *.log 73 | .static_storage/ 74 | .media/ 75 | local_settings.py 76 | 77 | # Flask stuff: 78 | instance/ 79 | .webassets-cache 80 | 81 | # Scrapy stuff: 82 | .scrapy 83 | 84 | # Sphinx documentation 85 | docs/_build/ 86 | 87 | # PyBuilder 88 | target/ 89 | 90 | # Jupyter Notebook 91 | .ipynb_checkpoints 92 | 93 | # pyenv 94 | .python-version 95 | 96 | # celery beat schedule file 97 | celerybeat-schedule 98 | 99 | # SageMath parsed files 100 | *.sage.py 101 | 102 | # Environments 103 | .env 104 | .venv 105 | env/ 106 | venv/ 107 | ENV/ 108 | env.bak/ 109 | venv.bak/ 110 | 111 | # Spyder project settings 112 | .spyderproject 113 | .spyproject 114 | 115 | # Rope project settings 116 | .ropeproject 117 | 118 | # mkdocs documentation 119 | /site 120 | 121 | # mypy 122 | .mypy_cache/ 123 | -------------------------------------------------------------------------------- /torchreid/losses/center_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import warnings 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class CenterLoss(nn.Module): 11 | """Center loss. 12 | 13 | Reference: 14 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. 15 | 16 | Args: 17 | - num_classes (int): number of classes. 18 | - feat_dim (int): feature dimension. 19 | """ 20 | def __init__(self, num_classes=10, feat_dim=2, use_gpu=True): 21 | super(CenterLoss, self).__init__() 22 | warnings.warn("This method is deprecated") 23 | self.num_classes = num_classes 24 | self.feat_dim = feat_dim 25 | self.use_gpu = use_gpu 26 | 27 | if self.use_gpu: 28 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda()) 29 | else: 30 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) 31 | 32 | def forward(self, x, labels): 33 | """ 34 | Args: 35 | - x: feature matrix with shape (batch_size, feat_dim). 36 | - labels: ground truth labels with shape (num_classes). 37 | """ 38 | batch_size = x.size(0) 39 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ 40 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() 41 | distmat.addmm_(1, -2, x, self.centers.t()) 42 | 43 | classes = torch.arange(self.num_classes).long() 44 | if self.use_gpu: classes = classes.cuda() 45 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) 46 | mask = labels.eq(classes.expand(batch_size, self.num_classes)) 47 | 48 | dist = [] 49 | for i in range(batch_size): 50 | value = distmat[i][mask[i]] 51 | value = value.clamp(min=1e-12, max=1e+12) # for numerical stability 52 | dist.append(value) 53 | dist = torch.cat(dist) 54 | loss = dist.mean() 55 | 56 | return loss -------------------------------------------------------------------------------- /torchreid/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | from torchvision.transforms import * 5 | import torch 6 | 7 | from PIL import Image 8 | import random 9 | import numpy as np 10 | 11 | 12 | class Random2DTranslation(object): 13 | """ 14 | With a probability, first increase image size to (1 + 1/8), and then perform random crop. 15 | 16 | Args: 17 | - height (int): target image height. 18 | - width (int): target image width. 19 | - p (float): probability of performing this transformation. Default: 0.5. 20 | """ 21 | def __init__(self, height, width, p=0.5, interpolation=Image.BILINEAR): 22 | self.height = height 23 | self.width = width 24 | self.p = p 25 | self.interpolation = interpolation 26 | 27 | def __call__(self, img): 28 | """ 29 | Args: 30 | - img (PIL Image): Image to be cropped. 31 | """ 32 | if random.uniform(0, 1) > self.p: 33 | return img.resize((self.width, self.height), self.interpolation) 34 | 35 | new_width, new_height = int(round(self.width * 1.125)), int(round(self.height * 1.125)) 36 | resized_img = img.resize((new_width, new_height), self.interpolation) 37 | x_maxrange = new_width - self.width 38 | y_maxrange = new_height - self.height 39 | x1 = int(round(random.uniform(0, x_maxrange))) 40 | y1 = int(round(random.uniform(0, y_maxrange))) 41 | croped_img = resized_img.crop((x1, y1, x1 + self.width, y1 + self.height)) 42 | return croped_img 43 | 44 | 45 | def build_transforms(height, width, is_train, **kwargs): 46 | """Build transforms 47 | 48 | Args: 49 | - height (int): target image height. 50 | - width (int): target image width. 51 | - is_train (bool): train or test phase. 52 | """ 53 | 54 | # use imagenet mean and std as default 55 | imagenet_mean = [0.485, 0.456, 0.406] 56 | imagenet_std = [0.229, 0.224, 0.225] 57 | normalize = Normalize(mean=imagenet_mean, std=imagenet_std) 58 | 59 | transforms = [] 60 | 61 | if is_train: 62 | transforms += [Random2DTranslation(height, width)] 63 | transforms += [RandomHorizontalFlip()] 64 | else: 65 | transforms += [Resize((height, width))] 66 | 67 | transforms += [ToTensor()] 68 | transforms += [normalize] 69 | 70 | transforms = Compose(transforms) 71 | 72 | return transforms -------------------------------------------------------------------------------- /torchreid/utils/loggers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import sys 4 | import os 5 | import os.path as osp 6 | 7 | from .iotools import mkdir_if_missing 8 | 9 | 10 | class Logger(object): 11 | """ 12 | Write console output to external text file. 13 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py. 14 | """ 15 | def __init__(self, fpath=None): 16 | self.console = sys.stdout 17 | self.file = None 18 | if fpath is not None: 19 | mkdir_if_missing(osp.dirname(fpath)) 20 | self.file = open(fpath, 'w') 21 | 22 | def __del__(self): 23 | self.close() 24 | 25 | def __enter__(self): 26 | pass 27 | 28 | def __exit__(self, *args): 29 | self.close() 30 | 31 | def write(self, msg): 32 | self.console.write(msg) 33 | if self.file is not None: 34 | self.file.write(msg) 35 | 36 | def flush(self): 37 | self.console.flush() 38 | if self.file is not None: 39 | self.file.flush() 40 | os.fsync(self.file.fileno()) 41 | 42 | def close(self): 43 | self.console.close() 44 | if self.file is not None: 45 | self.file.close() 46 | 47 | 48 | class RankLogger(object): 49 | """ 50 | RankLogger records the rank1 matching accuracy obtained for each 51 | test dataset at specified evaluation steps and provides a function 52 | to show the summarized results, which are convenient for analysis. 53 | 54 | Args: 55 | - source_names (list): list of strings (names) of source datasets. 56 | - target_names (list): list of strings (names) of target datasets. 57 | """ 58 | def __init__(self, source_names, target_names): 59 | self.source_names = source_names 60 | self.target_names = target_names 61 | self.logger = {name: {'epoch': [], 'rank1': []} for name in self.target_names} 62 | 63 | def write(self, name, epoch, rank1): 64 | self.logger[name]['epoch'].append(epoch) 65 | self.logger[name]['rank1'].append(rank1) 66 | 67 | def show_summary(self): 68 | print("=> Show summary") 69 | for name in self.target_names: 70 | from_where = 'source' if name in self.source_names else 'target' 71 | print("{} ({})".format(name, from_where)) 72 | for epoch, rank1 in zip(self.logger[name]['epoch'], self.logger[name]['rank1']): 73 | print("- epoch {}\t rank1 {:.1%}".format(epoch, rank1)) -------------------------------------------------------------------------------- /torchreid/utils/torchtools.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | def adjust_learning_rate(optimizer, base_lr, epoch, stepsize=20, gamma=0.1, 10 | linear_decay=False, final_lr=0, max_epoch=100): 11 | if linear_decay: 12 | # linearly decay learning rate from base_lr to final_lr 13 | frac_done = epoch / max_epoch 14 | lr = frac_done * final_lr + (1. - frac_done) * base_lr 15 | else: 16 | # decay learning rate by gamma for every stepsize 17 | lr = base_lr * (gamma ** (epoch // stepsize)) 18 | 19 | for param_group in optimizer.param_groups: 20 | param_group['lr'] = lr 21 | 22 | 23 | def set_bn_to_eval(m): 24 | # 1. no update for running mean and var 25 | # 2. scale and shift parameters are still trainable 26 | classname = m.__class__.__name__ 27 | if classname.find('BatchNorm') != -1: 28 | m.eval() 29 | 30 | 31 | def open_all_layers(model): 32 | """ 33 | Open all layers in model for training. 34 | 35 | Args: 36 | - model (nn.Module): neural net model. 37 | """ 38 | model.train() 39 | for p in model.parameters(): 40 | p.requires_grad = True 41 | 42 | 43 | def open_specified_layers(model, open_layers): 44 | """ 45 | Open specified layers in model for training while keeping 46 | other layers frozen. 47 | 48 | Args: 49 | - model (nn.Module): neural net model. 50 | - open_layers (list): list of layer names. 51 | """ 52 | if isinstance(model, nn.DataParallel): 53 | model = model.module 54 | 55 | for layer in open_layers: 56 | assert hasattr(model, layer), "'{}' is not an attribute of the model, please provide the correct name".format(layer) 57 | 58 | for name, module in model.named_children(): 59 | if name in open_layers: 60 | module.train() 61 | for p in module.parameters(): 62 | p.requires_grad = True 63 | else: 64 | module.eval() 65 | for p in module.parameters(): 66 | p.requires_grad = False 67 | 68 | 69 | def count_num_param(model): 70 | num_param = sum(p.numel() for p in model.parameters()) / 1e+06 71 | 72 | if isinstance(model, nn.DataParallel): 73 | model = model.module 74 | 75 | if hasattr(model, 'classifier') and isinstance(model.classifier, nn.Module): 76 | # we ignore the classifier because it is unused at test time 77 | num_param -= sum(p.numel() for p in model.classifier.parameters()) / 1e+06 78 | return num_param -------------------------------------------------------------------------------- /torchreid/eval_cylib/test_cython.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import sys 4 | import os.path as osp 5 | import timeit 6 | import numpy as np 7 | 8 | sys.path.insert(0, osp.dirname(osp.abspath(__file__)) + '/../..') 9 | 10 | from torchreid.eval_metrics import evaluate 11 | 12 | """ 13 | Test the speed of cython-based evaluation code. The speed improvements 14 | can be much bigger when using the real reid data, which contains a larger 15 | amount of query and gallery images. 16 | 17 | Note: you might encounter the following error: 18 | 'AssertionError: Error: all query identities do not appear in gallery'. 19 | This is normal because the inputs are random numbers. Just try again. 20 | """ 21 | 22 | print("*** Compare running time ***") 23 | 24 | setup = ''' 25 | import sys 26 | import os.path as osp 27 | import numpy as np 28 | sys.path.insert(0, osp.dirname(osp.abspath(__file__)) + '/../..') 29 | from torchreid.eval_metrics import evaluate 30 | num_q = 30 31 | num_g = 300 32 | max_rank = 5 33 | distmat = np.random.rand(num_q, num_g) * 20 34 | q_pids = np.random.randint(0, num_q, size=num_q) 35 | g_pids = np.random.randint(0, num_g, size=num_g) 36 | q_camids = np.random.randint(0, 5, size=num_q) 37 | g_camids = np.random.randint(0, 5, size=num_g) 38 | ''' 39 | 40 | print("=> Using market1501's metric") 41 | pytime = timeit.timeit('evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=False)', setup=setup, number=20) 42 | cytime = timeit.timeit('evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=True)', setup=setup, number=20) 43 | print("Python time: {} s".format(pytime)) 44 | print("Cython time: {} s".format(cytime)) 45 | print("Cython is {} times faster than python\n".format(pytime / cytime)) 46 | 47 | print("=> Using cuhk03's metric") 48 | pytime = timeit.timeit('evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=True, use_cython=False)', setup=setup, number=20) 49 | cytime = timeit.timeit('evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=True, use_cython=True)', setup=setup, number=20) 50 | print("Python time: {} s".format(pytime)) 51 | print("Cython time: {} s".format(cytime)) 52 | print("Cython is {} times faster than python\n".format(pytime / cytime)) 53 | 54 | """ 55 | print("=> Check precision") 56 | 57 | num_q = 30 58 | num_g = 300 59 | max_rank = 5 60 | distmat = np.random.rand(num_q, num_g) * 20 61 | q_pids = np.random.randint(0, num_q, size=num_q) 62 | g_pids = np.random.randint(0, num_g, size=num_g) 63 | q_camids = np.random.randint(0, 5, size=num_q) 64 | g_camids = np.random.randint(0, 5, size=num_g) 65 | 66 | cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=False) 67 | print("Python:\nmAP = {} \ncmc = {}\n".format(mAP, cmc)) 68 | cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=True) 69 | print("Cython:\nmAP = {} \ncmc = {}\n".format(mAP, cmc)) 70 | """ -------------------------------------------------------------------------------- /torchreid/losses/bce_focal_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class FocalLoss(nn.Module): 9 | """Cross entropy loss with label smoothing regularizer. 10 | 11 | Reference: 12 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 13 | 14 | Equation: y = (1 - epsilon) * y + epsilon / K. 15 | 16 | Args: 17 | - num_classes (int): number of classes 18 | - epsilon (float): weight 19 | - use_gpu (bool): whether to use gpu devices 20 | - label_smooth (bool): whether to apply label smoothing, if False, epsilon = 0 21 | - prior: Class*1 vector, prior class distribution 22 | """ 23 | def __init__(self, num_classes, w, epsilon=0.1, use_gpu=True, label_smooth=True, gamma = 0.5): 24 | super(FocalLoss, self).__init__() 25 | self.num_classes = num_classes 26 | self.epsilon = epsilon if label_smooth else 0 27 | self.use_gpu = use_gpu 28 | self.sigmoid = nn.Sigmoid() 29 | self.gamma = gamma 30 | self.w = w 31 | 32 | def forward(self, inputs, targets): 33 | """ 34 | Args: 35 | - inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 36 | - targets: ground truth labels with shape (num_classes) 37 | """ 38 | if self.use_gpu: 39 | targets = targets.cuda() 40 | w = torch.tensor(self.w).float().cuda() 41 | else: 42 | w = torch.tensor(self.w).float().cpu() 43 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 44 | 45 | # targets = targets*torch.log(sig_output+1e-5) + (1-targets)*torch.log(1-sig_output+1e-5) 46 | # loss = (- targets).mean(0).sum() 47 | 48 | # neg_abs = - inputs.abs() 49 | # loss = (inputs.clamp(min=0) - inputs * targets + (1 + neg_abs.exp()).log()).mean(0).sum() 50 | 51 | # max_val = (-inputs).clamp(min=0) 52 | # loss = inputs - inputs * targets + max_val + ((-max_val).exp() + (-inputs - max_val).exp()).log() 53 | 54 | inputs_pos = inputs.clamp(min=0) 55 | inputs_neg = inputs.clamp(max=0) 56 | 57 | # sigmoid(inputs)**self.gamma 58 | sigmoid_pos = (self.gamma*inputs_neg-self.gamma*(1+(-inputs.abs()).exp()).log()).exp() 59 | 60 | # 1-sigmoid(inputs)**self.gamma 61 | sigmoid_neg = (-self.gamma*inputs_pos-self.gamma*(1+(-inputs.abs()).exp()).log()).exp() 62 | 63 | first_pos = -sigmoid_pos*inputs_pos*(1-targets) 64 | first_neg = sigmoid_neg*inputs_neg*targets 65 | # print('pos: ',inputs_pos.norm()) 66 | # print('neg: ',inputs_neg.norm()) 67 | 68 | loss = -(first_pos + first_neg - sigmoid_neg*(1+(-inputs.abs()).exp() 69 | ).log()*targets - sigmoid_pos*(1+(-inputs.abs()).exp()).log()*(1-targets)) 70 | 71 | return (w*loss.mean(0)).sum() -------------------------------------------------------------------------------- /torchreid/reid_dataset/reiddataset_downloader.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import warnings 3 | warnings.filterwarnings('ignore','.*conversion.*') 4 | 5 | import os 6 | import zipfile 7 | import shutil 8 | import requests 9 | import h5py 10 | import numpy as np 11 | from PIL import Image 12 | import argparse 13 | from .gdrive_downloader import gdrive_downloader 14 | from .cuhk03_to_image import cuhk03_to_image 15 | 16 | dataset = { 17 | 'CUHK01': '153IzD3vyQ0PqxxanQRlP9l89F1S5Vr47', 18 | 'CUHK02': '0B2FnquNgAXoneE5YamFXY3NjYWM', 19 | 'CUHK03': '1BO4G9gbOTJgtYIB0VNyHQpZb8Lcn-05m', 20 | 'VIPeR': '0B2FnquNgAXonZzJPQUtrcWJWbWc', 21 | 'Market1501': '0B2FnquNgAXonU3RTcE1jQlZ3X0E', 22 | 'Market1501Attribute' : '1YMgni5oz-RPkyKHzOKnYRR2H3IRKdsHO', 23 | 'DukeMTMC': '1qtFGJQ6eFu66Tt7WG85KBxtACSE8RBZ0', 24 | 'DukeMTMCAttribute' : '1eilPJFnk_EHECKj2glU_ZLLO7eR3JIiO' 25 | } 26 | 27 | dataset_hdf5 = { 28 | 'Market1501': '1ipvyt4qesVK6CUiGcQdwle2c2XYknKco', 29 | 'DukeMTMC': '1tP-fty5YE-W2F6B5rjnQNfE-NzNssGM2' 30 | } 31 | 32 | def reiddataset_downloader(data_dir, data_name, hdf5 = True): 33 | 34 | if not os.path.exists(data_dir): 35 | os.makedirs(data_dir) 36 | 37 | if hdf5: 38 | dataset_dir = os.path.join(data_dir , data_name) 39 | if not os.path.exists(dataset_dir): 40 | os.makedirs(dataset_dir) 41 | destination = os.path.join(dataset_dir , data_name+'.hdf5') 42 | if not os.path.isfile(destination): 43 | id = dataset_hdf5[data_name] 44 | print("Downloading %s in HDF5 Formate" %data_name) 45 | gdrive_downloader(destination, id) 46 | print("Done") 47 | else: 48 | print("Dataset Check Success: %s exists!" %data_name) 49 | else: 50 | data_dir_exist = os.path.join(data_dir , data_name) 51 | 52 | if not os.path.exists(data_dir_exist): 53 | temp_dir = os.path.join(data_dir , 'temp') 54 | 55 | if not os.path.exists(temp_dir): 56 | os.makedirs(temp_dir) 57 | 58 | destination = os.path.join(temp_dir , data_name) 59 | 60 | id = dataset[data_name] 61 | 62 | print("Downloading %s in Original Images" % data_name) 63 | gdrive_downloader(destination, id) 64 | 65 | zip_ref = zipfile.ZipFile(destination) 66 | print("Extracting %s" % data_name) 67 | zip_ref.extractall(data_dir) 68 | zip_ref.close() 69 | shutil.rmtree(temp_dir) 70 | print("Done") 71 | if data_name == 'CUHK03': 72 | print('Converting cuhk03.mat into images') 73 | cuhk03_to_image(os.path.join(data_dir,'CUHK03')) 74 | print('Done') 75 | else: 76 | print("Dataset Check Success: %s exists!" %data_name) 77 | 78 | def reiddataset_downloader_all(data_dir): 79 | for k,v in dataset.items(): 80 | reiddataset_downloader(k,data_dir) 81 | 82 | #For United Testing and External Use 83 | if __name__ == "__main__": 84 | parser = argparse.ArgumentParser(description='Dataset Name and Dataset Directory') 85 | parser.add_argument(dest="data_dir", action="store", default="~/Datasets/",help="") 86 | parser.add_argument(dest="data_name", action="store", type=str,help="") 87 | args = parser.parse_args() 88 | reiddataset_downloader(args.data_dir,args.data_name) -------------------------------------------------------------------------------- /torchreid/reid_dataset/import_PETA.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import scipy.io as sio 4 | from collections import Counter 5 | 6 | 7 | def import_peta(dataset_dir): 8 | peta_dir = os.path.join(dataset_dir,'PETA') 9 | if not os.path.exists(peta_dir): 10 | print('Please Download PETA Dataset and check if the sub-folder name exists') 11 | 12 | file_list = os.listdir(peta_dir) 13 | 14 | for name in file_list: 15 | id = name.split('.')[0] 16 | globals()[id] = name 17 | 18 | f = sio.loadmat(os.path.join(dataset_dir, 'PETA.mat')) 19 | attr_dict = {} 20 | nrow, ncol = f['peta'][0][0][0].shape 21 | sem_dict={} 22 | sem_list=[] 23 | for id in range(nrow): 24 | attr_dict[str(id + 1).zfill(5)] = f['peta'][0][0][0][id][4:] 25 | sem_id = str(int("".join(map(str,attr_dict[str(id + 1).zfill(5)])), base=2)) 26 | if sem_id not in sem_list: 27 | sem_dict[sem_id] = [str(id + 1).zfill(5)] 28 | sem_list.append(sem_id) 29 | else: 30 | sem_dict[sem_id].append(str(id + 1).zfill(5)) 31 | attributes = [] # gives the names of attributes, label 32 | for i in range(ncol-4): 33 | attributes.append(f['peta'][0][0][1][i][0][0]) 34 | 35 | # Already know that there are 7769 semantic ids. Randomly pick 6769 for training and 1000 for testing. 36 | new_sem_dict = {} 37 | for sem_id in sem_dict.keys(): 38 | id_list = sem_dict[sem_id] 39 | if not len(id_list)==1: 40 | new_sem_dict[sem_id]=id_list 41 | sem_dict = new_sem_dict.copy() 42 | sem_list = list(sem_dict.keys()) 43 | np.random.seed(1) 44 | num_sem = len(sem_dict.keys()) 45 | training_sems = np.random.choice(num_sem, num_sem-200, replace=False) 46 | 47 | globals()['train'] = {} 48 | globals()['query'] = {} 49 | globals()['gallery'] = {} 50 | globals()['train']['data'] = [] 51 | globals()['train']['ids'] = [] 52 | globals()['query']['data'] = [] 53 | globals()['query']['ids'] = [] 54 | globals()['gallery']['data'] = [] 55 | globals()['gallery']['ids'] = [] 56 | train_attribute = {} 57 | test_attribute = {} 58 | for sem_id in sem_dict.keys(): 59 | id_list = sem_dict[sem_id] 60 | # set a same camid for all images. Camid is not used in this project. 61 | camid = np.int64(1) 62 | if sem_list.index(sem_id) in training_sems: 63 | for id in id_list: 64 | name = globals()[id] 65 | images = os.path.join(peta_dir,name) 66 | globals()['train']['ids'].append(id) 67 | globals()['train']['data'].append([images, np.int64(globals()['train']['ids'].index(id)), id, camid, name]) 68 | train_attribute[id] = attr_dict[id] 69 | else: 70 | for id in id_list: 71 | name = globals()[id] 72 | images = os.path.join(peta_dir,name) 73 | globals()['query']['ids'].append(id) 74 | globals()['gallery']['ids'].append(id) 75 | globals()['query']['data'].append([images, np.int64(globals()['query']['ids'].index(id)), id, camid, name]) 76 | globals()['gallery']['data'].append([images, np.int64(globals()['gallery']['ids'].index(id)), id, camid, name]) 77 | test_attribute[id] = attr_dict[id] 78 | return train, query, gallery, train_attribute, test_attribute, attributes 79 | 80 | -------------------------------------------------------------------------------- /torchreid/reid_dataset/pytorch_prepare.py: -------------------------------------------------------------------------------- 1 | import os 2 | from shutil import copyfile 3 | 4 | def pytorch_prepare(data_dir, dataset_name): 5 | dataset_dir = os.path.join(data_dir, dataset_name) 6 | 7 | if not os.path.isdir(dataset_dir): 8 | print('please change the download_path') 9 | 10 | pytorch_path = os.path.join(dataset_dir , 'pytorch') 11 | 12 | if not os.path.isdir(pytorch_path): 13 | os.mkdir(pytorch_path) 14 | #----------------------------------------- 15 | #query 16 | print('generatring ' + dataset_name + ' query images.') 17 | query_dir = os.path.join(dataset_dir , 'query') 18 | query_save_dir = os.path.join(dataset_dir , 'pytorch', 'query') 19 | if not os.path.isdir(query_save_dir): 20 | os.mkdir(query_save_dir) 21 | 22 | for root, dirs, files in os.walk(query_dir, topdown=True): 23 | for name in files: 24 | if not name[-3:]=='jpg': 25 | continue 26 | ID = name.split('_') 27 | src_dir = os.path.join(query_dir , name) 28 | dst_dir = os.path.join(query_save_dir, ID[0]) 29 | if not os.path.isdir(dst_dir): 30 | os.mkdir(dst_dir) 31 | copyfile(src_dir, os.path.join(dst_dir , name)) 32 | #----------------------------------------- 33 | #gallery 34 | print('generatring '+dataset_name+' gallery images.') 35 | gallery_dir = os.path.join(dataset_dir , 'bounding_box_test') 36 | gallery_save_dir = os.path.join(dataset_dir , 'pytorch' , 'gallery') 37 | if not os.path.isdir(gallery_save_dir): 38 | os.mkdir(gallery_save_dir) 39 | 40 | for root, dirs, files in os.walk(gallery_dir, topdown=True): 41 | for name in files: 42 | if not name[-3:]=='jpg': 43 | continue 44 | ID = name.split('_') 45 | src_dir = os.path.join(gallery_dir, name) 46 | dst_dir = os.path.join(gallery_save_dir, ID[0]) 47 | if not os.path.isdir(dst_dir): 48 | os.mkdir(dst_dir) 49 | copyfile(src_dir, os.path.join(dst_dir,name)) 50 | #--------------------------------------- 51 | #train_all 52 | print('generatring '+dataset_name + ' all training images.') 53 | train_dir = os.path.join( dataset_dir , 'bounding_box_train') 54 | train_save_all_dir = os.path.join( dataset_dir , 'pytorch', 'train_all') 55 | if not os.path.isdir(train_save_all_dir): 56 | os.mkdir(train_save_all_dir) 57 | 58 | for root, dirs, files in os.walk(train_dir, topdown=True): 59 | for name in files: 60 | if not name[-3:]=='jpg': 61 | continue 62 | ID = name.split('_') 63 | src_dir = os.path.join(train_dir , name) 64 | dst_dir = os.path.join(train_save_all_dir, ID[0]) 65 | if not os.path.isdir(dst_dir): 66 | os.mkdir(dst_dir) 67 | copyfile(src_dir, os.path.join(dst_dir, name)) 68 | 69 | #--------------------------------------- 70 | #train_val 71 | print('generatring '+ dataset_name+' training and validation images.') 72 | train_save_dir = os.path.join(dataset_dir, 'pytorch', 'train') 73 | val_save_dir = os.path.join(dataset_dir , 'pytorch' , 'val') 74 | if not os.path.isdir(train_save_dir): 75 | os.mkdir(train_save_dir) 76 | os.mkdir(val_save_dir) 77 | 78 | for root, dirs, files in os.walk(train_dir, topdown=True): 79 | for name in files: 80 | if not name[-3:]=='jpg': 81 | continue 82 | ID = name.split('_') 83 | src_dir = os.path.join(train_dir , name) 84 | dst_dir = os.path.join(train_save_dir , ID[0]) 85 | if not os.path.isdir(dst_dir): 86 | os.mkdir(dst_dir) 87 | dst_dir = os.path.join(val_save_dir, ID[0]) #first image is used as val image 88 | os.mkdir(dst_dir) 89 | copyfile(src_dir, os.path.join(dst_dir , name)) 90 | print('Finished ' + dataset_name) 91 | else: 92 | print(dataset_name + ' pytorch directory exists!') 93 | 94 | def pytorch_prepare_all(data_dir): 95 | pytorch_prepare('Market1501', data_dir) 96 | pytorch_prepare('DukeMTMC', data_dir) 97 | -------------------------------------------------------------------------------- /torchreid/models/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | import torchvision 8 | 9 | 10 | __all__ = ['MobileNetV2'] 11 | 12 | 13 | class ConvBlock(nn.Module): 14 | """Basic convolutional block: 15 | convolution (bias discarded) + batch normalization + relu6. 16 | 17 | Args (following http://pytorch.org/docs/master/nn.html#torch.nn.Conv2d): 18 | in_c (int): number of input channels. 19 | out_c (int): number of output channels. 20 | k (int or tuple): kernel size. 21 | s (int or tuple): stride. 22 | p (int or tuple): padding. 23 | g (int): number of blocked connections from input channels 24 | to output channels (default: 1). 25 | """ 26 | def __init__(self, in_c, out_c, k, s=1, p=0, g=1): 27 | super(ConvBlock, self).__init__() 28 | self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p, bias=False, groups=g) 29 | self.bn = nn.BatchNorm2d(out_c) 30 | 31 | def forward(self, x): 32 | return F.relu6(self.bn(self.conv(x))) 33 | 34 | 35 | class Bottleneck(nn.Module): 36 | def __init__(self, in_channels, out_channels, expansion_factor, stride): 37 | super(Bottleneck, self).__init__() 38 | mid_channels = in_channels * expansion_factor 39 | self.use_residual = stride == 1 and in_channels == out_channels 40 | self.conv1 = ConvBlock(in_channels, mid_channels, 1) 41 | self.dwconv2 = ConvBlock(mid_channels, mid_channels, 3, stride, 1, g=mid_channels) 42 | self.conv3 = nn.Sequential( 43 | nn.Conv2d(mid_channels, out_channels, 1, bias=False), 44 | nn.BatchNorm2d(out_channels), 45 | ) 46 | 47 | def forward(self, x): 48 | m = self.conv1(x) 49 | m = self.dwconv2(m) 50 | m = self.conv3(m) 51 | if self.use_residual: 52 | return x + m 53 | else: 54 | return m 55 | 56 | 57 | 58 | class MobileNetV2(nn.Module): 59 | """ 60 | MobileNetV2 61 | 62 | Reference: 63 | Sandler et al. MobileNetV2: Inverted Residuals and Linear Bottlenecks. CVPR 2018. 64 | """ 65 | def __init__(self, num_classes, loss={'xent'}, **kwargs): 66 | super(MobileNetV2, self).__init__() 67 | self.loss = loss 68 | 69 | self.conv1 = ConvBlock(3, 32, 3, s=2, p=1) 70 | self.block2 = Bottleneck(32, 16, 1, 1) 71 | self.block3 = nn.Sequential( 72 | Bottleneck(16, 24, 6, 2), 73 | Bottleneck(24, 24, 6, 1), 74 | ) 75 | self.block4 = nn.Sequential( 76 | Bottleneck(24, 32, 6, 2), 77 | Bottleneck(32, 32, 6, 1), 78 | Bottleneck(32, 32, 6, 1), 79 | ) 80 | self.block5 = nn.Sequential( 81 | Bottleneck(32, 64, 6, 2), 82 | Bottleneck(64, 64, 6, 1), 83 | Bottleneck(64, 64, 6, 1), 84 | Bottleneck(64, 64, 6, 1), 85 | ) 86 | self.block6 = nn.Sequential( 87 | Bottleneck(64, 96, 6, 1), 88 | Bottleneck(96, 96, 6, 1), 89 | Bottleneck(96, 96, 6, 1), 90 | ) 91 | self.block7 = nn.Sequential( 92 | Bottleneck(96, 160, 6, 2), 93 | Bottleneck(160, 160, 6, 1), 94 | Bottleneck(160, 160, 6, 1), 95 | ) 96 | self.block8 = Bottleneck(160, 320, 6, 1) 97 | self.conv9 = ConvBlock(320, 1280, 1) 98 | self.classifier = nn.Linear(1280, num_classes) 99 | self.feat_dim = 1280 100 | 101 | def featuremaps(self, x): 102 | x = self.conv1(x) 103 | x = self.block2(x) 104 | x = self.block3(x) 105 | x = self.block4(x) 106 | x = self.block5(x) 107 | x = self.block6(x) 108 | x = self.block7(x) 109 | x = self.block8(x) 110 | x = self.conv9(x) 111 | return x 112 | 113 | def forward(self, x): 114 | x = self.featuremaps(x) 115 | x = F.avg_pool2d(x, x.size()[2:]).view(x.size(0), -1) 116 | x = F.dropout(x, training=self.training) 117 | 118 | if not self.training: 119 | return x 120 | 121 | y = self.classifier(x) 122 | 123 | if self.loss == {'xent'}: 124 | return y 125 | elif self.loss == {'xent', 'htri'}: 126 | return y, x 127 | else: 128 | raise KeyError("Unsupported loss: {}".format(self.loss)) -------------------------------------------------------------------------------- /torchreid/samplers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | from collections import defaultdict 5 | import numpy as np 6 | import copy 7 | import random 8 | import itertools 9 | 10 | import torch 11 | from torch.utils.data.sampler import Sampler 12 | 13 | 14 | class RandomIdentitySampler(Sampler): 15 | """ 16 | Randomly sample N identities, then for each identity, 17 | randomly sample K instances, therefore batch size is N*K. 18 | 19 | Args: 20 | - data_source (list): list of (img_path, pid, camid). 21 | - num_instances (int): number of instances per identity in a batch. 22 | - batch_size (int): number of examples in a batch. 23 | """ 24 | def __init__(self, data_source, batch_size, num_instances): 25 | self.data_source = data_source 26 | self.batch_size = batch_size 27 | self.num_instances = num_instances 28 | self.num_pids_per_batch = self.batch_size // self.num_instances 29 | self.index_dic = defaultdict(list) 30 | for index, (_, _, _, _, _, _, sem) in enumerate(self.data_source): 31 | self.index_dic[sem].append(index) 32 | self.pids = list(self.index_dic.keys()) 33 | 34 | # estimate number of examples in an epoch 35 | self.length = 0 36 | for pid in self.pids: 37 | idxs = self.index_dic[pid] 38 | num = len(idxs) 39 | if num < self.num_instances: 40 | num = self.num_instances 41 | self.length += num - num % self.num_instances 42 | 43 | def __iter__(self): 44 | batch_idxs_dict = defaultdict(list) 45 | 46 | for pid in self.pids: 47 | idxs = copy.deepcopy(self.index_dic[pid]) 48 | if len(idxs) < self.num_instances: 49 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 50 | random.shuffle(idxs) 51 | batch_idxs = [] 52 | for idx in idxs: 53 | batch_idxs.append(idx) 54 | if len(batch_idxs) == self.num_instances: 55 | batch_idxs_dict[pid].append(batch_idxs) 56 | batch_idxs = [] 57 | 58 | avai_pids = copy.deepcopy(self.pids) 59 | final_idxs = [] 60 | 61 | while len(avai_pids) >= self.num_pids_per_batch: 62 | selected_pids = random.sample(avai_pids, self.num_pids_per_batch) 63 | for pid in selected_pids: 64 | batch_idxs = batch_idxs_dict[pid].pop(0) 65 | final_idxs.extend(batch_idxs) 66 | if len(batch_idxs_dict[pid]) == 0: 67 | avai_pids.remove(pid) 68 | 69 | return iter(final_idxs) 70 | 71 | def __len__(self): 72 | return self.length 73 | 74 | class TwoStreamBatchSampler(Sampler): 75 | """Iterate two sets of indices 76 | An 'epoch' is one iteration through the primary indices. 77 | During the epoch, the secondary indices are iterated through 78 | as many times as needed. 79 | """ 80 | def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size, unlabeled_size_limit=None): 81 | self.primary_indices = primary_indices 82 | self.secondary_indices = secondary_indices 83 | self.secondary_batch_size = secondary_batch_size 84 | self.primary_batch_size = batch_size - secondary_batch_size 85 | self.unlabeled_size_limit = unlabeled_size_limit 86 | 87 | assert len(self.primary_indices) >= self.primary_batch_size > 0 88 | assert len(self.secondary_indices) >= self.secondary_batch_size > 0 89 | 90 | def __iter__(self): 91 | primary_iter = iterate_once(self.primary_indices, self.unlabeled_size_limit) 92 | secondary_iter = iterate_eternally(self.secondary_indices) 93 | return ( 94 | primary_batch + secondary_batch 95 | for (primary_batch, secondary_batch) 96 | in zip(grouper(primary_iter, self.primary_batch_size), 97 | grouper(secondary_iter, self.secondary_batch_size)) 98 | ) 99 | 100 | def __len__(self): 101 | if self.unlabeled_size_limit is None: 102 | return len(self.primary_indices) // self.primary_batch_size 103 | else: 104 | return self.unlabeled_size_limit // self.primary_batch_size 105 | 106 | 107 | def iterate_once(iterable, unlabeled_size_limit=None): 108 | if unlabeled_size_limit is None: 109 | return np.random.permutation(iterable) 110 | else: 111 | result = np.random.permutation(iterable)[:unlabeled_size_limit] 112 | return result 113 | 114 | def iterate_eternally(indices): 115 | def infinite_shuffles(): 116 | while True: 117 | yield np.random.permutation(indices) 118 | return itertools.chain.from_iterable(infinite_shuffles()) 119 | 120 | 121 | def grouper(iterable, n): 122 | "Collect data into fixed-length chunks or blocks" 123 | args = [iter(iterable)] * n 124 | return zip(*args) 125 | -------------------------------------------------------------------------------- /torchreid/models/shufflenet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | import torchvision 8 | 9 | 10 | __all__ = ['ShuffleNet'] 11 | 12 | 13 | class ChannelShuffle(nn.Module): 14 | def __init__(self, num_groups): 15 | super(ChannelShuffle, self).__init__() 16 | self.g = num_groups 17 | 18 | def forward(self, x): 19 | b, c, h, w = x.size() 20 | n = c // self.g 21 | # reshape 22 | x = x.view(b, self.g, n, h, w) 23 | # transpose 24 | x = x.permute(0, 2, 1, 3, 4).contiguous() 25 | # flatten 26 | x = x.view(b, c, h, w) 27 | return x 28 | 29 | 30 | class Bottleneck(nn.Module): 31 | def __init__(self, in_channels, out_channels, stride, num_groups, group_conv1x1=True): 32 | super(Bottleneck, self).__init__() 33 | assert stride in [1, 2], "Warning: stride must be either 1 or 2" 34 | self.stride = stride 35 | mid_channels = out_channels // 4 36 | if stride == 2: out_channels -= in_channels 37 | # group conv is not applied to first conv1x1 at stage 2 38 | num_groups_conv1x1 = num_groups if group_conv1x1 else 1 39 | self.conv1 = nn.Conv2d(in_channels, mid_channels, 1, groups=num_groups_conv1x1, bias=False) 40 | self.bn1 = nn.BatchNorm2d(mid_channels) 41 | self.shuffle1 = ChannelShuffle(num_groups) 42 | self.conv2 = nn.Conv2d(mid_channels, mid_channels, 3, stride=stride, padding=1, groups=mid_channels, bias=False) 43 | self.bn2 = nn.BatchNorm2d(mid_channels) 44 | self.conv3 = nn.Conv2d(mid_channels, out_channels, 1, groups=num_groups, bias=False) 45 | self.bn3 = nn.BatchNorm2d(out_channels) 46 | if stride == 2: self.shortcut = nn.AvgPool2d(3, stride=2, padding=1) 47 | 48 | def forward(self, x): 49 | out = F.relu(self.bn1(self.conv1(x))) 50 | out = self.shuffle1(out) 51 | out = self.bn2(self.conv2(out)) 52 | out = self.bn3(self.conv3(out)) 53 | if self.stride == 2: 54 | res = self.shortcut(x) 55 | out = F.relu(torch.cat([res, out], 1)) 56 | else: 57 | out = F.relu(x + out) 58 | return out 59 | 60 | 61 | # configuration of (num_groups: #out_channels) based on Table 1 in the paper 62 | cfg = { 63 | 1: [144, 288, 576], 64 | 2: [200, 400, 800], 65 | 3: [240, 480, 960], 66 | 4: [272, 544, 1088], 67 | 8: [384, 768, 1536], 68 | } 69 | 70 | 71 | class ShuffleNet(nn.Module): 72 | """ 73 | ShuffleNet 74 | 75 | Reference: 76 | Zhang et al. ShuffleNet: An Extremely Efficient Convolutional Neural 77 | Network for Mobile Devices. CVPR 2018. 78 | """ 79 | def __init__(self, num_classes, loss={'xent'}, num_groups=3, **kwargs): 80 | super(ShuffleNet, self).__init__() 81 | self.loss = loss 82 | 83 | self.conv1 = nn.Sequential( 84 | nn.Conv2d(3, 24, 3, stride=2, padding=1, bias=False), 85 | nn.BatchNorm2d(24), 86 | nn.ReLU(), 87 | nn.MaxPool2d(3, stride=2, padding=1), 88 | ) 89 | 90 | self.stage2 = nn.Sequential( 91 | Bottleneck(24, cfg[num_groups][0], 2, num_groups, group_conv1x1=False), 92 | Bottleneck(cfg[num_groups][0], cfg[num_groups][0], 1, num_groups), 93 | Bottleneck(cfg[num_groups][0], cfg[num_groups][0], 1, num_groups), 94 | Bottleneck(cfg[num_groups][0], cfg[num_groups][0], 1, num_groups), 95 | ) 96 | 97 | self.stage3 = nn.Sequential( 98 | Bottleneck(cfg[num_groups][0], cfg[num_groups][1], 2, num_groups), 99 | Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups), 100 | Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups), 101 | Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups), 102 | Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups), 103 | Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups), 104 | Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups), 105 | Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups), 106 | ) 107 | 108 | self.stage4 = nn.Sequential( 109 | Bottleneck(cfg[num_groups][1], cfg[num_groups][2], 2, num_groups), 110 | Bottleneck(cfg[num_groups][2], cfg[num_groups][2], 1, num_groups), 111 | Bottleneck(cfg[num_groups][2], cfg[num_groups][2], 1, num_groups), 112 | Bottleneck(cfg[num_groups][2], cfg[num_groups][2], 1, num_groups), 113 | ) 114 | 115 | self.classifier = nn.Linear(cfg[num_groups][2], num_classes) 116 | self.feat_dim = cfg[num_groups][2] 117 | 118 | def forward(self, x): 119 | x = self.conv1(x) 120 | x = self.stage2(x) 121 | x = self.stage3(x) 122 | x = self.stage4(x) 123 | x = F.avg_pool2d(x, x.size()[2:]).view(x.size(0), -1) 124 | 125 | if not self.training: 126 | return x 127 | 128 | y = self.classifier(x) 129 | 130 | if self.loss == {'xent'}: 131 | return y 132 | elif self.loss == {'xent', 'htri'}: 133 | return y, x 134 | else: 135 | raise KeyError("Unsupported loss: {}".format(self.loss)) -------------------------------------------------------------------------------- /torchreid/utils/reidtools.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | import numpy as np 5 | import os 6 | import os.path as osp 7 | import shutil 8 | from PIL import Image,ImageDraw 9 | from .iotools import mkdir_if_missing 10 | 11 | 12 | def visualize_ranked_results(label, distmat, dataset, save_dir='log/ranked_results', topk=20): 13 | """ 14 | Visualize ranked results 15 | 16 | Support both imgreid and vidreid 17 | 18 | Args: 19 | - distmat: distance matrix of shape (num_query, num_gallery). 20 | - dataset: a 2-tuple containing (query, gallery), each contains a list of (img_path, pid, camid); 21 | for imgreid, img_path is a string, while for vidreid, img_path is a tuple containing 22 | a sequence of strings. 23 | - save_dir: directory to save output images. 24 | - topk: int, denoting top-k images in the rank list to be visualized. 25 | """ 26 | num_q, num_g = distmat.shape 27 | 28 | print("Visualizing top-{} ranks".format(topk)) 29 | print("# query: {}\n# gallery {}".format(num_q, num_g)) 30 | print("Saving images to '{}'".format(save_dir)) 31 | 32 | query, gallery = dataset 33 | assert num_q == len(query) 34 | assert num_g == len(gallery) 35 | 36 | indices = np.argsort(distmat, axis=1) 37 | mkdir_if_missing(save_dir) 38 | 39 | def _cp_img_to(src, dst, rank, prefix): 40 | """ 41 | - src: image path or tuple (for vidreid) 42 | - dst: target directory 43 | - rank: int, denoting ranked position, starting from 1 44 | - prefix: string 45 | """ 46 | if isinstance(src, tuple) or isinstance(src, list): 47 | dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3)) 48 | mkdir_if_missing(dst) 49 | for img_path in src: 50 | shutil.copy(img_path, dst) 51 | else: 52 | dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3) + '_name_' + osp.basename(src)) 53 | shutil.copy(src, dst) 54 | 55 | for q_idx in range(num_q): 56 | # data, label, id, camid, name, sem 57 | _, qlabel, qpid, qcamid, qimg_path, qsem = query[q_idx] 58 | qlabel = ''.join(map(str, qlabel)) 59 | qdir = osp.join(save_dir, 'sem_'+str(qsem)+'_'+qlabel) 60 | mkdir_if_missing(qdir) 61 | img = Image.new('RGB', (130, 400), color='black') 62 | d = ImageDraw.Draw(img) 63 | test_label = label 64 | d.text((10, 10), 'Query Attributes', fill=(255, 255, 0)) 65 | y = 30 66 | for i in range(len(test_label)): 67 | d.text((10, y), test_label[i].ljust(15) + ': '+qlabel[i], fill=(255, 255, 0)) 68 | y += 12 69 | qimgdir = osp.join(qdir, 'query_top000_'+str(qsem)+'_exp_'+str(qpid)+'.png') 70 | img.save(qimgdir) 71 | 72 | rank_idx = 1 73 | for g_idx in indices[q_idx,:]: 74 | _, glabel, gpid, gcamid, gimg_path, gsem = gallery[g_idx] 75 | _cp_img_to(gimg_path, qdir, rank=rank_idx, prefix='gallery_'+str(gsem)) 76 | rank_idx += 1 77 | if rank_idx > topk: 78 | break 79 | 80 | print("Done") 81 | 82 | 83 | def visualize_ranked_results_train(label, distmat, dataset, save_dir='log/ranked_results', topk=20): 84 | """ 85 | Visualize ranked results 86 | 87 | Support both imgreid and vidreid 88 | 89 | Args: 90 | - distmat: distance matrix of shape (num_query, num_gallery). 91 | - dataset: a 2-tuple containing (query, gallery), each contains a list of (img_path, pid, camid); 92 | for imgreid, img_path is a string, while for vidreid, img_path is a tuple containing 93 | a sequence of strings. 94 | - save_dir: directory to save output images. 95 | - topk: int, denoting top-k images in the rank list to be visualized. 96 | """ 97 | num_q, num_g = distmat.shape 98 | 99 | print("Visualizing top-{} ranks".format(topk)) 100 | print("# query: {}\n# gallery {}".format(num_q, num_g)) 101 | print("Saving images to '{}'".format(save_dir)) 102 | 103 | query, gallery = dataset 104 | assert num_q == len(query) 105 | assert num_g == len(gallery) 106 | 107 | indices = np.argsort(distmat, axis=1) 108 | mkdir_if_missing(save_dir) 109 | 110 | def _cp_img_to(src, dst, rank, prefix): 111 | """ 112 | - src: image path or tuple (for vidreid) 113 | - dst: target directory 114 | - rank: int, denoting ranked position, starting from 1 115 | - prefix: string 116 | """ 117 | if isinstance(src, tuple) or isinstance(src, list): 118 | dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3)) 119 | mkdir_if_missing(dst) 120 | for img_path in src: 121 | shutil.copy(img_path, dst) 122 | else: 123 | dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3) + '_name_' + osp.basename(src)) 124 | shutil.copy(src, dst) 125 | 126 | for q_idx in range(num_q): 127 | # data, i, label, id, cam, name, sem 128 | _, _, qlabel, qpid, qcamid, qimg_path, qsem = query[q_idx] 129 | qlabel = ''.join(map(str, qlabel)) 130 | qdir = osp.join(save_dir, 'sem_' + str(qsem) + '_' + qlabel) 131 | mkdir_if_missing(qdir) 132 | img = Image.new('RGB', (130, 400), color='black') 133 | d = ImageDraw.Draw(img) 134 | test_label = label 135 | d.text((10, 10), 'Query Attributes', fill=(255, 255, 0)) 136 | y = 30 137 | for i in range(len(test_label)): 138 | d.text((10, y), test_label[i].ljust(15) + ': ' + qlabel[i], fill=(255, 255, 0)) 139 | y += 12 140 | qimgdir = osp.join(qdir, 'query_top000_' + str(qsem) + '_exp_' + str(qpid) + '.png') 141 | img.save(qimgdir) 142 | 143 | rank_idx = 1 144 | for g_idx in indices[q_idx, :]: 145 | # data, label, id, camid, img_path, sem 146 | _, glabel, gpid, gcamid, gimg_path, gsem = gallery[g_idx] 147 | _cp_img_to(gimg_path, qdir, rank=rank_idx, prefix='gallery_' + str(gsem)) 148 | rank_idx += 1 149 | if rank_idx > topk: 150 | break 151 | 152 | print("Done") 153 | -------------------------------------------------------------------------------- /torchreid/eval_metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import numpy as np 6 | import copy 7 | from collections import defaultdict 8 | import sys 9 | import warnings 10 | 11 | try: 12 | from torchreid.eval_cylib.eval_metrics_cy import evaluate_cy 13 | IS_CYTHON_AVAI = True 14 | print("Using Cython evaluation code as the backend") 15 | except ImportError: 16 | IS_CYTHON_AVAI = False 17 | warnings.warn("Cython evaluation is UNAVAILABLE, which is highly recommended") 18 | 19 | 20 | def eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank): 21 | """Evaluation with cuhk03 metric 22 | Key: one image for each gallery identity is randomly sampled for each query identity. 23 | Random sampling is performed num_repeats times. 24 | """ 25 | num_repeats = 10 26 | num_q, num_g = distmat.shape 27 | 28 | if num_g < max_rank: 29 | max_rank = num_g 30 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 31 | 32 | indices = np.argsort(distmat, axis=1) 33 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 34 | 35 | # compute cmc curve for each query 36 | all_cmc = [] 37 | all_AP = [] 38 | num_valid_q = 0. # number of valid query 39 | 40 | for q_idx in range(num_q): 41 | # get query pid and camid 42 | q_pid = q_pids[q_idx] 43 | q_camid = q_camids[q_idx] 44 | 45 | # remove gallery samples that have the same pid and camid with query 46 | order = indices[q_idx] 47 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 48 | keep = np.invert(remove) 49 | 50 | # compute cmc curve 51 | raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 52 | if not np.any(raw_cmc): 53 | # this condition is true when query identity does not appear in gallery 54 | continue 55 | 56 | kept_g_pids = g_pids[order][keep] 57 | g_pids_dict = defaultdict(list) 58 | for idx, pid in enumerate(kept_g_pids): 59 | g_pids_dict[pid].append(idx) 60 | 61 | cmc, AP = 0., 0. 62 | for repeat_idx in range(num_repeats): 63 | mask = np.zeros(len(raw_cmc), dtype=np.bool) 64 | for _, idxs in g_pids_dict.items(): 65 | # randomly sample one image for each gallery person 66 | rnd_idx = np.random.choice(idxs) 67 | mask[rnd_idx] = True 68 | masked_raw_cmc = raw_cmc[mask] 69 | _cmc = masked_raw_cmc.cumsum() 70 | _cmc[_cmc > 1] = 1 71 | cmc += _cmc[:max_rank].astype(np.float32) 72 | # compute AP 73 | num_rel = masked_raw_cmc.sum() 74 | tmp_cmc = masked_raw_cmc.cumsum() 75 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 76 | tmp_cmc = np.asarray(tmp_cmc) * masked_raw_cmc 77 | AP += tmp_cmc.sum() / num_rel 78 | 79 | cmc /= num_repeats 80 | AP /= num_repeats 81 | all_cmc.append(cmc) 82 | all_AP.append(AP) 83 | num_valid_q += 1. 84 | 85 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 86 | 87 | all_cmc = np.asarray(all_cmc).astype(np.float32) 88 | all_cmc = all_cmc.sum(0) / num_valid_q 89 | mAP = np.mean(all_AP) 90 | 91 | return all_cmc, mAP 92 | 93 | 94 | def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank): 95 | """Evaluation with market1501 metric 96 | Key: for each query identity, its gallery images from the same camera view are discarded. 97 | """ 98 | num_q, num_g = distmat.shape 99 | 100 | if num_g < max_rank: 101 | max_rank = num_g 102 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 103 | 104 | indices = np.argsort(distmat, axis=1) 105 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 106 | # compute cmc curve for each query 107 | all_cmc = [] 108 | all_AP = [] 109 | num_valid_q = 0. # number of valid query 110 | 111 | for q_idx in range(num_q): 112 | # get query pid and camid 113 | q_pid = q_pids[q_idx] 114 | q_camid = q_camids[q_idx] 115 | 116 | # remove gallery samples that have the same pid and camid with query 117 | order = indices[q_idx] 118 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 119 | keep = np.invert(remove) 120 | 121 | # compute cmc curve 122 | raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 123 | if not np.any(raw_cmc): 124 | # this condition is true when query identity does not appear in gallery 125 | continue 126 | 127 | cmc = raw_cmc.cumsum() 128 | cmc[cmc > 1] = 1 129 | 130 | all_cmc.append(cmc[:max_rank]) 131 | num_valid_q += 1. 132 | 133 | # compute average precision 134 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 135 | num_rel = raw_cmc.sum() 136 | tmp_cmc = raw_cmc.cumsum() 137 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 138 | tmp_cmc = np.asarray(tmp_cmc) * raw_cmc 139 | AP = tmp_cmc.sum() / num_rel 140 | all_AP.append(AP) 141 | 142 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 143 | 144 | all_cmc = np.asarray(all_cmc).astype(np.float32) 145 | all_cmc = all_cmc.sum(0) / num_valid_q 146 | mAP = np.mean(all_AP) 147 | 148 | return all_cmc, mAP 149 | 150 | 151 | def evaluate_py(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03): 152 | if use_metric_cuhk03: 153 | return eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank) 154 | else: 155 | return eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank) 156 | 157 | 158 | def evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50, use_metric_cuhk03=False, use_cython=True): 159 | if use_cython and IS_CYTHON_AVAI: 160 | return evaluate_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03) 161 | else: 162 | return evaluate_py(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03) -------------------------------------------------------------------------------- /torchreid/reid_dataset/import_Market1501Attribute.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .import_Market1501 import * 3 | import scipy.io 4 | 5 | 6 | def import_Market1501Attribute(dataset_dir): 7 | dataset_name = 'Market-1501/attribute' 8 | train,query,test = import_Market1501(dataset_dir) 9 | if not os.path.exists(os.path.join(dataset_dir,dataset_name)): 10 | print('Please Download the Market1501Attribute Dataset') 11 | train_label=['age', 12 | 'backpack', 13 | 'bag', 14 | 'handbag', 15 | 'downblack', 16 | 'downblue', 17 | 'downbrown', 18 | 'downgray', 19 | 'downgreen', 20 | 'downpink', 21 | 'downpurple', 22 | 'downwhite', 23 | 'downyellow', 24 | 'upblack', 25 | 'upblue', 26 | 'upgreen', 27 | 'upgray', 28 | 'uppurple', 29 | 'upred', 30 | 'upwhite', 31 | 'upyellow', 32 | 'clothes', 33 | 'down', 34 | 'up', 35 | 'hair', 36 | 'hat', 37 | 'gender'] 38 | 39 | test_label=['age', 40 | 'backpack', 41 | 'bag', 42 | 'handbag', 43 | 'clothes', 44 | 'down', 45 | 'up', 46 | 'hair', 47 | 'hat', 48 | 'gender', 49 | 'upblack', 50 | 'upwhite', 51 | 'upred', 52 | 'uppurple', 53 | 'upyellow', 54 | 'upgray', 55 | 'upblue', 56 | 'upgreen', 57 | 'downblack', 58 | 'downwhite', 59 | 'downpink', 60 | 'downpurple', 61 | 'downyellow', 62 | 'downgray', 63 | 'downblue', 64 | 'downgreen', 65 | 'downbrown' 66 | ] 67 | 68 | train_person_id = [] 69 | for personid in train: 70 | train_person_id.append(personid) 71 | train_person_id.sort(key=int) 72 | 73 | test_person_id = [] 74 | for personid in test: 75 | test_person_id.append(personid) 76 | test_person_id.sort(key=int) 77 | test_person_id.remove('-1') 78 | test_person_id.remove('0000') 79 | 80 | f = scipy.io.loadmat(os.path.join(dataset_dir,dataset_name,'market_attribute.mat')) 81 | 82 | test_attribute = {} 83 | train_attribute = {} 84 | for test_train in range(len(f['market_attribute'][0][0])): 85 | if test_train == 0: 86 | id_list_name = 'test_person_id' 87 | group_name = 'test_attribute' 88 | else: 89 | id_list_name = 'train_person_id' 90 | group_name = 'train_attribute' 91 | for attribute_id in range(len(f['market_attribute'][0][0][test_train][0][0])): 92 | for person_id in range(len(f['market_attribute'][0][0][test_train][0][0][attribute_id][0])): 93 | id = locals()[id_list_name][person_id] 94 | if id not in locals()[group_name]: 95 | locals()[group_name][id]=[] 96 | locals()[group_name][id].append(f['market_attribute'][0][0][test_train][0][0][attribute_id][0][person_id]) 97 | 98 | unified_train_atr = {} 99 | for k,v in train_attribute.items(): 100 | temp_atr = [0]*len(test_label) 101 | for i in range(len(test_label)): 102 | temp_atr[i]=v[train_label.index(test_label[i])] 103 | unified_train_atr[k] = temp_atr 104 | 105 | return unified_train_atr, test_attribute, test_label 106 | 107 | 108 | def import_Market1501Attribute_binary(dataset_dir): 109 | train_market_attr, test_market_attr, label = import_Market1501Attribute(dataset_dir) 110 | 111 | for id in train_market_attr: 112 | train_market_attr[id][:] = [x - 1 for x in train_market_attr[id]] 113 | if train_market_attr[id][0] == 0: 114 | train_market_attr[id].pop(0) 115 | train_market_attr[id].insert(0, 1) 116 | train_market_attr[id].insert(1, 0) 117 | train_market_attr[id].insert(2, 0) 118 | train_market_attr[id].insert(3, 0) 119 | elif train_market_attr[id][0] == 1: 120 | train_market_attr[id].pop(0) 121 | train_market_attr[id].insert(0, 0) 122 | train_market_attr[id].insert(1, 1) 123 | train_market_attr[id].insert(2, 0) 124 | train_market_attr[id].insert(3, 0) 125 | elif train_market_attr[id][0] == 2: 126 | train_market_attr[id].pop(0) 127 | train_market_attr[id].insert(0, 0) 128 | train_market_attr[id].insert(1, 0) 129 | train_market_attr[id].insert(2, 1) 130 | train_market_attr[id].insert(3, 0) 131 | elif train_market_attr[id][0] == 3: 132 | train_market_attr[id].pop(0) 133 | train_market_attr[id].insert(0, 0) 134 | train_market_attr[id].insert(1, 0) 135 | train_market_attr[id].insert(2, 0) 136 | train_market_attr[id].insert(3, 1) 137 | 138 | for id in test_market_attr: 139 | test_market_attr[id][:] = [x - 1 for x in test_market_attr[id]] 140 | if test_market_attr[id][0] == 0: 141 | test_market_attr[id].pop(0) 142 | test_market_attr[id].insert(0, 1) 143 | test_market_attr[id].insert(1, 0) 144 | test_market_attr[id].insert(2, 0) 145 | test_market_attr[id].insert(3, 0) 146 | elif test_market_attr[id][0] == 1: 147 | test_market_attr[id].pop(0) 148 | test_market_attr[id].insert(0, 0) 149 | test_market_attr[id].insert(1, 1) 150 | test_market_attr[id].insert(2, 0) 151 | test_market_attr[id].insert(3, 0) 152 | elif test_market_attr[id][0] == 2: 153 | test_market_attr[id].pop(0) 154 | test_market_attr[id].insert(0, 0) 155 | test_market_attr[id].insert(1, 0) 156 | test_market_attr[id].insert(2, 1) 157 | test_market_attr[id].insert(3, 0) 158 | elif test_market_attr[id][0] == 3: 159 | test_market_attr[id].pop(0) 160 | test_market_attr[id].insert(0, 0) 161 | test_market_attr[id].insert(1, 0) 162 | test_market_attr[id].insert(2, 0) 163 | test_market_attr[id].insert(3, 1) 164 | 165 | label.pop(0) 166 | label.insert(0,'young') 167 | label.insert(1,'teenager') 168 | label.insert(2,'adult') 169 | label.insert(3,'old') 170 | 171 | return train_market_attr, test_market_attr, label -------------------------------------------------------------------------------- /torchreid/data_manager.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | from torch.utils.data import DataLoader 5 | import os 6 | #from .datasets import init_imgreid_dataset, init_vidreid_dataset 7 | from .transforms import build_transforms 8 | from .samplers import RandomIdentitySampler 9 | from .folder import Train_Dataset, Test_Dataset, Base_Dataset 10 | from math import ceil 11 | 12 | class BaseDataManager(object): 13 | 14 | @property 15 | def num_train_pids(self): 16 | return self._num_train_pids 17 | 18 | @property 19 | def num_train_cams(self): 20 | return self._num_train_cams 21 | 22 | @property 23 | def num_train_sems(self): 24 | return self._num_train_sems 25 | 26 | @property 27 | def num_train_attrs(self): 28 | return self._num_train_attributes 29 | 30 | 31 | def return_dataloaders(self): 32 | """ 33 | Return trainloader and testloader dictionary 34 | """ 35 | return self.trainloader, self.testloader_dict 36 | 37 | def return_testdataset_by_name(self, name): 38 | """ 39 | Return query and gallery, each containing a list of (img_path, pid, camid). 40 | """ 41 | return self.testdataset_dict[name]['query'], self.testdataset_dict[name]['gallery'] 42 | 43 | ''' 44 | THIS IS DESIGNED FOR VISUALISING THE DATA. NOT IMPLEMENTING AT THIS MOMENT. 45 | ----------------------------------------------------------------------------- 46 | def return_testdataset_by_name(self, name): 47 | """ 48 | Return query and gallery, each containing a list of (img_path, pid, camid). 49 | """ 50 | return self.testdataset_dict[name]['query'], self.testdataset_dict[name]['gallery'] 51 | ''' 52 | 53 | class ImageDataManager(BaseDataManager): 54 | """ 55 | Image-ReID data manager 56 | """ 57 | 58 | def __init__(self, 59 | use_gpu, 60 | source_names, 61 | target_names, 62 | root, 63 | num_train, 64 | seed, 65 | split_id=0, 66 | height=256, 67 | width=128, 68 | train_batch_size=128, 69 | test_batch_size=100, 70 | workers=4, 71 | train_sampler='', 72 | num_instances=4, # number of instances per identity (for RandomIdentitySampler) 73 | ): 74 | super(ImageDataManager, self).__init__() 75 | self.use_gpu = use_gpu 76 | self.source_names = source_names 77 | self.target_names = target_names 78 | self.root = root 79 | self.split_id = split_id 80 | self.height = height 81 | self.width = width 82 | self.train_batch_size = train_batch_size 83 | self.test_batch_size = test_batch_size 84 | self.workers = workers 85 | self.train_sampler = train_sampler 86 | self.num_instances = num_instances 87 | self.pin_memory = True if self.use_gpu else False 88 | 89 | # Build train and test transform functions 90 | transform_train = build_transforms(self.height, self.width, is_train=True) 91 | transform_test = build_transforms(self.height, self.width, is_train=False) 92 | 93 | print("=> Initializing TRAIN (source) datasets") 94 | 95 | self._num_train_pids = 0 96 | self._num_train_cams = 0 97 | 98 | name = self.source_names[0] 99 | base_dataset = Base_Dataset(self.root, dataset_name=name, num_train=num_train, seed=seed) 100 | train_dataset = Train_Dataset(base_dataset, train_val='train') 101 | self._num_train_pids = train_dataset.num_ids 102 | self._num_train_cams = train_dataset.num_cam 103 | self._num_train_images = train_dataset.num_images 104 | self._num_train_attributes = train_dataset.num_labels 105 | self._num_train_sems=base_dataset.num_sems 106 | self.total_sem = base_dataset.total_sem 107 | self.label = base_dataset.label 108 | self.w = base_dataset.w 109 | self.average = base_dataset.average 110 | self.trainloader = DataLoader( 111 | train_dataset, 112 | batch_size=self.train_batch_size, shuffle=True, num_workers=self.workers, 113 | pin_memory=self.pin_memory, drop_last=True 114 | ) 115 | 116 | 117 | 118 | print("=> Initializing TEST (target) datasets") 119 | self.testloader_dict = {name: {'query': None, 'gallery': None} for name in self.target_names} 120 | self.testdataset_dict = {name: {'query': None, 'gallery': None} for name in self.target_names} 121 | 122 | for name in self.target_names: 123 | 124 | test_dataset = Test_Dataset(base_dataset, query_gallery='all') 125 | query_dataset = Test_Dataset(base_dataset, query_gallery='query') 126 | self.testloader_dict[name]['query'] = DataLoader( 127 | query_dataset, 128 | batch_size=self.test_batch_size, shuffle=False, num_workers=self.workers, 129 | pin_memory=self.pin_memory, drop_last=False 130 | ) 131 | 132 | self.testloader_dict[name]['gallery'] = DataLoader( 133 | test_dataset, 134 | batch_size=self.test_batch_size, shuffle=False, num_workers=self.workers, 135 | pin_memory=self.pin_memory, drop_last=False 136 | ) 137 | self._num_test_images = test_dataset.test_num 138 | self._num_test_sems = base_dataset.test_sem_num 139 | self._num_query_images = train_dataset.num_query 140 | self._num_query_sems = base_dataset.num_query_sem 141 | 142 | ''' 143 | THIS IS DESIGNED FOR VISUALISING THE DATA. 144 | ----------------------------------------------------------------------------- 145 | ''' 146 | self.testdataset_dict[name]['query'] = query_dataset 147 | self.testdataset_dict[name]['gallery'] = test_dataset 148 | 149 | 150 | print("\n") 151 | print(" **************** Summary ****************") 152 | print(" dataset name : {}".format(self.source_names)) 153 | print(" # train images : {}".format(self._num_train_images)) 154 | print(" # train attributes : {}".format(self._num_train_attributes)) 155 | print(" # train categories : {}".format(self._num_train_sems)) 156 | print(" # gallery images : {}".format(self._num_test_images)) 157 | print(" # query/gallery categories : {}".format(self._num_query_images)) 158 | print(" categories in total : {}".format(self.total_sem)) 159 | print(" # batch size : {}".format(self.train_batch_size)) 160 | print(" *****************************************") 161 | print("\n") 162 | -------------------------------------------------------------------------------- /torchreid/models/mudeep.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | import torchvision 8 | 9 | 10 | __all__ = ['MuDeep'] 11 | 12 | 13 | class ConvBlock(nn.Module): 14 | """Basic convolutional block: 15 | convolution + batch normalization + relu. 16 | 17 | Args (following http://pytorch.org/docs/master/nn.html#torch.nn.Conv2d): 18 | in_c (int): number of input channels. 19 | out_c (int): number of output channels. 20 | k (int or tuple): kernel size. 21 | s (int or tuple): stride. 22 | p (int or tuple): padding. 23 | """ 24 | def __init__(self, in_c, out_c, k, s, p): 25 | super(ConvBlock, self).__init__() 26 | self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p) 27 | self.bn = nn.BatchNorm2d(out_c) 28 | 29 | def forward(self, x): 30 | return F.relu(self.bn(self.conv(x))) 31 | 32 | 33 | class ConvLayers(nn.Module): 34 | """Preprocessing layers.""" 35 | def __init__(self): 36 | super(ConvLayers, self).__init__() 37 | self.conv1 = ConvBlock(3, 48, k=3, s=1, p=1) 38 | self.conv2 = ConvBlock(48, 96, k=3, s=1, p=1) 39 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 40 | 41 | def forward(self, x): 42 | x = self.conv1(x) 43 | x = self.conv2(x) 44 | x = self.maxpool(x) 45 | return x 46 | 47 | 48 | class MultiScaleA(nn.Module): 49 | """Multi-scale stream layer A (Sec.3.1)""" 50 | def __init__(self): 51 | super(MultiScaleA, self).__init__() 52 | self.stream1 = nn.Sequential( 53 | ConvBlock(96, 96, k=1, s=1, p=0), 54 | ConvBlock(96, 24, k=3, s=1, p=1), 55 | ) 56 | self.stream2 = nn.Sequential( 57 | nn.AvgPool2d(kernel_size=3, stride=1, padding=1), 58 | ConvBlock(96, 24, k=1, s=1, p=0), 59 | ) 60 | self.stream3 = ConvBlock(96, 24, k=1, s=1, p=0) 61 | self.stream4 = nn.Sequential( 62 | ConvBlock(96, 16, k=1, s=1, p=0), 63 | ConvBlock(16, 24, k=3, s=1, p=1), 64 | ConvBlock(24, 24, k=3, s=1, p=1), 65 | ) 66 | 67 | def forward(self, x): 68 | s1 = self.stream1(x) 69 | s2 = self.stream2(x) 70 | s3 = self.stream3(x) 71 | s4 = self.stream4(x) 72 | y = torch.cat([s1, s2, s3, s4], dim=1) 73 | return y 74 | 75 | 76 | class Reduction(nn.Module): 77 | """Reduction layer (Sec.3.1)""" 78 | def __init__(self): 79 | super(Reduction, self).__init__() 80 | self.stream1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 81 | self.stream2 = ConvBlock(96, 96, k=3, s=2, p=1) 82 | self.stream3 = nn.Sequential( 83 | ConvBlock(96, 48, k=1, s=1, p=0), 84 | ConvBlock(48, 56, k=3, s=1, p=1), 85 | ConvBlock(56, 64, k=3, s=2, p=1), 86 | ) 87 | 88 | def forward(self, x): 89 | s1 = self.stream1(x) 90 | s2 = self.stream2(x) 91 | s3 = self.stream3(x) 92 | y = torch.cat([s1, s2, s3], dim=1) 93 | return y 94 | 95 | 96 | class MultiScaleB(nn.Module): 97 | """Multi-scale stream layer B (Sec.3.1)""" 98 | def __init__(self): 99 | super(MultiScaleB, self).__init__() 100 | self.stream1 = nn.Sequential( 101 | nn.AvgPool2d(kernel_size=3, stride=1, padding=1), 102 | ConvBlock(256, 256, k=1, s=1, p=0), 103 | ) 104 | self.stream2 = nn.Sequential( 105 | ConvBlock(256, 64, k=1, s=1, p=0), 106 | ConvBlock(64, 128, k=(1, 3), s=1, p=(0, 1)), 107 | ConvBlock(128, 256, k=(3, 1), s=1, p=(1, 0)), 108 | ) 109 | self.stream3 = ConvBlock(256, 256, k=1, s=1, p=0) 110 | self.stream4 = nn.Sequential( 111 | ConvBlock(256, 64, k=1, s=1, p=0), 112 | ConvBlock(64, 64, k=(1, 3), s=1, p=(0, 1)), 113 | ConvBlock(64, 128, k=(3, 1), s=1, p=(1, 0)), 114 | ConvBlock(128, 128, k=(1, 3), s=1, p=(0, 1)), 115 | ConvBlock(128, 256, k=(3, 1), s=1, p=(1, 0)), 116 | ) 117 | 118 | def forward(self, x): 119 | s1 = self.stream1(x) 120 | s2 = self.stream2(x) 121 | s3 = self.stream3(x) 122 | s4 = self.stream4(x) 123 | return s1, s2, s3, s4 124 | 125 | 126 | class Fusion(nn.Module): 127 | """Saliency-based learning fusion layer (Sec.3.2)""" 128 | def __init__(self): 129 | super(Fusion, self).__init__() 130 | self.a1 = nn.Parameter(torch.rand(1, 256, 1, 1)) 131 | self.a2 = nn.Parameter(torch.rand(1, 256, 1, 1)) 132 | self.a3 = nn.Parameter(torch.rand(1, 256, 1, 1)) 133 | self.a4 = nn.Parameter(torch.rand(1, 256, 1, 1)) 134 | 135 | # We add an average pooling layer to reduce the spatial dimension 136 | # of feature maps, which differs from the original paper. 137 | self.avgpool = nn.AvgPool2d(kernel_size=4, stride=4, padding=0) 138 | 139 | def forward(self, x1, x2, x3, x4): 140 | s1 = self.a1.expand_as(x1) * x1 141 | s2 = self.a2.expand_as(x2) * x2 142 | s3 = self.a3.expand_as(x3) * x3 143 | s4 = self.a4.expand_as(x4) * x4 144 | y = self.avgpool(s1 + s2 + s3 + s4) 145 | return y 146 | 147 | 148 | class MuDeep(nn.Module): 149 | """ 150 | Multiscale deep neural network. 151 | 152 | Reference: 153 | Qian et al. Multi-scale Deep Learning Architectures for Person Re-identification. ICCV 2017. 154 | """ 155 | def __init__(self, num_classes, loss={'xent'}, **kwargs): 156 | super(MuDeep, self).__init__() 157 | self.loss = loss 158 | 159 | self.block1 = ConvLayers() 160 | self.block2 = MultiScaleA() 161 | self.block3 = Reduction() 162 | self.block4 = MultiScaleB() 163 | self.block5 = Fusion() 164 | 165 | # Due to this fully connected layer, input image has to be fixed 166 | # in shape, i.e. (3, 256, 128), such that the last convolutional feature 167 | # maps are of shape (256, 16, 8). If input shape is changed, 168 | # the input dimension of this layer has to be changed accordingly. 169 | self.fc = nn.Sequential( 170 | nn.Linear(256*16*8, 4096), 171 | nn.BatchNorm1d(4096), 172 | nn.ReLU(), 173 | ) 174 | self.classifier = nn.Linear(4096, num_classes) 175 | self.feat_dim = 4096 176 | 177 | def featuremaps(self, x): 178 | x = self.block1(x) 179 | x = self.block2(x) 180 | x = self.block3(x) 181 | x = self.block4(x) 182 | x = self.block5(*x) 183 | return x 184 | 185 | def forward(self, x): 186 | x = self.featuremaps(x) 187 | x = x.view(x.size(0), -1) 188 | x = self.fc(x) 189 | y = self.classifier(x) 190 | 191 | if self.loss == {'xent'}: 192 | return y 193 | elif self.loss == {'xent', 'htri'}: 194 | return y, x 195 | else: 196 | raise KeyError("Unsupported loss: {}".format(self.loss)) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Symbiotic Adversarial Learning for Attribute-based Person Search 2 | PyTorch implementation of [Symbiotic Adversarial Learning for Attribute-based Person Search](https://arxiv.org/abs/2007.09609) (ECCV2020). 3 | 4 | ## Update 5 | 01/08/21: Updated [requirements.txt](requirements.txt). Upgrading Pillow to version 8.2.0 or later due to some issues in Pillow. 6 | 7 | 28/02/21: Updated [train.py](train.py) and [scripts](scripts). Uploaded [trained models](https://hkustconnect-my.sharepoint.com/:f:/g/personal/ycaoaf_connect_ust_hk/EqOH1p24IvtCid8o914_ai8BrP1SuXZT56JQbEhoVP_IxA?e=hN5IWI). 8 | 9 | ## Problem setting 10 |
11 |
12 |
16 |
17 |