├── .gitignore ├── LICENSE ├── README.md └── t5_pytorch ├── __init__.py └── t5_pytorch.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Henry Shippole 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## T5 - PyTorch (WIP) 2 | A PyTorch implementation of [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683). You can find the official T5x repository by Google [here](https://github.com/google-research/t5x). 3 | 4 | ### There is a small bug with dimensionality which needs to be resolved if someone wants to open a PR. I will not be able to personally get to this until later. 5 | 6 | ## Acknowledgement 7 | 8 | Phil Wang (lucidrains) advised and provided review for this implementation. [Please be sure to follow and support his work](https://github.com/lucidrains?tab=repositories). 9 | 10 | ## Usage 11 | 12 | ```python 13 | import torch 14 | from t5_pytorch import T5 15 | 16 | model = T5( 17 | dim = 768, 18 | enc_num_tokens = 512, 19 | enc_depth = 6, 20 | enc_heads = 12, 21 | enc_dim_head = 64, 22 | enc_mlp_mult = 4, 23 | dec_num_tokens = 512, 24 | dec_depth = 6, 25 | dec_heads = 12, 26 | dec_dim_head = 64, 27 | dec_mlp_mult = 4, 28 | dropout = 0., 29 | tie_token_emb = True 30 | ) 31 | 32 | src = torch.randint(0, 512, (1, 1024)) 33 | src_mask = torch.ones_like(src).bool() 34 | tgt = torch.randint(0, 512, (1, 1024)) 35 | 36 | output = model(src, tgt, mask = src_mask) 37 | 38 | print(output.shape) #torch.Size([1, 1024, 512]) 39 | ``` 40 | 41 | ## Abstract 42 | 43 | Transfer learning, where a model is first pre-trained on a data-rich task before being fine-tuned on a downstream task, has emerged as a powerful technique in natural language processing (NLP). The effectiveness of transfer learning has given rise to a diversity of approaches, methodology, and practice. In this paper, we explore the landscape of transfer learning techniques for NLP by introducing a unified framework that converts all text-based language problems into a text-to-text format. Our systematic study compares pre-training objectives, architectures, unlabeled data sets, transfer approaches, and other factors on dozens of language understanding tasks. By combining the insights from our exploration with scale and our new ``Colossal Clean Crawled Corpus'', we achieve state-of-the-art results on many benchmarks covering summarization, question answering, text classification, and more. To facilitate future work on transfer learning for NLP, we release our data set, pre-trained models, and code. 44 | 45 | 46 | ## Citations 47 | 48 | ```bibtex 49 | @misc{https://doi.org/10.48550/arxiv.1910.10683, 50 | doi = {10.48550/ARXIV.1910.10683}, 51 | 52 | url = {https://arxiv.org/abs/1910.10683}, 53 | 54 | author = {Raffel, Colin and Shazeer, Noam and Roberts, Adam and Lee, Katherine and Narang, Sharan and Matena, Michael and Zhou, Yanqi and Li, Wei and Liu, Peter J.}, 55 | 56 | keywords = {Machine Learning (cs.LG), Computation and Language (cs.CL), Machine Learning (stat.ML), FOS: Computer and information sciences, FOS: Computer and information sciences}, 57 | 58 | title = {Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer}, 59 | 60 | publisher = {arXiv}, 61 | 62 | year = {2019}, 63 | 64 | copyright = {arXiv.org perpetual, non-exclusive license} 65 | } 66 | ``` 67 | -------------------------------------------------------------------------------- /t5_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from t5_pytorch.t5_pytorch import T5 -------------------------------------------------------------------------------- /t5_pytorch/t5_pytorch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | import math 6 | 7 | from einops import rearrange 8 | 9 | def exists(val): 10 | return val is not None 11 | 12 | def default(val, d): 13 | return val if exists(val) else d 14 | 15 | # residual wrapper 16 | 17 | class Residual(nn.Module): 18 | def __init__(self, fn): 19 | super().__init__() 20 | self.fn = fn 21 | 22 | def forward(self, x, **kwargs): 23 | return self.fn(x, **kwargs) + x 24 | 25 | # pre-normalization wrapper 26 | # they use layernorm without bias 27 | 28 | class T5LayerNorm(nn.Module): 29 | def __init__(self, dim): 30 | super().__init__() 31 | self.gamma = nn.Parameter(torch.ones(dim)) 32 | self.register_buffer("beta", torch.zeros(dim)) 33 | 34 | def forward(self, x): 35 | return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) 36 | 37 | class PreNorm(nn.Module): 38 | def __init__(self, dim, fn): 39 | super().__init__() 40 | self.norm = T5LayerNorm(dim) 41 | self.fn = fn 42 | 43 | def forward(self, x, **kwargs): 44 | return self.fn(self.norm(x), **kwargs) 45 | 46 | # feedforward layer 47 | 48 | class FeedForward(nn.Module): 49 | def __init__(self, dim, mult = 4, dropout = 0.): 50 | super().__init__() 51 | inner_dim = int(dim * mult) 52 | self.net = nn.Sequential( 53 | nn.Linear(dim, inner_dim), 54 | nn.ReLU(), 55 | nn.Dropout(dropout), # optional dropout 56 | nn.Linear(inner_dim, dim) 57 | ) 58 | 59 | def forward(self, x): 60 | return self.net(x) 61 | 62 | # T5 relative positional bias 63 | 64 | class T5RelativePositionBias(nn.Module): 65 | def __init__(self, scale, causal = False, num_buckets = 32, max_distance = 128, heads = 12): 66 | super().__init__() 67 | self.scale = scale 68 | self.causal = causal 69 | self.num_buckets = num_buckets 70 | self.max_distance = max_distance 71 | self.relative_attention_bias = nn.Embedding(num_buckets, heads) 72 | 73 | @staticmethod 74 | def _relative_position_bucket(relative_position, causal = True, num_buckets = 32, max_distance = 128): 75 | ret = 0 76 | n = -relative_position 77 | if not causal: 78 | num_buckets //= 2 79 | ret += (n < 0).long() * num_buckets 80 | n = torch.abs(n) 81 | else: 82 | n = torch.max(n, torch.zeros_like(n)) 83 | 84 | max_exact = num_buckets // 2 85 | is_small = n < max_exact 86 | 87 | val_if_large = max_exact + ( 88 | torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) 89 | ).long() 90 | val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) 91 | 92 | ret += torch.where(is_small, n, val_if_large) 93 | return ret 94 | 95 | def forward(self, qk_dots): 96 | i, j, device = *qk_dots.shape[-2:], qk_dots.device 97 | q_pos = torch.arange(j - i, j, dtype = torch.long, device = device) 98 | k_pos = torch.arange(j, dtype = torch.long, device = device) 99 | rel_pos = k_pos[None, :] - q_pos[:, None] 100 | rp_bucket = self._relative_position_bucket( 101 | rel_pos, 102 | causal = self.causal, 103 | num_buckets = self.num_buckets, 104 | max_distance = self.max_distance 105 | ) 106 | values = self.relative_attention_bias(rp_bucket) 107 | bias = rearrange(values, 'i j h -> h i j') 108 | return qk_dots + (bias * self.scale) 109 | 110 | # T5 Self Attention 111 | 112 | class T5SelfAttention(nn.Module): 113 | def __init__( 114 | self, 115 | *, 116 | dim, 117 | heads = 12, 118 | dim_head = 64, 119 | causal = False, 120 | dropout = 0. 121 | ): 122 | super().__init__() 123 | inner_dim = dim_head * heads 124 | self.heads = heads 125 | self.scale = dim_head ** -0.5 126 | self.causal = causal 127 | 128 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 129 | self.to_k = nn.Linear(dim, inner_dim, bias = False) 130 | self.to_v = nn.Linear(dim, inner_dim, bias = False) 131 | self.to_out = nn.Linear(inner_dim, dim) 132 | 133 | self.relative_position_bias = T5RelativePositionBias( 134 | scale = dim_head ** -0.5, 135 | causal = causal, 136 | heads = heads 137 | ) 138 | 139 | self.dropout = nn.Dropout(dropout) 140 | 141 | def forward(self, x, mask = None): 142 | b, n, _, h = *x.shape, self.heads 143 | q, k, v = self.to_q(x), self.to_k(x), self.to_v(x) 144 | 145 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) 146 | 147 | q = q * self.scale 148 | 149 | sim = torch.einsum('b h i d, b h j d -> b h i j', q, k) 150 | 151 | sim = self.relative_position_bias(sim) 152 | 153 | # mask 154 | 155 | mask_value = -torch.finfo(sim.dtype).max 156 | 157 | if mask is not None: 158 | sim = sim.masked_fill_(~mask, mask_value) 159 | 160 | if self.causal: 161 | i, j = sim.shape[-2:] 162 | causal_mask = torch.ones((i, j), dtype = torch.bool, device = x.device).triu(j - i + 1) 163 | sim = sim.masked_fill(causal_mask, mask_value) 164 | 165 | # attention 166 | 167 | attn = sim.softmax(dim = -1) 168 | attn = self.dropout(attn) 169 | 170 | # aggregate 171 | 172 | out = torch.einsum('b h i j, b h j d -> b h i d', attn, v) 173 | 174 | # merge heads 175 | 176 | out = rearrange(out, 'b h n d -> b n (h d)') 177 | 178 | # combine heads and linear output 179 | 180 | return self.to_out(out) 181 | 182 | # T5 Cross Attention 183 | 184 | class T5CrossAttention(nn.Module): 185 | def __init__( 186 | self, 187 | *, 188 | dim, 189 | context_dim = None, 190 | heads = 12, 191 | dim_head = 64, 192 | dropout = 0. 193 | ): 194 | super().__init__() 195 | inner_dim = dim_head * heads 196 | context_dim = default(context_dim, dim) 197 | 198 | self.heads = heads 199 | self.scale = dim_head ** -0.5 200 | 201 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 202 | self.to_k = nn.Linear(context_dim, inner_dim, bias = False) 203 | self.to_v = nn.Linear(context_dim, inner_dim, bias = False) 204 | self.to_out = nn.Linear(inner_dim, dim) 205 | 206 | # self.relative_position_bias = T5RelativePositionBias( 207 | # scale = dim_head ** -0.5, 208 | # causal = False, 209 | # heads = heads 210 | # ) 211 | 212 | self.dropout = nn.Dropout(dropout) 213 | 214 | def forward(self, x, context, mask = None, context_mask = None): 215 | b, n, _, h = *x.shape, self.heads 216 | 217 | kv_input = default(context, x) 218 | 219 | q, k, v = self.to_q(x), self.to_k(kv_input), self.to_v(kv_input) 220 | 221 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) 222 | 223 | q = q * self.scale 224 | 225 | sim = torch.einsum('b h i d, b h j d -> b h i j', q, k) 226 | 227 | #sim = self.relative_position_bias(sim) 228 | 229 | # mask 230 | 231 | mask_value = -torch.finfo(sim.dtype).max 232 | 233 | if mask is not None: 234 | sim = sim.masked_fill_(~mask, mask_value) 235 | 236 | if context_mask is not None: 237 | sim = sim.masked_fill_(~context_mask[:, None, :], mask_value) 238 | 239 | # attention 240 | 241 | attn = sim.softmax(dim = -1) 242 | attn = self.dropout(attn) 243 | 244 | # aggregate 245 | 246 | out = torch.einsum('b h i j, b h j d -> b h i d', attn, v) 247 | 248 | # merge heads 249 | 250 | out = rearrange(out, 'b h n d -> b n (h d)') 251 | 252 | # combine heads and linear output 253 | 254 | return self.to_out(out) 255 | 256 | # T5 Encoder 257 | 258 | class T5Encoder(nn.Module): 259 | def __init__( 260 | self, 261 | *, 262 | dim, 263 | num_tokens, 264 | #max_seq_len, 265 | depth, 266 | heads = 12, 267 | dim_head = 64, 268 | causal = False, 269 | mlp_mult = 4, 270 | dropout = 0. 271 | ): 272 | super().__init__() 273 | self.token_emb = nn.Embedding(num_tokens, dim) 274 | #self.pos_emb = nn.Embedding(max_seq_len, dim) 275 | 276 | self.layer = nn.ModuleList([]) 277 | for _ in range(depth): 278 | self.layer.append(nn.ModuleList([ 279 | Residual(PreNorm(dim, T5SelfAttention(dim = dim, heads = heads, dim_head = dim_head, causal = causal, dropout = dropout))), 280 | Residual(PreNorm(dim, FeedForward(dim = dim, mult = mlp_mult, dropout = dropout))), 281 | ])) 282 | 283 | self.final_norm = T5LayerNorm(dim) 284 | 285 | def forward(self, x, mask = None): 286 | x = self.token_emb(x) 287 | #x = x + self.pos_emb(torch.arange(x.shape[1], device = x.device)) 288 | 289 | for attn, mlp in self.layer: 290 | x = attn(x, mask = mask) 291 | x = mlp(x) 292 | 293 | x = self.final_norm(x) 294 | 295 | return x 296 | 297 | # T5 Decoder 298 | 299 | class T5Decoder(nn.Module): 300 | def __init__( 301 | self, 302 | *, 303 | dim, 304 | num_tokens, 305 | #max_seq_len, 306 | depth, 307 | heads = 12, 308 | dim_head = 64, 309 | causal = True, 310 | mlp_mult = 4, 311 | dropout = 0. 312 | ): 313 | super().__init__() 314 | self.token_emb = nn.Embedding(num_tokens, dim) 315 | #self.pos_emb = nn.Embedding(max_seq_len, dim) 316 | 317 | self.layer = nn.ModuleList([]) 318 | for _ in range(depth): 319 | self.layer.append(nn.ModuleList([ 320 | Residual(PreNorm(dim, T5SelfAttention(dim = dim, heads = heads, dim_head = dim_head, causal = causal, dropout = dropout))), 321 | Residual(PreNorm(dim, T5CrossAttention(dim = dim, heads = heads, dim_head = dim_head, dropout = dropout))), 322 | Residual(PreNorm(dim, FeedForward(dim = dim, mult = mlp_mult, dropout = dropout))), 323 | ])) 324 | 325 | self.final_norm = T5LayerNorm(dim) 326 | 327 | def forward(self, x, context, mask = None, context_mask = None): 328 | x = self.token_emb(x) 329 | #x = x + self.pos_emb(torch.arange(x.shape[1], device = x.device)) 330 | 331 | for attn, cross_attn, mlp in self.layer: 332 | x = attn(x, mask = mask) 333 | x = cross_attn(x, context = context, mask = mask, context_mask = context_mask) 334 | x = mlp(x) 335 | 336 | x = self.final_norm(x) 337 | 338 | return x 339 | 340 | # T5 341 | 342 | class T5(nn.Module): 343 | def __init__( 344 | self, 345 | *, 346 | dim, 347 | #max_seq_len, 348 | enc_num_tokens, 349 | enc_depth, 350 | enc_heads, 351 | enc_dim_head, 352 | enc_mlp_mult, 353 | dec_num_tokens, 354 | dec_depth, 355 | dec_heads, 356 | dec_dim_head, 357 | dec_mlp_mult, 358 | dropout = 0., 359 | tie_token_emb = True 360 | ): 361 | super().__init__() 362 | 363 | self.embedding = nn.Embedding(enc_num_tokens, dim) 364 | #self.pos_emb = nn.Embedding(max_seq_len, dim) 365 | 366 | self.encoder = T5Encoder( 367 | dim = dim, 368 | #max_seq_len = max_seq_len, 369 | num_tokens = enc_num_tokens, 370 | depth = enc_depth, 371 | heads = enc_heads, 372 | dim_head = enc_dim_head, 373 | mlp_mult = enc_mlp_mult, 374 | dropout = dropout 375 | ) 376 | 377 | self.decoder = T5Decoder( 378 | dim = dim, 379 | #max_seq_len= max_seq_len, 380 | num_tokens = dec_num_tokens, 381 | depth = dec_depth, 382 | heads = dec_heads, 383 | dim_head = dec_dim_head, 384 | mlp_mult = dec_mlp_mult, 385 | dropout = dropout 386 | ) 387 | 388 | self.to_logits = nn.Linear(dim, dec_num_tokens) 389 | 390 | # tie weights 391 | if tie_token_emb: 392 | self.encoder.token_emb.weight = self.decoder.token_emb.weight 393 | 394 | def forward(self, src, tgt, mask = None, context_mask = None): 395 | x = self.embedding(src) 396 | #x = x + self.pos_emb(torch.arange(x.shape[1], device = x.device)) 397 | x = self.encoder(src, mask = mask) 398 | x = self.decoder(tgt, x, mask = mask, context_mask = context_mask) 399 | x = self.to_logits(x) 400 | return x 401 | 402 | 403 | # if __name__ == '__main__': 404 | 405 | # model = T5( 406 | # dim = 768, 407 | # #max_seq_len = 1024, 408 | # enc_num_tokens = 512, 409 | # enc_depth = 6, 410 | # enc_heads = 12, 411 | # enc_dim_head = 64, 412 | # enc_mlp_mult = 4, 413 | # dec_num_tokens = 512, 414 | # dec_depth = 6, 415 | # dec_heads = 12, 416 | # dec_dim_head = 64, 417 | # dec_mlp_mult = 4, 418 | # dropout = 0., 419 | # tie_token_emb = True 420 | # ) 421 | 422 | # src = torch.randint(0, 512, (1, 1024)) 423 | # src_mask = torch.ones_like(src).bool() 424 | # tgt = torch.randint(0, 512, (1, 1024)) 425 | 426 | # loss = model(src, tgt, mask = src_mask) 427 | 428 | # print(loss.shape) #torch.Size([1, 1024, 512]) --------------------------------------------------------------------------------