├── .coveragerc ├── .editorconfig ├── .github └── ISSUE_TEMPLATE.md ├── .gitignore ├── .idea ├── .gitignore ├── dictionaries │ └── laksh.xml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── sparsemax.iml └── vcs.xml ├── .travis.yml ├── AUTHORS.rst ├── CONTRIBUTING.rst ├── HISTORY.rst ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.rst ├── coverage.svg ├── docs ├── Makefile ├── authors.rst ├── conf.py ├── contributing.rst ├── history.rst ├── index.rst ├── installation.rst ├── make.bat ├── readme.rst └── usage.rst ├── requirements_dev.txt ├── setup.cfg ├── setup.py ├── sparsemax ├── __init__.py ├── sparsemax.py └── utils.py ├── tests ├── __init__.py └── test_sparsemax.py └── tox.ini /.coveragerc: -------------------------------------------------------------------------------- 1 | [paths] 2 | source = 3 | sparsemax/ 4 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # http://editorconfig.org 2 | 3 | root = true 4 | 5 | [*] 6 | indent_style = space 7 | indent_size = 4 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | charset = utf-8 11 | end_of_line = lf 12 | 13 | [*.bat] 14 | indent_style = tab 15 | end_of_line = crlf 16 | 17 | [LICENSE] 18 | insert_final_newline = false 19 | 20 | [Makefile] 21 | indent_style = tab 22 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | * sparsemax version: 2 | * Python version: 3 | * Operating System: 4 | 5 | ### Description 6 | 7 | Describe what you were trying to get done. 8 | Tell us what happened, what went wrong, and what you expected to happen. 9 | 10 | ### What I Did 11 | 12 | ``` 13 | Paste the command(s) you ran and the output. 14 | If there was a crash, please include the traceback here. 15 | ``` 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # Jupyter Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # SageMath parsed files 81 | *.sage.py 82 | 83 | # dotenv 84 | .env 85 | 86 | # virtualenv 87 | .venv 88 | venv/ 89 | ENV/ 90 | 91 | # Spyder project settings 92 | .spyderproject 93 | .spyproject 94 | 95 | # Rope project settings 96 | .ropeproject 97 | 98 | # mkdocs documentation 99 | /site 100 | 101 | # mypy 102 | .mypy_cache/ 103 | 104 | # IDE settings 105 | .vscode/ -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /.idea/dictionaries/laksh.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | sparsemax 5 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 54 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/sparsemax.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | 12 | 16 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - 3.8 4 | - 3.7 5 | - 3.6 6 | install: pip install -U tox-travis 7 | script: tox 8 | deploy: 9 | provider: pypi 10 | distributions: sdist bdist_wheel 11 | user: "__token__" 12 | password: 13 | secure: gO8xL1D62JB0mJvtJF5+qAI13CFf2fmnHCXK8lhxddzPd9IcnCD4CQE2PWVz677PyErpYgPAeZTwuRURYhhtdCplNCxy09Uc/80sqMiuJiZeULuK+sAncxALtQS5qc2TX08UiiIauKTk+bEuLm/WNLfaNdZIpfwAWgJeXKjOJPSjqyagCmXlVy/qQsV0UKNJE8QbV35BKTHonPMV6Ss8GzdaSdEUhq3dci+cEOABXW8GHUqO1Pa5TON5aNNQDWPgxJUtyKqYgzRVBrd+7jy443/5TiWpNxjpVEJRpL8Yuhb//xe2yyr7xHr47/s0u9pvRIevmc0O45UlmIraee5bgxfmSKfSgOkyLN2lA9XEPZuIihuqpPWufkqXqITURzOKPRnma9o9lzp47wahHfUbbGk/wFViWfRqMC3IwibvNP/kuLtHy/u2htJvKL78jBLiHKcW7B3mLOgMJXwuWEUWijdsZgDrmnpE27k2S5G/OKrkvnvWiIKelOIkizD7iNWVSyl8lvvjm21Q8a7x3FGT4ekHErP8htKYmD6rbXgwi8pY19bNZ3BuaMDh8W9ggRM2+tFDqnYOM+87cWZmrc27LBvuV/qu5pE6vZmZMlJ90+Zaqrto5GCNZCPzHcq723p/FHLITaPzu1ziP44Y1v8C71IqTXGaUwZYeF0T7T9SvBM= 14 | on: 15 | tags: true 16 | repo: aced125/sparsemax 17 | python: 3.8 18 | -------------------------------------------------------------------------------- /AUTHORS.rst: -------------------------------------------------------------------------------- 1 | ======= 2 | Credits 3 | ======= 4 | 5 | Development Lead 6 | ---------------- 7 | 8 | * Laksh Aithani 9 | 10 | Contributors 11 | ------------ 12 | 13 | None yet. Why not be the first? 14 | -------------------------------------------------------------------------------- /CONTRIBUTING.rst: -------------------------------------------------------------------------------- 1 | .. highlight:: shell 2 | 3 | ============ 4 | Contributing 5 | ============ 6 | 7 | Contributions are welcome, and they are greatly appreciated! Every little bit 8 | helps, and credit will always be given. 9 | 10 | You can contribute in many ways: 11 | 12 | Types of Contributions 13 | ---------------------- 14 | 15 | Report Bugs 16 | ~~~~~~~~~~~ 17 | 18 | Report bugs at https://github.com/aced125/sparsemax/issues. 19 | 20 | If you are reporting a bug, please include: 21 | 22 | * Your operating system name and version. 23 | * Any details about your local setup that might be helpful in troubleshooting. 24 | * Detailed steps to reproduce the bug. 25 | 26 | Fix Bugs 27 | ~~~~~~~~ 28 | 29 | Look through the GitHub issues for bugs. Anything tagged with "bug" and "help 30 | wanted" is open to whoever wants to implement it. 31 | 32 | Implement Features 33 | ~~~~~~~~~~~~~~~~~~ 34 | 35 | Look through the GitHub issues for features. Anything tagged with "enhancement" 36 | and "help wanted" is open to whoever wants to implement it. 37 | 38 | Write Documentation 39 | ~~~~~~~~~~~~~~~~~~~ 40 | 41 | sparsemax could always use more documentation, whether as part of the 42 | official sparsemax docs, in docstrings, or even on the web in blog posts, 43 | articles, and such. 44 | 45 | Submit Feedback 46 | ~~~~~~~~~~~~~~~ 47 | 48 | The best way to send feedback is to file an issue at https://github.com/aced125/sparsemax/issues. 49 | 50 | If you are proposing a feature: 51 | 52 | * Explain in detail how it would work. 53 | * Keep the scope as narrow as possible, to make it easier to implement. 54 | * Remember that this is a volunteer-driven project, and that contributions 55 | are welcome :) 56 | 57 | Get Started! 58 | ------------ 59 | 60 | Ready to contribute? Here's how to set up `sparsemax` for local development. 61 | 62 | 1. Fork the `sparsemax` repo on GitHub. 63 | 2. Clone your fork locally:: 64 | 65 | $ git clone git@github.com:your_name_here/sparsemax.git 66 | 67 | 3. Install your local copy into a virtualenv. Assuming you have virtualenvwrapper installed, this is how you set up your fork for local development:: 68 | 69 | $ mkvirtualenv sparsemax 70 | $ cd sparsemax/ 71 | $ python setup.py develop 72 | 73 | 4. Create a branch for local development:: 74 | 75 | $ git checkout -b name-of-your-bugfix-or-feature 76 | 77 | Now you can make your changes locally. 78 | 79 | 5. When you're done making changes, check that your changes pass flake8 and the 80 | tests, including testing other Python versions with tox:: 81 | 82 | $ flake8 sparsemax tests 83 | $ python setup.py test or pytest 84 | $ tox 85 | 86 | To get flake8 and tox, just pip install them into your virtualenv. 87 | 88 | 6. Commit your changes and push your branch to GitHub:: 89 | 90 | $ git add . 91 | $ git commit -m "Your detailed description of your changes." 92 | $ git push origin name-of-your-bugfix-or-feature 93 | 94 | 7. Submit a pull request through the GitHub website. 95 | 96 | Pull Request Guidelines 97 | ----------------------- 98 | 99 | Before you submit a pull request, check that it meets these guidelines: 100 | 101 | 1. The pull request should include tests. 102 | 2. If the pull request adds functionality, the docs should be updated. Put 103 | your new functionality into a function with a docstring, and add the 104 | feature to the list in README.rst. 105 | 3. The pull request should work for Python 3.5, 3.6, 3.7 and 3.8, and for PyPy. Check 106 | https://travis-ci.com/aced125/sparsemax/pull_requests 107 | and make sure that the tests pass for all supported Python versions. 108 | 109 | Tips 110 | ---- 111 | 112 | To run a subset of tests:: 113 | 114 | $ pytest tests.test_sparsemax 115 | 116 | 117 | Deploying 118 | --------- 119 | 120 | A reminder for the maintainers on how to deploy. 121 | Make sure all your changes are committed (including an entry in HISTORY.rst). 122 | Then run:: 123 | 124 | $ bump2version patch # possible: major / minor / patch 125 | $ git push 126 | $ git push --tags 127 | 128 | Travis will then deploy to PyPI if tests pass. 129 | -------------------------------------------------------------------------------- /HISTORY.rst: -------------------------------------------------------------------------------- 1 | ======= 2 | History 3 | ======= 4 | 5 | 0.1.0 (2020-05-25) 6 | ------------------ 7 | 8 | * First release on PyPI. 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020, Laksh Aithani 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include AUTHORS.rst 2 | include CONTRIBUTING.rst 3 | include HISTORY.rst 4 | include LICENSE 5 | include README.rst 6 | 7 | recursive-include tests * 8 | recursive-exclude * __pycache__ 9 | recursive-exclude * *.py[co] 10 | 11 | recursive-include docs *.rst conf.py Makefile make.bat *.jpg *.png *.gif 12 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: clean clean-test clean-pyc clean-build docs help 2 | .DEFAULT_GOAL := help 3 | 4 | define BROWSER_PYSCRIPT 5 | import os, webbrowser, sys 6 | 7 | from urllib.request import pathname2url 8 | 9 | webbrowser.open("file://" + pathname2url(os.path.abspath(sys.argv[1]))) 10 | endef 11 | export BROWSER_PYSCRIPT 12 | 13 | define PRINT_HELP_PYSCRIPT 14 | import re, sys 15 | 16 | for line in sys.stdin: 17 | match = re.match(r'^([a-zA-Z_-]+):.*?## (.*)$$', line) 18 | if match: 19 | target, help = match.groups() 20 | print("%-20s %s" % (target, help)) 21 | endef 22 | export PRINT_HELP_PYSCRIPT 23 | 24 | BROWSER := python -c "$$BROWSER_PYSCRIPT" 25 | 26 | help: 27 | @python -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST) 28 | 29 | clean: clean-build clean-pyc clean-test ## remove all build, test, coverage and Python artifacts 30 | 31 | clean-build: ## remove build artifacts 32 | rm -fr build/ 33 | rm -fr dist/ 34 | rm -fr .eggs/ 35 | find . -name '*.egg-info' -exec rm -fr {} + 36 | find . -name '*.egg' -exec rm -f {} + 37 | 38 | clean-pyc: ## remove Python file artifacts 39 | find . -name '*.pyc' -exec rm -f {} + 40 | find . -name '*.pyo' -exec rm -f {} + 41 | find . -name '*~' -exec rm -f {} + 42 | find . -name '__pycache__' -exec rm -fr {} + 43 | 44 | clean-test: ## remove test and coverage artifacts 45 | rm -fr .tox/ 46 | rm -f .coverage 47 | rm -fr htmlcov/ 48 | rm -fr .pytest_cache 49 | 50 | lint: ## check style with flake8 51 | flake8 sparsemax tests 52 | 53 | test: ## run tests quickly with the default Python 54 | pytest 55 | 56 | test-all: ## run tests on every Python version with tox 57 | tox 58 | 59 | coverage: ## check code coverage quickly with the default Python 60 | coverage run --source sparsemax -m pytest 61 | coverage report -m 62 | coverage html 63 | $(BROWSER) htmlcov/index.html 64 | 65 | docs: ## generate Sphinx HTML documentation, including API docs 66 | rm -f docs/sparsemax.rst 67 | rm -f docs/modules.rst 68 | sphinx-apidoc -o docs/ sparsemax 69 | $(MAKE) -C docs clean 70 | $(MAKE) -C docs html 71 | $(BROWSER) docs/_build/html/index.html 72 | 73 | servedocs: docs ## compile the docs watching for changes 74 | watchmedo shell-command -p '*.rst' -c '$(MAKE) -C docs html' -R -D . 75 | 76 | release: dist ## package and upload a release 77 | twine upload dist/* 78 | 79 | dist: clean ## builds source and wheel package 80 | python setup.py sdist 81 | python setup.py bdist_wheel 82 | ls -l dist 83 | 84 | install: clean ## install the package to the active Python's site-packages 85 | python setup.py install 86 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | ========= 2 | sparsemax 3 | ========= 4 | 5 | 6 | .. image:: https://img.shields.io/pypi/v/sparsemax.svg 7 | :target: https://pypi.python.org/pypi/sparsemax 8 | 9 | .. image:: https://img.shields.io/travis/aced125/sparsemax.svg 10 | :target: https://travis-ci.com/aced125/sparsemax 11 | 12 | .. image:: https://readthedocs.org/projects/sparsemax/badge/?version=latest 13 | :target: https://sparsemax.readthedocs.io/en/latest/?badge=latest 14 | :alt: Documentation Status 15 | 16 | 17 | .. image:: https://pyup.io/repos/github/aced125/sparsemax/shield.svg 18 | :target: https://pyup.io/repos/github/aced125/sparsemax/ 19 | :alt: Updates 20 | 21 | .. image:: coverage.svg 22 | 23 | 24 | 25 | 26 | A PyTorch implementation of SparseMax (https://arxiv.org/pdf/1602.02068.pdf) with gradients checked and tested 27 | 28 | Sparsemax is an alternative to softmax when one wants to generate 29 | hard probability distributions. It has been used to great effect in recent papers like 30 | ProtoAttend (https://arxiv.org/pdf/1902.06292v4.pdf). 31 | 32 | Installation 33 | ------------ 34 | 35 | .. code-block:: bash 36 | 37 | pip install -U sparsemax 38 | 39 | 40 | Usage 41 | ----- 42 | 43 | Use as if it was :code:`nn.Softmax()`! Nice and simple. 44 | 45 | .. code-block:: python 46 | 47 | from sparsemax import Sparsemax 48 | import torch 49 | import torch.nn as nn 50 | 51 | sparsemax = Sparsemax(dim=-1) 52 | softmax = torch.nn.Softmax(dim=-1) 53 | 54 | logits = torch.randn(2, 3, 5) 55 | logits.requires_grad = True 56 | print("\nLogits") 57 | print(logits) 58 | 59 | softmax_probs = softmax(logits) 60 | print("\nSoftmax probabilities") 61 | print(softmax_probs) 62 | 63 | sparsemax_probs = sparsemax(logits) 64 | print("\nSparsemax probabilities") 65 | print(sparsemax_probs) 66 | 67 | 68 | Advantages over existing implementations 69 | ---------------------------------------- 70 | This repo borrows heavily from: https://github.com/KrisKorrel/sparsemax-pytorch 71 | 72 | However, there are a few key advantages: 73 | 74 | 1. Backward pass equations implemented natively as a :code:`torch.autograd.Function`, **resulting in 30% speedup**, compared to the above repository. 75 | 2. The package is **easily pip-installable** (no need to copy the code). 76 | 3. The package works for **multi-dimensional tensors, operating over any axis**. 77 | 4. The operator **forward and backward passes are tested** (backward-pass check due to :code:`torch.autograd.gradcheck` 78 | 79 | 80 | Check that gradients are computed correctly 81 | ------------------------------------------- 82 | 83 | .. code-block:: python 84 | 85 | from torch.autograd import gradcheck 86 | from sparsemax import Sparsemax 87 | 88 | input = (torch.randn(6, 3, 20,dtype=torch.double,requires_grad=True)) 89 | test = gradcheck(sparsemax, input, eps=1e-6, atol=1e-4) 90 | print(test) 91 | 92 | 93 | 94 | Credits 95 | ------- 96 | 97 | This package was created with Cookiecutter_ and the `audreyr/cookiecutter-pypackage`_ project template. 98 | 99 | .. _Cookiecutter: https://github.com/audreyr/cookiecutter 100 | .. _`audreyr/cookiecutter-pypackage`: https://github.com/audreyr/cookiecutter-pypackage 101 | -------------------------------------------------------------------------------- /coverage.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | coverage 17 | coverage 18 | 79% 19 | 79% 20 | 21 | 22 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = python -msphinx 7 | SPHINXPROJ = sparsemax 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/authors.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../AUTHORS.rst 2 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # sparsemax documentation build configuration file, created by 4 | # sphinx-quickstart on Fri Jun 9 13:47:02 2017. 5 | # 6 | # This file is execfile()d with the current directory set to its 7 | # containing dir. 8 | # 9 | # Note that not all possible configuration values are present in this 10 | # autogenerated file. 11 | # 12 | # All configuration values have a default; values that are commented out 13 | # serve to show the default. 14 | 15 | # If extensions (or modules to document with autodoc) are in another 16 | # directory, add these directories to sys.path here. If the directory is 17 | # relative to the documentation root, use os.path.abspath to make it 18 | # absolute, like shown here. 19 | # 20 | import os 21 | import sys 22 | 23 | sys.path.insert(0, os.path.abspath("..")) 24 | 25 | import sparsemax 26 | 27 | # -- General configuration --------------------------------------------- 28 | 29 | # If your documentation needs a minimal Sphinx version, state it here. 30 | # 31 | # needs_sphinx = '1.0' 32 | 33 | # Add any Sphinx extension module names here, as strings. They can be 34 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. 35 | extensions = ["sphinx.ext.autodoc", "sphinx.ext.viewcode", "sphinxemoji.sphinxemoji"] 36 | 37 | # Add any paths that contain templates here, relative to this directory. 38 | templates_path = ["_templates"] 39 | 40 | # The suffix(es) of source filenames. 41 | # You can specify multiple suffix as a list of string: 42 | # 43 | # source_suffix = ['.rst', '.md'] 44 | source_suffix = ".rst" 45 | 46 | # The master toctree document. 47 | master_doc = "index" 48 | 49 | # General information about the project. 50 | project = "sparsemax" 51 | copyright = "2020, Laksh Aithani" 52 | author = "Laksh Aithani" 53 | 54 | # The version info for the project you're documenting, acts as replacement 55 | # for |version| and |release|, also used in various other places throughout 56 | # the built documents. 57 | # 58 | # The short X.Y version. 59 | version = sparsemax.__version__ 60 | # The full version, including alpha/beta/rc tags. 61 | release = sparsemax.__version__ 62 | 63 | # The language for content autogenerated by Sphinx. Refer to documentation 64 | # for a list of supported languages. 65 | # 66 | # This is also used if you do content translation via gettext catalogs. 67 | # Usually you set "language" from the command line for these cases. 68 | language = None 69 | 70 | # List of patterns, relative to source directory, that match files and 71 | # directories to ignore when looking for source files. 72 | # This patterns also effect to html_static_path and html_extra_path 73 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 74 | 75 | # The name of the Pygments (syntax highlighting) style to use. 76 | pygments_style = "sphinx" 77 | 78 | # If true, `todo` and `todoList` produce output, else they produce nothing. 79 | todo_include_todos = False 80 | 81 | 82 | # -- Options for HTML output ------------------------------------------- 83 | 84 | # The theme to use for HTML and HTML Help pages. See the documentation for 85 | # a list of builtin themes. 86 | # 87 | html_theme = "alabaster" 88 | 89 | # Theme options are theme-specific and customize the look and feel of a 90 | # theme further. For a list of options available for each theme, see the 91 | # documentation. 92 | # 93 | # html_theme_options = {} 94 | 95 | # Add any paths that contain custom static files (such as style sheets) here, 96 | # relative to this directory. They are copied after the builtin static files, 97 | # so a file named "default.css" will overwrite the builtin "default.css". 98 | html_static_path = ["_static"] 99 | 100 | 101 | # -- Options for HTMLHelp output --------------------------------------- 102 | 103 | # Output file base name for HTML help builder. 104 | htmlhelp_basename = "sparsemaxdoc" 105 | 106 | 107 | # -- Options for LaTeX output ------------------------------------------ 108 | 109 | latex_elements = { 110 | # The paper size ('letterpaper' or 'a4paper'). 111 | # 112 | # 'papersize': 'letterpaper', 113 | # The font size ('10pt', '11pt' or '12pt'). 114 | # 115 | # 'pointsize': '10pt', 116 | # Additional stuff for the LaTeX preamble. 117 | # 118 | # 'preamble': '', 119 | # Latex figure (float) alignment 120 | # 121 | # 'figure_align': 'htbp', 122 | } 123 | 124 | # Grouping the document tree into LaTeX files. List of tuples 125 | # (source start file, target name, title, author, documentclass 126 | # [howto, manual, or own class]). 127 | latex_documents = [ 128 | (master_doc, "sparsemax.tex", "sparsemax Documentation", "Laksh Aithani", "manual"), 129 | ] 130 | 131 | 132 | # -- Options for manual page output ------------------------------------ 133 | 134 | # One entry per manual page. List of tuples 135 | # (source start file, name, description, authors, manual section). 136 | man_pages = [(master_doc, "sparsemax", "sparsemax Documentation", [author], 1)] 137 | 138 | 139 | # -- Options for Texinfo output ---------------------------------------- 140 | 141 | # Grouping the document tree into Texinfo files. List of tuples 142 | # (source start file, target name, title, author, 143 | # dir menu entry, description, category) 144 | texinfo_documents = [ 145 | ( 146 | master_doc, 147 | "sparsemax", 148 | "sparsemax Documentation", 149 | author, 150 | "sparsemax", 151 | "One line description of project.", 152 | "Miscellaneous", 153 | ), 154 | ] 155 | -------------------------------------------------------------------------------- /docs/contributing.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../CONTRIBUTING.rst 2 | -------------------------------------------------------------------------------- /docs/history.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../HISTORY.rst 2 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to sparsemax's documentation! 2 | ====================================== 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | :caption: Contents: 7 | 8 | readme 9 | installation 10 | usage 11 | modules 12 | contributing 13 | authors 14 | history 15 | 16 | Indices and tables 17 | ================== 18 | * :ref:`genindex` 19 | * :ref:`modindex` 20 | * :ref:`search` 21 | -------------------------------------------------------------------------------- /docs/installation.rst: -------------------------------------------------------------------------------- 1 | .. highlight:: shell 2 | 3 | ============ 4 | Installation 5 | ============ 6 | 7 | 8 | Stable release 9 | -------------- 10 | 11 | To install sparsemax, run this command in your terminal: 12 | 13 | .. code-block:: console 14 | 15 | $ pip install sparsemax 16 | 17 | This is the preferred method to install sparsemax, as it will always install the most recent stable release. 18 | 19 | If you don't have `pip`_ installed, this `Python installation guide`_ can guide 20 | you through the process. 21 | 22 | .. _pip: https://pip.pypa.io 23 | .. _Python installation guide: http://docs.python-guide.org/en/latest/starting/installation/ 24 | 25 | 26 | From sources 27 | ------------ 28 | 29 | The sources for sparsemax can be downloaded from the `Github repo`_. 30 | 31 | You can either clone the public repository: 32 | 33 | .. code-block:: console 34 | 35 | $ git clone git://github.com/aced125/sparsemax 36 | 37 | Or download the `tarball`_: 38 | 39 | .. code-block:: console 40 | 41 | $ curl -OJL https://github.com/aced125/sparsemax/tarball/master 42 | 43 | Once you have a copy of the source, you can install it with: 44 | 45 | .. code-block:: console 46 | 47 | $ python setup.py install 48 | 49 | 50 | .. _Github repo: https://github.com/aced125/sparsemax 51 | .. _tarball: https://github.com/aced125/sparsemax/tarball/master 52 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=python -msphinx 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=sparsemax 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The Sphinx module was not found. Make sure you have Sphinx installed, 20 | echo.then set the SPHINXBUILD environment variable to point to the full 21 | echo.path of the 'sphinx-build' executable. Alternatively you may add the 22 | echo.Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /docs/readme.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../README.rst 2 | -------------------------------------------------------------------------------- /docs/usage.rst: -------------------------------------------------------------------------------- 1 | ===== 2 | Usage 3 | ===== 4 | 5 | To use sparsemax in a project:: 6 | 7 | import sparsemax 8 | -------------------------------------------------------------------------------- /requirements_dev.txt: -------------------------------------------------------------------------------- 1 | pip==19.2.3 2 | bump2version==0.5.11 3 | wheel==0.33.6 4 | watchdog==0.9.0 5 | flake8==3.7.8 6 | tox==3.14.0 7 | coverage==4.5.4 8 | Sphinx==1.8.5 9 | twine==1.14.0 10 | 11 | torch 12 | 13 | pytest==4.6.5 14 | pytest-runner==5.1 15 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 0.1.8 3 | commit = True 4 | tag = True 5 | 6 | [bumpversion:file:setup.py] 7 | search = version='{current_version}' 8 | replace = version='{new_version}' 9 | 10 | [bumpversion:file:sparsemax/__init__.py] 11 | search = __version__ = '{current_version}' 12 | replace = __version__ = '{new_version}' 13 | 14 | [bdist_wheel] 15 | universal = 1 16 | 17 | [flake8] 18 | exclude = docs 19 | 20 | [aliases] 21 | test = pytest 22 | 23 | [tool:pytest] 24 | collect_ignore = ['setup.py'] 25 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """The setup script.""" 4 | 5 | from setuptools import setup, find_packages 6 | 7 | with open("README.rst") as readme_file: 8 | readme = readme_file.read() 9 | 10 | with open("HISTORY.rst") as history_file: 11 | history = history_file.read() 12 | 13 | requirements = ["torch"] 14 | 15 | setup_requirements = [ 16 | "pytest-runner", 17 | ] 18 | 19 | test_requirements = [ 20 | "pytest>=3", 21 | ] 22 | 23 | # fmt: off 24 | setup( 25 | author="Laksh Aithani", 26 | author_email="lakshaithanii@gmail.com", 27 | python_requires=">=3.5", 28 | classifiers=[ 29 | "Development Status :: 2 - Pre-Alpha", 30 | "Intended Audience :: Developers", 31 | "License :: OSI Approved :: MIT License", 32 | "Natural Language :: English", 33 | "Programming Language :: Python :: 3", 34 | "Programming Language :: Python :: 3.5", 35 | "Programming Language :: Python :: 3.6", 36 | "Programming Language :: Python :: 3.7", 37 | "Programming Language :: Python :: 3.8", 38 | ], 39 | description="Sparsemax pytorch", 40 | install_requires=requirements, 41 | license="MIT license", 42 | long_description=readme + "\n\n" + history, 43 | include_package_data=True, 44 | keywords="sparsemax", 45 | name="sparsemax", 46 | packages=find_packages(include=["sparsemax", "sparsemax.*"]), 47 | setup_requires=setup_requirements, 48 | test_suite="tests", 49 | tests_require=test_requirements, 50 | url="https://github.com/aced125/sparsemax", 51 | version='0.1.8', 52 | zip_safe=False, 53 | ) 54 | # fmt: on 55 | -------------------------------------------------------------------------------- /sparsemax/__init__.py: -------------------------------------------------------------------------------- 1 | """Top-level package for sparsemax.""" 2 | # fmt: off 3 | __author__ = """Laksh Aithani""" 4 | __email__ = "lakshaithanii@gmail.com" 5 | __version__ = '0.1.8' 6 | 7 | from sparsemax.sparsemax import Sparsemax 8 | # fmt: on 9 | -------------------------------------------------------------------------------- /sparsemax/sparsemax.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from sparsemax.utils import flatten_all_but_nth_dim, unflatten_all_but_nth_dim 4 | 5 | 6 | class Sparsemax(nn.Module): 7 | __constants__ = ["dim"] 8 | 9 | def __init__(self, dim=-1): 10 | """ 11 | Sparsemax class as seen in https://arxiv.org/pdf/1602.02068.pdf 12 | Parameters 13 | ---------- 14 | dim: The dimension we want to cast the operation over. Default -1 15 | """ 16 | super(Sparsemax, self).__init__() 17 | self.dim = dim 18 | 19 | def __setstate__(self, state): 20 | self.__dict__.update(state) 21 | if not hasattr(self, "dim"): 22 | self.dim = None 23 | 24 | def forward(self, input): 25 | return SparsemaxFunction.apply(input, self.dim) 26 | 27 | def extra_repr(self): 28 | return f"dim={self.dim}" 29 | 30 | 31 | class SparsemaxFunction(torch.autograd.Function): 32 | @staticmethod 33 | def forward(ctx, input: torch.Tensor, dim: int = -1): 34 | input_dim = input.dim() 35 | if input_dim <= dim or dim < -input_dim: 36 | raise IndexError( 37 | f"Dimension out of range (expected to be in range of [-{input_dim}, {input_dim - 1}], but got {dim})" 38 | ) 39 | 40 | # Save operating dimension to context 41 | ctx.needs_reshaping = input_dim > 2 42 | ctx.dim = dim 43 | 44 | if ctx.needs_reshaping: 45 | ctx, input = flatten_all_but_nth_dim(ctx, input) 46 | 47 | # Translate by max for numerical stability 48 | input = input - input.max(-1, keepdim=True).values.expand_as(input) 49 | 50 | zs = input.sort(-1, descending=True).values 51 | range = torch.arange(1, input.size()[-1] + 1) 52 | range = range.expand_as(input).to(input) 53 | 54 | # Determine sparsity of projection 55 | bound = 1 + range * zs 56 | is_gt = bound.gt(zs.cumsum(-1)).type(input.dtype) 57 | k = (is_gt * range).max(-1, keepdim=True).values 58 | 59 | # Compute threshold 60 | zs_sparse = is_gt * zs 61 | 62 | # Compute taus 63 | taus = (zs_sparse.sum(-1, keepdim=True) - 1) / k 64 | taus = taus.expand_as(input) 65 | 66 | output = torch.max(torch.zeros_like(input), input - taus) 67 | 68 | # Save context 69 | ctx.save_for_backward(output) 70 | 71 | # Reshape back to original shape 72 | if ctx.needs_reshaping: 73 | ctx, output = unflatten_all_but_nth_dim(ctx, output) 74 | 75 | return output 76 | 77 | @staticmethod 78 | def backward(ctx, grad_output): 79 | output, *_ = ctx.saved_tensors 80 | 81 | # Reshape if needed 82 | if ctx.needs_reshaping: 83 | ctx, grad_output = flatten_all_but_nth_dim(ctx, grad_output) 84 | 85 | # Compute gradient 86 | nonzeros = torch.ne(output, 0) 87 | num_nonzeros = nonzeros.sum(-1, keepdim=True) 88 | sum = (grad_output * nonzeros).sum(-1, keepdim=True) / num_nonzeros 89 | grad_input = nonzeros * (grad_output - sum.expand_as(grad_output)) 90 | 91 | # Reshape back to original shape 92 | if ctx.needs_reshaping: 93 | ctx, grad_input = unflatten_all_but_nth_dim(ctx, grad_input) 94 | 95 | return grad_input, None 96 | -------------------------------------------------------------------------------- /sparsemax/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def flatten_all_but_nth_dim(ctx, x: torch.Tensor): 5 | """ 6 | Flattens tensor in all but 1 chosen dimension. 7 | Saves necessary context for backward pass and unflattening. 8 | """ 9 | 10 | # transpose batch and nth dim 11 | x = x.transpose(0, ctx.dim) 12 | 13 | # Get and save original size in context for backward pass 14 | original_size = x.size() 15 | ctx.original_size = original_size 16 | 17 | # Flatten all dimensions except nth dim 18 | x = x.reshape(x.size(0), -1) 19 | 20 | # Transpose flattened dimensions to 0th dim, nth dim to last dim 21 | return ctx, x.transpose(0, -1) 22 | 23 | 24 | def unflatten_all_but_nth_dim(ctx, x: torch.Tensor): 25 | """ 26 | Unflattens tensor using necessary context 27 | """ 28 | # Tranpose flattened dim to last dim, nth dim to 0th dim 29 | x = x.transpose(0, 1) 30 | 31 | # Reshape to original size 32 | x = x.reshape(ctx.original_size) 33 | 34 | # Swap batch dim and nth dim 35 | return ctx, x.transpose(0, ctx.dim) 36 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Unit test package for sparsemax.""" 2 | -------------------------------------------------------------------------------- /tests/test_sparsemax.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """Tests for `sparsemax` package.""" 4 | 5 | import pytest 6 | from sparsemax import Sparsemax 7 | from torch.autograd import gradcheck 8 | import torch 9 | 10 | 11 | @pytest.mark.parametrize("dimension", [-4, -3, -2, -1, 0, 1, 2, 3]) 12 | def test_sparsemax(dimension): 13 | sparsemax = Sparsemax(dimension) 14 | input = torch.randn(6, 3, 5, 4, dtype=torch.double, requires_grad=True) 15 | assert gradcheck(sparsemax, input, eps=1e-6, atol=1e-4) 16 | 17 | 18 | def test_sparsemax_invalid_dimension(): 19 | sparsemax = Sparsemax(-7) 20 | input = torch.randn(6, 3, 5, 4, dtype=torch.double, requires_grad=True) 21 | with pytest.raises(IndexError): 22 | gradcheck(sparsemax, input, eps=1e-6, atol=1e-4) 23 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py35, py36, py37, py38, flake8 3 | 4 | [travis] 5 | python = 6 | 3.8: py38 7 | 3.7: py37 8 | 3.6: py36 9 | 3.5: py35 10 | 11 | [testenv:flake8] 12 | basepython = python 13 | deps = flake8 14 | commands = flake8 sparsemax tests 15 | 16 | [testenv] 17 | setenv = 18 | PYTHONPATH = {toxinidir} 19 | deps = 20 | -r{toxinidir}/requirements_dev.txt 21 | ; If you want to make tox run the tests with the same versions, create a 22 | ; requirements.txt with the pinned versions and uncomment the following line: 23 | ; -r{toxinidir}/requirements.txt 24 | commands = 25 | pip install -U pip 26 | pytest --basetemp={envtmpdir} 27 | 28 | --------------------------------------------------------------------------------