├── .gitignore ├── README.md └── hist_loss.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | test.py 3 | *.pyc 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Histogram Loss 2 | A fast implementation of the histogram loss in pytorch, and the original paper can be found here: 3 | * [Learning Deep Embeddings with Histogram Loss](https://arxiv.org/pdf/1611.00822.pdf) - Evgeniya Ustinova, Victor Lempitsky 4 | 5 | ## Getting started 6 | Both forward and backward functions are implemented, so it can be used as a loss function in your own work. This version is rather stable on both CPUs and GPUs as no outstanding errors occurred during tests. 7 | 8 | ### Implementation 9 | This implementation is based on two pieces of information available online about pytorch: 10 | * [torch.bincount](https://pytorch.org/docs/stable/torch.html?highlight=bincount#torch.bincount) - The very fast `bincount` function in pytorch 11 | * [Extending Pytorch](https://pytorch.org/docs/stable/notes/extending.html) - Writing your own customised layer with both forward and backward functions. 12 | 13 | 14 | ### Prerequisites 15 | ``` 16 | pytorch >= v0.4.1 17 | ``` 18 | 19 | ### Running 20 | Import the function into python 21 | ``` 22 | from hist_loss import HistogramLoss 23 | ``` 24 | Initialise an instance of the function 25 | ``` 26 | func_loss = HistogramLoss() 27 | ``` 28 | Forward computation 29 | ``` 30 | loss = func_loss(sim_pos, sim_neg, n_bins, w_pos, w_neg) 31 | ``` 32 | Backward computation 33 | ``` 34 | loss.backward() 35 | ``` 36 | 37 | ### Contact 38 | * [shuaitang93@ucsd.edu](mailto:shuaitang93.ucsd.edu) - Email 39 | * [@Shuai93Tang](https://twitter.com/Shuai93Tang) - Twitter 40 | * [Shuai Tang](http://shuaitang.github.io/) - Homepage 41 | -------------------------------------------------------------------------------- /hist_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Hist(torch.autograd.Function): 5 | 6 | @staticmethod 7 | def forward(ctx, sim, n_bins, w): 8 | 9 | # compute the step size in the histogram 10 | step = 1. / n_bins 11 | idx = sim / step 12 | 13 | lower = idx.floor() 14 | upper = idx.ceil() 15 | 16 | delta_u = idx - lower 17 | delta_l = upper - idx 18 | 19 | lower = lower.long() 20 | upper = upper.long() 21 | 22 | hist = torch.bincount(upper, delta_u * w, n_bins + 1) + torch.bincount( lower, delta_l * w, n_bins + 1) 23 | w_sum = w.sum() 24 | hist = hist / w_sum 25 | 26 | ctx.save_for_backward(upper, lower, w, w_sum) 27 | 28 | return hist 29 | 30 | 31 | @staticmethod 32 | def backward(ctx, grad_hist): 33 | upper, lower, w, w_sum = ctx.saved_tensors 34 | grad_sim = None 35 | 36 | grad_hist = grad_hist / w_sum 37 | grad_sim = (grad_hist[upper] - grad_hist[lower]) * w 38 | 39 | return grad_sim, None, None 40 | 41 | 42 | class HistogramLoss(nn.Module): 43 | def __init__(self): 44 | super(HistogramLoss, self).__init__() 45 | 46 | self.hist = Hist.apply 47 | 48 | def forward(self, sim_pos, sim_neg, n_bins, w_pos=None, w_neg=None): 49 | 50 | sim_pos = sim_pos.flatten() 51 | sim_neg = sim_neg.flatten() 52 | 53 | # linearly transform similarity values to the range between 0 and 1 54 | max_pos, min_pos = torch.max(sim_pos.data), torch.min(sim_pos.data) 55 | max_neg, min_neg = torch.max(sim_neg.data), torch.min(sim_neg.data) 56 | 57 | max_ = max_pos if max_pos >= max_neg else max_neg 58 | min_ = min_pos if min_pos <= min_neg else min_neg 59 | 60 | sim_pos = (sim_pos - min_) / (max_ - min_) 61 | sim_neg = (sim_neg - min_) / (max_ - min_) 62 | 63 | if w_pos is not None: 64 | w_pos = w_pos.data.flatten() 65 | assert sim.size() == w.size(), "Please make sure the size of the similarity tensor matches that of the weight tensor." 66 | else: 67 | w_pos = torch.ones_like(sim_pos) 68 | 69 | if w_neg is not None: 70 | w_neg = w_neg.data.flatten() 71 | assert sim.size() == w.size(), "Please make sure the size of the similarity tensor matches that of the weight tensor." 72 | else: 73 | w_neg = torch.ones_like(sim_neg) 74 | 75 | pdf_pos = self.hist(sim_pos, n_bins, w_pos) 76 | pdf_neg = self.hist(sim_neg, n_bins, w_neg) 77 | 78 | cdf_pos = torch.cumsum(pdf_pos, dim=0) 79 | loss = (cdf_pos * pdf_neg).sum() 80 | 81 | return loss 82 | 83 | 84 | 85 | 86 | --------------------------------------------------------------------------------