├── README.md
├── data_utils.py
├── model.py
├── resnet.py
├── results
├── car_cub.png
├── sop_isc.png
└── structure.png
├── test.py
├── train.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # CGD
2 | A PyTorch implementation of CGD based on the paper [Combination of Multiple Global Descriptors for Image Retrieval](https://arxiv.org/abs/1903.10663v3).
3 |
4 | 
5 |
6 | ## Requirements
7 | - [Anaconda](https://www.anaconda.com/download/)
8 | - [PyTorch](https://pytorch.org)
9 | ```
10 | conda install pytorch torchvision cudatoolkit=10.0 -c pytorch
11 | ```
12 | - thop
13 | ```
14 | pip install thop
15 | ```
16 |
17 | ## Datasets
18 | [CARS196](http://ai.stanford.edu/~jkrause/cars/car_dataset.html), [CUB200-2011](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html),
19 | [Standard Online Products](http://cvgl.stanford.edu/projects/lifted_struct/) and
20 | [In-shop Clothes](http://mmlab.ie.cuhk.edu.hk/projects/DeepFashion/InShopRetrieval.html) are used in this repo.
21 |
22 | You should download these datasets by yourself, and extract them into `${data_path}` directory, make sure the dir names are
23 | `car`, `cub`, `sop` and `isc`. Then run `data_utils.py` to preprocess them.
24 |
25 | ## Usage
26 | ### Train CGD
27 | ```
28 | python train.py --feature_dim 512 --gd_config SM
29 | optional arguments:
30 | --data_path datasets path [default value is '/home/data']
31 | --data_name dataset name [default value is 'car'](choices=['car', 'cub', 'sop', 'isc'])
32 | --crop_type crop data or not, it only works for car or cub dataset [default value is 'uncropped'](choices=['uncropped', 'cropped'])
33 | --backbone_type backbone network type [default value is 'resnet50'](choices=['resnet50', 'resnext50'])
34 | --gd_config global descriptors config [default value is 'SG'](choices=['S', 'M', 'G', 'SM', 'MS', 'SG', 'GS', 'MG', 'GM', 'SMG', 'MSG', 'GSM'])
35 | --feature_dim feature dim [default value is 1536]
36 | --smoothing smoothing value for label smoothing [default value is 0.1]
37 | --temperature temperature scaling used in softmax cross-entropy loss [default value is 0.5]
38 | --margin margin of m for triplet loss [default value is 0.1]
39 | --recalls selected recall [default value is '1,2,4,8']
40 | --batch_size train batch size [default value is 128]
41 | --num_epochs train epoch number [default value is 20]
42 | ```
43 |
44 | ### Test CGD
45 | ```
46 | python test.py --retrieval_num 10
47 | optional arguments:
48 | --query_img_name query image name [default value is '/home/data/car/uncropped/008055.jpg']
49 | --data_base queried database [default value is 'car_uncropped_resnet50_SG_1536_0.1_0.5_0.1_128_data_base.pth']
50 | --retrieval_num retrieval number [default value is 8]
51 | ```
52 |
53 | ## Benchmarks
54 | The models are trained on one NVIDIA Tesla V100 (32G) GPU with 20 epochs,
55 | the learning rate is decayed by 10 on 12th and 16th epoch.
56 |
57 | ### Model Parameters and FLOPs (Params | FLOPs)
58 |
59 |
60 |
61 | Backbone |
62 | CARS196 |
63 | CUB200 |
64 | SOP |
65 | In-shop |
66 |
67 |
68 |
69 |
70 | ResNet50 |
71 | 26.86M | 10.64G |
72 | 26.86M | 10.64G |
73 | 49.85M | 10.69G |
74 | 34.85M | 10.66G |
75 |
76 |
77 | ResNeXt50 |
78 | 26.33M | 10.84G |
79 | 26.33M | 10.84G |
80 | 49.32M | 10.89G |
81 | 34.32M | 10.86G |
82 |
83 |
84 |
85 |
86 | ### CARS196 (Uncropped | Cropped)
87 |
88 |
89 |
90 | Backbone |
91 | R@1 |
92 | R@2 |
93 | R@4 |
94 | R@8 |
95 | Download Link |
96 |
97 |
98 |
99 |
100 | ResNet50(SG) |
101 | 86.4% | 92.4% |
102 | 92.1% | 96.1% |
103 | 95.6% | 97.8% |
104 | 97.5% | 98.7% |
105 | r3sn | sf5s |
106 |
107 |
108 | ResNeXt50(SG) |
109 | 86.4% | 91.7% |
110 | 92.0% | 95.4% |
111 | 95.4% | 97.3% |
112 | 97.6% | 98.6% |
113 | dsdx | fh72 |
114 |
115 |
116 |
117 |
118 | ### CUB200 (Uncropped | Cropped)
119 |
120 |
121 |
122 | Backbone |
123 | R@1 |
124 | R@2 |
125 | R@4 |
126 | R@8 |
127 | Download Link |
128 |
129 |
130 |
131 |
132 | ResNet50(MG) |
133 | 66.0% | 73.9% |
134 | 76.4% | 83.1% |
135 | 84.8% | 89.6% |
136 | 90.7% | 94.0% |
137 | 2cfi | pi4q |
138 |
139 |
140 | ResNeXt50(MG) |
141 | 66.1% | 73.7% |
142 | 76.3% | 82.6% |
143 | 84.0% | 89.0% |
144 | 90.1% | 93.3% |
145 | nm9h | 6mkf |
146 |
147 |
148 |
149 |
150 | ### SOP
151 |
152 |
153 |
154 | Backbone |
155 | R@1 |
156 | R@10 |
157 | R@100 |
158 | R@1000 |
159 | Download Link |
160 |
161 |
162 |
163 |
164 | ResNet50(SG) |
165 | 79.3% |
166 | 90.6% |
167 | 95.8% |
168 | 98.6% |
169 | qgsn |
170 |
171 |
172 | ResNeXt50(SG) |
173 | 71.0% |
174 | 85.3% |
175 | 93.5% |
176 | 97.9% |
177 | uexd |
178 |
179 |
180 |
181 |
182 | ### In-shop
183 |
184 |
185 |
186 | Backbone |
187 | R@1 |
188 | R@10 |
189 | R@20 |
190 | R@30 |
191 | R@40 |
192 | R@50 |
193 | Download Link |
194 |
195 |
196 |
197 |
198 | ResNet50(GS) |
199 | 83.6% |
200 | 95.7% |
201 | 97.1% |
202 | 97.7% |
203 | 98.1% |
204 | 98.4% |
205 | 8jmp |
206 |
207 |
208 | ResNeXt50(GS) |
209 | 85.0% |
210 | 96.1% |
211 | 97.3% |
212 | 97.9% |
213 | 98.2% |
214 | 98.4% |
215 | wdq5 |
216 |
217 |
218 |
219 |
220 | ## Results
221 |
222 | ### CAR/CUB (Uncropped | Cropped)
223 |
224 | 
225 |
226 | ### SOP/ISC
227 |
228 | 
--------------------------------------------------------------------------------
/data_utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import torch
5 | from PIL import Image
6 | from scipy.io import loadmat
7 | from tqdm import tqdm
8 |
9 |
10 | def read_txt(path, data_num):
11 | data = {}
12 | for line in open(path, 'r', encoding='utf-8'):
13 | if data_num == 2:
14 | data_1, data_2 = line.split()
15 | else:
16 | data_1, data_2, data_3, data_4, data_5 = line.split()
17 | data_2 = [data_2, data_3, data_4, data_5]
18 | data[data_1] = data_2
19 | return data
20 |
21 |
22 | def process_car_data(data_path, data_type):
23 | if not os.path.exists('{}/{}'.format(data_path, data_type)):
24 | os.mkdir('{}/{}'.format(data_path, data_type))
25 | train_images, test_images = {}, {}
26 | annotations = loadmat('{}/cars_annos.mat'.format(data_path))['annotations'][0]
27 | for img in tqdm(annotations, desc='process {} data for car dataset'.format(data_type)):
28 | img_name, img_label = str(img[0][0]), str(img[5][0][0])
29 | if data_type == 'uncropped':
30 | img = Image.open('{}/{}'.format(data_path, img_name)).convert('RGB')
31 | else:
32 | x1, y1, x2, y2 = int(img[1][0][0]), int(img[2][0][0]), int(img[3][0][0]), int(img[4][0][0])
33 | img = Image.open('{}/{}'.format(data_path, img_name)).convert('RGB').crop((x1, y1, x2, y2))
34 | save_name = '{}/{}/{}'.format(data_path, data_type, os.path.basename(img_name))
35 | img.save(save_name)
36 | if int(img_label) < 99:
37 | if img_label in train_images:
38 | train_images[img_label].append(save_name)
39 | else:
40 | train_images[img_label] = [save_name]
41 | else:
42 | if img_label in test_images:
43 | test_images[img_label].append(save_name)
44 | else:
45 | test_images[img_label] = [save_name]
46 | torch.save({'train': train_images, 'test': test_images}, '{}/{}_data_dicts.pth'.format(data_path, data_type))
47 |
48 |
49 | def process_cub_data(data_path, data_type):
50 | if not os.path.exists('{}/{}'.format(data_path, data_type)):
51 | os.mkdir('{}/{}'.format(data_path, data_type))
52 | images = read_txt('{}/images.txt'.format(data_path), 2)
53 | labels = read_txt('{}/image_class_labels.txt'.format(data_path), 2)
54 | bounding_boxes = read_txt('{}/bounding_boxes.txt'.format(data_path), 5)
55 | train_images, test_images = {}, {}
56 | for img_id, img_name in tqdm(images.items(), desc='process {} data for cub dataset'.format(data_type)):
57 | if data_type == 'uncropped':
58 | img = Image.open('{}/images/{}'.format(data_path, img_name)).convert('RGB')
59 | else:
60 | x1, y1 = int(float(bounding_boxes[img_id][0])), int(float(bounding_boxes[img_id][1]))
61 | x2, y2 = x1 + int(float(bounding_boxes[img_id][2])), y1 + int(float(bounding_boxes[img_id][3]))
62 | img = Image.open('{}/images/{}'.format(data_path, img_name)).convert('RGB').crop((x1, y1, x2, y2))
63 | save_name = '{}/{}/{}'.format(data_path, data_type, os.path.basename(img_name))
64 | img.save(save_name)
65 | if int(labels[img_id]) < 101:
66 | if labels[img_id] in train_images:
67 | train_images[labels[img_id]].append(save_name)
68 | else:
69 | train_images[labels[img_id]] = [save_name]
70 | else:
71 | if labels[img_id] in test_images:
72 | test_images[labels[img_id]].append(save_name)
73 | else:
74 | test_images[labels[img_id]] = [save_name]
75 | torch.save({'train': train_images, 'test': test_images}, '{}/{}_data_dicts.pth'.format(data_path, data_type))
76 |
77 |
78 | def process_sop_data(data_path):
79 | if not os.path.exists('{}/uncropped'.format(data_path)):
80 | os.mkdir('{}/uncropped'.format(data_path))
81 | train_images, test_images = {}, {}
82 | data_tuple = {'train': train_images, 'test': test_images}
83 | for data_type, image_list in data_tuple.items():
84 | for index, line in enumerate(open('{}/Ebay_{}.txt'.format(data_path, data_type), 'r', encoding='utf-8')):
85 | if index != 0:
86 | _, label, _, img_name = line.split()
87 | img = Image.open('{}/{}'.format(data_path, img_name)).convert('RGB')
88 | save_name = '{}/uncropped/{}'.format(data_path, os.path.basename(img_name))
89 | img.save(save_name)
90 | if label in image_list:
91 | image_list[label].append(save_name)
92 | else:
93 | image_list[label] = [save_name]
94 | torch.save({'train': train_images, 'test': test_images}, '{}/uncropped_data_dicts.pth'.format(data_path))
95 |
96 |
97 | def process_isc_data(data_path):
98 | if not os.path.exists('{}/uncropped'.format(data_path)):
99 | os.mkdir('{}/uncropped'.format(data_path))
100 | train_images, query_images, gallery_images = {}, {}, {}
101 | for index, line in enumerate(open('{}/Eval/list_eval_partition.txt'.format(data_path), 'r', encoding='utf-8')):
102 | if index > 1:
103 | img_name, label, status = line.split()
104 | img = Image.open('{}/Img/{}'.format(data_path, img_name)).convert('RGB')
105 | save_name = '{}/uncropped/{}_{}'.format(data_path, img_name.split('/')[-2], os.path.basename(img_name))
106 | img.save(save_name)
107 | if status == 'train':
108 | if label in train_images:
109 | train_images[label].append(save_name)
110 | else:
111 | train_images[label] = [save_name]
112 | elif status == 'query':
113 | if label in query_images:
114 | query_images[label].append(save_name)
115 | else:
116 | query_images[label] = [save_name]
117 | elif status == 'gallery':
118 | if label in gallery_images:
119 | gallery_images[label].append(save_name)
120 | else:
121 | gallery_images[label] = [save_name]
122 |
123 | torch.save({'train': train_images, 'query': query_images, 'gallery': gallery_images},
124 | '{}/uncropped_data_dicts.pth'.format(data_path))
125 |
126 |
127 | if __name__ == '__main__':
128 | parser = argparse.ArgumentParser(description='Process datasets')
129 | parser.add_argument('--data_path', default='/home/data', type=str, help='datasets path')
130 |
131 | opt = parser.parse_args()
132 |
133 | process_car_data('{}/car'.format(opt.data_path), 'uncropped')
134 | process_car_data('{}/car'.format(opt.data_path), 'cropped')
135 | process_cub_data('{}/cub'.format(opt.data_path), 'uncropped')
136 | process_cub_data('{}/cub'.format(opt.data_path), 'cropped')
137 | print('processing sop dataset')
138 | process_sop_data('{}/sop'.format(opt.data_path))
139 | print('processing isc dataset')
140 | process_isc_data('{}/isc'.format(opt.data_path))
141 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 | from resnet import resnet50, resnext50_32x4d
6 |
7 |
8 | def set_bn_eval(m):
9 | classname = m.__class__.__name__
10 | if classname.find('BatchNorm2d') != -1:
11 | m.eval()
12 |
13 |
14 | class GlobalDescriptor(nn.Module):
15 | def __init__(self, p=1):
16 | super().__init__()
17 | self.p = p
18 |
19 | def forward(self, x):
20 | assert x.dim() == 4, 'the input tensor of GlobalDescriptor must be the shape of [B, C, H, W]'
21 | if self.p == 1:
22 | return x.mean(dim=[-1, -2])
23 | elif self.p == float('inf'):
24 | return torch.flatten(F.adaptive_max_pool2d(x, output_size=(1, 1)), start_dim=1)
25 | else:
26 | sum_value = x.pow(self.p).mean(dim=[-1, -2])
27 | return torch.sign(sum_value) * (torch.abs(sum_value).pow(1.0 / self.p))
28 |
29 | def extra_repr(self):
30 | return 'p={}'.format(self.p)
31 |
32 |
33 | class L2Norm(nn.Module):
34 | def __init__(self):
35 | super().__init__()
36 |
37 | def forward(self, x):
38 | assert x.dim() == 2, 'the input tensor of L2Norm must be the shape of [B, C]'
39 | return F.normalize(x, p=2, dim=-1)
40 |
41 |
42 | class Model(nn.Module):
43 | def __init__(self, backbone_type, gd_config, feature_dim, num_classes):
44 | super().__init__()
45 |
46 | # Backbone Network
47 | backbone = resnet50(pretrained=True) if backbone_type == 'resnet50' else resnext50_32x4d(pretrained=True)
48 | self.features = []
49 | for name, module in backbone.named_children():
50 | if isinstance(module, nn.AdaptiveAvgPool2d) or isinstance(module, nn.Linear):
51 | continue
52 | self.features.append(module)
53 | self.features = nn.Sequential(*self.features)
54 |
55 | # Main Module
56 | n = len(gd_config)
57 | k = feature_dim // n
58 | assert feature_dim % n == 0, 'the feature dim should be divided by number of global descriptors'
59 |
60 | self.global_descriptors, self.main_modules = [], []
61 | for i in range(n):
62 | if gd_config[i] == 'S':
63 | p = 1
64 | elif gd_config[i] == 'M':
65 | p = float('inf')
66 | else:
67 | p = 3
68 | self.global_descriptors.append(GlobalDescriptor(p=p))
69 | self.main_modules.append(nn.Sequential(nn.Linear(2048, k, bias=False), L2Norm()))
70 | self.global_descriptors = nn.ModuleList(self.global_descriptors)
71 | self.main_modules = nn.ModuleList(self.main_modules)
72 |
73 | # Auxiliary Module
74 | self.auxiliary_module = nn.Sequential(nn.BatchNorm1d(2048), nn.Linear(2048, num_classes, bias=True))
75 |
76 | def forward(self, x):
77 | shared = self.features(x)
78 | global_descriptors = []
79 | for i in range(len(self.global_descriptors)):
80 | global_descriptor = self.global_descriptors[i](shared)
81 | if i == 0:
82 | classes = self.auxiliary_module(global_descriptor)
83 | global_descriptor = self.main_modules[i](global_descriptor)
84 | global_descriptors.append(global_descriptor)
85 | global_descriptors = F.normalize(torch.cat(global_descriptors, dim=-1), dim=-1)
86 | return global_descriptors, classes
87 |
--------------------------------------------------------------------------------
/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torchvision.models.utils import load_state_dict_from_url
4 |
5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
6 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
7 | 'wide_resnet50_2', 'wide_resnet101_2']
8 |
9 | model_urls = {
10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
15 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
16 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
17 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
18 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
19 | }
20 |
21 |
22 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
23 | """3x3 convolution with padding"""
24 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
25 | padding=dilation, groups=groups, bias=False, dilation=dilation)
26 |
27 |
28 | def conv1x1(in_planes, out_planes, stride=1):
29 | """1x1 convolution"""
30 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
31 |
32 |
33 | class BasicBlock(nn.Module):
34 | expansion = 1
35 | __constants__ = ['downsample']
36 |
37 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
38 | base_width=64, dilation=1, norm_layer=None):
39 | super(BasicBlock, self).__init__()
40 | if norm_layer is None:
41 | norm_layer = nn.BatchNorm2d
42 | if groups != 1 or base_width != 64:
43 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
44 | if dilation > 1:
45 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
46 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
47 | self.conv1 = conv3x3(inplanes, planes, stride)
48 | self.bn1 = norm_layer(planes)
49 | self.relu = nn.ReLU(inplace=True)
50 | self.conv2 = conv3x3(planes, planes)
51 | self.bn2 = norm_layer(planes)
52 | self.downsample = downsample
53 | self.stride = stride
54 |
55 | def forward(self, x):
56 | identity = x
57 |
58 | out = self.conv1(x)
59 | out = self.bn1(out)
60 | out = self.relu(out)
61 |
62 | out = self.conv2(out)
63 | out = self.bn2(out)
64 |
65 | if self.downsample is not None:
66 | identity = self.downsample(x)
67 |
68 | out += identity
69 | out = self.relu(out)
70 |
71 | return out
72 |
73 |
74 | class Bottleneck(nn.Module):
75 | expansion = 4
76 | __constants__ = ['downsample']
77 |
78 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
79 | base_width=64, dilation=1, norm_layer=None):
80 | super(Bottleneck, self).__init__()
81 | if norm_layer is None:
82 | norm_layer = nn.BatchNorm2d
83 | width = int(planes * (base_width / 64.)) * groups
84 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
85 | self.conv1 = conv1x1(inplanes, width)
86 | self.bn1 = norm_layer(width)
87 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
88 | self.bn2 = norm_layer(width)
89 | self.conv3 = conv1x1(width, planes * self.expansion)
90 | self.bn3 = norm_layer(planes * self.expansion)
91 | self.relu = nn.ReLU(inplace=True)
92 | self.downsample = downsample
93 | self.stride = stride
94 |
95 | def forward(self, x):
96 | identity = x
97 |
98 | out = self.conv1(x)
99 | out = self.bn1(out)
100 | out = self.relu(out)
101 |
102 | out = self.conv2(out)
103 | out = self.bn2(out)
104 | out = self.relu(out)
105 |
106 | out = self.conv3(out)
107 | out = self.bn3(out)
108 |
109 | if self.downsample is not None:
110 | identity = self.downsample(x)
111 |
112 | out += identity
113 | out = self.relu(out)
114 |
115 | return out
116 |
117 |
118 | class ResNet(nn.Module):
119 |
120 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
121 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
122 | norm_layer=None):
123 | super(ResNet, self).__init__()
124 | if norm_layer is None:
125 | norm_layer = nn.BatchNorm2d
126 | self._norm_layer = norm_layer
127 |
128 | self.inplanes = 64
129 | self.dilation = 1
130 | if replace_stride_with_dilation is None:
131 | # each element in the tuple indicates if we should replace
132 | # the 2x2 stride with a dilated convolution instead
133 | replace_stride_with_dilation = [False, False, False]
134 | if len(replace_stride_with_dilation) != 3:
135 | raise ValueError("replace_stride_with_dilation should be None "
136 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
137 | self.groups = groups
138 | self.base_width = width_per_group
139 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
140 | bias=False)
141 | self.bn1 = norm_layer(self.inplanes)
142 | self.relu = nn.ReLU(inplace=True)
143 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
144 | self.layer1 = self._make_layer(block, 64, layers[0])
145 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
146 | dilate=replace_stride_with_dilation[0])
147 | # remove down sample for stage3
148 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1,
149 | dilate=replace_stride_with_dilation[1])
150 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
151 | dilate=replace_stride_with_dilation[2])
152 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
153 | self.fc = nn.Linear(512 * block.expansion, num_classes)
154 |
155 | for m in self.modules():
156 | if isinstance(m, nn.Conv2d):
157 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
158 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
159 | nn.init.constant_(m.weight, 1)
160 | nn.init.constant_(m.bias, 0)
161 |
162 | # Zero-initialize the last BN in each residual branch,
163 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
164 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
165 | if zero_init_residual:
166 | for m in self.modules():
167 | if isinstance(m, Bottleneck):
168 | nn.init.constant_(m.bn3.weight, 0)
169 | elif isinstance(m, BasicBlock):
170 | nn.init.constant_(m.bn2.weight, 0)
171 |
172 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
173 | norm_layer = self._norm_layer
174 | downsample = None
175 | previous_dilation = self.dilation
176 | if dilate:
177 | self.dilation *= stride
178 | stride = 1
179 | if stride != 1 or self.inplanes != planes * block.expansion:
180 | downsample = nn.Sequential(
181 | conv1x1(self.inplanes, planes * block.expansion, stride),
182 | norm_layer(planes * block.expansion),
183 | )
184 |
185 | layers = []
186 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
187 | self.base_width, previous_dilation, norm_layer))
188 | self.inplanes = planes * block.expansion
189 | for _ in range(1, blocks):
190 | layers.append(block(self.inplanes, planes, groups=self.groups,
191 | base_width=self.base_width, dilation=self.dilation,
192 | norm_layer=norm_layer))
193 |
194 | return nn.Sequential(*layers)
195 |
196 | def _forward_impl(self, x):
197 | # See note [TorchScript super()]
198 | x = self.conv1(x)
199 | x = self.bn1(x)
200 | x = self.relu(x)
201 | x = self.maxpool(x)
202 |
203 | x = self.layer1(x)
204 | x = self.layer2(x)
205 | x = self.layer3(x)
206 | x = self.layer4(x)
207 |
208 | x = self.avgpool(x)
209 | x = torch.flatten(x, 1)
210 | x = self.fc(x)
211 |
212 | return x
213 |
214 | def forward(self, x):
215 | return self._forward_impl(x)
216 |
217 |
218 | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
219 | model = ResNet(block, layers, **kwargs)
220 | if pretrained:
221 | state_dict = load_state_dict_from_url(model_urls[arch],
222 | progress=progress)
223 | model.load_state_dict(state_dict)
224 | return model
225 |
226 |
227 | def resnet18(pretrained=False, progress=True, **kwargs):
228 | r"""ResNet-18 model from
229 | `"Deep Residual Learning for Image Recognition" `_
230 |
231 | Args:
232 | pretrained (bool): If True, returns a model pre-trained on ImageNet
233 | progress (bool): If True, displays a progress bar of the download to stderr
234 | """
235 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
236 | **kwargs)
237 |
238 |
239 | def resnet34(pretrained=False, progress=True, **kwargs):
240 | r"""ResNet-34 model from
241 | `"Deep Residual Learning for Image Recognition" `_
242 |
243 | Args:
244 | pretrained (bool): If True, returns a model pre-trained on ImageNet
245 | progress (bool): If True, displays a progress bar of the download to stderr
246 | """
247 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
248 | **kwargs)
249 |
250 |
251 | def resnet50(pretrained=False, progress=True, **kwargs):
252 | r"""ResNet-50 model from
253 | `"Deep Residual Learning for Image Recognition" `_
254 |
255 | Args:
256 | pretrained (bool): If True, returns a model pre-trained on ImageNet
257 | progress (bool): If True, displays a progress bar of the download to stderr
258 | """
259 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
260 | **kwargs)
261 |
262 |
263 | def resnet101(pretrained=False, progress=True, **kwargs):
264 | r"""ResNet-101 model from
265 | `"Deep Residual Learning for Image Recognition" `_
266 |
267 | Args:
268 | pretrained (bool): If True, returns a model pre-trained on ImageNet
269 | progress (bool): If True, displays a progress bar of the download to stderr
270 | """
271 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
272 | **kwargs)
273 |
274 |
275 | def resnet152(pretrained=False, progress=True, **kwargs):
276 | r"""ResNet-152 model from
277 | `"Deep Residual Learning for Image Recognition" `_
278 |
279 | Args:
280 | pretrained (bool): If True, returns a model pre-trained on ImageNet
281 | progress (bool): If True, displays a progress bar of the download to stderr
282 | """
283 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
284 | **kwargs)
285 |
286 |
287 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
288 | r"""ResNeXt-50 32x4d model from
289 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
290 |
291 | Args:
292 | pretrained (bool): If True, returns a model pre-trained on ImageNet
293 | progress (bool): If True, displays a progress bar of the download to stderr
294 | """
295 | kwargs['groups'] = 32
296 | kwargs['width_per_group'] = 4
297 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
298 | pretrained, progress, **kwargs)
299 |
300 |
301 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
302 | r"""ResNeXt-101 32x8d model from
303 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
304 |
305 | Args:
306 | pretrained (bool): If True, returns a model pre-trained on ImageNet
307 | progress (bool): If True, displays a progress bar of the download to stderr
308 | """
309 | kwargs['groups'] = 32
310 | kwargs['width_per_group'] = 8
311 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
312 | pretrained, progress, **kwargs)
313 |
314 |
315 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
316 | r"""Wide ResNet-50-2 model from
317 | `"Wide Residual Networks" `_
318 |
319 | The model is the same as ResNet except for the bottleneck number of channels
320 | which is twice larger in every block. The number of channels in outer 1x1
321 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
322 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
323 |
324 | Args:
325 | pretrained (bool): If True, returns a model pre-trained on ImageNet
326 | progress (bool): If True, displays a progress bar of the download to stderr
327 | """
328 | kwargs['width_per_group'] = 64 * 2
329 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
330 | pretrained, progress, **kwargs)
331 |
332 |
333 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
334 | r"""Wide ResNet-101-2 model from
335 | `"Wide Residual Networks" `_
336 |
337 | The model is the same as ResNet except for the bottleneck number of channels
338 | which is twice larger in every block. The number of channels in outer 1x1
339 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
340 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
341 |
342 | Args:
343 | pretrained (bool): If True, returns a model pre-trained on ImageNet
344 | progress (bool): If True, displays a progress bar of the download to stderr
345 | """
346 | kwargs['width_per_group'] = 64 * 2
347 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
348 | pretrained, progress, **kwargs)
349 |
--------------------------------------------------------------------------------
/results/car_cub.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leftthomas/CGD/c9cd98fcfe4296875509c316bc8536da71ed22d0/results/car_cub.png
--------------------------------------------------------------------------------
/results/sop_isc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leftthomas/CGD/c9cd98fcfe4296875509c316bc8536da71ed22d0/results/sop_isc.png
--------------------------------------------------------------------------------
/results/structure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leftthomas/CGD/c9cd98fcfe4296875509c316bc8536da71ed22d0/results/structure.png
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import shutil
4 |
5 | import torch
6 | from PIL import Image, ImageDraw
7 |
8 | if __name__ == '__main__':
9 | parser = argparse.ArgumentParser(description='Test CGD')
10 | parser.add_argument('--query_img_name', default='/home/data/car/uncropped/008055.jpg', type=str,
11 | help='query image name')
12 | parser.add_argument('--data_base', default='car_uncropped_resnet50_SG_1536_0.1_0.5_0.1_128_data_base.pth',
13 | type=str, help='queried database')
14 | parser.add_argument('--retrieval_num', default=8, type=int, help='retrieval number')
15 |
16 | opt = parser.parse_args()
17 |
18 | query_img_name, data_base_name, retrieval_num = opt.query_img_name, opt.data_base, opt.retrieval_num
19 | data_name = data_base_name.split('_')[0]
20 |
21 | data_base = torch.load('results/{}'.format(data_base_name))
22 |
23 | if query_img_name not in data_base['test_images']:
24 | raise FileNotFoundError('{} not found'.format(query_img_name))
25 | query_index = data_base['test_images'].index(query_img_name)
26 | query_image = Image.open(query_img_name).convert('RGB').resize((224, 224), resample=Image.BILINEAR)
27 | query_label = torch.tensor(data_base['test_labels'][query_index])
28 | query_feature = data_base['test_features'][query_index]
29 |
30 | gallery_images = data_base['{}_images'.format('test' if data_name != 'isc' else 'gallery')]
31 | gallery_labels = torch.tensor(data_base['{}_labels'.format('test' if data_name != 'isc' else 'gallery')])
32 | gallery_features = data_base['{}_features'.format('test' if data_name != 'isc' else 'gallery')]
33 |
34 | dist_matrix = torch.cdist(query_feature.unsqueeze(0).unsqueeze(0), gallery_features.unsqueeze(0)).squeeze()
35 | if data_name != 'isc':
36 | dist_matrix[query_index] = float('inf')
37 | idx = dist_matrix.topk(k=retrieval_num, dim=-1, largest=False)[1]
38 |
39 | result_path = 'results/{}'.format(query_img_name.split('/')[-1].split('.')[0])
40 | if os.path.exists(result_path):
41 | shutil.rmtree(result_path)
42 | os.mkdir(result_path)
43 | query_image.save('{}/query_img.jpg'.format(result_path))
44 | for num, index in enumerate(idx):
45 | retrieval_image = Image.open(gallery_images[index.item()]).convert('RGB') \
46 | .resize((224, 224), resample=Image.BILINEAR)
47 | draw = ImageDraw.Draw(retrieval_image)
48 | retrieval_label = gallery_labels[index.item()]
49 | retrieval_status = (retrieval_label == query_label).item()
50 | retrieval_dist = dist_matrix[index.item()].item()
51 | if retrieval_status:
52 | draw.rectangle((0, 0, 223, 223), outline='green', width=8)
53 | else:
54 | draw.rectangle((0, 0, 223, 223), outline='red', width=8)
55 | retrieval_image.save('{}/retrieval_img_{}_{}.jpg'.format(result_path, num + 1, '%.4f' % retrieval_dist))
56 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import pandas as pd
4 | import torch
5 | from thop import profile, clever_format
6 | from torch.optim import Adam
7 | from torch.optim.lr_scheduler import MultiStepLR
8 | from torch.utils.data import DataLoader
9 | from tqdm import tqdm
10 |
11 | from model import Model, set_bn_eval
12 | from utils import recall, LabelSmoothingCrossEntropyLoss, BatchHardTripletLoss, ImageReader, MPerClassSampler
13 |
14 |
15 | def train(net, optim):
16 | net.train()
17 | # fix bn on backbone network
18 | net.apply(set_bn_eval)
19 | total_loss, total_correct, total_num, data_bar = 0, 0, 0, tqdm(train_data_loader)
20 | for inputs, labels in data_bar:
21 | inputs, labels = inputs.cuda(), labels.cuda()
22 | features, classes = net(inputs)
23 | class_loss = class_criterion(classes, labels)
24 | feature_loss = feature_criterion(features, labels)
25 | loss = class_loss + feature_loss
26 | optim.zero_grad()
27 | loss.backward()
28 | optim.step()
29 | pred = torch.argmax(classes, dim=-1)
30 | total_loss += loss.item() * inputs.size(0)
31 | total_correct += torch.sum(pred == labels).item()
32 | total_num += inputs.size(0)
33 | data_bar.set_description('Train Epoch {}/{} - Loss:{:.4f} - Acc:{:.2f}%'
34 | .format(epoch, num_epochs, total_loss / total_num, total_correct / total_num * 100))
35 |
36 | return total_loss / total_num, total_correct / total_num * 100
37 |
38 |
39 | def test(net, recall_ids):
40 | net.eval()
41 | with torch.no_grad():
42 | # obtain feature vectors for all data
43 | for key in eval_dict.keys():
44 | eval_dict[key]['features'] = []
45 | for inputs, labels in tqdm(eval_dict[key]['data_loader'], desc='processing {} data'.format(key)):
46 | inputs, labels = inputs.cuda(), labels.cuda()
47 | features, classes = net(inputs)
48 | eval_dict[key]['features'].append(features)
49 | eval_dict[key]['features'] = torch.cat(eval_dict[key]['features'], dim=0)
50 |
51 | # compute recall metric
52 | if data_name == 'isc':
53 | acc_list = recall(eval_dict['test']['features'], test_data_set.labels, recall_ids,
54 | eval_dict['gallery']['features'], gallery_data_set.labels)
55 | else:
56 | acc_list = recall(eval_dict['test']['features'], test_data_set.labels, recall_ids)
57 | desc = 'Test Epoch {}/{} '.format(epoch, num_epochs)
58 | for index, rank_id in enumerate(recall_ids):
59 | desc += 'R@{}:{:.2f}% '.format(rank_id, acc_list[index] * 100)
60 | results['test_recall@{}'.format(rank_id)].append(acc_list[index] * 100)
61 | print(desc)
62 | return acc_list[0]
63 |
64 |
65 | if __name__ == '__main__':
66 | parser = argparse.ArgumentParser(description='Train CGD')
67 | parser.add_argument('--data_path', default='/home/data', type=str, help='datasets path')
68 | parser.add_argument('--data_name', default='car', type=str, choices=['car', 'cub', 'sop', 'isc'],
69 | help='dataset name')
70 | parser.add_argument('--crop_type', default='uncropped', type=str, choices=['uncropped', 'cropped'],
71 | help='crop data or not, it only works for car or cub dataset')
72 | parser.add_argument('--backbone_type', default='resnet50', type=str, choices=['resnet50', 'resnext50'],
73 | help='backbone network type')
74 | parser.add_argument('--gd_config', default='SG', type=str,
75 | choices=['S', 'M', 'G', 'SM', 'MS', 'SG', 'GS', 'MG', 'GM', 'SMG', 'MSG', 'GSM'],
76 | help='global descriptors config')
77 | parser.add_argument('--feature_dim', default=1536, type=int, help='feature dim')
78 | parser.add_argument('--smoothing', default=0.1, type=float, help='smoothing value for label smoothing')
79 | parser.add_argument('--temperature', default=0.5, type=float,
80 | help='temperature scaling used in softmax cross-entropy loss')
81 | parser.add_argument('--margin', default=0.1, type=float, help='margin of m for triplet loss')
82 | parser.add_argument('--recalls', default='1,2,4,8', type=str, help='selected recall')
83 | parser.add_argument('--batch_size', default=128, type=int, help='train batch size')
84 | parser.add_argument('--num_epochs', default=20, type=int, help='train epoch number')
85 |
86 | opt = parser.parse_args()
87 | # args parse
88 | data_path, data_name, crop_type, backbone_type = opt.data_path, opt.data_name, opt.crop_type, opt.backbone_type
89 | gd_config, feature_dim, smoothing, temperature = opt.gd_config, opt.feature_dim, opt.smoothing, opt.temperature
90 | margin, recalls, batch_size = opt.margin, [int(k) for k in opt.recalls.split(',')], opt.batch_size
91 | num_epochs = opt.num_epochs
92 | save_name_pre = '{}_{}_{}_{}_{}_{}_{}_{}_{}'.format(data_name, crop_type, backbone_type, gd_config, feature_dim,
93 | smoothing, temperature, margin, batch_size)
94 |
95 | results = {'train_loss': [], 'train_accuracy': []}
96 | for recall_id in recalls:
97 | results['test_recall@{}'.format(recall_id)] = []
98 |
99 | # dataset loader
100 | train_data_set = ImageReader(data_path, data_name, 'train', crop_type)
101 | train_sample = MPerClassSampler(train_data_set.labels, batch_size)
102 | train_data_loader = DataLoader(train_data_set, batch_sampler=train_sample, num_workers=8)
103 | test_data_set = ImageReader(data_path, data_name, 'query' if data_name == 'isc' else 'test', crop_type)
104 | test_data_loader = DataLoader(test_data_set, batch_size, shuffle=False, num_workers=8)
105 | eval_dict = {'test': {'data_loader': test_data_loader}}
106 | if data_name == 'isc':
107 | gallery_data_set = ImageReader(data_path, data_name, 'gallery', crop_type)
108 | gallery_data_loader = DataLoader(gallery_data_set, batch_size, shuffle=False, num_workers=8)
109 | eval_dict['gallery'] = {'data_loader': gallery_data_loader}
110 |
111 | # model setup, model profile, optimizer config and loss definition
112 | model = Model(backbone_type, gd_config, feature_dim, num_classes=len(train_data_set.class_to_idx)).cuda()
113 | flops, params = profile(model, inputs=(torch.randn(1, 3, 224, 224).cuda(),))
114 | flops, params = clever_format([flops, params])
115 | print('# Model Params: {} FLOPs: {}'.format(params, flops))
116 | optimizer = Adam(model.parameters(), lr=1e-4)
117 | lr_scheduler = MultiStepLR(optimizer, milestones=[int(0.6 * num_epochs), int(0.8 * num_epochs)], gamma=0.1)
118 | class_criterion = LabelSmoothingCrossEntropyLoss(smoothing=smoothing, temperature=temperature)
119 | feature_criterion = BatchHardTripletLoss(margin=margin)
120 |
121 | best_recall = 0.0
122 | for epoch in range(1, num_epochs + 1):
123 | train_loss, train_accuracy = train(model, optimizer)
124 | results['train_loss'].append(train_loss)
125 | results['train_accuracy'].append(train_accuracy)
126 | rank = test(model, recalls)
127 | lr_scheduler.step()
128 |
129 | # save statistics
130 | data_frame = pd.DataFrame(data=results, index=range(1, epoch + 1))
131 | data_frame.to_csv('results/{}_statistics.csv'.format(save_name_pre), index_label='epoch')
132 | # save database and model
133 | data_base = {}
134 | if rank > best_recall:
135 | best_recall = rank
136 | data_base['test_images'] = test_data_set.images
137 | data_base['test_labels'] = test_data_set.labels
138 | data_base['test_features'] = eval_dict['test']['features']
139 | if data_name == 'isc':
140 | data_base['gallery_images'] = gallery_data_set.images
141 | data_base['gallery_labels'] = gallery_data_set.labels
142 | data_base['gallery_features'] = eval_dict['gallery']['features']
143 | torch.save(model.state_dict(), 'results/{}_model.pth'.format(save_name_pre))
144 | torch.save(data_base, 'results/{}_data_base.pth'.format(save_name_pre))
145 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from PIL import Image
4 | from torch import nn
5 | from torch.nn import functional as F
6 | from torch.utils.data import Dataset
7 | from torch.utils.data.sampler import Sampler
8 | from torchvision import transforms
9 |
10 |
11 | class ImageReader(Dataset):
12 |
13 | def __init__(self, data_path, data_name, data_type, crop_type):
14 | if crop_type == 'cropped' and data_name not in ['car', 'cub']:
15 | raise NotImplementedError('cropped data only works for car or cub dataset')
16 |
17 | data_dict = torch.load('{}/{}/{}_data_dicts.pth'.format(data_path, data_name, crop_type))[data_type]
18 | self.class_to_idx = dict(zip(sorted(data_dict), range(len(data_dict))))
19 | normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
20 | if data_type == 'train':
21 | self.transform = transforms.Compose([transforms.Resize((252, 252)), transforms.RandomCrop(224),
22 | transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize])
23 | else:
24 | self.transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), normalize])
25 | self.images, self.labels = [], []
26 | for label, image_list in data_dict.items():
27 | self.images += image_list
28 | self.labels += [self.class_to_idx[label]] * len(image_list)
29 |
30 | def __getitem__(self, index):
31 | path, target = self.images[index], self.labels[index]
32 | img = Image.open(path).convert('RGB')
33 | img = self.transform(img)
34 | return img, target
35 |
36 | def __len__(self):
37 | return len(self.images)
38 |
39 |
40 | def recall(feature_vectors, feature_labels, rank, gallery_vectors=None, gallery_labels=None):
41 | num_features = len(feature_labels)
42 | feature_labels = torch.tensor(feature_labels, device=feature_vectors.device)
43 | gallery_vectors = feature_vectors if gallery_vectors is None else gallery_vectors
44 |
45 | dist_matrix = torch.cdist(feature_vectors.unsqueeze(0), gallery_vectors.unsqueeze(0)).squeeze(0)
46 |
47 | if gallery_labels is None:
48 | dist_matrix.fill_diagonal_(float('inf'))
49 | gallery_labels = feature_labels
50 | else:
51 | gallery_labels = torch.tensor(gallery_labels, device=feature_vectors.device)
52 |
53 | idx = dist_matrix.topk(k=rank[-1], dim=-1, largest=False)[1]
54 | acc_list = []
55 | for r in rank:
56 | correct = (gallery_labels[idx[:, 0:r]] == feature_labels.unsqueeze(dim=-1)).any(dim=-1).float()
57 | acc_list.append((torch.sum(correct) / num_features).item())
58 | return acc_list
59 |
60 |
61 | class LabelSmoothingCrossEntropyLoss(nn.Module):
62 | def __init__(self, smoothing=0.1, temperature=1.0):
63 | super().__init__()
64 | self.smoothing = smoothing
65 | self.temperature = temperature
66 |
67 | def forward(self, x, target):
68 | log_probs = F.log_softmax(x / self.temperature, dim=-1)
69 | nll_loss = -log_probs.gather(dim=-1, index=target.unsqueeze(dim=-1)).squeeze(dim=-1)
70 | smooth_loss = -log_probs.mean(dim=-1)
71 | loss = (1.0 - self.smoothing) * nll_loss + self.smoothing * smooth_loss
72 | return loss.mean()
73 |
74 |
75 | class BatchHardTripletLoss(nn.Module):
76 | def __init__(self, margin=1.0):
77 | super().__init__()
78 | self.margin = margin
79 |
80 | @staticmethod
81 | def get_anchor_positive_triplet_mask(target):
82 | mask = torch.eq(target.unsqueeze(0), target.unsqueeze(1))
83 | mask.fill_diagonal_(False)
84 | return mask
85 |
86 | @staticmethod
87 | def get_anchor_negative_triplet_mask(target):
88 | labels_equal = torch.eq(target.unsqueeze(0), target.unsqueeze(1))
89 | mask = ~ labels_equal
90 | return mask
91 |
92 | def forward(self, x, target):
93 | pairwise_dist = torch.cdist(x.unsqueeze(0), x.unsqueeze(0)).squeeze(0)
94 |
95 | mask_anchor_positive = self.get_anchor_positive_triplet_mask(target)
96 | anchor_positive_dist = mask_anchor_positive.float() * pairwise_dist
97 | hardest_positive_dist = anchor_positive_dist.max(1, True)[0]
98 |
99 | mask_anchor_negative = self.get_anchor_negative_triplet_mask(target)
100 | # make positive and anchor to be exclusive through maximizing the dist
101 | max_anchor_negative_dist = pairwise_dist.max(1, True)[0]
102 | anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative.float())
103 | hardest_negative_dist = anchor_negative_dist.min(1, True)[0]
104 |
105 | loss = (F.relu(hardest_positive_dist - hardest_negative_dist + self.margin))
106 | return loss.mean()
107 |
108 |
109 | class MPerClassSampler(Sampler):
110 | def __init__(self, labels, batch_size, m=4):
111 | self.labels = np.array(labels)
112 | self.labels_unique = np.unique(labels)
113 | self.batch_size = batch_size
114 | self.m = m
115 | assert batch_size % m == 0, 'batch size must be divided by m'
116 |
117 | def __len__(self):
118 | return len(self.labels) // self.batch_size
119 |
120 | def __iter__(self):
121 | for _ in range(self.__len__()):
122 | labels_in_batch = set()
123 | inds = np.array([], dtype=np.int)
124 |
125 | while inds.shape[0] < self.batch_size:
126 | sample_label = np.random.choice(self.labels_unique)
127 | if sample_label in labels_in_batch:
128 | continue
129 |
130 | labels_in_batch.add(sample_label)
131 | sample_label_ids = np.argwhere(np.in1d(self.labels, sample_label)).reshape(-1)
132 | subsample = np.random.permutation(sample_label_ids)[:self.m]
133 | inds = np.append(inds, subsample)
134 |
135 | inds = inds[:self.batch_size]
136 | inds = np.random.permutation(inds)
137 | yield list(inds)
138 |
--------------------------------------------------------------------------------