├── .gitignore
├── .travis.yml
├── LICENSE
├── README.md
├── setup.py
├── tests
├── test_encoder.py
├── test_ff.py
├── test_layer.py
├── test_ln.py
├── test_mha.py
├── test_optim.py
└── test_pe.py
└── transformer_encoder
├── __init__.py
├── encoder.py
├── encoder_layer.py
├── feed_forward.py
├── layer_norm.py
├── multi_head_attention.py
└── utils
├── __init__.py
├── positional_encoding.py
└── warmup_optimizer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 | .idea/
106 | test.py
107 | release.sh
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | dist: xenial
2 | language: python
3 | python:
4 | - "3.5"
5 | - "3.6"
6 | - "3.7"
7 | before_install:
8 | - if [[ $TRAVIS_PYTHON_VERSION == 3.5 ]]; then pip install https://download.pytorch.org/whl/cpu/torch-1.0.0-cp35-cp35m-linux_x86_64.whl; fi
9 | - if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then pip install https://download.pytorch.org/whl/cpu/torch-1.0.0-cp36-cp36m-linux_x86_64.whl; fi
10 | - if [[ $TRAVIS_PYTHON_VERSION == 3.7 ]]; then pip install https://download.pytorch.org/whl/cpu/torch-1.0.0-cp37-cp37m-linux_x86_64.whl; fi
11 | install: python setup.py install
12 | script: pytest
13 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 guocheng2018
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Transformer Encoder
2 |
3 |
4 |
5 |
6 |
7 | This repository provides a pytorch implementation of the encoder of [Transformer](http://papers.nips.cc/paper/7181-attention-is-all-you-need/).
8 |
9 |
10 |
11 |
12 |
13 | ## Getting started
14 |
15 | Build a transformer encoder
16 | ```python
17 | from transformer_encoder import TransformerEncoder
18 |
19 | encoder = TransformerEncoder(d_model=512, d_ff=2048, n_heads=8, n_layers=6, dropout=0.1)
20 |
21 | input_seqs = ...
22 | mask = ...
23 | out = encoder(input_seqs, mask)
24 | ```
25 |
26 | Add positional encoding to input embeddings
27 | ```python
28 | import torch.nn as nn
29 | from transformer_encoder.utils import PositionalEncoding
30 |
31 | input_layer = nn.Sequential(
32 | nn.Embedding(num_embeddings=10000, embedding_dim=512),
33 | PositionalEncoding(d_model=512, dropout=0.1, max_len=5000)
34 | )
35 | ```
36 |
37 | Optimize model with the warming up strategy
38 | ```python
39 | import torch.optim as optim
40 | from transformer_encoder.utils import WarmupOptimizer
41 |
42 | model = ...
43 |
44 | base_optimizer = optim.Adam(model.parameters(), lr=1e-3)
45 | optimizer = WarmupOptimizer(base_optimizer, d_model=512, scale_factor=1, warmup_steps=100)
46 |
47 | optimizer.zero_grad()
48 | loss = ...
49 | loss.backward()
50 | optimizer.step()
51 | ```
52 |
53 | ## API Reference
54 |
55 | *transformer_encoder.TransformerEncoder(d_model, d_ff, n_heads=1, n_layers=1, dropout=0.1)*
56 |
57 | - `d_model`: dimension of each word vector
58 | - `d_ff`: hidden dimension of feed forward layer
59 | - `n_heads`: number of heads in self-attention (defaults to 1)
60 | - `n_layers`: number of stacked layers of encoder (defaults to 1)
61 | - `dropout`: dropout rate (defaults to 0.1)
62 |
63 | *transformer_encoder.TransformerEncoder.forward(x, mask)*
64 |
65 | - `x (~torch.FloatTensor)`: shape *(batch_size, max_seq_len, d_model)*
66 | - `mask (~torch.ByteTensor)`: shape *(batch_size, max_seq_len)*
67 |
68 | *transformer_encoder.utils.PositionalEncoding(d_model, dropout=0.1, max_len=5000)*
69 |
70 | - `d_model`: same as TransformerEncoder
71 | - `dropout`: dropout rate (defaults to 0.1)
72 | - `max_len`: max sequence length (defaults to 5000)
73 |
74 | *transformer_encoder.utils.PositionalEncoding.forward(x)*
75 |
76 | - `x (~torch.FloatTensor)`: shape *(batch_size, max_seq_len, d_model)*
77 |
78 | *transformer_encoder.utils.WarmupOptimizer(base_optimizer, d_model, scale_factor, warmup_steps)*
79 |
80 | - `base_optimizer (~torch.optim.Optimzier)`: e.g. adam optimzier
81 | - `d_model`: equals d_model in TransformerEncoder
82 | - `scale_factor`: scale factor of learning rate
83 | - `warmup_steps`: warming up steps
84 |
85 |
86 | ## Installation
87 | Requires `python 3.5+`, `pytorch 1.0.0+`
88 | ```
89 | pip install transformer_encoder
90 | ```
91 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 | with open("README.md", "r") as fh:
4 | long_description = fh.read()
5 |
6 | setuptools.setup(
7 | name="transformer_encoder",
8 | version="0.0.3",
9 | author="Cheng Guo",
10 | author_email="guocheng672@gmail.com",
11 | description="A pytorch implementation of transformer encoder",
12 | long_description=long_description,
13 | long_description_content_type="text/markdown",
14 | url="https://github.com/guocheng2018/transformer-encoder",
15 | packages=setuptools.find_packages(),
16 | python_requires=">=3.5",
17 | classifiers=[
18 | "Programming Language :: Python :: 3",
19 | "License :: OSI Approved :: MIT License",
20 | "Operating System :: OS Independent",
21 | ],
22 | )
23 |
--------------------------------------------------------------------------------
/tests/test_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from transformer_encoder.encoder import TransformerEncoder
4 |
5 | d_model = 512
6 | n_heads = 8
7 | batch_size = 64
8 | max_len = 100
9 | d_ff = 2048
10 | dropout = 0.1
11 | n_layers = 6
12 |
13 |
14 | def test_encoder():
15 | enc = TransformerEncoder(d_model, d_ff, n_heads=n_heads, n_layers=n_layers, dropout=dropout)
16 | x = torch.randn(batch_size, max_len, d_model)
17 | mask = torch.randn(batch_size, max_len).ge(0)
18 | out = enc(x, mask)
19 | assert x.size() == out.size()
20 |
--------------------------------------------------------------------------------
/tests/test_ff.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformer_encoder.feed_forward import FeedForward
3 |
4 | batch_size = 64
5 | max_len = 100
6 | d_model = 512
7 | d_ff = 2048
8 | dropout = 0.1
9 |
10 |
11 | def test_ff():
12 | ff = FeedForward(d_model, d_ff, dropout)
13 | x = torch.randn(batch_size, max_len, d_model)
14 | out = ff(x)
15 | assert x.size() == out.size()
16 |
--------------------------------------------------------------------------------
/tests/test_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from transformer_encoder.encoder_layer import EncoderLayer
4 | from transformer_encoder.feed_forward import FeedForward
5 | from transformer_encoder.multi_head_attention import MultiHeadAttention
6 |
7 | d_model = 512
8 | n_heads = 8
9 | batch_size = 64
10 | max_len = 100
11 | d_ff = 2048
12 | dropout = 0.1
13 |
14 |
15 | def test_enclayer():
16 | # Components
17 | mha = MultiHeadAttention(n_heads, d_model)
18 | ff = FeedForward(d_model, d_ff, dropout)
19 | enclayer = EncoderLayer(d_model, mha, ff, dropout)
20 | # Input
21 | x = torch.randn(batch_size, max_len, d_model)
22 | mask = torch.randn(batch_size, max_len).ge(0)
23 | out = enclayer(x, mask)
24 | assert x.size() == out.size()
25 |
26 |
--------------------------------------------------------------------------------
/tests/test_ln.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from transformer_encoder.layer_norm import LayerNorm
4 |
5 | d_model = 512
6 | eps = 1e-6
7 | batch_size = 64
8 | max_len = 100
9 |
10 |
11 | def test_ln():
12 | LN = LayerNorm(d_model, eps)
13 | x = torch.randn(batch_size, max_len, d_model)
14 | out = LN(x)
15 | assert x.size() == out.size()
16 |
--------------------------------------------------------------------------------
/tests/test_mha.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformer_encoder.multi_head_attention import MultiHeadAttention
3 |
4 |
5 | d_model = 512
6 | n_heads = 8
7 | batch_size = 64
8 | max_len = 100
9 |
10 |
11 | def test_mha():
12 | mha = MultiHeadAttention(n_heads, d_model)
13 | x = torch.randn(batch_size, max_len, d_model)
14 | out = mha(x, x, x)
15 | assert out.size() == x.size()
16 |
17 |
18 | def test_masked_mha():
19 | mha = MultiHeadAttention(n_heads, d_model)
20 | x = torch.randn(batch_size, max_len, d_model)
21 | mask = torch.randn(batch_size, max_len).ge(0)
22 | out = mha(x, x, x, mask)
23 | assert out.size() == x.size()
24 |
--------------------------------------------------------------------------------
/tests/test_optim.py:
--------------------------------------------------------------------------------
1 | import torch.optim as optim
2 |
3 | from transformer_encoder.encoder import TransformerEncoder
4 | from transformer_encoder.utils import WarmupOptimizer
5 |
6 | d_model = 512
7 | n_heads = 8
8 | d_ff = 2048
9 | dropout = 0.1
10 | n_layers = 6
11 |
12 | scale_factor = 1
13 | warmup_steps = 20
14 |
15 |
16 | def test_optim():
17 | enc = TransformerEncoder(d_model, d_ff, n_heads=n_heads, n_layers=n_layers, dropout=dropout)
18 | opt = WarmupOptimizer(optim.Adam(enc.parameters()), d_model, scale_factor, warmup_steps)
19 | assert type(opt.rate(step=1)) is float # step starts from 1
20 | opt.step()
21 |
--------------------------------------------------------------------------------
/tests/test_pe.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from transformer_encoder.utils import PositionalEncoding
4 |
5 | d_model = 512
6 | dropout = 0.1
7 | max_len = 100
8 | batch_size = 64
9 |
10 |
11 | def test_pe():
12 | PE = PositionalEncoding(d_model, dropout=dropout, max_len=max_len)
13 | embeds = torch.randn(batch_size, max_len, d_model) # (batch_size, max_len, d_model)
14 | out = PE(embeds)
15 | assert embeds.size() == out.size()
16 |
--------------------------------------------------------------------------------
/transformer_encoder/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Transformer Encoder
3 |
4 | For more information: https://github.com/guocheng2018/transformer-encoder
5 | """
6 | __version__ = "0.0.3"
7 |
8 | from .encoder import TransformerEncoder
9 |
--------------------------------------------------------------------------------
/transformer_encoder/encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .encoder_layer import EncoderLayer
5 | from .feed_forward import FeedForward
6 | from .layer_norm import LayerNorm
7 | from .multi_head_attention import MultiHeadAttention
8 | from .utils import clones
9 |
10 |
11 | class Encoder(nn.Module):
12 | """Core encoder is a stack of N layers"""
13 |
14 | def __init__(self, layer: EncoderLayer, N: int):
15 | super(Encoder, self).__init__()
16 | self.layers = clones(layer, N)
17 | self.norm = LayerNorm(layer.size)
18 |
19 | def forward(self, x: torch.FloatTensor, mask: torch.ByteTensor) -> torch.FloatTensor:
20 | """Pass the input (and mask) through each layer in turn."""
21 | for layer in self.layers:
22 | x = layer(x, mask)
23 | return self.norm(x)
24 |
25 |
26 | class TransformerEncoder(nn.Module):
27 | """The encoder of transformer
28 |
29 | Args:
30 | `n_layers`: number of stacked encoder layers
31 | `d_model`: model dimension
32 | `d_ff`: hidden dimension of feed forward layer
33 | `n_heads`: number of heads of self-attention
34 | `dropout`: dropout rate, default 0.1
35 | """
36 |
37 | def __init__(self, d_model: int, d_ff: int, n_heads: int = 1, n_layers: int = 1,
38 | dropout: float = 0.1):
39 | super(TransformerEncoder, self).__init__()
40 | self.multi_headed_attention = MultiHeadAttention(n_heads, d_model, dropout)
41 | self.feed_forward = FeedForward(d_model, d_ff, dropout)
42 | self.encoder_layer = EncoderLayer(d_model, self.multi_headed_attention, self.feed_forward, dropout)
43 | self.encoder = Encoder(self.encoder_layer, n_layers)
44 | self.reset_parameters()
45 |
46 | def reset_parameters(self):
47 | for p in self.parameters():
48 | if p.dim() > 1:
49 | nn.init.xavier_uniform_(p)
50 |
51 | def forward(self, x: torch.FloatTensor, mask: torch.ByteTensor) -> torch.FloatTensor:
52 | return self.encoder(x, mask)
53 |
--------------------------------------------------------------------------------
/transformer_encoder/encoder_layer.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from .feed_forward import FeedForward
7 | from .layer_norm import LayerNorm
8 | from .multi_head_attention import MultiHeadAttention
9 | from .utils import clones
10 |
11 |
12 | class EncoderLayer(nn.Module):
13 | """Encoder is made up of self-attn and feed forward"""
14 |
15 | def __init__(self, size: int, self_attn: MultiHeadAttention, feed_forward: FeedForward, dropout: float):
16 | super(EncoderLayer, self).__init__()
17 | self.self_attn = self_attn
18 | self.feed_forward = feed_forward
19 | self.sublayer = clones(SublayerConnection(size, dropout), 2)
20 | self.size = size
21 |
22 | def forward(self, x: torch.FloatTensor, mask: torch.ByteTensor) -> torch.FloatTensor:
23 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
24 | return self.sublayer[1](x, self.feed_forward)
25 |
26 |
27 | class SublayerConnection(nn.Module):
28 | """
29 | A residual connection followed by a layer norm.
30 | Note for code simplicity the norm is first as opposed to last.
31 | """
32 |
33 | def __init__(self, size: int, dropout: float):
34 | super(SublayerConnection, self).__init__()
35 | self.norm = LayerNorm(size)
36 | self.dropout = nn.Dropout(dropout)
37 |
38 | def forward(self, x: torch.FloatTensor, sublayer: Union[MultiHeadAttention, FeedForward]) -> torch.FloatTensor:
39 | """Apply residual connection to any sublayer with the same size."""
40 | return x + self.dropout(sublayer(self.norm(x)))
41 |
--------------------------------------------------------------------------------
/transformer_encoder/feed_forward.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class FeedForward(nn.Module):
7 | def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
8 | super(FeedForward, self).__init__()
9 | self.w_1 = nn.Linear(d_model, d_ff)
10 | self.w_2 = nn.Linear(d_ff, d_model)
11 | self.dropout = nn.Dropout(dropout)
12 |
13 | def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
14 | """
15 | Args:
16 | `x`: shape (batch_size, max_len, d_model)
17 |
18 | Returns:
19 | same shape as input x
20 | """
21 | return self.w_2(self.dropout(F.relu(self.w_1(x))))
22 |
--------------------------------------------------------------------------------
/transformer_encoder/layer_norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class LayerNorm(nn.Module):
6 | def __init__(self, features: int, eps: float = 1e-6):
7 | # features = d_model
8 | super(LayerNorm, self).__init__()
9 | self.a = nn.Parameter(torch.ones(features))
10 | self.b = nn.Parameter(torch.zeros(features))
11 | self.eps = eps
12 |
13 | def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
14 | mean = x.mean(-1, keepdim=True)
15 | std = x.std(-1, keepdim=True)
16 | return self.a * (x - mean) / (std + self.eps) + self.b
17 |
--------------------------------------------------------------------------------
/transformer_encoder/multi_head_attention.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Optional, Tuple, Any
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from .utils import clones
9 |
10 |
11 | class ScaledDotProductAttention(nn.Module):
12 | def __init__(self):
13 | super(ScaledDotProductAttention, self).__init__()
14 |
15 | def forward(self, query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor,
16 | mask: Optional[torch.ByteTensor] = None, dropout: Optional[nn.Dropout] = None) -> Tuple[
17 | torch.Tensor, Any]:
18 | """
19 | Args:
20 | `query`: shape (batch_size, n_heads, max_len, d_q)
21 | `key`: shape (batch_size, n_heads, max_len, d_k)
22 | `value`: shape (batch_size, n_heads, max_len, d_v)
23 | `mask`: shape (batch_size, 1, 1, max_len)
24 | `dropout`: nn.Dropout
25 |
26 | Returns:
27 | `weighted value`: shape (batch_size, n_heads, max_len, d_v)
28 | `weight matrix`: shape (batch_size, n_heads, max_len, max_len)
29 | """
30 | d_k = query.size(-1) # d_k = d_model / n_heads
31 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) # B*H*L*L
32 | if mask is not None:
33 | scores = scores.masked_fill(mask.eq(0), -1e9)
34 | p_attn = F.softmax(scores, dim=-1) # B*H*L*L
35 | if dropout is not None:
36 | p_attn = dropout(p_attn)
37 | return torch.matmul(p_attn, value), p_attn
38 |
39 |
40 | class MultiHeadAttention(nn.Module):
41 | def __init__(self, n_heads: int, d_model: int, dropout: float = 0.1):
42 | super(MultiHeadAttention, self).__init__()
43 | assert d_model % n_heads == 0
44 | # We assume d_v always equals d_k
45 | self.d_k = d_model // n_heads
46 | self.h = n_heads
47 | self.linears = clones(nn.Linear(d_model, d_model), 4)
48 | self.sdpa = ScaledDotProductAttention()
49 | self.attn = None
50 | self.dropout = nn.Dropout(p=dropout)
51 |
52 | def forward(self, query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor,
53 | mask: Optional[torch.ByteTensor] = None) -> torch.FloatTensor:
54 | """
55 | Args:
56 | `query`: shape (batch_size, max_len, d_model)
57 | `key`: shape (batch_size, max_len, d_model)
58 | `value`: shape (batch_size, max_len, d_model)
59 | `mask`: shape (batch_size, max_len)
60 |
61 | Returns:
62 | shape (batch_size, max_len, d_model)
63 | """
64 | if mask is not None:
65 | # Same mask applied to all h heads. B*1*1*L
66 | mask = mask.unsqueeze(1).unsqueeze(1)
67 | batch_size = query.size(0)
68 |
69 | # 1) Do all the linear projections in batch from d_model => h x d_k
70 | query, key, value = [l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2) for l, x in
71 | zip(self.linears, (query, key, value))]
72 |
73 | # 2) Apply attention on all the projected vectors in batch.
74 | # x: B x H x L x D_v
75 | x, self.attn = self.sdpa(query, key, value, mask=mask, dropout=self.dropout)
76 |
77 | # 3) "Concat" using a view and apply a final linear.
78 | x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)
79 | return self.linears[-1](x)
80 |
--------------------------------------------------------------------------------
/transformer_encoder/utils/__init__.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import torch.nn as nn
4 |
5 | from .warmup_optimizer import WarmupOptimizer
6 | from .positional_encoding import PositionalEncoding
7 |
8 |
9 | def clones(module, N):
10 | """Produce N identical layers."""
11 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
12 |
--------------------------------------------------------------------------------
/transformer_encoder/utils/positional_encoding.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class PositionalEncoding(nn.Module):
8 | def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
9 | super(PositionalEncoding, self).__init__()
10 | self.dropout = nn.Dropout(p=dropout)
11 |
12 | # Compute the positional encodings once in log space.
13 | pe = torch.zeros(max_len, d_model)
14 | position = torch.arange(0, max_len).unsqueeze(1).float()
15 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
16 | pe[:, 0::2] = torch.sin(position * div_term)
17 | pe[:, 1::2] = torch.cos(position * div_term)
18 | pe = pe.unsqueeze(0)
19 | self.register_buffer("pe", pe)
20 |
21 | def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
22 | """
23 | Args:
24 | x: `embeddings`, shape (batch, max_len, d_model)
25 |
26 | Returns:
27 | `encoder input`, shape (batch, max_len, d_model)
28 | """
29 | x = x + self.pe[:, : x.size(1)]
30 | return self.dropout(x)
31 |
--------------------------------------------------------------------------------
/transformer_encoder/utils/warmup_optimizer.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch.optim as optim
4 |
5 |
6 | class WarmupOptimizer:
7 | """Optim wrapper that implements rate."""
8 |
9 | def __init__(self, base_optimizer: optim.Optimizer, d_model: int, scale_factor: float, warmup_steps: int):
10 | self.base_optimizer = base_optimizer
11 | self.warmup_steps = warmup_steps
12 | self.scale_factor = scale_factor
13 | self.d_model = d_model
14 | self._step = 0
15 | self._rate = 0
16 |
17 | def step(self):
18 | """Update parameters and rate"""
19 | self._step += 1
20 | self._rate = self.rate()
21 | for p in self.base_optimizer.param_groups:
22 | p["lr"] = self._rate
23 | self.base_optimizer.step()
24 |
25 | def zero_grad(self):
26 | self.base_optimizer.zero_grad()
27 |
28 | def rate(self, step: Optional[int] = None) -> float:
29 | """Implement `lrate` above"""
30 | if step is None:
31 | step = self._step
32 | return self.scale_factor * self.d_model ** (-0.5) * min(step ** (-0.5), step * self.warmup_steps ** (-1.5))
33 |
--------------------------------------------------------------------------------