├── README.md ├── config ├── datasets.py ├── evaluation.py ├── evaluation_old.py ├── layers.py ├── losses.py ├── main.py ├── samplers.py └── visualizer.py /README.md: -------------------------------------------------------------------------------- 1 | # Histogram Loss 2 | 3 | This is implementation of the paper [Learning Deep Embeddings with Histogram Loss](https://arxiv.org/pdf/1611.00822.pdf) in PyTorch 4 | 5 | See original code [here](https://github.com/madkn/HistogramLoss) 6 | 7 | ## Implementation details 8 | 9 | Pretrained resnet 34 is used. Fully connected layer with 512 neurons are added to the end of the net. 10 | 11 | Features should be [l2 normalized](https://github.com/valerystrizh/pytorch-histogram-loss/blob/master/layers.py#L30) before feeding to histogram loss. 12 | 13 | [Market-1501 Dataset](http://www.liangzheng.org/Project/project_reid.html) is used for training and testing. 14 | 15 | Loss, rank 1 and mAP metrics are visualized using [visdom](https://github.com/facebookresearch/visdom) tools. 16 | 17 | ## Quality 18 | rank-1: 77.02 19 | 20 | mAP: 54.71 21 | 22 | ## Usage 23 | Change [config file](https://github.com/valerystrizh/pytorch-histogram-loss/blob/master/config) to set your parameters 24 | 25 | ``` 26 | --dataroot DATAROOT path to dataset 27 | --batch_size BATCH_SIZE 28 | batch size for train, default=128 29 | --batch_size_test BATCH_SIZE_TEST 30 | batch size for test and query dataloaders for market 31 | dataset, default=64 32 | --checkpoints_path CHECKPOINTS_PATH 33 | folder to output model checkpoints, default="." 34 | --cuda enables cuda 35 | --dropout_prob DROPOUT_PROB 36 | probability of dropout, default=0.7 37 | --lr LR learning rate, default=1e-4 38 | --lr_fc LR_FC learning rate to train fc layer, default=1e-1 39 | --manual_seed MANUAL_SEED 40 | manual seed 41 | --market calculate rank1 and mAP on Market dataset; dataroot 42 | should contain folders "bounding_box_train", 43 | "bounding_box_test", "query" 44 | --nbins NBINS number of bins in histograms, default=150 45 | --nepoch NEPOCH number of epochs to train, default=150 46 | --nepoch_fc NEPOCH_FC 47 | number of epochs to train fc layer, default=0 48 | --nworkers NWORKERS number of data loading workers, default=10 49 | --visdom_port VISDOM_PORT 50 | port for visdom visualization 51 | 52 | 53 | ``` 54 | 55 | $ #start visdom server 56 | $ python -m visdom.server -port 8099 57 | $ python main.py 58 | -------------------------------------------------------------------------------- /config: -------------------------------------------------------------------------------- 1 | { 2 | "dataroot": "../data/Market-1501-v15.09.15", 3 | "batch_size": 128, "batch_size_test": 64, 4 | "checkpoints_path": "histogram", 5 | "cuda": true, 6 | "dropout_prob": 0.7, 7 | "lr": 0.0001, 8 | "lr_fc": 0.1, 9 | "manual_seed": 18, 10 | "market": true, 11 | "nbins": 151, 12 | "nepoch": 150, 13 | "nepoch_fc": 0, 14 | "nworkers": 10, 15 | "visdom_port": 8099 16 | } 17 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from PIL import Image 5 | from sklearn.preprocessing import LabelEncoder 6 | from torch.utils.data import Dataset 7 | 8 | class ImageDataset(Dataset): 9 | def __init__(self, paths, transform, is_train, labels=None): 10 | self.img_paths = np.array(paths) 11 | self.transform = transform 12 | self.is_train = is_train 13 | if self.is_train: 14 | self.labels = LabelEncoder().fit_transform(labels) 15 | self.labels = torch.from_numpy(np.array(self.labels, dtype=np.float)) 16 | 17 | def __getitem__(self, index): 18 | img_path = self.img_paths[index] 19 | img = self.transform(Image.open(img_path)) 20 | if self.is_train: 21 | label = self.labels[index] 22 | return img, label 23 | else: 24 | return img 25 | 26 | def __len__(self): 27 | return len(self.img_paths) 28 | 29 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | 6 | class Evaluation(): 7 | def __init__(self, df_test, df_query, dataloader_test, dataloader_query, cuda): 8 | self.test_labels = np.array(df_test['label']) 9 | self.test_cameras = np.array(df_test['camera']) 10 | self.query_labels = np.expand_dims(np.array(df_query['label']), 1) 11 | self.query_cameras = np.expand_dims(np.array(df_query['camera']), 1) 12 | self.distractors = self.test_labels == 0 13 | self.junk = self.test_labels == -1 14 | 15 | self.test_labels = torch.Tensor(self.test_labels) 16 | self.test_cameras = torch.Tensor(self.test_cameras) 17 | self.query_labels = torch.Tensor(self.query_labels) 18 | self.query_cameras = torch.Tensor(self.query_cameras) 19 | self.distractors = torch.Tensor(self.distractors.astype(int)) 20 | self.junk = torch.Tensor(self.junk.astype(int)) 21 | 22 | if cuda: 23 | self.test_labels = self.test_labels.cuda() 24 | self.test_cameras = self.test_cameras.cuda() 25 | self.query_labels = self.query_labels.cuda() 26 | self.query_cameras = self.query_cameras.cuda() 27 | self.distractors = self.distractors.cuda() 28 | self.junk = self.junk.cuda() 29 | 30 | self.dataloader_test = dataloader_test 31 | self.dataloader_query = dataloader_query 32 | self.cuda = cuda 33 | 34 | def ranks_map(self, model, maxrank, remove_fc=False, features_normalized=True): 35 | if remove_fc: 36 | model = nn.Sequential(*list(model.children())[:-1]) 37 | test_descriptors = self.descriptors(self.dataloader_test, model) 38 | query_descriptors = self.descriptors(self.dataloader_query, model) 39 | 40 | # cosine distances between query and test descriptors 41 | if features_normalized: 42 | dists = 1 - torch.mm(query_descriptors, test_descriptors.transpose(1, 0)) 43 | else: 44 | dists = torch.mm(query_descriptors, test_descriptors.transpose(1, 0)) 45 | dists = dists / torch.norm(query_descriptors, 2, 1).unsqueeze(1) 46 | dists = dists / torch.norm(test_descriptors, 2, 1) 47 | dists = 1 - dists 48 | 49 | dists_sorted, dists_sorted_inds = torch.sort(dists) 50 | 51 | # sort test data by indices which sort distances 52 | def sort_by_dists_inds(data): 53 | return torch.gather(data.repeat(self.query_labels.shape[0], 1), 1, dists_sorted_inds) 54 | 55 | test_sorted_labels = sort_by_dists_inds(self.test_labels) 56 | test_sorted_cameras = sort_by_dists_inds(self.test_cameras) 57 | sorted_distractors = sort_by_dists_inds(self.distractors).byte() 58 | sorted_junk = sort_by_dists_inds(self.junk).byte() 59 | 60 | # junk are not taken into account unlike distractors, so junk cumulative sum is calculated to be used later 61 | sorted_junk = (sorted_junk | 62 | (test_sorted_labels == self.query_labels) & 63 | (test_sorted_cameras == self.query_cameras)) 64 | junk_cumsum = torch.cumsum(sorted_junk.int(), 1) 65 | 66 | # indices where query labels equal test labels without distractors and junk 67 | eq_inds = torch.nonzero(~sorted_distractors & ~sorted_junk & (self.query_labels == test_sorted_labels)) 68 | eq_inds_rows = eq_inds[:, 0].long() 69 | eq_inds_cols = eq_inds[:, 1].long() 70 | eq_inds_first = np.unique(eq_inds_rows.cpu().numpy(), return_index=True)[1] 71 | # subtract junk cumsum from columns of indices 72 | eq_inds_cols_nojunk = (eq_inds_cols - junk_cumsum[eq_inds_rows, eq_inds_cols]).cpu().numpy() 73 | 74 | ranks = self.ranks(maxrank, eq_inds_first, eq_inds_cols_nojunk) 75 | mAP = self.mAP(eq_inds_first, eq_inds_cols_nojunk) 76 | 77 | return ranks, mAP 78 | 79 | def descriptors(self, dataloder, model): 80 | result = torch.FloatTensor() 81 | if self.cuda: 82 | result = result.cuda() 83 | for data in dataloder: 84 | if self.cuda: 85 | data = data.cuda() 86 | inputs = Variable(data) 87 | outputs = model(inputs) 88 | result = torch.cat((result, outputs.data), 0) 89 | 90 | return result 91 | 92 | def ranks(self, maxrank, eq_inds_first, eq_inds_cols): 93 | eq_inds_cols_first = eq_inds_cols[eq_inds_first] 94 | eq_inds_cols_first_maxrank = eq_inds_cols_first[eq_inds_cols_first < maxrank] 95 | ranks = np.zeros(maxrank) 96 | np.add.at(ranks, eq_inds_cols_first_maxrank, 1) 97 | ranks = np.cumsum(ranks) 98 | 99 | return ranks / self.query_labels.shape[0] 100 | 101 | def mAP(self, eq_inds_first, eq_inds_cols): 102 | labels_count = np.append(eq_inds_first[1:], eq_inds_cols.shape[0]) - eq_inds_first 103 | inds_start_repeat = np.repeat(eq_inds_first, labels_count) 104 | labels_count_repeat = np.repeat(labels_count, labels_count) 105 | average_precision = np.sum((np.arange(eq_inds_cols.shape[0]) - inds_start_repeat + 1) / 106 | (eq_inds_cols + 1) / 107 | labels_count_repeat) 108 | 109 | return average_precision / self.query_labels.shape[0] 110 | -------------------------------------------------------------------------------- /evaluation_old.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | 6 | from scipy.spatial.distance import cdist 7 | from torch.autograd import Variable 8 | 9 | class Evaluation(): 10 | def __init__(self, df_test, df_query, dataloader_test, dataloader_query, cuda): 11 | self.test_labels, self.test_camera_labels, self.test_image_paths, self.test_image_names = get_info(df_test) 12 | self.query_labels, self.query_camera_labels, self.query_image_paths, self.query_image_names = get_info(df_query) 13 | self.Index = make_index(self.test_image_names, self.query_image_names) 14 | self.dataloader_test = dataloader_test 15 | self.dataloader_query = dataloader_query 16 | self.cuda = cuda 17 | 18 | def ranks_map(self, model, nr, remove_fc=False): 19 | if remove_fc: 20 | model = nn.Sequential(*list(model.children())[:-1]) 21 | 22 | gescr_gallery = get_descr(self.dataloader_test, model, self.cuda) 23 | gescr_query = get_descr(self.dataloader_query, model, self.cuda) 24 | 25 | ranks, img_names_sorted, places = ranking('cosine', gescr_query, self.query_image_names, gescr_gallery, self.test_image_names, nr, self.Index) 26 | ranks_w = ranks / len(gescr_query) * 1.0 27 | mAP_ = mAP(gescr_query, self.query_image_names, gescr_gallery, self.test_image_names, nr, self.Index) 28 | 29 | return ranks_w, mAP_ 30 | 31 | def get_descr(dataloder, model, use_gpu): 32 | result = [] 33 | 34 | for data in dataloder: 35 | if use_gpu: 36 | inputs = Variable(data.cuda()) 37 | else: 38 | inputs = Variable(data) 39 | 40 | out = model(inputs) 41 | result.extend(out.data.squeeze().cpu().numpy()) 42 | 43 | return np.array(result) 44 | 45 | def get_info(df): 46 | labels = np.array(df['label']) 47 | camera_labels = np.array(df['camera']) 48 | image_paths = np.array(df['path']) 49 | image_names = np.array(df['name']) 50 | 51 | return labels, camera_labels, image_paths, image_names 52 | 53 | def make_index(test_image_names, query_image_names): 54 | Index = dict() 55 | Index['junk'] = set() 56 | Index['distractor'] = set() 57 | for name in test_image_names: 58 | if ifJunk(name): 59 | Index['junk'].add(name) 60 | elif ifDistractor(name): 61 | Index['distractor'].add(name) 62 | for query in query_image_names: 63 | Index[query] = dict() 64 | Index[query]['pos'] = set() 65 | Index[query]['junk'] = set() 66 | 67 | person, camera = parse_market_1501_name(query) 68 | for name in test_image_names: 69 | if ifJunk(name) or ifDistractor(name): 70 | continue 71 | person_, camera_ = parse_market_1501_name(name) 72 | if person == person_ and camera != camera_ : 73 | Index[query]['pos'].add(name) 74 | 75 | elif person == person_ and camera == camera : 76 | Index[query]['junk'].add(name) 77 | 78 | return Index 79 | 80 | def parse_market_1501_name(full_name): 81 | name_ar = full_name.split('/') 82 | name = name_ar[len(name_ar)-1] 83 | 84 | person = int(name.split('_')[0]) 85 | camera = int(name.split('_')[1].split('s')[0].split('c')[1]) 86 | 87 | return person, camera 88 | 89 | def parseMarket1501(path): 90 | person_label = [] 91 | camera_label = [] 92 | image_path = [] 93 | image_name = [] 94 | 95 | for file in sorted(os.listdir(path)): 96 | if file.endswith(".jpg"): 97 | person, camera = parse_market_1501_name(file) 98 | 99 | person_label.append(person) 100 | camera_label.append(camera) 101 | image_path.append(os.path.join(path, file)) 102 | image_name.append(file) 103 | 104 | return person_label, camera_label, image_path, image_name 105 | 106 | def ifJunk(filename): 107 | if filename.startswith("-1"): 108 | return True 109 | else: 110 | return False 111 | 112 | def ifDistractor(filename): 113 | if filename.startswith("0000"): 114 | return True 115 | else: 116 | return False 117 | 118 | def getPlace(query, sorted_gallery_filenames, Index): 119 | 120 | place = 0 121 | for i in range(len(sorted_gallery_filenames)): 122 | if sorted_gallery_filenames[i] in Index['junk'] or sorted_gallery_filenames[i] in Index[query]['junk']: 123 | continue 124 | elif sorted_gallery_filenames[i] in Index['distractor']: 125 | place +=1 126 | 127 | elif sorted_gallery_filenames[i] in Index[query]['pos']: 128 | # print "PLACE " , sorted_gallery_filenames[i] 129 | return place 130 | else : 131 | place +=1 132 | 133 | return place 134 | 135 | def ranking(metric, gescrs_query, query_image_names, gescrs_gallery, test_image_names, maxrank, Index): 136 | ranks = np.zeros(maxrank + 1) 137 | places = dict() 138 | all_dist = cdist(gescrs_query, gescrs_gallery, metric) 139 | np_test_image_names = np.array(test_image_names) 140 | img_names_sorted = dict() 141 | 142 | all_gallery_names_sorted = np_test_image_names[np.argsort(all_dist).astype(np.uint32)] 143 | for qind in range(len(gescrs_query)): 144 | dist = all_dist[qind] 145 | gallery_names_sorted = all_gallery_names_sorted[qind] 146 | 147 | place=getPlace(query_image_names[qind], gallery_names_sorted, Index) 148 | img_names_sorted[qind] = all_gallery_names_sorted[qind] 149 | 150 | places[qind] = place 151 | 152 | ranks[place+1:maxrank+1] += 1 153 | 154 | return ranks, img_names_sorted,places 155 | 156 | def cos_dist(x, y): 157 | xy = np.dot(x,y); 158 | xx = np.dot(x,x); 159 | yy = np.dot(y,y); 160 | 161 | return -xy*1.0/np.sqrt(xx*yy) 162 | 163 | def getDistances(gescr_query, gescrs_gallery): 164 | dist = list() 165 | 166 | for i in range(len(gescrs_gallery)): 167 | dist.append(cos_dist(gescr_query, gescrs_gallery[i])) 168 | 169 | return dist 170 | 171 | def getAveragePrecision(query, sorted_gallery_filenames, Index): 172 | 173 | ap = 0 174 | tp = 0 175 | k = 0 176 | 177 | for i in range(len(sorted_gallery_filenames)): 178 | 179 | if sorted_gallery_filenames[i] in Index['junk'] or sorted_gallery_filenames[i] in Index[query]['junk']: 180 | continue 181 | elif sorted_gallery_filenames[i] in Index['distractor']: 182 | k+=1 183 | deltaR = 0 184 | elif sorted_gallery_filenames[i] in Index[query]['pos']: 185 | tp+=1 186 | k+=1 187 | deltaR = 1.0/len(Index[query]['pos']) 188 | else : 189 | k +=1 190 | deltaR = 0 191 | 192 | precision = tp*1.0/k * deltaR 193 | ap += precision 194 | if tp == len(Index[query]['pos']): 195 | return ap 196 | 197 | return ap 198 | 199 | def mAP(gescrs_query, query_image_names, gescrs_gallery, test_image_names, maxrank, Index): 200 | ranks = np.zeros(maxrank+1) 201 | places = dict() 202 | ap_sum = 0 203 | 204 | for qind in range(len(gescrs_query)): 205 | dist = getDistances(gescrs_query[qind], gescrs_gallery) 206 | dist_zip = sorted(zip(dist,test_image_names)) 207 | gallery_names_sorted = [x for (y,x) in dist_zip] 208 | 209 | ap=getAveragePrecision(query_image_names[qind], gallery_names_sorted, Index) 210 | ap_sum += ap 211 | 212 | return ap_sum * 1.0 /len(gescrs_query) 213 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class DropoutShared(nn.Module): 5 | def __init__(self, p=0.5, use_gpu=True): 6 | super(DropoutShared, self).__init__() 7 | if p < 0 or p > 1: 8 | raise ValueError("dropout probability has to be between 0 and 1, " 9 | "but got {}".format(p)) 10 | self.p = p 11 | self.use_gpu = use_gpu 12 | 13 | def forward(self, input): 14 | if self.training: 15 | index = torch.arange(0, input.size()[1])[torch.Tensor(input.size()[1]).uniform_(0, 1).le(self.p)].long() 16 | input_cloned = input.clone() 17 | if self.use_gpu: 18 | input_cloned[:, index.cuda()] = 0 19 | else: 20 | input_cloned[:, index] = 0 21 | return input_cloned / (1 - self.p) 22 | else: 23 | return input 24 | 25 | def __repr__(self): 26 | return self.__class__.__name__ + ' (' \ 27 | + 'p=' + str(self.p) + ')' 28 | 29 | 30 | class L2Normalization(nn.Module): 31 | def __init__(self): 32 | super(L2Normalization, self).__init__() 33 | 34 | def forward(self, input): 35 | input = input.squeeze() 36 | return input.div(torch.norm(input, dim=1).view(-1, 1)) 37 | 38 | def __repr__(self): 39 | return self.__class__.__name__ 40 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from numpy.testing import assert_almost_equal 4 | 5 | class HistogramLoss(torch.nn.Module): 6 | def __init__(self, num_steps, cuda=True): 7 | super(HistogramLoss, self).__init__() 8 | self.step = 2 / (num_steps - 1) 9 | self.eps = 1 / num_steps 10 | self.cuda = cuda 11 | self.t = torch.arange(-1, 1+self.step, self.step).view(-1, 1) 12 | self.tsize = self.t.size()[0] 13 | if self.cuda: 14 | self.t = self.t.cuda() 15 | 16 | def forward(self, features, classes): 17 | def histogram(inds, size): 18 | s_repeat_ = s_repeat.clone() 19 | indsa = (s_repeat_floor - (self.t - self.step) > -self.eps) & (s_repeat_floor - (self.t - self.step) < self.eps) & inds 20 | assert indsa.nonzero().size()[0] == size, ('Another number of bins should be used') 21 | zeros = torch.zeros((1, indsa.size()[1])).byte() 22 | if self.cuda: 23 | zeros = zeros.cuda() 24 | indsb = torch.cat((indsa, zeros))[1:, :] 25 | s_repeat_[~(indsb|indsa)] = 0 26 | # indsa corresponds to the first condition of the second equation of the paper 27 | s_repeat_[indsa] = (s_repeat_ - self.t + self.step)[indsa] / self.step 28 | # indsb corresponds to the second condition of the second equation of the paper 29 | s_repeat_[indsb] = (-s_repeat_ + self.t + self.step)[indsb] / self.step 30 | 31 | return s_repeat_.sum(1) / size 32 | 33 | classes_size = classes.size()[0] 34 | classes_eq = (classes.repeat(classes_size, 1) == classes.view(-1, 1).repeat(1, classes_size)).data 35 | dists = torch.mm(features, features.transpose(0, 1)) 36 | assert ((dists > 1 + self.eps).sum().item() + (dists < -1 - self.eps).sum().item()) == 0, 'L2 normalization should be used' 37 | s_inds = torch.triu(torch.ones(classes_eq.size()), 1).byte() 38 | if self.cuda: 39 | s_inds= s_inds.cuda() 40 | pos_inds = classes_eq[s_inds].repeat(self.tsize, 1) 41 | neg_inds = ~classes_eq[s_inds].repeat(self.tsize, 1) 42 | pos_size = classes_eq[s_inds].sum().item() 43 | neg_size = (~classes_eq[s_inds]).sum().item() 44 | s = dists[s_inds].view(1, -1) 45 | s_repeat = s.repeat(self.tsize, 1) 46 | s_repeat_floor = (torch.floor(s_repeat.data / self.step) * self.step).float() 47 | 48 | histogram_pos = histogram(pos_inds, pos_size) 49 | assert_almost_equal(histogram_pos.sum().item(), 1, decimal=1, 50 | err_msg='Not good positive histogram', verbose=True) 51 | histogram_neg = histogram(neg_inds, neg_size) 52 | assert_almost_equal(histogram_neg.sum().item(), 1, decimal=1, 53 | err_msg='Not good negative histogram', verbose=True) 54 | histogram_pos_repeat = histogram_pos.view(-1, 1).repeat(1, histogram_pos.size()[0]) 55 | histogram_pos_inds = torch.tril(torch.ones(histogram_pos_repeat.size()), -1).byte() 56 | if self.cuda: 57 | histogram_pos_inds = histogram_pos_inds.cuda() 58 | histogram_pos_repeat[histogram_pos_inds] = 0 59 | histogram_pos_cdf = histogram_pos_repeat.sum(0) 60 | loss = torch.sum(histogram_neg * histogram_pos_cdf) 61 | 62 | return loss 63 | 64 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import json 4 | import numpy as np 5 | import os 6 | import pandas as pd 7 | import random 8 | import time 9 | import torch 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | import torchvision 13 | 14 | from glob import glob 15 | from torch.optim import lr_scheduler 16 | from torch.utils.data import DataLoader 17 | from torchvision import models, transforms 18 | 19 | from datasets import ImageDataset 20 | from evaluation import Evaluation 21 | from layers import L2Normalization 22 | from losses import HistogramLoss 23 | from samplers import MarketSampler 24 | from visualizer import Visualizer 25 | 26 | with open('config') as json_file: 27 | opt = json.load(json_file) 28 | print(opt) 29 | 30 | try: 31 | os.makedirs(opt['checkpoints_path']) 32 | except OSError: 33 | pass 34 | 35 | if opt['manual_seed'] is None: 36 | opt['manual_seed'] = random.randint(1, 10000) 37 | print("Random Seed: ", opt['manual_seed']) 38 | random.seed(opt['manual_seed']) 39 | torch.manual_seed(opt['manual_seed']) 40 | if opt['cuda']: 41 | torch.cuda.manual_seed_all(opt['manual_seed']) 42 | 43 | if torch.cuda.is_available() and not opt['cuda']: 44 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 45 | 46 | vis = Visualizer(opt['checkpoints_path'], opt['visdom_port']) 47 | 48 | def create_df(dataroot, size=-1): 49 | df_paths = glob(dataroot) 50 | df = pd.DataFrame({'path': df_paths}) 51 | df['label'] = df.path.apply(lambda x: int(x.split('/')[-1].split('_')[0])) 52 | return df[:size] 53 | 54 | if not opt['market']: 55 | df_train = create_df(os.path.join(opt['dataroot'], '*.jpg')) 56 | else: 57 | def create_market_df(x): 58 | df = create_df(os.path.join(opt['dataroot'], paths[x])) 59 | df['camera'] = df.path.apply(lambda x: int(x.split('/')[-1].split('_')[1].split('s')[0].split('c')[1])) 60 | df['name'] = df.path.apply(lambda x: x.split('/')[-1]) 61 | return df 62 | 63 | paths = { 64 | 'train': 'bounding_box_train/*.jpg', 65 | 'test': 'bounding_box_test/*.jpg', 66 | 'query': 'query/*.jpg', 67 | } 68 | 69 | df_train = create_market_df('train') 70 | dfs_test = { 71 | x: create_market_df(x) for x in ['test', 'query'] 72 | } 73 | 74 | data_transform_test = transforms.Compose([ 75 | transforms.Resize([256, 256]), 76 | transforms.CenterCrop(224), 77 | transforms.ToTensor(), 78 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 79 | ]) 80 | 81 | datasets_test = { 82 | x: ImageDataset( 83 | dfs_test[x]['path'], 84 | transform=data_transform_test, 85 | is_train=False 86 | ) for x in ['test', 'query'] 87 | } 88 | 89 | dataloaders_test = { 90 | x: DataLoader( 91 | datasets_test[x], 92 | batch_size=opt['batch_size_test'], 93 | shuffle=False, 94 | num_workers=opt['nworkers'] 95 | ) for x in datasets_test.keys() 96 | } 97 | 98 | evaluation = Evaluation(dfs_test['test'], dfs_test['query'], dataloaders_test['test'], dataloaders_test['query'], opt['cuda']) 99 | 100 | data_transform = transforms.Compose([ 101 | transforms.Resize([256, 256]), 102 | transforms.RandomCrop(224), 103 | transforms.RandomHorizontalFlip(), 104 | transforms.ToTensor(), 105 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 106 | ]) 107 | 108 | dataset = ImageDataset(df_train['path'], data_transform, True, df_train['label']) 109 | sampler = MarketSampler(df_train['label'], opt['batch_size']) 110 | dataloader = DataLoader(dataset, batch_sampler=sampler, num_workers=opt['nworkers']) 111 | 112 | def train(optimizer, criterion, scheduler, epoch_start, epoch_end): 113 | for epoch in range(epoch_start, epoch_end): 114 | scheduler.step() 115 | model.train(True) 116 | running_loss = .0 117 | 118 | for data in dataloader: 119 | inputs, labels = data 120 | inputs, labels = inputs.squeeze(), labels.squeeze() 121 | 122 | if opt['cuda']: 123 | inputs, labels = inputs.cuda(), labels.cuda() 124 | 125 | optimizer.zero_grad() 126 | outputs = model(inputs) 127 | loss = criterion(outputs, labels) 128 | 129 | loss.backward() 130 | optimizer.step() 131 | running_loss += loss.data.item() 132 | 133 | epoch_loss = running_loss / len(dataloader) 134 | 135 | vis.quality('Loss', {'Loss': epoch_loss}, epoch, opt['nepoch']) 136 | 137 | if opt['market']: 138 | if epoch % 5 == 0: 139 | model.train(False) 140 | ranks, mAP = ranks, mAP = evaluation.ranks_map(model, 2) 141 | vis.quality('Rank1 and mAP', {'Rank1': ranks[1], 'mAP': mAP}, epoch, opt['nepoch']) 142 | 143 | if epoch % 10 == 0: 144 | torch.save(model, '{}/finetuned_histogram_e{}.pt'.format(opt['checkpoints_path'], epoch)) 145 | 146 | model = models.resnet34(pretrained=True) 147 | for param in model.parameters(): 148 | param.requires_grad = False 149 | 150 | num_ftrs = model.fc.in_features 151 | model.fc = torch.nn.Sequential() 152 | if opt['dropout_prob'] > 0: 153 | model.fc.add_module('dropout', nn.Dropout(opt['dropout_prob'])) 154 | model.fc.add_module('fc', nn.Linear(num_ftrs, 512)) 155 | model.fc.add_module('l2normalization', L2Normalization()) 156 | if opt['cuda']: 157 | model = model.cuda() 158 | print(model) 159 | 160 | criterion = HistogramLoss(num_steps=opt['nbins'], cuda=opt['cuda']) 161 | 162 | if opt['nepoch_fc'] > 0: 163 | print('\nTrain fc layer\n') 164 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=opt['lr_fc']) 165 | scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) 166 | train(optimizer, criterion, scheduler, 1, opt['nepoch_fc'] + 1) 167 | 168 | print('\nTrain all layers\n') 169 | for param in model.parameters(): 170 | param.requires_grad = True 171 | 172 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=opt['lr']) 173 | scheduler = lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1) 174 | train(optimizer, criterion, scheduler, opt['nepoch_fc'] + 1, opt['nepoch'] + 1) 175 | -------------------------------------------------------------------------------- /samplers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from torch.utils.data.sampler import Sampler 4 | 5 | class MarketSampler(Sampler): 6 | def __init__(self, labels, batch_size): 7 | self.labels = np.array(labels) 8 | self.labels_unique = np.unique(labels) 9 | self.batch_size = batch_size 10 | 11 | def __iter__(self): 12 | for i in range(self.__len__()): 13 | labels_in_batch = set() 14 | inds = np.array([], dtype=np.int) 15 | 16 | while inds.shape[0] < self.batch_size: 17 | sample_label = np.random.choice(self.labels_unique) 18 | if sample_label in labels_in_batch: 19 | continue 20 | 21 | labels_in_batch.add(sample_label) 22 | subsample_size = np.random.choice(range(5, 11)) 23 | sample_label_ids = np.argwhere(np.in1d(self.labels, sample_label)).reshape(-1) 24 | subsample = np.random.permutation(sample_label_ids)[:subsample_size] 25 | inds = np.append(inds, subsample) 26 | 27 | inds = inds[:self.batch_size] 28 | yield list(inds) 29 | 30 | def __len__(self): 31 | return len(self.labels) // self.batch_size 32 | 33 | -------------------------------------------------------------------------------- /visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import time 4 | 5 | class Visualizer(): 6 | def __init__(self, checkpoints_path, visdom_port=None): 7 | if visdom_port is not None: 8 | import visdom 9 | self.vis = visdom.Visdom(port=visdom_port) 10 | self.use_vis = True 11 | else: 12 | self.use_vis = False 13 | 14 | self.checkpoints_path = checkpoints_path 15 | self.log_name = os.path.join(checkpoints_path, 'loss_log.txt') 16 | now = time.strftime('%c') 17 | self.start_time = time.time() 18 | 19 | with open(self.log_name, 'a') as log_file: 20 | log_file.write('Training {} \n'.format(now)) 21 | 22 | def plot_quality(self, name, quality, epoch): 23 | if not hasattr(self, 'plot_data'): 24 | self.plot_data = {} 25 | if name not in self.plot_data: 26 | self.plot_data[name] = {'X':[],'Y':[], 'legend':list(quality.keys())} 27 | self.plot_data[name]['X'].append(epoch) 28 | self.plot_data[name]['Y'].append([quality[k] for k in self.plot_data[name]['legend']]) 29 | 30 | self.plot_data[name]['X'].append(epoch) 31 | self.plot_data[name]['Y'].append([quality[k] for k in self.plot_data[name]['legend']]) 32 | self.vis.line( 33 | X=np.stack([np.array(self.plot_data[name]['X'])]*len(self.plot_data[name]['legend']), 1), 34 | Y=np.array(self.plot_data[name]['Y']), 35 | opts={ 36 | 'title': name, 37 | 'legend': self.plot_data[name]['legend'], 38 | 'xlabel': 'epoch'}, 39 | win=name) 40 | 41 | def print_quality(self, quality, epoch, epochs): 42 | message = '[Epoch {}/{}] Time elapsed: {:.2f}; '.format(epoch, epochs, time.time() - self.start_time) 43 | for k, v in quality.items(): 44 | message += '{}: {:.4f}; '.format(k, v) 45 | 46 | print(message) 47 | with open(self.log_name, 'a') as log_file: 48 | log_file.write('{}\n'.format(message)) 49 | 50 | def quality(self, name, quality, epoch, epochs): 51 | self.print_quality(quality, epoch, epochs) 52 | if self.use_vis : 53 | self.plot_quality(name, quality, epoch) 54 | 55 | --------------------------------------------------------------------------------