├── LICENSE ├── README.md ├── attention ├── __init__.py └── attention.py ├── examples └── Pointer-Network-Argmin-Argmax.ipynb ├── setup.py ├── svgs ├── 190083ef7a1625fbc75f243cffb9c96d.svg ├── 1eb39a281b1e66935a51005b6beb9dbe.svg ├── 28e6b84adb66aca59d04ec9e227bfd3f.svg ├── 39c9d05724010ea29be9eb321b1422ec.svg ├── 39d2a848a943a7f5ec27272dad27c784.svg ├── 3cf4fbd05970446973fc3d9fa3fe3c41.svg ├── 5397f1268e113895a997a61e51165ffc.svg ├── a5a09669219f681bb51e176b190b0e4a.svg ├── a5d4c0a87edcc90e9dc7bb8a1845e86a.svg ├── da2cf8b162672dc46adcace06ec2740a.svg └── e73485aa867794d51ccd8725055d03a3.svg └── test └── test_attention.py /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2017, thom lake 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ```python 2 | def attend(query, context, value=None, score='dot', normalize='softmax', 3 | context_sizes=None, context_mask=None, return_weight=False 4 | ): 5 | """Attend to value (or context) by scoring each query and context. 6 | 7 | Args 8 | ---- 9 | query: Variable of size (B, M, D1) 10 | Batch of M query vectors. 11 | context: Variable of size (B, N, D2) 12 | Batch of N context vectors. 13 | value: Variable of size (B, N, P), default=None 14 | If given, the output vectors will be weighted 15 | combinations of the value vectors. 16 | Otherwise, the context vectors will be used. 17 | score: str or callable, default='dot' 18 | If score == 'dot', scores are computed 19 | as the dot product between context and 20 | query vectors. This Requires D1 == D2. 21 | Otherwise, score should be a callable: 22 | query context score 23 | (B,M,D1) (B,N,D2) -> (B,M,N) 24 | normalize: str, default='softmax' 25 | One of 'softmax', 'sigmoid', or 'identity'. 26 | Name of function used to map scores to weights. 27 | context_mask: Tensor of (B, M, N), default=None 28 | A Tensor used to mask context. Masked 29 | and unmasked entries should be filled 30 | appropriately for the normalization function. 31 | context_sizes: list[int], default=None, 32 | List giving the size of context for each item 33 | in the batch and used to compute a context_mask. 34 | If context_mask or context_sizes are not given, 35 | context is assumed to have fixed size. 36 | return_weight: bool, default=False 37 | If True, return the attention weight Tensor. 38 | 39 | Returns 40 | ------- 41 | output: Variable of size (B, M, P) 42 | If return_weight is False. 43 | weight, output: Variable of size (B, M, N), Variable of size (B, M, P) 44 | If return_weight is True. 45 | """ 46 | ``` 47 | 48 | Install 49 | ------- 50 | ```bash 51 | python setup.py install 52 | ``` 53 | 54 | Test 55 | ---- 56 | ```bash 57 | python -m pytest 58 | ``` 59 | Tested with pytorch 1.0.0 60 | 61 | About 62 | ----- 63 | Attention is used to focus processing on a particular region of input. 64 | The `attend` function provided by this package implements the most 65 | common attention mechanism [[1](#1), [2](#2), [3](#3), [4](#4)], which produces 66 | an output by taking a weighted combination of value vectors with weights 67 | from a scoring function operating over pairs of query and context vectors. 68 | 69 | Given query vector `q`, context vectors `c_1,...,c_n`, and value vectors 70 | `v_1,...,v_n` the attention score of `q` with `c_i` is given by 71 | 72 | ``` 73 | s_i = f(q, c_i) 74 | ``` 75 | 76 | Frequently `f` takes the form of a dot product between query and context vectors. 77 | 78 | ``` 79 | s_i = q^T c_i 80 | ``` 81 | 82 | The scores are passed through a normalization functions `g` (normally the softmax function). 83 | 84 | ``` 85 | w_i = g(s_1,...,s_n)_i 86 | ``` 87 | 88 | Finally, the output is computed as a weighted sum of the value vectors. 89 | 90 | ``` 91 | z = \sum_{i=1}^n w_i * v_i 92 | ``` 93 | 94 | In many applications [[1](#1), [4](#4), [5](#5)] attention is applied 95 | to the context vectors themselves, `v_i = c_i`. 96 | 97 | Sizes 98 | ----- 99 | This `attend` function provided by this package accepts 100 | batches of size `B` containing 101 | `M` query vectors of dimension `D1`, 102 | `N` context vectors of dimension `D2`, 103 | and optionally `N` value vectors of dimension `P`. 104 | 105 | Variable Length 106 | --------------- 107 | If the number of context vectors varies within a batch, a context 108 | can be ignored by forcing the corresponding weight to be zero. 109 | 110 | In the case of the softmax, this can be achieved by adding negative 111 | infinity to the corresponding score before normalization. 112 | Similarly, for elementwise normalization functions the weights can 113 | be multiplied by an appropriate {0,1} mask after normalization. 114 | 115 | To facilitate the above behavior, a context mask, with entries 116 | in `{-inf, 0}` or `{0, 1}` depending on the normalization function, 117 | can be passed to this function. The masks should have size `(B, M, N)`. 118 | 119 | Alternatively, a list can be passed giving the size of the context for 120 | each item in the batch. Appropriate masks will be created from these lists. 121 | 122 | Note that the size of output does not depend on the number of context vectors. 123 | Because of this, context positions are truly unaccounted for in the output. 124 | 125 | References 126 | ---------- 127 | #### [[1]](https://arxiv.org/abs/1409.0473) 128 | 129 | @article{bahdanau2014neural, 130 | title={Neural machine translation by jointly learning to align and translate}, 131 | author={Bahdanau, Dzmitry and Cho, Kyunghyun and Bengio, Yoshua}, 132 | journal={arXiv preprint arXiv:1409.0473}, 133 | year={2014} 134 | } 135 | 136 | #### [[2]](https://arxiv.org/abs/1410.5401) 137 | @article{graves2014neural, 138 | title={Neural turing machines}, 139 | author={Graves, Alex and Wayne, Greg and Danihelka, Ivo}, 140 | journal={arXiv preprint arXiv:1410.5401}, 141 | year={2014} 142 | } 143 | 144 | #### [[3]](https://arxiv.org/abs/1503.08895) 145 | 146 | @inproceedings{sukhbaatar2015end, 147 | title={End-to-end memory networks}, 148 | author={Sukhbaatar, Sainbayar and Weston, Jason and Fergus, Rob and others}, 149 | booktitle={Advances in neural information processing systems}, 150 | pages={2440--2448}, 151 | year={2015} 152 | } 153 | 154 | #### [[4]](https://distill.pub/2016/augmented-rnns/) 155 | 156 | @article{olah2016attention, 157 | title={Attention and augmented recurrent neural networks}, 158 | author={Olah, Chris and Carter, Shan}, 159 | journal={Distill}, 160 | volume={1}, 161 | number={9}, 162 | pages={e1}, 163 | year={2016} 164 | } 165 | 166 | #### [[5]](https://arxiv.org/abs/1506.03134) 167 | 168 | @inproceedings{vinyals2015pointer, 169 | title={Pointer networks}, 170 | author={Vinyals, Oriol and Fortunato, Meire and Jaitly, Navdeep}, 171 | booktitle={Advances in Neural Information Processing Systems}, 172 | pages={2692--2700}, 173 | year={2015} 174 | } 175 | -------------------------------------------------------------------------------- /attention/__init__.py: -------------------------------------------------------------------------------- 1 | from . attention import attend 2 | -------------------------------------------------------------------------------- /attention/attention.py: -------------------------------------------------------------------------------- 1 | from torch import FloatTensor 2 | from torch.autograd import Variable 3 | from torch.nn.functional import sigmoid, softmax 4 | 5 | 6 | def mask3d(value, sizes): 7 | """Mask entries in value with 0 based on sizes. 8 | 9 | Args 10 | ---- 11 | value: Tensor of size (B, N, D) 12 | Tensor to be masked. 13 | sizes: list of int 14 | List giving the number of valid values for each item 15 | in the batch. Positions beyond each size will be masked. 16 | 17 | Returns 18 | ------- 19 | value: 20 | Masked value. 21 | """ 22 | v_mask = 0 23 | v_unmask = 1 24 | mask = value.data.new(value.size()).fill_(v_unmask) 25 | n = mask.size(1) 26 | for i, size in enumerate(sizes): 27 | if size < n: 28 | mask[i,size:,:] = v_mask 29 | return Variable(mask) * value 30 | 31 | 32 | def fill_context_mask(mask, sizes, v_mask, v_unmask): 33 | """Fill attention mask inplace for a variable length context. 34 | 35 | Args 36 | ---- 37 | mask: Tensor of size (B, N, D) 38 | Tensor to fill with mask values. 39 | sizes: list[int] 40 | List giving the size of the context for each item in 41 | the batch. Positions beyond each size will be masked. 42 | v_mask: float 43 | Value to use for masked positions. 44 | v_unmask: float 45 | Value to use for unmasked positions. 46 | 47 | Returns 48 | ------- 49 | mask: 50 | Filled with values in {v_mask, v_unmask} 51 | """ 52 | mask.fill_(v_unmask) 53 | n_context = mask.size(2) 54 | for i, size in enumerate(sizes): 55 | if size < n_context: 56 | mask[i,:,size:] = v_mask 57 | return mask 58 | 59 | 60 | def dot(a, b): 61 | """Compute the dot product between pairs of vectors in 3D Variables. 62 | 63 | Args 64 | ---- 65 | a: Variable of size (B, M, D) 66 | b: Variable of size (B, N, D) 67 | 68 | Returns 69 | ------- 70 | c: Variable of size (B, M, N) 71 | c[i,j,k] = dot(a[i,j], b[i,k]) 72 | """ 73 | return a.bmm(b.transpose(1, 2)) 74 | 75 | 76 | def attend(query, context, value=None, score='dot', normalize='softmax', 77 | context_sizes=None, context_mask=None, return_weight=False 78 | ): 79 | """Attend to value (or context) by scoring each query and context. 80 | 81 | Args 82 | ---- 83 | query: Variable of size (B, M, D1) 84 | Batch of M query vectors. 85 | context: Variable of size (B, N, D2) 86 | Batch of N context vectors. 87 | value: Variable of size (B, N, P), default=None 88 | If given, the output vectors will be weighted 89 | combinations of the value vectors. 90 | Otherwise, the context vectors will be used. 91 | score: str or callable, default='dot' 92 | If score == 'dot', scores are computed 93 | as the dot product between context and 94 | query vectors. This Requires D1 == D2. 95 | Otherwise, score should be a callable: 96 | query context score 97 | (B,M,D1) (B,N,D2) -> (B,M,N) 98 | normalize: str, default='softmax' 99 | One of 'softmax', 'sigmoid', or 'identity'. 100 | Name of function used to map scores to weights. 101 | context_mask: Tensor of (B, M, N), default=None 102 | A Tensor used to mask context. Masked 103 | and unmasked entries should be filled 104 | appropriately for the normalization function. 105 | context_sizes: list[int], default=None, 106 | List giving the size of context for each item 107 | in the batch and used to compute a context_mask. 108 | If context_mask or context_sizes are not given, 109 | context is assumed to have fixed size. 110 | return_weight: bool, default=False 111 | If True, return the attention weight Tensor. 112 | 113 | Returns 114 | ------- 115 | output: Variable of size (B, M, P) 116 | If return_weight is False. 117 | weight, output: Variable of size (B, M, N), Variable of size (B, M, P) 118 | If return_weight is True. 119 | 120 | 121 | About 122 | ----- 123 | Attention is used to focus processing on a particular region of input. 124 | This function implements the most common attention mechanism [1, 2, 3], 125 | which produces an output by taking a weighted combination of value vectors 126 | with weights from by a scoring function operating over pairs of query and 127 | context vectors. 128 | 129 | Given query vector `q`, context vectors `c_1,...,c_n`, and value vectors 130 | `v_1,...,v_n` the attention score of `q` with `c_i` is given by 131 | 132 | s_i = f(q, c_i) 133 | 134 | Frequently, `f` is given by the dot product between query and context vectors. 135 | 136 | s_i = q^T c_i 137 | 138 | The scores are passed through a normalization functions g. 139 | This is normally the softmax function. 140 | 141 | w_i = g(s_1,...,s_n)_i 142 | 143 | Finally, the output is computed as a weighted 144 | combination of the values with the normalized scores. 145 | 146 | z = sum_{i=1}^n w_i * v_i 147 | 148 | In many applications [4, 5] the context and value vectors are the same, `v_i = c_i`. 149 | 150 | Sizes 151 | ----- 152 | This function accepts batches of size `B` containing 153 | `M` query vectors of dimension `D1`, 154 | `N` context vectors of dimension `D2`, 155 | and optionally `N` value vectors of dimension `P`. 156 | 157 | Variable Length Contexts 158 | ------------------------ 159 | If the number of context vectors varies within a batch, a context 160 | can be ignored by forcing the corresponding weight to be zero. 161 | 162 | In the case of the softmax, this can be achieved by adding negative 163 | infinity to the corresponding score before normalization. 164 | Similarly, for elementwise normalization functions the weights can 165 | be multiplied by an appropriate {0,1} mask after normalization. 166 | 167 | To facilitate the above behavior, a context mask, with entries 168 | in `{-inf, 0}` or `{0, 1}` depending on the normalization function, 169 | can be passed to this function. The masks should have size `(B, M, N)`. 170 | 171 | Alternatively, a list can be passed giving the size of the context for 172 | each item in the batch. Appropriate masks will be created from these lists. 173 | 174 | Note that the size of output does not depend on the number of context vectors. 175 | Because of this, context positions are truly unaccounted for in the output. 176 | 177 | References 178 | ---------- 179 | [1](https://arxiv.org/abs/1410.5401) 180 | @article{graves2014neural, 181 | title={Neural turing machines}, 182 | author={Graves, Alex and Wayne, Greg and Danihelka, Ivo}, 183 | journal={arXiv preprint arXiv:1410.5401}, 184 | year={2014} 185 | } 186 | 187 | [2](https://arxiv.org/abs/1503.08895) 188 | 189 | @inproceedings{sukhbaatar2015end, 190 | title={End-to-end memory networks}, 191 | author={Sukhbaatar, Sainbayar and Weston, Jason and Fergus, Rob and others}, 192 | booktitle={Advances in neural information processing systems}, 193 | pages={2440--2448}, 194 | year={2015} 195 | } 196 | 197 | [3](https://distill.pub/2016/augmented-rnns/) 198 | 199 | @article{olah2016attention, 200 | title={Attention and augmented recurrent neural networks}, 201 | author={Olah, Chris and Carter, Shan}, 202 | journal={Distill}, 203 | volume={1}, 204 | number={9}, 205 | pages={e1}, 206 | year={2016} 207 | } 208 | 209 | [4](https://arxiv.org/abs/1409.0473) 210 | 211 | @article{bahdanau2014neural, 212 | title={Neural machine translation by jointly learning to align and translate}, 213 | author={Bahdanau, Dzmitry and Cho, Kyunghyun and Bengio, Yoshua}, 214 | journal={arXiv preprint arXiv:1409.0473}, 215 | year={2014} 216 | } 217 | 218 | [5](https://arxiv.org/abs/1506.03134) 219 | 220 | @inproceedings{vinyals2015pointer, 221 | title={Pointer networks}, 222 | author={Vinyals, Oriol and Fortunato, Meire and Jaitly, Navdeep}, 223 | booktitle={Advances in Neural Information Processing Systems}, 224 | pages={2692--2700}, 225 | year={2015} 226 | } 227 | """ 228 | q, c, v = query, context, value 229 | if v is None: 230 | v = c 231 | 232 | batch_size_q, n_q, dim_q = q.size() 233 | batch_size_c, n_c, dim_c = c.size() 234 | batch_size_v, n_v, dim_v = v.size() 235 | 236 | if not (batch_size_q == batch_size_c == batch_size_v): 237 | msg = 'batch size mismatch (query: {}, context: {}, value: {})' 238 | raise ValueError(msg.format(q.size(), c.size(), v.size())) 239 | 240 | batch_size = batch_size_q 241 | 242 | # Compute scores 243 | if score == 'dot': 244 | s = dot(q, c) 245 | elif callable(score): 246 | s = score(q, c) 247 | else: 248 | raise ValueError(f'unknown score function: {score}') 249 | 250 | # Normalize scores and mask contexts 251 | if normalize == 'softmax': 252 | if context_mask is not None: 253 | s = context_mask + s 254 | 255 | elif context_sizes is not None: 256 | context_mask = s.data.new(batch_size, n_q, n_c) 257 | context_mask = fill_context_mask(context_mask, 258 | sizes=context_sizes, 259 | v_mask=float('-inf'), 260 | v_unmask=0 261 | ) 262 | s = context_mask + s 263 | 264 | s_flat = s.view(batch_size * n_q, n_c) 265 | w_flat = softmax(s_flat, dim=1) 266 | w = w_flat.view(batch_size, n_q, n_c) 267 | 268 | elif normalize == 'sigmoid' or normalize == 'identity': 269 | w = sigmoid(s) if normalize == 'sigmoid' else s 270 | if context_mask is not None: 271 | w = context_mask * w 272 | elif context_sizes is not None: 273 | context_mask = s.data.new(batch_size, n_q, n_c) 274 | context_mask = fill_context_mask(context_mask, 275 | sizes=context_sizes, 276 | v_mask=0, 277 | v_unmask=1 278 | ) 279 | w = context_mask * w 280 | 281 | else: 282 | raise ValueError(f'unknown normalize function: {normalize}') 283 | 284 | # Combine 285 | z = w.bmm(v) 286 | if return_weight: 287 | return w, z 288 | return z 289 | -------------------------------------------------------------------------------- /examples/Pointer-Network-Argmin-Argmax.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Pointer Network Attention Demo\n", 8 | "\n", 9 | "The below code trains a [pointer network](https://arxiv.org/abs/1506.03134) like architecture that takes as input a sequence of vectors and outputs the vector with the minimum or maximum value along a particular coordinate. In other words, a neural network version of argmax/argmin.\n", 10 | "\n", 11 | "### Setup\n", 12 | "\n", 13 | "Let $\\{\\mathbf{c}_i\\}_{i=1}^n = (\\mathbf{c}_1, \\ldots, \\mathbf{c}_n)$ be a sequence of $n$ vectors with each $\\mathbf{c}_i \\in \\mathbb{R}^d$. The minimum and maximum target positions, $i_\\min$ and $i_\\max$, for the sequence are given by\n", 14 | "\n", 15 | "$$\n", 16 | "\\begin{align*}\n", 17 | " i_\\min &= \\text{argmin}_i \\left\\{ x_{i, k_\\min}\\right\\} \\\\\n", 18 | " i_\\max &= \\text{argmax}_i \\left\\{ x_{i, k_\\max}\\right\\} \\\\\n", 19 | "\\end{align*}\n", 20 | "$$\n", 21 | "\n", 22 | "where $1 \\leq k_\\min \\neq k_\\max \\leq d$ are a priori chosen coordiantes along which to compute the minimum or maximum.\n", 23 | "\n", 24 | "### Model\n", 25 | "\n", 26 | "The model has the following form\n", 27 | "\n", 28 | "$$\n", 29 | "\\begin{align*}\n", 30 | " \\mathbf{u}_i &= A \\mathbf{c}_i & \\; i = 1, \\ldots, n \\\\\n", 31 | " \\mathbf{v} &= B \\mathbf{q} \\\\\n", 32 | " \\mathbf{p} &= \\text{softmax}_i(\\mathbf{v}^T \\mathbf{u}_i) \\\\\n", 33 | " \\mathbf{z} &= \\sum_i p_i \\mathbf{c}_i\n", 34 | "\\end{align*}\n", 35 | "$$\n", 36 | "\n", 37 | "where $A, B \\in \\mathbb{R}^{p \\times d}$ and $\\mathbf{q} \\in \\{\\mathbf{q}_\\min, \\mathbf{q}_\\max\\} \\subseteq \\mathbb{R}^d$ is a query vector indicating whether the model should output the minimum or maximum.\n", 38 | "\n", 39 | "The loss is defined as the mean squared error between the output and target vector.\n", 40 | "\n", 41 | "$$\n", 42 | "\\begin{align*}\n", 43 | " l &= \\frac{1}{n}\\sum_j (z_j - c_{t,j})^2\n", 44 | "\\end{align*}\n", 45 | "$$\n", 46 | "\n", 47 | "where $t$ is either $i_\\min$ or $i_\\max$.\n", 48 | "\n", 49 | "### Details\n", 50 | "The vectors $\\mathbf{q}_\\min$ and $\\mathbf{q}_\\max$ are initialized to random values sampled from $N(\\mathbb{0}, \\mathbb{I}_d)$ and held constant throughout training. The code below uses 10 dimensional context and query vectors and 7 dimensional hidden representations. The model is optimized over 1,600 training instances using RMSProp with mini batches of size 8. Each training instance contains betwen 5 and 14 context vectors. To better assess generalization validation instances are longer, containing between 15 and 24 context vectors." 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 1, 56 | "metadata": {}, 57 | "outputs": [ 58 | { 59 | "name": "stdout", 60 | "output_type": "stream", 61 | "text": [ 62 | "Seed: 1273\n" 63 | ] 64 | } 65 | ], 66 | "source": [ 67 | "%matplotlib inline\n", 68 | "import matplotlib.pyplot as plt\n", 69 | "import numpy as np\n", 70 | "import seaborn as sns\n", 71 | "\n", 72 | "import torch\n", 73 | "from torch.nn import Linear, Module\n", 74 | "\n", 75 | "from attention import attend\n", 76 | "\n", 77 | "\n", 78 | "seed = sum(map(ord, 'les bons mots'))\n", 79 | "np.random.seed(seed)\n", 80 | "torch.manual_seed(seed)\n", 81 | "print(f'Seed: {seed}')" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 2, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "class Data(object):\n", 91 | " dim = 10\n", 92 | " min_position, max_position = 3, 7\n", 93 | " q_min = np.random.normal(0, 1, dim).astype(np.float32)\n", 94 | " q_max = np.random.normal(0, 1, dim).astype(np.float32)\n", 95 | " query = np.row_stack([q_min, q_max])\n", 96 | "\n", 97 | " @staticmethod\n", 98 | " def create_minibatches(n, m, min_length, max_length):\n", 99 | " assert 0 < min_length <= max_length\n", 100 | "\n", 101 | " minibatches = []\n", 102 | " for i in range(n):\n", 103 | " lengths = np.random.randint(min_length, max_length + 1, m)\n", 104 | " context = np.zeros((m, lengths.max(), Data.dim), dtype=np.float32)\n", 105 | " target = np.zeros((m, 2, Data.dim), dtype=np.float32)\n", 106 | " target_indices = []\n", 107 | "\n", 108 | " for j, length in enumerate(lengths):\n", 109 | " c = np.random.normal(0, 1, (length, Data.dim))\n", 110 | " k_min = np.argmin(c[:,Data.min_position])\n", 111 | " k_max = np.argmax(c[:,Data.max_position])\n", 112 | " target_min = c[k_min]\n", 113 | " target_max = c[k_max]\n", 114 | " context[j,:length] = c\n", 115 | " target[j,0] = target_min\n", 116 | " target[j,1] = target_max\n", 117 | " target_indices.append((k_min, k_max))\n", 118 | "\n", 119 | " query = torch.from_numpy(np.tile(Data.query, (m, 1, 1)))\n", 120 | " context = torch.from_numpy(context)\n", 121 | " target = torch.from_numpy(target)\n", 122 | " minibatches.append((query, context, target, lengths, target_indices))\n", 123 | " return minibatches" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 3, 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "class PointerNet(Module):\n", 133 | " def __init__(self, n_hidden):\n", 134 | " super().__init__()\n", 135 | " self.n_hidden = n_hidden\n", 136 | " self.f = Linear(Data.dim, n_hidden)\n", 137 | " self.g = Linear(Data.dim, n_hidden)\n", 138 | "\n", 139 | " def forward(self, q, x, lengths=None, **kwargs):\n", 140 | " batch_size_q, n_queries, dim_q = q.size()\n", 141 | " batch_size_x, n_inputs, dim_x = x.size()\n", 142 | " assert batch_size_q == batch_size_x\n", 143 | " assert dim_q == dim_x\n", 144 | " batch_size = batch_size_q\n", 145 | " dim = dim_q\n", 146 | "\n", 147 | " q_flat = q.view(batch_size*n_queries, dim)\n", 148 | " u_flat = self.f(q_flat)\n", 149 | " u = u_flat.view(batch_size, n_queries, self.n_hidden)\n", 150 | "\n", 151 | " x_flat = x.view(batch_size*n_inputs, dim)\n", 152 | " v_flat = self.g(x_flat)\n", 153 | " v = v_flat.view(batch_size, n_inputs, self.n_hidden)\n", 154 | "\n", 155 | " return attend(u, v, value=x, context_sizes=lengths, **kwargs)" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 4, 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "batch_size = 8\n", 165 | "\n", 166 | "n_train = 200\n", 167 | "min_length_train, max_length_train = 5, 14\n", 168 | "train_batches = Data.create_minibatches(n_train, batch_size, min_length_train, max_length_train)\n", 169 | "\n", 170 | "n_valid = 100\n", 171 | "min_length_valid, max_length_valid = 15, 24\n", 172 | "valid_batches = Data.create_minibatches(n_valid, batch_size, min_length_valid, max_length_valid)" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 5, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "net = PointerNet(7)\n", 182 | "opt = torch.optim.RMSprop(net.parameters(), lr=0.001)\n", 183 | "mse = torch.nn.MSELoss()" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": 6, 189 | "metadata": {}, 190 | "outputs": [ 191 | { 192 | "name": "stdout", 193 | "output_type": "stream", 194 | "text": [ 195 | "[ 1] 0.647\n", 196 | "[ 2] 0.264\n", 197 | "[ 3] 0.162\n", 198 | "[ 4] 0.120\n", 199 | "[ 5] 0.096\n", 200 | "[ 6] 0.081\n", 201 | "[ 7] 0.070\n", 202 | "[ 8] 0.063\n", 203 | "[ 9] 0.057\n", 204 | "[10] 0.052\n" 205 | ] 206 | } 207 | ], 208 | "source": [ 209 | "epoch = 0\n", 210 | "max_epochs = 10\n", 211 | "while epoch < max_epochs:\n", 212 | " sum_loss = 0\n", 213 | " for query, context, target, lengths, target_indices in train_batches:\n", 214 | " net.zero_grad()\n", 215 | " output = net(query, context, lengths=lengths)\n", 216 | " loss = mse(output, target)\n", 217 | " loss.backward()\n", 218 | " opt.step()\n", 219 | " sum_loss += loss.item()\n", 220 | " epoch += 1\n", 221 | " print('[{:2d}] {:5.3f}'.format(epoch, sum_loss / n_train))" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": 7, 227 | "metadata": {}, 228 | "outputs": [ 229 | { 230 | "name": "stdout", 231 | "output_type": "stream", 232 | "text": [ 233 | "valid loss: 0.059\n", 234 | "valid error min: 0.018\n", 235 | "valid error max: 0.015\n" 236 | ] 237 | } 238 | ], 239 | "source": [ 240 | "sum_loss = 0\n", 241 | "sum_error_min = 0\n", 242 | "sum_error_max = 0\n", 243 | "\n", 244 | "for query, context, target, lengths, target_indices in valid_batches:\n", 245 | " with torch.no_grad():\n", 246 | " weight, output = net(query, context, lengths=lengths, return_weight=True)\n", 247 | " loss = mse(output, target)\n", 248 | "\n", 249 | " sum_loss += loss.item()\n", 250 | " weight = weight.data.numpy()\n", 251 | "\n", 252 | " for i, (i_min_true, i_max_true) in enumerate(target_indices):\n", 253 | " w_min, w_max = weight[i]\n", 254 | " i_min_pred = w_min.argmax()\n", 255 | " i_max_pred = w_max.argmax()\n", 256 | " sum_error_min += int(i_min_true != i_min_pred)\n", 257 | " sum_error_max += int(i_max_true != i_max_pred)\n", 258 | "\n", 259 | "print('valid loss: {:5.3f}'.format(sum_loss / n_valid))\n", 260 | "print('valid error min: {:5.3f}'.format(sum_error_min / (n_valid * batch_size)))\n", 261 | "print('valid error max: {:5.3f}'.format(sum_error_max / (n_valid * batch_size)))" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": 8, 267 | "metadata": {}, 268 | "outputs": [ 269 | { 270 | "data": { 271 | "image/png": "\n", 272 | "text/plain": [ 273 | "
" 274 | ] 275 | }, 276 | "metadata": { 277 | "needs_background": "light" 278 | }, 279 | "output_type": "display_data" 280 | } 281 | ], 282 | "source": [ 283 | "query, context, target, lengths, target_indices = valid_batches[0]\n", 284 | "\n", 285 | "with torch.no_grad():\n", 286 | " weight, output = net(query, context, lengths=lengths, return_weight=True)\n", 287 | "\n", 288 | "context = context.numpy()\n", 289 | "weight = weight.data.numpy()\n", 290 | "\n", 291 | "colors = sns.color_palette('husl', 3)\n", 292 | "\n", 293 | "fig, axs = plt.subplots(batch_size, 2, figsize=(10, 10), sharex=True, sharey=True)\n", 294 | "sns.despine(fig=fig)\n", 295 | "axs_min, axs_max = axs[:,0], axs[:,1]\n", 296 | "\n", 297 | "for i, (i_min, i_max) in enumerate(target_indices):\n", 298 | " length = lengths[i]\n", 299 | " w_min, w_max = weight[i]\n", 300 | " c_min = [context[i,j,Data.min_position] for j in range(length)]\n", 301 | " c_max = [context[i,j,Data.max_position] for j in range(length)]\n", 302 | "\n", 303 | " axs_min[i].axvline(i_min, zorder=1, color=colors[2], label='min value')\n", 304 | " axs_min[i].bar(np.arange(length) - 0.4, c_min, zorder=2, color=colors[1], lw=0, label='values')\n", 305 | " axs_min[i].plot(np.arange(length), w_min[:length], zorder=3, color=colors[0], lw=4, label='attention')\n", 306 | "\n", 307 | " axs_max[i].axvline(i_max, zorder=1, color=colors[2], label='max value')\n", 308 | " axs_max[i].bar(np.arange(length) - 0.4, c_max, zorder=2, color=colors[1], lw=0, label='values')\n", 309 | " axs_max[i].plot(np.arange(length), w_max[:length], zorder=3, color=colors[0], lw=4, label='attention')\n", 310 | "\n", 311 | "axs_min[0].set_title('Minimum')\n", 312 | "axs_max[0].set_title('Maximum')\n", 313 | "axs_max[0].legend(loc='best') \n", 314 | "axs_min[0].legend(loc='best')\n", 315 | "axs_max[0].legend(loc='best')\n", 316 | "axs_min[-1].set_xlabel('position')\n", 317 | "axs_max[-1].set_xlabel('position')\n", 318 | "plt.tight_layout()" 319 | ] 320 | }, 321 | { 322 | "cell_type": "code", 323 | "execution_count": null, 324 | "metadata": {}, 325 | "outputs": [], 326 | "source": [] 327 | } 328 | ], 329 | "metadata": { 330 | "kernelspec": { 331 | "display_name": "Python 3", 332 | "language": "python", 333 | "name": "python3" 334 | }, 335 | "language_info": { 336 | "codemirror_mode": { 337 | "name": "ipython", 338 | "version": 3 339 | }, 340 | "file_extension": ".py", 341 | "mimetype": "text/x-python", 342 | "name": "python", 343 | "nbconvert_exporter": "python", 344 | "pygments_lexer": "ipython3", 345 | "version": "3.7.1" 346 | } 347 | }, 348 | "nbformat": 4, 349 | "nbformat_minor": 1 350 | } 351 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | 3 | setup( 4 | name='attention', 5 | version='0.1.0', 6 | author='tllake', 7 | author_email='thom.l.lake@gmail.com', 8 | packages=['attention'], 9 | description='An attention function for PyTorch.', 10 | long_description=open('README.md').read()) -------------------------------------------------------------------------------- /svgs/190083ef7a1625fbc75f243cffb9c96d.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /svgs/1eb39a281b1e66935a51005b6beb9dbe.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /svgs/28e6b84adb66aca59d04ec9e227bfd3f.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /svgs/39c9d05724010ea29be9eb321b1422ec.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /svgs/39d2a848a943a7f5ec27272dad27c784.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /svgs/3cf4fbd05970446973fc3d9fa3fe3c41.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /svgs/5397f1268e113895a997a61e51165ffc.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /svgs/a5a09669219f681bb51e176b190b0e4a.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /svgs/a5d4c0a87edcc90e9dc7bb8a1845e86a.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /svgs/da2cf8b162672dc46adcace06ec2740a.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /svgs/e73485aa867794d51ccd8725055d03a3.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /test/test_attention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | import torch 5 | from attention import attention 6 | 7 | 8 | def test_apply_mask_3d(): 9 | batch_size, m, n = 3, 4, 5 10 | sizes = [4, 3, 2] 11 | values = torch.randn(batch_size, m, n) 12 | masked = attention.mask3d(values, sizes=sizes).data 13 | assert values.size() == masked.size() == (batch_size, m, n) 14 | for i in range(batch_size): 15 | for j in range(m): 16 | for k in range(n): 17 | if j < sizes[i]: 18 | assert masked[i,j,k] == values[i,j,k] 19 | else: 20 | assert masked[i,j,k] == 0 21 | 22 | 23 | @pytest.mark.parametrize('v_mask, v_unmask', [(0, 1), (float('-inf'), 0)]) 24 | def test_fill_context_mask(v_mask, v_unmask): 25 | batch_size, n_q, n_c = 3, 4, 5 26 | query_sizes = [4, 3, 2] 27 | context_sizes = [3, 2, 5] 28 | mask = torch.randn(batch_size, n_q, n_c) 29 | mask = attention.fill_context_mask( 30 | mask, sizes=context_sizes, 31 | v_mask=v_mask, v_unmask=v_unmask) 32 | 33 | for i in range(batch_size): 34 | for j in range(n_q): 35 | for k in range(n_c): 36 | if k < context_sizes[i]: 37 | assert mask[i,j,k] == v_unmask 38 | else: 39 | assert mask[i,j,k] == v_mask 40 | 41 | 42 | def test_dot(): 43 | batch_size, n_q, n_c, d = 31, 18, 15, 22 44 | q = np.random.normal(0, 1, (batch_size, n_q, d)) 45 | c = np.random.normal(0, 1, (batch_size, n_c, d)) 46 | 47 | s = attention.dot(torch.from_numpy(q), 48 | torch.from_numpy(c) 49 | ) 50 | s = s.data.numpy() 51 | 52 | assert s.shape == (batch_size, n_q, n_c) 53 | 54 | for i in range(batch_size): 55 | for j in range(n_q): 56 | for k in range(n_c): 57 | assert np.allclose(np.dot(q[i,j], c[i,k]), s[i,j,k]) 58 | 59 | 60 | @pytest.mark.parametrize( 61 | 'batch_size,n_q,n_c,d', [ 62 | (1, 1, 6, 11), 63 | (20, 1, 10, 3), 64 | (3, 10, 15, 5)]) 65 | def test_attention(batch_size, n_q, n_c, d): 66 | q = np.random.normal(0, 1, (batch_size, n_q, d)) 67 | c = np.random.normal(0, 1, (batch_size, n_c, d)) 68 | 69 | w_out, z_out = attention.attend(torch.from_numpy(q), 70 | torch.from_numpy(c), 71 | return_weight=True 72 | ) 73 | w_out = w_out.data.numpy() 74 | z_out = z_out.data.numpy() 75 | 76 | assert w_out.shape == (batch_size, n_q, n_c) 77 | assert z_out.shape == (batch_size, n_q, d) 78 | 79 | for i in range(batch_size): 80 | for j in range(n_q): 81 | s = [np.dot(q[i,j], c[i,k]) for k in range(n_c)] 82 | max_s = max(s) 83 | exp_s = [np.exp(si - max_s) for si in s] 84 | sum_exp_s = sum(exp_s) 85 | 86 | w_ref = [ei / sum_exp_s for ei in exp_s] 87 | assert np.allclose(w_ref, w_out[i,j]) 88 | 89 | z_ref = sum(w_ref[k] * c[i,k] for k in range(n_c)) 90 | assert np.allclose(z_ref, z_out[i,j]) 91 | 92 | 93 | @pytest.mark.parametrize( 94 | 'batch_size,n_q,n_c,d,p', [ 95 | (1, 1, 6, 11, 5), 96 | (20, 1, 10, 3, 14), 97 | (3, 10, 15, 5, 9)]) 98 | def test_attention_values(batch_size, n_q, n_c, d, p): 99 | q = np.random.normal(0, 1, (batch_size, n_q, d)) 100 | c = np.random.normal(0, 1, (batch_size, n_c, d)) 101 | v = np.random.normal(0, 1, (batch_size, n_c, p)) 102 | 103 | w_out, z_out = attention.attend(torch.from_numpy(q), 104 | torch.from_numpy(c), 105 | value=torch.from_numpy(v), 106 | return_weight=True 107 | ) 108 | w_out = w_out.data.numpy() 109 | z_out = z_out.data.numpy() 110 | 111 | assert w_out.shape == (batch_size, n_q, n_c) 112 | assert z_out.shape == (batch_size, n_q, p) 113 | 114 | for i in range(batch_size): 115 | for j in range(n_q): 116 | s = [np.dot(q[i,j], c[i,k]) for k in range(n_c)] 117 | max_s = max(s) 118 | exp_s = [np.exp(si - max_s) for si in s] 119 | sum_exp_s = sum(exp_s) 120 | 121 | w_ref = [ei / sum_exp_s for ei in exp_s] 122 | assert np.allclose(w_ref, w_out[i,j]) 123 | 124 | z_ref = sum(w_ref[k] * v[i,k] for k in range(n_c)) 125 | assert np.allclose(z_ref, z_out[i,j]) 126 | 127 | 128 | @pytest.mark.parametrize( 129 | 'batch_size,n_q,n_c,d,context_sizes', [ 130 | (1, 1, 6, 11, [3]), 131 | (4, 1, 10, 3, [7, 5, 10, 9])]) 132 | def test_attention_masked(batch_size, n_q, n_c, d, context_sizes): 133 | q = np.random.normal(0, 1, (batch_size, n_q, d)) 134 | c = np.random.normal(0, 1, (batch_size, n_c, d)) 135 | 136 | w_out, z_out = attention.attend(torch.from_numpy(q), 137 | torch.from_numpy(c), 138 | context_sizes=context_sizes, 139 | return_weight=True 140 | ) 141 | w_out = w_out.data.numpy() 142 | z_out = z_out.data.numpy() 143 | 144 | assert w_out.shape == (batch_size, n_q, n_c) 145 | assert z_out.shape == (batch_size, n_q, d) 146 | 147 | w_checked = np.zeros((batch_size, n_q, n_c), dtype=int) 148 | z_checked = np.zeros((batch_size, n_q, d), dtype=int) 149 | 150 | for i in range(batch_size): 151 | for j in range(n_q): 152 | n = context_sizes[i] if context_sizes is not None else n_c 153 | 154 | s = [np.dot(q[i,j], c[i,k]) for k in range(n)] 155 | max_s = max(s) 156 | exp_s = [np.exp(sk - max_s) for sk in s] 157 | sum_exp_s = sum(exp_s) 158 | 159 | w_ref = [ek / sum_exp_s for ek in exp_s] 160 | for k in range(n_c): 161 | if k < n: 162 | assert np.allclose(w_ref[k], w_out[i,j,k]) 163 | w_checked[i,j,k] = 1 164 | else: 165 | assert np.allclose(0, w_out[i,j,k]) 166 | w_checked[i,j,k] = 1 167 | 168 | z_ref = sum(w_ref[k] * c[i,k] for k in range(n)) 169 | for k in range(d): 170 | assert np.allclose(z_ref[k], z_out[i,j,k]) 171 | z_checked[i,j,k] = 1 172 | 173 | assert np.all(w_checked == 1) 174 | assert np.all(z_checked == 1) 175 | --------------------------------------------------------------------------------