├── .coveragerc
├── .editorconfig
├── .github
└── ISSUE_TEMPLATE.md
├── .gitignore
├── .idea
├── .gitignore
├── dictionaries
│ └── laksh.xml
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── sparsemax.iml
└── vcs.xml
├── .travis.yml
├── AUTHORS.rst
├── CONTRIBUTING.rst
├── HISTORY.rst
├── LICENSE
├── MANIFEST.in
├── Makefile
├── README.rst
├── coverage.svg
├── docs
├── Makefile
├── authors.rst
├── conf.py
├── contributing.rst
├── history.rst
├── index.rst
├── installation.rst
├── make.bat
├── readme.rst
└── usage.rst
├── requirements_dev.txt
├── setup.cfg
├── setup.py
├── sparsemax
├── __init__.py
├── sparsemax.py
└── utils.py
├── tests
├── __init__.py
└── test_sparsemax.py
└── tox.ini
/.coveragerc:
--------------------------------------------------------------------------------
1 | [paths]
2 | source =
3 | sparsemax/
4 |
--------------------------------------------------------------------------------
/.editorconfig:
--------------------------------------------------------------------------------
1 | # http://editorconfig.org
2 |
3 | root = true
4 |
5 | [*]
6 | indent_style = space
7 | indent_size = 4
8 | trim_trailing_whitespace = true
9 | insert_final_newline = true
10 | charset = utf-8
11 | end_of_line = lf
12 |
13 | [*.bat]
14 | indent_style = tab
15 | end_of_line = crlf
16 |
17 | [LICENSE]
18 | insert_final_newline = false
19 |
20 | [Makefile]
21 | indent_style = tab
22 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | * sparsemax version:
2 | * Python version:
3 | * Operating System:
4 |
5 | ### Description
6 |
7 | Describe what you were trying to get done.
8 | Tell us what happened, what went wrong, and what you expected to happen.
9 |
10 | ### What I Did
11 |
12 | ```
13 | Paste the command(s) you ran and the output.
14 | If there was a crash, please include the traceback here.
15 | ```
16 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 |
58 | # Flask stuff:
59 | instance/
60 | .webassets-cache
61 |
62 | # Scrapy stuff:
63 | .scrapy
64 |
65 | # Sphinx documentation
66 | docs/_build/
67 |
68 | # PyBuilder
69 | target/
70 |
71 | # Jupyter Notebook
72 | .ipynb_checkpoints
73 |
74 | # pyenv
75 | .python-version
76 |
77 | # celery beat schedule file
78 | celerybeat-schedule
79 |
80 | # SageMath parsed files
81 | *.sage.py
82 |
83 | # dotenv
84 | .env
85 |
86 | # virtualenv
87 | .venv
88 | venv/
89 | ENV/
90 |
91 | # Spyder project settings
92 | .spyderproject
93 | .spyproject
94 |
95 | # Rope project settings
96 | .ropeproject
97 |
98 | # mkdocs documentation
99 | /site
100 |
101 | # mypy
102 | .mypy_cache/
103 |
104 | # IDE settings
105 | .vscode/
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 |
--------------------------------------------------------------------------------
/.idea/dictionaries/laksh.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | sparsemax
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
19 |
20 |
21 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/sparsemax.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: python
2 | python:
3 | - 3.8
4 | - 3.7
5 | - 3.6
6 | install: pip install -U tox-travis
7 | script: tox
8 | deploy:
9 | provider: pypi
10 | distributions: sdist bdist_wheel
11 | user: "__token__"
12 | password:
13 | secure: gO8xL1D62JB0mJvtJF5+qAI13CFf2fmnHCXK8lhxddzPd9IcnCD4CQE2PWVz677PyErpYgPAeZTwuRURYhhtdCplNCxy09Uc/80sqMiuJiZeULuK+sAncxALtQS5qc2TX08UiiIauKTk+bEuLm/WNLfaNdZIpfwAWgJeXKjOJPSjqyagCmXlVy/qQsV0UKNJE8QbV35BKTHonPMV6Ss8GzdaSdEUhq3dci+cEOABXW8GHUqO1Pa5TON5aNNQDWPgxJUtyKqYgzRVBrd+7jy443/5TiWpNxjpVEJRpL8Yuhb//xe2yyr7xHr47/s0u9pvRIevmc0O45UlmIraee5bgxfmSKfSgOkyLN2lA9XEPZuIihuqpPWufkqXqITURzOKPRnma9o9lzp47wahHfUbbGk/wFViWfRqMC3IwibvNP/kuLtHy/u2htJvKL78jBLiHKcW7B3mLOgMJXwuWEUWijdsZgDrmnpE27k2S5G/OKrkvnvWiIKelOIkizD7iNWVSyl8lvvjm21Q8a7x3FGT4ekHErP8htKYmD6rbXgwi8pY19bNZ3BuaMDh8W9ggRM2+tFDqnYOM+87cWZmrc27LBvuV/qu5pE6vZmZMlJ90+Zaqrto5GCNZCPzHcq723p/FHLITaPzu1ziP44Y1v8C71IqTXGaUwZYeF0T7T9SvBM=
14 | on:
15 | tags: true
16 | repo: aced125/sparsemax
17 | python: 3.8
18 |
--------------------------------------------------------------------------------
/AUTHORS.rst:
--------------------------------------------------------------------------------
1 | =======
2 | Credits
3 | =======
4 |
5 | Development Lead
6 | ----------------
7 |
8 | * Laksh Aithani
9 |
10 | Contributors
11 | ------------
12 |
13 | None yet. Why not be the first?
14 |
--------------------------------------------------------------------------------
/CONTRIBUTING.rst:
--------------------------------------------------------------------------------
1 | .. highlight:: shell
2 |
3 | ============
4 | Contributing
5 | ============
6 |
7 | Contributions are welcome, and they are greatly appreciated! Every little bit
8 | helps, and credit will always be given.
9 |
10 | You can contribute in many ways:
11 |
12 | Types of Contributions
13 | ----------------------
14 |
15 | Report Bugs
16 | ~~~~~~~~~~~
17 |
18 | Report bugs at https://github.com/aced125/sparsemax/issues.
19 |
20 | If you are reporting a bug, please include:
21 |
22 | * Your operating system name and version.
23 | * Any details about your local setup that might be helpful in troubleshooting.
24 | * Detailed steps to reproduce the bug.
25 |
26 | Fix Bugs
27 | ~~~~~~~~
28 |
29 | Look through the GitHub issues for bugs. Anything tagged with "bug" and "help
30 | wanted" is open to whoever wants to implement it.
31 |
32 | Implement Features
33 | ~~~~~~~~~~~~~~~~~~
34 |
35 | Look through the GitHub issues for features. Anything tagged with "enhancement"
36 | and "help wanted" is open to whoever wants to implement it.
37 |
38 | Write Documentation
39 | ~~~~~~~~~~~~~~~~~~~
40 |
41 | sparsemax could always use more documentation, whether as part of the
42 | official sparsemax docs, in docstrings, or even on the web in blog posts,
43 | articles, and such.
44 |
45 | Submit Feedback
46 | ~~~~~~~~~~~~~~~
47 |
48 | The best way to send feedback is to file an issue at https://github.com/aced125/sparsemax/issues.
49 |
50 | If you are proposing a feature:
51 |
52 | * Explain in detail how it would work.
53 | * Keep the scope as narrow as possible, to make it easier to implement.
54 | * Remember that this is a volunteer-driven project, and that contributions
55 | are welcome :)
56 |
57 | Get Started!
58 | ------------
59 |
60 | Ready to contribute? Here's how to set up `sparsemax` for local development.
61 |
62 | 1. Fork the `sparsemax` repo on GitHub.
63 | 2. Clone your fork locally::
64 |
65 | $ git clone git@github.com:your_name_here/sparsemax.git
66 |
67 | 3. Install your local copy into a virtualenv. Assuming you have virtualenvwrapper installed, this is how you set up your fork for local development::
68 |
69 | $ mkvirtualenv sparsemax
70 | $ cd sparsemax/
71 | $ python setup.py develop
72 |
73 | 4. Create a branch for local development::
74 |
75 | $ git checkout -b name-of-your-bugfix-or-feature
76 |
77 | Now you can make your changes locally.
78 |
79 | 5. When you're done making changes, check that your changes pass flake8 and the
80 | tests, including testing other Python versions with tox::
81 |
82 | $ flake8 sparsemax tests
83 | $ python setup.py test or pytest
84 | $ tox
85 |
86 | To get flake8 and tox, just pip install them into your virtualenv.
87 |
88 | 6. Commit your changes and push your branch to GitHub::
89 |
90 | $ git add .
91 | $ git commit -m "Your detailed description of your changes."
92 | $ git push origin name-of-your-bugfix-or-feature
93 |
94 | 7. Submit a pull request through the GitHub website.
95 |
96 | Pull Request Guidelines
97 | -----------------------
98 |
99 | Before you submit a pull request, check that it meets these guidelines:
100 |
101 | 1. The pull request should include tests.
102 | 2. If the pull request adds functionality, the docs should be updated. Put
103 | your new functionality into a function with a docstring, and add the
104 | feature to the list in README.rst.
105 | 3. The pull request should work for Python 3.5, 3.6, 3.7 and 3.8, and for PyPy. Check
106 | https://travis-ci.com/aced125/sparsemax/pull_requests
107 | and make sure that the tests pass for all supported Python versions.
108 |
109 | Tips
110 | ----
111 |
112 | To run a subset of tests::
113 |
114 | $ pytest tests.test_sparsemax
115 |
116 |
117 | Deploying
118 | ---------
119 |
120 | A reminder for the maintainers on how to deploy.
121 | Make sure all your changes are committed (including an entry in HISTORY.rst).
122 | Then run::
123 |
124 | $ bump2version patch # possible: major / minor / patch
125 | $ git push
126 | $ git push --tags
127 |
128 | Travis will then deploy to PyPI if tests pass.
129 |
--------------------------------------------------------------------------------
/HISTORY.rst:
--------------------------------------------------------------------------------
1 | =======
2 | History
3 | =======
4 |
5 | 0.1.0 (2020-05-25)
6 | ------------------
7 |
8 | * First release on PyPI.
9 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020, Laksh Aithani
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
23 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include AUTHORS.rst
2 | include CONTRIBUTING.rst
3 | include HISTORY.rst
4 | include LICENSE
5 | include README.rst
6 |
7 | recursive-include tests *
8 | recursive-exclude * __pycache__
9 | recursive-exclude * *.py[co]
10 |
11 | recursive-include docs *.rst conf.py Makefile make.bat *.jpg *.png *.gif
12 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: clean clean-test clean-pyc clean-build docs help
2 | .DEFAULT_GOAL := help
3 |
4 | define BROWSER_PYSCRIPT
5 | import os, webbrowser, sys
6 |
7 | from urllib.request import pathname2url
8 |
9 | webbrowser.open("file://" + pathname2url(os.path.abspath(sys.argv[1])))
10 | endef
11 | export BROWSER_PYSCRIPT
12 |
13 | define PRINT_HELP_PYSCRIPT
14 | import re, sys
15 |
16 | for line in sys.stdin:
17 | match = re.match(r'^([a-zA-Z_-]+):.*?## (.*)$$', line)
18 | if match:
19 | target, help = match.groups()
20 | print("%-20s %s" % (target, help))
21 | endef
22 | export PRINT_HELP_PYSCRIPT
23 |
24 | BROWSER := python -c "$$BROWSER_PYSCRIPT"
25 |
26 | help:
27 | @python -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST)
28 |
29 | clean: clean-build clean-pyc clean-test ## remove all build, test, coverage and Python artifacts
30 |
31 | clean-build: ## remove build artifacts
32 | rm -fr build/
33 | rm -fr dist/
34 | rm -fr .eggs/
35 | find . -name '*.egg-info' -exec rm -fr {} +
36 | find . -name '*.egg' -exec rm -f {} +
37 |
38 | clean-pyc: ## remove Python file artifacts
39 | find . -name '*.pyc' -exec rm -f {} +
40 | find . -name '*.pyo' -exec rm -f {} +
41 | find . -name '*~' -exec rm -f {} +
42 | find . -name '__pycache__' -exec rm -fr {} +
43 |
44 | clean-test: ## remove test and coverage artifacts
45 | rm -fr .tox/
46 | rm -f .coverage
47 | rm -fr htmlcov/
48 | rm -fr .pytest_cache
49 |
50 | lint: ## check style with flake8
51 | flake8 sparsemax tests
52 |
53 | test: ## run tests quickly with the default Python
54 | pytest
55 |
56 | test-all: ## run tests on every Python version with tox
57 | tox
58 |
59 | coverage: ## check code coverage quickly with the default Python
60 | coverage run --source sparsemax -m pytest
61 | coverage report -m
62 | coverage html
63 | $(BROWSER) htmlcov/index.html
64 |
65 | docs: ## generate Sphinx HTML documentation, including API docs
66 | rm -f docs/sparsemax.rst
67 | rm -f docs/modules.rst
68 | sphinx-apidoc -o docs/ sparsemax
69 | $(MAKE) -C docs clean
70 | $(MAKE) -C docs html
71 | $(BROWSER) docs/_build/html/index.html
72 |
73 | servedocs: docs ## compile the docs watching for changes
74 | watchmedo shell-command -p '*.rst' -c '$(MAKE) -C docs html' -R -D .
75 |
76 | release: dist ## package and upload a release
77 | twine upload dist/*
78 |
79 | dist: clean ## builds source and wheel package
80 | python setup.py sdist
81 | python setup.py bdist_wheel
82 | ls -l dist
83 |
84 | install: clean ## install the package to the active Python's site-packages
85 | python setup.py install
86 |
--------------------------------------------------------------------------------
/README.rst:
--------------------------------------------------------------------------------
1 | =========
2 | sparsemax
3 | =========
4 |
5 |
6 | .. image:: https://img.shields.io/pypi/v/sparsemax.svg
7 | :target: https://pypi.python.org/pypi/sparsemax
8 |
9 | .. image:: https://img.shields.io/travis/aced125/sparsemax.svg
10 | :target: https://travis-ci.com/aced125/sparsemax
11 |
12 | .. image:: https://readthedocs.org/projects/sparsemax/badge/?version=latest
13 | :target: https://sparsemax.readthedocs.io/en/latest/?badge=latest
14 | :alt: Documentation Status
15 |
16 |
17 | .. image:: https://pyup.io/repos/github/aced125/sparsemax/shield.svg
18 | :target: https://pyup.io/repos/github/aced125/sparsemax/
19 | :alt: Updates
20 |
21 | .. image:: coverage.svg
22 |
23 |
24 |
25 |
26 | A PyTorch implementation of SparseMax (https://arxiv.org/pdf/1602.02068.pdf) with gradients checked and tested
27 |
28 | Sparsemax is an alternative to softmax when one wants to generate
29 | hard probability distributions. It has been used to great effect in recent papers like
30 | ProtoAttend (https://arxiv.org/pdf/1902.06292v4.pdf).
31 |
32 | Installation
33 | ------------
34 |
35 | .. code-block:: bash
36 |
37 | pip install -U sparsemax
38 |
39 |
40 | Usage
41 | -----
42 |
43 | Use as if it was :code:`nn.Softmax()`! Nice and simple.
44 |
45 | .. code-block:: python
46 |
47 | from sparsemax import Sparsemax
48 | import torch
49 | import torch.nn as nn
50 |
51 | sparsemax = Sparsemax(dim=-1)
52 | softmax = torch.nn.Softmax(dim=-1)
53 |
54 | logits = torch.randn(2, 3, 5)
55 | logits.requires_grad = True
56 | print("\nLogits")
57 | print(logits)
58 |
59 | softmax_probs = softmax(logits)
60 | print("\nSoftmax probabilities")
61 | print(softmax_probs)
62 |
63 | sparsemax_probs = sparsemax(logits)
64 | print("\nSparsemax probabilities")
65 | print(sparsemax_probs)
66 |
67 |
68 | Advantages over existing implementations
69 | ----------------------------------------
70 | This repo borrows heavily from: https://github.com/KrisKorrel/sparsemax-pytorch
71 |
72 | However, there are a few key advantages:
73 |
74 | 1. Backward pass equations implemented natively as a :code:`torch.autograd.Function`, **resulting in 30% speedup**, compared to the above repository.
75 | 2. The package is **easily pip-installable** (no need to copy the code).
76 | 3. The package works for **multi-dimensional tensors, operating over any axis**.
77 | 4. The operator **forward and backward passes are tested** (backward-pass check due to :code:`torch.autograd.gradcheck`
78 |
79 |
80 | Check that gradients are computed correctly
81 | -------------------------------------------
82 |
83 | .. code-block:: python
84 |
85 | from torch.autograd import gradcheck
86 | from sparsemax import Sparsemax
87 |
88 | input = (torch.randn(6, 3, 20,dtype=torch.double,requires_grad=True))
89 | test = gradcheck(sparsemax, input, eps=1e-6, atol=1e-4)
90 | print(test)
91 |
92 |
93 |
94 | Credits
95 | -------
96 |
97 | This package was created with Cookiecutter_ and the `audreyr/cookiecutter-pypackage`_ project template.
98 |
99 | .. _Cookiecutter: https://github.com/audreyr/cookiecutter
100 | .. _`audreyr/cookiecutter-pypackage`: https://github.com/audreyr/cookiecutter-pypackage
101 |
--------------------------------------------------------------------------------
/coverage.svg:
--------------------------------------------------------------------------------
1 |
2 |
22 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = python -msphinx
7 | SPHINXPROJ = sparsemax
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/authors.rst:
--------------------------------------------------------------------------------
1 | .. include:: ../AUTHORS.rst
2 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #
3 | # sparsemax documentation build configuration file, created by
4 | # sphinx-quickstart on Fri Jun 9 13:47:02 2017.
5 | #
6 | # This file is execfile()d with the current directory set to its
7 | # containing dir.
8 | #
9 | # Note that not all possible configuration values are present in this
10 | # autogenerated file.
11 | #
12 | # All configuration values have a default; values that are commented out
13 | # serve to show the default.
14 |
15 | # If extensions (or modules to document with autodoc) are in another
16 | # directory, add these directories to sys.path here. If the directory is
17 | # relative to the documentation root, use os.path.abspath to make it
18 | # absolute, like shown here.
19 | #
20 | import os
21 | import sys
22 |
23 | sys.path.insert(0, os.path.abspath(".."))
24 |
25 | import sparsemax
26 |
27 | # -- General configuration ---------------------------------------------
28 |
29 | # If your documentation needs a minimal Sphinx version, state it here.
30 | #
31 | # needs_sphinx = '1.0'
32 |
33 | # Add any Sphinx extension module names here, as strings. They can be
34 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones.
35 | extensions = ["sphinx.ext.autodoc", "sphinx.ext.viewcode", "sphinxemoji.sphinxemoji"]
36 |
37 | # Add any paths that contain templates here, relative to this directory.
38 | templates_path = ["_templates"]
39 |
40 | # The suffix(es) of source filenames.
41 | # You can specify multiple suffix as a list of string:
42 | #
43 | # source_suffix = ['.rst', '.md']
44 | source_suffix = ".rst"
45 |
46 | # The master toctree document.
47 | master_doc = "index"
48 |
49 | # General information about the project.
50 | project = "sparsemax"
51 | copyright = "2020, Laksh Aithani"
52 | author = "Laksh Aithani"
53 |
54 | # The version info for the project you're documenting, acts as replacement
55 | # for |version| and |release|, also used in various other places throughout
56 | # the built documents.
57 | #
58 | # The short X.Y version.
59 | version = sparsemax.__version__
60 | # The full version, including alpha/beta/rc tags.
61 | release = sparsemax.__version__
62 |
63 | # The language for content autogenerated by Sphinx. Refer to documentation
64 | # for a list of supported languages.
65 | #
66 | # This is also used if you do content translation via gettext catalogs.
67 | # Usually you set "language" from the command line for these cases.
68 | language = None
69 |
70 | # List of patterns, relative to source directory, that match files and
71 | # directories to ignore when looking for source files.
72 | # This patterns also effect to html_static_path and html_extra_path
73 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
74 |
75 | # The name of the Pygments (syntax highlighting) style to use.
76 | pygments_style = "sphinx"
77 |
78 | # If true, `todo` and `todoList` produce output, else they produce nothing.
79 | todo_include_todos = False
80 |
81 |
82 | # -- Options for HTML output -------------------------------------------
83 |
84 | # The theme to use for HTML and HTML Help pages. See the documentation for
85 | # a list of builtin themes.
86 | #
87 | html_theme = "alabaster"
88 |
89 | # Theme options are theme-specific and customize the look and feel of a
90 | # theme further. For a list of options available for each theme, see the
91 | # documentation.
92 | #
93 | # html_theme_options = {}
94 |
95 | # Add any paths that contain custom static files (such as style sheets) here,
96 | # relative to this directory. They are copied after the builtin static files,
97 | # so a file named "default.css" will overwrite the builtin "default.css".
98 | html_static_path = ["_static"]
99 |
100 |
101 | # -- Options for HTMLHelp output ---------------------------------------
102 |
103 | # Output file base name for HTML help builder.
104 | htmlhelp_basename = "sparsemaxdoc"
105 |
106 |
107 | # -- Options for LaTeX output ------------------------------------------
108 |
109 | latex_elements = {
110 | # The paper size ('letterpaper' or 'a4paper').
111 | #
112 | # 'papersize': 'letterpaper',
113 | # The font size ('10pt', '11pt' or '12pt').
114 | #
115 | # 'pointsize': '10pt',
116 | # Additional stuff for the LaTeX preamble.
117 | #
118 | # 'preamble': '',
119 | # Latex figure (float) alignment
120 | #
121 | # 'figure_align': 'htbp',
122 | }
123 |
124 | # Grouping the document tree into LaTeX files. List of tuples
125 | # (source start file, target name, title, author, documentclass
126 | # [howto, manual, or own class]).
127 | latex_documents = [
128 | (master_doc, "sparsemax.tex", "sparsemax Documentation", "Laksh Aithani", "manual"),
129 | ]
130 |
131 |
132 | # -- Options for manual page output ------------------------------------
133 |
134 | # One entry per manual page. List of tuples
135 | # (source start file, name, description, authors, manual section).
136 | man_pages = [(master_doc, "sparsemax", "sparsemax Documentation", [author], 1)]
137 |
138 |
139 | # -- Options for Texinfo output ----------------------------------------
140 |
141 | # Grouping the document tree into Texinfo files. List of tuples
142 | # (source start file, target name, title, author,
143 | # dir menu entry, description, category)
144 | texinfo_documents = [
145 | (
146 | master_doc,
147 | "sparsemax",
148 | "sparsemax Documentation",
149 | author,
150 | "sparsemax",
151 | "One line description of project.",
152 | "Miscellaneous",
153 | ),
154 | ]
155 |
--------------------------------------------------------------------------------
/docs/contributing.rst:
--------------------------------------------------------------------------------
1 | .. include:: ../CONTRIBUTING.rst
2 |
--------------------------------------------------------------------------------
/docs/history.rst:
--------------------------------------------------------------------------------
1 | .. include:: ../HISTORY.rst
2 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | Welcome to sparsemax's documentation!
2 | ======================================
3 |
4 | .. toctree::
5 | :maxdepth: 2
6 | :caption: Contents:
7 |
8 | readme
9 | installation
10 | usage
11 | modules
12 | contributing
13 | authors
14 | history
15 |
16 | Indices and tables
17 | ==================
18 | * :ref:`genindex`
19 | * :ref:`modindex`
20 | * :ref:`search`
21 |
--------------------------------------------------------------------------------
/docs/installation.rst:
--------------------------------------------------------------------------------
1 | .. highlight:: shell
2 |
3 | ============
4 | Installation
5 | ============
6 |
7 |
8 | Stable release
9 | --------------
10 |
11 | To install sparsemax, run this command in your terminal:
12 |
13 | .. code-block:: console
14 |
15 | $ pip install sparsemax
16 |
17 | This is the preferred method to install sparsemax, as it will always install the most recent stable release.
18 |
19 | If you don't have `pip`_ installed, this `Python installation guide`_ can guide
20 | you through the process.
21 |
22 | .. _pip: https://pip.pypa.io
23 | .. _Python installation guide: http://docs.python-guide.org/en/latest/starting/installation/
24 |
25 |
26 | From sources
27 | ------------
28 |
29 | The sources for sparsemax can be downloaded from the `Github repo`_.
30 |
31 | You can either clone the public repository:
32 |
33 | .. code-block:: console
34 |
35 | $ git clone git://github.com/aced125/sparsemax
36 |
37 | Or download the `tarball`_:
38 |
39 | .. code-block:: console
40 |
41 | $ curl -OJL https://github.com/aced125/sparsemax/tarball/master
42 |
43 | Once you have a copy of the source, you can install it with:
44 |
45 | .. code-block:: console
46 |
47 | $ python setup.py install
48 |
49 |
50 | .. _Github repo: https://github.com/aced125/sparsemax
51 | .. _tarball: https://github.com/aced125/sparsemax/tarball/master
52 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=python -msphinx
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 | set SPHINXPROJ=sparsemax
13 |
14 | if "%1" == "" goto help
15 |
16 | %SPHINXBUILD% >NUL 2>NUL
17 | if errorlevel 9009 (
18 | echo.
19 | echo.The Sphinx module was not found. Make sure you have Sphinx installed,
20 | echo.then set the SPHINXBUILD environment variable to point to the full
21 | echo.path of the 'sphinx-build' executable. Alternatively you may add the
22 | echo.Sphinx directory to PATH.
23 | echo.
24 | echo.If you don't have Sphinx installed, grab it from
25 | echo.http://sphinx-doc.org/
26 | exit /b 1
27 | )
28 |
29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
30 | goto end
31 |
32 | :help
33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
34 |
35 | :end
36 | popd
37 |
--------------------------------------------------------------------------------
/docs/readme.rst:
--------------------------------------------------------------------------------
1 | .. include:: ../README.rst
2 |
--------------------------------------------------------------------------------
/docs/usage.rst:
--------------------------------------------------------------------------------
1 | =====
2 | Usage
3 | =====
4 |
5 | To use sparsemax in a project::
6 |
7 | import sparsemax
8 |
--------------------------------------------------------------------------------
/requirements_dev.txt:
--------------------------------------------------------------------------------
1 | pip==19.2.3
2 | bump2version==0.5.11
3 | wheel==0.33.6
4 | watchdog==0.9.0
5 | flake8==3.7.8
6 | tox==3.14.0
7 | coverage==4.5.4
8 | Sphinx==1.8.5
9 | twine==1.14.0
10 |
11 | torch
12 |
13 | pytest==4.6.5
14 | pytest-runner==5.1
15 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [bumpversion]
2 | current_version = 0.1.8
3 | commit = True
4 | tag = True
5 |
6 | [bumpversion:file:setup.py]
7 | search = version='{current_version}'
8 | replace = version='{new_version}'
9 |
10 | [bumpversion:file:sparsemax/__init__.py]
11 | search = __version__ = '{current_version}'
12 | replace = __version__ = '{new_version}'
13 |
14 | [bdist_wheel]
15 | universal = 1
16 |
17 | [flake8]
18 | exclude = docs
19 |
20 | [aliases]
21 | test = pytest
22 |
23 | [tool:pytest]
24 | collect_ignore = ['setup.py']
25 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | """The setup script."""
4 |
5 | from setuptools import setup, find_packages
6 |
7 | with open("README.rst") as readme_file:
8 | readme = readme_file.read()
9 |
10 | with open("HISTORY.rst") as history_file:
11 | history = history_file.read()
12 |
13 | requirements = ["torch"]
14 |
15 | setup_requirements = [
16 | "pytest-runner",
17 | ]
18 |
19 | test_requirements = [
20 | "pytest>=3",
21 | ]
22 |
23 | # fmt: off
24 | setup(
25 | author="Laksh Aithani",
26 | author_email="lakshaithanii@gmail.com",
27 | python_requires=">=3.5",
28 | classifiers=[
29 | "Development Status :: 2 - Pre-Alpha",
30 | "Intended Audience :: Developers",
31 | "License :: OSI Approved :: MIT License",
32 | "Natural Language :: English",
33 | "Programming Language :: Python :: 3",
34 | "Programming Language :: Python :: 3.5",
35 | "Programming Language :: Python :: 3.6",
36 | "Programming Language :: Python :: 3.7",
37 | "Programming Language :: Python :: 3.8",
38 | ],
39 | description="Sparsemax pytorch",
40 | install_requires=requirements,
41 | license="MIT license",
42 | long_description=readme + "\n\n" + history,
43 | include_package_data=True,
44 | keywords="sparsemax",
45 | name="sparsemax",
46 | packages=find_packages(include=["sparsemax", "sparsemax.*"]),
47 | setup_requires=setup_requirements,
48 | test_suite="tests",
49 | tests_require=test_requirements,
50 | url="https://github.com/aced125/sparsemax",
51 | version='0.1.8',
52 | zip_safe=False,
53 | )
54 | # fmt: on
55 |
--------------------------------------------------------------------------------
/sparsemax/__init__.py:
--------------------------------------------------------------------------------
1 | """Top-level package for sparsemax."""
2 | # fmt: off
3 | __author__ = """Laksh Aithani"""
4 | __email__ = "lakshaithanii@gmail.com"
5 | __version__ = '0.1.8'
6 |
7 | from sparsemax.sparsemax import Sparsemax
8 | # fmt: on
9 |
--------------------------------------------------------------------------------
/sparsemax/sparsemax.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from sparsemax.utils import flatten_all_but_nth_dim, unflatten_all_but_nth_dim
4 |
5 |
6 | class Sparsemax(nn.Module):
7 | __constants__ = ["dim"]
8 |
9 | def __init__(self, dim=-1):
10 | """
11 | Sparsemax class as seen in https://arxiv.org/pdf/1602.02068.pdf
12 | Parameters
13 | ----------
14 | dim: The dimension we want to cast the operation over. Default -1
15 | """
16 | super(Sparsemax, self).__init__()
17 | self.dim = dim
18 |
19 | def __setstate__(self, state):
20 | self.__dict__.update(state)
21 | if not hasattr(self, "dim"):
22 | self.dim = None
23 |
24 | def forward(self, input):
25 | return SparsemaxFunction.apply(input, self.dim)
26 |
27 | def extra_repr(self):
28 | return f"dim={self.dim}"
29 |
30 |
31 | class SparsemaxFunction(torch.autograd.Function):
32 | @staticmethod
33 | def forward(ctx, input: torch.Tensor, dim: int = -1):
34 | input_dim = input.dim()
35 | if input_dim <= dim or dim < -input_dim:
36 | raise IndexError(
37 | f"Dimension out of range (expected to be in range of [-{input_dim}, {input_dim - 1}], but got {dim})"
38 | )
39 |
40 | # Save operating dimension to context
41 | ctx.needs_reshaping = input_dim > 2
42 | ctx.dim = dim
43 |
44 | if ctx.needs_reshaping:
45 | ctx, input = flatten_all_but_nth_dim(ctx, input)
46 |
47 | # Translate by max for numerical stability
48 | input = input - input.max(-1, keepdim=True).values.expand_as(input)
49 |
50 | zs = input.sort(-1, descending=True).values
51 | range = torch.arange(1, input.size()[-1] + 1)
52 | range = range.expand_as(input).to(input)
53 |
54 | # Determine sparsity of projection
55 | bound = 1 + range * zs
56 | is_gt = bound.gt(zs.cumsum(-1)).type(input.dtype)
57 | k = (is_gt * range).max(-1, keepdim=True).values
58 |
59 | # Compute threshold
60 | zs_sparse = is_gt * zs
61 |
62 | # Compute taus
63 | taus = (zs_sparse.sum(-1, keepdim=True) - 1) / k
64 | taus = taus.expand_as(input)
65 |
66 | output = torch.max(torch.zeros_like(input), input - taus)
67 |
68 | # Save context
69 | ctx.save_for_backward(output)
70 |
71 | # Reshape back to original shape
72 | if ctx.needs_reshaping:
73 | ctx, output = unflatten_all_but_nth_dim(ctx, output)
74 |
75 | return output
76 |
77 | @staticmethod
78 | def backward(ctx, grad_output):
79 | output, *_ = ctx.saved_tensors
80 |
81 | # Reshape if needed
82 | if ctx.needs_reshaping:
83 | ctx, grad_output = flatten_all_but_nth_dim(ctx, grad_output)
84 |
85 | # Compute gradient
86 | nonzeros = torch.ne(output, 0)
87 | num_nonzeros = nonzeros.sum(-1, keepdim=True)
88 | sum = (grad_output * nonzeros).sum(-1, keepdim=True) / num_nonzeros
89 | grad_input = nonzeros * (grad_output - sum.expand_as(grad_output))
90 |
91 | # Reshape back to original shape
92 | if ctx.needs_reshaping:
93 | ctx, grad_input = unflatten_all_but_nth_dim(ctx, grad_input)
94 |
95 | return grad_input, None
96 |
--------------------------------------------------------------------------------
/sparsemax/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def flatten_all_but_nth_dim(ctx, x: torch.Tensor):
5 | """
6 | Flattens tensor in all but 1 chosen dimension.
7 | Saves necessary context for backward pass and unflattening.
8 | """
9 |
10 | # transpose batch and nth dim
11 | x = x.transpose(0, ctx.dim)
12 |
13 | # Get and save original size in context for backward pass
14 | original_size = x.size()
15 | ctx.original_size = original_size
16 |
17 | # Flatten all dimensions except nth dim
18 | x = x.reshape(x.size(0), -1)
19 |
20 | # Transpose flattened dimensions to 0th dim, nth dim to last dim
21 | return ctx, x.transpose(0, -1)
22 |
23 |
24 | def unflatten_all_but_nth_dim(ctx, x: torch.Tensor):
25 | """
26 | Unflattens tensor using necessary context
27 | """
28 | # Tranpose flattened dim to last dim, nth dim to 0th dim
29 | x = x.transpose(0, 1)
30 |
31 | # Reshape to original size
32 | x = x.reshape(ctx.original_size)
33 |
34 | # Swap batch dim and nth dim
35 | return ctx, x.transpose(0, ctx.dim)
36 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | """Unit test package for sparsemax."""
2 |
--------------------------------------------------------------------------------
/tests/test_sparsemax.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | """Tests for `sparsemax` package."""
4 |
5 | import pytest
6 | from sparsemax import Sparsemax
7 | from torch.autograd import gradcheck
8 | import torch
9 |
10 |
11 | @pytest.mark.parametrize("dimension", [-4, -3, -2, -1, 0, 1, 2, 3])
12 | def test_sparsemax(dimension):
13 | sparsemax = Sparsemax(dimension)
14 | input = torch.randn(6, 3, 5, 4, dtype=torch.double, requires_grad=True)
15 | assert gradcheck(sparsemax, input, eps=1e-6, atol=1e-4)
16 |
17 |
18 | def test_sparsemax_invalid_dimension():
19 | sparsemax = Sparsemax(-7)
20 | input = torch.randn(6, 3, 5, 4, dtype=torch.double, requires_grad=True)
21 | with pytest.raises(IndexError):
22 | gradcheck(sparsemax, input, eps=1e-6, atol=1e-4)
23 |
--------------------------------------------------------------------------------
/tox.ini:
--------------------------------------------------------------------------------
1 | [tox]
2 | envlist = py35, py36, py37, py38, flake8
3 |
4 | [travis]
5 | python =
6 | 3.8: py38
7 | 3.7: py37
8 | 3.6: py36
9 | 3.5: py35
10 |
11 | [testenv:flake8]
12 | basepython = python
13 | deps = flake8
14 | commands = flake8 sparsemax tests
15 |
16 | [testenv]
17 | setenv =
18 | PYTHONPATH = {toxinidir}
19 | deps =
20 | -r{toxinidir}/requirements_dev.txt
21 | ; If you want to make tox run the tests with the same versions, create a
22 | ; requirements.txt with the pinned versions and uncomment the following line:
23 | ; -r{toxinidir}/requirements.txt
24 | commands =
25 | pip install -U pip
26 | pytest --basetemp={envtmpdir}
27 |
28 |
--------------------------------------------------------------------------------