├── .gitignore ├── LICENSE.txt ├── README.md └── sparsemax.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Kris Korrel 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sparsemax 2 | 3 | Implementation of the Sparsemax activation function in Pytorch from the paper: 4 | [From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification](https://arxiv.org/abs/1602.02068) by André F. T. Martins and Ramón Fernandez Astudillo 5 | 6 | Tested in Pytorch 0.4.0 7 | 8 | Example usage 9 | ```python 10 | import torch 11 | from sparsemax import Sparsemax 12 | 13 | sparsemax = Sparsemax(dim=1) 14 | softmax = torch.nn.Softmax(dim=1) 15 | 16 | logits = torch.randn(2, 5) 17 | print("\nLogits") 18 | print(logits) 19 | 20 | softmax_probs = softmax(logits) 21 | print("\nSoftmax probabilities") 22 | print(softmax_probs) 23 | 24 | sparsemax_probs = sparsemax(logits) 25 | print("\nSparsemax probabilities") 26 | print(sparsemax_probs) 27 | ``` 28 | 29 | Please add an issue if you have questions or suggestions. 30 | 31 | # References 32 | [From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification](https://arxiv.org/abs/1602.02068) 33 | 34 | Note that is a Pytorch port of an existing implementation: https://github.com/gokceneraslan/SparseMax.torch/ 35 | 36 | DOI for this particular repository: [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.3860669.svg)](https://doi.org/10.5281/zenodo.3860669) 37 | 38 | This implementation was used for [Transcoding Compositionally: Using Attention to Find More Generalizable Solutions](https://arxiv.org/abs/1906.01234) 39 | -------------------------------------------------------------------------------- /sparsemax.py: -------------------------------------------------------------------------------- 1 | """Sparsemax activation function. 2 | 3 | Pytorch implementation of Sparsemax function from: 4 | -- "From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification" 5 | -- André F. T. Martins, Ramón Fernandez Astudillo (http://arxiv.org/abs/1602.02068) 6 | """ 7 | 8 | from __future__ import division 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 14 | 15 | 16 | class Sparsemax(nn.Module): 17 | """Sparsemax function.""" 18 | 19 | def __init__(self, dim=None): 20 | """Initialize sparsemax activation 21 | 22 | Args: 23 | dim (int, optional): The dimension over which to apply the sparsemax function. 24 | """ 25 | super(Sparsemax, self).__init__() 26 | 27 | self.dim = -1 if dim is None else dim 28 | 29 | def forward(self, input): 30 | """Forward function. 31 | 32 | Args: 33 | input (torch.Tensor): Input tensor. First dimension should be the batch size 34 | 35 | Returns: 36 | torch.Tensor: [batch_size x number_of_logits] Output tensor 37 | 38 | """ 39 | # Sparsemax currently only handles 2-dim tensors, 40 | # so we reshape to a convenient shape and reshape back after sparsemax 41 | input = input.transpose(0, self.dim) 42 | original_size = input.size() 43 | input = input.reshape(input.size(0), -1) 44 | input = input.transpose(0, 1) 45 | dim = 1 46 | 47 | number_of_logits = input.size(dim) 48 | 49 | # Translate input by max for numerical stability 50 | input = input - torch.max(input, dim=dim, keepdim=True)[0].expand_as(input) 51 | 52 | # Sort input in descending order. 53 | # (NOTE: Can be replaced with linear time selection method described here: 54 | # http://stanford.edu/~jduchi/projects/DuchiShSiCh08.html) 55 | zs = torch.sort(input=input, dim=dim, descending=True)[0] 56 | range = torch.arange(start=1, end=number_of_logits + 1, step=1, device=device, dtype=input.dtype).view(1, -1) 57 | range = range.expand_as(zs) 58 | 59 | # Determine sparsity of projection 60 | bound = 1 + range * zs 61 | cumulative_sum_zs = torch.cumsum(zs, dim) 62 | is_gt = torch.gt(bound, cumulative_sum_zs).type(input.type()) 63 | k = torch.max(is_gt * range, dim, keepdim=True)[0] 64 | 65 | # Compute threshold function 66 | zs_sparse = is_gt * zs 67 | 68 | # Compute taus 69 | taus = (torch.sum(zs_sparse, dim, keepdim=True) - 1) / k 70 | taus = taus.expand_as(input) 71 | 72 | # Sparsemax 73 | self.output = torch.max(torch.zeros_like(input), input - taus) 74 | 75 | # Reshape back to original shape 76 | output = self.output 77 | output = output.transpose(0, 1) 78 | output = output.reshape(original_size) 79 | output = output.transpose(0, self.dim) 80 | 81 | return output 82 | 83 | def backward(self, grad_output): 84 | """Backward function.""" 85 | dim = 1 86 | 87 | nonzeros = torch.ne(self.output, 0) 88 | sum = torch.sum(grad_output * nonzeros, dim=dim) / torch.sum(nonzeros, dim=dim) 89 | self.grad_input = nonzeros * (grad_output - sum.expand_as(grad_output)) 90 | 91 | return self.grad_input --------------------------------------------------------------------------------