├── .gitignore ├── README.md ├── check.sh ├── data ├── debug.txt ├── reid.data └── test_debug.txt ├── dataset ├── __init__.py ├── dataset.py ├── huaijin.jpg ├── person-1.jpg └── transforms.py ├── module ├── TripletLoss.py ├── __init__.py ├── l2normalize.py └── symbols.py ├── tools ├── __init__.py ├── test.py ├── train.py └── utils.py └── train_val.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.o 2 | *.dSYM 3 | *.csv 4 | *.out 5 | *.png 6 | *.jpg 7 | caffe/ 8 | grasp/ 9 | images/ 10 | opencv/ 11 | convnet/ 12 | decaf/ 13 | submission/ 14 | cfg/ 15 | darknet 16 | .fuse* 17 | models/ 18 | *.pyc 19 | *log.txt 20 | *train.txt 21 | *val.txt 22 | *.weights 23 | <<<<<<< HEAD 24 | data/* 25 | ======= 26 | *.txt 27 | >>>>>>> 44763a74cdcfb3786ec341f406d57e7d2f9ac1cf 28 | 29 | # OS Generated # 30 | .DS_Store* 31 | ehthumbs.db 32 | Icon? 33 | Thumbs.db 34 | *.swp 35 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Triplet Margin Loss for Person Re-identification 2 | 3 | This Project is for Person Re-identification using [Triplet Loss](https://arxiv.org/abs/1503.03832) based on PyTorch 4 | 5 | - [x] Triplet Margin Loss 6 | - [x] load weights form darknet weights file 7 | - [x] save weighes file with darknet format 8 | - [x] find best threshold for test set 9 | - [ ] more network structure 10 | - [ ] more trick using in ReID 11 | - [x] faster in multi-GPU 12 | - [ ] load and save caffemode 13 | 14 | 15 | 16 | ## Training and validation 17 | 18 | 1. creat triplet list file and put it in data/ 19 | 2. set `image_root`=`your/images/path` 20 | 21 | ``` 22 | python train_val.py --gpus 0,1,2,3 23 | ``` 24 | 25 | ## Time Cost 26 | 27 | time cost in `4 TITAN X` 28 | 29 | batch size |cost(ms) / 1 TripletSample 30 | --- | --- 31 | 64 | 1.31 32 | 128 | 0.96 33 | 256 | 0.35 34 | 512 | 0.20 35 | 1024 | 0.18 36 | 37 | ## Reference 38 | 39 | - FaceNet: A Unified Embedding for Face Recognition and Clustering 40 | 41 | 42 | -------------------------------------------------------------------------------- /check.sh: -------------------------------------------------------------------------------- 1 | ######################################################################### 2 | # Author: Chao CHEN 3 | # Created On: 2017-06-16 4 | ######################################################################### 5 | #!/bin/bash 6 | grep A log.txt | sort -k7 7 | -------------------------------------------------------------------------------- /data/reid.data: -------------------------------------------------------------------------------- 1 | train=train.txt 2 | val=val.txt 3 | gpus=0,1,2,3 4 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huaijin-chen/pytorch-PersonReID/719bc9422ffd6e7787432a0db1474d0b0b3eadec/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # encoding: utf-8 3 | 4 | import os,random 5 | import os.path as osp 6 | import torch 7 | from torch.utils.data import Dataset 8 | from torchvision import datasets, transforms 9 | from PIL import Image 10 | 11 | 12 | class listDataset(Dataset): 13 | 14 | def __init__(self, root, filename, shuffle=True, transform=None, 15 | target_transform=None, is_visualization=False): 16 | self.root = root 17 | with open(filename, 'r') as file: 18 | print filename 19 | self.lines = file.readlines() 20 | 21 | if shuffle: 22 | random.shuffle(self.lines) 23 | 24 | self.is_visualization = is_visualization 25 | self.nSamples = len(self.lines) 26 | self.transform = transform 27 | self.target_transform = target_transform 28 | 29 | def __len__(self): 30 | return self.nSamples 31 | 32 | def __getitem__(self, index): 33 | assert index <= len(self), 'index range error' 34 | anchor, pos, neg = self.lines[index].split() 35 | imga = Image.open(os.path.join(self.root,anchor)).convert('RGB') 36 | imgp = Image.open(os.path.join(self.root,pos)).convert('RGB') 37 | imgn = Image.open(os.path.join(self.root,neg)).convert('RGB') 38 | if self.is_visualization: 39 | anchor_path = osp.join(self.root, anchor) 40 | pos_path = osp.join(self.root, pos) 41 | neg_path = osp.join(self.root, neg) 42 | image_paths = (anchor_path, pos_path, neg_path) 43 | 44 | if self.transform is not None: 45 | imga = self.transform(imga) 46 | imgp = self.transform(imgp) 47 | imgn = self.transform(imgn) 48 | label = torch.LongTensor([1,0]) 49 | if self.target_transform is not None: 50 | label = self.target_transform(label) 51 | if not self.is_visualization: 52 | return imga, imgp, imgn, label 53 | else: 54 | return imga, imgp, imgn, label, image_paths 55 | -------------------------------------------------------------------------------- /dataset/huaijin.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huaijin-chen/pytorch-PersonReID/719bc9422ffd6e7787432a0db1474d0b0b3eadec/dataset/huaijin.jpg -------------------------------------------------------------------------------- /dataset/person-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huaijin-chen/pytorch-PersonReID/719bc9422ffd6e7787432a0db1474d0b0b3eadec/dataset/person-1.jpg -------------------------------------------------------------------------------- /dataset/transforms.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # -------------------------------------------------------- 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Chao CHEN (chaochancs@gmail.com) 6 | # Created On: 2017-07-05 7 | # -------------------------------------------------------- 8 | from PIL import Image, ImageOps 9 | class person_crop(object): 10 | """ 11 | crop image to a specified size 12 | """ 13 | def __init__(self, size=None, ratio=(1,0.75), crop_type=1): 14 | """ 15 | Args: 16 | size (tuple) : Desired output size of the crop. If size is an 17 | int instead of squence like (w, h), a square crop (size, size) is made. 18 | 19 | ratio (float) : the crop's size is caculated by this value 20 | crop_type (int) : 0 crop using size. 1, crop using ratios 21 | """ 22 | self.size = size 23 | self.ratio = ratio 24 | self.crop_type = crop_type 25 | 26 | def __call__(self, img): 27 | """ 28 | Args: 29 | img (PIL.Image): Image to be cropped. 30 | Returns: 31 | PIL.Image: Cropped image. 32 | """ 33 | w, h = img.size 34 | if self.crop_type == 0: 35 | th, tw = self.size 36 | th = max(th, h-1) 37 | tw = max(tw, w-1) 38 | x1 = 0 39 | y1 = 0 40 | return img.crop((x1, y1, x1 + tw, y1 + th)) 41 | elif self.crop_type == 1: 42 | r_w, r_h = self.ratio 43 | tw = int(w*r_w) 44 | th = int(h*r_h) 45 | return img.crop((0, 0, tw, th)) 46 | else: 47 | print('crop_type error.') 48 | 49 | class scale(object): 50 | def __init__(self, size, interpolation=Image.BILINEAR): 51 | self.size = size; 52 | self.interpolation = interpolation 53 | 54 | def __call__(self, img): 55 | img = img.resize(self.size, self.interpolation) 56 | #print(img.size) 57 | return img 58 | 59 | if '__main__' == __name__: 60 | img_path = '/data/chenchao/darknet/data/person.jpg' 61 | img = Image.open(img_path).convert('RGB') 62 | pc = person_crop() 63 | img_crop = pc(img) 64 | img_crop.save('person-1.jpg') 65 | 66 | -------------------------------------------------------------------------------- /module/TripletLoss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # -------------------------------------------------------- 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Chao CHEN (chaochancs@gmail.com) 6 | # Created On: 2017-06-08 7 | # -------------------------------------------------------- 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.autograd import Variable 12 | def cosine_similarity(x1, x2, dim=1, eps=1e-8): 13 | r"""Returns cosine similarity between x1 and x2, computed along dim. 14 | 15 | Args: 16 | x1 (Variable): First input. 17 | x2 (Variable): Second input (of size matching x1). 18 | dim (int, optional): Dimension of vectors. Default: 1 19 | eps (float, optional): Small value to avoid division by zero. Default: 1e-8 20 | 21 | Shape: 22 | - Input: :math:`(\ast_1, D, \ast_2)` where D is at position `dim`. 23 | - Output: :math:`(\ast_1, \ast_2)` where 1 is at position `dim`. 24 | """ 25 | w12 = torch.sum(x1 * x2, dim) 26 | w1 = torch.norm(x1, 2, dim) 27 | w2 = torch.norm(x2, 2, dim) 28 | return (w12 / (w1 * w2).clamp(min=eps)).squeeze() 29 | 30 | def cos_distance(self, a, b): 31 | return torch.dot(a, b)/(torch.norm(a)*torch.norm(b)) 32 | 33 | class TripletMarginLoss(nn.Module): 34 | def __init__(self, margin, use_ohem=False, ohem_bs=128, dist_type = 0): 35 | super(TripletMarginLoss, self).__init__() 36 | self.margin = margin 37 | self.dist_type = dist_type 38 | self.use_ohem = use_ohem 39 | self.ohem_bs = ohem_bs 40 | #print('Use_OHEM : ',self.use_ohem) 41 | 42 | def forward(self, anchor, positive, negative): 43 | #eucl distance 44 | #dist = torch.sum( (anchor - positive) ** 2 - (anchor - negative) ** 2, dim=1)\ 45 | # + self.margin 46 | 47 | if self.dist_type == 0: 48 | dist_p = F.pairwise_distance(anchor ,positive) 49 | dist_n = F.pairwise_distance(anchor ,negative) 50 | if self.dist_type == 1: 51 | dist_p = cosine_similarity(anchor, positive) 52 | disp_n = cosine_similarity(anchor, negative) 53 | 54 | 55 | dist_hinge = torch.clamp(dist_p - dist_n + self.margin, min=0.0) 56 | if self.use_ohem: 57 | v, idx = torch.sort(dist_hinge,descending=True) 58 | loss = torch.mean(v[0:self.ohem_bs]) 59 | else: 60 | loss = torch.mean(dist_hinge) 61 | 62 | return loss 63 | -------------------------------------------------------------------------------- /module/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # -------------------------------------------------------- 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Chao CHEN (chaochancs@gmail.com) 6 | # Created On: 2017-06-08 7 | # -------------------------------------------------------- 8 | -------------------------------------------------------------------------------- /module/l2normalize.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # -------------------------------------------------------- 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Chao CHEN (chaochancs@gmail.com) 6 | # Created On: 2017-06-20 7 | # -------------------------------------------------------- 8 | import torch 9 | import torch.nn as nn 10 | 11 | class L2Normalize(nn.Module): 12 | def __init__(self): 13 | super(L2Normalize, self).__init__() 14 | 15 | def forward(self, data): 16 | norms = data.norm(2, 1) 17 | #print norms.size() 18 | batch_size = data.size()[0] 19 | norms = norms.view(-1, 1).repeat(1, data.size()[1]) 20 | #print norms 21 | #print norms.size() 22 | x = data / norms 23 | return x 24 | 25 | if __name__ == '__main__': 26 | import torch 27 | from torch.autograd import Variable 28 | 29 | data = torch.randn(2, 10) 30 | print data.size() 31 | print 'data: ',data 32 | l2 = L2Normalize() 33 | l2.eval() 34 | x = l2(Variable(data)) 35 | print x 36 | print '-------------' 37 | 38 | -------------------------------------------------------------------------------- /module/symbols.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # -------------------------------------------------------- 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Chao CHEN (chaochancs@gmail.com) 6 | # Created On: 2017-06-08 7 | # -------------------------------------------------------- 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.autograd import Variable 12 | from collections import OrderedDict 13 | from .l2normalize import L2Normalize 14 | import numpy as np 15 | 16 | class View(nn.Module): 17 | def __init__(self, B, N): 18 | super(View, self).__init__() 19 | self.B = B 20 | self.N = N 21 | def forward(self, x): 22 | x = x.view(self.B, self.N) 23 | return x 24 | 25 | class Net(nn.Module): 26 | def __init__(self): 27 | super(Net, self).__init__() 28 | self.feat_model = nn.Sequential(OrderedDict([ 29 | # conv1 30 | ('conv1', nn.Conv2d( 3, 16, 3, 1, 1, bias=False)), 31 | ('bn1', nn.BatchNorm2d(16)), 32 | ('relu1', nn.ReLU()), 33 | ('pool1', nn.MaxPool2d(2, 2)), 34 | # conv2 35 | ('conv2', nn.Conv2d(16, 32, 3, 1, 1, bias=False)), 36 | ('bn2', nn.BatchNorm2d(32)), 37 | ('relu2', nn.ReLU()), 38 | ('pool2', nn.MaxPool2d(2, 2)), 39 | # conv3 40 | ('conv3', nn.Conv2d(32, 64, 3, 1, 1, bias=False)), 41 | ('bn3', nn.BatchNorm2d(64)), 42 | ('relu3', nn.ReLU()), 43 | ('pool3', nn.MaxPool2d(2, 2)), 44 | # conv4 45 | ('conv4', nn.Conv2d(64, 128, 3, 1, 1, bias=False)), 46 | ('bn4', nn.BatchNorm2d(128)), 47 | ('relu4', nn.ReLU()), 48 | ('pool4', nn.MaxPool2d(2, 2)), 49 | # conv5 50 | ('conv5', nn.Conv2d(128, 256, 3, 1, 1, bias=False)), 51 | ('bn5', nn.BatchNorm2d(256)), 52 | ('relu5', nn.ReLU()), 53 | ('pool5', nn.MaxPool2d(2, 2)), 54 | # conv6 55 | #('conv6',nn.Conv2d(256,512,3,1,1,bias=False)), 56 | #('bn6',nn.BatchNorm2d(512)), 57 | #('relu6',nn.ReLU()), 58 | #('pool6',nn.MaxPool2d(2,2)), 59 | 60 | ('reshape', View(-1, 2*4*256)), 61 | ('fc1', nn.Linear(2048, 1024)),#add fc level 62 | ('relu6', nn.ReLU()), 63 | ('fc2', nn.Linear(1024, 512)), 64 | ('l2normal', L2Normalize()) 65 | #('relu7', nn.ReLU()), 66 | ])) 67 | 68 | def load_weights(self,path): 69 | buf = np.fromfile(path, dtype = np.float32) 70 | start = 4 71 | start = load_conv_bn(buf, start, \ 72 | self.feat_model[0], self.feat_model[1]) 73 | start = load_conv_bn(buf, start, \ 74 | self.feat_model[4], self.feat_model[5]) 75 | start = load_conv_bn(buf, start, \ 76 | self.feat_model[8], self.feat_model[9]) 77 | start = load_conv_bn(buf, start, \ 78 | self.feat_model[12], self.feat_model[13]) 79 | start = load_conv_bn(buf, start,\ 80 | self.feat_model[16], self.feat_model[17]) 81 | 82 | def load_full_weights(self, model_path): 83 | buf = np.fromfile(model_path, dtype = np.float32) 84 | start = 4 85 | start = load_conv_bn(buf, start, \ 86 | self.feat_model[0], self.feat_model[1]) 87 | start = load_conv_bn(buf, start, \ 88 | self.feat_model[4], self.feat_model[5]) 89 | start = load_conv_bn(buf, start, \ 90 | self.feat_model[8], self.feat_model[9]) 91 | start = load_conv_bn(buf, start, \ 92 | self.feat_model[12], self.feat_model[13]) 93 | start = load_conv_bn(buf, start,\ 94 | self.feat_model[16], self.feat_model[17]) 95 | 96 | # fc1 -21 97 | num_w = self.feat_model[21].weight.numel() 98 | num_b = self.feat_model[21].bias.numel() 99 | 100 | self.feat_model[21].bias.data.copy_( 101 | torch.from_numpy(buf[start:start+num_b])) 102 | 103 | start = start + num_b 104 | self.feat_model[21].weight.data.copy_( 105 | torch.from_numpy(buf[start:start+num_w])) 106 | start = start + num_w 107 | 108 | # fc2 - 23 109 | num_w = self.feat_model[23].weight.numel() 110 | num_b = self.feat_model[23].bias.numel() 111 | 112 | self.feat_model[23].bias.data.copy_( 113 | torch.from_numpy(buf[start:start+num_b])) 114 | start = start + num_b 115 | self.feat_model[23].weight.data.copy_( 116 | torch.from_numpy(buf[start:start+num_w])) 117 | start = start + num_w 118 | 119 | def forward(self, x): 120 | x = self.feat_model(x) 121 | return x 122 | 123 | #train for classification 124 | class Net_cls(nn.Module): 125 | def __init__(self): 126 | super(Net_cls,self).__init__() 127 | self.cls_model = nn.Sequential(OrderedDict([ 128 | ('fc4',nn.Linear(1024,512)), 129 | ('relu',nn.ReLU()), 130 | ('fc5',nn.Linear(512,2)), 131 | ('log_softmax',nn.LogSoftmax()), 132 | 133 | ])) 134 | 135 | def forward(self,x): 136 | x = self.cls_model(x) 137 | return x 138 | 139 | def load_conv_bn(buf, start, conv_model, bn_model): 140 | num_w = conv_model.weight.numel() 141 | num_b = bn_model.bias.numel() 142 | #print(num_w,num_b) 143 | bn_model.bias.data.copy_(torch.from_numpy(buf[start:start+num_b])) 144 | start = start + num_b 145 | bn_model.weight.data.copy_(torch.from_numpy(buf[start:start+num_b])) 146 | start = start + num_b 147 | bn_model.running_mean.copy_(torch.from_numpy(buf[start:start+num_b])) 148 | start = start + num_b 149 | bn_model.running_var.copy_(torch.from_numpy(buf[start:start+num_b])) 150 | start = start + num_b 151 | conv_model.weight.data.copy_(torch.from_numpy(buf[start:start+num_w])) 152 | start = start + num_w 153 | return start 154 | 155 | 156 | def save_conv_bn(fp, conv_model, bn_model): 157 | #print(bn_model.bias.data) 158 | bn_model.bias.data.cpu().numpy().tofile(fp) 159 | bn_model.weight.data.cpu().numpy().tofile(fp) 160 | bn_model.running_mean.cpu().numpy().tofile(fp) 161 | bn_model.running_var.cpu().numpy().tofile(fp) 162 | conv_model.weight.data.cpu().numpy().tofile(fp) 163 | 164 | def save_conv(fp, conv_model): 165 | conv_model.bias.data.cpu().numpy().tofile(fp) 166 | conv_model.weight.data.cpu().numpy().tofile(fp) 167 | 168 | def save_weights(selfmodel, outfile, cutoff): 169 | ind = 0 170 | fp = open(outfile, 'wb') 171 | header = torch.IntTensor([0,0,0,0]) 172 | header[3] = 0 173 | header.numpy().tofile(fp) 174 | for blockId in range(0, cutoff): 175 | block = selfmodel[blockId] 176 | #print(block.__class__.__name__) 177 | if block.__class__.__name__ == 'Conv2d': 178 | if selfmodel[blockId+1].__class__.__name__ == 'BatchNorm2d': 179 | save_conv_bn(fp, selfmodel[blockId], selfmodel[blockId+1]) 180 | ind = ind + 2 181 | else: 182 | save_conv(fp, selfmodel[blockId]) 183 | ind = ind+1 184 | elif block.__class__.__name__ == 'MaxPool2d': 185 | ind = ind+1 186 | elif block.__class__.__name__ == 'View': 187 | ind = ind+1 188 | elif block.__class__.__name__ == 'Linear': 189 | save_conv(fp, selfmodel[blockId]) 190 | ind = ind+1 191 | else: 192 | layer='unknown type ' 193 | fp.close() 194 | 195 | 196 | if __name__ == '__main__': 197 | model = Net() 198 | model.load_full_weights('data/tiny-yolo.weights') 199 | save_weights(model.feat_model, 'cc.weights', len(model.feat_model._modules)) 200 | 201 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # -------------------------------------------------------- 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Chao CHEN (chaochancs@gmail.com) 6 | # Created On: 2017-06-08 7 | # -------------------------------------------------------- 8 | -------------------------------------------------------------------------------- /tools/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # -------------------------------------------------------- 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Chao CHEN (chaochancs@gmail.com) 6 | # Created On: 2017-06-13 7 | # -------------------------------------------------------- 8 | import sys 9 | sys.path.insert(0, '../') 10 | import dataset.dataset as dataset 11 | from module.symbols import Net, Net_cls, save_weights 12 | from utils import AverageMeter, data_loader 13 | from torch.autograd import Variable 14 | import torch.nn.functional as F 15 | import torch 16 | import cv2 17 | import uuid 18 | import numpy as np 19 | import os.path as osp 20 | 21 | def combine_and_save(img1_path, img2_path, dist, output_dir, class_flag): 22 | img1 = cv2.imread(img1_path) 23 | img2 = cv2.imread(img2_path) 24 | h, w, c = img1.shape 25 | out_size = (h, w*2, c) 26 | out = np.zeros(out_size) 27 | out[:, 0:w, :] = img1 28 | out[:, w:2*w, :] = img2 29 | uid = uuid.uuid4() 30 | out_name = osp.join(output_dir, '{:s}_{:.03f}_{:s}.jpg'.format(class_flag, dist.data[0],str(uid))) 31 | cv2.imwrite(out_name, out) 32 | #print out_name 33 | 34 | 35 | def test(model, test_loader, epoch, margin, threshlod, is_cuda=True, log_interval=1000): 36 | model.eval() 37 | test_loss = AverageMeter() 38 | accuracy = 0 39 | num_p = 0 40 | total_num = 0 41 | batch_num = len(test_loader) 42 | for batch_idx, (data_a, data_p, data_n,target) in enumerate(test_loader): 43 | if is_cuda: 44 | data_a = data_a.cuda() 45 | data_p = data_p.cuda() 46 | data_n = data_n.cuda() 47 | target = target.cuda() 48 | 49 | data_a = Variable(data_a, volatile=True) 50 | data_p = Variable(data_p, volatile=True) 51 | data_n = Variable(data_n, volatile=True) 52 | target = Variable(target) 53 | 54 | out_a = model(data_a) 55 | out_p = model(data_p) 56 | out_n = model(data_n) 57 | 58 | loss = F.triplet_margin_loss(out_a,out_p,out_n, margin) 59 | 60 | dist1 = F.pairwise_distance(out_a,out_p) 61 | dist2 = F.pairwise_distance(out_a,out_n) 62 | 63 | num = ((dist1 < threshlod).float().sum() + (dist2 > threshlod).float().sum()).data[0] 64 | num_p += num 65 | num_p = 1.0 * num_p 66 | total_num += data_a.size()[0] * 2 67 | #print('num--num_p -- total', num, num_p , total_num) 68 | test_loss.update(loss.data[0]) 69 | if (batch_idx + 1) % log_interval == 0: 70 | accuracy_tmp = num_p / total_num 71 | print('Test- Epoch {:04d}\tbatch:{:06d}/{:06d}\tAccuracy:{:.04f}\tloss:{:06f}'\ 72 | .format(epoch, batch_idx+1, batch_num, accuracy_tmp, test_loss.avg)) 73 | test_loss.reset() 74 | 75 | accuracy = num_p / total_num 76 | return accuracy 77 | 78 | def test_vis(model, test_loader, model_path, threshlod,\ 79 | margin=1.0, is_cuda=True, output_dir='output',is_visualization=True): 80 | if not model_path is None: 81 | model.load_full_weights(model_path) 82 | print('loaded model file: {:s}'.format(model_path)) 83 | if is_cuda: 84 | model = model.cuda() 85 | model.eval() 86 | test_loss = AverageMeter() 87 | accuracy = 0 88 | num_p = 0 89 | total_num = 0 90 | batch_num = len(test_loader) 91 | for batch_idx, (data_a, data_p, data_n,target, img_paths) in enumerate(test_loader): 92 | #for batch_idx, (data_a, data_p, data_n, target) in enumerate(test_loader): 93 | if is_cuda: 94 | data_a = data_a.cuda() 95 | data_p = data_p.cuda() 96 | data_n = data_n.cuda() 97 | target = target.cuda() 98 | 99 | data_a = Variable(data_a, volatile=True) 100 | data_p = Variable(data_p, volatile=True) 101 | data_n = Variable(data_n, volatile=True) 102 | target = Variable(target) 103 | 104 | out_a = model(data_a) 105 | out_p = model(data_p) 106 | out_n = model(data_n) 107 | 108 | loss = F.triplet_margin_loss(out_a,out_p,out_n, margin) 109 | 110 | dist1 = F.pairwise_distance(out_a,out_p) 111 | dist2 = F.pairwise_distance(out_a,out_n) 112 | batch_size = data_a.size()[0] 113 | pos_flag = (dist1 <= threshlod).float() 114 | neg_flag = (dist2 > threshlod).float() 115 | if is_visualization: 116 | for k in torch.arange(0, batch_size): 117 | k = int(k) 118 | if pos_flag[k].data[0] == 0: 119 | combine_and_save(img_paths[0][k], img_paths[1][k], dist1[k], output_dir, '1-1') 120 | if neg_flag[k].data[0] == 0: 121 | combine_and_save(img_paths[0][k], img_paths[2][k], dist2[k], output_dir, '1-0') 122 | 123 | num = (pos_flag.sum() + neg_flag.sum()).data[0] 124 | 125 | print('{:f}, {:f}, {:f}'.format(num, pos_flag.sum().data[0], neg_flag.sum().data[0])) 126 | num_p += num 127 | total_num += data_a.size()[0] * 2 128 | print('num_p = {:f}, total = {:f}'.format(num_p, total_num)) 129 | print('dist1 = {:f}, dist2 = {:f}'.format(dist1[0].data[0], dist2[0].data[0])) 130 | 131 | accuracy = num_p / total_num 132 | return accuracy 133 | 134 | def best_test(model, _loader, model_path=None, is_cuda=True): 135 | if not model_path is None: 136 | model.load_full_weights(model_path) 137 | print('loaded model file: {:s}'.format(model_path)) 138 | if is_cuda: 139 | model = model.cuda() 140 | model.eval() 141 | total_num = 0 142 | batch_num = len(_loader) 143 | for batch_idx, (data_a, data_p, data_n,target) in enumerate(_loader): 144 | if is_cuda: 145 | data_a = data_a.cuda() 146 | data_p = data_p.cuda() 147 | data_n = data_n.cuda() 148 | target = target.cuda() 149 | 150 | data_a = Variable(data_a, volatile=True) 151 | data_p = Variable(data_p, volatile=True) 152 | data_n = Variable(data_n, volatile=True) 153 | target = Variable(target) 154 | 155 | out_a = model(data_a) 156 | out_p = model(data_p) 157 | out_n = model(data_n) 158 | current_d_a_p = F.pairwise_distance(out_a,out_p) 159 | current_d_a_n = F.pairwise_distance(out_a,out_n) 160 | if batch_idx == 0: 161 | d_a_p = current_d_a_p 162 | d_a_n = current_d_a_n 163 | else: 164 | d_a_p = torch.cat((d_a_p, current_d_a_p), 0) 165 | d_a_n = torch.cat((d_a_n, current_d_a_n), 0) 166 | total_num += 2*data_a.size()[0] 167 | 168 | mean_d_a_p = d_a_p.mean().data[0] 169 | mean_d_a_n = d_a_n.mean().data[0] 170 | start = min(mean_d_a_n, mean_d_a_p) 171 | end = max(mean_d_a_n, mean_d_a_p) 172 | best_thre = 0 173 | best_num = 0 174 | thre_step = 0.05 175 | 176 | for val in torch.arange(start, end+thre_step, thre_step): 177 | num = (((d_a_p <= val).float()).sum() + (d_a_n > val).float().sum()).data[0] 178 | #print(num, val) 179 | if num > best_num: 180 | best_num = num 181 | best_thre = val 182 | return best_thre, best_num/total_num, mean_d_a_p, mean_d_a_n 183 | 184 | 185 | def visualization(): 186 | pass 187 | 188 | 189 | 190 | def evaluation(): 191 | model_path = '/data/chenchao/reid_train/exp11.2/models/epoch_0004-000500_feat.weights' 192 | model = Net() 193 | model = model.cuda() 194 | model.load_full_weights(model_path) 195 | model.eval() 196 | 197 | test_list = '../data/val.txt' 198 | image_root = '/home/chenchao/ReID/' 199 | test_loader = data_loader(image_root, test_list, shuffle=True, batch_size=256) 200 | margin = 1.0 201 | #accuracy = test(model, test_loader, 0, margin, threshlod=0.18, log_interval=10) 202 | thre, acc, mean_1, mean_2 = best_test(model, test_loader, model_path) 203 | print('accuracy = {:f}'.format(acc)) 204 | 205 | 206 | def test_offline(): 207 | import os 208 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 209 | model = Net() 210 | margin = 1.0 211 | threshlod = 1.0 212 | model_path = '/data/chenchao/reid_train/exp11.2/models/epoch_0004-000500_feat.weights' 213 | if not os.path.exists(model_path): 214 | print('huaijinhhhh') 215 | test_list = '../data/val.txt' 216 | image_root = '/home/chenchao/ReID/' 217 | test_loader = data_loader(image_root, test_list, shuffle=False, batch_size=512, is_visualization=True) 218 | acc = test_vis(model, test_loader, model_path, threshlod, margin, is_cuda=True, is_visualization=True) 219 | print('best_threshold : {:.03f}, best_accuracy:{:.03f}'.format(threshlod, acc)) 220 | 221 | if __name__ == '__main__': 222 | test_offline() 223 | #evaluation() 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # -------------------------------------------------------- 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Chao CHEN (chaochancs@gmail.com) 6 | # Created On: 2017-06-15 7 | # -------------------------------------------------------- 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.optim as optim 12 | from torch.autograd import Variable 13 | from module.symbols import Net, save_weights 14 | from tools.utils import AverageMeter, logging 15 | from tools.test import best_test 16 | from module.TripletLoss import TripletMarginLoss 17 | 18 | def train_val(model, optimizer, train_loader, test_loader, 19 | epoch, margin=1.0, use_ohem=False, log_interval=100, test_interval=2000, is_cuda=True): 20 | loss = AverageMeter() 21 | batch_num = len(train_loader) 22 | for batch_idx, (data_a, data_p,data_n, target) in enumerate(train_loader): 23 | model.train() 24 | if is_cuda: 25 | data_a = data_a.cuda() 26 | data_p = data_p.cuda() 27 | data_n = data_n.cuda() 28 | #target = target.cuda() 29 | #print('data_size = ',data_a.size()) 30 | #print(data_a) 31 | #print('-----------------------------------------') 32 | data_a = Variable(data_a) 33 | data_p = Variable(data_p) 34 | data_n = Variable(data_n) 35 | target = Variable(target) 36 | 37 | optimizer.zero_grad() 38 | out_a = model(data_a) 39 | out_p = model(data_p) 40 | out_n = model(data_n) 41 | 42 | triploss_layer = TripletMarginLoss(margin, use_ohem=use_ohem) 43 | trip_loss = triploss_layer(out_a, out_p, out_n) 44 | 45 | trip_loss.backward() 46 | optimizer.step() 47 | 48 | loss.update(trip_loss.data[0]) 49 | if (batch_idx+1) % log_interval == 0: 50 | logging('Train-Epoch:{:04d}\tbatch:{:06d}/{:06d}\tloss:{:.04f}'\ 51 | .format(epoch, batch_idx+1, batch_num, trip_loss.data[0])) 52 | if (batch_idx+1) % test_interval == 0: 53 | threshlod, accuracy , mean_d_a_p, mean_d_a_n = best_test(model, test_loader) 54 | logging('Test-T-A Epoch {:04d}-{:06d} accuracy: {:.04f} threshold: {:.05} ap_mean: {:.04f} an_mean: {:.04f}' 55 | .format(epoch, batch_idx+1, accuracy, threshlod, mean_d_a_p, mean_d_a_n)) 56 | cutoff = len(model.module.feat_model._modules) 57 | model_name = 'models/epoch_{:04d}-{:06d}_feat.weights'.format(epoch, batch_idx+1) 58 | save_weights(model.module.feat_model, model_name, cutoff) 59 | logging('save model: {:s}'.format(model_name)) 60 | 61 | 62 | -------------------------------------------------------------------------------- /tools/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # -------------------------------------------------------- 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Chao CHEN (chaochancs@gmail.com) 6 | # Created On: 2017-06-08 7 | # -------------------------------------------------------- 8 | import torch 9 | from torchvision import datasets, transforms 10 | import dataset.dataset as dataset 11 | import dataset.transforms as trans 12 | import time 13 | 14 | def logging(message): 15 | print('%s %s' % (time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), message)) 16 | 17 | class AverageMeter(object): 18 | """Computes and stores the average and current value""" 19 | def __init__(self): 20 | self.reset() 21 | 22 | def reset(self): 23 | self.val = 0 24 | self.avg = 0 25 | self.sum = 0 26 | self.count = 0 27 | self.pos = 0 28 | self.neg = 0 29 | def update(self, val, n=1,pos=0.0,neg=0.0): 30 | self.val = val 31 | self.sum += val * n 32 | self.count += n 33 | self.avg = self.sum / self.count 34 | self.pos += pos 35 | self.neg += neg 36 | 37 | 38 | def data_loader(image_root, data_list, shuffle=True, batch_size=64, workers=20, is_cuda=True, is_visualization=False): 39 | kwargs = {'num_workers': workers, 'pin_memory': True} if is_cuda else {} 40 | transform=transforms.Compose([ 41 | trans.person_crop(ratio=(1, 0.75),crop_type=1),\ 42 | trans.scale(size=(64, 128)),\ 43 | transforms.ToTensor() 44 | ]) 45 | preid = dataset.listDataset( 46 | image_root, 47 | data_list, 48 | shuffle, 49 | transform=transform, 50 | is_visualization=is_visualization) 51 | 52 | data_loader = torch.utils.data.DataLoader(preid, 53 | batch_size=batch_size, 54 | shuffle=True, 55 | **kwargs) 56 | 57 | return data_loader 58 | 59 | def read_data_cfg(datacfg): 60 | options = dict() 61 | with open(datacfg, 'r') as fp: 62 | lines = fp.readlines() 63 | 64 | for line in lines: 65 | line = line.strip() 66 | if line == '': 67 | continue 68 | key,value = line.split('=') 69 | key = key.strip() 70 | value = value.strip() 71 | options[key] = value 72 | return options 73 | 74 | 75 | def adjust_learning_rate(optimizer, epoch, epoch_step, learning_rate): 76 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 77 | lr_step = 0.1 78 | lr = learning_rate * (lr_step ** (epoch // epoch_step)) 79 | for param_group in optimizer.param_groups: 80 | param_group['lr'] = lr 81 | if epoch % epoch_step == 0: 82 | logging('lr = %f' % (lr)) 83 | -------------------------------------------------------------------------------- /train_val.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os,sys 3 | import argparse 4 | import torch 5 | import torch.optim as optim 6 | import dataset.dataset as dataset 7 | from module.symbols import Net 8 | from tools.utils import data_loader, logging, adjust_learning_rate 9 | from tools.train import train_val 10 | 11 | # Training settings 12 | parser = argparse.ArgumentParser(description='Person Re-Identify') 13 | parser.add_argument('--gpus', type=str, default='0,1,2,3', 14 | help='gpus split with , (default: 0)') 15 | parser.add_argument('--seed', type=int, default=1, metavar='S', 16 | help='random seed (default: 1)') 17 | 18 | args = parser.parse_args() 19 | print(args) 20 | DEBUG = False 21 | is_cuda = True 22 | margin = 1.0 23 | <<<<<<< HEAD 24 | lr = 0.02 25 | momentum = 0.9 26 | epoch_step = 5 27 | batch_size = 256 28 | models_dir = 'models' 29 | if not os.path.exists(models_dir): 30 | os.makedirs(models_dir) 31 | ======= 32 | lr = 0.04 33 | momentum = 0.9 34 | epoch_step = 5 35 | batch_size = 512 36 | >>>>>>> 44763a74cdcfb3786ec341f406d57e7d2f9ac1cf 37 | ####################################### 38 | model = Net() 39 | model.load_weights('data/reid_96.86.weights') 40 | print(model) 41 | if is_cuda: 42 | <<<<<<< HEAD 43 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus 44 | print('GPU ID: {:s}'.format(os.environ['CUDA_VISIBLE_DEVICES'])) 45 | ======= 46 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 47 | >>>>>>> 44763a74cdcfb3786ec341f406d57e7d2f9ac1cf 48 | torch.cuda.manual_seed(args.seed) 49 | model = torch.nn.DataParallel(model).cuda() 50 | optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) 51 | 52 | ####################################### 53 | image_root = '/home/chenchao/ReID/' 54 | #image_root = '/data/chenchao/reid_train/pereid-master/data' 55 | if DEBUG: 56 | train_list = 'data/tmp.txt' 57 | test_list = 'data/test_debug.txt' 58 | val_interval = 100 59 | log_interval = 50 60 | else: 61 | train_list = 'data/train.txt' 62 | test_list = 'data/val.txt' 63 | <<<<<<< HEAD 64 | val_interval = 1000 65 | ======= 66 | val_interval = 500 67 | >>>>>>> 44763a74cdcfb3786ec341f406d57e7d2f9ac1cf 68 | log_interval = 100 69 | train_loader = data_loader(image_root, train_list, shuffle=True, batch_size=batch_size) 70 | test_loader = data_loader(image_root, test_list, shuffle=True, batch_size=batch_size) 71 | 72 | for epoch in range(1, 20): 73 | adjust_learning_rate(optimizer, epoch, epoch_step=epoch_step, learning_rate=lr) 74 | train_val(model,optimizer, train_loader, test_loader, epoch, margin, use_ohem=False, 75 | log_interval=log_interval, test_interval=val_interval) 76 | 77 | 78 | 79 | 80 | 81 | --------------------------------------------------------------------------------