├── README.md ├── data_utils.py ├── model.py ├── results └── result.png ├── test.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # ProxyAnchor 2 | 3 | A PyTorch implementation of Proxy Anchor Loss based on CVPR 2020 4 | paper [Proxy Anchor Loss for Deep Metric Learning](https://arxiv.org/abs/2003.13911). 5 | 6 | ## Requirements 7 | 8 | - [Anaconda](https://www.anaconda.com/download/) 9 | - [PyTorch](https://pytorch.org) 10 | 11 | ``` 12 | conda install pytorch torchvision cudatoolkit=11.0 -c pytorch 13 | ``` 14 | 15 | - pretrainedmodels 16 | 17 | ``` 18 | pip install pretrainedmodels 19 | ``` 20 | 21 | ## Datasets 22 | 23 | [CARS196](http://ai.stanford.edu/~jkrause/cars/car_dataset.html) 24 | and [CUB200-2011](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) 25 | are used in this repo. You should download these datasets by yourself, and extract them into `${data_path}` directory, 26 | make sure the dir names are `car` and `cub`. Then run `data_utils.py` to preprocess them. 27 | 28 | ## Usage 29 | ### Train Model 30 | 31 | ``` 32 | python train.py --data_name cub --backbone_type inception --feature_dim 256 33 | optional arguments: 34 | --data_path datasets path [default value is '/home/data'] 35 | --data_name dataset name [default value is 'car'](choices=['car', 'cub']) 36 | --backbone_type backbone network type [default value is 'resnet50'](choices=['resnet50', 'inception', 'googlenet']) 37 | --feature_dim feature dim [default value is 512] 38 | --batch_size training batch size [default value is 64] 39 | --num_epochs training epoch number [default value is 20] 40 | --warm_up warm up number [default value is 2] 41 | --recalls selected recall [default value is '1,2,4,8'] 42 | ``` 43 | 44 | ### Test Model 45 | 46 | ``` 47 | python test.py --retrieval_num 10 48 | optional arguments: 49 | --query_img_name query image name [default value is '/home/data/car/uncropped/008055.jpg'] 50 | --data_base queried database [default value is 'car_resnet50_512_data_base.pth'] 51 | --retrieval_num retrieval number [default value is 8] 52 | ``` 53 | 54 | ## Benchmarks 55 | 56 | The models are trained on one NVIDIA GeForce GTX 1070 (8G) GPU. `AdamW` is used to optimize the model, `lr` is `1e-2` 57 | for the parameters of `proxies` and `1e-4` for other parameters, every `5 steps` the `lr` is reduced by `2`. 58 | `weight decay` is used, `scale` is `32` and `margin` is `0.1`, other hyper-parameters are the default values. 59 | 60 | ### CARS196 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 |
BackboneR@1R@2R@4R@8Download
ResNet5087.2%92.4%95.5%97.4%5bww
Inception85.1%91.1%94.5%96.9%r6e7
GoogLeNet78.2%85.5%91.1%94.5%espu
100 | 101 | ### CUB200 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 |
BackboneR@1R@2R@4R@8Download
ResNet5067.0%77.3%85.1%90.8%73h5
Inception67.6%78.2%86.3%91.4%u5b9
GoogLeNet62.8%73.9%82.4%89.4%anbq
141 | 142 | ## Results 143 | 144 | ![vis](results/result.png) 145 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | from PIL import Image 6 | from scipy.io import loadmat 7 | from tqdm import tqdm 8 | 9 | 10 | def read_txt(path, data_num): 11 | data = {} 12 | for line in open(path, 'r', encoding='utf-8'): 13 | if data_num == 2: 14 | data_1, data_2 = line.split() 15 | else: 16 | data_1, data_2, data_3, data_4, data_5 = line.split() 17 | data_2 = [data_2, data_3, data_4, data_5] 18 | data[data_1] = data_2 19 | return data 20 | 21 | 22 | def process_car_data(data_path, data_type): 23 | if not os.path.exists('{}/{}'.format(data_path, data_type)): 24 | os.mkdir('{}/{}'.format(data_path, data_type)) 25 | train_images, test_images = {}, {} 26 | annotations = loadmat('{}/cars_annos.mat'.format(data_path))['annotations'][0] 27 | for img in tqdm(annotations, desc='process {} data for car dataset'.format(data_type), dynamic_ncols=True): 28 | img_name, img_label = str(img[0][0]), str(img[5][0][0]) 29 | if data_type == 'uncropped': 30 | img = Image.open('{}/{}'.format(data_path, img_name)).convert('RGB') 31 | else: 32 | x1, y1, x2, y2 = int(img[1][0][0]), int(img[2][0][0]), int(img[3][0][0]), int(img[4][0][0]) 33 | img = Image.open('{}/{}'.format(data_path, img_name)).convert('RGB').crop((x1, y1, x2, y2)) 34 | save_name = '{}/{}/{}'.format(data_path, data_type, os.path.basename(img_name)) 35 | img.save(save_name) 36 | if int(img_label) < 99: 37 | if img_label in train_images: 38 | train_images[img_label].append(save_name) 39 | else: 40 | train_images[img_label] = [save_name] 41 | else: 42 | if img_label in test_images: 43 | test_images[img_label].append(save_name) 44 | else: 45 | test_images[img_label] = [save_name] 46 | torch.save({'train': train_images, 'test': test_images}, '{}/{}_data_dicts.pth'.format(data_path, data_type)) 47 | 48 | 49 | def process_cub_data(data_path, data_type): 50 | if not os.path.exists('{}/{}'.format(data_path, data_type)): 51 | os.mkdir('{}/{}'.format(data_path, data_type)) 52 | images = read_txt('{}/images.txt'.format(data_path), 2) 53 | labels = read_txt('{}/image_class_labels.txt'.format(data_path), 2) 54 | bounding_boxes = read_txt('{}/bounding_boxes.txt'.format(data_path), 5) 55 | train_images, test_images = {}, {} 56 | for img_id, img_name in tqdm(images.items(), desc='process {} data for cub dataset'.format(data_type), 57 | dynamic_ncols=True): 58 | if data_type == 'uncropped': 59 | img = Image.open('{}/images/{}'.format(data_path, img_name)).convert('RGB') 60 | else: 61 | x1, y1 = int(float(bounding_boxes[img_id][0])), int(float(bounding_boxes[img_id][1])) 62 | x2, y2 = x1 + int(float(bounding_boxes[img_id][2])), y1 + int(float(bounding_boxes[img_id][3])) 63 | img = Image.open('{}/images/{}'.format(data_path, img_name)).convert('RGB').crop((x1, y1, x2, y2)) 64 | save_name = '{}/{}/{}'.format(data_path, data_type, os.path.basename(img_name)) 65 | img.save(save_name) 66 | if int(labels[img_id]) < 101: 67 | if labels[img_id] in train_images: 68 | train_images[labels[img_id]].append(save_name) 69 | else: 70 | train_images[labels[img_id]] = [save_name] 71 | else: 72 | if labels[img_id] in test_images: 73 | test_images[labels[img_id]].append(save_name) 74 | else: 75 | test_images[labels[img_id]] = [save_name] 76 | torch.save({'train': train_images, 'test': test_images}, '{}/{}_data_dicts.pth'.format(data_path, data_type)) 77 | 78 | 79 | if __name__ == '__main__': 80 | parser = argparse.ArgumentParser(description='Process datasets') 81 | parser.add_argument('--data_path', default='/home/data', type=str, help='datasets path') 82 | 83 | opt = parser.parse_args() 84 | 85 | process_car_data('{}/car'.format(opt.data_path), 'uncropped') 86 | process_cub_data('{}/cub'.format(opt.data_path), 'uncropped') -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from pretrainedmodels import bninception 4 | from torch import nn 5 | from torchvision.models import resnet50, googlenet 6 | 7 | 8 | class ProxyLinear(nn.Module): 9 | def __init__(self, num_proxy, in_features): 10 | super(ProxyLinear, self).__init__() 11 | self.num_proxy = num_proxy 12 | self.in_features = in_features 13 | # init proxy vector as unit random vector 14 | self.weight = nn.Parameter(F.normalize(torch.randn(num_proxy, in_features), dim=-1)) 15 | 16 | def forward(self, x): 17 | normalized_weight = F.normalize(self.weight, dim=-1) 18 | output = x.mm(normalized_weight.t()) 19 | return output 20 | 21 | def extra_repr(self): 22 | return 'num_proxy={}, in_features={}'.format(self.num_proxy, self.in_features) 23 | 24 | 25 | class AvgMaxPool(nn.Module): 26 | def __init__(self): 27 | super(AvgMaxPool, self).__init__() 28 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 29 | self.max_pool = nn.AdaptiveMaxPool2d(1) 30 | 31 | def forward(self, x): 32 | return self.avg_pool(x) + self.max_pool(x) 33 | 34 | 35 | class Model(nn.Module): 36 | def __init__(self, backbone_type, feature_dim, num_classes): 37 | super().__init__() 38 | 39 | # Backbone Network 40 | backbones = {'resnet50': (resnet50, 2048), 'inception': (bninception, 1024), 'googlenet': (googlenet, 1024)} 41 | backbone, middle_dim = backbones[backbone_type] 42 | backbone = backbone(pretrained='imagenet' if backbone_type == 'inception' else True) 43 | if backbone_type == 'inception': 44 | backbone.global_pool = AvgMaxPool() 45 | backbone.last_linear = nn.Identity() 46 | else: 47 | backbone.avgpool = AvgMaxPool() 48 | backbone.fc = nn.Identity() 49 | self.backbone = backbone 50 | 51 | # Refactor Layer 52 | self.refactor = nn.Linear(middle_dim, feature_dim, bias=False) 53 | self.fc = ProxyLinear(num_classes, feature_dim) 54 | 55 | def forward(self, x): 56 | features = self.backbone(x) 57 | features = F.normalize(self.refactor(features), dim=-1) 58 | classes = self.fc(features) 59 | return features, classes 60 | -------------------------------------------------------------------------------- /results/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leftthomas/ProxyAnchor/16812c88e39a6887718a8bdc0a2e93bbd2595544/results/result.png -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | 5 | import numpy as np 6 | import torch 7 | from PIL import Image, ImageDraw 8 | 9 | if __name__ == '__main__': 10 | parser = argparse.ArgumentParser(description='Test Model') 11 | parser.add_argument('--query_img_name', default='/home/data/car/uncropped/008055.jpg', type=str, 12 | help='query image name') 13 | parser.add_argument('--data_base', default='car_resnet50_512_data_base.pth', type=str, help='queried database') 14 | parser.add_argument('--retrieval_num', default=8, type=int, help='retrieval number') 15 | 16 | opt = parser.parse_args() 17 | 18 | query_img_name, data_base_name, retrieval_num = opt.query_img_name, opt.data_base, opt.retrieval_num 19 | data_name = data_base_name.split('_')[0] 20 | 21 | data_base = torch.load('results/{}'.format(data_base_name)) 22 | 23 | if query_img_name not in data_base['test_images']: 24 | raise FileNotFoundError('{} not found'.format(query_img_name)) 25 | query_index = data_base['test_images'].index(query_img_name) 26 | query_image = Image.open(query_img_name).convert('RGB').resize((224, 224), resample=Image.BILINEAR) 27 | query_label = torch.tensor(data_base['test_labels'][query_index]) 28 | query_feature = data_base['test_features'][query_index] 29 | 30 | gallery_images = data_base['test_images'] 31 | gallery_labels = torch.tensor(data_base['test_labels']) 32 | gallery_features = data_base['test_features'] 33 | 34 | sim_matrix = query_feature.unsqueeze(0).mm(gallery_features.t()).squeeze() 35 | sim_matrix[query_index] = -np.inf 36 | idx = sim_matrix.topk(k=retrieval_num, dim=-1)[1] 37 | 38 | result_path = 'results/{}'.format(query_img_name.split('/')[-1].split('.')[0]) 39 | if os.path.exists(result_path): 40 | shutil.rmtree(result_path) 41 | os.mkdir(result_path) 42 | query_image.save('{}/query_img.jpg'.format(result_path)) 43 | for num, index in enumerate(idx): 44 | retrieval_image = Image.open(gallery_images[index.item()]).convert('RGB') \ 45 | .resize((224, 224), resample=Image.BILINEAR) 46 | draw = ImageDraw.Draw(retrieval_image) 47 | retrieval_label = gallery_labels[index.item()] 48 | retrieval_status = torch.equal(retrieval_label, query_label) 49 | retrieval_sim = sim_matrix[index.item()].item() 50 | if retrieval_status: 51 | draw.rectangle((0, 0, 223, 223), outline='green', width=8) 52 | else: 53 | draw.rectangle((0, 0, 223, 223), outline='red', width=8) 54 | retrieval_image.save('{}/retrieval_img_{}_{}.jpg'.format(result_path, num + 1, '%.4f' % retrieval_sim)) 55 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import torch 6 | from torch.backends import cudnn 7 | from torch.optim import AdamW 8 | from torch.optim.lr_scheduler import StepLR 9 | from torch.utils.data import DataLoader 10 | from tqdm import tqdm 11 | 12 | from model import Model 13 | from utils import recall, ImageReader, set_bn_eval, ProxyAnchorLoss 14 | 15 | # for reproducibility 16 | np.random.seed(1) 17 | torch.manual_seed(1) 18 | cudnn.deterministic = True 19 | cudnn.benchmark = False 20 | 21 | 22 | def train(net, optim): 23 | net.train() 24 | # fix bn on backbone network 25 | net.backbone.apply(set_bn_eval) 26 | total_loss, total_correct, total_num = 0.0, 0.0, 0 27 | data_bar = tqdm(train_data_loader, dynamic_ncols=True) 28 | for inputs, labels in data_bar: 29 | inputs, labels = inputs.cuda(), labels.cuda() 30 | feature, output = net(inputs) 31 | loss = loss_criterion(output, labels) 32 | optim.zero_grad() 33 | loss.backward() 34 | optim.step() 35 | 36 | with torch.no_grad(): 37 | pred = torch.argmax(output, dim=-1) 38 | total_loss += loss.item() * inputs.size(0) 39 | total_correct += torch.sum(torch.eq(pred, labels)).item() 40 | total_num += inputs.size(0) 41 | data_bar.set_description('Train Epoch {}/{} - Loss:{:.4f} - Acc:{:.2f}%' 42 | .format(epoch, num_epochs, total_loss / total_num, 43 | total_correct / total_num * 100)) 44 | return total_loss / total_num, total_correct / total_num * 100 45 | 46 | 47 | def test(net, recall_ids): 48 | net.eval() 49 | # obtain feature vectors for all data 50 | with torch.no_grad(): 51 | features = [] 52 | for inputs, labels in tqdm(test_data_loader, desc='processing test data', dynamic_ncols=True): 53 | feature, _ = net(inputs.cuda()) 54 | features.append(feature) 55 | features = torch.cat(features, dim=0) 56 | # compute recall metric 57 | acc_list = recall(features, test_data_set.labels, recall_ids) 58 | desc = 'Test Epoch {}/{} '.format(epoch, num_epochs) 59 | for index, rank_id in enumerate(recall_ids): 60 | desc += 'R@{}:{:.2f}% '.format(rank_id, acc_list[index] * 100) 61 | results['test_recall@{}'.format(rank_id)].append(acc_list[index] * 100) 62 | print(desc) 63 | data_base['test_features'] = features 64 | return acc_list[0] 65 | 66 | 67 | if __name__ == '__main__': 68 | parser = argparse.ArgumentParser(description='Train Model') 69 | parser.add_argument('--data_path', default='/home/data', type=str, help='datasets path') 70 | parser.add_argument('--data_name', default='car', type=str, choices=['car', 'cub'], help='dataset name') 71 | parser.add_argument('--backbone_type', default='resnet50', type=str, choices=['resnet50', 'inception', 'googlenet'], 72 | help='backbone network type') 73 | parser.add_argument('--feature_dim', default=512, type=int, help='feature dim') 74 | parser.add_argument('--batch_size', default=64, type=int, help='training batch size') 75 | parser.add_argument('--num_epochs', default=20, type=int, help='training epoch number') 76 | parser.add_argument('--warm_up', default=2, type=int, help='warm up number') 77 | parser.add_argument('--recalls', default='1,2,4,8', type=str, help='selected recall') 78 | 79 | opt = parser.parse_args() 80 | # args parse 81 | data_path, data_name, backbone_type = opt.data_path, opt.data_name, opt.backbone_type 82 | feature_dim, batch_size, num_epochs = opt.feature_dim, opt.batch_size, opt.num_epochs 83 | warm_up, recalls = opt.warm_up, [int(k) for k in opt.recalls.split(',')] 84 | save_name_pre = '{}_{}_{}'.format(data_name, backbone_type, feature_dim) 85 | 86 | results = {'train_loss': [], 'train_accuracy': []} 87 | for recall_id in recalls: 88 | results['test_recall@{}'.format(recall_id)] = [] 89 | 90 | # dataset loader 91 | train_data_set = ImageReader(data_path, data_name, 'train', backbone_type) 92 | train_data_loader = DataLoader(train_data_set, batch_size, shuffle=True, num_workers=8) 93 | test_data_set = ImageReader(data_path, data_name, 'test', backbone_type) 94 | test_data_loader = DataLoader(test_data_set, batch_size, shuffle=False, num_workers=8) 95 | 96 | # model setup, optimizer config and loss definition 97 | model = Model(backbone_type, feature_dim, len(train_data_set.class_to_idx)).cuda() 98 | optimizer = AdamW([{'params': model.backbone.parameters()}, {'params': model.refactor.parameters()}, 99 | {'params': model.fc.parameters(), 'lr': 1e-2}], lr=1e-4, weight_decay=1e-4) 100 | lr_scheduler = StepLR(optimizer, step_size=5, gamma=0.5) 101 | loss_criterion = ProxyAnchorLoss() 102 | 103 | data_base = {'test_images': test_data_set.images, 'test_labels': test_data_set.labels} 104 | best_recall = 0.0 105 | for epoch in range(1, num_epochs + 1): 106 | 107 | # warmup, not update the parameters of backbone 108 | for param in model.backbone.parameters(): 109 | param.requires_grad = False if epoch <= warm_up else True 110 | 111 | train_loss, train_accuracy = train(model, optimizer) 112 | results['train_loss'].append(train_loss) 113 | results['train_accuracy'].append(train_accuracy) 114 | rank = test(model, recalls) 115 | lr_scheduler.step() 116 | # save statistics 117 | data_frame = pd.DataFrame(data=results, index=range(1, epoch + 1)) 118 | data_frame.to_csv('results/{}_statistics.csv'.format(save_name_pre), index_label='epoch') 119 | 120 | if rank > best_recall: 121 | best_recall = rank 122 | # save database and model 123 | torch.save(model.state_dict(), 'results/{}_model.pth'.format(save_name_pre)) 124 | torch.save(data_base, 'results/{}_data_base.pth'.format(save_name_pre)) 125 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from PIL import Image 5 | from torch import nn 6 | from torch.utils.data import Dataset 7 | from torchvision import transforms 8 | 9 | 10 | class Identity(object): 11 | def __call__(self, im): 12 | return im 13 | 14 | 15 | class RGBToBGR(object): 16 | def __call__(self, im): 17 | assert im.mode == 'RGB' 18 | r, g, b = [im.getchannel(i) for i in range(3)] 19 | im = Image.merge('RGB', [b, g, r]) 20 | return im 21 | 22 | 23 | class ScaleIntensities(object): 24 | def __init__(self, in_range, out_range): 25 | """ Scales intensities. For example [-1, 1] -> [0, 255].""" 26 | self.in_range = in_range 27 | self.out_range = out_range 28 | 29 | def __call__(self, tensor): 30 | tensor = (tensor - self.in_range[0]) / (self.in_range[1] - self.in_range[0]) * ( 31 | self.out_range[1] - self.out_range[0]) + self.out_range[0] 32 | return tensor 33 | 34 | 35 | class ImageReader(Dataset): 36 | 37 | def __init__(self, data_path, data_name, data_type, backbone_type): 38 | data_dict = torch.load('{}/{}/uncropped_data_dicts.pth'.format(data_path, data_name))[data_type] 39 | self.class_to_idx = dict(zip(sorted(data_dict), range(len(data_dict)))) 40 | if backbone_type == 'inception': 41 | normalize = transforms.Normalize([104, 117, 128], [1, 1, 1]) 42 | else: 43 | normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 44 | if data_type == 'train': 45 | self.transform = transforms.Compose([ 46 | RGBToBGR() if backbone_type == 'inception' else Identity(), 47 | transforms.RandomResizedCrop(224), 48 | transforms.RandomHorizontalFlip(), 49 | transforms.ToTensor(), 50 | ScaleIntensities([0, 1], [0, 255]) if backbone_type == 'inception' else Identity(), 51 | normalize]) 52 | else: 53 | self.transform = transforms.Compose([ 54 | RGBToBGR() if backbone_type == 'inception' else Identity(), 55 | transforms.Resize(256), transforms.CenterCrop(224), 56 | transforms.ToTensor(), 57 | ScaleIntensities([0, 1], [0, 255]) if backbone_type == 'inception' else Identity(), 58 | normalize]) 59 | self.images, self.labels = [], [] 60 | for label, image_list in data_dict.items(): 61 | self.images += image_list 62 | self.labels += [self.class_to_idx[label]] * len(image_list) 63 | 64 | def __getitem__(self, index): 65 | path, target = self.images[index], self.labels[index] 66 | img = Image.open(path).convert('RGB') 67 | img = self.transform(img) 68 | return img, target 69 | 70 | def __len__(self): 71 | return len(self.images) 72 | 73 | 74 | def set_bn_eval(m): 75 | classname = m.__class__.__name__ 76 | if classname.find('BatchNorm2d') != -1: 77 | m.eval() 78 | 79 | 80 | def recall(feature_vectors, feature_labels, rank): 81 | feature_labels = torch.tensor(feature_labels, device=feature_vectors.device) 82 | sim_matrix = feature_vectors.mm(feature_vectors.t()) 83 | sim_matrix.fill_diagonal_(-np.inf) 84 | 85 | idx = sim_matrix.topk(k=rank[-1], dim=-1, largest=True)[1] 86 | acc_list = [] 87 | for r in rank: 88 | correct = (torch.eq(feature_labels[idx[:, 0:r]], feature_labels.unsqueeze(dim=-1))).any(dim=-1) 89 | acc_list.append((torch.sum(correct) / correct.size(0)).item()) 90 | return acc_list 91 | 92 | 93 | class ProxyAnchorLoss(nn.Module): 94 | def __init__(self, scale=32, margin=0.1): 95 | super(ProxyAnchorLoss, self).__init__() 96 | self.scale = scale 97 | self.margin = margin 98 | 99 | def forward(self, output, label): 100 | pos_label = F.one_hot(label, num_classes=output.size(-1)) 101 | neg_label = 1 - pos_label 102 | pos_num = torch.sum(torch.ne(pos_label.sum(dim=0), 0)) 103 | pos_output = torch.exp(-self.scale * (output - self.margin)) 104 | neg_output = torch.exp(self.scale * (output + self.margin)) 105 | pos_output = (torch.where(torch.eq(pos_label, 1), pos_output, torch.zeros_like(pos_output))).sum(dim=0) 106 | neg_output = (torch.where(torch.eq(neg_label, 1), neg_output, torch.zeros_like(neg_output))).sum(dim=0) 107 | pos_loss = torch.sum(torch.log(pos_output + 1)) / pos_num 108 | neg_loss = torch.sum(torch.log(neg_output + 1)) / output.size(-1) 109 | loss = pos_loss + neg_loss 110 | return loss 111 | --------------------------------------------------------------------------------