├── .gitignore ├── LICENSE ├── README.md ├── pyproject.toml ├── requirements.txt ├── setup.cfg ├── src └── pytorch_beam_search │ ├── __init__.py │ ├── autoregressive │ ├── __init__.py │ ├── index.py │ ├── models.py │ └── search_algorithms.py │ └── seq2seq │ ├── __init__.py │ ├── index.py │ ├── models.py │ └── search_algorithms.py └── tests ├── gpt.py └── transformer.py /.gitignore: -------------------------------------------------------------------------------- 1 | *ipynb_checkpoints* 2 | *pycache* 3 | *pt 4 | *arch 5 | *pkl 6 | dist* 7 | *egg* 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Juan Antonio Ramirez-Orta 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Beam Search 2 | 3 | This library implements fully vectorized Beam Search, Greedy Search and Sampling for sequence models written in PyTorch. 4 | This is specially useful for tasks in Natural Language Processing, but can also be used for anything that requires 5 | generating a sequence from a sequence model. 6 | 7 | ## Usage 8 | 9 | ### A GPT-like character-level language model 10 | 11 | ```python 12 | from pytorch_beam_search import autoregressive 13 | 14 | # Create vocabulary and examples 15 | # Tokenize the way you need 16 | 17 | corpus = list("abcdefghijklmnopqrstwxyz ") 18 | # len(corpus) == 25 19 | # An Index object represents a mapping from the vocabulary 20 | # to integers (indices) to feed into the models 21 | index = autoregressive.Index(corpus) 22 | n_gram_size = 17 # 16 with an offset of 1 23 | n_grams = [corpus[i:n_gram_size + i] for i in range(len(corpus))[:-n_gram_size + 1]] 24 | 25 | # Create tensor 26 | 27 | X = index.text2tensor(n_grams) 28 | # X.shape == (n_examples, len_examples) == (25 - 17 + 1 = 9, 17) 29 | 30 | # Create and train the model 31 | 32 | model = autoregressive.TransformerEncoder(index) # just a PyTorch model 33 | model.fit(X) # basic method included 34 | 35 | # Generate new predictions 36 | 37 | new_examples = [list("new first"), list("new second")] 38 | X_new = index.text2tensor(new_examples) 39 | loss, error_rate = model.evaluate(X_new) # basic method included 40 | predictions, log_probabilities = autoregressive.beam_search(model, X_new) 41 | # every element in predictions is the list of candidates for each example 42 | output = [index.tensor2text(p) for p in predictions] 43 | output 44 | ``` 45 | 46 | ### A Transformer character sequence-to-sequence model 47 | 48 | ```python 49 | from pytorch_beam_search import seq2seq 50 | 51 | # Create vocabularies 52 | # Tokenize the way you need 53 | 54 | source = [list("abcdefghijkl"), list("mnopqrstwxyz")] 55 | target = [list("ABCDEFGHIJKL"), list("MNOPQRSTWXYZ")] 56 | # An Index object represents a mapping from the vocabulary 57 | # to integers (indices) to feed into the models 58 | source_index = seq2seq.Index(source) 59 | target_index = seq2seq.Index(target) 60 | 61 | # Create tensors 62 | 63 | X = source_index.text2tensor(source) 64 | Y = target_index.text2tensor(target) 65 | # X.shape == (n_source_examples, len_source_examples) == (2, 11) 66 | # Y.shape == (n_target_examples, len_target_examples) == (2, 12) 67 | 68 | # Create and train the model 69 | 70 | model = seq2seq.Transformer(source_index, target_index) # just a PyTorch model 71 | model.fit(X, Y, epochs=100) # basic method included 72 | 73 | # Generate new predictions 74 | 75 | new_source = [list("new first in"), list("new second in")] 76 | new_target = [list("new first out"), list("new second out")] 77 | X_new = source_index.text2tensor(new_source) 78 | Y_new = target_index.text2tensor(new_target) 79 | loss, error_rate = model.evaluate(X_new, Y_new) # basic method included 80 | predictions, log_probabilities = seq2seq.beam_search(model, X_new) 81 | output = [target_index.tensor2text(p) for p in predictions] 82 | output 83 | ``` 84 | 85 | ## Features 86 | 87 | ### Algorithms 88 | 89 | - The **greedy_search** function implements Greedy Search, which simply picks the most likely token at every step. This 90 | is the fastest and simplest algorithm, but can work well if the model is properly trained. 91 | - The **sample** function implements sampling from a sequence model, using the learned distribution at every step to 92 | build the output token by token. This is very useful to inspect what the model learned. 93 | - The **beam_search** function implements Beam Search, a form of pruned Breadth-First Search that expands a fixed number 94 | of the best candidates at every step. This is the slowest algorithm, but usually outperforms Greedy Search. 95 | 96 | ### Models 97 | 98 | - The **autoregressive** module implements the search algorithms and some architectures for unsupervised models that 99 | learn to predict the next token in a sequence. 100 | - **LSTM** is a simple baseline/sanity check. 101 | - **TransformerEncoder** is 102 | a [GPT](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf) 103 | -like model for state-of-the-art performance. 104 | - The **seq2seq** module implements the search algorithms and some architectures for supervised encoder-decoder models 105 | that learn how to map sequences to sequences. 106 | - **LSTM** is a sequence-to-sequence unidirectional LSTM model similar to the one 107 | in [Cho et al., 2014](https://arxiv.org/pdf/1406.1078.pdf), useful as a simple baseline/sanity check. 108 | - **ReversingLSTM** is a sequence-to-sequence unidirectional LSTM model that reverses the order of the tokens in the 109 | input, similar to the one in [Sutskever et al., 2014](https://arxiv.org/pdf/1409.3215.pdf). A bit more complex 110 | than LSTM but gives better performance. 111 | - **Transformer** is a standard [Transformer](https://arxiv.org/pdf/1706.03762.pdf) model for state-of-the-art 112 | performance. 113 | 114 | ## Installation 115 | 116 | ```shell 117 | pip install pytorch_beam_search 118 | ``` 119 | 120 | ## Contribute 121 | 122 | - [Issue Tracker](https://github.com/jarobyte91/pytorch_beam_search/issues) 123 | - [Pull Requests](https://github.com/jarobyte91/pytorch_beam_search/pulls) 124 | 125 | ## License 126 | 127 | The project is licensed under the MIT License. 128 | 129 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=42", 4 | "wheel" 5 | ] 6 | build-backend = "setuptools.build_meta" 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | certifi==2020.12.5 2 | click==8.0.3 3 | joblib==1.1.0 4 | nltk==3.6.5 5 | numpy==1.20.3 6 | pandas==1.2.4 7 | python-dateutil==2.8.1 8 | pytz==2021.1 9 | regex==2021.10.23 10 | six==1.16.0 11 | torch==1.8.1 12 | tqdm==4.61.0 13 | typing-extensions==3.10.0.0 14 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = pytorch_beam_search 3 | version = 1.2.2 4 | author = Juan Ramirez-Orta 5 | author_email = jarobyte91@gmail.com 6 | description = A simple library that implements search algorithms for sequence models written in PyTorch. 7 | long_description = file: README.md 8 | long_description_content_type = text/markdown 9 | url = https://github.com/jarobyte91/pytorch_beam_search 10 | project_urls = 11 | Bug Tracker = https://github.com/jarobyte91/pytorch_beam_search/issues 12 | classifiers = 13 | Programming Language :: Python :: 3 14 | License :: OSI Approved :: MIT License 15 | Operating System :: OS Independent 16 | 17 | [options] 18 | package_dir = 19 | = src 20 | packages = find: 21 | 22 | install_requires = 23 | certifi>=2020.12.5 24 | numpy>=1.20.3 25 | pandas>=1.2.4 26 | python-dateutil>=2.8.1 27 | pytz>=2021.1 28 | six>=1.16.0 29 | torch>=1.8.1 30 | tqdm>=4.61.0 31 | typing-extensions>=3.10.0.0 32 | nltk>=3.6.5 33 | 34 | python_requires = >=3.6 35 | 36 | [options.packages.find] 37 | where = src 38 | -------------------------------------------------------------------------------- /src/pytorch_beam_search/__init__.py: -------------------------------------------------------------------------------- 1 | from . import autoregressive 2 | from . import seq2seq 3 | -------------------------------------------------------------------------------- /src/pytorch_beam_search/autoregressive/__init__.py: -------------------------------------------------------------------------------- 1 | from .index import * 2 | from .models import * 3 | from .search_algorithms import * -------------------------------------------------------------------------------- /src/pytorch_beam_search/autoregressive/index.py: -------------------------------------------------------------------------------- 1 | from tqdm.auto import tqdm 2 | import torch 3 | import torch.nn as nn 4 | from nltk import lm 5 | import re 6 | 7 | class Index(): 8 | def __init__(self, corpus, progress_bar = False): 9 | self.special_tokens = [""] 10 | if progress_bar: 11 | corpus = tqdm(corpus) 12 | self.vocabulary = lm.Vocabulary(corpus) 13 | tokens = self.special_tokens + list(sorted(self.vocabulary)) 14 | self.voc2idx = {c:i for i, c in enumerate(tokens)} 15 | self.idx2voc = {i:c for i, c in enumerate(tokens)} 16 | 17 | def __len__(self): 18 | return len(self.voc2idx) 19 | 20 | def __str__(self): 21 | return f"" 22 | 23 | def text2tensor( 24 | self, 25 | strings, 26 | progress_bar = False 27 | ): 28 | if progress_bar: 29 | iterator = tqdm(strings) 30 | else: 31 | iterator = strings 32 | m = max([len(s) for s in strings]) 33 | idx = [] 34 | for l in iterator: 35 | idx.append( 36 | [0 for i in range(m - len(l))] +\ 37 | [self.voc2idx[self.vocabulary.lookup(c)] for c in l] 38 | ) 39 | return torch.tensor(idx) 40 | 41 | def tensor2text( 42 | self, 43 | X, 44 | separator = "", 45 | end = "", 46 | progress_bar = False 47 | ): 48 | X = X.tolist() 49 | if progress_bar: 50 | X = tqdm(X) 51 | return [separator.join([self.idx2voc[i] for i in l]) for l in X] 52 | 53 | -------------------------------------------------------------------------------- /src/pytorch_beam_search/autoregressive/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.data as tud 4 | from timeit import default_timer as timer 5 | from tqdm.auto import tqdm 6 | import pandas as pd 7 | import warnings 8 | 9 | class Autoregressive(nn.Module): 10 | def print_architecture(self): 11 | """ 12 | Displays the information about the model in standard output. 13 | """ 14 | for k in self.architecture.keys(): 15 | print(f"{k.replace('_', ' ').capitalize()}: {self.architecture[k]}") 16 | print(f"Trainable parameters: {sum([p.numel() for p in self.parameters()]):,}") 17 | print() 18 | 19 | def fit( 20 | self, 21 | X_train, 22 | X_dev = None, 23 | batch_size = 128, 24 | epochs = 5, 25 | learning_rate = 10**-4, 26 | progress_bar = 0, 27 | weight_decay = 0, 28 | save_path = None 29 | ): 30 | """ 31 | A generic training method with Adam and Cross Entropy. 32 | 33 | Parameters 34 | ---------- 35 | X_train: LongTensor of shape (train_examples, train_input_length) 36 | The input sequences of the training set. 37 | 38 | X_dev: LongTensor of shape (dev_examples, dev_input_length), optional 39 | The input sequences for the development set. 40 | 41 | batch_size: int 42 | The number of examples to process in each batch. 43 | 44 | epochs: int 45 | The number of epochs of the training process. 46 | 47 | learning_rate: float 48 | The learning rate to use with Adam in the training process. 49 | 50 | weight_decay: float 51 | The weight_decay parameter of Adam (L2 penalty), useful for regularizing models. For a deeper 52 | documentation, go to https://pytorch.org/docs/stable/_modules/torch/optim/adam.html#Adam 53 | 54 | progress_bar: int 55 | Shows a tqdm progress bar, useful for tracking progress with large tensors. 56 | If equal to 0, no progress bar is shown. 57 | If equal to 1, shows a bar with one step for every epoch. 58 | If equal to 2, shows the bar when equal to 1 and also shows a bar with one step per batch for every epoch. 59 | If equal to 3, shows the bars when equal to 2 and also shows a bar to track the progress of the evaluation 60 | in the development set. 61 | 62 | save_path: string, optional 63 | Path to save the .pt file containing the model parameters when the training ends. 64 | 65 | Returns 66 | ------- 67 | performance: Pandas DataFrame 68 | DataFrame with the following columns: epoch, train_loss, train_error_rate, (optionally dev_loss and 69 | dev_error_rate), minutes, learning_rate, weight_decay, model, encoder_embedding_dimension, 70 | decoder_embedding_dimension, encoder_hidden_units, encoder_layers, decoder_hidden_units, decoder_layers, 71 | dropout, parameters and one row for each of the epochs, containing information about the training process. 72 | """ 73 | assert X_train.shape[1] > 1 74 | if X_dev is not None: 75 | dev = True 76 | else: 77 | dev = False 78 | train_dataset = tud.TensorDataset(X_train) 79 | train_loader = tud.DataLoader( 80 | train_dataset, 81 | batch_size = batch_size, 82 | shuffle = True 83 | ) 84 | criterion = nn.CrossEntropyLoss(ignore_index = 0) 85 | optimizer = torch.optim.Adam( 86 | self.parameters(), 87 | lr = learning_rate, 88 | weight_decay = weight_decay 89 | ) 90 | performance = [] 91 | start = timer() 92 | epochs_iterator = range(1, epochs + 1) 93 | if progress_bar > 0: 94 | epochs_iterator = tqdm(epochs_iterator) 95 | print("Training started") 96 | print("X_train.shape:", X_train.shape) 97 | if dev: 98 | print("X_dev.shape:", X_dev.shape) 99 | print(f"Epochs: {epochs:,}\nLearning rate: {learning_rate}\nWeight decay: {weight_decay}") 100 | header_1 = "Epoch | Train " 101 | header_2 = " | Loss | Error Rate" 102 | rule = "-" * 29 103 | if dev: 104 | header_1 += " | Development " 105 | header_2 += " | Loss | Error Rate" 106 | rule += "-" * 24 107 | header_1 += " | Minutes" 108 | header_2 += " |" 109 | rule += "-" * 10 110 | print(header_1, header_2, rule, sep = "\n") 111 | for e in epochs_iterator: 112 | self.train() 113 | losses = [] 114 | errors = [] 115 | sizes = [] 116 | train_iterator = train_loader 117 | if progress_bar > 1: 118 | train_iterator = tqdm(train_iterator) 119 | for (x, ) in train_iterator: 120 | # compute loss and backpropagate 121 | probabilities = self.forward(x[:, :-1]) 122 | y = x[:, 1:] 123 | loss = criterion(probabilities.transpose(1, 2), y) 124 | loss.backward() 125 | optimizer.step() 126 | optimizer.zero_grad() 127 | # compute accuracy 128 | predictions = probabilities.argmax(-1) 129 | batch_errors = (predictions != y) 130 | # append the results 131 | losses.append(loss.item()) 132 | errors.append(batch_errors.sum().item()) 133 | sizes.append(batch_errors.numel()) 134 | train_loss = sum(losses) / len(losses) 135 | train_error_rate = 100 * sum(errors) / sum(sizes) 136 | t = (timer() - start) / 60 137 | status_string = f"{e:>5} | {train_loss:>8.4f} | {train_error_rate:>10.3f}" 138 | status = {"epoch":e, 139 | "train_loss": train_loss, 140 | "train_error_rate": train_error_rate} 141 | if dev: 142 | dev_loss, dev_error_rate = self.evaluate( 143 | X_dev, 144 | batch_size = batch_size, 145 | progress_bar = progress_bar > 2, 146 | criterion = criterion 147 | ) 148 | status_string += f" | {dev_loss:>8.4f} | {dev_error_rate:>10.3f}" 149 | status.update( 150 | { 151 | "dev_loss": dev_loss, 152 | "dev_error_rate": dev_error_rate 153 | } 154 | ) 155 | status.update({"training_minutes": t, 156 | "learning_rate": learning_rate, 157 | "weight_decay": weight_decay}) 158 | performance.append(status) 159 | if save_path is not None: 160 | if (not dev) or (e < 2) or (dev_loss < min([p["dev_loss"]\ 161 | for p in performance[:-1]])): 162 | torch.save(self.state_dict(), save_path) 163 | status_string += f" | {t:>7.1f}" 164 | print(status_string) 165 | return pd.concat( 166 | ( 167 | pd.DataFrame(performance), 168 | pd.DataFrame([self.architecture for i in performance]) 169 | ), 170 | axis = 1 171 | ).drop(columns = "index") 172 | 173 | 174 | def evaluate( 175 | self, 176 | X, 177 | criterion = nn.CrossEntropyLoss(), 178 | batch_size = 128, 179 | progress_bar = False 180 | ): 181 | """ 182 | Evaluates the model on a dataset. 183 | 184 | Parameters 185 | ---------- 186 | X: LongTensor of shape (examples, input_length) 187 | The input sequences of the dataset. 188 | 189 | Y: LongTensor of shape (examples, output_length) 190 | The output sequences of the dataset. 191 | 192 | criterion: PyTorch module 193 | The loss function to evalue the model on the dataset, has to be able to compare self.forward(X, Y) and Y 194 | to produce a real number. 195 | 196 | batch_size: int 197 | The batch size of the evaluation loop. 198 | 199 | progress_bar: bool 200 | Shows a tqdm progress bar, useful for tracking progress with large tensors. 201 | 202 | Returns 203 | ------- 204 | loss: float 205 | The average of criterion across the whole dataset. 206 | 207 | error_rate: float 208 | The step-by-step accuracy of the model across the whole dataset. Useful as a sanity check, as it should 209 | go to zero as the loss goes to zero. 210 | 211 | """ 212 | dataset = tud.TensorDataset(X) 213 | loader = tud.DataLoader(dataset, batch_size = batch_size) 214 | self.eval() 215 | losses = [] 216 | errors = [] 217 | sizes = [] 218 | with torch.no_grad(): 219 | iterator = iter(loader) 220 | if progress_bar: 221 | iterator = tqdm(iterator) 222 | for (x,) in iterator: 223 | # compute loss 224 | probabilities = self.forward(x[:, :-1]) 225 | y = x[:, 1:] 226 | loss = criterion(probabilities.transpose(1, 2), y) 227 | # compute accuracy 228 | predictions = probabilities.argmax(-1) 229 | batch_errors = (predictions != y) 230 | # append the results 231 | losses.append(loss.item()) 232 | errors.append(batch_errors.sum().item()) 233 | sizes.append(batch_errors.numel()) 234 | loss = sum(losses) / len(losses) 235 | error_rate = 100 * sum(errors) / sum(sizes) 236 | return loss, error_rate 237 | 238 | 239 | class LSTM(Autoregressive): 240 | def __init__(self, 241 | index, 242 | embedding_dimension = 32, 243 | hidden_units = 128, 244 | layers = 2, 245 | dropout = 0.0): 246 | """ 247 | A standard autoregressive model with an LSTM network. 248 | 249 | Parameters 250 | ---------- 251 | vocabulary: set-like 252 | Set-like containing the tokens of the model. 253 | 254 | embedding_dimension: int 255 | Dimension of the embeddings to feed into the model. 256 | 257 | hidden_units: int 258 | Hidden units of the model. 259 | 260 | layers: int 261 | Hidden layers of the model. 262 | 263 | dropout: float between 0.0 and 1.0 264 | Dropout rate to apply to whole model. 265 | """ 266 | super().__init__() 267 | self.index = index 268 | self.embeddings = nn.Embedding( 269 | len(index.voc2idx), 270 | embedding_dimension 271 | ) 272 | self.rnn = nn.LSTM( 273 | input_size = embedding_dimension, 274 | hidden_size = hidden_units, 275 | num_layers = layers, 276 | dropout = dropout 277 | ) 278 | self.output_layer = nn.Linear(hidden_units, len(index.idx2voc)) 279 | self.architecture = dict( 280 | model = "Autoregressive LSTM", 281 | index = index, 282 | embedding_dimension = embedding_dimension, 283 | hidden_units = hidden_units, 284 | layers = layers, 285 | dropout = dropout 286 | ) 287 | self.print_architecture() 288 | 289 | def forward(self, X): 290 | """ 291 | Forward method of the model. 292 | 293 | Parameters 294 | ---------- 295 | X: LongTensor of shape (batch_size, sequence_length) 296 | Tensor of integers containing the inputs for the model. 297 | 298 | Returns 299 | ------- 300 | output: FloatTensor of shape (batch_size, sequence_length, len(out_vocabulary)) 301 | Tensor of floats containing the inputs for the final Softmax layer (usually integrated into the loss function). 302 | """ 303 | X = self.embeddings(X.T) 304 | rnn, (rnn_last_hidden, rnn_last_memory) = self.rnn(X) 305 | return self.output_layer(rnn.transpose(0, 1)) 306 | 307 | 308 | class TransformerEncoder(Autoregressive): 309 | def __init__( 310 | self, 311 | index, 312 | max_sequence_length = 16, 313 | embedding_dimension = 32, 314 | feedforward_dimension = 128, 315 | layers = 2, 316 | attention_heads = 2, 317 | activation = "relu", 318 | dropout = 0.0 319 | ): 320 | """ 321 | The standard PyTorch implementation of a Transformer Encoder. 322 | 323 | Parameters 324 | ---------- 325 | vocabulary: set-like 326 | Set-like containing the tokens of the model. 327 | 328 | max_sequence_length: int 329 | Maximum sequence length accepted by the model. 330 | 331 | embedding_dimension: int 332 | Dimension of the embeddings of the model. 333 | 334 | feedforward_dimension: int 335 | Dimension of the feedforward network inside the self-attention layers of the model. 336 | 337 | layers: int 338 | Hidden layers of the encoder. 339 | 340 | attention_heads: int 341 | Attention heads inside every self-attention layer of the model. 342 | 343 | activation: string 344 | Activation function of the feedforward network inside the self-attention layers of the model. Can 345 | be either 'relu' or 'gelu'. 346 | 347 | dropout: float between 0.0 and 1.0 348 | Dropout rate to apply to whole model. 349 | """ 350 | super().__init__() 351 | self.index = index 352 | self.embeddings = nn.Embedding( 353 | len(index.idx2voc), 354 | embedding_dimension 355 | ) 356 | self.positional_embeddings = nn.Embedding( 357 | max_sequence_length, 358 | embedding_dimension 359 | ) 360 | self.transformer_layer = nn.TransformerEncoderLayer( 361 | d_model = embedding_dimension, 362 | dim_feedforward = feedforward_dimension, 363 | nhead = attention_heads, 364 | activation = activation, 365 | dropout = dropout 366 | ) 367 | self.encoder = nn.TransformerEncoder( 368 | encoder_layer = self.transformer_layer, 369 | num_layers = layers 370 | ) 371 | self.output_layer = nn.Linear( 372 | embedding_dimension, 373 | len(index.idx2voc) 374 | ) 375 | self.architecture = dict( 376 | model = "Autoregressive Transformer Encoder", 377 | index = index, 378 | max_sequence_length = max_sequence_length, 379 | embedding_dimension = embedding_dimension, 380 | feedforward_dimension = feedforward_dimension, 381 | layers = layers, 382 | attention_heads = attention_heads, 383 | activation = activation, 384 | dropout = dropout 385 | ) 386 | self.print_architecture() 387 | 388 | def forward(self, X, warn_last_tokens = True): 389 | """ 390 | Forward method of the model. 391 | 392 | Parameters 393 | ---------- 394 | X: LongTensor of shape (batch_size, sequence_length) 395 | Tensor of integers containing the inputs for the model. 396 | 397 | Returns 398 | ------- 399 | output: FloatTensor of shape (batch_size, sequence_length, len(out_vocabulary)) 400 | Tensor of floats containing the inputs for the final Softmax layer (usually integrated into the loss function). 401 | """ 402 | if warn_last_tokens and X.shape[1] > self.architecture["max_sequence_length"]: 403 | warnings.warn(f"Max sequence length exceded, only using the last {self.architecture['max_sequence_length']} tokens of the input. You can disable this warning with the warn_last_tokens parameter of the forward method.", category = RuntimeWarning) 404 | X = X[:, -self.architecture["max_sequence_length"]:] 405 | X = self.embeddings(X) 406 | X_positional = torch.arange(X.shape[1], device = X.device)\ 407 | .repeat((X.shape[0], 1)) 408 | X_positional = self.positional_embeddings(X_positional) 409 | X = (X + X_positional).transpose(0, 1) 410 | mask = (torch.triu(torch.ones(X.shape[0], X.shape[0])) == 1)\ 411 | .transpose(0, 1).to(X.device) 412 | mask = mask.float().masked_fill(mask == 0, float('-inf'))\ 413 | .masked_fill(mask == 1, float(0.0)) 414 | output = self.encoder.forward(src = X, mask = mask).transpose(0, 1) 415 | return self.output_layer(output) 416 | -------------------------------------------------------------------------------- /src/pytorch_beam_search/autoregressive/search_algorithms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as tud 3 | from tqdm.auto import tqdm 4 | import warnings 5 | 6 | def greedy_search( 7 | model, 8 | X, 9 | predictions = 20, 10 | progress_bar = False 11 | ): 12 | """ 13 | Implements Greedy Search to extend the sequences given in X. The method can compute 14 | several outputs in parallel with the first dimension of X. 15 | 16 | Parameters 17 | ---------- 18 | X: LongTensor of shape (examples, length) 19 | The sequences to start the decoding process. 20 | 21 | predictions: int 22 | The number of tokens to append to X. 23 | 24 | progress_bar: bool 25 | Shows a tqdm progress bar, useful for tracking progress with large tensors. 26 | 27 | Returns 28 | ------- 29 | X: LongTensor of shape (examples, length + predictions) 30 | The sequences extended with the decoding process. 31 | 32 | probabilities: FloatTensor of length examples 33 | The estimated log-probabilities for the output sequences. They are computed by iteratively adding the 34 | probability of the next token at every step. 35 | """ 36 | with torch.no_grad(): 37 | probabilities = torch.zeros(X.shape[0])\ 38 | .to(next(model.parameters()).device) 39 | iterator = range(predictions) 40 | if progress_bar: 41 | iterator = tqdm(iterator) 42 | for i in iterator: 43 | next_probabilities = model.forward(X)[:, -1].log_softmax(-1) 44 | max_next_probabilities, next_chars = next_probabilities.max(-1) 45 | next_chars = next_chars.unsqueeze(-1) 46 | X = torch.cat((X, next_chars), axis = 1) 47 | probabilities += max_next_probabilities 48 | return X, probabilities 49 | 50 | def sample( 51 | model, 52 | X, 53 | predictions = 20, 54 | temperature = 1, 55 | progress_bar = False 56 | ): 57 | """ 58 | Samples the sequence distribution to extend the sequences given in X. The method can compute 59 | several outputs in parallel with the first dimension of X. 60 | 61 | Parameters 62 | ---------- 63 | X: LongTensor of shape (examples, length) 64 | The sequences to start the decoding process. 65 | 66 | predictions: int 67 | The number of tokens to append to X. 68 | 69 | temperature: positive float 70 | Parameter to control the freedom of the sampling. Higher values give more freedom. 71 | 72 | progress_bar: bool 73 | Shows a tqdm progress bar, useful for tracking progress with large tensors. 74 | 75 | Returns 76 | ------- 77 | X: LongTensor of shape (examples, length + predictions) 78 | The sequences extended with the decoding process. 79 | 80 | probabilities: FloatTensor of length examples 81 | The estimated log-probabilities for the output sequences. They are computed by iteratively adding the 82 | probability of the next token at every step. 83 | """ 84 | with torch.no_grad(): 85 | probabilities = torch.zeros(X.shape[0])\ 86 | .to(next(model.parameters()).device) 87 | iterator = range(predictions) 88 | if progress_bar: 89 | iterator = tqdm(iterator) 90 | for i in iterator: 91 | next_probabilities = model.forward(X)[:, -1] 92 | next_probabilities = (next_probabilities / temperature)\ 93 | .softmax(1) 94 | random = torch.rand((next_probabilities.shape[0], 1))\ 95 | .to(next(model.parameters()).device) 96 | next_chars = (next_probabilities.cumsum(1) < random)\ 97 | .sum(1, keepdims = True) 98 | probabilities += torch.gather( 99 | input = next_probabilities.log(), 100 | dim = 1, 101 | index = next_chars 102 | ).squeeze() 103 | X = torch.cat((X, next_chars), axis = 1) 104 | return X, probabilities 105 | 106 | def beam_search( 107 | model, 108 | X, 109 | predictions = 20, 110 | beam_width = 5, 111 | batch_size = 128, 112 | progress_bar = 0 113 | ): 114 | """ 115 | Implements Beam Search to extend the sequences given in X. The method can compute 116 | several outputs in parallel with the first dimension of X. 117 | 118 | Parameters 119 | ---------- 120 | X: LongTensor of shape (examples, length) 121 | The sequences to start the decoding process. 122 | 123 | predictions: int 124 | The number of tokens to append to X. 125 | 126 | beam_width: int 127 | The number of candidates to keep in the search. 128 | 129 | batch_size: int 130 | The batch size of the inner loop of the method, which relies on the beam width. 131 | 132 | progress_bar: bool 133 | Shows a tqdm progress bar, useful for tracking progress with large tensors. 134 | 135 | Returns 136 | ------- 137 | X: LongTensor of shape (examples, length + predictions) 138 | The sequences extended with the decoding process. 139 | 140 | probabilities: FloatTensor of length examples 141 | The estimated log-probabilities for the output sequences. They are computed by iteratively adding the 142 | probability of the next token at every step. 143 | """ 144 | with torch.no_grad(): 145 | # The next command can be a memory bottleneck, but can be controlled with the batch 146 | # size of the predict method. 147 | next_probabilities = model.forward(X)[:, -1, :] 148 | vocabulary_size = next_probabilities.shape[-1] 149 | probabilities, idx = next_probabilities.squeeze().log_softmax(-1)\ 150 | .topk(k = beam_width, axis = -1) 151 | X = X.repeat((beam_width, 1, 1)).transpose(0, 1)\ 152 | .flatten(end_dim = -2) 153 | next_chars = idx.reshape(-1, 1) 154 | X = torch.cat((X, next_chars), axis = -1) 155 | # This has to be minus one because we already produced a round 156 | # of predictions before the for loop. 157 | predictions_iterator = range(predictions - 1) 158 | if progress_bar > 0: 159 | predictions_iterator = tqdm(predictions_iterator) 160 | for i in predictions_iterator: 161 | dataset = tud.TensorDataset(X) 162 | loader = tud.DataLoader(dataset, batch_size = batch_size) 163 | next_probabilities = [] 164 | iterator = iter(loader) 165 | if progress_bar > 1: 166 | iterator = tqdm(iterator) 167 | for (x,) in iterator: 168 | next_probabilities.append( 169 | model.forward(x)[:, -1, :].log_softmax(-1) 170 | ) 171 | next_probabilities = torch.cat(next_probabilities, axis = 0) 172 | next_probabilities = next_probabilities.reshape( 173 | (-1, beam_width, next_probabilities.shape[-1]) 174 | ) 175 | probabilities = probabilities.unsqueeze(-1) + next_probabilities 176 | probabilities = probabilities.flatten(start_dim = 1) 177 | probabilities, idx = probabilities.topk( 178 | k = beam_width, 179 | axis = -1 180 | ) 181 | next_chars = torch.remainder(idx, vocabulary_size).flatten()\ 182 | .unsqueeze(-1) 183 | best_candidates = (idx / vocabulary_size).long() 184 | best_candidates += torch.arange( 185 | X.shape[0] // beam_width, 186 | device = X.device 187 | ).unsqueeze(-1) * beam_width 188 | X = X[best_candidates].flatten(end_dim = -2) 189 | X = torch.cat((X, next_chars), axis = 1) 190 | return X.reshape(-1, beam_width, X.shape[-1]), probabilities 191 | -------------------------------------------------------------------------------- /src/pytorch_beam_search/seq2seq/__init__.py: -------------------------------------------------------------------------------- 1 | from .index import * 2 | from .models import * 3 | from .search_algorithms import * -------------------------------------------------------------------------------- /src/pytorch_beam_search/seq2seq/index.py: -------------------------------------------------------------------------------- 1 | from tqdm.auto import tqdm 2 | import torch 3 | import torch.nn as nn 4 | from nltk import lm 5 | import re 6 | 7 | class Index(): 8 | def __init__(self, corpus, progress_bar = False): 9 | self.special_tokens = ["", "", ""] 10 | if progress_bar: 11 | corpus = tqdm(corpus) 12 | self.vocabulary = lm.Vocabulary( 13 | [item for example in corpus for item in example] 14 | ) 15 | tokens = self.special_tokens + list(sorted(self.vocabulary)) 16 | self.voc2idx = {c:i for i, c in enumerate(tokens)} 17 | self.idx2voc = {i:c for i, c in enumerate(tokens)} 18 | 19 | def __len__(self): 20 | return len(self.voc2idx) 21 | 22 | def __str__(self): 23 | return f"" 24 | 25 | def text2tensor( 26 | self, 27 | strings, 28 | progress_bar = False 29 | ): 30 | if progress_bar: 31 | iterator = tqdm(strings) 32 | else: 33 | iterator = strings 34 | m = max([len(s) for s in strings]) 35 | idx = [] 36 | for l in iterator: 37 | idx.append( 38 | [1] +\ 39 | [self.voc2idx[self.vocabulary.lookup(c)] for c in l] +\ 40 | [2] +\ 41 | [0 for i in range(m - len(l))] 42 | ) 43 | return torch.tensor(idx) 44 | 45 | def tensor2text( 46 | self, 47 | X, 48 | separator = "", 49 | end = "" 50 | ): 51 | return [ 52 | separator.join([self.idx2voc[i] for i in l]) 53 | for l in X.tolist() 54 | ] 55 | 56 | -------------------------------------------------------------------------------- /src/pytorch_beam_search/seq2seq/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.data as tud 4 | from timeit import default_timer as timer 5 | from tqdm.auto import tqdm 6 | import pandas as pd 7 | import warnings 8 | 9 | class Seq2Seq(nn.Module): 10 | """ 11 | A generic sequence-to-sequence model. All other sequence-to-sequence models should extend this class 12 | with a __init__ and forward methods, in the same way as in normal PyTorch. 13 | """ 14 | def print_architecture(self): 15 | """ 16 | Displays the information about the model in standard output. 17 | """ 18 | for k in self.architecture.keys(): 19 | print(f"{k.replace('_', ' ').capitalize()}: {self.architecture[k]}") 20 | print(f"Trainable parameters: {sum([p.numel() for p in self.parameters()]):,}") 21 | print() 22 | 23 | def fit(self, 24 | X_train, 25 | Y_train, 26 | X_dev = None, 27 | Y_dev = None, 28 | batch_size = 100, 29 | epochs = 5, 30 | learning_rate = 10**-4, 31 | weight_decay = 0, 32 | progress_bar = 0, 33 | save_path = None): 34 | """ 35 | A generic training method with Adam and Cross Entropy. 36 | 37 | Parameters 38 | ---------- 39 | X_train: LongTensor of shape (train_examples, train_input_length) 40 | The input sequences of the training set. 41 | 42 | Y_train: LongTensor of shape (train_examples, train_output_length) 43 | The output sequences of the training set. 44 | 45 | X_dev: LongTensor of shape (dev_examples, dev_input_length), optional 46 | The input sequences for the development set. 47 | 48 | Y_train: LongTensor of shape (dev_examples, dev_output_length), optional 49 | The output sequences for the development set. 50 | 51 | batch_size: int 52 | The number of examples to process in each batch. 53 | 54 | epochs: int 55 | The number of epochs of the training process. 56 | 57 | learning_rate: float 58 | The learning rate to use with Adam in the training process. 59 | 60 | weight_decay: float 61 | The weight_decay parameter of Adam (L2 penalty), useful for regularizing models. For a deeper 62 | documentation, go to https://pytorch.org/docs/stable/_modules/torch/optim/adam.html#Adam 63 | 64 | progress_bar: int 65 | Shows a tqdm progress bar, useful for tracking progress with large tensors. 66 | If equal to 0, no progress bar is shown. 67 | If equal to 1, shows a bar with one step for every epoch. 68 | If equal to 2, shows the bar when equal to 1 and also shows a bar with one step per batch for every epoch. 69 | If equal to 3, shows the bars when equal to 2 and also shows a bar to track the progress of the evaluation 70 | in the development set. 71 | 72 | save_path: string, optional 73 | Path to save the .pt file containing the model parameters when the training ends. 74 | 75 | Returns 76 | ------- 77 | performance: Pandas DataFrame 78 | DataFrame with the following columns: epoch, train_loss, train_error_rate, (optionally dev_loss and 79 | dev_error_rate), minutes, learning_rate, weight_decay, model, encoder_embedding_dimension, 80 | decoder_embedding_dimension, encoder_hidden_units, encoder_layers, decoder_hidden_units, decoder_layers, 81 | dropout, parameters and one row for each of the epochs, containing information about the training process. 82 | """ 83 | assert X_train.shape[0] == Y_train.shape[0] 84 | assert (X_dev is None and Y_dev is None) or (X_dev is not None and Y_dev is not None) 85 | if (X_dev is not None and Y_dev is not None): 86 | assert X_dev.shape[0] == Y_dev.shape[0] 87 | dev = True 88 | else: 89 | dev = False 90 | train_dataset = tud.TensorDataset(X_train, Y_train) 91 | train_loader = tud.DataLoader(train_dataset, batch_size = batch_size, shuffle = True) 92 | criterion = nn.CrossEntropyLoss(ignore_index = 0) 93 | optimizer = torch.optim.Adam(self.parameters(), lr = learning_rate, weight_decay = weight_decay) 94 | performance = [] 95 | start = timer() 96 | epochs_iterator = range(1, epochs + 1) 97 | if progress_bar > 0: 98 | epochs_iterator = tqdm(epochs_iterator) 99 | print("Training started") 100 | print("X_train.shape:", X_train.shape) 101 | print("Y_train.shape:", Y_train.shape) 102 | if dev: 103 | print("X_dev.shape:", X_dev.shape) 104 | print("Y_dev.shape:", Y_dev.shape) 105 | print(f"Epochs: {epochs:,}\nLearning rate: {learning_rate}\nWeight decay: {weight_decay}") 106 | header_1 = "Epoch | Train " 107 | header_2 = " | Loss | Error Rate" 108 | rule = "-" * 29 109 | if dev: 110 | header_1 += " | Development " 111 | header_2 += " | Loss | Error Rate" 112 | rule += "-" * 24 113 | header_1 += " | Minutes" 114 | header_2 += " |" 115 | rule += "-" * 10 116 | print(header_1, header_2, rule, sep = "\n") 117 | for e in epochs_iterator: 118 | self.train() 119 | losses = [] 120 | errors = [] 121 | sizes = [] 122 | train_iterator = train_loader 123 | if progress_bar > 1: 124 | train_iterator = tqdm(train_iterator) 125 | for x, y in train_iterator: 126 | # compute loss and backpropagate 127 | probabilities = self.forward(x, y).transpose(1, 2)[:, :, :-1] 128 | y = y[:, 1:] 129 | loss = criterion(probabilities, y) 130 | loss.backward() 131 | optimizer.step() 132 | optimizer.zero_grad() 133 | # compute accuracy 134 | predictions = probabilities.argmax(1) 135 | batch_errors = (predictions != y) 136 | # append the results 137 | losses.append(loss.item()) 138 | errors.append(batch_errors.sum().item()) 139 | sizes.append(batch_errors.numel()) 140 | train_loss = sum(losses) / len(losses) 141 | train_error_rate = 100 * sum(errors) / sum(sizes) 142 | t = (timer() - start) / 60 143 | status_string = f"{e:>5} | {train_loss:>8.4f} | {train_error_rate:>10.3f}" 144 | status = {"epoch":e, 145 | "train_loss": train_loss, 146 | "train_error_rate": train_error_rate} 147 | if dev: 148 | dev_loss, dev_error_rate = self.evaluate(X_dev, 149 | Y_dev, 150 | batch_size = batch_size, 151 | progress_bar = progress_bar > 2, 152 | criterion = criterion) 153 | status_string += f" | {dev_loss:>8.4f} | {dev_error_rate:>10.3f}" 154 | status.update({"dev_loss": dev_loss, "dev_error_rate": dev_error_rate}) 155 | status.update({"training_minutes": t, 156 | "learning_rate": learning_rate, 157 | "weight_decay": weight_decay}) 158 | performance.append(status) 159 | if save_path is not None: 160 | if (not dev) or (e < 2) or (dev_loss < min([p["dev_loss"] for p in performance[:-1]])): 161 | torch.save(self.state_dict(), save_path) 162 | status_string += f" | {t:>7.1f}" 163 | print(status_string) 164 | print() 165 | return pd.concat((pd.DataFrame(performance), 166 | pd.DataFrame([self.architecture for i in performance])), axis = 1)\ 167 | .drop(columns = ["source_index", "target_index"]) 168 | 169 | 170 | def evaluate(self, 171 | X, 172 | Y, 173 | criterion = nn.CrossEntropyLoss(), 174 | batch_size = 128, 175 | progress_bar = False): 176 | """ 177 | Evaluates the model on a dataset. 178 | 179 | Parameters 180 | ---------- 181 | X: LongTensor of shape (examples, input_length) 182 | The input sequences of the dataset. 183 | 184 | Y: LongTensor of shape (examples, output_length) 185 | The output sequences of the dataset. 186 | 187 | criterion: PyTorch module 188 | The loss function to evalue the model on the dataset, has to be able to compare self.forward(X, Y) and Y 189 | to produce a real number. 190 | 191 | batch_size: int 192 | The batch size of the evaluation loop. 193 | 194 | progress_bar: bool 195 | Shows a tqdm progress bar, useful for tracking progress with large tensors. 196 | 197 | Returns 198 | ------- 199 | loss: float 200 | The average of criterion across the whole dataset. 201 | 202 | error_rate: float 203 | The step-by-step accuracy of the model across the whole dataset. Useful as a sanity check, as it should 204 | go to zero as the loss goes to zero. 205 | 206 | """ 207 | dataset = tud.TensorDataset(X, Y) 208 | loader = tud.DataLoader(dataset, batch_size = batch_size) 209 | self.eval() 210 | losses = [] 211 | errors = [] 212 | sizes = [] 213 | with torch.no_grad(): 214 | iterator = iter(loader) 215 | if progress_bar: 216 | iterator = tqdm(iterator) 217 | for batch in iterator: 218 | x, y = batch 219 | # compute loss 220 | probabilities = self.forward(x, y).transpose(1, 2)[:, :, :-1] 221 | y = y[:, 1:] 222 | loss = criterion(probabilities, y) 223 | # compute accuracy 224 | predictions = probabilities.argmax(1) 225 | batch_errors = (predictions != y) 226 | # append the results 227 | losses.append(loss.item()) 228 | errors.append(batch_errors.sum().item()) 229 | sizes.append(batch_errors.numel()) 230 | loss = sum(losses) / len(losses) 231 | error_rate = 100 * sum(errors) / sum(sizes) 232 | return loss, error_rate 233 | 234 | class LSTM(Seq2Seq): 235 | def __init__(self, 236 | source_index, 237 | target_index, 238 | encoder_embedding_dimension = 32, 239 | decoder_embedding_dimension = 32, 240 | encoder_hidden_units = 128, 241 | encoder_layers = 2, 242 | decoder_hidden_units = 128, 243 | decoder_layers = 2, 244 | dropout = 0.0): 245 | """ 246 | A standard Seq2Seq LSTM model as in 'Learning Phrase Representations using RNN Encoder-Decoder 247 | for Statistical Machine Translation' by Cho et al. (2014). 248 | 249 | Parameters 250 | ---------- 251 | in_vocabulary: dictionary 252 | Vocabulary with the index:token pairs for the inputs of the model. 253 | 254 | out_vocabulary: dictionary 255 | Vocabulary with the token:index pairs for the outputs of the model. 256 | 257 | encoder_embedding_dimension: int 258 | Dimension of the embeddings to feed into the encoder. 259 | 260 | decoder_embedding_dimension: int 261 | Dimension of the embeddings to feed into the decoder. 262 | 263 | encoder_hidden_units: int 264 | Hidden size of the encoder. 265 | 266 | encoder_layers: int 267 | Hidden layers of the encoder. 268 | 269 | decoder_hidden_units: int 270 | Hidden units of the decoder. 271 | 272 | decoder_layers: int 273 | Hidden layers of the decoder. 274 | 275 | dropout: float between 0.0 and 1.0 276 | Dropout rate to apply to whole model. 277 | """ 278 | self.source_index = source_index 279 | self.target_index = target_index 280 | super().__init__() 281 | self.source_embeddings = nn.Embedding(len(source_index), encoder_embedding_dimension) 282 | self.target_embeddings = nn.Embedding(len(target_index), decoder_embedding_dimension) 283 | self.encoder_rnn = nn.LSTM(input_size = encoder_embedding_dimension, 284 | hidden_size = encoder_hidden_units, 285 | num_layers = encoder_layers, 286 | dropout = dropout) 287 | self.decoder_rnn = nn.LSTM(input_size = encoder_layers * encoder_hidden_units + decoder_embedding_dimension, 288 | hidden_size = decoder_hidden_units, 289 | num_layers = decoder_layers, 290 | dropout = dropout) 291 | self.output_layer = nn.Linear(decoder_hidden_units, len(target_index)) 292 | self.architecture = dict(model = "Seq2Seq LSTM", 293 | source_index = source_index, 294 | target_index = target_index, 295 | encoder_embedding_dimension = encoder_embedding_dimension, 296 | decoder_embedding_dimension = decoder_embedding_dimension, 297 | encoder_hidden_units = encoder_hidden_units, 298 | encoder_layers = encoder_layers, 299 | decoder_hidden_units = decoder_hidden_units, 300 | decoder_layers = decoder_layers, 301 | dropout = dropout) 302 | self.print_architecture() 303 | 304 | def forward(self, X, Y): 305 | """ 306 | Forward method of the model. 307 | 308 | Parameters 309 | ---------- 310 | X: LongTensor of shape (batch_size, input_length) 311 | Tensor of integers containing the inputs for the model. 312 | 313 | Y: LongTensor of shape (batch_size, output_length) 314 | Tensor of integers containing the output produced so far. 315 | 316 | Returns 317 | ------- 318 | output: FloatTensor of shape (batch_size, output_length, len(out_vocabulary)) 319 | Tensor of floats containing the inputs for the final Softmax layer (usually integrated into the loss function). 320 | """ 321 | X = self.source_embeddings(X.T) 322 | encoder, (encoder_last_hidden, encoder_last_memory) = self.encoder_rnn(X) 323 | encoder_last_hidden = encoder_last_hidden.transpose(0, 1).flatten(start_dim = 1) 324 | encoder_last_hidden = encoder_last_hidden.repeat((Y.shape[1], 1, 1)) 325 | Y = self.target_embeddings(Y.T) 326 | Y = torch.cat((Y, encoder_last_hidden), axis = -1) 327 | decoder, (decoder_last_hidden, decoder_last_memory) = self.decoder_rnn(Y) 328 | output = self.output_layer(decoder.transpose(0, 1)) 329 | return output 330 | 331 | 332 | class ReversingLSTM(Seq2Seq): 333 | def __init__(self, 334 | source_index, 335 | target_index, 336 | encoder_embedding_dimension = 32, 337 | decoder_embedding_dimension = 32, 338 | encoder_hidden_units = 128, 339 | encoder_layers = 2, 340 | decoder_hidden_units = 128, 341 | decoder_layers = 2, 342 | dropout = 0.0): 343 | """ 344 | A standard Seq2Seq LSTM model that reverses the order of the input as in 345 | 'Sequence to sequence learning with Neural Networks' by Sutskever et al. (2014). 346 | 347 | Parameters 348 | ---------- 349 | in_vocabulary: dictionary 350 | Vocabulary with the index:token pairs for the inputs of the model. 351 | 352 | out_vocabulary: dictionary 353 | Vocabulary with the token:index pairs for the outputs of the model. 354 | 355 | encoder_embedding_dimension: int 356 | Dimension of the embeddings to feed into the encoder. 357 | 358 | decoder_embedding_dimension: int 359 | Dimension of the embeddings to feed into the decoder. 360 | 361 | encoder_hidden_units: int 362 | Hidden size of the encoder. 363 | 364 | encoder_layers: int 365 | Hidden layers of the encoder. 366 | 367 | decoder_hidden_units: int 368 | Hidden units of the decoder. 369 | 370 | decoder_layers: int 371 | Hidden layers of the decoder. 372 | 373 | dropout: float between 0.0 and 1.0 374 | Dropout rate to apply to whole model. 375 | """ 376 | super().__init__() 377 | self.source_index = source_index 378 | self.target_index = target_index 379 | self.source_embeddings = nn.Embedding(len(source_index), encoder_embedding_dimension) 380 | self.target_embeddings = nn.Embedding(len(target_index), decoder_embedding_dimension) 381 | self.encoder_rnn = nn.LSTM(input_size = encoder_embedding_dimension, 382 | hidden_size = encoder_hidden_units, 383 | num_layers = encoder_layers, 384 | dropout = dropout) 385 | self.decoder_rnn = nn.LSTM(input_size = decoder_embedding_dimension, 386 | hidden_size = decoder_hidden_units, 387 | num_layers = decoder_layers, 388 | dropout = dropout) 389 | self.output_layer = nn.Linear(decoder_hidden_units, len(target_index)) 390 | self.enc2dec = nn.Linear(encoder_hidden_units * encoder_layers, decoder_hidden_units * decoder_layers) 391 | self.architecture = dict(model = "Seq2Seq Reversing LSTM", 392 | source_index = source_index, 393 | target_index = target_index, 394 | encoder_embedding_dimension = encoder_embedding_dimension, 395 | decoder_embedding_dimension = decoder_embedding_dimension, 396 | encoder_hidden_units = encoder_hidden_units, 397 | encoder_layers = encoder_layers, 398 | decoder_hidden_units = decoder_hidden_units, 399 | decoder_layers = decoder_layers, 400 | dropout = dropout) 401 | self.print_architecture() 402 | 403 | def forward(self, X, Y): 404 | """ 405 | Forward method of the model. 406 | 407 | Parameters 408 | ---------- 409 | X: LongTensor of shape (batch_size, input_length) 410 | Tensor of integers containing the inputs for the model. 411 | 412 | Y: LongTensor of shape (batch_size, output_length) 413 | Tensor of integers containing the output produced so far. 414 | 415 | Returns 416 | ------- 417 | output: FloatTensor of shape (batch_size, output_length, len(out_vocabulary)) 418 | Tensor of floats containing the inputs for the final Softmax layer (usually integrated into the loss function). 419 | """ 420 | X = self.source_embeddings(torch.flip(X.T, dims = (1, ))) 421 | encoder, (encoder_last_hidden, encoder_last_memory) = self.encoder_rnn(X) 422 | encoder_last_hidden = encoder_last_hidden.transpose(0, 1).flatten(start_dim = 1) 423 | enc2dec = self.enc2dec(encoder_last_hidden)\ 424 | .reshape(-1, self.decoder_rnn.num_layers, self.decoder_rnn.hidden_size)\ 425 | .transpose(0, 1)\ 426 | .contiguous() 427 | Y = self.target_embeddings(Y.T) 428 | decoder, (decoder_last_hidden, decoder_last_memory) = self.decoder_rnn(Y, (enc2dec, torch.zeros_like(enc2dec))) 429 | output = self.output_layer(decoder.transpose(0, 1)) 430 | return output 431 | 432 | 433 | class Transformer(Seq2Seq): 434 | def __init__(self, 435 | source_index, 436 | target_index, 437 | max_sequence_length = 32, 438 | embedding_dimension = 32, 439 | feedforward_dimension = 128, 440 | encoder_layers = 2, 441 | decoder_layers = 2, 442 | attention_heads = 2, 443 | activation = "relu", 444 | dropout = 0.0): 445 | """ 446 | The standard PyTorch implementation of a Transformer model. 447 | 448 | Parameters 449 | ---------- 450 | in_vocabulary: dictionary 451 | Vocabulary with the index:token pairs for the inputs of the model. 452 | 453 | out_vocabulary: dictionary 454 | Vocabulary with the token:index pairs for the outputs of the model. 455 | 456 | max_sequence_length: int 457 | Maximum sequence length accepted by the model, both for the encoder and the decoder. 458 | 459 | embedding_dimension: int 460 | Dimension of the embeddings of the model. 461 | 462 | feedforward_dimension: int 463 | Dimension of the feedforward network inside the self-attention layers of the model. 464 | 465 | encoder_layers: int 466 | Hidden layers of the encoder. 467 | 468 | decoder_layers: int 469 | Hidden layers of the decoder. 470 | 471 | attention_heads: int 472 | Attention heads inside every self-attention layer of the model. 473 | 474 | activation: string 475 | Activation function of the feedforward network inside the self-attention layers of the model. Can 476 | be either 'relu' or 'gelu'. 477 | 478 | dropout: float between 0.0 and 1.0 479 | Dropout rate to apply to whole model. 480 | """ 481 | super().__init__() 482 | self.source_index = source_index 483 | self.target_index = target_index 484 | self.source_embeddings = nn.Embedding(len(source_index), embedding_dimension) 485 | self.target_embeddings = nn.Embedding(len(target_index), embedding_dimension) 486 | self.positional_embeddings = nn.Embedding(max_sequence_length, embedding_dimension) 487 | self.transformer = nn.Transformer(d_model = embedding_dimension, 488 | dim_feedforward = feedforward_dimension, 489 | nhead = attention_heads, 490 | num_encoder_layers = encoder_layers, 491 | num_decoder_layers = decoder_layers, 492 | activation = activation, 493 | dropout = dropout) 494 | self.output_layer = nn.Linear(embedding_dimension, len(target_index)) 495 | self.architecture = dict(model = "Seq2Seq Transformer", 496 | source_index = source_index, 497 | target_index = target_index, 498 | max_sequence_length = max_sequence_length, 499 | embedding_dimension = embedding_dimension, 500 | feedforward_dimension = feedforward_dimension, 501 | encoder_layers = encoder_layers, 502 | decoder_layers = decoder_layers, 503 | attention_heads = attention_heads, 504 | activation = activation, 505 | dropout = dropout) 506 | self.print_architecture() 507 | 508 | def forward(self, X, Y): 509 | """ 510 | Forward method of the model. 511 | 512 | Parameters 513 | ---------- 514 | X: LongTensor of shape (batch_size, input_length) 515 | Tensor of integers containing the inputs for the model. 516 | 517 | Y: LongTensor of shape (batch_size, output_length) 518 | Tensor of integers containing the output produced so far. 519 | 520 | Returns 521 | ------- 522 | output: FloatTensor of shape (batch_size, output_length, len(out_vocabulary)) 523 | Tensor of floats containing the inputs for the final Softmax layer (usually integrated in the loss function). 524 | """ 525 | assert X.shape[1] <= self.architecture["max_sequence_length"] 526 | assert Y.shape[1] <= self.architecture["max_sequence_length"] 527 | X = self.source_embeddings(X) 528 | X_positional = torch.arange(X.shape[1], device = X.device).repeat((X.shape[0], 1)) 529 | X_positional = self.positional_embeddings(X_positional) 530 | X = (X + X_positional).transpose(0, 1) 531 | Y = self.target_embeddings(Y) 532 | Y_positional = torch.arange(Y.shape[1], device = Y.device).repeat((Y.shape[0], 1)) 533 | Y_positional = self.positional_embeddings(Y_positional) 534 | Y = (Y + Y_positional).transpose(0, 1) 535 | mask = self.transformer.generate_square_subsequent_mask(Y.shape[0]).to(Y.device) 536 | transformer_output = self.transformer.forward(src = X, 537 | tgt = Y, 538 | tgt_mask = mask) 539 | transformer_output = transformer_output.transpose(0, 1) 540 | return self.output_layer(transformer_output) 541 | -------------------------------------------------------------------------------- /src/pytorch_beam_search/seq2seq/search_algorithms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as tud 3 | from tqdm.auto import tqdm 4 | import warnings 5 | 6 | def greedy_search( 7 | model, 8 | X, 9 | predictions = 20, 10 | progress_bar = False 11 | ): 12 | """ 13 | Implements Greedy Search to compute the output with the sequences given in X. The method can compute 14 | several outputs in parallel with the first dimension of X. 15 | 16 | Parameters 17 | ---------- 18 | X: LongTensor of shape (examples, length) 19 | The sequences to start the decoding process. 20 | 21 | predictions: int 22 | The number of tokens to append to X. 23 | 24 | progress_bar: bool 25 | Shows a tqdm progress bar, useful for tracking progress with large tensors. 26 | 27 | Returns 28 | ------- 29 | Y: LongTensor of shape (examples, length + predictions) 30 | The output sequences. 31 | 32 | probabilities: FloatTensor of length examples 33 | The estimated log-probabilities for the output sequences. They are computed by iteratively adding the 34 | probability of the next token at every step. 35 | """ 36 | with torch.no_grad(): 37 | Y = torch.ones(X.shape[0], 1).long().to(next(model.parameters()).device) 38 | probabilities = torch.zeros(X.shape[0]).to(next(model.parameters()).device) 39 | iterator = range(predictions) 40 | if progress_bar: 41 | iterator = tqdm(iterator) 42 | for i in iterator: 43 | next_probabilities = model.forward(X, Y)[:, -1].log_softmax(-1) 44 | max_next_probabilities, next_chars = next_probabilities.max(-1) 45 | next_chars = next_chars.unsqueeze(-1) 46 | Y = torch.cat((Y, next_chars), axis = 1) 47 | probabilities += max_next_probabilities 48 | return Y, probabilities 49 | 50 | def sample( 51 | model, 52 | X, 53 | predictions = 20, 54 | temperature = 1.0, 55 | progress_bar = False 56 | ): 57 | """ 58 | Samples the sequence distribution to compute the output with the sequences given in X. The method can compute 59 | several outputs in parallel with the first dimension of X. 60 | 61 | Parameters 62 | ---------- 63 | X: LongTensor of shape (examples, length) 64 | The sequences to start the decoding process. 65 | 66 | predictions: int 67 | The number of tokens to append to X. 68 | 69 | temperature: positive float 70 | Parameter to control the freedom of the sampling. Higher values give more freedom. 71 | 72 | progress_bar: bool 73 | Shows a tqdm progress bar, useful for tracking progress with large tensors. 74 | 75 | Returns 76 | ------- 77 | Y: LongTensor of shape (examples, length + predictions) 78 | The output sequences. 79 | 80 | probabilities: FloatTensor of length examples 81 | The estimated log-probabilities for the output sequences. They are computed by iteratively adding the 82 | probability of the next token at every step. 83 | """ 84 | with torch.no_grad(): 85 | Y = torch.ones(X.shape[0], 1).long().to(next(model.parameters()).device) 86 | probabilities = torch.zeros(X.shape[0]).to(next(model.parameters()).device) 87 | iterator = range(predictions) 88 | if progress_bar: 89 | iterator = tqdm(iterator) 90 | for i in iterator: 91 | next_probabilities = model.forward(X, Y)[:, -1] 92 | next_probabilities = (next_probabilities / temperature).softmax(1) 93 | random = torch.rand((next_probabilities.shape[0], 1)).to(next(model.parameters()).device) 94 | next_chars = ((next_probabilities.cumsum(1) < random).sum(1, keepdims = True)) 95 | probabilities += torch.gather(input = next_probabilities.log(), dim = 1, index = next_chars).squeeze() 96 | Y = torch.cat((Y, next_chars), axis = 1) 97 | return Y, probabilities 98 | 99 | def beam_search( 100 | model, 101 | X, 102 | predictions = 20, 103 | beam_width = 5, 104 | batch_size = 50, 105 | progress_bar = 0 106 | ): 107 | """ 108 | Implements Beam Search to compute the output with the sequences given in X. The method can compute 109 | several outputs in parallel with the first dimension of X. 110 | 111 | Parameters 112 | ---------- 113 | X: LongTensor of shape (examples, length) 114 | The sequences to start the decoding process. 115 | 116 | predictions: int 117 | The number of tokens to append to X. 118 | 119 | beam_width: int 120 | The number of candidates to keep in the search. 121 | 122 | batch_size: int 123 | The batch size of the inner loop of the method, which relies on the beam width. 124 | 125 | progress_bar: int 126 | Shows a tqdm progress bar, useful for tracking progress with large tensors. Ranges from 0 to 2. 127 | 128 | Returns 129 | ------- 130 | Y: LongTensor of shape (examples, length + predictions) 131 | The output sequences. 132 | 133 | probabilities: FloatTensor of length examples 134 | The estimated log-probabilities for the output sequences. They are computed by iteratively adding the 135 | probability of the next token at every step. 136 | """ 137 | with torch.no_grad(): 138 | Y = torch.ones(X.shape[0], 1).to(next(model.parameters()).device).long() 139 | # The next command can be a memory bottleneck, can be controlled with the batch 140 | # size of the predict method. 141 | next_probabilities = model.forward(X, Y)[:, -1, :] 142 | vocabulary_size = next_probabilities.shape[-1] 143 | probabilities, next_chars = next_probabilities.squeeze().log_softmax(-1)\ 144 | .topk(k = beam_width, axis = -1) 145 | Y = Y.repeat((beam_width, 1)) 146 | next_chars = next_chars.reshape(-1, 1) 147 | Y = torch.cat((Y, next_chars), axis = -1) 148 | # This has to be minus one because we already produced a round 149 | # of predictions before the for loop. 150 | predictions_iterator = range(predictions - 1) 151 | if progress_bar > 0: 152 | predictions_iterator = tqdm(predictions_iterator) 153 | for i in predictions_iterator: 154 | dataset = tud.TensorDataset(X.repeat((beam_width, 1, 1)).transpose(0, 1).flatten(end_dim = 1), Y) 155 | loader = tud.DataLoader(dataset, batch_size = batch_size) 156 | next_probabilities = [] 157 | iterator = iter(loader) 158 | if progress_bar > 1: 159 | iterator = tqdm(iterator) 160 | for x, y in iterator: 161 | next_probabilities.append(model.forward(x, y)[:, -1, :].log_softmax(-1)) 162 | next_probabilities = torch.cat(next_probabilities, axis = 0) 163 | next_probabilities = next_probabilities.reshape((-1, beam_width, next_probabilities.shape[-1])) 164 | probabilities = probabilities.unsqueeze(-1) + next_probabilities 165 | probabilities = probabilities.flatten(start_dim = 1) 166 | probabilities, idx = probabilities.topk(k = beam_width, axis = -1) 167 | next_chars = torch.remainder(idx, vocabulary_size).flatten().unsqueeze(-1) 168 | best_candidates = (idx / vocabulary_size).long() 169 | best_candidates += torch.arange(Y.shape[0] // beam_width, device = X.device).unsqueeze(-1) * beam_width 170 | Y = Y[best_candidates].flatten(end_dim = -2) 171 | Y = torch.cat((Y, next_chars), axis = 1) 172 | return Y.reshape(-1, beam_width, Y.shape[-1]), probabilities 173 | -------------------------------------------------------------------------------- /tests/gpt.py: -------------------------------------------------------------------------------- 1 | from pprint import pprint 2 | from pytorch_beam_search import autoregressive 3 | 4 | # Create vocabulary and examples 5 | # tokenize the way you need 6 | corpus = list("abcdefghijklmnopqrstwxyz ") 7 | print("corpus") 8 | print(corpus, "\n") 9 | # len(corpus) == 25 10 | # An Index object represents a mapping from the vocabulary 11 | # to integers (indices) to feed into the models 12 | print("creating index") 13 | index = autoregressive.Index(corpus) 14 | print(index, "\n") 15 | n_gram_size = 17 # 16 with an offset of 1 16 | n_grams = [ 17 | corpus[i:n_gram_size + i] 18 | for i in range(len(corpus))[:-n_gram_size + 1] 19 | ] 20 | 21 | # Create tensor 22 | print("creating tensor...") 23 | X = index.text2tensor(n_grams) 24 | # X.shape == (n_examples, len_examples) == (25 - 17 + 1 = 9, 17) 25 | print("X", X.shape, "\n") 26 | 27 | # Create and train the model 28 | model = autoregressive.TransformerEncoder(index) # just a PyTorch model 29 | model.fit(X) # basic method included 30 | 31 | # Generate new predictions 32 | print("\ntest data") 33 | new_examples = [list("new first"), list("new second")] 34 | pprint(new_examples) 35 | print("\ncreating tensor...") 36 | X_new = index.text2tensor(new_examples) 37 | print("X_new", X_new.shape, "\n") 38 | print("evaluating model...") 39 | loss, error_rate = model.evaluate(X_new) # basic method included 40 | print("loss:", loss) 41 | print("error_rate:", error_rate, "\n") 42 | print("beam_search...") 43 | predictions, log_probabilities = autoregressive.beam_search( 44 | model, 45 | X_new, 46 | predictions = 5, 47 | # progress_bar = 1 48 | ) 49 | # every element in predictions is the list of candidates for each example 50 | output = [index.tensor2text(p) for p in predictions] 51 | print("\npredictions") 52 | pprint(output) 53 | -------------------------------------------------------------------------------- /tests/transformer.py: -------------------------------------------------------------------------------- 1 | from pprint import pprint 2 | from pytorch_beam_search import seq2seq 3 | 4 | # Create vocabularies 5 | # Tokenize the way you need 6 | print("train data") 7 | source = [list("abcdefghijkl"), list("mnopqrstwxyz")] 8 | target = [list("ABCDEFGHIJKL"), list("MNOPQRSTWXYZ")] 9 | print("source") 10 | pprint(source) 11 | print("target") 12 | pprint(target) 13 | # An Index object represents a mapping from the vocabulary 14 | # to integers (indices) to feed into the models 15 | print("creating indexes...") 16 | source_index = seq2seq.Index(source) 17 | target_index = seq2seq.Index(target) 18 | 19 | # Create tensors 20 | print("creating tensors...") 21 | X = source_index.text2tensor(source) 22 | Y = target_index.text2tensor(target) 23 | # X.shape == (n_source_examples, len_source_examples) == (2, 11) 24 | # Y.shape == (n_target_examples, len_target_examples) == (2, 12) 25 | 26 | # Create and train the model 27 | print("creating model...") 28 | model = seq2seq.Transformer(source_index, target_index) # just a PyTorch model 29 | # model = seq2seq.LSTM(source_index, target_index) # just a PyTorch model 30 | print("training model...") 31 | model.fit(X, Y, epochs = 100) # basic method included 32 | 33 | # Generate new predictions 34 | print("test data") 35 | new_source = [list("new first in"), list("new second in")] 36 | new_target = [list("new first out"), list("new second out")] 37 | print("new source") 38 | pprint(new_source) 39 | print("new target") 40 | pprint(new_target) 41 | print("creating tensors...") 42 | X_new = source_index.text2tensor(new_source) 43 | Y_new = target_index.text2tensor(new_target) 44 | print("evaluating model...") 45 | loss, error_rate = model.evaluate(X_new, Y_new) # basic method included 46 | print("beam search...") 47 | predictions, log_probabilities = seq2seq.beam_search( 48 | model, 49 | X_new, 50 | # progress_bar = 1, 51 | # predictions = 100 52 | ) 53 | output = [target_index.tensor2text(p) for p in predictions] 54 | print("\npredictions") 55 | pprint(output) 56 | --------------------------------------------------------------------------------