├── gtrxl_torch ├── __init__.py └── gtrxl_torch.py ├── .gitignore ├── setup.py ├── LICENSE └── README.md /gtrxl_torch/__init__.py: -------------------------------------------------------------------------------- 1 | from gtrxl_torch.gtrxl_torch import GTrXL 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | dist/ 2 | gtrxl_torch.egg-info/ 3 | build/ 4 | __pycache__/ 5 | 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup(name="gtrxl-torch", 4 | version='0.1.7', 5 | license="MIT", 6 | description="Gated-Transformer XL - PyTorch", 7 | author="Alan Tessier", 8 | author_email="alantessier97@gmail.com", 9 | url="https://github.com/alantess/gtrxl-torch", 10 | keywords=[ 11 | "transformer", "computer vision", "deep learning", 12 | "artifical intelligence" 13 | ], 14 | install_requires=["torch>=1.6"], 15 | classifiers=[ 16 | 'Development Status :: 4 - Beta', 17 | 'Intended Audience :: Developers', 18 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 19 | 'License :: OSI Approved :: MIT License', 20 | 'Programming Language :: Python :: 3.6', 21 | ], 22 | packages=find_packages(exclude=["examples"])) 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Alan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Gated Transformer Model for Computer Vision 2 | 3 | ## Install 4 | ```bash 5 | $ pip install gtrxl-torch 6 | ``` 7 | 8 | ## Implementation 9 | ```python 10 | from gtrxl_torch.gtrxl_torch import GTrXL 11 | import torch 12 | model = GTrXL( 13 | d_model=512, 14 | nheads=4, 15 | transformer_layers=1 16 | ) 17 | input = torch.randn(32,16,512) 18 | output = model(input) 19 | ``` 20 | 21 | ### Output Dimensions 22 | Dimension ➯ [**Sequence**, **Batch**, **Memory Size**] 23 | ### Saving Model 24 | ```python 25 | model.save() 26 | ``` 27 | ### Loading Model 28 | ```python 29 | model.load() 30 | ``` 31 | 32 | ## Parameters 33 | - `d_model`: int. 34 | The number of expected features in the encoder/decoder inputs 35 | - `nheads`: int. 36 | The number of heads in the multiheadattention models 37 | - `transformer_layers`: int. 38 | Number of Transformer blocks. 39 | - `hidden_dims`: int. 40 | Number of hidden neurons for the postion wise MLP. 41 | - `n_layers`: int. 42 | RNN (GRU) layers. 43 | - `layer_norm_eps`: float, default `1e-5`. 44 | The eps value in layer normalization components. 45 | - `batch_first`: bool, default `False`. 46 | (N, S, E) if batch first. 47 | - `chkpt_dir`: str default `models`. 48 | Directory name where model is saved. 49 | - `activation`: str, default `relu`. 50 | Activation function for MLP. 51 | - `network_name`: str, default `network.pt`. 52 | Name of the model (file) you're saving. 53 | 54 | 55 | ## Resources 56 | - *Alterations to the transformer model (**GTrXL**)* ➱ [Click Here](https://arxiv.org/abs/1910.06764) 57 | 58 | -------------------------------------------------------------------------------- /gtrxl_torch/gtrxl_torch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch as T 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | from typing import Optional 7 | from torch.nn import TransformerEncoder, TransformerEncoderLayer 8 | from torch import Tensor 9 | import math 10 | ''' 11 | Positional Encoding : takes a 2d tensor --> 3d tensor 12 | Injects some information on the relevant position of the img in the sequence 13 | ''' 14 | 15 | 16 | class PositionalEncoding(nn.Module): 17 | def __init__(self, d_model, dropout=0.1, max_len=1024): 18 | super(PositionalEncoding, self).__init__() 19 | self.dropout = nn.Dropout(p=dropout) 20 | pe = T.zeros(max_len, d_model) 21 | position = T.arange(0, max_len, dtype=T.float).unsqueeze(1) 22 | div_term = T.exp( 23 | T.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 24 | pe[:, 0::2] = T.sin(position * div_term) 25 | pe[:, 1::2] = T.cos(position * div_term) 26 | pe = pe.unsqueeze(0).transpose(0, 1) 27 | self.register_buffer('pe', pe) 28 | 29 | def forward(self, x): 30 | x = x + self.pe[:x.size(0), :] 31 | return self.dropout(x) 32 | 33 | 34 | ''' 35 | Recreate the transfomer layers done in the following paper 36 | https://arxiv.org/pdf/1910.06764.pdf 37 | ''' 38 | 39 | 40 | class TEL(TransformerEncoderLayer): 41 | def __init__(self, 42 | d_model, 43 | nhead, 44 | n_layers=1, 45 | dim_feedforward=256, 46 | activation="relu", 47 | dropout=0, 48 | layer_norm_eps=1e-5, 49 | batch_first=False): 50 | super().__init__(d_model, nhead, dim_feedforward, dropout, activation, 51 | layer_norm_eps, batch_first) 52 | # 2 GRUs are needed - 1 for the beginning / 1 at the end 53 | self.gru_1 = nn.GRU(d_model, 54 | d_model, 55 | num_layers=n_layers, 56 | batch_first=True) 57 | self.gru_2 = nn.GRU(input_size=d_model, 58 | hidden_size=d_model, 59 | num_layers=n_layers, 60 | batch_first=True) 61 | 62 | def forward(self, 63 | src: Tensor, 64 | src_mask: Optional[Tensor] = None, 65 | src_key_padding_mask: Optional[Tensor] = None) -> Tensor: 66 | h = (src).sum(dim=1).unsqueeze(dim=0) 67 | src = self.norm1(src) 68 | out = self.self_attn(src, 69 | src, 70 | src, 71 | attn_mask=src_mask, 72 | key_padding_mask=src_key_padding_mask)[0] 73 | 74 | out, h = self.gru_1(out, h) 75 | out = self.norm2(out) 76 | out = self.activation(self.linear1(out)) 77 | out = self.activation(self.linear2(out)) 78 | out, h = self.gru_2(out, h) 79 | return out 80 | 81 | 82 | ''' 83 | Implementation of transfomer model using GRUs 84 | ''' 85 | 86 | 87 | class GTrXL(nn.Module): 88 | def __init__(self, 89 | d_model, 90 | nheads, 91 | transformer_layers, 92 | hidden_dims=256, 93 | n_layers=1, 94 | layer_norm_eps=1e-5, 95 | batch_first=False, 96 | chkpt_dir="models", 97 | activation='relu', 98 | network_name='network.pt'): 99 | super(GTrXL, self).__init__() 100 | # Module layers 101 | self.embed = PositionalEncoding(d_model) 102 | encoded = TEL(d_model, 103 | nheads, 104 | n_layers, 105 | dim_feedforward=hidden_dims, 106 | activation=activation, 107 | layer_norm_eps=layer_norm_eps, 108 | batch_first=batch_first) 109 | self.transfomer = TransformerEncoder(encoded, transformer_layers) 110 | self.file = os.path.join(chkpt_dir, network_name) 111 | 112 | def forward(self, x): 113 | x = self.embed(x) 114 | x = self.transfomer(x) 115 | return x 116 | 117 | def save(self): 118 | T.save(self.state_dict(), self.file) 119 | 120 | def load(self): 121 | self.load_state_dict(T.load(self.file)) 122 | --------------------------------------------------------------------------------