├── .gitignore ├── LICENSE ├── README.md ├── img └── fig1.png ├── mgp.ipynb ├── mnist-example ├── MNIST_D_absolute.pkl ├── MNIST_D_visual.pkl ├── free_proto_model.pth.tar ├── pre-trained-D_absolute │ ├── fix_proto_dito_model.pth.tar │ └── guided_proto_disto_model.pth.tar ├── pre-trained-D_visual │ ├── fix_proto_dito_model.pth.tar │ └── guided_proto_disto_model.pth.tar └── xe_model.pth.tar ├── setup.py └── torch_prototypes ├── __init__.py ├── metrics ├── __init__.py ├── cost.py ├── distortion.py ├── hyperspherical.py └── rank.py └── modules ├── __init__.py ├── hierarchical_inference.py ├── metric_embedding.py ├── prototypical_network.py └── softlabels.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.swp 6 | # C extensions 7 | *.so 8 | .idea/ 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | 29 | 30 | .DS_Store 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Vivien Sainte Fare Garnot 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Metric-Guided Prototype Learning 2 | PyTorch implementation of [Metric-Guided Prototype Learning](https://arxiv.org/abs/2007.03047) for hierarchical classification. 3 | The modules implemented in this repo can be applied to any classification task where a metric can be defined on the class set, *i.e.* when not all misclassifications have the same cost. Such a metric can easily be derived from the hierarchical structure of most class sets. 4 | 5 | 6 | ![](img/fig1.png) 7 | 8 | ### Example usage on MNIST 9 | We show how to use the code to reproduce Figure 1 of the paper in the notebook `mgp.ipynb`. 10 | The notebook can also be directly run on [this google colab](https://colab.research.google.com/drive/1VoQfBx5q5lWFev0cwxLZ0qQOZU7Rlmb_#offline=true&sandboxMode=true). 11 | 12 | ## Installation 13 | 14 | ### Requirements 15 | The `torch_prototypes` package only requires an environment with PyTorch installed (only tested with version 1.5.0). 16 | For the DeepNCM, and Hierarchical Inference module [torch_scatter](https://github.com/rusty1s/pytorch_scatter) is also required. 17 | The installation of torch_scatter can be challenging, if you do not need the two latter modules please use the [no_scatter](https://github.com/VSainteuf/metric-guided-prototypes-pytorch/tree/no_scatter) branch of this repo. 18 | 19 | | | PyTorch | torch_scatter | 20 | |---------------------------------|:-------:|:-------------:| 21 | | Learnt Prototypes | x | | 22 | | Hyperspherical Prototypes | x | | 23 | | DeepNCM | x | x | 24 | | Hierarchical Inference (Yolov2) | x | x | 25 | 26 | ### Install 27 | All the methods presented in the paper are implemented as PyTorch modules packaged in torch_prototypes. 28 | To install the package, run `pip install -e .` inside the main folder. 29 | 30 | 31 | ## Code 32 | The package torch_prototypes contains different methods shown in the paper, implemented as `torch.nn` modules: 33 | - Learnt prototypes 34 | - Metric-guided prototypes (learnt and fixed) 35 | - Distortion and *scale free* distortion 36 | - Distortion loss, Rank loss 37 | - Hyperspherical prototypes and associated loss 38 | - Hierarchical inference (YOLOv2 hierarchical classification) 39 | - Soft labels 40 | 41 | 42 | 43 | ### Generic usage 44 | 45 | #### Model definition 46 | The modules in the `torch_prototypes` package are applicable to any classification problem. 47 | They all follow the same paradigm of being wrapper modules around a backbone neural architecture that maps the samples `X` of a dataset to embeddings `E`. 48 | 49 | For example, the following lines define a learnt-prototypes classification model with ResNet18 backbone for image embedding. 50 | Backward gradient computations will propagate both to the prototypes and to the backbone's weights. 51 | ```python 52 | from torch_prototypes.modules import prototypical_network 53 | from torchvision.models.resnet import resnet18 54 | 55 | backbone = resnet18() 56 | emb_dim = backbone.fc.out_features 57 | num_classes = 100 58 | 59 | model = prototypical_network.LearntPrototypes(backbone, n_prototypes=num_classes, embedding_dim=emb_dim) 60 | 61 | ``` 62 | 63 | #### Model output 64 | The output of the wrapper model can be treated as regular classification logits: 65 | ```python 66 | import torch.nn as nn 67 | xe = nn.CrossEntropyLoss() 68 | 69 | logits = model(X) 70 | loss = xe(logits, Y) 71 | prediction = logits.argmax(dim=-1) 72 | ``` 73 | 74 | #### Metric-guided regularization 75 | The `torch_prototypes` package also contains the loss functions (DistortionLoss and RankLoss) to implement metric-guided regularization. 76 | The metric needs to be given in the form of a tensor `D` of shape `num_classes x num_classes` that defines the pairwise misclassification costs. 77 | Once defined, these losses can be applied to the model's prototypes to guide the learning process: 78 | 79 | ```python 80 | import torch 81 | import torch.nn as nn 82 | from torch_prototypes.metrics.distortion import DistortionLoss 83 | 84 | xe = nn.CrossEntropyLoss() 85 | 86 | D = torch.rand((num_classes,num_classes)) #Dummy cost tensor 87 | disto_loss = DistortionLoss(D=D) 88 | 89 | logits = model(X) 90 | loss = xe(logits, Y) + disto_loss(model.prototypes) 91 | 92 | 93 | ``` 94 | 95 | By default `DistortionLoss` computes our *scale-free* definition of the distortion (see paper). 96 | 97 | 98 | ## References 99 | 100 | Please include a reference to the following paper if you are using any of the learnt-prototype based methods: 101 | 102 | ``` 103 | @article{garnot2020mgp, 104 | title={Metric-Guided Prototype Learning}, 105 | author={Sainte Fare Garnot, Vivien and Landrieu, Loic}, 106 | journal={arXiv preprint arXiv:2007.03047}, 107 | year={2020} 108 | } 109 | 110 | ``` 111 | For the hyperspherical prototypes, DeepNCM and Yolov2, respectively: 112 | - *Hyperspherical Prorotype Network*, Mettes Pascal and van der Pol Elise and Snoek, NeurIPS 2019 113 | - *DeepNCM: deep nearest class mean classifiers*, Guerriero Samantha and Caputo Barbara and Mensink Thomas, ICLR Workshop 2018 114 | - *YOLO9000: better, faster, stronger*, Redmon Joseph and Farhadi Ali, CVPR 2017 -------------------------------------------------------------------------------- /img/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VSainteuf/metric-guided-prototypes-pytorch/fae90235b57a71f3ab682917eaf5517799fbffed/img/fig1.png -------------------------------------------------------------------------------- /mnist-example/MNIST_D_absolute.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VSainteuf/metric-guided-prototypes-pytorch/fae90235b57a71f3ab682917eaf5517799fbffed/mnist-example/MNIST_D_absolute.pkl -------------------------------------------------------------------------------- /mnist-example/MNIST_D_visual.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VSainteuf/metric-guided-prototypes-pytorch/fae90235b57a71f3ab682917eaf5517799fbffed/mnist-example/MNIST_D_visual.pkl -------------------------------------------------------------------------------- /mnist-example/free_proto_model.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VSainteuf/metric-guided-prototypes-pytorch/fae90235b57a71f3ab682917eaf5517799fbffed/mnist-example/free_proto_model.pth.tar -------------------------------------------------------------------------------- /mnist-example/pre-trained-D_absolute/fix_proto_dito_model.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VSainteuf/metric-guided-prototypes-pytorch/fae90235b57a71f3ab682917eaf5517799fbffed/mnist-example/pre-trained-D_absolute/fix_proto_dito_model.pth.tar -------------------------------------------------------------------------------- /mnist-example/pre-trained-D_absolute/guided_proto_disto_model.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VSainteuf/metric-guided-prototypes-pytorch/fae90235b57a71f3ab682917eaf5517799fbffed/mnist-example/pre-trained-D_absolute/guided_proto_disto_model.pth.tar -------------------------------------------------------------------------------- /mnist-example/pre-trained-D_visual/fix_proto_dito_model.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VSainteuf/metric-guided-prototypes-pytorch/fae90235b57a71f3ab682917eaf5517799fbffed/mnist-example/pre-trained-D_visual/fix_proto_dito_model.pth.tar -------------------------------------------------------------------------------- /mnist-example/pre-trained-D_visual/guided_proto_disto_model.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VSainteuf/metric-guided-prototypes-pytorch/fae90235b57a71f3ab682917eaf5517799fbffed/mnist-example/pre-trained-D_visual/guided_proto_disto_model.pth.tar -------------------------------------------------------------------------------- /mnist-example/xe_model.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VSainteuf/metric-guided-prototypes-pytorch/fae90235b57a71f3ab682917eaf5517799fbffed/mnist-example/xe_model.pth.tar -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from setuptools import find_packages 3 | 4 | setup( 5 | name='torch_prototypes', 6 | version='0.1dev', 7 | packages=find_packages(), 8 | license='MIT' 9 | ) 10 | -------------------------------------------------------------------------------- /torch_prototypes/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VSainteuf/metric-guided-prototypes-pytorch/fae90235b57a71f3ab682917eaf5517799fbffed/torch_prototypes/__init__.py -------------------------------------------------------------------------------- /torch_prototypes/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VSainteuf/metric-guided-prototypes-pytorch/fae90235b57a71f3ab682917eaf5517799fbffed/torch_prototypes/metrics/__init__.py -------------------------------------------------------------------------------- /torch_prototypes/metrics/cost.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class AverageCost(nn.Module): 6 | """Average Cost of predictions, according to given cost matrix D""" 7 | 8 | def __init__(self, D, ignore_index=None): 9 | """ 10 | Args: 11 | D (tensor): 2D cost matrix (n_classes x n_classes) 12 | ignore_index (int): index of label to ignore (if any) 13 | """ 14 | super(AverageCost, self).__init__() 15 | self.D = D 16 | self.ignore_index = ignore_index 17 | 18 | def forward(self, input, y_true): 19 | if len(input.shape) == 4: # Flatten 2D data 20 | b, c, h, w = input.shape 21 | input = ( 22 | input.view(b, c, h * w).transpose(1, 2).contiguous().view(b * h * w, c) 23 | ) 24 | y_true = y_true.view(b * h * w) 25 | 26 | out = nn.Softmax(dim=-1)(input) 27 | b = torch.zeros(out.shape, device=out.device) 28 | b = b.scatter(1, out.argmax(dim=1).view(-1, 1), 1) 29 | Dists = self.D[y_true.long()] 30 | 31 | if self.ignore_index is None: 32 | return float((Dists * b).sum(dim=-1).mean().detach().cpu().numpy()) 33 | else: 34 | return float( 35 | (Dists * b)[y_true.long() != self.ignore_index] 36 | .sum(dim=-1) 37 | .mean() 38 | .detach() 39 | .cpu() 40 | .numpy() 41 | ) 42 | 43 | 44 | class EMDLoss(nn.Module): 45 | """Squared Earth Mover regularization""" 46 | 47 | def __init__(self, l, mu, D): 48 | """ 49 | Args: 50 | l (float): regularization coefficient 51 | mu (float): offset 52 | D (ground distance matrix): 2D cost matrix (n_classes x n_classes) 53 | """ 54 | super(EMDLoss, self).__init__() 55 | self.l = l 56 | self.mu = mu 57 | self.D = D 58 | 59 | def forward(self, input, y_true): 60 | 61 | if len(input.shape) == 4: # Flatten 2D data 62 | b, c, h, w = input.shape 63 | input = ( 64 | input.view(b, c, h * w).transpose(1, 2).contiguous().view(b * h * w, c) 65 | ) 66 | y_true = y_true.view(b * h * w) 67 | 68 | out = nn.Softmax(dim=-1)(input) 69 | Dists = self.D[y_true.long()] 70 | p2 = out ** 2 71 | E = p2 * (Dists - self.mu) 72 | E = E.sum(dim=-1) 73 | 74 | return self.l * E.mean() 75 | -------------------------------------------------------------------------------- /torch_prototypes/metrics/distortion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class Eucl_Mat(nn.Module): 7 | """Pairwise Euclidean distance""" 8 | 9 | def __init_(self): 10 | super(Eucl_Mat, self).__init__() 11 | 12 | def forward(self, mapping): 13 | """ 14 | Args: 15 | mapping (tensor): Tensor of shape N_vectors x Embedding_dimension 16 | Returns: 17 | distances: Tensor of shape N_vectors x N_vectors giving the pairwise Euclidean distances 18 | 19 | """ 20 | return torch.norm(mapping[:, None, :] - mapping[None, :, :], dim=-1) 21 | 22 | 23 | class Cosine_Mat(nn.Module): 24 | """Pairwise Cosine distance""" 25 | 26 | def __init__(self): 27 | super(Cosine_Mat, self).__init__() 28 | 29 | def forward(self, mapping): 30 | """ 31 | Args: 32 | mapping (tensor): Tensor of shape N_vectors x Embedding_dimension 33 | Returns: 34 | distances: Tensor of shape N_vectors x N_vectors giving the pairwise Cosine distances 35 | 36 | """ 37 | return 1 - nn.CosineSimilarity(dim=-1)(mapping[:, None, :], mapping[None, :, :]) 38 | 39 | 40 | class Pseudo_Huber(nn.Module): 41 | """Pseudo-Huber function""" 42 | 43 | def __init__(self, delta=1): 44 | super(Pseudo_Huber, self).__init__() 45 | self.delta = delta 46 | 47 | def forward(self, input): 48 | out = (input / self.delta) ** 2 49 | out = torch.sqrt(out + 1) 50 | out = self.delta * (out - 1) 51 | return out 52 | 53 | 54 | class Distortion(nn.Module): 55 | """Distortion measure of the embedding of finite metric given by matrix D into another metric space""" 56 | 57 | def __init__(self, D, dist="euclidian"): 58 | """ 59 | Args: 60 | D (tensor): 2D cost matrix of the finite metric, shape (NxN) 61 | dist: Distance to use in the target embedding space (euclidean or cosine) 62 | """ 63 | super(Distortion, self).__init__() 64 | self.D = D 65 | if dist == "euclidian": 66 | self.dist = Eucl_Mat() 67 | elif dist == "cosine": 68 | self.dist = Cosine_Mat() 69 | 70 | def forward(self, mapping, idxs=None): 71 | """ 72 | mapping (tensor): Tensor of shape (N x Embedding_dimension) giving the mapping to the target metric space 73 | """ 74 | d = self.dist(mapping) 75 | d = (d - self.D).abs() / ( 76 | self.D + torch.eye(self.D.shape[0], device=self.D.device) 77 | ) 78 | d = d.sum() / (d.shape[0] ** 2 - d.shape[0]) 79 | return d 80 | 81 | 82 | class ScaleFreeDistortion(nn.Module): 83 | def __init__(self, D): 84 | super(ScaleFreeDistortion, self).__init__() 85 | self.D = D 86 | self.disto = Distortion(D) 87 | self.dist = Eucl_Mat() 88 | 89 | def forward(self, prototypes): 90 | # Compute distance ratios 91 | d = self.em(prototypes) 92 | d = d / (self.D + torch.eye(self.D.shape[0], device=self.D.device)) 93 | 94 | # Get sorted list of ratios 95 | alpha = d[d > 0].detach().cpu().numpy() 96 | alpha = np.sort(alpha) 97 | 98 | # Find optimal scaling 99 | cumul = np.cumsum(alpha) 100 | a_i = alpha[np.where(cumul >= alpha.sum() - cumul)[0].min()] 101 | scale = 1 / a_i 102 | 103 | return self.disto(scale * prototypes) 104 | 105 | 106 | class DistortionLoss(nn.Module): 107 | """Scale-free squared distortion regularizer""" 108 | 109 | def __init__(self, D, dist="euclidian", scale_free=True): 110 | super(DistortionLoss, self).__init__() 111 | self.D = D 112 | self.scale_free = scale_free 113 | if dist == "euclidian": 114 | self.dist = Eucl_Mat() 115 | elif dist == "cosine": 116 | self.dist = Cosine_Mat() 117 | 118 | def forward(self, mapping, idxs=None): 119 | d = self.dist(mapping) 120 | 121 | if self.scale_free: 122 | a = d / (self.D + torch.eye(self.D.shape[0], device=self.D.device)) 123 | scaling = a.sum() / torch.pow(a, 2).sum() 124 | else: 125 | scaling = 1.0 126 | 127 | d = (scaling * d - self.D) ** 2 / ( 128 | self.D + torch.eye(self.D.shape[0], device=self.D.device) 129 | ) ** 2 130 | d = d.sum() / (d.shape[0] ** 2 - d.shape[0]) 131 | return d 132 | -------------------------------------------------------------------------------- /torch_prototypes/metrics/hyperspherical.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of the methods proposed in Hyperspherical Prototype Network, Mettes et al., NeurIPS 2019 3 | https://arxiv.org/abs/1901.10514 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class SeparationLoss(nn.Module): 11 | """Large margin separation between hyperspherical protoypes""" 12 | 13 | def __init__(self): 14 | super(SeparationLoss, self).__init__() 15 | 16 | def forward(self, protos): 17 | """ 18 | Args: 19 | protos (tensor): (N_prototypes x Embedding_dimension) 20 | """ 21 | M = torch.matmul(protos, protos.transpose(0, 1)) - 2 * torch.eye( 22 | protos.shape[0] 23 | ).to(protos.device) 24 | return M.max(dim=1)[0].mean() 25 | 26 | 27 | class HypersphericalLoss(nn.Module): 28 | """Training loss, minimizes the cosine distance between a samples embedding and its class prototype""" 29 | 30 | def __init__(self, ignore_label=None, class_weights=None): 31 | super(HypersphericalLoss, self).__init__() 32 | self.ignore_label = ignore_label 33 | self.class_weights = class_weights 34 | 35 | def forward(self, scores, y): 36 | if len(scores.shape) == 4: # Flatten 2D data 37 | b, c, h, w = scores.shape 38 | scores = ( 39 | scores.view(b, c, h * w).transpose(1, 2).contiguous().view(b * h * w, c) 40 | ) 41 | y_lab = y.view(b * h * w) 42 | else: 43 | y_lab = y 44 | loss = -scores.gather(dim=1, index=y_lab.long().view(-1, 1)).squeeze() 45 | if self.ignore_label is not None: 46 | loss = loss[y_lab != self.ignore_label] 47 | y_lab = y_lab[y_lab != self.ignore_label] 48 | if self.class_weights is not None: 49 | W = self.class_weights[y_lab.long()] 50 | loss = loss * W 51 | return loss.sum() 52 | -------------------------------------------------------------------------------- /torch_prototypes/metrics/rank.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch_prototypes.metrics.distortion import Eucl_Mat, Cosine_Mat 4 | 5 | 6 | class RankLoss(nn.Module): 7 | """Rank preserving loss for finite metric embedding""" 8 | 9 | def __init__(self, D, n_triplets, dist="eucl", ignore_index=None): 10 | """ 11 | Args: 12 | D (tensor): 2D cost matrix of the finite metric to embed 13 | n_triplets (int): Number of triplets over which to compute the loss at each iteration 14 | dist (str): Euclidean or Cosine distance for the target embedding space (eucl/cos) 15 | ignore_index (int): index of the label to be ignore (if any) 16 | """ 17 | super(RankLoss, self).__init__() 18 | self.D = D 19 | self.n_triplets = n_triplets 20 | if dist == "eucl": 21 | self.dist = Eucl_Mat() 22 | elif dist == "cos": 23 | self.dist = Cosine_Mat() 24 | self.ignore_index = ignore_index 25 | if self.ignore_index is not None: 26 | self.idxs = torch.tensor( 27 | [i for i in range(D.shape[0]) if i != ignore_index], device=D.device 28 | ).long() 29 | 30 | def forward(self, prototypes, idxs=None): 31 | if self.ignore_index is None: 32 | i, j, k = torch.stack( 33 | [torch.randperm(self.D.shape[0])[:3] for _ in range(self.n_triplets)], 34 | dim=1, 35 | ) 36 | else: 37 | i, j, k = torch.stack( 38 | [ 39 | torch.randperm(self.idxs.shape[0])[:3] 40 | for _ in range(self.n_triplets) 41 | ], 42 | dim=1, 43 | ) 44 | i = self.idxs[i] 45 | j = self.idxs[j] 46 | k = self.idxs[k] 47 | 48 | S_hat_ijk = (self.D[i, j] > self.D[i, k]).float() 49 | 50 | dists = self.dist(prototypes) 51 | diff = dists[i, j] - dists[i, k] 52 | log_Sijk = -torch.log(1 + torch.exp(-diff)) 53 | log_1_Sijk = -torch.log(1 + torch.exp(diff)) 54 | 55 | l = S_hat_ijk * log_Sijk + (1 - S_hat_ijk) * log_1_Sijk 56 | l = -l.mean() 57 | return l 58 | -------------------------------------------------------------------------------- /torch_prototypes/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VSainteuf/metric-guided-prototypes-pytorch/fae90235b57a71f3ab682917eaf5517799fbffed/torch_prototypes/modules/__init__.py -------------------------------------------------------------------------------- /torch_prototypes/modules/hierarchical_inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from torch_scatter import scatter_softmax, scatter_sum, scatter_logsumexp 5 | 6 | 7 | class HierarchicalInference(nn.Module): 8 | """Tree-graph hierarchical inference module""" 9 | 10 | def __init__(self, model, path_matrix, sibling_mask): 11 | """ 12 | 13 | Args: 14 | model (nn.Module): backbone feature extracting network, should return an embedding for all nodes of the tree 15 | (n_nodes = num_classes (or leaf nodes) + internal nodes) 16 | path_matrix (tensor): 2D tensor of shape (depth x n_nodes) that specifies the parent-child relationships in 17 | the tree. For i in [1,depth] and for k in [1,n_nodes], path_matrix[d-i, k] gives the i-th parent of node k 18 | of 0 if the root has already been reached. 19 | sibling_mask (tensor): 1D mask over the nodes, that specifies the sibling groups (i.e. nodes with the same 20 | first parent). This mask is used to compute the softmax restricted to each sibling group. 21 | """ 22 | super(HierarchicalInference, self).__init__() 23 | self.model = model 24 | self.path_matrix = path_matrix 25 | self.sibling_mask = sibling_mask 26 | 27 | def forward(self, *input): 28 | """Returns the log-probability of the marginal probability of each node in the graph""" 29 | edge_logits = self.model(*input) 30 | if len(edge_logits.shape) == 4: # Flatten 2D data 31 | two_dim_data = True 32 | b, c, h, w = edge_logits.shape 33 | edge_logits = ( 34 | edge_logits.view(b, c, h * w) 35 | .transpose(1, 2) 36 | .contiguous() 37 | .view(b * h * w, c) 38 | ) 39 | else: 40 | two_dim_data = False 41 | 42 | lse = scatter_logsumexp(edge_logits, self.sibling_mask, dim=1) 43 | scaled_logits = edge_logits - lse[:, self.sibling_mask] 44 | marginal_logits = scaled_logits.clone() 45 | depth = self.path_matrix.shape[0] 46 | for d in range(depth): 47 | parent_logits = scaled_logits[:, self.path_matrix[d]] 48 | parent_logits[:, self.path_matrix[d] == 0] = 0 49 | marginal_logits = marginal_logits + parent_logits 50 | 51 | if two_dim_data: # Un-flatten 2D data 52 | _, n_out = marginal_logits.shape 53 | marginal_logits = ( 54 | marginal_logits.view(b, h * w, n_out) 55 | .transpose(1, 2) 56 | .contiguous() 57 | .view(b, n_out, h, w) 58 | ) 59 | 60 | return marginal_logits 61 | 62 | 63 | class HierarchicalCrossEntropy(nn.Module): 64 | def __init__( 65 | self, 66 | path_matrix, 67 | alpha=1, 68 | class_weights=None, 69 | ignore_label=None, 70 | focal_gamma=None, 71 | eps=0.000001, 72 | ): 73 | """ 74 | Hierarchical Cross-Entropy 75 | Args: 76 | path_matrix(tensor): 2D tensor of shape (depth x n_nodes) that specifies the parent-child relationships in 77 | the tree. For i in [1,depth] and for k in [1,n_nodes], path_matrix[d-i, k] gives the i-th parent of node k 78 | of 0 if the root has already been reached.: 79 | alpha (float): discounting strength parameter 80 | class_weights (list): (optional) class weighting values 81 | ignore_label (int): (optional) label to ignore in the loss 82 | focal_gamma (float): (optional) focal loss parameter. If provided the log-probabilities are multiplied by 83 | a factor (1 - p)^gamma . 84 | eps (float): default 10e-6, for numerical stability. 85 | """ 86 | super(HierarchicalCrossEntropy, self).__init__() 87 | self.alpha = alpha 88 | self.class_weights = class_weights 89 | self.ignore_label = ignore_label 90 | self.M = torch.Tensor(path_matrix).long().cuda() 91 | self.eps = eps 92 | self.gamma = focal_gamma 93 | 94 | def forward(self, input, y_true): 95 | if len(input.shape) == 4: # Flatten 2D data 96 | b, c, h, w = input.shape 97 | input = ( 98 | input.view(b, c, h * w).transpose(1, 2).contiguous().view(b * h * w, c) 99 | ) 100 | y_true = y_true.view(b * h * w) 101 | 102 | # Posterior class probabilities 103 | p = nn.Softmax(dim=-1)(input) 104 | 105 | # Compute cumulative probabilities of internal nodes 106 | cum_proba = torch.ones((p.shape[0], int(self.M.max() + 1)), device=p.device) 107 | for d in range(self.M.shape[0] - 1, 0, -1): 108 | cum_proba = cum_proba + scatter_sum( 109 | p, self.M[d, :], dim=1, dim_size=int(self.M.max()) + 1 110 | ) 111 | 112 | cond_proba = ( 113 | torch.cat([cum_proba[:, self.M[1:, :]], p[:, None, :]], dim=1) + self.eps 114 | ) / (cum_proba[:, self.M] + self.eps) 115 | 116 | # Discounting coefficients 117 | c = ( 118 | self.alpha 119 | * torch.arange(self.M.shape[0], 0, -1, device=input.device).float() 120 | ) 121 | c = torch.exp(-c).squeeze() 122 | 123 | # Focal or simple corss-entropy 124 | if self.gamma is not None: 125 | out = ( 126 | c[None, :, None] 127 | * (1 - cond_proba) ** self.gamma 128 | * torch.log(cond_proba) 129 | ) 130 | else: 131 | out = c[None, :, None] * torch.log(cond_proba) 132 | 133 | # Combine levels 134 | out = -out.sum(dim=1).squeeze() 135 | out = out.gather(dim=1, index=y_true.view(-1, 1).long()).squeeze() 136 | 137 | if self.class_weights is not None: 138 | W = self.class_weights[y_true.long()] 139 | out = out * W 140 | 141 | if self.ignore_label is not None: 142 | out = out[y_true != self.ignore_label] 143 | return out.mean() 144 | -------------------------------------------------------------------------------- /torch_prototypes/modules/metric_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_prototypes.metrics.rank import RankLoss 3 | from torch_prototypes.metrics.hyperspherical import SeparationLoss 4 | from torch_prototypes.metrics.distortion import DistortionLoss 5 | 6 | 7 | def embed_nomenclature( 8 | D, 9 | embedding_dimension, 10 | loss="rank", 11 | n_steps=1000, 12 | lr=10, 13 | momentum=0.9, 14 | weight_decay=1e-4, 15 | ignore_index=None, 16 | ): 17 | """ 18 | Embed a finite metric into a target embedding space 19 | Args: 20 | D (tensor): 2D-cost matrix of the finite metric 21 | embedding_dimension (int): dimension of the target embedding space 22 | loss (str): embedding loss to use distortion base (loss='disto') or rank based (loss='rank') 23 | n_steps (int): number of gradient iterations 24 | lr (float): learning rate 25 | momentum (float): momentum 26 | weight_decay (float): weight decay 27 | 28 | Returns: 29 | embedding (tensor): embedding of each vertex of the finite metric space, shape n_vertex x embedding_dimension 30 | """ 31 | n_vertex = D.shape[0] 32 | mapping = torch.rand( 33 | (n_vertex, embedding_dimension), requires_grad=True, device=D.device 34 | ) 35 | 36 | if loss == "rank": 37 | crit = RankLoss(D, n_triplets=1000) 38 | elif loss == "disto": 39 | crit = DistortionLoss(D, scale_free=False) 40 | else: 41 | raise ValueError 42 | 43 | optimizer = torch.optim.SGD( 44 | [mapping], lr=lr, momentum=momentum, weight_decay=weight_decay 45 | ) 46 | 47 | print("Embedding nomenclature . . .") 48 | for i in range(n_steps): 49 | loss = crit(mapping) 50 | optimizer.zero_grad() 51 | loss.backward() 52 | optimizer.step() 53 | print( 54 | "Step {}: loss {:.4f} ".format(i + 1, loss.cpu().detach().numpy(), end="\r") 55 | ) 56 | print("Final loss {:.4f}".format(crit(mapping).cpu().detach().numpy())) 57 | return mapping.detach() 58 | 59 | 60 | def embed_on_sphere( 61 | D, embedding_dimension, lr=0.1, momentum=0.9, n_steps=1000, wd=1e-4 62 | ): 63 | """ 64 | Embed finite metric on the hypersphere 65 | Args: 66 | D (tensor): 2D-cost matrix of the finite metric 67 | embedding_dimension (int): dimension of the target embedding space 68 | lr (float): learning rate 69 | momentum (float): momentum 70 | n_steps (int): number of gradient iterations 71 | wd (float): weight decay 72 | 73 | Returns: 74 | embedding (tensor): embedding of each vertex of the finite metric space, shape n_vertex x embedding_dimension 75 | 76 | """ 77 | n_vertex = D.shape[0] 78 | mapping = torch.rand( 79 | (n_vertex, embedding_dimension), requires_grad=True, device=D.device 80 | ) 81 | optimizer = torch.optim.SGD([mapping], lr=lr, momentum=momentum, weight_decay=wd) 82 | 83 | L_hp = SeparationLoss() 84 | L_pi = RankLoss(D, n_triplets=1000, dist="cos") 85 | 86 | print("Embedding nomenclature . . .") 87 | for i in range(n_steps): 88 | with torch.no_grad(): 89 | mapping.div_(torch.norm(mapping, dim=1, keepdim=True)) 90 | loss = L_hp(mapping) + L_pi(mapping) 91 | optimizer.zero_grad() 92 | loss.backward() 93 | optimizer.step() 94 | print( 95 | "Step {}: loss {:.4f} ".format(i + 1, loss.cpu().detach().numpy(), end="\r") 96 | ) 97 | with torch.no_grad(): 98 | mapping.div_(torch.norm(mapping, dim=1, keepdim=True)) 99 | loss = L_hp(mapping) + L_pi(mapping) 100 | print("Final loss {:.4f} ".format(loss.cpu().detach().numpy())) 101 | return mapping.detach() 102 | -------------------------------------------------------------------------------- /torch_prototypes/modules/prototypical_network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch_prototypes.metrics.distortion import Pseudo_Huber 4 | 5 | from torch_scatter import scatter_mean 6 | 7 | 8 | class LearntPrototypes(nn.Module): 9 | """ 10 | Learnt Prototypes Module. Classification module based on learnt prototypes, to be wrapped around a backbone 11 | embedding network. 12 | """ 13 | 14 | def __init__( 15 | self, 16 | model, 17 | n_prototypes, 18 | embedding_dim, 19 | prototypes=None, 20 | squarred=False, 21 | ph=None, 22 | dist="euclidean", 23 | device="cuda", 24 | ): 25 | """ 26 | 27 | Args: 28 | model (nn.Module): feature extracting network 29 | n_prototypes (int): number of prototypes to use 30 | embedding_dim (int): dimension of the embedding space 31 | prototypes (tensor): Prototype tensor of shape (n_prototypes x embedding_dim), 32 | squared (bool): Whether to use the squared Euclidean distance or not 33 | ph (float): if specified, the distances function is huberized with delta parameter equal to the specified value 34 | dist (str): default 'euclidean', other possibility 'cosine' 35 | device (str): device on which to declare the prototypes (cpu/cuda) 36 | """ 37 | super(LearntPrototypes, self).__init__() 38 | self.model = model 39 | self.prototypes = ( 40 | nn.Parameter( 41 | torch.rand((n_prototypes, embedding_dim), device=device) 42 | ).requires_grad_(True) 43 | if prototypes is None 44 | else nn.Parameter(prototypes).requires_grad_(False) 45 | ) 46 | self.n_prototypes = n_prototypes 47 | self.squarred = squarred 48 | self.dist = dist 49 | self.ph = None if ph is None else Pseudo_Huber(delta=ph) 50 | 51 | def forward(self, *input, **kwargs): 52 | embeddings = self.model(*input, **kwargs) 53 | 54 | if len(embeddings.shape) == 4: # Flatten 2D data 55 | two_dim_data = True 56 | b, c, h, w = embeddings.shape 57 | embeddings = ( 58 | embeddings.view(b, c, h * w) 59 | .transpose(1, 2) 60 | .contiguous() 61 | .view(b * h * w, c) 62 | ) 63 | else: 64 | two_dim_data = False 65 | 66 | if self.dist == "cosine": 67 | dists = 1 - nn.CosineSimilarity(dim=-1)( 68 | embeddings[:, None, :], self.prototypes[None, :, :] 69 | ) 70 | else: 71 | dists = torch.norm( 72 | embeddings[:, None, :] - self.prototypes[None, :, :], dim=-1 73 | ) 74 | if self.ph is not None: 75 | dists = self.ph(dists) 76 | if self.squarred: 77 | dists = dists ** 2 78 | 79 | if two_dim_data: # Un-flatten 2D data 80 | dists = ( 81 | dists.view(b, h * w, self.n_prototypes) 82 | .transpose(1, 2) 83 | .contiguous() 84 | .view(b, self.n_prototypes, h, w) 85 | ) 86 | 87 | return -dists 88 | 89 | 90 | class HypersphericalProto(nn.Module): 91 | """ 92 | Implementation of Hypershperical Prototype Networks (Mettes et al., 2019) 93 | """ 94 | 95 | def __init__(self, model, num_classes, prototypes): 96 | """ 97 | Args: 98 | model (nn.Module): backbone feature extracting network 99 | num_classes (int): number of classes 100 | prototypes (tensor): pre-defined prototypes, tensor has shape (num_classes x embedding_dimension) 101 | """ 102 | super(HypersphericalProto, self).__init__() 103 | self.model = model 104 | self.prototypes = nn.Parameter(prototypes).requires_grad_(False) 105 | self.num_classes = num_classes 106 | 107 | def forward(self, *input, **kwargs): 108 | embeddings = self.model(*input, **kwargs) 109 | 110 | if len(embeddings.shape) == 4: # Flatten 2D data 111 | two_dim_data = True 112 | b, c, h, w = embeddings.shape 113 | embeddings = ( 114 | embeddings.view(b, c, h * w) 115 | .transpose(1, 2) 116 | .contiguous() 117 | .view(b * h * w, c) 118 | ) 119 | else: 120 | two_dim_data = False 121 | 122 | dists = 1 - nn.CosineSimilarity(dim=-1)( 123 | embeddings[:, None, :], self.prototypes[None, :, :] 124 | ) 125 | scores = -dists.pow(2) 126 | 127 | if two_dim_data: # Un-flatten 2D data 128 | scores = ( 129 | scores.view(b, h * w, self.num_classes) 130 | .transpose(1, 2) 131 | .contiguous() 132 | .view(b, self.num_classes, h, w) 133 | ) 134 | return scores 135 | 136 | 137 | class DeepNCM(nn.Module): 138 | """ 139 | Implementation of Deep Nearest Mean Classifiers (Gueriero et al., 2017) 140 | """ 141 | 142 | def __init__(self, model, num_classes, embedding_dim): 143 | """ 144 | Args: 145 | model (nn.Module): backbone feature extracting network 146 | num_classes (int): number of classes 147 | embedding_dim (int): number of dimensions of the embedding space 148 | """ 149 | super(DeepNCM, self).__init__() 150 | self.model = model 151 | self.prototypes = nn.Parameter( 152 | torch.rand((num_classes, embedding_dim), device="cuda") 153 | ).requires_grad_(False) 154 | self.num_classes = num_classes 155 | self.counter = torch.zeros(num_classes) 156 | self._check_device = True 157 | 158 | def forward(self, *input_target, **kwargs): 159 | """ 160 | DeepNCM needs the target vector to update the class prototypes 161 | Args: 162 | *input_target: tuple of tensors (*input, target) 163 | """ 164 | input = input_target[:-1] 165 | y_true = input_target[-1] 166 | embeddings = self.model(*input, **kwargs) 167 | if self._check_device: 168 | self.counter = self.counter.to(embeddings.device) 169 | self._check_device = False 170 | 171 | if len(embeddings.shape) == 4: # Flatten 2D data 172 | two_dim_data = True 173 | b, c, h, w = embeddings.shape 174 | embeddings = ( 175 | embeddings.view(b, c, h * w) 176 | .transpose(1, 2) 177 | .contiguous() 178 | .view(b * h * w, c) 179 | ) 180 | y_true = y_true.view(b * h * w) 181 | else: 182 | two_dim_data = False 183 | 184 | if self.training: 185 | represented_classes = torch.unique(y_true).detach().cpu().numpy() 186 | 187 | # Compute Prototypes 188 | new_prototypes = scatter_mean( 189 | embeddings, y_true.unsqueeze(1), dim=0, dim_size=self.num_classes 190 | ).detach() 191 | # Updated stored prototype values 192 | self.prototypes[represented_classes, :] = ( 193 | self.counter[represented_classes, None] 194 | * self.prototypes[represented_classes, :] 195 | + new_prototypes[represented_classes, :] 196 | ) / (self.counter[represented_classes, None] + 1) 197 | # self.counter[represented_classes] 198 | self.counter[represented_classes] = self.counter[represented_classes] + 1 199 | dists = torch.norm(embeddings[:, None, :] - self.prototypes[None, :, :], dim=-1) 200 | if two_dim_data: # Un-flatten 2D data 201 | dists = ( 202 | dists.view(b, h * w, self.num_classes) 203 | .transpose(1, 2) 204 | .contiguous() 205 | .view(b, self.num_classes, h, w) 206 | ) 207 | 208 | return -dists.pow(2) 209 | -------------------------------------------------------------------------------- /torch_prototypes/modules/softlabels.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class SoftCrossEntropy(nn.Module): 6 | def __init__( 7 | self, D, beta=1, class_weights=None, ignore_label=None, focal_gamma=None 8 | ): 9 | """ 10 | Single Module to compute the soft-labels and pass tham to a cross-entropy loss. 11 | Args: 12 | D (tensor): Cost matrix 13 | beta (float): Invert temperature in the softmax layer. 14 | class_weights (list): (optional) class weighting values 15 | ignore_label (int): (optional) label to ignore in the loss 16 | focal_gamma (float): (optional) focal loss parameter. If provided the log-probabilities are multiplied by 17 | a factor (1 - p)^gamma . 18 | """ 19 | super(SoftCrossEntropy, self).__init__() 20 | self.D = D / D.max() 21 | self.beta = beta 22 | self.class_weights = class_weights 23 | self.ignore_label = ignore_label 24 | self.kldiv = nn.KLDivLoss(reduction="batchmean") 25 | self.gamma = focal_gamma 26 | 27 | def forward(self, input, y_true): 28 | if len(input.shape) == 4: # Flatten 2D data 29 | b, c, h, w = input.shape 30 | input = ( 31 | input.view(b, c, h * w).transpose(1, 2).contiguous().view(b * h * w, c) 32 | ) 33 | y_true = y_true.view(b * h * w) 34 | 35 | if self.ignore_label is not None: 36 | input = input[y_true != self.ignore_label] 37 | y_true = y_true[y_true != self.ignore_label] 38 | 39 | out = torch.nn.functional.log_softmax(input, dim=1) 40 | soft_target = nn.Softmax(dim=-1)(-self.beta * self.D[y_true.long()]) 41 | 42 | if self.gamma is not None: 43 | out = out * (1 - torch.exp(out)) ** self.gamma 44 | out = self.kldiv(out, soft_target) 45 | 46 | return out.mean() 47 | --------------------------------------------------------------------------------