├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── setup.py ├── tests ├── test_encoder.py ├── test_ff.py ├── test_layer.py ├── test_ln.py ├── test_mha.py ├── test_optim.py └── test_pe.py └── transformer_encoder ├── __init__.py ├── encoder.py ├── encoder_layer.py ├── feed_forward.py ├── layer_norm.py ├── multi_head_attention.py └── utils ├── __init__.py ├── positional_encoding.py └── warmup_optimizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | .idea/ 106 | test.py 107 | release.sh -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | dist: xenial 2 | language: python 3 | python: 4 | - "3.5" 5 | - "3.6" 6 | - "3.7" 7 | before_install: 8 | - if [[ $TRAVIS_PYTHON_VERSION == 3.5 ]]; then pip install https://download.pytorch.org/whl/cpu/torch-1.0.0-cp35-cp35m-linux_x86_64.whl; fi 9 | - if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then pip install https://download.pytorch.org/whl/cpu/torch-1.0.0-cp36-cp36m-linux_x86_64.whl; fi 10 | - if [[ $TRAVIS_PYTHON_VERSION == 3.7 ]]; then pip install https://download.pytorch.org/whl/cpu/torch-1.0.0-cp37-cp37m-linux_x86_64.whl; fi 11 | install: python setup.py install 12 | script: pytest 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 guocheng2018 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Transformer Encoder 2 |

3 | 4 | 5 |

