├── README.md ├── data_utils.py ├── model.py ├── resnet.py ├── results ├── car_cub.png ├── sop_isc.png └── structure.png ├── test.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # CGD 2 | A PyTorch implementation of CGD based on the paper [Combination of Multiple Global Descriptors for Image Retrieval](https://arxiv.org/abs/1903.10663v3). 3 | 4 | ![Network Architecture image from the paper](results/structure.png) 5 | 6 | ## Requirements 7 | - [Anaconda](https://www.anaconda.com/download/) 8 | - [PyTorch](https://pytorch.org) 9 | ``` 10 | conda install pytorch torchvision cudatoolkit=10.0 -c pytorch 11 | ``` 12 | - thop 13 | ``` 14 | pip install thop 15 | ``` 16 | 17 | ## Datasets 18 | [CARS196](http://ai.stanford.edu/~jkrause/cars/car_dataset.html), [CUB200-2011](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html), 19 | [Standard Online Products](http://cvgl.stanford.edu/projects/lifted_struct/) and 20 | [In-shop Clothes](http://mmlab.ie.cuhk.edu.hk/projects/DeepFashion/InShopRetrieval.html) are used in this repo. 21 | 22 | You should download these datasets by yourself, and extract them into `${data_path}` directory, make sure the dir names are 23 | `car`, `cub`, `sop` and `isc`. Then run `data_utils.py` to preprocess them. 24 | 25 | ## Usage 26 | ### Train CGD 27 | ``` 28 | python train.py --feature_dim 512 --gd_config SM 29 | optional arguments: 30 | --data_path datasets path [default value is '/home/data'] 31 | --data_name dataset name [default value is 'car'](choices=['car', 'cub', 'sop', 'isc']) 32 | --crop_type crop data or not, it only works for car or cub dataset [default value is 'uncropped'](choices=['uncropped', 'cropped']) 33 | --backbone_type backbone network type [default value is 'resnet50'](choices=['resnet50', 'resnext50']) 34 | --gd_config global descriptors config [default value is 'SG'](choices=['S', 'M', 'G', 'SM', 'MS', 'SG', 'GS', 'MG', 'GM', 'SMG', 'MSG', 'GSM']) 35 | --feature_dim feature dim [default value is 1536] 36 | --smoothing smoothing value for label smoothing [default value is 0.1] 37 | --temperature temperature scaling used in softmax cross-entropy loss [default value is 0.5] 38 | --margin margin of m for triplet loss [default value is 0.1] 39 | --recalls selected recall [default value is '1,2,4,8'] 40 | --batch_size train batch size [default value is 128] 41 | --num_epochs train epoch number [default value is 20] 42 | ``` 43 | 44 | ### Test CGD 45 | ``` 46 | python test.py --retrieval_num 10 47 | optional arguments: 48 | --query_img_name query image name [default value is '/home/data/car/uncropped/008055.jpg'] 49 | --data_base queried database [default value is 'car_uncropped_resnet50_SG_1536_0.1_0.5_0.1_128_data_base.pth'] 50 | --retrieval_num retrieval number [default value is 8] 51 | ``` 52 | 53 | ## Benchmarks 54 | The models are trained on one NVIDIA Tesla V100 (32G) GPU with 20 epochs, 55 | the learning rate is decayed by 10 on 12th and 16th epoch. 56 | 57 | ### Model Parameters and FLOPs (Params | FLOPs) 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 |
BackboneCARS196CUB200SOPIn-shop
ResNet5026.86M | 10.64G26.86M | 10.64G49.85M | 10.69G34.85M | 10.66G
ResNeXt5026.33M | 10.84G26.33M | 10.84G49.32M | 10.89G34.32M | 10.86G
85 | 86 | ### CARS196 (Uncropped | Cropped) 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 |
BackboneR@1R@2R@4R@8Download Link
ResNet50(SG)86.4% | 92.4%92.1% | 96.1%95.6% | 97.8%97.5% | 98.7%r3sn | sf5s
ResNeXt50(SG)86.4% | 91.7%92.0% | 95.4%95.4% | 97.3%97.6% | 98.6%dsdx | fh72
117 | 118 | ### CUB200 (Uncropped | Cropped) 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 |
BackboneR@1R@2R@4R@8Download Link
ResNet50(MG)66.0% | 73.9%76.4% | 83.1%84.8% | 89.6%90.7% | 94.0%2cfi | pi4q
ResNeXt50(MG)66.1% | 73.7%76.3% | 82.6%84.0% | 89.0%90.1% | 93.3%nm9h | 6mkf
149 | 150 | ### SOP 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 |
BackboneR@1R@10R@100R@1000Download Link
ResNet50(SG)79.3%90.6%95.8%98.6%qgsn
ResNeXt50(SG)71.0%85.3%93.5%97.9%uexd
181 | 182 | ### In-shop 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 |
BackboneR@1R@10R@20R@30R@40R@50Download Link
ResNet50(GS)83.6%95.7%97.1%97.7%98.1%98.4%8jmp
ResNeXt50(GS)85.0%96.1%97.3%97.9%98.2%98.4%wdq5
219 | 220 | ## Results 221 | 222 | ### CAR/CUB (Uncropped | Cropped) 223 | 224 | ![CAR/CUB](results/car_cub.png) 225 | 226 | ### SOP/ISC 227 | 228 | ![SOP/ISC](results/sop_isc.png) -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | from PIL import Image 6 | from scipy.io import loadmat 7 | from tqdm import tqdm 8 | 9 | 10 | def read_txt(path, data_num): 11 | data = {} 12 | for line in open(path, 'r', encoding='utf-8'): 13 | if data_num == 2: 14 | data_1, data_2 = line.split() 15 | else: 16 | data_1, data_2, data_3, data_4, data_5 = line.split() 17 | data_2 = [data_2, data_3, data_4, data_5] 18 | data[data_1] = data_2 19 | return data 20 | 21 | 22 | def process_car_data(data_path, data_type): 23 | if not os.path.exists('{}/{}'.format(data_path, data_type)): 24 | os.mkdir('{}/{}'.format(data_path, data_type)) 25 | train_images, test_images = {}, {} 26 | annotations = loadmat('{}/cars_annos.mat'.format(data_path))['annotations'][0] 27 | for img in tqdm(annotations, desc='process {} data for car dataset'.format(data_type)): 28 | img_name, img_label = str(img[0][0]), str(img[5][0][0]) 29 | if data_type == 'uncropped': 30 | img = Image.open('{}/{}'.format(data_path, img_name)).convert('RGB') 31 | else: 32 | x1, y1, x2, y2 = int(img[1][0][0]), int(img[2][0][0]), int(img[3][0][0]), int(img[4][0][0]) 33 | img = Image.open('{}/{}'.format(data_path, img_name)).convert('RGB').crop((x1, y1, x2, y2)) 34 | save_name = '{}/{}/{}'.format(data_path, data_type, os.path.basename(img_name)) 35 | img.save(save_name) 36 | if int(img_label) < 99: 37 | if img_label in train_images: 38 | train_images[img_label].append(save_name) 39 | else: 40 | train_images[img_label] = [save_name] 41 | else: 42 | if img_label in test_images: 43 | test_images[img_label].append(save_name) 44 | else: 45 | test_images[img_label] = [save_name] 46 | torch.save({'train': train_images, 'test': test_images}, '{}/{}_data_dicts.pth'.format(data_path, data_type)) 47 | 48 | 49 | def process_cub_data(data_path, data_type): 50 | if not os.path.exists('{}/{}'.format(data_path, data_type)): 51 | os.mkdir('{}/{}'.format(data_path, data_type)) 52 | images = read_txt('{}/images.txt'.format(data_path), 2) 53 | labels = read_txt('{}/image_class_labels.txt'.format(data_path), 2) 54 | bounding_boxes = read_txt('{}/bounding_boxes.txt'.format(data_path), 5) 55 | train_images, test_images = {}, {} 56 | for img_id, img_name in tqdm(images.items(), desc='process {} data for cub dataset'.format(data_type)): 57 | if data_type == 'uncropped': 58 | img = Image.open('{}/images/{}'.format(data_path, img_name)).convert('RGB') 59 | else: 60 | x1, y1 = int(float(bounding_boxes[img_id][0])), int(float(bounding_boxes[img_id][1])) 61 | x2, y2 = x1 + int(float(bounding_boxes[img_id][2])), y1 + int(float(bounding_boxes[img_id][3])) 62 | img = Image.open('{}/images/{}'.format(data_path, img_name)).convert('RGB').crop((x1, y1, x2, y2)) 63 | save_name = '{}/{}/{}'.format(data_path, data_type, os.path.basename(img_name)) 64 | img.save(save_name) 65 | if int(labels[img_id]) < 101: 66 | if labels[img_id] in train_images: 67 | train_images[labels[img_id]].append(save_name) 68 | else: 69 | train_images[labels[img_id]] = [save_name] 70 | else: 71 | if labels[img_id] in test_images: 72 | test_images[labels[img_id]].append(save_name) 73 | else: 74 | test_images[labels[img_id]] = [save_name] 75 | torch.save({'train': train_images, 'test': test_images}, '{}/{}_data_dicts.pth'.format(data_path, data_type)) 76 | 77 | 78 | def process_sop_data(data_path): 79 | if not os.path.exists('{}/uncropped'.format(data_path)): 80 | os.mkdir('{}/uncropped'.format(data_path)) 81 | train_images, test_images = {}, {} 82 | data_tuple = {'train': train_images, 'test': test_images} 83 | for data_type, image_list in data_tuple.items(): 84 | for index, line in enumerate(open('{}/Ebay_{}.txt'.format(data_path, data_type), 'r', encoding='utf-8')): 85 | if index != 0: 86 | _, label, _, img_name = line.split() 87 | img = Image.open('{}/{}'.format(data_path, img_name)).convert('RGB') 88 | save_name = '{}/uncropped/{}'.format(data_path, os.path.basename(img_name)) 89 | img.save(save_name) 90 | if label in image_list: 91 | image_list[label].append(save_name) 92 | else: 93 | image_list[label] = [save_name] 94 | torch.save({'train': train_images, 'test': test_images}, '{}/uncropped_data_dicts.pth'.format(data_path)) 95 | 96 | 97 | def process_isc_data(data_path): 98 | if not os.path.exists('{}/uncropped'.format(data_path)): 99 | os.mkdir('{}/uncropped'.format(data_path)) 100 | train_images, query_images, gallery_images = {}, {}, {} 101 | for index, line in enumerate(open('{}/Eval/list_eval_partition.txt'.format(data_path), 'r', encoding='utf-8')): 102 | if index > 1: 103 | img_name, label, status = line.split() 104 | img = Image.open('{}/Img/{}'.format(data_path, img_name)).convert('RGB') 105 | save_name = '{}/uncropped/{}_{}'.format(data_path, img_name.split('/')[-2], os.path.basename(img_name)) 106 | img.save(save_name) 107 | if status == 'train': 108 | if label in train_images: 109 | train_images[label].append(save_name) 110 | else: 111 | train_images[label] = [save_name] 112 | elif status == 'query': 113 | if label in query_images: 114 | query_images[label].append(save_name) 115 | else: 116 | query_images[label] = [save_name] 117 | elif status == 'gallery': 118 | if label in gallery_images: 119 | gallery_images[label].append(save_name) 120 | else: 121 | gallery_images[label] = [save_name] 122 | 123 | torch.save({'train': train_images, 'query': query_images, 'gallery': gallery_images}, 124 | '{}/uncropped_data_dicts.pth'.format(data_path)) 125 | 126 | 127 | if __name__ == '__main__': 128 | parser = argparse.ArgumentParser(description='Process datasets') 129 | parser.add_argument('--data_path', default='/home/data', type=str, help='datasets path') 130 | 131 | opt = parser.parse_args() 132 | 133 | process_car_data('{}/car'.format(opt.data_path), 'uncropped') 134 | process_car_data('{}/car'.format(opt.data_path), 'cropped') 135 | process_cub_data('{}/cub'.format(opt.data_path), 'uncropped') 136 | process_cub_data('{}/cub'.format(opt.data_path), 'cropped') 137 | print('processing sop dataset') 138 | process_sop_data('{}/sop'.format(opt.data_path)) 139 | print('processing isc dataset') 140 | process_isc_data('{}/isc'.format(opt.data_path)) 141 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from resnet import resnet50, resnext50_32x4d 6 | 7 | 8 | def set_bn_eval(m): 9 | classname = m.__class__.__name__ 10 | if classname.find('BatchNorm2d') != -1: 11 | m.eval() 12 | 13 | 14 | class GlobalDescriptor(nn.Module): 15 | def __init__(self, p=1): 16 | super().__init__() 17 | self.p = p 18 | 19 | def forward(self, x): 20 | assert x.dim() == 4, 'the input tensor of GlobalDescriptor must be the shape of [B, C, H, W]' 21 | if self.p == 1: 22 | return x.mean(dim=[-1, -2]) 23 | elif self.p == float('inf'): 24 | return torch.flatten(F.adaptive_max_pool2d(x, output_size=(1, 1)), start_dim=1) 25 | else: 26 | sum_value = x.pow(self.p).mean(dim=[-1, -2]) 27 | return torch.sign(sum_value) * (torch.abs(sum_value).pow(1.0 / self.p)) 28 | 29 | def extra_repr(self): 30 | return 'p={}'.format(self.p) 31 | 32 | 33 | class L2Norm(nn.Module): 34 | def __init__(self): 35 | super().__init__() 36 | 37 | def forward(self, x): 38 | assert x.dim() == 2, 'the input tensor of L2Norm must be the shape of [B, C]' 39 | return F.normalize(x, p=2, dim=-1) 40 | 41 | 42 | class Model(nn.Module): 43 | def __init__(self, backbone_type, gd_config, feature_dim, num_classes): 44 | super().__init__() 45 | 46 | # Backbone Network 47 | backbone = resnet50(pretrained=True) if backbone_type == 'resnet50' else resnext50_32x4d(pretrained=True) 48 | self.features = [] 49 | for name, module in backbone.named_children(): 50 | if isinstance(module, nn.AdaptiveAvgPool2d) or isinstance(module, nn.Linear): 51 | continue 52 | self.features.append(module) 53 | self.features = nn.Sequential(*self.features) 54 | 55 | # Main Module 56 | n = len(gd_config) 57 | k = feature_dim // n 58 | assert feature_dim % n == 0, 'the feature dim should be divided by number of global descriptors' 59 | 60 | self.global_descriptors, self.main_modules = [], [] 61 | for i in range(n): 62 | if gd_config[i] == 'S': 63 | p = 1 64 | elif gd_config[i] == 'M': 65 | p = float('inf') 66 | else: 67 | p = 3 68 | self.global_descriptors.append(GlobalDescriptor(p=p)) 69 | self.main_modules.append(nn.Sequential(nn.Linear(2048, k, bias=False), L2Norm())) 70 | self.global_descriptors = nn.ModuleList(self.global_descriptors) 71 | self.main_modules = nn.ModuleList(self.main_modules) 72 | 73 | # Auxiliary Module 74 | self.auxiliary_module = nn.Sequential(nn.BatchNorm1d(2048), nn.Linear(2048, num_classes, bias=True)) 75 | 76 | def forward(self, x): 77 | shared = self.features(x) 78 | global_descriptors = [] 79 | for i in range(len(self.global_descriptors)): 80 | global_descriptor = self.global_descriptors[i](shared) 81 | if i == 0: 82 | classes = self.auxiliary_module(global_descriptor) 83 | global_descriptor = self.main_modules[i](global_descriptor) 84 | global_descriptors.append(global_descriptor) 85 | global_descriptors = F.normalize(torch.cat(global_descriptors, dim=-1), dim=-1) 86 | return global_descriptors, classes 87 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision.models.utils import load_state_dict_from_url 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 7 | 'wide_resnet50_2', 'wide_resnet101_2'] 8 | 9 | model_urls = { 10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 15 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 16 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 17 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 18 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 19 | } 20 | 21 | 22 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 23 | """3x3 convolution with padding""" 24 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 25 | padding=dilation, groups=groups, bias=False, dilation=dilation) 26 | 27 | 28 | def conv1x1(in_planes, out_planes, stride=1): 29 | """1x1 convolution""" 30 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 31 | 32 | 33 | class BasicBlock(nn.Module): 34 | expansion = 1 35 | __constants__ = ['downsample'] 36 | 37 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 38 | base_width=64, dilation=1, norm_layer=None): 39 | super(BasicBlock, self).__init__() 40 | if norm_layer is None: 41 | norm_layer = nn.BatchNorm2d 42 | if groups != 1 or base_width != 64: 43 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 44 | if dilation > 1: 45 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 46 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 47 | self.conv1 = conv3x3(inplanes, planes, stride) 48 | self.bn1 = norm_layer(planes) 49 | self.relu = nn.ReLU(inplace=True) 50 | self.conv2 = conv3x3(planes, planes) 51 | self.bn2 = norm_layer(planes) 52 | self.downsample = downsample 53 | self.stride = stride 54 | 55 | def forward(self, x): 56 | identity = x 57 | 58 | out = self.conv1(x) 59 | out = self.bn1(out) 60 | out = self.relu(out) 61 | 62 | out = self.conv2(out) 63 | out = self.bn2(out) 64 | 65 | if self.downsample is not None: 66 | identity = self.downsample(x) 67 | 68 | out += identity 69 | out = self.relu(out) 70 | 71 | return out 72 | 73 | 74 | class Bottleneck(nn.Module): 75 | expansion = 4 76 | __constants__ = ['downsample'] 77 | 78 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 79 | base_width=64, dilation=1, norm_layer=None): 80 | super(Bottleneck, self).__init__() 81 | if norm_layer is None: 82 | norm_layer = nn.BatchNorm2d 83 | width = int(planes * (base_width / 64.)) * groups 84 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 85 | self.conv1 = conv1x1(inplanes, width) 86 | self.bn1 = norm_layer(width) 87 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 88 | self.bn2 = norm_layer(width) 89 | self.conv3 = conv1x1(width, planes * self.expansion) 90 | self.bn3 = norm_layer(planes * self.expansion) 91 | self.relu = nn.ReLU(inplace=True) 92 | self.downsample = downsample 93 | self.stride = stride 94 | 95 | def forward(self, x): 96 | identity = x 97 | 98 | out = self.conv1(x) 99 | out = self.bn1(out) 100 | out = self.relu(out) 101 | 102 | out = self.conv2(out) 103 | out = self.bn2(out) 104 | out = self.relu(out) 105 | 106 | out = self.conv3(out) 107 | out = self.bn3(out) 108 | 109 | if self.downsample is not None: 110 | identity = self.downsample(x) 111 | 112 | out += identity 113 | out = self.relu(out) 114 | 115 | return out 116 | 117 | 118 | class ResNet(nn.Module): 119 | 120 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 121 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 122 | norm_layer=None): 123 | super(ResNet, self).__init__() 124 | if norm_layer is None: 125 | norm_layer = nn.BatchNorm2d 126 | self._norm_layer = norm_layer 127 | 128 | self.inplanes = 64 129 | self.dilation = 1 130 | if replace_stride_with_dilation is None: 131 | # each element in the tuple indicates if we should replace 132 | # the 2x2 stride with a dilated convolution instead 133 | replace_stride_with_dilation = [False, False, False] 134 | if len(replace_stride_with_dilation) != 3: 135 | raise ValueError("replace_stride_with_dilation should be None " 136 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 137 | self.groups = groups 138 | self.base_width = width_per_group 139 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 140 | bias=False) 141 | self.bn1 = norm_layer(self.inplanes) 142 | self.relu = nn.ReLU(inplace=True) 143 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 144 | self.layer1 = self._make_layer(block, 64, layers[0]) 145 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 146 | dilate=replace_stride_with_dilation[0]) 147 | # remove down sample for stage3 148 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, 149 | dilate=replace_stride_with_dilation[1]) 150 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 151 | dilate=replace_stride_with_dilation[2]) 152 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 153 | self.fc = nn.Linear(512 * block.expansion, num_classes) 154 | 155 | for m in self.modules(): 156 | if isinstance(m, nn.Conv2d): 157 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 158 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 159 | nn.init.constant_(m.weight, 1) 160 | nn.init.constant_(m.bias, 0) 161 | 162 | # Zero-initialize the last BN in each residual branch, 163 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 164 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 165 | if zero_init_residual: 166 | for m in self.modules(): 167 | if isinstance(m, Bottleneck): 168 | nn.init.constant_(m.bn3.weight, 0) 169 | elif isinstance(m, BasicBlock): 170 | nn.init.constant_(m.bn2.weight, 0) 171 | 172 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 173 | norm_layer = self._norm_layer 174 | downsample = None 175 | previous_dilation = self.dilation 176 | if dilate: 177 | self.dilation *= stride 178 | stride = 1 179 | if stride != 1 or self.inplanes != planes * block.expansion: 180 | downsample = nn.Sequential( 181 | conv1x1(self.inplanes, planes * block.expansion, stride), 182 | norm_layer(planes * block.expansion), 183 | ) 184 | 185 | layers = [] 186 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 187 | self.base_width, previous_dilation, norm_layer)) 188 | self.inplanes = planes * block.expansion 189 | for _ in range(1, blocks): 190 | layers.append(block(self.inplanes, planes, groups=self.groups, 191 | base_width=self.base_width, dilation=self.dilation, 192 | norm_layer=norm_layer)) 193 | 194 | return nn.Sequential(*layers) 195 | 196 | def _forward_impl(self, x): 197 | # See note [TorchScript super()] 198 | x = self.conv1(x) 199 | x = self.bn1(x) 200 | x = self.relu(x) 201 | x = self.maxpool(x) 202 | 203 | x = self.layer1(x) 204 | x = self.layer2(x) 205 | x = self.layer3(x) 206 | x = self.layer4(x) 207 | 208 | x = self.avgpool(x) 209 | x = torch.flatten(x, 1) 210 | x = self.fc(x) 211 | 212 | return x 213 | 214 | def forward(self, x): 215 | return self._forward_impl(x) 216 | 217 | 218 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 219 | model = ResNet(block, layers, **kwargs) 220 | if pretrained: 221 | state_dict = load_state_dict_from_url(model_urls[arch], 222 | progress=progress) 223 | model.load_state_dict(state_dict) 224 | return model 225 | 226 | 227 | def resnet18(pretrained=False, progress=True, **kwargs): 228 | r"""ResNet-18 model from 229 | `"Deep Residual Learning for Image Recognition" `_ 230 | 231 | Args: 232 | pretrained (bool): If True, returns a model pre-trained on ImageNet 233 | progress (bool): If True, displays a progress bar of the download to stderr 234 | """ 235 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 236 | **kwargs) 237 | 238 | 239 | def resnet34(pretrained=False, progress=True, **kwargs): 240 | r"""ResNet-34 model from 241 | `"Deep Residual Learning for Image Recognition" `_ 242 | 243 | Args: 244 | pretrained (bool): If True, returns a model pre-trained on ImageNet 245 | progress (bool): If True, displays a progress bar of the download to stderr 246 | """ 247 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 248 | **kwargs) 249 | 250 | 251 | def resnet50(pretrained=False, progress=True, **kwargs): 252 | r"""ResNet-50 model from 253 | `"Deep Residual Learning for Image Recognition" `_ 254 | 255 | Args: 256 | pretrained (bool): If True, returns a model pre-trained on ImageNet 257 | progress (bool): If True, displays a progress bar of the download to stderr 258 | """ 259 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 260 | **kwargs) 261 | 262 | 263 | def resnet101(pretrained=False, progress=True, **kwargs): 264 | r"""ResNet-101 model from 265 | `"Deep Residual Learning for Image Recognition" `_ 266 | 267 | Args: 268 | pretrained (bool): If True, returns a model pre-trained on ImageNet 269 | progress (bool): If True, displays a progress bar of the download to stderr 270 | """ 271 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 272 | **kwargs) 273 | 274 | 275 | def resnet152(pretrained=False, progress=True, **kwargs): 276 | r"""ResNet-152 model from 277 | `"Deep Residual Learning for Image Recognition" `_ 278 | 279 | Args: 280 | pretrained (bool): If True, returns a model pre-trained on ImageNet 281 | progress (bool): If True, displays a progress bar of the download to stderr 282 | """ 283 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 284 | **kwargs) 285 | 286 | 287 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 288 | r"""ResNeXt-50 32x4d model from 289 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 290 | 291 | Args: 292 | pretrained (bool): If True, returns a model pre-trained on ImageNet 293 | progress (bool): If True, displays a progress bar of the download to stderr 294 | """ 295 | kwargs['groups'] = 32 296 | kwargs['width_per_group'] = 4 297 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 298 | pretrained, progress, **kwargs) 299 | 300 | 301 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 302 | r"""ResNeXt-101 32x8d model from 303 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 304 | 305 | Args: 306 | pretrained (bool): If True, returns a model pre-trained on ImageNet 307 | progress (bool): If True, displays a progress bar of the download to stderr 308 | """ 309 | kwargs['groups'] = 32 310 | kwargs['width_per_group'] = 8 311 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 312 | pretrained, progress, **kwargs) 313 | 314 | 315 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 316 | r"""Wide ResNet-50-2 model from 317 | `"Wide Residual Networks" `_ 318 | 319 | The model is the same as ResNet except for the bottleneck number of channels 320 | which is twice larger in every block. The number of channels in outer 1x1 321 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 322 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 323 | 324 | Args: 325 | pretrained (bool): If True, returns a model pre-trained on ImageNet 326 | progress (bool): If True, displays a progress bar of the download to stderr 327 | """ 328 | kwargs['width_per_group'] = 64 * 2 329 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 330 | pretrained, progress, **kwargs) 331 | 332 | 333 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 334 | r"""Wide ResNet-101-2 model from 335 | `"Wide Residual Networks" `_ 336 | 337 | The model is the same as ResNet except for the bottleneck number of channels 338 | which is twice larger in every block. The number of channels in outer 1x1 339 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 340 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 341 | 342 | Args: 343 | pretrained (bool): If True, returns a model pre-trained on ImageNet 344 | progress (bool): If True, displays a progress bar of the download to stderr 345 | """ 346 | kwargs['width_per_group'] = 64 * 2 347 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 348 | pretrained, progress, **kwargs) 349 | -------------------------------------------------------------------------------- /results/car_cub.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leftthomas/CGD/c9cd98fcfe4296875509c316bc8536da71ed22d0/results/car_cub.png -------------------------------------------------------------------------------- /results/sop_isc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leftthomas/CGD/c9cd98fcfe4296875509c316bc8536da71ed22d0/results/sop_isc.png -------------------------------------------------------------------------------- /results/structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leftthomas/CGD/c9cd98fcfe4296875509c316bc8536da71ed22d0/results/structure.png -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | 5 | import torch 6 | from PIL import Image, ImageDraw 7 | 8 | if __name__ == '__main__': 9 | parser = argparse.ArgumentParser(description='Test CGD') 10 | parser.add_argument('--query_img_name', default='/home/data/car/uncropped/008055.jpg', type=str, 11 | help='query image name') 12 | parser.add_argument('--data_base', default='car_uncropped_resnet50_SG_1536_0.1_0.5_0.1_128_data_base.pth', 13 | type=str, help='queried database') 14 | parser.add_argument('--retrieval_num', default=8, type=int, help='retrieval number') 15 | 16 | opt = parser.parse_args() 17 | 18 | query_img_name, data_base_name, retrieval_num = opt.query_img_name, opt.data_base, opt.retrieval_num 19 | data_name = data_base_name.split('_')[0] 20 | 21 | data_base = torch.load('results/{}'.format(data_base_name)) 22 | 23 | if query_img_name not in data_base['test_images']: 24 | raise FileNotFoundError('{} not found'.format(query_img_name)) 25 | query_index = data_base['test_images'].index(query_img_name) 26 | query_image = Image.open(query_img_name).convert('RGB').resize((224, 224), resample=Image.BILINEAR) 27 | query_label = torch.tensor(data_base['test_labels'][query_index]) 28 | query_feature = data_base['test_features'][query_index] 29 | 30 | gallery_images = data_base['{}_images'.format('test' if data_name != 'isc' else 'gallery')] 31 | gallery_labels = torch.tensor(data_base['{}_labels'.format('test' if data_name != 'isc' else 'gallery')]) 32 | gallery_features = data_base['{}_features'.format('test' if data_name != 'isc' else 'gallery')] 33 | 34 | dist_matrix = torch.cdist(query_feature.unsqueeze(0).unsqueeze(0), gallery_features.unsqueeze(0)).squeeze() 35 | if data_name != 'isc': 36 | dist_matrix[query_index] = float('inf') 37 | idx = dist_matrix.topk(k=retrieval_num, dim=-1, largest=False)[1] 38 | 39 | result_path = 'results/{}'.format(query_img_name.split('/')[-1].split('.')[0]) 40 | if os.path.exists(result_path): 41 | shutil.rmtree(result_path) 42 | os.mkdir(result_path) 43 | query_image.save('{}/query_img.jpg'.format(result_path)) 44 | for num, index in enumerate(idx): 45 | retrieval_image = Image.open(gallery_images[index.item()]).convert('RGB') \ 46 | .resize((224, 224), resample=Image.BILINEAR) 47 | draw = ImageDraw.Draw(retrieval_image) 48 | retrieval_label = gallery_labels[index.item()] 49 | retrieval_status = (retrieval_label == query_label).item() 50 | retrieval_dist = dist_matrix[index.item()].item() 51 | if retrieval_status: 52 | draw.rectangle((0, 0, 223, 223), outline='green', width=8) 53 | else: 54 | draw.rectangle((0, 0, 223, 223), outline='red', width=8) 55 | retrieval_image.save('{}/retrieval_img_{}_{}.jpg'.format(result_path, num + 1, '%.4f' % retrieval_dist)) 56 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import pandas as pd 4 | import torch 5 | from thop import profile, clever_format 6 | from torch.optim import Adam 7 | from torch.optim.lr_scheduler import MultiStepLR 8 | from torch.utils.data import DataLoader 9 | from tqdm import tqdm 10 | 11 | from model import Model, set_bn_eval 12 | from utils import recall, LabelSmoothingCrossEntropyLoss, BatchHardTripletLoss, ImageReader, MPerClassSampler 13 | 14 | 15 | def train(net, optim): 16 | net.train() 17 | # fix bn on backbone network 18 | net.apply(set_bn_eval) 19 | total_loss, total_correct, total_num, data_bar = 0, 0, 0, tqdm(train_data_loader) 20 | for inputs, labels in data_bar: 21 | inputs, labels = inputs.cuda(), labels.cuda() 22 | features, classes = net(inputs) 23 | class_loss = class_criterion(classes, labels) 24 | feature_loss = feature_criterion(features, labels) 25 | loss = class_loss + feature_loss 26 | optim.zero_grad() 27 | loss.backward() 28 | optim.step() 29 | pred = torch.argmax(classes, dim=-1) 30 | total_loss += loss.item() * inputs.size(0) 31 | total_correct += torch.sum(pred == labels).item() 32 | total_num += inputs.size(0) 33 | data_bar.set_description('Train Epoch {}/{} - Loss:{:.4f} - Acc:{:.2f}%' 34 | .format(epoch, num_epochs, total_loss / total_num, total_correct / total_num * 100)) 35 | 36 | return total_loss / total_num, total_correct / total_num * 100 37 | 38 | 39 | def test(net, recall_ids): 40 | net.eval() 41 | with torch.no_grad(): 42 | # obtain feature vectors for all data 43 | for key in eval_dict.keys(): 44 | eval_dict[key]['features'] = [] 45 | for inputs, labels in tqdm(eval_dict[key]['data_loader'], desc='processing {} data'.format(key)): 46 | inputs, labels = inputs.cuda(), labels.cuda() 47 | features, classes = net(inputs) 48 | eval_dict[key]['features'].append(features) 49 | eval_dict[key]['features'] = torch.cat(eval_dict[key]['features'], dim=0) 50 | 51 | # compute recall metric 52 | if data_name == 'isc': 53 | acc_list = recall(eval_dict['test']['features'], test_data_set.labels, recall_ids, 54 | eval_dict['gallery']['features'], gallery_data_set.labels) 55 | else: 56 | acc_list = recall(eval_dict['test']['features'], test_data_set.labels, recall_ids) 57 | desc = 'Test Epoch {}/{} '.format(epoch, num_epochs) 58 | for index, rank_id in enumerate(recall_ids): 59 | desc += 'R@{}:{:.2f}% '.format(rank_id, acc_list[index] * 100) 60 | results['test_recall@{}'.format(rank_id)].append(acc_list[index] * 100) 61 | print(desc) 62 | return acc_list[0] 63 | 64 | 65 | if __name__ == '__main__': 66 | parser = argparse.ArgumentParser(description='Train CGD') 67 | parser.add_argument('--data_path', default='/home/data', type=str, help='datasets path') 68 | parser.add_argument('--data_name', default='car', type=str, choices=['car', 'cub', 'sop', 'isc'], 69 | help='dataset name') 70 | parser.add_argument('--crop_type', default='uncropped', type=str, choices=['uncropped', 'cropped'], 71 | help='crop data or not, it only works for car or cub dataset') 72 | parser.add_argument('--backbone_type', default='resnet50', type=str, choices=['resnet50', 'resnext50'], 73 | help='backbone network type') 74 | parser.add_argument('--gd_config', default='SG', type=str, 75 | choices=['S', 'M', 'G', 'SM', 'MS', 'SG', 'GS', 'MG', 'GM', 'SMG', 'MSG', 'GSM'], 76 | help='global descriptors config') 77 | parser.add_argument('--feature_dim', default=1536, type=int, help='feature dim') 78 | parser.add_argument('--smoothing', default=0.1, type=float, help='smoothing value for label smoothing') 79 | parser.add_argument('--temperature', default=0.5, type=float, 80 | help='temperature scaling used in softmax cross-entropy loss') 81 | parser.add_argument('--margin', default=0.1, type=float, help='margin of m for triplet loss') 82 | parser.add_argument('--recalls', default='1,2,4,8', type=str, help='selected recall') 83 | parser.add_argument('--batch_size', default=128, type=int, help='train batch size') 84 | parser.add_argument('--num_epochs', default=20, type=int, help='train epoch number') 85 | 86 | opt = parser.parse_args() 87 | # args parse 88 | data_path, data_name, crop_type, backbone_type = opt.data_path, opt.data_name, opt.crop_type, opt.backbone_type 89 | gd_config, feature_dim, smoothing, temperature = opt.gd_config, opt.feature_dim, opt.smoothing, opt.temperature 90 | margin, recalls, batch_size = opt.margin, [int(k) for k in opt.recalls.split(',')], opt.batch_size 91 | num_epochs = opt.num_epochs 92 | save_name_pre = '{}_{}_{}_{}_{}_{}_{}_{}_{}'.format(data_name, crop_type, backbone_type, gd_config, feature_dim, 93 | smoothing, temperature, margin, batch_size) 94 | 95 | results = {'train_loss': [], 'train_accuracy': []} 96 | for recall_id in recalls: 97 | results['test_recall@{}'.format(recall_id)] = [] 98 | 99 | # dataset loader 100 | train_data_set = ImageReader(data_path, data_name, 'train', crop_type) 101 | train_sample = MPerClassSampler(train_data_set.labels, batch_size) 102 | train_data_loader = DataLoader(train_data_set, batch_sampler=train_sample, num_workers=8) 103 | test_data_set = ImageReader(data_path, data_name, 'query' if data_name == 'isc' else 'test', crop_type) 104 | test_data_loader = DataLoader(test_data_set, batch_size, shuffle=False, num_workers=8) 105 | eval_dict = {'test': {'data_loader': test_data_loader}} 106 | if data_name == 'isc': 107 | gallery_data_set = ImageReader(data_path, data_name, 'gallery', crop_type) 108 | gallery_data_loader = DataLoader(gallery_data_set, batch_size, shuffle=False, num_workers=8) 109 | eval_dict['gallery'] = {'data_loader': gallery_data_loader} 110 | 111 | # model setup, model profile, optimizer config and loss definition 112 | model = Model(backbone_type, gd_config, feature_dim, num_classes=len(train_data_set.class_to_idx)).cuda() 113 | flops, params = profile(model, inputs=(torch.randn(1, 3, 224, 224).cuda(),)) 114 | flops, params = clever_format([flops, params]) 115 | print('# Model Params: {} FLOPs: {}'.format(params, flops)) 116 | optimizer = Adam(model.parameters(), lr=1e-4) 117 | lr_scheduler = MultiStepLR(optimizer, milestones=[int(0.6 * num_epochs), int(0.8 * num_epochs)], gamma=0.1) 118 | class_criterion = LabelSmoothingCrossEntropyLoss(smoothing=smoothing, temperature=temperature) 119 | feature_criterion = BatchHardTripletLoss(margin=margin) 120 | 121 | best_recall = 0.0 122 | for epoch in range(1, num_epochs + 1): 123 | train_loss, train_accuracy = train(model, optimizer) 124 | results['train_loss'].append(train_loss) 125 | results['train_accuracy'].append(train_accuracy) 126 | rank = test(model, recalls) 127 | lr_scheduler.step() 128 | 129 | # save statistics 130 | data_frame = pd.DataFrame(data=results, index=range(1, epoch + 1)) 131 | data_frame.to_csv('results/{}_statistics.csv'.format(save_name_pre), index_label='epoch') 132 | # save database and model 133 | data_base = {} 134 | if rank > best_recall: 135 | best_recall = rank 136 | data_base['test_images'] = test_data_set.images 137 | data_base['test_labels'] = test_data_set.labels 138 | data_base['test_features'] = eval_dict['test']['features'] 139 | if data_name == 'isc': 140 | data_base['gallery_images'] = gallery_data_set.images 141 | data_base['gallery_labels'] = gallery_data_set.labels 142 | data_base['gallery_features'] = eval_dict['gallery']['features'] 143 | torch.save(model.state_dict(), 'results/{}_model.pth'.format(save_name_pre)) 144 | torch.save(data_base, 'results/{}_data_base.pth'.format(save_name_pre)) 145 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from PIL import Image 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.utils.data import Dataset 7 | from torch.utils.data.sampler import Sampler 8 | from torchvision import transforms 9 | 10 | 11 | class ImageReader(Dataset): 12 | 13 | def __init__(self, data_path, data_name, data_type, crop_type): 14 | if crop_type == 'cropped' and data_name not in ['car', 'cub']: 15 | raise NotImplementedError('cropped data only works for car or cub dataset') 16 | 17 | data_dict = torch.load('{}/{}/{}_data_dicts.pth'.format(data_path, data_name, crop_type))[data_type] 18 | self.class_to_idx = dict(zip(sorted(data_dict), range(len(data_dict)))) 19 | normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 20 | if data_type == 'train': 21 | self.transform = transforms.Compose([transforms.Resize((252, 252)), transforms.RandomCrop(224), 22 | transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize]) 23 | else: 24 | self.transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), normalize]) 25 | self.images, self.labels = [], [] 26 | for label, image_list in data_dict.items(): 27 | self.images += image_list 28 | self.labels += [self.class_to_idx[label]] * len(image_list) 29 | 30 | def __getitem__(self, index): 31 | path, target = self.images[index], self.labels[index] 32 | img = Image.open(path).convert('RGB') 33 | img = self.transform(img) 34 | return img, target 35 | 36 | def __len__(self): 37 | return len(self.images) 38 | 39 | 40 | def recall(feature_vectors, feature_labels, rank, gallery_vectors=None, gallery_labels=None): 41 | num_features = len(feature_labels) 42 | feature_labels = torch.tensor(feature_labels, device=feature_vectors.device) 43 | gallery_vectors = feature_vectors if gallery_vectors is None else gallery_vectors 44 | 45 | dist_matrix = torch.cdist(feature_vectors.unsqueeze(0), gallery_vectors.unsqueeze(0)).squeeze(0) 46 | 47 | if gallery_labels is None: 48 | dist_matrix.fill_diagonal_(float('inf')) 49 | gallery_labels = feature_labels 50 | else: 51 | gallery_labels = torch.tensor(gallery_labels, device=feature_vectors.device) 52 | 53 | idx = dist_matrix.topk(k=rank[-1], dim=-1, largest=False)[1] 54 | acc_list = [] 55 | for r in rank: 56 | correct = (gallery_labels[idx[:, 0:r]] == feature_labels.unsqueeze(dim=-1)).any(dim=-1).float() 57 | acc_list.append((torch.sum(correct) / num_features).item()) 58 | return acc_list 59 | 60 | 61 | class LabelSmoothingCrossEntropyLoss(nn.Module): 62 | def __init__(self, smoothing=0.1, temperature=1.0): 63 | super().__init__() 64 | self.smoothing = smoothing 65 | self.temperature = temperature 66 | 67 | def forward(self, x, target): 68 | log_probs = F.log_softmax(x / self.temperature, dim=-1) 69 | nll_loss = -log_probs.gather(dim=-1, index=target.unsqueeze(dim=-1)).squeeze(dim=-1) 70 | smooth_loss = -log_probs.mean(dim=-1) 71 | loss = (1.0 - self.smoothing) * nll_loss + self.smoothing * smooth_loss 72 | return loss.mean() 73 | 74 | 75 | class BatchHardTripletLoss(nn.Module): 76 | def __init__(self, margin=1.0): 77 | super().__init__() 78 | self.margin = margin 79 | 80 | @staticmethod 81 | def get_anchor_positive_triplet_mask(target): 82 | mask = torch.eq(target.unsqueeze(0), target.unsqueeze(1)) 83 | mask.fill_diagonal_(False) 84 | return mask 85 | 86 | @staticmethod 87 | def get_anchor_negative_triplet_mask(target): 88 | labels_equal = torch.eq(target.unsqueeze(0), target.unsqueeze(1)) 89 | mask = ~ labels_equal 90 | return mask 91 | 92 | def forward(self, x, target): 93 | pairwise_dist = torch.cdist(x.unsqueeze(0), x.unsqueeze(0)).squeeze(0) 94 | 95 | mask_anchor_positive = self.get_anchor_positive_triplet_mask(target) 96 | anchor_positive_dist = mask_anchor_positive.float() * pairwise_dist 97 | hardest_positive_dist = anchor_positive_dist.max(1, True)[0] 98 | 99 | mask_anchor_negative = self.get_anchor_negative_triplet_mask(target) 100 | # make positive and anchor to be exclusive through maximizing the dist 101 | max_anchor_negative_dist = pairwise_dist.max(1, True)[0] 102 | anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative.float()) 103 | hardest_negative_dist = anchor_negative_dist.min(1, True)[0] 104 | 105 | loss = (F.relu(hardest_positive_dist - hardest_negative_dist + self.margin)) 106 | return loss.mean() 107 | 108 | 109 | class MPerClassSampler(Sampler): 110 | def __init__(self, labels, batch_size, m=4): 111 | self.labels = np.array(labels) 112 | self.labels_unique = np.unique(labels) 113 | self.batch_size = batch_size 114 | self.m = m 115 | assert batch_size % m == 0, 'batch size must be divided by m' 116 | 117 | def __len__(self): 118 | return len(self.labels) // self.batch_size 119 | 120 | def __iter__(self): 121 | for _ in range(self.__len__()): 122 | labels_in_batch = set() 123 | inds = np.array([], dtype=np.int) 124 | 125 | while inds.shape[0] < self.batch_size: 126 | sample_label = np.random.choice(self.labels_unique) 127 | if sample_label in labels_in_batch: 128 | continue 129 | 130 | labels_in_batch.add(sample_label) 131 | sample_label_ids = np.argwhere(np.in1d(self.labels, sample_label)).reshape(-1) 132 | subsample = np.random.permutation(sample_label_ids)[:self.m] 133 | inds = np.append(inds, subsample) 134 | 135 | inds = inds[:self.batch_size] 136 | inds = np.random.permutation(inds) 137 | yield list(inds) 138 | --------------------------------------------------------------------------------