├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── fusedmax.png └── pytorch ├── MANIFEST.in ├── setup.py └── torchsparseattn ├── __init__.py ├── _fused.pyx ├── _fused_jv.pyx ├── _isotonic.pyx ├── base.py ├── fused.py ├── isotonic.py ├── oscar.py ├── sparsemax.py ├── test_attention.py ├── test_fused.py ├── test_oscar.py └── test_sparsemax.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: false 2 | language: python 3 | dist: xenial 4 | 5 | python: 6 | - "2.7" 7 | - "3.7" 8 | 9 | env: 10 | - TORCH_VERSION=1.0.1 11 | - TORCH_VERSION=0.4.1 12 | 13 | cache: 14 | apt: true 15 | directories: 16 | - $HOME/.cache/pip 17 | 18 | install: 19 | 20 | - wget http://repo.continuum.io/miniconda/Miniconda-3.6.0-Linux-x86_64.sh -O miniconda.sh 21 | - bash miniconda.sh -b -p $HOME/miniconda 22 | - export PATH="$HOME/miniconda/bin:$PATH" 23 | - conda config --set always_yes yes --set changeps1 no 24 | - conda update -q conda 25 | - hash -r 26 | - conda info -a 27 | - conda create -q -n testenv python=$TRAVIS_PYTHON_VERSION 28 | - source activate testenv 29 | - conda install -c pytorch pytorch-cpu=$TORCH_VERSION numpy scipy pytest cython 30 | 31 | # install package 32 | - cd pytorch 33 | - pip install . 34 | 35 | script: 36 | - mkdir empty_dir 37 | - pytest pytest -vs --pyargs torchsparseattn 38 | - cd .. 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2017, Vlad Niculae 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sparse and structured attention mechanisms 2 | [![Build Status](https://travis-ci.org/vene/sparse-structured-attention.svg?branch=master)](https://travis-ci.org/vene/sparse-structured-attention) 3 | [![PyPI version](https://badge.fury.io/py/torchsparseattn.svg)](https://badge.fury.io/py/torchsparseattn) 4 | 5 |

