├── .gitignore ├── LICENSE ├── README.md ├── config.py ├── datasets ├── __init__.py ├── data_loader.py ├── data_manager.py └── samplers.py ├── main_reid.py ├── models ├── __init__.py ├── networks.py └── resnet.py ├── requirements.txt ├── trainers ├── __init__.py ├── evaluator.py ├── re_ranking.py └── trainer.py └── utils ├── DistWeightDevianceLoss.py ├── LiftedStructure.py ├── __init__.py ├── loss.py ├── meters.py ├── random_erasing.py ├── serialization.py ├── transforms.py └── validation_metrics.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | **__pycache__ 3 | .DS_Store 4 | data/ 5 | pytorch-ckpt/ 6 | .vscode/ 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Dai Zuozhuo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Batch DropBlock Network for Person Re-identification and Beyond 2 | Official source code of paper https://arxiv.org/abs/1811.07130 3 | 4 | ## Update on 2019.3.15 5 | Update CUHK03 results. 6 | 7 | ## Update on 2019.1.29 8 | Traning scripts are released. The best Markt1501 result is 95.3%! Please look at the training section of README.md. 9 | 10 | ## Update on 2019.1.23 11 | In-Shop Clothes Retrieval dataset and pretrained model are released!. The rank-1 result is 89.5 which is a litter bit higher than paper reported. 12 | 13 | ## This paper is accepted by ICCV 2019. Please cite if you use this code in your research. 14 | 15 | ``` 16 | @article{dai2018batch, 17 | title={Batch DropBlock Network for Person Re-identification and Beyond}, 18 | author={Dai, Zuozhuo and Chen, Mingqiang and Gu, Xiaodong and Zhu, Siyu and Tan, Ping}, 19 | journal={arXiv preprint arXiv:1811.07130}, 20 | year={2018} 21 | } 22 | ``` 23 | 24 | ## Setup running environment 25 | This project requires python3, cython, torch, torchvision, scikit-learn, tensorboardX, fire. 26 | The baseline source code is borrowed from https://github.com/L1aoXingyu/reid_baseline. 27 | 28 | ## Prepare dataset 29 | 30 | Create a directory to store reid datasets under this repo via 31 | ```bash 32 | cd reid 33 | mkdir data 34 | ``` 35 | 36 | For market1501 dataset, 37 | 1. Download Market1501 dataset to `data/` from http://www.liangzheng.org/Project/project_reid.html 38 | 2. Extract dataset and rename to `market1501`. The data structure would like: 39 | ``` 40 | market1501/ 41 | bounding_box_test/ 42 | bounding_box_train/ 43 | query/ 44 | ``` 45 | 46 | For CUHK03 dataset, 47 | 1. Download CUHK03-NP dataset from https://github.com/zhunzhong07/person-re-ranking/tree/master/CUHK03-NP 48 | 2. Extract dataset and rename folers inside it to cuhk-detect and cuhk-label. 49 | For DukeMTMC-reID dataset, 50 | Dowload from https://github.com/layumi/DukeMTMC-reID_evaluation 51 | 52 | For In-Shop Clothes dataset, 53 | 1. Downlaod clothes dataset from http://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/bfe_models/clothes.tar 54 | 2. Extract dataset and put it to `data/` folder. 55 | 56 | ## Results 57 | 58 | Dataset | CUHK03-Label | CUHK03-Detect | DukeMTMC re-ID | Market1501 | In-Shop Clothes| 59 | --------|--------------|---------------|-----------------|------------|----------------| 60 | Rank-1 | 79.4 | 76.4 | 88.9 | 95.3 |89.5 | 61 | mAP | 76.7 | 73.5 | 75.9 | 86.2 |72.3 | 62 | model | [aliyun](http://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/bfe_models/cuhk-label-794.pth.tar)| [aliyun](http://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/bfe_models/cuhk-detect-764.pth.tar)] | [aliyun](http://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/bfe_models/duke_887.pth.tar) | [aliyun](http://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/bfe_models/market_953.pth.tar)|[aliyun](http://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/bfe_models/clothes_895.pth.tar) 63 | 64 | You can download the pre-trained models from the above table and evaluate on person re-ID datasets. 65 | For example, to evaluate CUHK03-Label dataset, you can download the model to './pytorch-ckpt/cuhk_label_bfe' directory and run the following commands. 66 | 67 | ### Evaluate Market1501 68 | ```bash 69 | python3 main_reid.py train --save_dir='./pytorch-ckpt/market_bfe' --model_name=bfe --train_batch=32 --test_batch=32 --dataset=market1501 --pretrained_model='./pytorch-ckpt/market_bfe/944.pth.tar' --evaluate 70 | ``` 71 | ### Evaluate CUHK03-Label 72 | ```bash 73 | python3 main_reid.py train --save_dir='./pytorch-ckpt/cuhk_label_bfe' --model_name=bfe --train_batch=32 --test_batch=32 --dataset=cuhk-label --pretrained_model='./pytorch-ckpt/cuhk_label_bfe/750.pth.tar' --evaluate 74 | ``` 75 | ### Evaluate In-Shop clothes 76 | ```bash 77 | python main_reid.py train --save_dir='./pytorch-ckpt/clothes_bfe' --model_name=bfe --pretrained_model='./pytorch-ckpt/clothes_bfe/clothes_895.pth.tar' --test_batch=32 --dataset=clothes --evaluate 78 | ``` 79 | 80 | ## Training 81 | 82 | ### Traning Market1501 83 | ```bash 84 | python main_reid.py train --save_dir='./pytorch-ckpt/market-bfe' --max_epoch=400 --eval_step=30 --dataset=market1501 --test_batch=128 --train_batch=128 --optim=adam --adjust_lr 85 | ``` 86 | This traning command is tested on 4 GTX1080 gpus. Here is [training log](http://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/bfe_models/market_953.txt). You shoud get a result around 95%. 87 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import warnings 3 | import numpy as np 4 | 5 | 6 | class DefaultConfig(object): 7 | seed = 0 8 | 9 | # dataset options 10 | dataset = 'market1501' 11 | datatype = 'person' 12 | mode = 'retrieval' 13 | # optimization options 14 | loss = 'triplet' 15 | optim = 'adam' 16 | max_epoch = 60 17 | train_batch = 32 18 | test_batch = 32 19 | adjust_lr = False 20 | lr = 0.0001 21 | adjust_lr = False 22 | gamma = 0.1 23 | weight_decay = 5e-4 24 | momentum = 0.9 25 | random_crop = False 26 | margin = None 27 | num_instances = 4 28 | num_gpu = 1 29 | evaluate = False 30 | savefig = None 31 | re_ranking = False 32 | 33 | # model options 34 | model_name = 'bfe' # triplet, softmax_triplet, bfe, ide 35 | last_stride = 1 36 | pretrained_model = None 37 | 38 | # miscs 39 | print_freq = 30 40 | eval_step = 50 41 | save_dir = './pytorch-ckpt/market' 42 | workers = 10 43 | start_epoch = 0 44 | best_rank = -np.inf 45 | 46 | def _parse(self, kwargs): 47 | for k, v in kwargs.items(): 48 | if not hasattr(self, k): 49 | warnings.warn("Warning: opt has not attribut %s" % k) 50 | setattr(self, k, v) 51 | if 'cls' in self.dataset: 52 | self.mode='class' 53 | if 'market' in self.dataset or 'cuhk' in self.dataset or 'duke' in self.dataset: 54 | self.datatype = 'person' 55 | elif 'cub' in self.dataset: 56 | self.datatype = 'cub' 57 | elif 'car' in self.dataset: 58 | self.datatype = 'car' 59 | elif 'clothes' in self.dataset: 60 | self.datatype = 'clothes' 61 | elif 'product' in self.dataset: 62 | self.datatype = 'product' 63 | 64 | def _state_dict(self): 65 | return {k: getattr(self, k) for k, _ in DefaultConfig.__dict__.items() 66 | if not k.startswith('_')} 67 | 68 | opt = DefaultConfig() 69 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daizuozhuo/batch-dropblock-network/21c99abb8d85cfb29d56fc57d09c1ecdfe6b6be5/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/data_loader.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | 6 | 7 | def read_image(img_path): 8 | """Keep reading image until succeed. 9 | This can avoid IOError incurred by heavy IO process.""" 10 | got_img = False 11 | while not got_img: 12 | try: 13 | img = Image.open(img_path).convert('RGB') 14 | got_img = True 15 | except IOError: 16 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 17 | pass 18 | return img 19 | 20 | class ImageData(Dataset): 21 | def __init__(self, dataset, transform): 22 | self.dataset = dataset 23 | self.transform = transform 24 | 25 | def __getitem__(self, item): 26 | img, pid, camid = self.dataset[item] 27 | img = read_image(img) 28 | if self.transform is not None: 29 | img = self.transform(img) 30 | return img, pid, camid 31 | 32 | def __len__(self): 33 | return len(self.dataset) 34 | -------------------------------------------------------------------------------- /datasets/data_manager.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | import glob 4 | import re 5 | from os import path as osp 6 | import os 7 | 8 | """Dataset classes""" 9 | 10 | 11 | class Market1501(object): 12 | """ 13 | Market1501 14 | Reference: 15 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 16 | URL: http://www.liangzheng.org/Project/project_reid.html 17 | 18 | Dataset statistics: 19 | # identities: 1501 (+1 for background) 20 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 21 | """ 22 | def __init__(self, dataset_dir, mode, root='data'): 23 | self.dataset_dir = dataset_dir 24 | self.dataset_dir = osp.join(root, self.dataset_dir) 25 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 26 | self.query_dir = osp.join(self.dataset_dir, 'query') 27 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 28 | 29 | self._check_before_run() 30 | train_relabel = (mode == 'retrieval') 31 | train, num_train_pids, num_train_imgs = self._process_dir(self.train_dir, relabel=train_relabel) 32 | query, num_query_pids, num_query_imgs = self._process_dir(self.query_dir, relabel=False) 33 | gallery, num_gallery_pids, num_gallery_imgs = self._process_dir(self.gallery_dir, relabel=False) 34 | num_total_pids = num_train_pids + num_query_pids 35 | num_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgs 36 | 37 | print("=> Market1501 loaded") 38 | print("Dataset statistics:") 39 | print(" ------------------------------") 40 | print(" subset | # ids | # images") 41 | print(" ------------------------------") 42 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_imgs)) 43 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_imgs)) 44 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_imgs)) 45 | print(" ------------------------------") 46 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_imgs)) 47 | print(" ------------------------------") 48 | 49 | self.train = train 50 | self.query = query 51 | self.gallery = gallery 52 | 53 | self.num_train_pids = num_train_pids 54 | self.num_query_pids = num_query_pids 55 | self.num_gallery_pids = num_gallery_pids 56 | 57 | def _check_before_run(self): 58 | """Check if all files are available before going deeper""" 59 | if not osp.exists(self.dataset_dir): 60 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 61 | if not osp.exists(self.train_dir): 62 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 63 | if not osp.exists(self.query_dir): 64 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 65 | if not osp.exists(self.gallery_dir): 66 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 67 | 68 | def _process_dir(self, dir_path, relabel=False): 69 | img_names = os.listdir(dir_path) 70 | img_paths = [os.path.join(dir_path, img_name) for img_name in img_names \ 71 | if img_name.endswith('jpg') or img_name.endswith('png')] 72 | pattern = re.compile(r'([-\d]+)_c([-\d]+)') 73 | 74 | pid_container = set() 75 | for img_path in img_paths: 76 | pid, _ = map(int, pattern.search(img_path).groups()) 77 | if pid == -1: continue # junk images are just ignored 78 | pid_container.add(pid) 79 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 80 | 81 | dataset = [] 82 | for img_path in img_paths: 83 | pid, camid = map(int, pattern.search(img_path).groups()) 84 | if pid == -1: 85 | continue # junk images are just ignored 86 | #assert 0 <= pid <= 1501 # pid == 0 means background 87 | #assert 1 <= camid <= 6 88 | camid -= 1 # index starts from 0 89 | if relabel: pid = pid2label[pid] 90 | dataset.append((img_path, pid, camid)) 91 | 92 | num_pids = len(pid_container) 93 | num_imgs = len(dataset) 94 | return dataset, num_pids, num_imgs 95 | 96 | def init_dataset(name, mode): 97 | return Market1501(name, mode) 98 | -------------------------------------------------------------------------------- /datasets/samplers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | import torch 7 | import random 8 | from torch.utils.data.sampler import Sampler 9 | 10 | 11 | class RandomIdentitySampler(Sampler): 12 | def __init__(self, data_source, num_instances=4): 13 | self.data_source = data_source 14 | self.num_instances = num_instances 15 | self.index_dic = defaultdict(list) 16 | for index, (_, pid, _) in enumerate(data_source): 17 | self.index_dic[pid].append(index) 18 | self.pids = list(self.index_dic.keys()) 19 | self.num_identities = len(self.pids) 20 | 21 | def __iter__(self): 22 | indices = torch.randperm(self.num_identities) 23 | ret = [] 24 | for i in indices: 25 | pid = self.pids[i] 26 | t = self.index_dic[pid] 27 | replace = False if len(t) >= self.num_instances else True 28 | t = np.random.choice(t, size=self.num_instances, replace=replace) 29 | ret.extend(t) 30 | return iter(ret) 31 | 32 | def __len__(self): 33 | return self.num_identities * self.num_instances 34 | -------------------------------------------------------------------------------- /main_reid.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import os 3 | import sys 4 | from os import path as osp 5 | from pprint import pprint 6 | 7 | import numpy as np 8 | import torch 9 | from tensorboardX import SummaryWriter 10 | from torch import nn 11 | from torch.backends import cudnn 12 | from torch.utils.data import DataLoader 13 | 14 | from config import opt 15 | from datasets import data_manager 16 | from datasets.data_loader import ImageData 17 | from datasets.samplers import RandomIdentitySampler 18 | from models.networks import ResNetBuilder, IDE, Resnet, BFE 19 | from trainers.evaluator import ResNetEvaluator 20 | from trainers.trainer import cls_tripletTrainer 21 | from utils.loss import CrossEntropyLabelSmooth, TripletLoss, Margin 22 | from utils.LiftedStructure import LiftedStructureLoss 23 | from utils.DistWeightDevianceLoss import DistWeightBinDevianceLoss 24 | from utils.serialization import Logger, save_checkpoint 25 | from utils.transforms import TestTransform, TrainTransform 26 | 27 | 28 | def train(**kwargs): 29 | opt._parse(kwargs) 30 | 31 | # set random seed and cudnn benchmark 32 | torch.manual_seed(opt.seed) 33 | os.makedirs(opt.save_dir, exist_ok=True) 34 | use_gpu = torch.cuda.is_available() 35 | sys.stdout = Logger(osp.join(opt.save_dir, 'log_train.txt')) 36 | 37 | print('=========user config==========') 38 | pprint(opt._state_dict()) 39 | print('============end===============') 40 | 41 | if use_gpu: 42 | print('currently using GPU') 43 | cudnn.benchmark = True 44 | torch.cuda.manual_seed_all(opt.seed) 45 | else: 46 | print('currently using cpu') 47 | 48 | print('initializing dataset {}'.format(opt.dataset)) 49 | dataset = data_manager.init_dataset(name=opt.dataset, mode=opt.mode) 50 | 51 | pin_memory = True if use_gpu else False 52 | 53 | summary_writer = SummaryWriter(osp.join(opt.save_dir, 'tensorboard_log')) 54 | 55 | trainloader = DataLoader( 56 | ImageData(dataset.train, TrainTransform(opt.datatype)), 57 | sampler=RandomIdentitySampler(dataset.train, opt.num_instances), 58 | batch_size=opt.train_batch, num_workers=opt.workers, 59 | pin_memory=pin_memory, drop_last=True 60 | ) 61 | 62 | queryloader = DataLoader( 63 | ImageData(dataset.query, TestTransform(opt.datatype)), 64 | batch_size=opt.test_batch, num_workers=opt.workers, 65 | pin_memory=pin_memory 66 | ) 67 | 68 | galleryloader = DataLoader( 69 | ImageData(dataset.gallery, TestTransform(opt.datatype)), 70 | batch_size=opt.test_batch, num_workers=opt.workers, 71 | pin_memory=pin_memory 72 | ) 73 | queryFliploader = DataLoader( 74 | ImageData(dataset.query, TestTransform(opt.datatype, True)), 75 | batch_size=opt.test_batch, num_workers=opt.workers, 76 | pin_memory=pin_memory 77 | ) 78 | 79 | galleryFliploader = DataLoader( 80 | ImageData(dataset.gallery, TestTransform(opt.datatype, True)), 81 | batch_size=opt.test_batch, num_workers=opt.workers, 82 | pin_memory=pin_memory 83 | ) 84 | 85 | print('initializing model ...') 86 | if opt.model_name == 'softmax' or opt.model_name == 'softmax_triplet': 87 | model = ResNetBuilder(dataset.num_train_pids, 1, True) 88 | elif opt.model_name == 'triplet': 89 | model = ResNetBuilder(None, 1, True) 90 | elif opt.model_name == 'bfe': 91 | if opt.datatype == "person": 92 | model = BFE(dataset.num_train_pids, 1.0, 0.33) 93 | else: 94 | model = BFE(dataset.num_train_pids, 0.5, 0.5) 95 | elif opt.model_name == 'ide': 96 | model = IDE(dataset.num_train_pids) 97 | elif opt.model_name == 'resnet': 98 | model = Resnet(dataset.num_train_pids) 99 | 100 | optim_policy = model.get_optim_policy() 101 | 102 | if opt.pretrained_model: 103 | state_dict = torch.load(opt.pretrained_model)['state_dict'] 104 | #state_dict = {k: v for k, v in state_dict.items() \ 105 | # if not ('reduction' in k or 'softmax' in k)} 106 | model.load_state_dict(state_dict, False) 107 | print('load pretrained model ' + opt.pretrained_model) 108 | print('model size: {:.5f}M'.format(sum(p.numel() for p in model.parameters()) / 1e6)) 109 | 110 | if use_gpu: 111 | model = nn.DataParallel(model).cuda() 112 | reid_evaluator = ResNetEvaluator(model) 113 | 114 | if opt.evaluate: 115 | reid_evaluator.evaluate(queryloader, galleryloader, 116 | queryFliploader, galleryFliploader, re_ranking=opt.re_ranking, savefig=opt.savefig) 117 | return 118 | 119 | #xent_criterion = nn.CrossEntropyLoss() 120 | xent_criterion = CrossEntropyLabelSmooth(dataset.num_train_pids) 121 | 122 | if opt.loss == 'triplet': 123 | embedding_criterion = TripletLoss(opt.margin) 124 | elif opt.loss == 'lifted': 125 | embedding_criterion = LiftedStructureLoss(hard_mining=True) 126 | elif opt.loss == 'weight': 127 | embedding_criterion = Margin() 128 | 129 | def criterion(triplet_y, softmax_y, labels): 130 | losses = [embedding_criterion(output, labels)[0] for output in triplet_y] + \ 131 | [xent_criterion(output, labels) for output in softmax_y] 132 | loss = sum(losses) 133 | return loss 134 | 135 | # get optimizer 136 | if opt.optim == "sgd": 137 | optimizer = torch.optim.SGD(optim_policy, lr=opt.lr, momentum=0.9, weight_decay=opt.weight_decay) 138 | else: 139 | optimizer = torch.optim.Adam(optim_policy, lr=opt.lr, weight_decay=opt.weight_decay) 140 | 141 | 142 | start_epoch = opt.start_epoch 143 | # get trainer and evaluator 144 | reid_trainer = cls_tripletTrainer(opt, model, optimizer, criterion, summary_writer) 145 | 146 | def adjust_lr(optimizer, ep): 147 | if ep < 50: 148 | lr = 1e-4*(ep//5+1) 149 | elif ep < 200: 150 | lr = 1e-3 151 | elif ep < 300: 152 | lr = 1e-4 153 | else: 154 | lr = 1e-5 155 | for p in optimizer.param_groups: 156 | p['lr'] = lr 157 | 158 | # start training 159 | best_rank1 = opt.best_rank 160 | best_epoch = 0 161 | for epoch in range(start_epoch, opt.max_epoch): 162 | if opt.adjust_lr: 163 | adjust_lr(optimizer, epoch + 1) 164 | reid_trainer.train(epoch, trainloader) 165 | 166 | # skip if not save model 167 | if opt.eval_step > 0 and (epoch + 1) % opt.eval_step == 0 or (epoch + 1) == opt.max_epoch: 168 | if opt.mode == 'class': 169 | rank1 = test(model, queryloader) 170 | else: 171 | rank1 = reid_evaluator.evaluate(queryloader, galleryloader, queryFliploader, galleryFliploader) 172 | is_best = rank1 > best_rank1 173 | if is_best: 174 | best_rank1 = rank1 175 | best_epoch = epoch + 1 176 | 177 | if use_gpu: 178 | state_dict = model.module.state_dict() 179 | else: 180 | state_dict = model.state_dict() 181 | save_checkpoint({'state_dict': state_dict, 'epoch': epoch + 1}, 182 | is_best=is_best, save_dir=opt.save_dir, 183 | filename='checkpoint_ep' + str(epoch + 1) + '.pth.tar') 184 | 185 | print('Best rank-1 {:.1%}, achived at epoch {}'.format(best_rank1, best_epoch)) 186 | 187 | def test(model, queryloader): 188 | model.eval() 189 | correct = 0 190 | with torch.no_grad(): 191 | for data, target, _ in queryloader: 192 | output = model(data).cpu() 193 | # get the index of the max log-probability 194 | pred = output.max(1, keepdim=True)[1] 195 | correct += pred.eq(target.view_as(pred)).sum().item() 196 | 197 | rank1 = 100. * correct / len(queryloader.dataset) 198 | print('\nTest set: Accuracy: {}/{} ({:.2f}%)\n'.format(correct, len(queryloader.dataset), rank1)) 199 | return rank1 200 | 201 | if __name__ == '__main__': 202 | import fire 203 | fire.Fire() 204 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daizuozhuo/batch-dropblock-network/21c99abb8d85cfb29d56fc57d09c1ecdfe6b6be5/models/__init__.py -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import copy 3 | import itertools 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | import torch.utils.model_zoo as model_zoo 9 | import random 10 | from scipy.spatial.distance import cdist 11 | from sklearn.preprocessing import normalize 12 | from torch import nn, optim 13 | from torch.utils.data import dataloader 14 | from torchvision import transforms 15 | from torchvision.models.resnet import Bottleneck, resnet50 16 | from torchvision.transforms import functional 17 | 18 | from .resnet import ResNet 19 | 20 | def weights_init_kaiming(m): 21 | classname = m.__class__.__name__ 22 | if classname.find('Linear') != -1: 23 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') 24 | nn.init.constant_(m.bias, 0.0) 25 | elif classname.find('Conv') != -1: 26 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 27 | if m.bias is not None: 28 | nn.init.constant_(m.bias, 0.0) 29 | elif classname.find('BatchNorm') != -1: 30 | if m.affine: 31 | nn.init.normal_(m.weight, 1.0, 0.02) 32 | nn.init.constant_(m.bias, 0.0) 33 | 34 | 35 | def weights_init_classifier(m): 36 | classname = m.__class__.__name__ 37 | if classname.find('Linear') != -1: 38 | nn.init.normal_(m.weight, std=0.001) 39 | if m.bias: 40 | nn.init.constant_(m.bias, 0.0) 41 | 42 | class SELayer(nn.Module): 43 | def __init__(self, channel, reduction=16): 44 | super(SELayer, self).__init__() 45 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 46 | self.fc = nn.Sequential( 47 | nn.Linear(channel, channel // reduction), 48 | nn.ReLU(inplace=True), 49 | nn.Linear(channel // reduction, channel), 50 | nn.Sigmoid() 51 | ) 52 | 53 | def forward(self, x): 54 | b, c, _, _ = x.size() 55 | y = self.avg_pool(x).view(b, c) 56 | y = self.fc(y).view(b, c, 1, 1) 57 | return x * y 58 | 59 | class BatchDrop(nn.Module): 60 | def __init__(self, h_ratio, w_ratio): 61 | super(BatchDrop, self).__init__() 62 | self.h_ratio = h_ratio 63 | self.w_ratio = w_ratio 64 | 65 | def forward(self, x): 66 | if self.training: 67 | h, w = x.size()[-2:] 68 | rh = round(self.h_ratio * h) 69 | rw = round(self.w_ratio * w) 70 | sx = random.randint(0, h-rh) 71 | sy = random.randint(0, w-rw) 72 | mask = x.new_ones(x.size()) 73 | mask[:, :, sx:sx+rh, sy:sy+rw] = 0 74 | x = x * mask 75 | return x 76 | 77 | class BatchCrop(nn.Module): 78 | def __init__(self, ratio): 79 | super(BatchCrop, self).__init__() 80 | self.ratio = ratio 81 | 82 | def forward(self, x): 83 | if self.training: 84 | h, w = x.size()[-2:] 85 | rw = int(self.ratio * w) 86 | start = random.randint(0, h-1) 87 | if start + rw > h: 88 | select = list(range(0, start+rw-h)) + list(range(start, h)) 89 | else: 90 | select = list(range(start, start+rw)) 91 | mask = x.new_zeros(x.size()) 92 | mask[:, :, select, :] = 1 93 | x = x * mask 94 | return x 95 | 96 | class ResNetBuilder(nn.Module): 97 | in_planes = 2048 98 | 99 | def __init__(self, num_classes=None, last_stride=1, pretrained=False): 100 | super().__init__() 101 | self.base = ResNet(last_stride) 102 | if pretrained: 103 | model_url = 'https://download.pytorch.org/models/resnet50-19c8e357.pth' 104 | self.base.load_param(model_zoo.load_url(model_url)) 105 | 106 | self.num_classes = num_classes 107 | if num_classes is not None: 108 | self.bottleneck = nn.Sequential( 109 | nn.Linear(self.in_planes, 512), 110 | nn.BatchNorm1d(512), 111 | nn.LeakyReLU(0.1), 112 | nn.Dropout(p=0.5) 113 | ) 114 | self.bottleneck.apply(weights_init_kaiming) 115 | self.classifier = nn.Linear(512, self.num_classes) 116 | self.classifier.apply(weights_init_classifier) 117 | 118 | def forward(self, x): 119 | global_feat = self.base(x) 120 | global_feat = F.avg_pool2d(global_feat, global_feat.shape[2:]) # (b, 2048, 1, 1) 121 | global_feat = global_feat.view(global_feat.shape[0], -1) 122 | if self.training and self.num_classes is not None: 123 | feat = self.bottleneck(global_feat) 124 | cls_score = self.classifier(feat) 125 | return [global_feat], [cls_score] 126 | else: 127 | return global_feat 128 | 129 | def get_optim_policy(self): 130 | base_param_group = self.base.parameters() 131 | if self.num_classes is not None: 132 | add_param_group = itertools.chain(self.bottleneck.parameters(), self.classifier.parameters()) 133 | return [ 134 | {'params': base_param_group}, 135 | {'params': add_param_group} 136 | ] 137 | else: 138 | return [ 139 | {'params': base_param_group} 140 | ] 141 | 142 | class BFE(nn.Module): 143 | def __init__(self, num_classes, width_ratio=0.5, height_ratio=0.5): 144 | super(BFE, self).__init__() 145 | resnet = resnet50(pretrained=True) 146 | self.backbone = nn.Sequential( 147 | resnet.conv1, 148 | resnet.bn1, 149 | resnet.relu, 150 | resnet.maxpool, 151 | resnet.layer1, # res_conv2 152 | resnet.layer2, # res_conv3 153 | resnet.layer3, # res_conv4 154 | ) 155 | self.res_part = nn.Sequential( 156 | Bottleneck(1024, 512, stride=1, downsample=nn.Sequential( 157 | nn.Conv2d(1024, 2048, kernel_size=1, stride=1, bias=False), 158 | nn.BatchNorm2d(2048), 159 | )), 160 | Bottleneck(2048, 512), 161 | Bottleneck(2048, 512), 162 | ) 163 | self.res_part.load_state_dict(resnet.layer4.state_dict()) 164 | reduction = nn.Sequential( 165 | nn.Conv2d(2048, 512, 1), 166 | nn.BatchNorm2d(512), 167 | nn.ReLU() 168 | ) 169 | # global branch 170 | self.global_avgpool = nn.AdaptiveAvgPool2d((1, 1)) 171 | self.global_softmax = nn.Linear(512, num_classes) 172 | self.global_softmax.apply(weights_init_kaiming) 173 | self.global_reduction = copy.deepcopy(reduction) 174 | self.global_reduction.apply(weights_init_kaiming) 175 | 176 | # part branch 177 | self.res_part2 = Bottleneck(2048, 512) 178 | 179 | self.part_maxpool = nn.AdaptiveMaxPool2d((1,1)) 180 | self.batch_crop = BatchDrop(height_ratio, width_ratio) 181 | self.reduction = nn.Sequential( 182 | nn.Linear(2048, 1024, 1), 183 | nn.BatchNorm1d(1024), 184 | nn.ReLU() 185 | ) 186 | self.reduction.apply(weights_init_kaiming) 187 | self.softmax = nn.Linear(1024, num_classes) 188 | self.softmax.apply(weights_init_kaiming) 189 | 190 | def forward(self, x): 191 | """ 192 | :param x: input image tensor of (N, C, H, W) 193 | :return: (prediction, triplet_losses, softmax_losses) 194 | """ 195 | x = self.backbone(x) 196 | x = self.res_part(x) 197 | 198 | predict = [] 199 | triplet_features = [] 200 | softmax_features = [] 201 | 202 | #global branch 203 | glob = self.global_avgpool(x) 204 | global_triplet_feature = self.global_reduction(glob).squeeze() 205 | global_softmax_class = self.global_softmax(global_triplet_feature) 206 | softmax_features.append(global_softmax_class) 207 | triplet_features.append(global_triplet_feature) 208 | predict.append(global_triplet_feature) 209 | 210 | #part branch 211 | x = self.res_part2(x) 212 | 213 | x = self.batch_crop(x) 214 | triplet_feature = self.part_maxpool(x).squeeze() 215 | feature = self.reduction(triplet_feature) 216 | softmax_feature = self.softmax(feature) 217 | triplet_features.append(feature) 218 | softmax_features.append(softmax_feature) 219 | predict.append(feature) 220 | 221 | if self.training: 222 | return triplet_features, softmax_features 223 | else: 224 | return torch.cat(predict, 1) 225 | 226 | def get_optim_policy(self): 227 | params = [ 228 | {'params': self.backbone.parameters()}, 229 | {'params': self.res_part.parameters()}, 230 | {'params': self.global_reduction.parameters()}, 231 | {'params': self.global_softmax.parameters()}, 232 | {'params': self.res_part2.parameters()}, 233 | {'params': self.reduction.parameters()}, 234 | {'params': self.softmax.parameters()}, 235 | ] 236 | return params 237 | 238 | class Resnet(nn.Module): 239 | def __init__(self, num_classes, resnet=None): 240 | super(Resnet, self).__init__() 241 | if not resnet: 242 | resnet = resnet50(pretrained=True) 243 | self.backbone = nn.Sequential( 244 | resnet.conv1, 245 | resnet.bn1, 246 | resnet.relu, 247 | resnet.maxpool, 248 | resnet.layer1, # res_conv2 249 | resnet.layer2, # res_conv3 250 | resnet.layer3, # res_conv4 251 | resnet.layer4 252 | ) 253 | self.global_avgpool = nn.AdaptiveAvgPool2d((1, 1)) 254 | self.softmax = nn.Linear(2048, num_classes) 255 | 256 | def forward(self, x): 257 | """ 258 | :param x: input image tensor of (N, C, H, W) 259 | :return: (prediction, triplet_losses, softmax_losses) 260 | """ 261 | x = self.backbone(x) 262 | 263 | x = self.global_avgpool(x).squeeze() 264 | feature = self.softmax(x) 265 | if self.training: 266 | return [], [feature] 267 | else: 268 | return feature 269 | 270 | def get_optim_policy(self): 271 | return self.parameters() 272 | 273 | class IDE(nn.Module): 274 | def __init__(self, num_classes, resnet=None): 275 | super(IDE, self).__init__() 276 | if not resnet: 277 | resnet = resnet50(pretrained=True) 278 | self.backbone = nn.Sequential( 279 | resnet.conv1, 280 | resnet.bn1, 281 | resnet.relu, 282 | resnet.maxpool, 283 | resnet.layer1, # res_conv2 284 | resnet.layer2, # res_conv3 285 | resnet.layer3, # res_conv4 286 | resnet.layer4 287 | ) 288 | self.global_avgpool = nn.AvgPool2d(kernel_size=(12, 4)) 289 | 290 | def forward(self, x): 291 | """ 292 | :param x: input image tensor of (N, C, H, W) 293 | :return: (prediction, triplet_losses, softmax_losses) 294 | """ 295 | x = self.backbone(x) 296 | 297 | feature = self.global_avgpool(x).squeeze() 298 | if self.training: 299 | return [feature], [] 300 | else: 301 | return feature 302 | 303 | def get_optim_policy(self): 304 | return self.parameters() -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import math 3 | 4 | import torch as th 5 | import torch 6 | from torch import nn 7 | 8 | 9 | class Bottleneck(nn.Module): 10 | expansion = 4 11 | 12 | def __init__(self, inplanes, planes, stride=1, downsample=None): 13 | super(Bottleneck, self).__init__() 14 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 15 | self.bn1 = nn.BatchNorm2d(planes) 16 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 17 | padding=1, bias=False) 18 | self.bn2 = nn.BatchNorm2d(planes) 19 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 20 | self.bn3 = nn.BatchNorm2d(planes * 4) 21 | self.relu = nn.ReLU(inplace=True) 22 | self.downsample = downsample 23 | self.stride = stride 24 | 25 | def forward(self, x): 26 | residual = x 27 | 28 | out = self.conv1(x) 29 | out = self.bn1(out) 30 | out = self.relu(out) 31 | 32 | out = self.conv2(out) 33 | out = self.bn2(out) 34 | out = self.relu(out) 35 | 36 | out = self.conv3(out) 37 | out = self.bn3(out) 38 | 39 | if self.downsample is not None: 40 | residual = self.downsample(x) 41 | 42 | out += residual 43 | out = self.relu(out) 44 | 45 | return out 46 | 47 | class CBAM_Module(nn.Module): 48 | 49 | def __init__(self, channels, reduction): 50 | super(CBAM_Module, self).__init__() 51 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 52 | self.max_pool = nn.AdaptiveMaxPool2d(1) 53 | self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, 54 | padding=0) 55 | self.relu = nn.ReLU(inplace=True) 56 | self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, 57 | padding=0) 58 | self.sigmoid_channel = nn.Sigmoid() 59 | self.conv_after_concat = nn.Conv2d(2, 1, kernel_size = 3, stride=1, padding = 1) 60 | self.sigmoid_spatial = nn.Sigmoid() 61 | 62 | def forward(self, x): 63 | #channel attention 64 | module_input = x 65 | avg = self.avg_pool(x) 66 | mx = self.max_pool(x) 67 | avg = self.fc1(avg) 68 | mx = self.fc1(mx) 69 | avg = self.relu(avg) 70 | mx = self.relu(mx) 71 | avg = self.fc2(avg) 72 | mx = self.fc2(mx) 73 | x = avg + mx 74 | x = self.sigmoid_channel(x) 75 | x = module_input * x 76 | #spatial attention 77 | module_input = x 78 | avg = torch.mean(x, 1, True) 79 | mx, _ = torch.max(x, 1, True) 80 | x = torch.cat((avg, mx), 1) 81 | x = self.conv_after_concat(x) 82 | x = self.sigmoid_spatial(x) 83 | x = module_input * x 84 | return x 85 | 86 | class CBAMBottleneck(nn.Module): 87 | expansion = 4 88 | 89 | def __init__(self, inplanes, planes, stride=1, downsample=None): 90 | super(CBAMBottleneck, self).__init__() 91 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 92 | self.bn1 = nn.BatchNorm2d(planes) 93 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 94 | padding=1, bias=False) 95 | self.bn2 = nn.BatchNorm2d(planes) 96 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 97 | self.bn3 = nn.BatchNorm2d(planes * 4) 98 | self.relu = nn.ReLU(inplace=True) 99 | self.cbam = CBAM_Module(planes * 4, reduction=16) 100 | self.downsample = downsample 101 | self.stride = stride 102 | 103 | def forward(self, x): 104 | residual = x 105 | 106 | out = self.conv1(x) 107 | out = self.bn1(out) 108 | out = self.relu(out) 109 | 110 | out = self.conv2(out) 111 | out = self.bn2(out) 112 | out = self.relu(out) 113 | 114 | out = self.conv3(out) 115 | out = self.bn3(out) 116 | out = self.cbam(out) 117 | if self.downsample is not None: 118 | residual = self.downsample(x) 119 | 120 | out += residual 121 | out = self.relu(out) 122 | 123 | return out 124 | 125 | def cbam_resnet50(): 126 | return ResNet(last_stride=1, block=CBAMBottleneck) 127 | 128 | 129 | class ResNet(nn.Module): 130 | def __init__(self, last_stride=2, block=Bottleneck, layers=[3, 4, 6, 3]): 131 | self.inplanes = 64 132 | super().__init__() 133 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 134 | bias=False) 135 | self.bn1 = nn.BatchNorm2d(64) 136 | self.relu = nn.ReLU(inplace=True) 137 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 138 | self.layer1 = self._make_layer(block, 64, layers[0]) 139 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 140 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 141 | self.layer4 = self._make_layer( 142 | block, 512, layers[3], stride=last_stride) 143 | 144 | def _make_layer(self, block, planes, blocks, stride=1): 145 | downsample = None 146 | if stride != 1 or self.inplanes != planes * block.expansion: 147 | downsample = nn.Sequential( 148 | nn.Conv2d(self.inplanes, planes * block.expansion, 149 | kernel_size=1, stride=stride, bias=False), 150 | nn.BatchNorm2d(planes * block.expansion), 151 | ) 152 | 153 | layers = [] 154 | layers.append(block(self.inplanes, planes, stride, downsample)) 155 | self.inplanes = planes * block.expansion 156 | for i in range(1, blocks): 157 | layers.append(block(self.inplanes, planes)) 158 | 159 | return nn.Sequential(*layers) 160 | 161 | def forward(self, x): 162 | x = self.conv1(x) 163 | x = self.bn1(x) 164 | x = self.relu(x) 165 | x = self.maxpool(x) 166 | 167 | x = self.layer1(x) 168 | x = self.layer2(x) 169 | x = self.layer3(x) 170 | x = self.layer4(x) 171 | 172 | return x 173 | 174 | def load_param(self, param_dict): 175 | for i in param_dict: 176 | if 'fc' in i: 177 | continue 178 | self.state_dict()[i].copy_(param_dict[i]) 179 | 180 | def random_init(self): 181 | for m in self.modules(): 182 | if isinstance(m, nn.Conv2d): 183 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 184 | m.weight.data.normal_(0, math.sqrt(2. / n)) 185 | elif isinstance(m, nn.BatchNorm2d): 186 | m.weight.data.fill_(1) 187 | m.bias.data.zero_() 188 | 189 | 190 | if __name__ == "__main__": 191 | net = ResNet(last_stride=2) 192 | import torch 193 | 194 | x = net(torch.zeros(1, 3, 256, 128)) 195 | print(x.shape) 196 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorboardX 2 | fire -------------------------------------------------------------------------------- /trainers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daizuozhuo/batch-dropblock-network/21c99abb8d85cfb29d56fc57d09c1ecdfe6b6be5/trainers/__init__.py -------------------------------------------------------------------------------- /trainers/evaluator.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import numpy as np 3 | import os 4 | import torch 5 | from PIL import Image 6 | import matplotlib.pyplot as plt 7 | 8 | from trainers.re_ranking import re_ranking as re_ranking_func 9 | 10 | class ResNetEvaluator: 11 | def __init__(self, model): 12 | self.model = model 13 | 14 | def save_incorrect_pairs(self, distmat, queryloader, galleryloader, 15 | g_pids, q_pids, g_camids, q_camids, savefig): 16 | os.makedirs(savefig, exist_ok=True) 17 | self.model.eval() 18 | m = distmat.shape[0] 19 | indices = np.argsort(distmat, axis=1) 20 | for i in range(m): 21 | for j in range(10): 22 | index = indices[i][j] 23 | if g_camids[index] == q_camids[i] and g_pids[index] == q_pids[i]: 24 | continue 25 | else: 26 | break 27 | if g_pids[index] == q_pids[i]: 28 | continue 29 | fig, axes =plt.subplots(1, 11, figsize=(12, 8)) 30 | img = queryloader.dataset.dataset[i][0] 31 | img = Image.open(img).convert('RGB') 32 | axes[0].set_title(q_pids[i]) 33 | axes[0].imshow(img) 34 | axes[0].set_axis_off() 35 | for j in range(10): 36 | gallery_index = indices[i][j] 37 | img = galleryloader.dataset.dataset[gallery_index][0] 38 | img = Image.open(img).convert('RGB') 39 | axes[j+1].set_title(g_pids[gallery_index]) 40 | axes[j+1].set_axis_off() 41 | axes[j+1].imshow(img) 42 | fig.savefig(os.path.join(savefig, '%d.png' %q_pids[i])) 43 | plt.close(fig) 44 | 45 | def evaluate(self, queryloader, galleryloader, queryFliploader, galleryFliploader, 46 | ranks=[1, 2, 4, 5,8, 10, 16, 20], eval_flip=False, re_ranking=False, savefig=False): 47 | self.model.eval() 48 | qf, q_pids, q_camids = [], [], [] 49 | for inputs0, inputs1 in zip(queryloader, queryFliploader): 50 | inputs, pids, camids = self._parse_data(inputs0) 51 | feature0 = self._forward(inputs) 52 | if eval_flip: 53 | inputs, pids, camids = self._parse_data(inputs1) 54 | feature1 = self._forward(inputs) 55 | qf.append((feature0 + feature1) / 2.0) 56 | else: 57 | qf.append(feature0) 58 | 59 | q_pids.extend(pids) 60 | q_camids.extend(camids) 61 | qf = torch.cat(qf, 0) 62 | q_pids = torch.Tensor(q_pids) 63 | q_camids = torch.Tensor(q_camids) 64 | 65 | print("Extracted features for query set: {} x {}".format(qf.size(0), qf.size(1))) 66 | 67 | gf, g_pids, g_camids = [], [], [] 68 | for inputs0, inputs1 in zip(galleryloader, galleryFliploader): 69 | inputs, pids, camids = self._parse_data(inputs0) 70 | feature0 = self._forward(inputs) 71 | if eval_flip: 72 | inputs, pids, camids = self._parse_data(inputs1) 73 | feature1 = self._forward(inputs) 74 | gf.append((feature0 + feature1) / 2.0) 75 | else: 76 | gf.append(feature0) 77 | 78 | g_pids.extend(pids) 79 | g_camids.extend(camids) 80 | gf = torch.cat(gf, 0) 81 | g_pids = torch.Tensor(g_pids) 82 | g_camids = torch.Tensor(g_camids) 83 | 84 | print("Extracted features for gallery set: {} x {}".format(gf.size(0), gf.size(1))) 85 | 86 | print("Computing distance matrix") 87 | 88 | m, n = qf.size(0), gf.size(0) 89 | q_g_dist = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 90 | torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() 91 | q_g_dist.addmm_(1, -2, qf, gf.t()) 92 | 93 | if re_ranking: 94 | q_q_dist = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, m) + \ 95 | torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, m).t() 96 | q_q_dist.addmm_(1, -2, qf, qf.t()) 97 | 98 | g_g_dist = torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, n) + \ 99 | torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, n).t() 100 | g_g_dist.addmm_(1, -2, gf, gf.t()) 101 | 102 | q_g_dist = q_g_dist.numpy() 103 | q_g_dist[q_g_dist < 0] = 0 104 | q_g_dist = np.sqrt(q_g_dist) 105 | 106 | q_q_dist = q_q_dist.numpy() 107 | q_q_dist[q_q_dist < 0] = 0 108 | q_q_dist = np.sqrt(q_q_dist) 109 | 110 | g_g_dist = g_g_dist.numpy() 111 | g_g_dist[g_g_dist < 0] = 0 112 | g_g_dist = np.sqrt(g_g_dist) 113 | 114 | distmat = torch.Tensor(re_ranking_func(q_g_dist, q_q_dist, g_g_dist)) 115 | else: 116 | distmat = q_g_dist 117 | 118 | if savefig: 119 | print("Saving fingure") 120 | self.save_incorrect_pairs(distmat.numpy(), queryloader, galleryloader, 121 | g_pids.numpy(), q_pids.numpy(), g_camids.numpy(), q_camids.numpy(), savefig) 122 | 123 | print("Computing CMC and mAP") 124 | cmc, mAP = self.eval_func_gpu(distmat, q_pids, g_pids, q_camids, g_camids) 125 | 126 | print("Results ----------") 127 | print("mAP: {:.1%}".format(mAP)) 128 | print("CMC curve") 129 | for r in ranks: 130 | print("Rank-{:<3}: {:.1%}".format(r, cmc[r - 1])) 131 | print("------------------") 132 | 133 | return cmc[0] 134 | 135 | def _parse_data(self, inputs): 136 | imgs, pids, camids = inputs 137 | return imgs.cuda(), pids, camids 138 | 139 | def _forward(self, inputs): 140 | with torch.no_grad(): 141 | feature = self.model(inputs) 142 | return feature.cpu() 143 | 144 | def eval_func_gpu(self, distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50): 145 | num_q, num_g = distmat.size() 146 | if num_g < max_rank: 147 | max_rank = num_g 148 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 149 | _, indices = torch.sort(distmat, dim=1) 150 | matches = g_pids[indices] == q_pids.view([num_q, -1]) 151 | keep = ~((g_pids[indices] == q_pids.view([num_q, -1])) & (g_camids[indices] == q_camids.view([num_q, -1]))) 152 | #keep = g_camids[indices] != q_camids.view([num_q, -1]) 153 | 154 | results = [] 155 | num_rel = [] 156 | for i in range(num_q): 157 | m = matches[i][keep[i]] 158 | if m.any(): 159 | num_rel.append(m.sum()) 160 | results.append(m[:max_rank].unsqueeze(0)) 161 | matches = torch.cat(results, dim=0).float() 162 | num_rel = torch.Tensor(num_rel) 163 | 164 | cmc = matches.cumsum(dim=1) 165 | cmc[cmc > 1] = 1 166 | all_cmc = cmc.sum(dim=0) / cmc.size(0) 167 | 168 | pos = torch.Tensor(range(1, max_rank+1)) 169 | temp_cmc = matches.cumsum(dim=1) / pos * matches 170 | AP = temp_cmc.sum(dim=1) / num_rel 171 | mAP = AP.sum() / AP.size(0) 172 | return all_cmc.numpy(), mAP.item() 173 | 174 | def eval_func(self, distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50): 175 | """Evaluation with market1501 metric 176 | Key: for each query identity, its gallery images from the same camera view are discarded. 177 | """ 178 | num_q, num_g = distmat.shape 179 | if num_g < max_rank: 180 | max_rank = num_g 181 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 182 | indices = np.argsort(distmat, axis=1) 183 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 184 | 185 | # compute cmc curve for each query 186 | all_cmc = [] 187 | all_AP = [] 188 | num_valid_q = 0. # number of valid query 189 | for q_idx in range(num_q): 190 | # get query pid and camid 191 | q_pid = q_pids[q_idx] 192 | q_camid = q_camids[q_idx] 193 | 194 | # remove gallery samples that have the same pid and camid with query 195 | order = indices[q_idx] 196 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 197 | keep = np.invert(remove) 198 | 199 | # compute cmc curve 200 | # binary vector, positions with value 1 are correct matches 201 | orig_cmc = matches[q_idx][keep] 202 | if not np.any(orig_cmc): 203 | # this condition is true when query identity does not appear in gallery 204 | continue 205 | 206 | cmc = orig_cmc.cumsum() 207 | cmc[cmc > 1] = 1 208 | 209 | all_cmc.append(cmc[:max_rank]) 210 | num_valid_q += 1. 211 | 212 | # compute average precision 213 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 214 | num_rel = orig_cmc.sum() 215 | tmp_cmc = orig_cmc.cumsum() 216 | tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] 217 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 218 | AP = tmp_cmc.sum() / num_rel 219 | all_AP.append(AP) 220 | 221 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 222 | 223 | all_cmc = np.asarray(all_cmc).astype(np.float32) 224 | all_cmc = all_cmc.sum(0) / num_valid_q 225 | mAP = np.mean(all_AP) 226 | 227 | return all_cmc, mAP 228 | -------------------------------------------------------------------------------- /trainers/re_ranking.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2/python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Mon Jun 26 14:46:56 2017 5 | @author: luohao 6 | Modified by Houjing Huang, 2017-12-22. 7 | - This version accepts distance matrix instead of raw features. 8 | - The difference of `/` division between python 2 and 3 is handled. 9 | - numpy.float16 is replaced by numpy.float32 for numerical precision. 10 | 11 | Modified by Zhedong Zheng, 2018-1-12. 12 | - replace sort with topK, which save about 30s. 13 | """ 14 | 15 | """ 16 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017. 17 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf 18 | Matlab version: https://github.com/zhunzhong07/person-re-ranking 19 | """ 20 | 21 | """ 22 | API 23 | q_g_dist: query-gallery distance matrix, numpy array, shape [num_query, num_gallery] 24 | q_q_dist: query-query distance matrix, numpy array, shape [num_query, num_query] 25 | g_g_dist: gallery-gallery distance matrix, numpy array, shape [num_gallery, num_gallery] 26 | k1, k2, lambda_value: parameters, the original paper is (k1=20, k2=6, lambda_value=0.3) 27 | Returns: 28 | final_dist: re-ranked distance, numpy array, shape [num_query, num_gallery] 29 | """ 30 | 31 | 32 | import numpy as np 33 | def re_ranking(q_g_dist, q_q_dist, g_g_dist, k1=20, k2=6, lambda_value=0.3): 34 | 35 | # The following naming, e.g. gallery_num, is different from outer scope. 36 | # Don't care about it. 37 | 38 | original_dist = np.concatenate( 39 | [np.concatenate([q_q_dist, q_g_dist], axis=1), 40 | np.concatenate([q_g_dist.T, g_g_dist], axis=1)], 41 | axis=0) 42 | original_dist = np.power(original_dist, 2).astype(np.float32) 43 | original_dist = np.transpose(1. * original_dist/np.max(original_dist,axis = 0)) 44 | V = np.zeros_like(original_dist).astype(np.float32) 45 | initial_rank = np.argsort(original_dist).astype(np.int32) 46 | 47 | query_num = q_g_dist.shape[0] 48 | gallery_num = q_g_dist.shape[0] + q_g_dist.shape[1] 49 | all_num = gallery_num 50 | 51 | for i in range(all_num): 52 | # k-reciprocal neighbors 53 | forward_k_neigh_index = initial_rank[i,:k1+1] 54 | backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1] 55 | fi = np.where(backward_k_neigh_index==i)[0] 56 | k_reciprocal_index = forward_k_neigh_index[fi] 57 | k_reciprocal_expansion_index = k_reciprocal_index 58 | for j in range(len(k_reciprocal_index)): 59 | candidate = k_reciprocal_index[j] 60 | candidate_forward_k_neigh_index = initial_rank[candidate,:int(np.around(k1/2.))+1] 61 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,:int(np.around(k1/2.))+1] 62 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 63 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 64 | if len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index))> 2./3*len(candidate_k_reciprocal_index): 65 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index) 66 | 67 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 68 | weight = np.exp(-original_dist[i,k_reciprocal_expansion_index]) 69 | V[i,k_reciprocal_expansion_index] = 1.*weight/np.sum(weight) 70 | original_dist = original_dist[:query_num,] 71 | if k2 != 1: 72 | V_qe = np.zeros_like(V,dtype=np.float32) 73 | for i in range(all_num): 74 | V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:],axis=0) 75 | V = V_qe 76 | del V_qe 77 | del initial_rank 78 | invIndex = [] 79 | for i in range(gallery_num): 80 | invIndex.append(np.where(V[:,i] != 0)[0]) 81 | 82 | jaccard_dist = np.zeros_like(original_dist,dtype = np.float32) 83 | 84 | 85 | for i in range(query_num): 86 | temp_min = np.zeros(shape=[1,gallery_num],dtype=np.float32) 87 | indNonZero = np.where(V[i,:] != 0)[0] 88 | indImages = [] 89 | indImages = [invIndex[ind] for ind in indNonZero] 90 | for j in range(len(indNonZero)): 91 | temp_min[0,indImages[j]] = temp_min[0,indImages[j]]+ np.minimum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]]) 92 | jaccard_dist[i] = 1-temp_min/(2.-temp_min) 93 | 94 | final_dist = jaccard_dist*(1-lambda_value) + original_dist*lambda_value 95 | del original_dist 96 | del V 97 | del jaccard_dist 98 | final_dist = final_dist[:query_num,query_num:] 99 | return final_dist 100 | 101 | def k_reciprocal_neigh( initial_rank, i, k1): 102 | forward_k_neigh_index = initial_rank[i,:k1+1] 103 | backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1] 104 | fi = np.where(backward_k_neigh_index==i)[0] 105 | return forward_k_neigh_index[fi] 106 | 107 | def re_ranking_new(q_g_dist, q_q_dist, g_g_dist, k1=20, k2=6, lambda_value=0.3): 108 | # The following naming, e.g. gallery_num, is different from outer scope. 109 | # Don't care about it. 110 | original_dist = np.concatenate( 111 | [np.concatenate([q_q_dist, q_g_dist], axis=1), 112 | np.concatenate([q_g_dist.T, g_g_dist], axis=1)], 113 | axis=0) 114 | original_dist = 2. - 2 * original_dist #np.power(original_dist, 2).astype(np.float32) 115 | original_dist = np.transpose(1. * original_dist/np.max(original_dist,axis = 0)) 116 | V = np.zeros_like(original_dist).astype(np.float32) 117 | #initial_rank = np.argsort(original_dist).astype(np.int32) 118 | # top K1+1 119 | initial_rank = np.argpartition( original_dist, range(1,k1+1) ) 120 | 121 | query_num = q_g_dist.shape[0] 122 | all_num = original_dist.shape[0] 123 | 124 | for i in range(all_num): 125 | # k-reciprocal neighbors 126 | k_reciprocal_index = k_reciprocal_neigh( initial_rank, i, k1) 127 | k_reciprocal_expansion_index = k_reciprocal_index 128 | for j in range(len(k_reciprocal_index)): 129 | candidate = k_reciprocal_index[j] 130 | candidate_k_reciprocal_index = k_reciprocal_neigh( initial_rank, candidate, int(np.around(k1/2))) 131 | if len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index))> 2./3*len(candidate_k_reciprocal_index): 132 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index) 133 | 134 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 135 | weight = np.exp(-original_dist[i,k_reciprocal_expansion_index]) 136 | V[i,k_reciprocal_expansion_index] = 1.*weight/np.sum(weight) 137 | 138 | original_dist = original_dist[:query_num,] 139 | if k2 != 1: 140 | V_qe = np.zeros_like(V,dtype=np.float32) 141 | for i in range(all_num): 142 | V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:],axis=0) 143 | V = V_qe 144 | del V_qe 145 | del initial_rank 146 | invIndex = [] 147 | for i in range(all_num): 148 | invIndex.append(np.where(V[:,i] != 0)[0]) 149 | 150 | jaccard_dist = np.zeros_like(original_dist,dtype = np.float32) 151 | 152 | for i in range(query_num): 153 | temp_min = np.zeros(shape=[1,all_num],dtype=np.float32) 154 | indNonZero = np.where(V[i,:] != 0)[0] 155 | indImages = [] 156 | indImages = [invIndex[ind] for ind in indNonZero] 157 | for j in range(len(indNonZero)): 158 | temp_min[0,indImages[j]] = temp_min[0,indImages[j]]+ np.minimum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]]) 159 | jaccard_dist[i] = 1-temp_min/(2.-temp_min) 160 | 161 | final_dist = jaccard_dist*(1-lambda_value) + original_dist*lambda_value 162 | del original_dist 163 | del V 164 | del jaccard_dist 165 | final_dist = final_dist[:query_num,query_num:] 166 | return final_dist 167 | -------------------------------------------------------------------------------- /trainers/trainer.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import math 3 | import time 4 | import numpy as np 5 | import random 6 | import torch 7 | from torch import nn 8 | from torch.utils.data import DataLoader 9 | from utils.loss import euclidean_dist, hard_example_mining 10 | from utils.meters import AverageMeter 11 | 12 | 13 | class cls_tripletTrainer: 14 | def __init__(self, opt, model, optimzier, criterion, summary_writer): 15 | self.opt = opt 16 | self.model = model 17 | self.optimizer= optimzier 18 | self.criterion = criterion 19 | self.summary_writer = summary_writer 20 | 21 | def train(self, epoch, data_loader): 22 | self.model.train() 23 | 24 | batch_time = AverageMeter() 25 | data_time = AverageMeter() 26 | losses = AverageMeter() 27 | 28 | start = time.time() 29 | for i, inputs in enumerate(data_loader): 30 | data_time.update(time.time() - start) 31 | 32 | # model optimizer 33 | self._parse_data(inputs) 34 | self._forward() 35 | self.optimizer.zero_grad() 36 | self._backward() 37 | self.optimizer.step() 38 | 39 | batch_time.update(time.time() - start) 40 | losses.update(self.loss.item()) 41 | 42 | # tensorboard 43 | global_step = epoch * len(data_loader) + i 44 | self.summary_writer.add_scalar('loss', self.loss.item(), global_step) 45 | self.summary_writer.add_scalar('lr', self.optimizer.param_groups[0]['lr'], global_step) 46 | 47 | start = time.time() 48 | 49 | if (i + 1) % self.opt.print_freq == 0: 50 | print('Epoch: [{}][{}/{}]\t' 51 | 'Batch Time {:.3f} ({:.3f})\t' 52 | 'Data Time {:.3f} ({:.3f})\t' 53 | 'Loss {:.3f} ({:.3f})\t' 54 | .format(epoch, i + 1, len(data_loader), 55 | batch_time.val, batch_time.mean, 56 | data_time.val, data_time.mean, 57 | losses.val, losses.mean)) 58 | param_group = self.optimizer.param_groups 59 | print('Epoch: [{}]\tEpoch Time {:.3f} s\tLoss {:.3f}\t' 60 | 'Lr {:.2e}' 61 | .format(epoch, batch_time.sum, losses.mean, param_group[0]['lr'])) 62 | print() 63 | 64 | def _parse_data(self, inputs): 65 | imgs, pids, _ = inputs 66 | if self.opt.random_crop and random.random() > 0.3: 67 | h, w = imgs.size()[-2:] 68 | start = int((h-2*w)*random.random()) 69 | mask = imgs.new_zeros(imgs.size()) 70 | mask[:, :, start:start+2*w, :] = 1 71 | imgs = imgs * mask 72 | ''' 73 | if random.random() > 0.5: 74 | h, w = imgs.size()[-2:] 75 | for attempt in range(100): 76 | area = h * w 77 | target_area = random.uniform(0.02, 0.4) * area 78 | aspect_ratio = random.uniform(0.3, 3.33) 79 | ch = int(round(math.sqrt(target_area * aspect_ratio))) 80 | cw = int(round(math.sqrt(target_area / aspect_ratio))) 81 | if cw < w and ch < h: 82 | x1 = random.randint(0, h - ch) 83 | y1 = random.randint(0, w - cw) 84 | imgs[:, :, x1:x1+h, y1:y1+w] = 0 85 | break 86 | ''' 87 | self.data = imgs.cuda() 88 | self.target = pids.cuda() 89 | 90 | def _forward(self): 91 | score, feat = self.model(self.data) 92 | self.loss = self.criterion(score, feat, self.target) 93 | 94 | def _backward(self): 95 | self.loss.backward() 96 | -------------------------------------------------------------------------------- /utils/DistWeightDevianceLoss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Variable 6 | import numpy as np 7 | 8 | 9 | def similarity(inputs_): 10 | # Compute similarity mat of deep feature 11 | # n = inputs_.size(0) 12 | sim = torch.matmul(inputs_, inputs_.t()) 13 | return sim 14 | 15 | 16 | def GaussDistribution(data): 17 | """ 18 | :param data: 19 | :return: 20 | """ 21 | mean_value = torch.mean(data) 22 | diff = data - mean_value 23 | std = torch.sqrt(torch.mean(torch.pow(diff, 2))) 24 | return mean_value, std 25 | 26 | 27 | class DistWeightBinDevianceLoss(nn.Module): 28 | def __init__(self, margin=0.5): 29 | super(DistWeightBinDevianceLoss, self).__init__() 30 | self.margin = margin 31 | 32 | def forward(self, inputs, targets): 33 | n = inputs.size(0) 34 | # Compute similarity matrix 35 | sim_mat = similarity(inputs) 36 | # print(sim_mat) 37 | targets = targets.cuda() 38 | # split the positive and negative pairs 39 | eyes_ = Variable(torch.eye(n, n)).cuda() 40 | # eyes_ = Variable(torch.eye(n, n)) 41 | pos_mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 42 | neg_mask = eyes_.eq(eyes_) - pos_mask 43 | pos_mask = pos_mask - eyes_.eq(1) 44 | 45 | pos_sim = torch.masked_select(sim_mat, pos_mask) 46 | neg_sim = torch.masked_select(sim_mat, neg_mask) 47 | 48 | num_instances = len(pos_sim)//n + 1 49 | num_neg_instances = n - num_instances 50 | 51 | pos_sim = pos_sim.resize(len(pos_sim)//(num_instances-1), num_instances-1) 52 | neg_sim = neg_sim.resize( 53 | len(neg_sim) // num_neg_instances, num_neg_instances) 54 | 55 | # clear way to compute the loss first 56 | loss = list() 57 | c = 0 58 | 59 | for i, pos_pair in enumerate(pos_sim): 60 | # print(i) 61 | pos_pair = torch.sort(pos_pair)[0] 62 | neg_pair = torch.sort(neg_sim[i])[0] 63 | 64 | neg_mean, neg_std = GaussDistribution(neg_pair) 65 | prob = torch.exp(torch.pow(neg_pair - neg_mean, 2) / (2*torch.pow(neg_std, 2))) 66 | neg_index = torch.multinomial(prob, num_instances - 1, replacement=False) 67 | 68 | neg_pair = neg_pair[neg_index] 69 | 70 | if len(neg_pair) < 1: 71 | c += 1 72 | continue 73 | if pos_pair[-1].item() > neg_pair[-1].item() + 0.05: 74 | c += 1 75 | 76 | neg_pair = torch.sort(neg_pair)[0] 77 | 78 | if i == 1 and np.random.randint(256) == 1: 79 | print('neg_pair is ---------', neg_pair) 80 | print('pos_pair is ---------', pos_pair.data) 81 | 82 | pos_loss = torch.mean(torch.log(1 + torch.exp(-2*(pos_pair - self.margin)))) 83 | neg_loss = 0.04*torch.mean(torch.log(1 + torch.exp(50*(neg_pair - self.margin)))) 84 | loss.append(pos_loss + neg_loss) 85 | loss = [torch.unsqueeze(l,0) for l in loss] 86 | loss = torch.sum(torch.cat(loss))/n 87 | 88 | prec = float(c)/n 89 | neg_d = torch.mean(neg_sim).item() 90 | pos_d = torch.mean(pos_sim).item() 91 | 92 | return loss, prec, pos_d, neg_d 93 | 94 | 95 | def main(): 96 | data_size = 32 97 | input_dim = 3 98 | output_dim = 2 99 | num_class = 4 100 | # margin = 0.5 101 | x = Variable(torch.rand(data_size, input_dim), requires_grad=False) 102 | # print(x) 103 | w = Variable(torch.rand(input_dim, output_dim), requires_grad=True) 104 | inputs = x.mm(w) 105 | y_ = 8*list(range(num_class)) 106 | targets = Variable(torch.IntTensor(y_)) 107 | 108 | print(DistWeightBinDevianceLoss()(inputs, targets)) 109 | 110 | 111 | if __name__ == '__main__': 112 | main() 113 | print('Congratulations to you!') 114 | 115 | 116 | -------------------------------------------------------------------------------- /utils/LiftedStructure.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Variable 6 | import torch.nn.functional as F 7 | import numpy as np 8 | 9 | 10 | def similarity(inputs_): 11 | # Compute similarity mat of deep feature 12 | # n = inputs_.size(0) 13 | sim = torch.matmul(inputs_, inputs_.t()) 14 | return sim 15 | 16 | def pdist(A, squared = False, eps = 1e-4): 17 | prod = torch.mm(A, A.t()) 18 | norm = prod.diag().unsqueeze(1).expand_as(prod) 19 | res = (norm + norm.t() - 2 * prod).clamp(min = 0) 20 | return res if squared else res.clamp(min = eps).sqrt() 21 | 22 | class LiftedStructureLoss(nn.Module): 23 | def __init__(self, alpha=10, beta=2, margin=0.5, hard_mining=None, **kwargs): 24 | super(LiftedStructureLoss, self).__init__() 25 | self.margin = margin 26 | self.alpha = alpha 27 | self.beta = beta 28 | self.hard_mining = hard_mining 29 | 30 | def forward(self, embeddings, labels): 31 | ''' 32 | score = embeddings 33 | target = labels 34 | loss = 0 35 | counter = 0 36 | bsz = score.size(0) 37 | mag = (score ** 2).sum(1).expand(bsz, bsz) 38 | sim = score.mm(score.transpose(0, 1)) 39 | dist = (mag + mag.transpose(0, 1) - 2 * sim) 40 | dist = torch.nn.functional.relu(dist).sqrt() 41 | 42 | for i in range(bsz): 43 | t_i = target[i].item() 44 | for j in range(i + 1, bsz): 45 | t_j = target[j].item() 46 | if t_i == t_j: 47 | # Negative component 48 | # !! Could do other things (like softmax that weights closer negatives) 49 | l_ni = (self.margin - dist[i][target != t_i]).exp().sum() 50 | l_nj = (self.margin - dist[j][target != t_j]).exp().sum() 51 | l_n = (l_ni + l_nj).log() 52 | # Positive component 53 | l_p = dist[i,j] 54 | loss += torch.nn.functional.relu(l_n + l_p) ** 2 55 | counter += 1 56 | return loss / (2 * counter), 0 57 | ''' 58 | margin = 1.0 59 | eps = 1e-4 60 | d = pdist(embeddings, squared = False, eps = eps) 61 | pos = torch.eq(*[labels.unsqueeze(dim).expand_as(d) for dim in [0, 1]]).type_as(d) 62 | neg_i = torch.mul((margin - d).exp(), 1 - pos).sum(1).expand_as(d) 63 | return torch.sum(F.relu(pos.triu(1) * ((neg_i + neg_i.t()).log() + d)).pow(2)) / (pos.sum() - len(d)), 0 64 | 65 | def main(): 66 | data_size = 32 67 | input_dim = 3 68 | output_dim = 2 69 | num_class = 4 70 | # margin = 0.5 71 | x = Variable(torch.rand(data_size, input_dim), requires_grad=False) 72 | # print(x) 73 | w = Variable(torch.rand(input_dim, output_dim), requires_grad=True) 74 | inputs = x.mm(w) 75 | y_ = 8*list(range(num_class)) 76 | targets = Variable(torch.IntTensor(y_)) 77 | 78 | print(LiftedStructureLoss()(inputs, targets)) 79 | 80 | 81 | if __name__ == '__main__': 82 | main() 83 | print('Congratulations to you!') 84 | 85 | 86 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daizuozhuo/batch-dropblock-network/21c99abb8d85cfb29d56fc57d09c1ecdfe6b6be5/utils/__init__.py -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import random 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | def topk_mask(input, dim, K = 10, **kwargs): 8 | index = input.topk(max(1, min(K, input.size(dim))), dim = dim, **kwargs)[1] 9 | return torch.autograd.Variable(torch.zeros_like(input.data)).scatter(dim, index, 1.0) 10 | 11 | def pdist(A, squared = False, eps = 1e-4): 12 | prod = torch.mm(A, A.t()) 13 | norm = prod.diag().unsqueeze(1).expand_as(prod) 14 | res = (norm + norm.t() - 2 * prod).clamp(min = 0) 15 | return res if squared else res.clamp(min = eps).sqrt() 16 | 17 | 18 | def normalize(x, axis=-1): 19 | """Normalizing to unit length along the specified dimension. 20 | Args: 21 | x: pytorch Variable 22 | Returns: 23 | x: pytorch Variable, same shape as input 24 | """ 25 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) 26 | return x 27 | 28 | 29 | def euclidean_dist(x, y): 30 | """ 31 | Args: 32 | x: pytorch Variable, with shape [m, d] 33 | y: pytorch Variable, with shape [n, d] 34 | Returns: 35 | dist: pytorch Variable, with shape [m, n] 36 | """ 37 | m, n = x.size(0), y.size(0) 38 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 39 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 40 | dist = xx + yy 41 | dist.addmm_(1, -2, x, y.t()) 42 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 43 | return dist 44 | 45 | 46 | def hard_example_mining(dist_mat, labels, margin, return_inds=False): 47 | """For each anchor, find the hardest positive and negative sample. 48 | Args: 49 | dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N] 50 | labels: pytorch LongTensor, with shape [N] 51 | return_inds: whether to return the indices. Save time if `False`(?) 52 | Returns: 53 | dist_ap: pytorch Variable, distance(anchor, positive); shape [N] 54 | dist_an: pytorch Variable, distance(anchor, negative); shape [N] 55 | p_inds: pytorch LongTensor, with shape [N]; 56 | indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1 57 | n_inds: pytorch LongTensor, with shape [N]; 58 | indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1 59 | NOTE: Only consider the case in which all labels have same num of samples, 60 | thus we can cope with all anchors in parallel. 61 | """ 62 | 63 | torch.set_printoptions(threshold=5000) 64 | assert len(dist_mat.size()) == 2 65 | assert dist_mat.size(0) == dist_mat.size(1) 66 | N = dist_mat.size(0) 67 | 68 | # shape [N, N] 69 | is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()) 70 | is_neg = labels.expand(N, N).ne(labels.expand(N, N).t()) 71 | # `dist_ap` means distance(anchor, positive) 72 | # both `dist_ap` and `relative_p_inds` with shape [N, 1] 73 | dist_ap, relative_p_inds = torch.max( 74 | dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True) 75 | # `dist_an` means distance(anchor, negative) 76 | # both `dist_an` and `relative_n_inds` with shape [N, 1] 77 | dist_an, relative_n_inds = torch.min( 78 | dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True) 79 | # shape [N] 80 | dist_ap = dist_ap.squeeze(1) 81 | dist_an = dist_an.squeeze(1) 82 | 83 | if return_inds: 84 | # shape [N, N] 85 | ind = (labels.new().resize_as_(labels) 86 | .copy_(torch.arange(0, N).long()) 87 | .unsqueeze(0).expand(N, N)) 88 | # shape [N, 1] 89 | p_inds = torch.gather( 90 | ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data) 91 | n_inds = torch.gather( 92 | ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data) 93 | # shape [N] 94 | p_inds = p_inds.squeeze(1) 95 | n_inds = n_inds.squeeze(1) 96 | return dist_ap, dist_an, p_inds, n_inds 97 | 98 | return dist_ap, dist_an 99 | 100 | 101 | class TripletLoss(object): 102 | """Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid). 103 | Related Triplet Loss theory can be found in paper 'In Defense of the Triplet 104 | Loss for Person Re-Identification'.""" 105 | 106 | def __init__(self, margin=None): 107 | self.margin = margin 108 | if margin is not None: 109 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 110 | else: 111 | self.ranking_loss = nn.SoftMarginLoss() 112 | 113 | def __call__(self, global_feat, labels, normalize_feature=False): 114 | if normalize_feature: 115 | global_feat = normalize(global_feat, axis=-1) 116 | dist_mat = euclidean_dist(global_feat, global_feat) 117 | dist_ap, dist_an = hard_example_mining(dist_mat, labels, self.margin) 118 | y = dist_an.new().resize_as_(dist_an).fill_(1) 119 | if self.margin is not None: 120 | loss = self.ranking_loss(dist_an, dist_ap, y) 121 | else: 122 | loss = self.ranking_loss(dist_an - dist_ap, y) 123 | return loss, dist_ap, dist_an 124 | 125 | 126 | class CrossEntropyLabelSmooth(nn.Module): 127 | """Cross entropy loss with label smoothing regularizer. 128 | Reference: 129 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 130 | Equation: y = (1 - epsilon) * y + epsilon / K. 131 | Args: 132 | num_classes (int): number of classes. 133 | epsilon (float): weight. 134 | """ 135 | 136 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True): 137 | super(CrossEntropyLabelSmooth, self).__init__() 138 | self.num_classes = num_classes 139 | self.epsilon = epsilon 140 | self.use_gpu = use_gpu 141 | self.logsoftmax = nn.LogSoftmax(dim=1) 142 | 143 | def forward(self, inputs, targets): 144 | """ 145 | Args: 146 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 147 | targets: ground truth labels with shape (num_classes) 148 | """ 149 | log_probs = self.logsoftmax(inputs) 150 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).cpu(), 1) 151 | if self.use_gpu: targets = targets.cuda() 152 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 153 | loss = (- targets * log_probs).mean(0).sum() 154 | return loss 155 | 156 | class Margin: 157 | def __call__(self, embeddings, labels): 158 | embeddings = F.normalize(embeddings) 159 | alpha = 0.2 160 | beta = 1.2 161 | distance_threshold = 0.5 162 | inf = 1e6 163 | eps = 1e-6 164 | distance_weighted_sampling = True 165 | d = pdist(embeddings) 166 | pos = torch.eq(*[labels.unsqueeze(dim).expand_as(d) for dim in [0, 1]]).type_as(d) - torch.autograd.Variable(torch.eye(len(d))).type_as(d) 167 | num_neg = int(pos.data.sum() / len(pos)) 168 | if distance_weighted_sampling: 169 | ''' 170 | dim = embeddings.size(-1) 171 | distance = d.data.clamp(min = distance_threshold) 172 | distribution = distance.pow(dim - 2) * ((1 - distance.pow(2) / 4).pow(0.5 * (dim - 3))) 173 | weights = distribution.reciprocal().masked_fill_(pos.data + torch.eye(len(d)).type_as(d.data) > 0, eps) 174 | samples = torch.multinomial(weights, replacement = False, num_samples = num_neg) 175 | neg = torch.autograd.Variable(torch.zeros_like(pos.data).scatter_(1, samples, 1)) 176 | ''' 177 | neg = torch.autograd.Variable(torch.zeros_like(pos.data).scatter_(1, torch.multinomial((d.data.clamp(min = distance_threshold).pow(embeddings.size(-1) - 2) * (1 - d.data.clamp(min = distance_threshold).pow(2) / 4).pow(0.5 * (embeddings.size(-1) - 3))).reciprocal().masked_fill_(pos.data + torch.eye(len(d)).type_as(d.data) > 0, eps), replacement = False, num_samples = num_neg), 1)) 178 | else: 179 | neg = topk_mask(d + inf * ((pos > 0) + (d < distance_threshold)).type_as(d), dim = 1, largest = False, K = num_neg) 180 | L = F.relu(alpha + (pos * 2 - 1) * (d - beta)) 181 | M = ((pos + neg > 0) * (L > 0)).float() 182 | return (M * L).sum() / M.sum(), 0 183 | 184 | -------------------------------------------------------------------------------- /utils/meters.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import math 3 | 4 | import numpy as np 5 | 6 | 7 | class AverageMeter(object): 8 | def __init__(self): 9 | self.n = 0 10 | self.sum = 0.0 11 | self.var = 0.0 12 | self.val = 0.0 13 | self.mean = np.nan 14 | self.std = np.nan 15 | 16 | def update(self, value, n=1): 17 | self.val = value 18 | self.sum += value 19 | self.var += value * value 20 | self.n += n 21 | 22 | if self.n == 0: 23 | self.mean, self.std = np.nan, np.nan 24 | elif self.n == 1: 25 | self.mean, self.std = self.sum, np.inf 26 | else: 27 | self.mean = self.sum / self.n 28 | self.std = math.sqrt( 29 | (self.var - self.n * self.mean * self.mean) / (self.n - 1.0)) 30 | 31 | def value(self): 32 | return self.mean, self.std 33 | 34 | def reset(self): 35 | self.n = 0 36 | self.sum = 0.0 37 | self.var = 0.0 38 | self.val = 0.0 39 | self.mean = np.nan 40 | self.std = np.nan 41 | -------------------------------------------------------------------------------- /utils/random_erasing.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torchvision.transforms import * 4 | 5 | from PIL import Image 6 | import random 7 | import math 8 | import numpy as np 9 | import torch 10 | 11 | class Cutout(object): 12 | def __init__(self, probability = 0.5, size = 64, mean=[0.4914, 0.4822, 0.4465]): 13 | self.probability = probability 14 | self.mean = mean 15 | self.size = size 16 | 17 | def __call__(self, img): 18 | 19 | if random.uniform(0, 1) > self.probability: 20 | return img 21 | 22 | h = self.size 23 | w = self.size 24 | for attempt in range(100): 25 | area = img.size()[1] * img.size()[2] 26 | if w < img.size()[2] and h < img.size()[1]: 27 | x1 = random.randint(0, img.size()[1] - h) 28 | y1 = random.randint(0, img.size()[2] - w) 29 | if img.size()[0] == 3: 30 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 31 | img[1, x1:x1+h, y1:y1+w] = self.mean[1] 32 | img[2, x1:x1+h, y1:y1+w] = self.mean[2] 33 | else: 34 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 35 | return img 36 | return img 37 | 38 | class RandomErasing(object): 39 | """ Randomly selects a rectangle region in an image and erases its pixels. 40 | 'Random Erasing Data Augmentation' by Zhong et al. 41 | See https://arxiv.org/pdf/1708.04896.pdf 42 | Args: 43 | probability: The probability that the Random Erasing operation will be performed. 44 | sl: Minimum proportion of erased area against input image. 45 | sh: Maximum proportion of erased area against input image. 46 | r1: Minimum aspect ratio of erased area. 47 | mean: Erasing value. 48 | """ 49 | 50 | def __init__(self, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3, mean=[0.4914, 0.4822, 0.4465]): 51 | self.probability = probability 52 | self.mean = mean 53 | self.sl = sl 54 | self.sh = sh 55 | self.r1 = r1 56 | 57 | def __call__(self, img): 58 | 59 | if random.uniform(0, 1) > self.probability: 60 | return img 61 | 62 | for attempt in range(100): 63 | area = img.size()[1] * img.size()[2] 64 | 65 | target_area = random.uniform(self.sl, self.sh) * area 66 | aspect_ratio = random.uniform(self.r1, 1/self.r1) 67 | 68 | h = int(round(math.sqrt(target_area * aspect_ratio))) 69 | w = int(round(math.sqrt(target_area / aspect_ratio))) 70 | 71 | if w < img.size()[2] and h < img.size()[1]: 72 | x1 = random.randint(0, img.size()[1] - h) 73 | y1 = random.randint(0, img.size()[2] - w) 74 | if img.size()[0] == 3: 75 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 76 | img[1, x1:x1+h, y1:y1+w] = self.mean[1] 77 | img[2, x1:x1+h, y1:y1+w] = self.mean[2] 78 | else: 79 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 80 | return img 81 | 82 | return img 83 | -------------------------------------------------------------------------------- /utils/serialization.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import errno 3 | import os 4 | import shutil 5 | import sys 6 | 7 | import os.path as osp 8 | import torch 9 | 10 | 11 | class Logger(object): 12 | """ 13 | Write console output to external text file. 14 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py. 15 | """ 16 | 17 | def __init__(self, fpath=None): 18 | self.console = sys.stdout 19 | self.file = None 20 | if fpath is not None: 21 | mkdir_if_missing(os.path.dirname(fpath)) 22 | self.file = open(fpath, 'w') 23 | 24 | def __del__(self): 25 | self.close() 26 | 27 | def __enter__(self): 28 | pass 29 | 30 | def __exit__(self, *args): 31 | self.close() 32 | 33 | def write(self, msg): 34 | self.console.write(msg) 35 | if self.file is not None: 36 | self.file.write(msg) 37 | 38 | def flush(self): 39 | self.console.flush() 40 | if self.file is not None: 41 | self.file.flush() 42 | os.fsync(self.file.fileno()) 43 | 44 | def close(self): 45 | self.console.close() 46 | if self.file is not None: 47 | self.file.close() 48 | 49 | 50 | def mkdir_if_missing(dir_path): 51 | try: 52 | os.makedirs(dir_path) 53 | except OSError as e: 54 | if e.errno != errno.EEXIST: 55 | raise 56 | 57 | 58 | def save_checkpoint(state, is_best, save_dir, filename): 59 | fpath = osp.join(save_dir, filename) 60 | mkdir_if_missing(save_dir) 61 | torch.save(state, fpath) 62 | if is_best: 63 | shutil.copy(fpath, osp.join(save_dir, 'model_best.pth.tar')) 64 | -------------------------------------------------------------------------------- /utils/transforms.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | from PIL import Image 3 | from torchvision import transforms as T 4 | from utils.random_erasing import RandomErasing, Cutout 5 | import random 6 | 7 | 8 | class Random2DTranslation(object): 9 | """ 10 | With a probability, first increase image size to (1 + 1/8), and then perform random crop. 11 | 12 | Args: 13 | height (int): target height. 14 | width (int): target width. 15 | p (float): probability of performing this transformation. Default: 0.5. 16 | """ 17 | 18 | def __init__(self, height, width, p=0.5, interpolation=Image.BILINEAR): 19 | self.height = height 20 | self.width = width 21 | self.p = p 22 | self.interpolation = interpolation 23 | 24 | def __call__(self, img): 25 | """ 26 | Args: 27 | img (PIL Image): Image to be cropped. 28 | 29 | Returns: 30 | PIL Image: Cropped image. 31 | """ 32 | if random.random() < self.p: 33 | return img.resize((self.width, self.height), self.interpolation) 34 | new_width, new_height = int( 35 | round(self.width * 1.125)), int(round(self.height * 1.125)) 36 | resized_img = img.resize((new_width, new_height), self.interpolation) 37 | x_maxrange = new_width - self.width 38 | y_maxrange = new_height - self.height 39 | x1 = int(round(random.uniform(0, x_maxrange))) 40 | y1 = int(round(random.uniform(0, y_maxrange))) 41 | croped_img = resized_img.crop( 42 | (x1, y1, x1 + self.width, y1 + self.height)) 43 | return croped_img 44 | 45 | def pad_shorter(x): 46 | h,w = x.size[-2:] 47 | s = max(h, w) 48 | new_im = Image.new("RGB", (s, s)) 49 | new_im.paste(x, ((s-h)//2, (s-w)//2)) 50 | return new_im 51 | 52 | class TrainTransform(object): 53 | def __init__(self, data): 54 | self.data = data 55 | 56 | def __call__(self, x): 57 | if self.data == 'person': 58 | x = T.Resize((384, 128))(x) 59 | elif self.data == 'car': 60 | x = pad_shorter(x) 61 | x = T.Resize((256, 256))(x) 62 | x = T.RandomCrop((224, 224))(x) 63 | elif self.data == 'cub': 64 | x = pad_shorter(x) 65 | x = T.Resize((256, 256))(x) 66 | x = T.RandomCrop((224, 224))(x) 67 | elif self.data == 'clothes': 68 | x = pad_shorter(x) 69 | x = T.Resize((256, 256))(x) 70 | x = T.RandomCrop((224, 224))(x) 71 | elif self.data == 'product': 72 | x = pad_shorter(x) 73 | x = T.Resize((256, 256))(x) 74 | x = T.RandomCrop((224, 224))(x) 75 | elif self.data == 'cifar': 76 | x = T.Resize((40, 40))(x) 77 | x = T.RandomCrop((32, 32))(x) 78 | x = T.RandomHorizontalFlip()(x) 79 | x = T.ToTensor()(x) 80 | x = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(x) 81 | if self.data == 'person': 82 | x = Cutout(probability = 0.5, size=64, mean=[0.0, 0.0, 0.0])(x) 83 | else: 84 | x = RandomErasing(probability = 0.5, mean=[0.0, 0.0, 0.0])(x) 85 | return x 86 | 87 | 88 | class TestTransform(object): 89 | def __init__(self, data, flip=False): 90 | self.data = data 91 | self.flip = flip 92 | 93 | def __call__(self, x=None): 94 | if self.data == 'cub': 95 | x = pad_shorter(x) 96 | x = T.Resize((256, 256))(x) 97 | elif self.data == 'car': 98 | #x = pad_shorter(x) 99 | x = T.Resize((256, 256))(x) 100 | elif self.data == 'clothes': 101 | x = pad_shorter(x) 102 | x = T.Resize((256, 256))(x) 103 | elif self.data == 'product': 104 | x = pad_shorter(x) 105 | x = T.Resize((224, 224))(x) 106 | elif self.data == 'person': 107 | x = T.Resize((384, 128))(x) 108 | 109 | if self.flip: 110 | x = T.functional.hflip(x) 111 | x = T.ToTensor()(x) 112 | x = T.Normalize(mean=[0.485, 0.456, 0.406], 113 | std=[0.229, 0.224, 0.225])(x) 114 | return x 115 | -------------------------------------------------------------------------------- /utils/validation_metrics.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | def accuracy(score, target, topk=(1,)): 3 | maxk = max(topk) 4 | batch_size = target.size(0) 5 | 6 | _, pred = score.topk(maxk, 1, True, True) 7 | pred = pred.t() 8 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 9 | 10 | ret = [] 11 | for k in topk: 12 | correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True) 13 | ret.append(correct_k.mul_(1. / batch_size)) 14 | return ret 15 | --------------------------------------------------------------------------------