├── README.md
├── data_utils.py
├── model.py
├── results
└── result.png
├── test.py
├── train.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # ProxyAnchor
2 |
3 | A PyTorch implementation of Proxy Anchor Loss based on CVPR 2020
4 | paper [Proxy Anchor Loss for Deep Metric Learning](https://arxiv.org/abs/2003.13911).
5 |
6 | ## Requirements
7 |
8 | - [Anaconda](https://www.anaconda.com/download/)
9 | - [PyTorch](https://pytorch.org)
10 |
11 | ```
12 | conda install pytorch torchvision cudatoolkit=11.0 -c pytorch
13 | ```
14 |
15 | - pretrainedmodels
16 |
17 | ```
18 | pip install pretrainedmodels
19 | ```
20 |
21 | ## Datasets
22 |
23 | [CARS196](http://ai.stanford.edu/~jkrause/cars/car_dataset.html)
24 | and [CUB200-2011](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html)
25 | are used in this repo. You should download these datasets by yourself, and extract them into `${data_path}` directory,
26 | make sure the dir names are `car` and `cub`. Then run `data_utils.py` to preprocess them.
27 |
28 | ## Usage
29 | ### Train Model
30 |
31 | ```
32 | python train.py --data_name cub --backbone_type inception --feature_dim 256
33 | optional arguments:
34 | --data_path datasets path [default value is '/home/data']
35 | --data_name dataset name [default value is 'car'](choices=['car', 'cub'])
36 | --backbone_type backbone network type [default value is 'resnet50'](choices=['resnet50', 'inception', 'googlenet'])
37 | --feature_dim feature dim [default value is 512]
38 | --batch_size training batch size [default value is 64]
39 | --num_epochs training epoch number [default value is 20]
40 | --warm_up warm up number [default value is 2]
41 | --recalls selected recall [default value is '1,2,4,8']
42 | ```
43 |
44 | ### Test Model
45 |
46 | ```
47 | python test.py --retrieval_num 10
48 | optional arguments:
49 | --query_img_name query image name [default value is '/home/data/car/uncropped/008055.jpg']
50 | --data_base queried database [default value is 'car_resnet50_512_data_base.pth']
51 | --retrieval_num retrieval number [default value is 8]
52 | ```
53 |
54 | ## Benchmarks
55 |
56 | The models are trained on one NVIDIA GeForce GTX 1070 (8G) GPU. `AdamW` is used to optimize the model, `lr` is `1e-2`
57 | for the parameters of `proxies` and `1e-4` for other parameters, every `5 steps` the `lr` is reduced by `2`.
58 | `weight decay` is used, `scale` is `32` and `margin` is `0.1`, other hyper-parameters are the default values.
59 |
60 | ### CARS196
61 |
62 |
63 |
64 |
65 | Backbone |
66 | R@1 |
67 | R@2 |
68 | R@4 |
69 | R@8 |
70 | Download |
71 |
72 |
73 |
74 |
75 | ResNet50 |
76 | 87.2% |
77 | 92.4% |
78 | 95.5% |
79 | 97.4% |
80 | 5bww |
81 |
82 |
83 | Inception |
84 | 85.1% |
85 | 91.1% |
86 | 94.5% |
87 | 96.9% |
88 | r6e7 |
89 |
90 |
91 | GoogLeNet |
92 | 78.2% |
93 | 85.5% |
94 | 91.1% |
95 | 94.5% |
96 | espu |
97 |
98 |
99 |
100 |
101 | ### CUB200
102 |
103 |
104 |
105 |
106 | Backbone |
107 | R@1 |
108 | R@2 |
109 | R@4 |
110 | R@8 |
111 | Download |
112 |
113 |
114 |
115 |
116 | ResNet50 |
117 | 67.0% |
118 | 77.3% |
119 | 85.1% |
120 | 90.8% |
121 | 73h5 |
122 |
123 |
124 | Inception |
125 | 67.6% |
126 | 78.2% |
127 | 86.3% |
128 | 91.4% |
129 | u5b9 |
130 |
131 |
132 | GoogLeNet |
133 | 62.8% |
134 | 73.9% |
135 | 82.4% |
136 | 89.4% |
137 | anbq |
138 |
139 |
140 |
141 |
142 | ## Results
143 |
144 | 
145 |
--------------------------------------------------------------------------------
/data_utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import torch
5 | from PIL import Image
6 | from scipy.io import loadmat
7 | from tqdm import tqdm
8 |
9 |
10 | def read_txt(path, data_num):
11 | data = {}
12 | for line in open(path, 'r', encoding='utf-8'):
13 | if data_num == 2:
14 | data_1, data_2 = line.split()
15 | else:
16 | data_1, data_2, data_3, data_4, data_5 = line.split()
17 | data_2 = [data_2, data_3, data_4, data_5]
18 | data[data_1] = data_2
19 | return data
20 |
21 |
22 | def process_car_data(data_path, data_type):
23 | if not os.path.exists('{}/{}'.format(data_path, data_type)):
24 | os.mkdir('{}/{}'.format(data_path, data_type))
25 | train_images, test_images = {}, {}
26 | annotations = loadmat('{}/cars_annos.mat'.format(data_path))['annotations'][0]
27 | for img in tqdm(annotations, desc='process {} data for car dataset'.format(data_type), dynamic_ncols=True):
28 | img_name, img_label = str(img[0][0]), str(img[5][0][0])
29 | if data_type == 'uncropped':
30 | img = Image.open('{}/{}'.format(data_path, img_name)).convert('RGB')
31 | else:
32 | x1, y1, x2, y2 = int(img[1][0][0]), int(img[2][0][0]), int(img[3][0][0]), int(img[4][0][0])
33 | img = Image.open('{}/{}'.format(data_path, img_name)).convert('RGB').crop((x1, y1, x2, y2))
34 | save_name = '{}/{}/{}'.format(data_path, data_type, os.path.basename(img_name))
35 | img.save(save_name)
36 | if int(img_label) < 99:
37 | if img_label in train_images:
38 | train_images[img_label].append(save_name)
39 | else:
40 | train_images[img_label] = [save_name]
41 | else:
42 | if img_label in test_images:
43 | test_images[img_label].append(save_name)
44 | else:
45 | test_images[img_label] = [save_name]
46 | torch.save({'train': train_images, 'test': test_images}, '{}/{}_data_dicts.pth'.format(data_path, data_type))
47 |
48 |
49 | def process_cub_data(data_path, data_type):
50 | if not os.path.exists('{}/{}'.format(data_path, data_type)):
51 | os.mkdir('{}/{}'.format(data_path, data_type))
52 | images = read_txt('{}/images.txt'.format(data_path), 2)
53 | labels = read_txt('{}/image_class_labels.txt'.format(data_path), 2)
54 | bounding_boxes = read_txt('{}/bounding_boxes.txt'.format(data_path), 5)
55 | train_images, test_images = {}, {}
56 | for img_id, img_name in tqdm(images.items(), desc='process {} data for cub dataset'.format(data_type),
57 | dynamic_ncols=True):
58 | if data_type == 'uncropped':
59 | img = Image.open('{}/images/{}'.format(data_path, img_name)).convert('RGB')
60 | else:
61 | x1, y1 = int(float(bounding_boxes[img_id][0])), int(float(bounding_boxes[img_id][1]))
62 | x2, y2 = x1 + int(float(bounding_boxes[img_id][2])), y1 + int(float(bounding_boxes[img_id][3]))
63 | img = Image.open('{}/images/{}'.format(data_path, img_name)).convert('RGB').crop((x1, y1, x2, y2))
64 | save_name = '{}/{}/{}'.format(data_path, data_type, os.path.basename(img_name))
65 | img.save(save_name)
66 | if int(labels[img_id]) < 101:
67 | if labels[img_id] in train_images:
68 | train_images[labels[img_id]].append(save_name)
69 | else:
70 | train_images[labels[img_id]] = [save_name]
71 | else:
72 | if labels[img_id] in test_images:
73 | test_images[labels[img_id]].append(save_name)
74 | else:
75 | test_images[labels[img_id]] = [save_name]
76 | torch.save({'train': train_images, 'test': test_images}, '{}/{}_data_dicts.pth'.format(data_path, data_type))
77 |
78 |
79 | if __name__ == '__main__':
80 | parser = argparse.ArgumentParser(description='Process datasets')
81 | parser.add_argument('--data_path', default='/home/data', type=str, help='datasets path')
82 |
83 | opt = parser.parse_args()
84 |
85 | process_car_data('{}/car'.format(opt.data_path), 'uncropped')
86 | process_cub_data('{}/cub'.format(opt.data_path), 'uncropped')
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from pretrainedmodels import bninception
4 | from torch import nn
5 | from torchvision.models import resnet50, googlenet
6 |
7 |
8 | class ProxyLinear(nn.Module):
9 | def __init__(self, num_proxy, in_features):
10 | super(ProxyLinear, self).__init__()
11 | self.num_proxy = num_proxy
12 | self.in_features = in_features
13 | # init proxy vector as unit random vector
14 | self.weight = nn.Parameter(F.normalize(torch.randn(num_proxy, in_features), dim=-1))
15 |
16 | def forward(self, x):
17 | normalized_weight = F.normalize(self.weight, dim=-1)
18 | output = x.mm(normalized_weight.t())
19 | return output
20 |
21 | def extra_repr(self):
22 | return 'num_proxy={}, in_features={}'.format(self.num_proxy, self.in_features)
23 |
24 |
25 | class AvgMaxPool(nn.Module):
26 | def __init__(self):
27 | super(AvgMaxPool, self).__init__()
28 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
29 | self.max_pool = nn.AdaptiveMaxPool2d(1)
30 |
31 | def forward(self, x):
32 | return self.avg_pool(x) + self.max_pool(x)
33 |
34 |
35 | class Model(nn.Module):
36 | def __init__(self, backbone_type, feature_dim, num_classes):
37 | super().__init__()
38 |
39 | # Backbone Network
40 | backbones = {'resnet50': (resnet50, 2048), 'inception': (bninception, 1024), 'googlenet': (googlenet, 1024)}
41 | backbone, middle_dim = backbones[backbone_type]
42 | backbone = backbone(pretrained='imagenet' if backbone_type == 'inception' else True)
43 | if backbone_type == 'inception':
44 | backbone.global_pool = AvgMaxPool()
45 | backbone.last_linear = nn.Identity()
46 | else:
47 | backbone.avgpool = AvgMaxPool()
48 | backbone.fc = nn.Identity()
49 | self.backbone = backbone
50 |
51 | # Refactor Layer
52 | self.refactor = nn.Linear(middle_dim, feature_dim, bias=False)
53 | self.fc = ProxyLinear(num_classes, feature_dim)
54 |
55 | def forward(self, x):
56 | features = self.backbone(x)
57 | features = F.normalize(self.refactor(features), dim=-1)
58 | classes = self.fc(features)
59 | return features, classes
60 |
--------------------------------------------------------------------------------
/results/result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leftthomas/ProxyAnchor/16812c88e39a6887718a8bdc0a2e93bbd2595544/results/result.png
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import shutil
4 |
5 | import numpy as np
6 | import torch
7 | from PIL import Image, ImageDraw
8 |
9 | if __name__ == '__main__':
10 | parser = argparse.ArgumentParser(description='Test Model')
11 | parser.add_argument('--query_img_name', default='/home/data/car/uncropped/008055.jpg', type=str,
12 | help='query image name')
13 | parser.add_argument('--data_base', default='car_resnet50_512_data_base.pth', type=str, help='queried database')
14 | parser.add_argument('--retrieval_num', default=8, type=int, help='retrieval number')
15 |
16 | opt = parser.parse_args()
17 |
18 | query_img_name, data_base_name, retrieval_num = opt.query_img_name, opt.data_base, opt.retrieval_num
19 | data_name = data_base_name.split('_')[0]
20 |
21 | data_base = torch.load('results/{}'.format(data_base_name))
22 |
23 | if query_img_name not in data_base['test_images']:
24 | raise FileNotFoundError('{} not found'.format(query_img_name))
25 | query_index = data_base['test_images'].index(query_img_name)
26 | query_image = Image.open(query_img_name).convert('RGB').resize((224, 224), resample=Image.BILINEAR)
27 | query_label = torch.tensor(data_base['test_labels'][query_index])
28 | query_feature = data_base['test_features'][query_index]
29 |
30 | gallery_images = data_base['test_images']
31 | gallery_labels = torch.tensor(data_base['test_labels'])
32 | gallery_features = data_base['test_features']
33 |
34 | sim_matrix = query_feature.unsqueeze(0).mm(gallery_features.t()).squeeze()
35 | sim_matrix[query_index] = -np.inf
36 | idx = sim_matrix.topk(k=retrieval_num, dim=-1)[1]
37 |
38 | result_path = 'results/{}'.format(query_img_name.split('/')[-1].split('.')[0])
39 | if os.path.exists(result_path):
40 | shutil.rmtree(result_path)
41 | os.mkdir(result_path)
42 | query_image.save('{}/query_img.jpg'.format(result_path))
43 | for num, index in enumerate(idx):
44 | retrieval_image = Image.open(gallery_images[index.item()]).convert('RGB') \
45 | .resize((224, 224), resample=Image.BILINEAR)
46 | draw = ImageDraw.Draw(retrieval_image)
47 | retrieval_label = gallery_labels[index.item()]
48 | retrieval_status = torch.equal(retrieval_label, query_label)
49 | retrieval_sim = sim_matrix[index.item()].item()
50 | if retrieval_status:
51 | draw.rectangle((0, 0, 223, 223), outline='green', width=8)
52 | else:
53 | draw.rectangle((0, 0, 223, 223), outline='red', width=8)
54 | retrieval_image.save('{}/retrieval_img_{}_{}.jpg'.format(result_path, num + 1, '%.4f' % retrieval_sim))
55 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import numpy as np
4 | import pandas as pd
5 | import torch
6 | from torch.backends import cudnn
7 | from torch.optim import AdamW
8 | from torch.optim.lr_scheduler import StepLR
9 | from torch.utils.data import DataLoader
10 | from tqdm import tqdm
11 |
12 | from model import Model
13 | from utils import recall, ImageReader, set_bn_eval, ProxyAnchorLoss
14 |
15 | # for reproducibility
16 | np.random.seed(1)
17 | torch.manual_seed(1)
18 | cudnn.deterministic = True
19 | cudnn.benchmark = False
20 |
21 |
22 | def train(net, optim):
23 | net.train()
24 | # fix bn on backbone network
25 | net.backbone.apply(set_bn_eval)
26 | total_loss, total_correct, total_num = 0.0, 0.0, 0
27 | data_bar = tqdm(train_data_loader, dynamic_ncols=True)
28 | for inputs, labels in data_bar:
29 | inputs, labels = inputs.cuda(), labels.cuda()
30 | feature, output = net(inputs)
31 | loss = loss_criterion(output, labels)
32 | optim.zero_grad()
33 | loss.backward()
34 | optim.step()
35 |
36 | with torch.no_grad():
37 | pred = torch.argmax(output, dim=-1)
38 | total_loss += loss.item() * inputs.size(0)
39 | total_correct += torch.sum(torch.eq(pred, labels)).item()
40 | total_num += inputs.size(0)
41 | data_bar.set_description('Train Epoch {}/{} - Loss:{:.4f} - Acc:{:.2f}%'
42 | .format(epoch, num_epochs, total_loss / total_num,
43 | total_correct / total_num * 100))
44 | return total_loss / total_num, total_correct / total_num * 100
45 |
46 |
47 | def test(net, recall_ids):
48 | net.eval()
49 | # obtain feature vectors for all data
50 | with torch.no_grad():
51 | features = []
52 | for inputs, labels in tqdm(test_data_loader, desc='processing test data', dynamic_ncols=True):
53 | feature, _ = net(inputs.cuda())
54 | features.append(feature)
55 | features = torch.cat(features, dim=0)
56 | # compute recall metric
57 | acc_list = recall(features, test_data_set.labels, recall_ids)
58 | desc = 'Test Epoch {}/{} '.format(epoch, num_epochs)
59 | for index, rank_id in enumerate(recall_ids):
60 | desc += 'R@{}:{:.2f}% '.format(rank_id, acc_list[index] * 100)
61 | results['test_recall@{}'.format(rank_id)].append(acc_list[index] * 100)
62 | print(desc)
63 | data_base['test_features'] = features
64 | return acc_list[0]
65 |
66 |
67 | if __name__ == '__main__':
68 | parser = argparse.ArgumentParser(description='Train Model')
69 | parser.add_argument('--data_path', default='/home/data', type=str, help='datasets path')
70 | parser.add_argument('--data_name', default='car', type=str, choices=['car', 'cub'], help='dataset name')
71 | parser.add_argument('--backbone_type', default='resnet50', type=str, choices=['resnet50', 'inception', 'googlenet'],
72 | help='backbone network type')
73 | parser.add_argument('--feature_dim', default=512, type=int, help='feature dim')
74 | parser.add_argument('--batch_size', default=64, type=int, help='training batch size')
75 | parser.add_argument('--num_epochs', default=20, type=int, help='training epoch number')
76 | parser.add_argument('--warm_up', default=2, type=int, help='warm up number')
77 | parser.add_argument('--recalls', default='1,2,4,8', type=str, help='selected recall')
78 |
79 | opt = parser.parse_args()
80 | # args parse
81 | data_path, data_name, backbone_type = opt.data_path, opt.data_name, opt.backbone_type
82 | feature_dim, batch_size, num_epochs = opt.feature_dim, opt.batch_size, opt.num_epochs
83 | warm_up, recalls = opt.warm_up, [int(k) for k in opt.recalls.split(',')]
84 | save_name_pre = '{}_{}_{}'.format(data_name, backbone_type, feature_dim)
85 |
86 | results = {'train_loss': [], 'train_accuracy': []}
87 | for recall_id in recalls:
88 | results['test_recall@{}'.format(recall_id)] = []
89 |
90 | # dataset loader
91 | train_data_set = ImageReader(data_path, data_name, 'train', backbone_type)
92 | train_data_loader = DataLoader(train_data_set, batch_size, shuffle=True, num_workers=8)
93 | test_data_set = ImageReader(data_path, data_name, 'test', backbone_type)
94 | test_data_loader = DataLoader(test_data_set, batch_size, shuffle=False, num_workers=8)
95 |
96 | # model setup, optimizer config and loss definition
97 | model = Model(backbone_type, feature_dim, len(train_data_set.class_to_idx)).cuda()
98 | optimizer = AdamW([{'params': model.backbone.parameters()}, {'params': model.refactor.parameters()},
99 | {'params': model.fc.parameters(), 'lr': 1e-2}], lr=1e-4, weight_decay=1e-4)
100 | lr_scheduler = StepLR(optimizer, step_size=5, gamma=0.5)
101 | loss_criterion = ProxyAnchorLoss()
102 |
103 | data_base = {'test_images': test_data_set.images, 'test_labels': test_data_set.labels}
104 | best_recall = 0.0
105 | for epoch in range(1, num_epochs + 1):
106 |
107 | # warmup, not update the parameters of backbone
108 | for param in model.backbone.parameters():
109 | param.requires_grad = False if epoch <= warm_up else True
110 |
111 | train_loss, train_accuracy = train(model, optimizer)
112 | results['train_loss'].append(train_loss)
113 | results['train_accuracy'].append(train_accuracy)
114 | rank = test(model, recalls)
115 | lr_scheduler.step()
116 | # save statistics
117 | data_frame = pd.DataFrame(data=results, index=range(1, epoch + 1))
118 | data_frame.to_csv('results/{}_statistics.csv'.format(save_name_pre), index_label='epoch')
119 |
120 | if rank > best_recall:
121 | best_recall = rank
122 | # save database and model
123 | torch.save(model.state_dict(), 'results/{}_model.pth'.format(save_name_pre))
124 | torch.save(data_base, 'results/{}_data_base.pth'.format(save_name_pre))
125 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | from PIL import Image
5 | from torch import nn
6 | from torch.utils.data import Dataset
7 | from torchvision import transforms
8 |
9 |
10 | class Identity(object):
11 | def __call__(self, im):
12 | return im
13 |
14 |
15 | class RGBToBGR(object):
16 | def __call__(self, im):
17 | assert im.mode == 'RGB'
18 | r, g, b = [im.getchannel(i) for i in range(3)]
19 | im = Image.merge('RGB', [b, g, r])
20 | return im
21 |
22 |
23 | class ScaleIntensities(object):
24 | def __init__(self, in_range, out_range):
25 | """ Scales intensities. For example [-1, 1] -> [0, 255]."""
26 | self.in_range = in_range
27 | self.out_range = out_range
28 |
29 | def __call__(self, tensor):
30 | tensor = (tensor - self.in_range[0]) / (self.in_range[1] - self.in_range[0]) * (
31 | self.out_range[1] - self.out_range[0]) + self.out_range[0]
32 | return tensor
33 |
34 |
35 | class ImageReader(Dataset):
36 |
37 | def __init__(self, data_path, data_name, data_type, backbone_type):
38 | data_dict = torch.load('{}/{}/uncropped_data_dicts.pth'.format(data_path, data_name))[data_type]
39 | self.class_to_idx = dict(zip(sorted(data_dict), range(len(data_dict))))
40 | if backbone_type == 'inception':
41 | normalize = transforms.Normalize([104, 117, 128], [1, 1, 1])
42 | else:
43 | normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
44 | if data_type == 'train':
45 | self.transform = transforms.Compose([
46 | RGBToBGR() if backbone_type == 'inception' else Identity(),
47 | transforms.RandomResizedCrop(224),
48 | transforms.RandomHorizontalFlip(),
49 | transforms.ToTensor(),
50 | ScaleIntensities([0, 1], [0, 255]) if backbone_type == 'inception' else Identity(),
51 | normalize])
52 | else:
53 | self.transform = transforms.Compose([
54 | RGBToBGR() if backbone_type == 'inception' else Identity(),
55 | transforms.Resize(256), transforms.CenterCrop(224),
56 | transforms.ToTensor(),
57 | ScaleIntensities([0, 1], [0, 255]) if backbone_type == 'inception' else Identity(),
58 | normalize])
59 | self.images, self.labels = [], []
60 | for label, image_list in data_dict.items():
61 | self.images += image_list
62 | self.labels += [self.class_to_idx[label]] * len(image_list)
63 |
64 | def __getitem__(self, index):
65 | path, target = self.images[index], self.labels[index]
66 | img = Image.open(path).convert('RGB')
67 | img = self.transform(img)
68 | return img, target
69 |
70 | def __len__(self):
71 | return len(self.images)
72 |
73 |
74 | def set_bn_eval(m):
75 | classname = m.__class__.__name__
76 | if classname.find('BatchNorm2d') != -1:
77 | m.eval()
78 |
79 |
80 | def recall(feature_vectors, feature_labels, rank):
81 | feature_labels = torch.tensor(feature_labels, device=feature_vectors.device)
82 | sim_matrix = feature_vectors.mm(feature_vectors.t())
83 | sim_matrix.fill_diagonal_(-np.inf)
84 |
85 | idx = sim_matrix.topk(k=rank[-1], dim=-1, largest=True)[1]
86 | acc_list = []
87 | for r in rank:
88 | correct = (torch.eq(feature_labels[idx[:, 0:r]], feature_labels.unsqueeze(dim=-1))).any(dim=-1)
89 | acc_list.append((torch.sum(correct) / correct.size(0)).item())
90 | return acc_list
91 |
92 |
93 | class ProxyAnchorLoss(nn.Module):
94 | def __init__(self, scale=32, margin=0.1):
95 | super(ProxyAnchorLoss, self).__init__()
96 | self.scale = scale
97 | self.margin = margin
98 |
99 | def forward(self, output, label):
100 | pos_label = F.one_hot(label, num_classes=output.size(-1))
101 | neg_label = 1 - pos_label
102 | pos_num = torch.sum(torch.ne(pos_label.sum(dim=0), 0))
103 | pos_output = torch.exp(-self.scale * (output - self.margin))
104 | neg_output = torch.exp(self.scale * (output + self.margin))
105 | pos_output = (torch.where(torch.eq(pos_label, 1), pos_output, torch.zeros_like(pos_output))).sum(dim=0)
106 | neg_output = (torch.where(torch.eq(neg_label, 1), neg_output, torch.zeros_like(neg_output))).sum(dim=0)
107 | pos_loss = torch.sum(torch.log(pos_output + 1)) / pos_num
108 | neg_loss = torch.sum(torch.log(neg_output + 1)) / output.size(-1)
109 | loss = pos_loss + neg_loss
110 | return loss
111 |
--------------------------------------------------------------------------------