6 | 7 | -------------------------------------------------------------------------------- 8 | 9 | Efficient implementation of structured sparsity inducing 10 | attention mechanisms: fusedmax, oscarmax and sparsemax. 11 | 12 | **Note**: If you are just looking for sparsemax, I recommend the implementation in the [entmax](https://github.com/deep-spin/entmax). 13 | 14 | Currently available for pytorch >= 0.4.1. (For older versions, use a previous 15 | release of this package.) Requires python >= 2.7, cython, numpy, scipy. 16 | 17 | Usage example: 18 | 19 | ```python 20 | 21 | In [1]: import torch 22 | In [2]: import torchsparseattn 23 | In [3]: a = torch.tensor([1, 2.1, 1.9], dtype=torch.double) 24 | In [4]: lengths = torch.tensor([3]) 25 | In [5]: fusedmax = torchsparseattn.Fusedmax(alpha=.1) 26 | In [6]: fusedmax(a, lengths) 27 | Out[6]: tensor([0.0000, 0.5000, 0.5000], dtype=torch.float64) 28 | ``` 29 | 30 | For details, check out our paper: 31 | 32 | > Vlad Niculae and Mathieu Blondel 33 | > A Regularized Framework for Sparse and Structured Neural Attention 34 | > In: Proceedings of NIPS, 2017. 35 | > https://arxiv.org/abs/1705.07704 36 | 37 | See also: 38 | 39 | > André F. T. Martins and Ramón Fernandez Astudillo 40 | > From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification 41 | > In: Proceedings of ICML, 2016 42 | > https://arxiv.org/abs/1602.02068 43 | 44 | > X. Zeng and M. Figueiredo, 45 | > The ordered weighted L1 norm: Atomic formulation, dual norm, and projections. 46 | > eprint http://arxiv.org/abs/1409.4271 47 | 48 | -------------------------------------------------------------------------------- /fusedmax.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vene/sparse-structured-attention/7003b3deaa513c034c3317a44d5b601f95e9e2c6/fusedmax.png -------------------------------------------------------------------------------- /pytorch/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include MANIFEST.in 2 | recursive-include torchsparseattn *.c *.h *.cpp *.pyx *.pxd 3 | -------------------------------------------------------------------------------- /pytorch/setup.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | from setuptools import setup, find_packages, Extension 3 | 4 | from Cython.Build import cythonize 5 | 6 | extensions = [ 7 | Extension('torchsparseattn._isotonic', 8 | ["torchsparseattn/_isotonic.pyx"], 9 | include_dirs=[numpy.get_include()]), 10 | Extension('torchsparseattn._fused', 11 | ["torchsparseattn/_fused.pyx"], 12 | include_dirs=[numpy.get_include()]), 13 | Extension('torchsparseattn._fused_jv', 14 | ["torchsparseattn/_fused_jv.pyx"]), 15 | ] 16 | 17 | extensions = cythonize(extensions) 18 | 19 | 20 | setup(name="torchsparseattn", 21 | version="0.3.dev0", 22 | description="Sparse structured attention mechanisms for pytorch", 23 | author="Vlad Niculae", 24 | author_email="vlad@vene.ro", 25 | license="BSD 3-clause", 26 | packages=find_packages(), 27 | ext_modules=extensions, 28 | install_requires=['numpy'], 29 | zip_safe=False, 30 | classifiers=[ 31 | 'Intended Audience :: Science/Research', 32 | 'Intended Audience :: Developers', 'License :: OSI Approved', 33 | 'Programming Language :: C', 'Programming Language :: Python', 34 | 'Topic :: Software Development', 35 | 'Topic :: Scientific/Engineering', 36 | 'Operating System :: Microsoft :: Windows', 37 | 'Operating System :: POSIX', 'Operating System :: Unix', 38 | 'Operating System :: MacOS'] 39 | ) 40 | -------------------------------------------------------------------------------- /pytorch/torchsparseattn/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused import Fusedmax, FusedProxFunction 2 | from .oscar import Oscarmax, OscarProxFunction 3 | from .sparsemax import Sparsemax, SparsemaxFunction 4 | 5 | __version__ = __VERSION__ = '0.3.dev0' 6 | -------------------------------------------------------------------------------- /pytorch/torchsparseattn/_fused.pyx: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # cython: cdivision=True 3 | # cython: boundscheck=False 4 | # cython: wraparound=False 5 | # 6 | # Authors: Fabian Pedregosa 7 | # Bundled file from lightning library 8 | 9 | """ 10 | These are some helper functions to compute the proximal operator of some common penalties 11 | """ 12 | 13 | cimport numpy as np 14 | from cython cimport floating 15 | 16 | cpdef prox_tv1d(np.ndarray[ndim=1, dtype=floating] w, floating stepsize): 17 | """ 18 | Computes the proximal operator of the 1-dimensional total variation operator. 19 | 20 | This solves a problem of the form 21 | 22 | argmin_x TV(x) + (1/(2 stepsize)) ||x - w||^2 23 | 24 | where TV(x) is the one-dimensional total variation 25 | 26 | Parameters 27 | ---------- 28 | w: array 29 | vector of coefficieents 30 | stepsize: float 31 | step size (sometimes denoted gamma) in proximal objective function 32 | 33 | References 34 | ---------- 35 | Condat, Laurent. "A direct algorithm for 1D total variation denoising." 36 | IEEE Signal Processing Letters (2013) 37 | """ 38 | cdef long width, k, k0, kplus, kminus 39 | cdef floating umin, umax, vmin, vmax, twolambda, minlambda 40 | width = w.size 41 | 42 | # /to avoid invalid memory access to input[0] and invalid lambda values 43 | if width > 0 and stepsize >= 0: 44 | k, k0 = 0, 0 # k: current sample location, k0: beginning of current segment 45 | umin = stepsize # u is the dual variable 46 | umax = - stepsize 47 | vmin = w[0] - stepsize 48 | vmax = w[0] + stepsize # bounds for the segment's value 49 | kplus = 0 50 | kminus = 0 # last positions where umax=-lambda, umin=lambda, respectively 51 | twolambda = 2.0 * stepsize # auxiliary variable 52 | minlambda = -stepsize # auxiliary variable 53 | while True: # simple loop, the exit test is inside 54 | while k >= width-1: # we use the right boundary condition 55 | if umin < 0.0: # vmin is too high -> negative jump necessary 56 | while True: 57 | w[k0] = vmin 58 | k0 += 1 59 | if k0 > kminus: 60 | break 61 | k = k0 62 | kminus = k 63 | vmin = w[kminus] 64 | umin = stepsize 65 | umax = vmin + umin - vmax 66 | elif umax > 0.0: # vmax is too low -> positive jump necessary 67 | while True: 68 | w[k0] = vmax 69 | k0 += 1 70 | if k0 > kplus: 71 | break 72 | k = k0 73 | kplus = k 74 | vmax = w[kplus] 75 | umax = minlambda 76 | umin = vmax + umax -vmin 77 | else: 78 | vmin += umin / (k-k0+1) 79 | while True: 80 | w[k0] = vmin 81 | k0 += 1 82 | if k0 > k: 83 | break 84 | return 85 | umin += w[k + 1] - vmin 86 | if umin < minlambda: # negative jump necessary 87 | while True: 88 | w[k0] = vmin 89 | k0 += 1 90 | if k0 > kminus: 91 | break 92 | k = k0 93 | kminus = k 94 | kplus = kminus 95 | vmin = w[kplus] 96 | vmax = vmin + twolambda 97 | umin = stepsize 98 | umax = minlambda 99 | else: 100 | umax += w[k + 1] - vmax 101 | if umax > stepsize: 102 | while True: 103 | w[k0] = vmax 104 | k0 += 1 105 | if k0 > kplus: 106 | break 107 | k = k0 108 | kminus = k 109 | kplus = kminus 110 | vmax = w[kplus] 111 | vmin = vmax - twolambda 112 | umin = stepsize 113 | umax = minlambda 114 | else: # no jump necessary, we continue 115 | k += 1 116 | if umin >= stepsize: # update of vmin 117 | kminus = k 118 | vmin += (umin - stepsize) / (kminus - k0 + 1) 119 | umin = stepsize 120 | if umax <= minlambda: # update of vmax 121 | kplus = k 122 | vmax += (umax + stepsize) / (kplus - k0 + 1) 123 | umax = minlambda 124 | -------------------------------------------------------------------------------- /pytorch/torchsparseattn/_fused_jv.pyx: -------------------------------------------------------------------------------- 1 | cimport cython 2 | from cython cimport floating 3 | 4 | 5 | @cython.boundscheck(False) 6 | @cython.wraparound(False) 7 | @cython.cdivision(True) 8 | def _inplace_fused_prox_jv(floating[::1] y_hat, floating[::1] dout): 9 | cdef Py_ssize_t n_features = dout.shape[0] 10 | cdef Py_ssize_t i, last_ix 11 | cdef unsigned int n 12 | cdef floating acc 13 | for i in range(n_features + 1): 14 | if i in (0, n_features) or y_hat[i] != y_hat[i - 1]: 15 | if i > 0: 16 | dout[last_ix:i] = acc / n 17 | 18 | if i < n_features: 19 | last_ix = i 20 | acc = dout[i] 21 | n = 1 22 | 23 | else: 24 | acc += dout[i] 25 | n += 1 26 | return dout 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /pytorch/torchsparseattn/_isotonic.pyx: -------------------------------------------------------------------------------- 1 | # Author: Nelle Varoquaux, Andrew Tulloch, Antony Lee 2 | 3 | # Uses the pool adjacent violators algorithm (PAVA), with the 4 | # enhancement of searching for the longest decreasing subsequence to 5 | # pool at each step. 6 | 7 | import numpy as np 8 | cimport numpy as np 9 | cimport cython 10 | from cython cimport floating 11 | 12 | 13 | @cython.boundscheck(False) 14 | @cython.wraparound(False) 15 | @cython.cdivision(True) 16 | def _inplace_contiguous_isotonic_regression(floating[::1] y, floating[::1] w): 17 | cdef: 18 | Py_ssize_t n = y.shape[0], i, k 19 | floating prev_y, sum_wy, sum_w 20 | Py_ssize_t[::1] target = np.arange(n, dtype=np.intp) 21 | 22 | # target describes a list of blocks. At any time, if [i..j] (inclusive) is 23 | # an active block, then target[i] := j and target[j] := i. 24 | 25 | # For "active" indices (block starts): 26 | # w[i] := sum{w_orig[j], j=[i..target[i]]} 27 | # y[i] := sum{y_orig[j]*w_orig[j], j=[i..target[i]]} / w[i] 28 | 29 | with nogil: 30 | i = 0 31 | while i < n: 32 | k = target[i] + 1 33 | if k == n: 34 | break 35 | if y[i] < y[k]: 36 | i = k 37 | continue 38 | sum_wy = w[i] * y[i] 39 | sum_w = w[i] 40 | while True: 41 | # We are within a decreasing subsequence. 42 | prev_y = y[k] 43 | sum_wy += w[k] * y[k] 44 | sum_w += w[k] 45 | k = target[k] + 1 46 | if k == n or prev_y < y[k]: 47 | # Non-singleton decreasing subsequence is finished, 48 | # update first entry. 49 | y[i] = sum_wy / sum_w 50 | w[i] = sum_w 51 | target[i] = k - 1 52 | target[k - 1] = i 53 | if i > 0: 54 | # Backtrack if we can. This makes the algorithm 55 | # single-pass and ensures O(n) complexity. 56 | i = target[i - 1] 57 | # Otherwise, restart from the same point. 58 | break 59 | # Reconstruct the solution. 60 | i = 0 61 | while i < n: 62 | k = target[i] + 1 63 | y[i + 1 : k] = y[i] 64 | i = k 65 | 66 | 67 | @cython.boundscheck(False) 68 | @cython.wraparound(False) 69 | @cython.cdivision(True) 70 | def _make_unique(np.ndarray[dtype=floating] X, 71 | np.ndarray[dtype=floating] y, 72 | np.ndarray[dtype=floating] sample_weights): 73 | """Average targets for duplicate X, drop duplicates. 74 | 75 | Aggregates duplicate X values into a single X value where 76 | the target y is a (sample_weighted) average of the individual 77 | targets. 78 | 79 | Assumes that X is ordered, so that all duplicates follow each other. 80 | """ 81 | unique_values = len(np.unique(X)) 82 | if unique_values == len(X): 83 | return X, y, sample_weights 84 | cdef np.ndarray[dtype=floating] y_out = np.empty(unique_values) 85 | cdef np.ndarray[dtype=floating] x_out = np.empty(unique_values) 86 | cdef np.ndarray[dtype=floating] weights_out = np.empty(unique_values) 87 | 88 | cdef floating current_x = X[0] 89 | cdef floating current_y = 0 90 | cdef floating current_weight = 0 91 | cdef floating y_old = 0 92 | cdef int i = 0 93 | cdef int current_count = 0 94 | cdef int j 95 | cdef floating x 96 | cdef int n_samples = len(X) 97 | for j in range(n_samples): 98 | x = X[j] 99 | if x != current_x: 100 | # next unique value 101 | x_out[i] = current_x 102 | weights_out[i] = current_weight / current_count 103 | y_out[i] = current_y / current_weight 104 | i += 1 105 | current_x = x 106 | current_weight = sample_weights[j] 107 | current_y = y[j] * sample_weights[j] 108 | current_count = 1 109 | else: 110 | current_weight += sample_weights[j] 111 | current_y += y[j] * sample_weights[j] 112 | current_count += 1 113 | 114 | x_out[i] = current_x 115 | weights_out[i] = current_weight / current_count 116 | y_out[i] = current_y / current_weight 117 | return x_out, y_out, weights_out 118 | -------------------------------------------------------------------------------- /pytorch/torchsparseattn/base.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch import autograd as ta 3 | 4 | class _BaseBatchProjection(ta.Function): 5 | """Applies a sample-wise normalizing projection over a batch.""" 6 | 7 | def forward(self, x, lengths=None): 8 | 9 | requires_squeeze = False 10 | if x.dim() == 1: 11 | x = x.unsqueeze(0) 12 | requires_squeeze = True 13 | 14 | n_samples, max_dim = x.size() 15 | 16 | has_lengths = True 17 | if lengths is None: 18 | has_lengths = False 19 | lengths = [max_dim] * n_samples 20 | 21 | y_star = x.new() 22 | y_star.resize_as_(x) 23 | y_star.zero_() 24 | 25 | for i in range(n_samples): 26 | y_star[i, :lengths[i]] = self.project(x[i, :lengths[i]]) 27 | 28 | if requires_squeeze: 29 | y_star = y_star.squeeze() 30 | 31 | self.mark_non_differentiable(y_star) 32 | if has_lengths: 33 | self.mark_non_differentiable(lengths) 34 | self.save_for_backward(y_star, lengths) 35 | else: 36 | self.save_for_backward(y_star) 37 | 38 | return y_star 39 | 40 | def backward(self, dout): 41 | 42 | if not self.needs_input_grad[0]: 43 | return None 44 | 45 | if len(self.needs_input_grad) > 1 and self.needs_input_grad[1]: 46 | raise ValueError("Cannot differentiate {} w.r.t. the " 47 | "sequence lengths".format(self.__name__)) 48 | 49 | saved = self.saved_tensors 50 | if len(saved) == 2: 51 | y_star, lengths = saved 52 | else: 53 | y_star, = saved 54 | lengths = None 55 | 56 | requires_squeeze = False 57 | if y_star.dim() == 1: 58 | y_star = y_star.unsqueeze(0) 59 | dout = dout.unsqueeze(0) 60 | requires_squeeze = True 61 | 62 | n_samples, max_dim = y_star.size() 63 | din = dout.new() 64 | din.resize_as_(y_star) 65 | din.zero_() 66 | 67 | if lengths is None: 68 | lengths = [max_dim] * n_samples 69 | 70 | for i in range(n_samples): 71 | din[i, :lengths[i]] = self.project_jv(dout[i, :lengths[i]], 72 | y_star[i, :lengths[i]]) 73 | 74 | if requires_squeeze: 75 | din = din.squeeze() 76 | 77 | return din, None 78 | -------------------------------------------------------------------------------- /pytorch/torchsparseattn/fused.py: -------------------------------------------------------------------------------- 1 | """Fusedmax attention 2 | 3 | Clusters neighboring attention weights into groups with equal weight. 4 | 5 | A Regularized Framework for Sparse and Structured Neural Attention 6 | Vlad Niculae, Mathieu Blondel 7 | https://arxiv.org/abs/1705.07704 8 | """ 9 | 10 | from __future__ import division 11 | 12 | import torch 13 | from torch import nn 14 | from torch import autograd as ta 15 | import warnings 16 | 17 | from .base import _BaseBatchProjection 18 | from .sparsemax import SparsemaxFunction 19 | from ._fused import prox_tv1d 20 | 21 | 22 | def _inplace_fused_prox_jv_slow(y_hat, dout): 23 | """not efficient in python for long seqs, but template for a cython impl""" 24 | 25 | n_features = len(dout) 26 | 27 | for i in range(n_features + 1): 28 | if i in (0, n_features) or y_hat[i] != y_hat[i - 1]: 29 | if i > 0: 30 | dout[last_ix:i] = acc / n 31 | 32 | if i < n_features: 33 | last_ix = i 34 | acc = dout[i] 35 | n = 1 36 | else: 37 | acc += dout[i] 38 | n += 1 39 | return dout 40 | 41 | 42 | try: 43 | from ._fused_jv import _inplace_fused_prox_jv 44 | except ImportError: 45 | warnings.warn("Could not import cython implementation of fused backward " 46 | "pass. Slow implementation used instead.") 47 | _inplace_fused_prox_jv = _inplace_fused_prox_jv_slow 48 | 49 | 50 | def fused_prox_jv_slow(y_hat, dout): 51 | dout = dout.clone() 52 | _inplace_fused_prox_jv_slow(y_hat, dout) 53 | return dout 54 | 55 | 56 | def fused_prox_jv_fast(y_hat, dout): 57 | dout = dout.clone() 58 | _inplace_fused_prox_jv(y_hat.detach().numpy(), dout.numpy()) 59 | return dout 60 | 61 | 62 | class FusedProxFunction(_BaseBatchProjection): 63 | 64 | def __init__(self, alpha=1): 65 | self.alpha = alpha 66 | 67 | def project(self, x): 68 | x_np = x.detach().numpy().copy() 69 | prox_tv1d(x_np, self.alpha) 70 | y_hat = torch.from_numpy(x_np) 71 | return y_hat 72 | 73 | def project_jv(self, dout, y_hat): 74 | dout = dout.clone() 75 | _inplace_fused_prox_jv(y_hat.detach().numpy(), dout.numpy()) 76 | return dout 77 | 78 | 79 | class Fusedmax(nn.Module): 80 | def __init__(self, alpha=1): 81 | self.alpha = alpha 82 | super(Fusedmax, self).__init__() 83 | 84 | def forward(self, x, lengths=None): 85 | fused_prox = FusedProxFunction(self.alpha) 86 | sparsemax = SparsemaxFunction() 87 | return sparsemax(fused_prox(x, lengths), lengths) 88 | 89 | 90 | if __name__ == '__main__': 91 | from timeit import timeit 92 | torch.manual_seed(1) 93 | 94 | for dim in (5, 10, 50, 100, 500, 1000): 95 | 96 | x = torch.randn(dim) 97 | x_var = ta.Variable(x, requires_grad=True) 98 | y_hat = FusedProxFunction()(x_var).data 99 | dout = torch.arange(0, dim) 100 | print("dimension={}".format(dim)) 101 | print("slow", timeit("fused_prox_jv_slow(y_hat, dout)", 102 | globals=globals(), 103 | number=10000)) 104 | print("fast", timeit("fused_prox_jv_fast(y_hat, dout)", 105 | globals=globals(), 106 | number=10000)) 107 | -------------------------------------------------------------------------------- /pytorch/torchsparseattn/isotonic.py: -------------------------------------------------------------------------------- 1 | """ 2 | Isotonic Regression that preserves 32bit inputs. 3 | 4 | backported from scikit-learn pull request 5 | https://github.com/scikit-learn/scikit-learn/pull/9106""" 6 | 7 | import numpy as np 8 | 9 | from ._isotonic import _inplace_contiguous_isotonic_regression 10 | 11 | 12 | def isotonic_regression(y, sample_weight=None, y_min=None, y_max=None, 13 | increasing=True): 14 | """Solve the isotonic regression model:: 15 | 16 | min sum w[i] (y[i] - y_[i]) ** 2 17 | 18 | subject to y_min = y_[1] <= y_[2] ... <= y_[n] = y_max 19 | 20 | where: 21 | - y[i] are inputs (real numbers) 22 | - y_[i] are fitted 23 | - w[i] are optional strictly positive weights (default to 1.0) 24 | 25 | Read more in the :ref:`User Guide `. 26 | 27 | Parameters 28 | ---------- 29 | y : iterable of floating-point values 30 | The data. 31 | 32 | sample_weight : iterable of floating-point values, optional, default: None 33 | Weights on each point of the regression. 34 | If None, weight is set to 1 (equal weights). 35 | 36 | y_min : optional, default: None 37 | If not None, set the lowest value of the fit to y_min. 38 | 39 | y_max : optional, default: None 40 | If not None, set the highest value of the fit to y_max. 41 | 42 | increasing : boolean, optional, default: True 43 | Whether to compute ``y_`` is increasing (if set to True) or decreasing 44 | (if set to False) 45 | 46 | Returns 47 | ------- 48 | y_ : list of floating-point values 49 | Isotonic fit of y. 50 | 51 | References 52 | ---------- 53 | "Active set algorithms for isotonic regression; A unifying framework" 54 | by Michael J. Best and Nilotpal Chakravarti, section 3. 55 | """ 56 | order = np.s_[:] if increasing else np.s_[::-1] 57 | # y = as_float_array(y) # avoid sklearn dependency; we always pass arrays 58 | y = np.array(y[order], dtype=y.dtype) 59 | if sample_weight is None: 60 | sample_weight = np.ones(len(y), dtype=y.dtype) 61 | else: 62 | sample_weight = np.array(sample_weight[order], dtype=y.dtype) 63 | 64 | _inplace_contiguous_isotonic_regression(y, sample_weight) 65 | if y_min is not None or y_max is not None: 66 | # Older versions of np.clip don't accept None as a bound, so use np.inf 67 | if y_min is None: 68 | y_min = -np.inf 69 | if y_max is None: 70 | y_max = np.inf 71 | np.clip(y, y_min, y_max, y) 72 | return y[order] 73 | -------------------------------------------------------------------------------- /pytorch/torchsparseattn/oscar.py: -------------------------------------------------------------------------------- 1 | """Oscarmax attention 2 | 3 | Clusters attention weights into groups with equal weight, regardless of index. 4 | 5 | A Regularized Framework for Sparse and Structured Neural Attention 6 | Vlad Niculae, Mathieu Blondel 7 | https://arxiv.org/abs/1705.07704 8 | """ 9 | 10 | import numpy as np 11 | import torch 12 | from torch import nn 13 | from torch import autograd as ta 14 | 15 | from .isotonic import isotonic_regression 16 | from .base import _BaseBatchProjection 17 | from .sparsemax import SparsemaxFunction 18 | 19 | 20 | def oscar_prox_jv(y_hat, dout): 21 | y_hat = y_hat.detach().numpy() 22 | din = dout.clone().zero_() 23 | dout = dout.numpy() 24 | din_np = din.numpy() 25 | 26 | sign = np.sign(y_hat) 27 | y_hat = np.abs(y_hat) 28 | 29 | uniq, inv, counts = np.unique(y_hat, return_inverse=True, 30 | return_counts=True) 31 | n_unique = len(uniq) 32 | tmp = np.zeros((n_unique,), dtype=y_hat.dtype) 33 | np.add.at(tmp, inv, dout * sign) 34 | tmp /= counts 35 | tmp.take(inv, mode='clip', out=din_np) 36 | din_np *= sign 37 | return din 38 | 39 | 40 | def prox_owl(v, w): 41 | """Proximal operator of the OWL norm dot(w, reversed(sort(v))) 42 | 43 | Follows description and notation from: 44 | X. Zeng, M. Figueiredo, 45 | The ordered weighted L1 norm: Atomic formulation, dual norm, 46 | and projections. 47 | eprint http://arxiv.org/abs/1409.4271 48 | """ 49 | 50 | # wlog operate on absolute values 51 | v_abs = np.abs(v) 52 | ix = np.argsort(v_abs)[::-1] 53 | v_abs = v_abs[ix] 54 | # project to K+ (monotone non-negative decreasing cone) 55 | v_abs = isotonic_regression(v_abs - w, y_min=0, increasing=False) 56 | 57 | # undo the sorting 58 | inv_ix = np.zeros_like(ix) 59 | inv_ix[ix] = np.arange(len(v)) 60 | v_abs = v_abs[inv_ix] 61 | 62 | return np.sign(v) * v_abs 63 | 64 | 65 | def _oscar_weights(alpha, beta, size): 66 | w = np.arange(size - 1, -1, -1, dtype=np.float32) 67 | w *= beta 68 | w += alpha 69 | return w 70 | 71 | 72 | class OscarProxFunction(_BaseBatchProjection): 73 | """Proximal operator of the OSCAR regularizer. 74 | 75 | ||w||_oscar = alpha ||w||_1 + beta * sum_i 0 22 | rho = ind.masked_select(cond)[-1] 23 | tau = cssv.masked_select(cond)[-1] / rho 24 | w = torch.clamp(v - tau, min=0) 25 | return w 26 | 27 | 28 | def sparsemax_grad(dout, w_star): 29 | supp = w_star > 0 30 | masked = dout.masked_select(supp) 31 | nnz = supp.to(dtype=dout.dtype).sum() 32 | masked -= masked.sum() / nnz 33 | out = dout.new(dout.size()).zero_() 34 | out[supp] = masked 35 | return(out) 36 | 37 | 38 | class SparsemaxFunction(_BaseBatchProjection): 39 | 40 | def project(self, x): 41 | return project_simplex(x) 42 | 43 | def project_jv(self, dout, y_star): 44 | return sparsemax_grad(dout, y_star) 45 | 46 | 47 | class Sparsemax(nn.Module): 48 | 49 | def forward(self, x, lengths=None): 50 | sparsemax = SparsemaxFunction() 51 | return sparsemax(x, lengths) 52 | 53 | -------------------------------------------------------------------------------- /pytorch/torchsparseattn/test_attention.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Variable 6 | 7 | from . import Sparsemax, Fusedmax, Oscarmax 8 | 9 | 10 | class AttentionRegressor(nn.Module): 11 | 12 | def __init__(self, projection, n_features=100): 13 | super(AttentionRegressor, self).__init__() 14 | self.projection = projection 15 | self.attn_template = nn.Parameter(torch.Tensor(n_features)) 16 | self.attn_template.data.uniform_(-0.1, 0.1) 17 | 18 | def forward(self, X, lengths): 19 | 20 | # compute scores for each input word 21 | scores = torch.matmul(X, self.attn_template) 22 | weights = self.projection(scores, lengths) 23 | weighted_avg = torch.bmm(X.transpose(1, 2), 24 | weights.unsqueeze(-1)).squeeze(-1) 25 | pred = weighted_avg.sum(dim=1) # very simple prediction rule 26 | return pred 27 | 28 | 29 | @pytest.mark.parametrize('projection', [Sparsemax(), 30 | Fusedmax(0.1), 31 | Oscarmax(0.01)]) 32 | def test_attention(projection): 33 | n_samples = 20 34 | max_len = 10 35 | torch.manual_seed(1) 36 | n_features = 50 37 | 38 | X = torch.zeros(n_samples, max_len, n_features) 39 | 40 | # generate lengths in [1, max_len] 41 | lengths = 1 + (torch.rand(n_samples) * max_len).long() 42 | 43 | for i in range(n_samples): 44 | X[i, :lengths[i], :] = torch.randn(lengths[i], n_features) 45 | 46 | X = Variable(X) 47 | lengths = Variable(lengths) 48 | targets = Variable(torch.randn(n_samples)) 49 | 50 | regr = AttentionRegressor(projection, n_features=n_features) 51 | loss_func = nn.MSELoss() 52 | optim = torch.optim.SGD(regr.parameters(), lr=0.0001) 53 | 54 | pred = regr(X, lengths) 55 | 56 | init_obj = loss_func(pred, targets) 57 | 58 | for it in range(50): 59 | optim.zero_grad() 60 | pred = regr(X, lengths) 61 | obj = loss_func(pred, targets) 62 | obj.backward() 63 | optim.step() 64 | 65 | final_obj = obj 66 | assert final_obj < init_obj 67 | assert regr.attn_template.grad.size() == (n_features,) 68 | -------------------------------------------------------------------------------- /pytorch/torchsparseattn/test_fused.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import pytest 4 | from numpy.testing import assert_allclose 5 | import torch 6 | from torch.autograd import gradcheck, Variable 7 | 8 | from .fused import fused_prox_jv_slow, fused_prox_jv_fast 9 | from .fused import FusedProxFunction 10 | 11 | 12 | def _fused_prox_jacobian(y_hat, dout=None): 13 | """reference naive implementation: construct the jacobian""" 14 | dim = y_hat.shape[0] 15 | groups = torch.zeros(dim) 16 | J = torch.zeros(dim, dim) 17 | current_group = 0 18 | 19 | for i in range(1, dim): 20 | if y_hat[i] == y_hat[i - 1]: 21 | groups[i] = groups[i - 1] 22 | else: 23 | current_group += 1 24 | groups[i] = current_group 25 | 26 | for i in range(dim): 27 | for j in range(dim): 28 | if groups[i] == groups[j]: 29 | n_fused = (groups == groups[i]).sum() 30 | J[i, j] = 1 / n_fused.to(y_hat.dtype) 31 | 32 | if dout is not None: 33 | return torch.mv(J, dout) 34 | else: 35 | return J 36 | 37 | 38 | @pytest.mark.parametrize('alpha', [0.001, 0.01, 0.1, 1]) 39 | def test_jv(alpha): 40 | 41 | torch.manual_seed(1) 42 | torch.set_default_tensor_type('torch.DoubleTensor') 43 | 44 | for _ in range(30): 45 | x = Variable(torch.randn(15)) 46 | dout = torch.randn(15) 47 | 48 | y_hat = FusedProxFunction(alpha=alpha)(x).data 49 | 50 | 51 | ref = _fused_prox_jacobian(y_hat, dout) 52 | din_slow = fused_prox_jv_slow(y_hat, dout) 53 | din_fast = fused_prox_jv_fast(y_hat, dout) 54 | assert_allclose(ref.numpy(), din_slow.numpy(), atol=1e-5) 55 | assert_allclose(ref.numpy(), din_fast.numpy(), atol=1e-5) 56 | 57 | 58 | @pytest.mark.parametrize('alpha', [0.001, 0.01, 0.1, 1]) 59 | def test_finite_diff(alpha): 60 | torch.manual_seed(1) 61 | torch.set_default_tensor_type('torch.DoubleTensor') 62 | 63 | for _ in range(30): 64 | x = Variable(torch.randn(20), requires_grad=True) 65 | func = FusedProxFunction(alpha=alpha) 66 | assert gradcheck(func, (x,), eps=1e-4, atol=1e-3) 67 | -------------------------------------------------------------------------------- /pytorch/torchsparseattn/test_oscar.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import pytest 4 | from numpy.testing import assert_allclose 5 | import numpy as np 6 | import torch 7 | from torch.autograd import gradcheck, Variable 8 | 9 | from .oscar import OscarProxFunction, oscar_prox_jv 10 | 11 | 12 | def _oscar_prox_jacobian(y_star, dout=None): 13 | y_star = y_star.numpy() 14 | dim = y_star.shape[0] 15 | J = torch.zeros(dim, dim) 16 | 17 | _, inv, counts = np.unique(np.abs(y_star), 18 | return_inverse=True, 19 | return_counts=True) 20 | 21 | for i in range(dim): 22 | for j in range(dim): 23 | if (inv[i] == inv[j] and 24 | y_star[i] != 0): 25 | J[i, j] = (np.sign(y_star[i]) * np.sign(y_star[j]) 26 | / counts[inv[i]]) 27 | if dout is not None: 28 | return torch.mv(J, dout) 29 | else: 30 | return J 31 | 32 | 33 | @pytest.mark.parametrize('alpha', [0.001, 0.01, 0.1, 1]) 34 | @pytest.mark.parametrize('beta', [0.001, 0.01, 0.1, 1]) 35 | def test_jv(alpha, beta): 36 | 37 | torch.manual_seed(1) 38 | torch.set_default_tensor_type('torch.DoubleTensor') 39 | 40 | for _ in range(30): 41 | x = Variable(torch.randn(15)) 42 | dout = torch.randn(15) 43 | y_hat = OscarProxFunction(alpha=alpha, beta=beta)(x).data 44 | 45 | ref = _oscar_prox_jacobian(y_hat, dout) 46 | din = oscar_prox_jv(y_hat, dout) 47 | assert_allclose(ref.numpy(), din.numpy(), atol=1e-5) 48 | 49 | 50 | @pytest.mark.parametrize('alpha', [0.001, 0.01, 0.1, 1]) 51 | @pytest.mark.parametrize('beta', [0.001, 0.01, 0.1, 1]) 52 | def test_finite_diff(alpha, beta): 53 | torch.manual_seed(1) 54 | torch.set_default_tensor_type('torch.DoubleTensor') 55 | 56 | for _ in range(30): 57 | x = Variable(torch.randn(20), requires_grad=True) 58 | func = OscarProxFunction(alpha, beta=beta) 59 | assert gradcheck(func, (x,), eps=1e-5, atol=1e-3) 60 | -------------------------------------------------------------------------------- /pytorch/torchsparseattn/test_sparsemax.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import gradcheck, Variable 3 | from .sparsemax import SparsemaxFunction 4 | 5 | 6 | def test_sparsemax(): 7 | 8 | torch.manual_seed(1) 9 | torch.set_default_tensor_type('torch.DoubleTensor') 10 | 11 | for _ in range(30): 12 | func = SparsemaxFunction() 13 | x = Variable(torch.randn(20), requires_grad=True) 14 | assert gradcheck(func, (x,), eps=1e-4, atol=1e-3) 15 | --------------------------------------------------------------------------------