6 | 7 | This repository provides a pytorch implementation of the encoder of [Transformer](http://papers.nips.cc/paper/7181-attention-is-all-you-need/). 8 | 9 |

10 | encoder 11 |

12 | 13 | ## Getting started 14 | 15 | Build a transformer encoder 16 | ```python 17 | from transformer_encoder import TransformerEncoder 18 | 19 | encoder = TransformerEncoder(d_model=512, d_ff=2048, n_heads=8, n_layers=6, dropout=0.1) 20 | 21 | input_seqs = ... 22 | mask = ... 23 | out = encoder(input_seqs, mask) 24 | ``` 25 | 26 | Add positional encoding to input embeddings 27 | ```python 28 | import torch.nn as nn 29 | from transformer_encoder.utils import PositionalEncoding 30 | 31 | input_layer = nn.Sequential( 32 | nn.Embedding(num_embeddings=10000, embedding_dim=512), 33 | PositionalEncoding(d_model=512, dropout=0.1, max_len=5000) 34 | ) 35 | ``` 36 | 37 | Optimize model with the warming up strategy 38 | ```python 39 | import torch.optim as optim 40 | from transformer_encoder.utils import WarmupOptimizer 41 | 42 | model = ... 43 | 44 | base_optimizer = optim.Adam(model.parameters(), lr=1e-3) 45 | optimizer = WarmupOptimizer(base_optimizer, d_model=512, scale_factor=1, warmup_steps=100) 46 | 47 | optimizer.zero_grad() 48 | loss = ... 49 | loss.backward() 50 | optimizer.step() 51 | ``` 52 | 53 | ## API Reference 54 | 55 | *transformer_encoder.TransformerEncoder(d_model, d_ff, n_heads=1, n_layers=1, dropout=0.1)* 56 | 57 | - `d_model`: dimension of each word vector 58 | - `d_ff`: hidden dimension of feed forward layer 59 | - `n_heads`: number of heads in self-attention (defaults to 1) 60 | - `n_layers`: number of stacked layers of encoder (defaults to 1) 61 | - `dropout`: dropout rate (defaults to 0.1) 62 | 63 | *transformer_encoder.TransformerEncoder.forward(x, mask)* 64 | 65 | - `x (~torch.FloatTensor)`: shape *(batch_size, max_seq_len, d_model)* 66 | - `mask (~torch.ByteTensor)`: shape *(batch_size, max_seq_len)* 67 | 68 | *transformer_encoder.utils.PositionalEncoding(d_model, dropout=0.1, max_len=5000)* 69 | 70 | - `d_model`: same as TransformerEncoder 71 | - `dropout`: dropout rate (defaults to 0.1) 72 | - `max_len`: max sequence length (defaults to 5000) 73 | 74 | *transformer_encoder.utils.PositionalEncoding.forward(x)* 75 | 76 | - `x (~torch.FloatTensor)`: shape *(batch_size, max_seq_len, d_model)* 77 | 78 | *transformer_encoder.utils.WarmupOptimizer(base_optimizer, d_model, scale_factor, warmup_steps)* 79 | 80 | - `base_optimizer (~torch.optim.Optimzier)`: e.g. adam optimzier 81 | - `d_model`: equals d_model in TransformerEncoder 82 | - `scale_factor`: scale factor of learning rate 83 | - `warmup_steps`: warming up steps 84 | 85 | 86 | ## Installation 87 | Requires `python 3.5+`, `pytorch 1.0.0+` 88 | ``` 89 | pip install transformer_encoder 90 | ``` 91 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="transformer_encoder", 8 | version="0.0.3", 9 | author="Cheng Guo", 10 | author_email="guocheng672@gmail.com", 11 | description="A pytorch implementation of transformer encoder", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/guocheng2018/transformer-encoder", 15 | packages=setuptools.find_packages(), 16 | python_requires=">=3.5", 17 | classifiers=[ 18 | "Programming Language :: Python :: 3", 19 | "License :: OSI Approved :: MIT License", 20 | "Operating System :: OS Independent", 21 | ], 22 | ) 23 | -------------------------------------------------------------------------------- /tests/test_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from transformer_encoder.encoder import TransformerEncoder 4 | 5 | d_model = 512 6 | n_heads = 8 7 | batch_size = 64 8 | max_len = 100 9 | d_ff = 2048 10 | dropout = 0.1 11 | n_layers = 6 12 | 13 | 14 | def test_encoder(): 15 | enc = TransformerEncoder(d_model, d_ff, n_heads=n_heads, n_layers=n_layers, dropout=dropout) 16 | x = torch.randn(batch_size, max_len, d_model) 17 | mask = torch.randn(batch_size, max_len).ge(0) 18 | out = enc(x, mask) 19 | assert x.size() == out.size() 20 | -------------------------------------------------------------------------------- /tests/test_ff.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformer_encoder.feed_forward import FeedForward 3 | 4 | batch_size = 64 5 | max_len = 100 6 | d_model = 512 7 | d_ff = 2048 8 | dropout = 0.1 9 | 10 | 11 | def test_ff(): 12 | ff = FeedForward(d_model, d_ff, dropout) 13 | x = torch.randn(batch_size, max_len, d_model) 14 | out = ff(x) 15 | assert x.size() == out.size() 16 | -------------------------------------------------------------------------------- /tests/test_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from transformer_encoder.encoder_layer import EncoderLayer 4 | from transformer_encoder.feed_forward import FeedForward 5 | from transformer_encoder.multi_head_attention import MultiHeadAttention 6 | 7 | d_model = 512 8 | n_heads = 8 9 | batch_size = 64 10 | max_len = 100 11 | d_ff = 2048 12 | dropout = 0.1 13 | 14 | 15 | def test_enclayer(): 16 | # Components 17 | mha = MultiHeadAttention(n_heads, d_model) 18 | ff = FeedForward(d_model, d_ff, dropout) 19 | enclayer = EncoderLayer(d_model, mha, ff, dropout) 20 | # Input 21 | x = torch.randn(batch_size, max_len, d_model) 22 | mask = torch.randn(batch_size, max_len).ge(0) 23 | out = enclayer(x, mask) 24 | assert x.size() == out.size() 25 | 26 | -------------------------------------------------------------------------------- /tests/test_ln.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from transformer_encoder.layer_norm import LayerNorm 4 | 5 | d_model = 512 6 | eps = 1e-6 7 | batch_size = 64 8 | max_len = 100 9 | 10 | 11 | def test_ln(): 12 | LN = LayerNorm(d_model, eps) 13 | x = torch.randn(batch_size, max_len, d_model) 14 | out = LN(x) 15 | assert x.size() == out.size() 16 | -------------------------------------------------------------------------------- /tests/test_mha.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformer_encoder.multi_head_attention import MultiHeadAttention 3 | 4 | 5 | d_model = 512 6 | n_heads = 8 7 | batch_size = 64 8 | max_len = 100 9 | 10 | 11 | def test_mha(): 12 | mha = MultiHeadAttention(n_heads, d_model) 13 | x = torch.randn(batch_size, max_len, d_model) 14 | out = mha(x, x, x) 15 | assert out.size() == x.size() 16 | 17 | 18 | def test_masked_mha(): 19 | mha = MultiHeadAttention(n_heads, d_model) 20 | x = torch.randn(batch_size, max_len, d_model) 21 | mask = torch.randn(batch_size, max_len).ge(0) 22 | out = mha(x, x, x, mask) 23 | assert out.size() == x.size() 24 | -------------------------------------------------------------------------------- /tests/test_optim.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | 3 | from transformer_encoder.encoder import TransformerEncoder 4 | from transformer_encoder.utils import WarmupOptimizer 5 | 6 | d_model = 512 7 | n_heads = 8 8 | d_ff = 2048 9 | dropout = 0.1 10 | n_layers = 6 11 | 12 | scale_factor = 1 13 | warmup_steps = 20 14 | 15 | 16 | def test_optim(): 17 | enc = TransformerEncoder(d_model, d_ff, n_heads=n_heads, n_layers=n_layers, dropout=dropout) 18 | opt = WarmupOptimizer(optim.Adam(enc.parameters()), d_model, scale_factor, warmup_steps) 19 | assert type(opt.rate(step=1)) is float # step starts from 1 20 | opt.step() 21 | -------------------------------------------------------------------------------- /tests/test_pe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from transformer_encoder.utils import PositionalEncoding 4 | 5 | d_model = 512 6 | dropout = 0.1 7 | max_len = 100 8 | batch_size = 64 9 | 10 | 11 | def test_pe(): 12 | PE = PositionalEncoding(d_model, dropout=dropout, max_len=max_len) 13 | embeds = torch.randn(batch_size, max_len, d_model) # (batch_size, max_len, d_model) 14 | out = PE(embeds) 15 | assert embeds.size() == out.size() 16 | -------------------------------------------------------------------------------- /transformer_encoder/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Transformer Encoder 3 | 4 | For more information: https://github.com/guocheng2018/transformer-encoder 5 | """ 6 | __version__ = "0.0.3" 7 | 8 | from .encoder import TransformerEncoder 9 | -------------------------------------------------------------------------------- /transformer_encoder/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .encoder_layer import EncoderLayer 5 | from .feed_forward import FeedForward 6 | from .layer_norm import LayerNorm 7 | from .multi_head_attention import MultiHeadAttention 8 | from .utils import clones 9 | 10 | 11 | class Encoder(nn.Module): 12 | """Core encoder is a stack of N layers""" 13 | 14 | def __init__(self, layer: EncoderLayer, N: int): 15 | super(Encoder, self).__init__() 16 | self.layers = clones(layer, N) 17 | self.norm = LayerNorm(layer.size) 18 | 19 | def forward(self, x: torch.FloatTensor, mask: torch.ByteTensor) -> torch.FloatTensor: 20 | """Pass the input (and mask) through each layer in turn.""" 21 | for layer in self.layers: 22 | x = layer(x, mask) 23 | return self.norm(x) 24 | 25 | 26 | class TransformerEncoder(nn.Module): 27 | """The encoder of transformer 28 | 29 | Args: 30 | `n_layers`: number of stacked encoder layers 31 | `d_model`: model dimension 32 | `d_ff`: hidden dimension of feed forward layer 33 | `n_heads`: number of heads of self-attention 34 | `dropout`: dropout rate, default 0.1 35 | """ 36 | 37 | def __init__(self, d_model: int, d_ff: int, n_heads: int = 1, n_layers: int = 1, 38 | dropout: float = 0.1): 39 | super(TransformerEncoder, self).__init__() 40 | self.multi_headed_attention = MultiHeadAttention(n_heads, d_model, dropout) 41 | self.feed_forward = FeedForward(d_model, d_ff, dropout) 42 | self.encoder_layer = EncoderLayer(d_model, self.multi_headed_attention, self.feed_forward, dropout) 43 | self.encoder = Encoder(self.encoder_layer, n_layers) 44 | self.reset_parameters() 45 | 46 | def reset_parameters(self): 47 | for p in self.parameters(): 48 | if p.dim() > 1: 49 | nn.init.xavier_uniform_(p) 50 | 51 | def forward(self, x: torch.FloatTensor, mask: torch.ByteTensor) -> torch.FloatTensor: 52 | return self.encoder(x, mask) 53 | -------------------------------------------------------------------------------- /transformer_encoder/encoder_layer.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .feed_forward import FeedForward 7 | from .layer_norm import LayerNorm 8 | from .multi_head_attention import MultiHeadAttention 9 | from .utils import clones 10 | 11 | 12 | class EncoderLayer(nn.Module): 13 | """Encoder is made up of self-attn and feed forward""" 14 | 15 | def __init__(self, size: int, self_attn: MultiHeadAttention, feed_forward: FeedForward, dropout: float): 16 | super(EncoderLayer, self).__init__() 17 | self.self_attn = self_attn 18 | self.feed_forward = feed_forward 19 | self.sublayer = clones(SublayerConnection(size, dropout), 2) 20 | self.size = size 21 | 22 | def forward(self, x: torch.FloatTensor, mask: torch.ByteTensor) -> torch.FloatTensor: 23 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) 24 | return self.sublayer[1](x, self.feed_forward) 25 | 26 | 27 | class SublayerConnection(nn.Module): 28 | """ 29 | A residual connection followed by a layer norm. 30 | Note for code simplicity the norm is first as opposed to last. 31 | """ 32 | 33 | def __init__(self, size: int, dropout: float): 34 | super(SublayerConnection, self).__init__() 35 | self.norm = LayerNorm(size) 36 | self.dropout = nn.Dropout(dropout) 37 | 38 | def forward(self, x: torch.FloatTensor, sublayer: Union[MultiHeadAttention, FeedForward]) -> torch.FloatTensor: 39 | """Apply residual connection to any sublayer with the same size.""" 40 | return x + self.dropout(sublayer(self.norm(x))) 41 | -------------------------------------------------------------------------------- /transformer_encoder/feed_forward.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FeedForward(nn.Module): 7 | def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1): 8 | super(FeedForward, self).__init__() 9 | self.w_1 = nn.Linear(d_model, d_ff) 10 | self.w_2 = nn.Linear(d_ff, d_model) 11 | self.dropout = nn.Dropout(dropout) 12 | 13 | def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: 14 | """ 15 | Args: 16 | `x`: shape (batch_size, max_len, d_model) 17 | 18 | Returns: 19 | same shape as input x 20 | """ 21 | return self.w_2(self.dropout(F.relu(self.w_1(x)))) 22 | -------------------------------------------------------------------------------- /transformer_encoder/layer_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class LayerNorm(nn.Module): 6 | def __init__(self, features: int, eps: float = 1e-6): 7 | # features = d_model 8 | super(LayerNorm, self).__init__() 9 | self.a = nn.Parameter(torch.ones(features)) 10 | self.b = nn.Parameter(torch.zeros(features)) 11 | self.eps = eps 12 | 13 | def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: 14 | mean = x.mean(-1, keepdim=True) 15 | std = x.std(-1, keepdim=True) 16 | return self.a * (x - mean) / (std + self.eps) + self.b 17 | -------------------------------------------------------------------------------- /transformer_encoder/multi_head_attention.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional, Tuple, Any 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from .utils import clones 9 | 10 | 11 | class ScaledDotProductAttention(nn.Module): 12 | def __init__(self): 13 | super(ScaledDotProductAttention, self).__init__() 14 | 15 | def forward(self, query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, 16 | mask: Optional[torch.ByteTensor] = None, dropout: Optional[nn.Dropout] = None) -> Tuple[ 17 | torch.Tensor, Any]: 18 | """ 19 | Args: 20 | `query`: shape (batch_size, n_heads, max_len, d_q) 21 | `key`: shape (batch_size, n_heads, max_len, d_k) 22 | `value`: shape (batch_size, n_heads, max_len, d_v) 23 | `mask`: shape (batch_size, 1, 1, max_len) 24 | `dropout`: nn.Dropout 25 | 26 | Returns: 27 | `weighted value`: shape (batch_size, n_heads, max_len, d_v) 28 | `weight matrix`: shape (batch_size, n_heads, max_len, max_len) 29 | """ 30 | d_k = query.size(-1) # d_k = d_model / n_heads 31 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) # B*H*L*L 32 | if mask is not None: 33 | scores = scores.masked_fill(mask.eq(0), -1e9) 34 | p_attn = F.softmax(scores, dim=-1) # B*H*L*L 35 | if dropout is not None: 36 | p_attn = dropout(p_attn) 37 | return torch.matmul(p_attn, value), p_attn 38 | 39 | 40 | class MultiHeadAttention(nn.Module): 41 | def __init__(self, n_heads: int, d_model: int, dropout: float = 0.1): 42 | super(MultiHeadAttention, self).__init__() 43 | assert d_model % n_heads == 0 44 | # We assume d_v always equals d_k 45 | self.d_k = d_model // n_heads 46 | self.h = n_heads 47 | self.linears = clones(nn.Linear(d_model, d_model), 4) 48 | self.sdpa = ScaledDotProductAttention() 49 | self.attn = None 50 | self.dropout = nn.Dropout(p=dropout) 51 | 52 | def forward(self, query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, 53 | mask: Optional[torch.ByteTensor] = None) -> torch.FloatTensor: 54 | """ 55 | Args: 56 | `query`: shape (batch_size, max_len, d_model) 57 | `key`: shape (batch_size, max_len, d_model) 58 | `value`: shape (batch_size, max_len, d_model) 59 | `mask`: shape (batch_size, max_len) 60 | 61 | Returns: 62 | shape (batch_size, max_len, d_model) 63 | """ 64 | if mask is not None: 65 | # Same mask applied to all h heads. B*1*1*L 66 | mask = mask.unsqueeze(1).unsqueeze(1) 67 | batch_size = query.size(0) 68 | 69 | # 1) Do all the linear projections in batch from d_model => h x d_k 70 | query, key, value = [l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2) for l, x in 71 | zip(self.linears, (query, key, value))] 72 | 73 | # 2) Apply attention on all the projected vectors in batch. 74 | # x: B x H x L x D_v 75 | x, self.attn = self.sdpa(query, key, value, mask=mask, dropout=self.dropout) 76 | 77 | # 3) "Concat" using a view and apply a final linear. 78 | x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k) 79 | return self.linears[-1](x) 80 | -------------------------------------------------------------------------------- /transformer_encoder/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch.nn as nn 4 | 5 | from .warmup_optimizer import WarmupOptimizer 6 | from .positional_encoding import PositionalEncoding 7 | 8 | 9 | def clones(module, N): 10 | """Produce N identical layers.""" 11 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 12 | -------------------------------------------------------------------------------- /transformer_encoder/utils/positional_encoding.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class PositionalEncoding(nn.Module): 8 | def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): 9 | super(PositionalEncoding, self).__init__() 10 | self.dropout = nn.Dropout(p=dropout) 11 | 12 | # Compute the positional encodings once in log space. 13 | pe = torch.zeros(max_len, d_model) 14 | position = torch.arange(0, max_len).unsqueeze(1).float() 15 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)) 16 | pe[:, 0::2] = torch.sin(position * div_term) 17 | pe[:, 1::2] = torch.cos(position * div_term) 18 | pe = pe.unsqueeze(0) 19 | self.register_buffer("pe", pe) 20 | 21 | def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: 22 | """ 23 | Args: 24 | x: `embeddings`, shape (batch, max_len, d_model) 25 | 26 | Returns: 27 | `encoder input`, shape (batch, max_len, d_model) 28 | """ 29 | x = x + self.pe[:, : x.size(1)] 30 | return self.dropout(x) 31 | -------------------------------------------------------------------------------- /transformer_encoder/utils/warmup_optimizer.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch.optim as optim 4 | 5 | 6 | class WarmupOptimizer: 7 | """Optim wrapper that implements rate.""" 8 | 9 | def __init__(self, base_optimizer: optim.Optimizer, d_model: int, scale_factor: float, warmup_steps: int): 10 | self.base_optimizer = base_optimizer 11 | self.warmup_steps = warmup_steps 12 | self.scale_factor = scale_factor 13 | self.d_model = d_model 14 | self._step = 0 15 | self._rate = 0 16 | 17 | def step(self): 18 | """Update parameters and rate""" 19 | self._step += 1 20 | self._rate = self.rate() 21 | for p in self.base_optimizer.param_groups: 22 | p["lr"] = self._rate 23 | self.base_optimizer.step() 24 | 25 | def zero_grad(self): 26 | self.base_optimizer.zero_grad() 27 | 28 | def rate(self, step: Optional[int] = None) -> float: 29 | """Implement `lrate` above""" 30 | if step is None: 31 | step = self._step 32 | return self.scale_factor * self.d_model ** (-0.5) * min(step ** (-0.5), step * self.warmup_steps ** (-1.5)) 33 | --------------------------------------------------------------------------------