├── .gitignore ├── LICENSE ├── README.md ├── backbone.py ├── balanced_sampler.py ├── evaluate.py ├── get_market1501.sh ├── market1501.py ├── random_erasing.py ├── sphere_loss.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | 107 | ## Coin 108 | dataset/ 109 | dataset 110 | res/ 111 | adj.md 112 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 CoinCheung 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SphereReID 2 | 3 | This is my implementation of [SphereReID](https://arxiv.org/abs/1807.00537). 4 | 5 | My working environment is python3.5.2, and my pytorch version is 0.4.0. If things are not going well on your system, please check you environment. 6 | 7 | I only implement the *network-D* in the paper which is claimed to have highest performance of the four networks that the author proposed. 8 | 9 | 10 | ### Get Market1501 dataset 11 | Execute the script in the command line: 12 | ``` 13 | $ sh get_market1501.sh 14 | ``` 15 | 16 | 17 | ### Train and Evaluate 18 | * To train the model, just run the training script: 19 | ``` 20 | $ python train.py 21 | ``` 22 | This will train the model and save the parameters to the directory of ```res/```. 23 | 24 | * To embed the gallery and query set with the trained model and compute the accuracy, directly run: 25 | ``` 26 | $ python evaluate.py 27 | ``` 28 | This will embed the gallery and query set, and then compute cmc and mAP. 29 | 30 | 31 | ### Notes: 32 | Sadly, I am not able to reproduce the result merely with the method mentioned in the paper. So I add a few other tricks beyond the paper which help to boost the performance, these tricks includes: 33 | 34 | * During training phase, use [random erasing](https://arxiv.org/abs/1708.04896) augumentation method. 35 | 36 | * During embedding phase, aggregate the embeddings of the original pictures and those of their horizontal counterparts by computing the average of these embeddings, as done in [MGN](https://arxiv.org/pdf/1804.01438.pdf). 37 | 38 | * Change the stride of the last stage of resnet50 backbone from 2 to 1. 39 | 40 | * Adjust the total training epoch number to 150, and let the learning rate jump by a factor of 0.1 at epoch 90 and 130. 41 | 42 | With these tricks, the rank-1 cmc and mAP of my implementation reaches 93.08 and 83.01. 43 | -------------------------------------------------------------------------------- /backbone.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torchvision 9 | import torch.utils.model_zoo as model_zoo 10 | 11 | 12 | resnet50_url = 'https://download.pytorch.org/models/resnet50-19c8e357.pth' 13 | 14 | 15 | class Network_D(nn.Module): 16 | def __init__(self, *args, **kwargs): 17 | super(Network_D, self).__init__(*args, **kwargs) 18 | resnet50 = torchvision.models.resnet50() 19 | 20 | self.conv1 = resnet50.conv1 21 | self.bn1 = resnet50.bn1 22 | self.relu = resnet50.relu 23 | self.maxpool = resnet50.maxpool 24 | self.layer1 = create_layer(64, 64, 3, stride=1) 25 | self.layer2 = create_layer(256, 128, 4, stride=2) 26 | self.layer3 = create_layer(512, 256, 6, stride=2) 27 | self.layer4 = create_layer(1024, 512, 3, stride=1) 28 | self.bn2 = nn.BatchNorm1d(2048) 29 | self.dp = nn.Dropout(0.5) 30 | self.fc = nn.Linear(in_features=2048, out_features=1024, bias=True) 31 | self.bn3 = nn.BatchNorm1d(1024) 32 | 33 | # load pretrained weights and initialize added weight 34 | pretrained_state = model_zoo.load_url(resnet50_url) 35 | state_dict = self.state_dict() 36 | for k, v in pretrained_state.items(): 37 | if 'fc' in k: 38 | continue 39 | state_dict.update({k: v}) 40 | self.load_state_dict(state_dict) 41 | nn.init.kaiming_normal_(self.fc.weight, a=1) 42 | nn.init.constant_(self.fc.bias, 0) 43 | 44 | def forward(self, x): 45 | x = self.conv1(x) 46 | x = self.bn1(x) 47 | x = self.relu(x) 48 | x = self.maxpool(x) 49 | x = self.layer1(x) 50 | x = self.layer2(x) 51 | x = self.layer3(x) 52 | x = self.layer4(x) 53 | x = F.avg_pool2d(x, x.size()[2:]).view(x.size()[:2]) 54 | x = self.bn2(x) 55 | x = self.dp(x) 56 | x = self.fc(x) 57 | embd = self.bn3(x) 58 | if not self.training: 59 | embd_norm = torch.norm(embd, 2, 1, True).clamp(min=1e-12).expand_as(embd) 60 | embd = embd / embd_norm 61 | return embd 62 | 63 | 64 | class Bottleneck(nn.Module): 65 | def __init__(self, in_chan, mid_chan, stride=1, stride_at_1x1=False, *args, **kwargs): 66 | super(Bottleneck, self).__init__(*args, **kwargs) 67 | 68 | stride1x1, stride3x3 = (stride, 1) if stride_at_1x1 else (1, stride) 69 | 70 | out_chan = 4 * mid_chan 71 | self.conv1 = nn.Conv2d(in_chan, mid_chan, kernel_size=1, stride=stride1x1, 72 | bias=False) 73 | self.bn1 = nn.BatchNorm2d(mid_chan) 74 | self.conv2 = nn.Conv2d(mid_chan, mid_chan, kernel_size=3, stride=stride3x3, 75 | padding=1, bias=False) 76 | self.bn2 = nn.BatchNorm2d(mid_chan) 77 | self.conv3 = nn.Conv2d(mid_chan, out_chan, kernel_size=1, bias=False) 78 | self.bn3 = nn.BatchNorm2d(out_chan) 79 | self.relu = nn.ReLU(inplace=True) 80 | 81 | self.downsample = None 82 | if in_chan != out_chan or stride != 1: 83 | self.downsample = nn.Sequential( 84 | nn.Conv2d(in_chan, out_chan, kernel_size=1, stride=stride, bias=False), 85 | nn.BatchNorm2d(out_chan)) 86 | 87 | def forward(self, x): 88 | out = self.conv1(x) 89 | out = self.bn1(out) 90 | out = self.relu(out) 91 | out = self.conv2(out) 92 | out = self.bn2(out) 93 | out = self.relu(out) 94 | out = self.conv3(out) 95 | out = self.bn3(out) 96 | 97 | if self.downsample == None: 98 | residual = x 99 | else: 100 | residual = self.downsample(x) 101 | 102 | out += residual 103 | out = self.relu(out) 104 | 105 | return out 106 | 107 | 108 | def create_layer(in_chan, mid_chan, b_num, stride): 109 | out_chan = mid_chan * 4 110 | blocks = [Bottleneck(in_chan, mid_chan, stride=stride),] 111 | for i in range(1, b_num): 112 | blocks.append(Bottleneck(out_chan, mid_chan, stride=1)) 113 | return nn.Sequential(*blocks) 114 | 115 | 116 | 117 | if __name__ == '__main__': 118 | intensor = torch.randn(10, 3, 256, 128) 119 | net = Network_D() 120 | out = net(intensor) 121 | print(out.shape) 122 | 123 | params = list(net.parameters()) 124 | optim = torch.optim.Adam(params, lr = 1e-3, weight_decay = 5e-4) 125 | lr = 3 126 | optim.defaults['lr'] = 4 127 | for param_group in optim.param_groups: 128 | param_group['lr'] = lr 129 | print(param_group.keys()) 130 | print(param_group['lr']) 131 | print(optim.defaults['lr']) 132 | print(optim.defaults.keys()) 133 | print(net) 134 | -------------------------------------------------------------------------------- /balanced_sampler.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import random 6 | import numpy as np 7 | import torch 8 | from torch.utils.data.sampler import Sampler 9 | 10 | 11 | class BalancedSampler(Sampler): 12 | def __init__(self, data_source, P, K, *args, **kwargs): 13 | super(BalancedSampler, self).__init__(data_source, *args, **kwargs) 14 | 15 | self.data_source = data_source 16 | self.P, self.K = P, K 17 | self.person_infos = data_source.person_infos 18 | self.persons = list(data_source.person_infos.keys()) 19 | random.shuffle(self.persons) 20 | self.iter_num = len(self.persons) // P 21 | 22 | 23 | def __iter__(self): 24 | random.shuffle(self.persons) 25 | curr_p = 0 26 | for it in range(self.iter_num): 27 | pids = self.persons[curr_p: curr_p + self.P] 28 | curr_p += self.P 29 | ids = [] 30 | for pid in pids: 31 | if len(self.person_infos[pid]) >= self.K: 32 | id_sam = np.random.choice(self.person_infos[pid], self.K, False) 33 | ids.extend(id_sam.tolist()) 34 | else: 35 | id_sam = np.random.choice(self.person_infos[pid], self.K, True) 36 | ids.extend(id_sam.tolist()) 37 | yield ids 38 | 39 | 40 | def __len__(self): 41 | return self.iter_num 42 | 43 | 44 | 45 | if __name__ == "__main__": 46 | from torch.utils.data import DataLoader 47 | from market1501 import Market1501 48 | import cv2 49 | ds = Market1501('./dataset/Market-1501-v15.09.15/bounding_box_train') 50 | sampler1 = BalancedSampler(ds, 2, 4) 51 | sampler2 = BalancedSampler(ds, 2, 4) 52 | dl1 = DataLoader(ds, batch_sampler = sampler1, num_workers = 1) 53 | dl2 = DataLoader(ds, batch_sampler = sampler2, num_workers = 1) 54 | 55 | for jj in range(2): 56 | for i, ((imgs1, lbs1, ids1), (imgs2, lbs2, ids2)) in enumerate(zip(dl1, dl2)): 57 | print(i) 58 | print(lbs1) 59 | print(lbs2) 60 | 61 | if i == 4: break 62 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import logging 9 | import pickle 10 | from tqdm import tqdm 11 | import numpy as np 12 | import torch 13 | from backbone import Network_D 14 | from torch.utils.data import DataLoader 15 | from market1501 import Market1501 16 | 17 | 18 | 19 | FORMAT = '%(levelname)s %(filename)s(%(lineno)d): %(message)s' 20 | logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout) 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | def embed(): 25 | ## load checkpoint 26 | res_pth = './res' 27 | mod_pth = osp.join(res_pth, 'model_final.pkl') 28 | net = Network_D() 29 | net.load_state_dict(torch.load(mod_pth)) 30 | net.cuda() 31 | net.eval() 32 | 33 | ## data loader 34 | query_set = Market1501('./dataset/Market-1501-v15.09.15/query', 35 | is_train = False) 36 | gallery_set = Market1501('./dataset/Market-1501-v15.09.15/bounding_box_test', 37 | is_train = False) 38 | query_loader = DataLoader(query_set, 39 | batch_size = 32, 40 | num_workers = 4, 41 | drop_last = False) 42 | gallery_loader = DataLoader(gallery_set, 43 | batch_size = 32, 44 | num_workers = 4, 45 | drop_last = False) 46 | 47 | ## embed 48 | logger.info('embedding query set ...') 49 | query_pids = [] 50 | query_camids = [] 51 | query_embds = [] 52 | for i, (im, _, ids) in enumerate(tqdm(query_loader)): 53 | embds = [] 54 | for crop in im: 55 | crop = crop.cuda() 56 | embds.append(net(crop).detach().cpu().numpy()) 57 | embed = sum(embds) / len(embds) 58 | pid = ids[0].numpy() 59 | camid = ids[1].numpy() 60 | query_embds.append(embed) 61 | query_pids.extend(pid) 62 | query_camids.extend(camid) 63 | query_embds = np.vstack(query_embds) 64 | query_pids = np.array(query_pids) 65 | query_camids = np.array(query_camids) 66 | 67 | logger.info('embedding gallery set ...') 68 | gallery_pids = [] 69 | gallery_camids = [] 70 | gallery_embds = [] 71 | for i, (im, _, ids) in enumerate(tqdm(gallery_loader)): 72 | embds = [] 73 | for crop in im: 74 | crop = crop.cuda() 75 | embds.append(net(crop).detach().cpu().numpy()) 76 | embed = sum(embds) / len(embds) 77 | pid = ids[0].numpy() 78 | camid = ids[1].numpy() 79 | gallery_embds.append(embed) 80 | gallery_pids.extend(pid) 81 | gallery_camids.extend(camid) 82 | gallery_embds = np.vstack(gallery_embds) 83 | gallery_pids = np.array(gallery_pids) 84 | gallery_camids = np.array(gallery_camids) 85 | 86 | ## dump embeds results 87 | embd_res = (query_embds, query_pids, query_camids, gallery_embds, gallery_pids, gallery_camids) 88 | with open('./res/embds.pkl', 'wb') as fw: 89 | pickle.dump(embd_res, fw) 90 | logger.info('embedding done, dump to: ./res/embds.pkl') 91 | 92 | return embd_res 93 | 94 | 95 | def evaluate(embd_res, cmc_max_rank = 1): 96 | query_embds, query_pids, query_camids, gallery_embds, gallery_pids, gallery_camids = embd_res 97 | 98 | ## compute distance matrix 99 | logger.info('compute distance matrix') 100 | dist_mtx = np.matmul(query_embds, gallery_embds.T) 101 | dist_mtx = 1.0 / (dist_mtx + 1) 102 | n_q, n_g = dist_mtx.shape 103 | 104 | logger.info('start evaluating ...') 105 | indices = np.argsort(dist_mtx, axis = 1) 106 | matches = gallery_pids[indices] == query_pids[:, np.newaxis] 107 | matches = matches.astype(np.int32) 108 | all_aps = [] 109 | all_cmcs = [] 110 | for query_idx in tqdm(range(n_q)): 111 | query_pid = query_pids[query_idx] 112 | query_camid = query_camids[query_idx] 113 | 114 | ## exclude duplicated gallery pictures 115 | order = indices[query_idx] 116 | pid_diff = gallery_pids[order] != query_pid 117 | camid_diff = gallery_camids[order] != query_camid 118 | useful = gallery_pids[order] != -1 119 | keep = np.logical_or(pid_diff, camid_diff) 120 | keep = np.logical_and(keep, useful) 121 | match = matches[query_idx][keep] 122 | 123 | if not np.any(match): continue 124 | 125 | ## compute cmc 126 | cmc = match.cumsum() 127 | cmc[cmc > 1] = 1 128 | all_cmcs.append(cmc[:cmc_max_rank]) 129 | 130 | ## compute map 131 | num_real = match.sum() 132 | match_cum = match.cumsum() 133 | match_cum = [el / (1.0 + i) for i, el in enumerate(match_cum)] 134 | match_cum = np.array(match_cum) * match 135 | ap = match_cum.sum() / num_real 136 | all_aps.append(ap) 137 | 138 | assert len(all_aps) > 0, "NO QUERRY APPEARS IN THE GALLERY" 139 | mAP = sum(all_aps) / len(all_aps) 140 | all_cmcs = np.array(all_cmcs, dtype = np.float32) 141 | cmc = np.mean(all_cmcs, axis = 0) 142 | 143 | return cmc, mAP 144 | 145 | 146 | if __name__ == '__main__': 147 | embd_res = embed() 148 | with open('./res/embds.pkl', 'rb') as fr: 149 | embd_res = pickle.load(fr) 150 | 151 | cmc, mAP = evaluate(embd_res) 152 | print('cmc is: {}, map is: {}'.format(cmc, mAP)) 153 | -------------------------------------------------------------------------------- /get_market1501.sh: -------------------------------------------------------------------------------- 1 | 2 | mkdir -p dataset 3 | cd dataset 4 | wget -c http://188.138.127.15:81/Datasets/Market-1501-v15.09.15.zip 5 | unzip Market-1501-v15.09.15.zip 6 | -------------------------------------------------------------------------------- /market1501.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import os 6 | import os.path as osp 7 | import numpy as np 8 | import cv2 9 | import random 10 | import math 11 | from PIL import Image 12 | import torch 13 | import torchvision.transforms as transforms 14 | from torch.utils.data import Dataset 15 | 16 | from random_erasing import RandomErasing 17 | 18 | 19 | class Market1501(Dataset): 20 | def __init__(self, data_pth, is_train = True, *args, **kwargs): 21 | super(Market1501, self).__init__(*args, **kwargs) 22 | 23 | ## parse image names to generate image ids 24 | imgs = os.listdir(data_pth) 25 | imgs = [im for im in imgs if osp.splitext(im)[-1] == '.jpg'] 26 | self.is_train = is_train 27 | self.im_pths = [osp.join(data_pth, im) for im in imgs] 28 | self.im_infos = {} 29 | self.person_infos = {} 30 | for i, im in enumerate(imgs): 31 | tokens = im.split('_') 32 | im_pth = self.im_pths[i] 33 | pid = int(tokens[0]) 34 | cam = int(tokens[1][1]) 35 | self.im_infos.update({im_pth: (pid, cam)}) 36 | if pid in self.person_infos.keys(): 37 | self.person_infos[pid].append(i) 38 | else: 39 | self.person_infos[pid] = [i, ] 40 | 41 | self.pid_label_map = {} 42 | for i, (pid, ids) in enumerate(self.person_infos.items()): 43 | self.person_infos[pid] = np.array(ids, dtype = np.int32) 44 | self.pid_label_map[pid] = i 45 | 46 | ## preprocessing 47 | self.trans_train = transforms.Compose([ 48 | transforms.Resize((288, 144)), 49 | transforms.RandomCrop((256, 128)), 50 | transforms.RandomHorizontalFlip(), 51 | transforms.ToTensor(), 52 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 53 | RandomErasing(0.5, mean=[0.0, 0.0, 0.0]) 54 | ]) 55 | ## H-Flip 56 | self.trans_no_train_flip = transforms.Compose([ 57 | transforms.Resize((288, 144)), 58 | transforms.RandomHorizontalFlip(1), 59 | transforms.ToTensor(), 60 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 61 | ]) 62 | self.trans_no_train_noflip = transforms.Compose([ 63 | transforms.Resize((288, 144)), 64 | transforms.ToTensor(), 65 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 66 | ]) 67 | 68 | 69 | def __getitem__(self, idx): 70 | im_pth = self.im_pths[idx] 71 | pid = self.im_infos[im_pth][0] 72 | im = Image.open(im_pth) 73 | 74 | if self.is_train: 75 | im = self.trans_train(im) 76 | else: 77 | im_noflip = self.trans_no_train_noflip(im) 78 | im_flip = self.trans_no_train_flip(im) 79 | im = [im_noflip, im_flip] 80 | return im, self.pid_label_map[pid], self.im_infos[im_pth] 81 | 82 | def __len__(self): 83 | return len(self.im_pths) 84 | 85 | def get_num_classes(self): 86 | return len(list(self.person_infos.keys())) 87 | 88 | 89 | if __name__ == "__main__": 90 | ds_train = Market1501('./dataset/Market-1501-v15.09.15/bounding_box_train') 91 | ds_test = Market1501('./dataset/Market-1501-v15.09.15/bounding_box_test', is_train = False) 92 | im, lb, _ = ds_train[10] 93 | -------------------------------------------------------------------------------- /random_erasing.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torchvision.transforms import * 4 | 5 | from PIL import Image 6 | import random 7 | import math 8 | import numpy as np 9 | import torch 10 | 11 | class RandomErasing(object): 12 | """ Randomly selects a rectangle region in an image and erases its pixels. 13 | 'Random Erasing Data Augmentation' by Zhong et al. 14 | See https://arxiv.org/pdf/1708.04896.pdf 15 | Args: 16 | probability: The probability that the Random Erasing operation will be performed. 17 | sl: Minimum proportion of erased area against input image. 18 | sh: Maximum proportion of erased area against input image. 19 | r1: Minimum aspect ratio of erased area. 20 | mean: Erasing value. 21 | """ 22 | 23 | def __init__(self, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3, mean=[0.4914, 0.4822, 0.4465]): 24 | self.probability = probability 25 | self.mean = mean 26 | self.sl = sl 27 | self.sh = sh 28 | self.r1 = r1 29 | 30 | def __call__(self, img): 31 | 32 | if random.uniform(0, 1) > self.probability: 33 | return img 34 | 35 | for attempt in range(100): 36 | area = img.size()[1] * img.size()[2] 37 | 38 | target_area = random.uniform(self.sl, self.sh) * area 39 | aspect_ratio = random.uniform(self.r1, 1/self.r1) 40 | 41 | h = int(round(math.sqrt(target_area * aspect_ratio))) 42 | w = int(round(math.sqrt(target_area / aspect_ratio))) 43 | 44 | if w < img.size()[2] and h < img.size()[1]: 45 | x1 = random.randint(0, img.size()[1] - h) 46 | y1 = random.randint(0, img.size()[2] - w) 47 | if img.size()[0] == 3: 48 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 49 | img[1, x1:x1+h, y1:y1+w] = self.mean[1] 50 | img[2, x1:x1+h, y1:y1+w] = self.mean[2] 51 | else: 52 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 53 | return img 54 | 55 | return img 56 | -------------------------------------------------------------------------------- /sphere_loss.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class OhemSphereLoss(nn.Module): 10 | def __init__(self, in_feats, n_classes, thresh=0.7, scale=14, *args, **kwargs): 11 | super(OhemSphereLoss, self).__init__(*args, **kwargs) 12 | self.thresh = thresh 13 | self.scale = scale 14 | self.cross_entropy = nn.CrossEntropyLoss(reduction='none') 15 | self.W = torch.nn.Parameter(torch.randn(in_feats, n_classes), 16 | requires_grad = True) 17 | # nn.init.kaiming_normal_(self.W, a=1) 18 | nn.init.xavier_normal_(self.W, gain=1) 19 | 20 | def forward(self, x, label): 21 | n_examples = x.size()[0] 22 | n_pick = int(n_examples*self.thresh) 23 | x_norm = torch.norm(x, 2, 1, True).clamp(min = 1e-12).expand_as(x) 24 | x_norm = x / x_norm 25 | w_norm = torch.norm(self.W, 2, 0, True).clamp(min = 1e-12).expand_as(self.W) 26 | w_norm = self.W / w_norm 27 | cos_th = torch.mm(x_norm, w_norm) 28 | s_cos_th = self.scale * cos_th 29 | loss = self.cross_entropy(s_cos_th, label) 30 | loss, _ = torch.sort(loss, descending=True) 31 | loss = torch.mean(loss[:n_pick]) 32 | return loss 33 | 34 | 35 | class SphereLoss(nn.Module): 36 | def __init__(self, in_feats, n_classes, scale = 14, *args, **kwargs): 37 | super(SphereLoss, self).__init__(*args, **kwargs) 38 | self.scale = scale 39 | self.cross_entropy = nn.CrossEntropyLoss() 40 | self.W = torch.nn.Parameter(torch.randn(in_feats, n_classes), 41 | requires_grad = True) 42 | # nn.init.kaiming_normal_(self.W, a=1) 43 | nn.init.xavier_normal_(self.W, gain=1) 44 | 45 | def forward(self, x, label): 46 | x_norm = torch.norm(x, 2, 1, True).clamp(min = 1e-12).expand_as(x) 47 | x_norm = x / x_norm 48 | w_norm = torch.norm(self.W, 2, 0, True).clamp(min = 1e-12).expand_as(self.W) 49 | w_norm = self.W / w_norm 50 | cos_th = torch.mm(x_norm, w_norm) 51 | s_cos_th = self.scale * cos_th 52 | loss = self.cross_entropy(s_cos_th, label) 53 | return loss 54 | 55 | 56 | if __name__ == '__main__': 57 | Loss = SphereLoss(1024, 10) 58 | a = torch.randn(20, 1024) 59 | lb = torch.ones(20, dtype = torch.long) 60 | loss = Loss(a, lb) 61 | loss.backward() 62 | print(loss.detach().numpy()) 63 | print(list(Loss.parameters())[0].shape) 64 | print(type(next(Loss.parameters()))) 65 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import time 5 | import logging 6 | import os 7 | import sys 8 | import torch 9 | import torch.nn as nn 10 | from torch.utils.data import DataLoader 11 | import numpy as np 12 | 13 | from backbone import Network_D 14 | from sphere_loss import SphereLoss, OhemSphereLoss 15 | from market1501 import Market1501 16 | from balanced_sampler import BalancedSampler 17 | 18 | 19 | ## logging 20 | if not os.path.exists('./res/'): os.makedirs('./res/') 21 | logfile = 'sphere_reid-{}.log'.format(time.strftime('%Y-%m-%d-%H-%M-%S')) 22 | logfile = os.path.join('res', logfile) 23 | FORMAT = '%(levelname)s %(filename)s(%(lineno)d): %(message)s' 24 | logging.basicConfig(level=logging.INFO, format=FORMAT, filename=logfile) 25 | logger = logging.getLogger(__name__) 26 | logger.addHandler(logging.StreamHandler()) 27 | 28 | 29 | # start_lr = 1e-2 30 | 31 | def lr_scheduler(epoch, optimizer): 32 | warmup_epoch = 20 33 | warmup_lr = 1e-5 34 | lr_steps = [90, 130] 35 | start_lr = 1e-3 36 | lr_factor = 0.1 37 | 38 | if epoch <= warmup_epoch: # lr warmup 39 | warmup_scale = (start_lr / warmup_lr) ** (1.0 / warmup_epoch) 40 | lr = warmup_lr * (warmup_scale ** epoch) 41 | for param_group in optimizer.param_groups: 42 | param_group['lr'] = lr 43 | optimizer.defaults['lr'] = lr 44 | else: # lr jump 45 | for i, el in enumerate(lr_steps): 46 | if epoch == el: 47 | lr = start_lr * (lr_factor ** (i + 1)) 48 | logger.info('====> LR is set to: {}'.format(lr)) 49 | for param_group in optimizer.param_groups: 50 | param_group['lr'] = lr 51 | optimizer.defaults['lr'] = lr 52 | lrs = [round(el['lr'], 6) for el in optimizer.param_groups] 53 | return optimizer, lrs 54 | 55 | 56 | def train(): 57 | ## data 58 | logger.info('creating dataloader') 59 | dataset = Market1501('./dataset/Market-1501-v15.09.15/bounding_box_train', 60 | is_train = True) 61 | num_classes = dataset.get_num_classes() 62 | sampler = BalancedSampler(dataset, 16, 4) 63 | dl = DataLoader(dataset, 64 | batch_sampler = sampler, 65 | num_workers = 8) 66 | 67 | ## network and loss 68 | logger.info('setup model and loss') 69 | # criteria = SphereLoss(1024, num_classes) 70 | criteria = OhemSphereLoss(1024, num_classes, thresh=0.8) 71 | criteria.cuda() 72 | net = Network_D() 73 | net.train() 74 | net.cuda() 75 | 76 | ## optimizer 77 | logger.info('creating optimizer') 78 | params = list(net.parameters()) 79 | params += list(criteria.parameters()) 80 | optim = torch.optim.Adam(params, lr = 1e-3) 81 | 82 | ## training 83 | logger.info('start training') 84 | t_start = time.time() 85 | loss_it = [] 86 | for ep in range(150): 87 | optim, lrs = lr_scheduler(ep, optim) 88 | for it, (imgs, lbs, ids) in enumerate(dl): 89 | imgs = imgs.cuda() 90 | lbs = lbs.cuda() 91 | 92 | embs = net(imgs) 93 | loss = criteria(embs, lbs) 94 | optim.zero_grad() 95 | loss.backward() 96 | optim.step() 97 | 98 | loss_it.append(loss.detach().cpu().numpy()) 99 | ## print log 100 | t_end = time.time() 101 | t_interval = t_end - t_start 102 | log_loss = sum(loss_it) / len(loss_it) 103 | msg = 'epoch: {}, iter: {}, loss: {:.4f}, lr: {}, time: {:.4f}'.format(ep, 104 | it, log_loss, lrs, t_interval) 105 | logger.info(msg) 106 | loss_it = [] 107 | t_start = t_end 108 | 109 | ## save model 110 | torch.save(net.state_dict(), './res/model_final.pkl') 111 | logger.info('\nTraining done, model saved to {}\n\n'.format('./res/model_final.pkl')) 112 | 113 | 114 | if __name__ == '__main__': 115 | train() 116 | --------------------------------------------------------------------------------