├── .gitignore ├── README.md ├── allennlp-book-r-and-d.jpg ├── allennlp-book.jpeg ├── classifier-model ├── configs │ └── experiment.jsonnet ├── datasets │ └── download.sh └── src │ ├── __init__.py │ ├── data │ ├── __init__.py │ └── dataset_readers │ │ ├── __init__.py │ │ └── ag_news_reader.py │ └── models │ ├── __init__.py │ └── text_classifier.py ├── jp-classifier-model ├── configs │ └── experiment.jsonnet ├── datasets │ └── download.sh └── src │ ├── __init__.py │ ├── data │ ├── __init__.py │ ├── dataset_readers │ │ ├── __init__.py │ │ └── livedoor_news_reader.py │ └── tokenizers │ │ ├── __init__.py │ │ └── janome_tokenizer.py │ └── models │ ├── __init__.py │ └── text_classifier.py ├── mlflow ├── .gitignore ├── MLproject ├── README.md ├── conda.yaml ├── configs │ ├── ner.jsonnet │ └── sequence_tagging.jsonnet ├── data │ └── ner_data.json ├── scripts │ └── train.py └── src │ ├── __init__.py │ └── training │ ├── __init__.py │ └── callbacks │ ├── __init__.py │ └── mlflow_metrics.py ├── ner-model ├── configs │ ├── bert-experiment.jsonnet │ └── experiment.jsonnet ├── datasets │ └── download.sh └── src │ ├── __init__.py │ ├── data │ ├── __init__.py │ └── dataset_readers │ │ ├── __init__.py │ │ └── conll_2003_reader.py │ ├── models │ ├── __init__.py │ └── ner_tagger.py │ ├── predictors │ ├── __init__.py │ └── conll_2003_predictor.py │ └── tests │ ├── __init__.py │ ├── data │ ├── __init__.py │ └── dataset_readers │ │ ├── __init__.py │ │ └── conll_2003_reader_test.py │ ├── fixtures │ ├── configs │ │ └── experiment.jsonnet │ └── data │ │ └── conll2003.txt │ └── models │ ├── __init__.py │ └── ner_tagger_test.py ├── nli ├── .gitignore ├── configs │ ├── esim.jsonnet │ ├── san.jsonnet │ └── san_test.jsonnet ├── data │ └── snli_test.jsonl └── src │ ├── __init__.py │ ├── models │ ├── __init__.py │ └── san.py │ └── modules │ ├── __init__.py │ └── full_layer_lstm.py ├── requirements.txt └── seq2seq ├── .gitignore ├── configs ├── common.jsonnet ├── composed_seq2seq.jsonnet └── simple_seq2seq.jsonnet ├── data └── dataset.py └── decoder.py /.gitignore: -------------------------------------------------------------------------------- 1 | .venv_book 2 | .vscode 3 | __pycache__ 4 | .DS_Store 5 | .pytest_cache 6 | 7 | */datasets/* 8 | !*/datasets/download.sh 9 | */tmp 10 | pretrain_bert/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AllenNLP入門 2 | 3 | 本リポジトリはAnnanAIによる「AllenNLP入門」のソースコード置き場です。[Amazon](https://www.amazon.co.jp/dp/B08GLG39DF/ref=cm_sw_em_r_mt_dp_2pGsFbYJDFYJT)または[BOOTH](https://annan-ai.booth.pm/items/1881126)にて販売中です。 4 | 5 |

6 | 7 |

8 | 9 |

10 | 11 |

12 | 13 | ## 目次 14 | 15 | - [第1章 AllenNLPチュートリアル](./ner-model) 16 | - [第2章 文書分類](./classifier-model) 17 | - [第3章 Seq2Seq](./seq2seq) 18 | - [第4章 Natural Language Inference](./nli) 19 | - [第5章 事前学習済みBERT](./ner-model) 20 | - [第6章 AllenNLP で日本語を扱おう](./jp-classifier-model) 21 | - [第7章 MLflow との連携](./mlflow) 22 | 23 | ## 著者 24 | 25 | [@altescy](https://github.com/altescy), [@kajyuuen](https://github.com/kajyuuen) 26 | -------------------------------------------------------------------------------- /allennlp-book-r-and-d.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kajyuuen/allennlp-book/e3df88ef284e24e5ae0d9cfa3527c255c05a25e6/allennlp-book-r-and-d.jpg -------------------------------------------------------------------------------- /allennlp-book.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kajyuuen/allennlp-book/e3df88ef284e24e5ae0d9cfa3527c255c05a25e6/allennlp-book.jpeg -------------------------------------------------------------------------------- /classifier-model/configs/experiment.jsonnet: -------------------------------------------------------------------------------- 1 | local embedding_dim = 10; 2 | local num_filters = 8; 3 | local output_dim = 16; 4 | local num_epochs = 1; 5 | local batch_size = 2; 6 | local learning_rate = 0.1; 7 | 8 | { 9 | dataset_reader: { 10 | type: 'ag_news_reader', 11 | token_indexers: { 12 | tokens: { 13 | type: 'single_id', 14 | token_min_padding_length: num_filters 15 | }, 16 | }, 17 | }, 18 | train_data_path: 'datasets/train.csv', 19 | test_data_path: 'datasets/test.csv', 20 | model: { 21 | type: 'text_classifier', 22 | word_embeddings: { 23 | tokens: { 24 | type: 'embedding', 25 | embedding_dim: embedding_dim, 26 | trainable: true 27 | } 28 | }, 29 | encoder: { 30 | type: 'cnn', 31 | num_filters: num_filters, 32 | embedding_dim: embedding_dim, 33 | output_dim: output_dim 34 | } 35 | }, 36 | iterator: { 37 | type: 'bucket', 38 | batch_size: batch_size, 39 | sorting_keys: [['sentence', 'num_tokens']] 40 | }, 41 | trainer: { 42 | num_epochs: num_epochs, 43 | optimizer: { 44 | type: 'sgd', 45 | lr: learning_rate 46 | } 47 | } 48 | } -------------------------------------------------------------------------------- /classifier-model/datasets/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | curl -OL https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/train.csv 4 | curl -OL https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/test.csv 5 | -------------------------------------------------------------------------------- /classifier-model/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kajyuuen/allennlp-book/e3df88ef284e24e5ae0d9cfa3527c255c05a25e6/classifier-model/src/__init__.py -------------------------------------------------------------------------------- /classifier-model/src/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kajyuuen/allennlp-book/e3df88ef284e24e5ae0d9cfa3527c255c05a25e6/classifier-model/src/data/__init__.py -------------------------------------------------------------------------------- /classifier-model/src/data/dataset_readers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kajyuuen/allennlp-book/e3df88ef284e24e5ae0d9cfa3527c255c05a25e6/classifier-model/src/data/dataset_readers/__init__.py -------------------------------------------------------------------------------- /classifier-model/src/data/dataset_readers/ag_news_reader.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Iterator 2 | from overrides import overrides 3 | 4 | import csv 5 | 6 | from allennlp.data import Instance 7 | from allennlp.data.tokenizers import Token 8 | from allennlp.data.dataset_readers import DatasetReader 9 | from allennlp.data.token_indexers import TokenIndexer 10 | from allennlp.data.fields import TextField, LabelField 11 | 12 | @DatasetReader.register('ag_news_reader') 13 | class AgNewsReader(DatasetReader): 14 | def __init__(self, token_indexers: Dict[str, TokenIndexer]) -> None: 15 | super().__init__(lazy=False) 16 | self.token_indexers = token_indexers 17 | self.classes = ["World", "Sports", "Business", "Sci/Tech"] 18 | 19 | @overrides 20 | def text_to_instance(self, 21 | tokens: List[Token], 22 | label: str = None) -> Instance: 23 | sentence_field = TextField(tokens, self.token_indexers) 24 | fields = {"sentence": sentence_field} 25 | 26 | if label: 27 | label_field = LabelField(label) 28 | fields["label"] = label_field 29 | 30 | return Instance(fields) 31 | 32 | @overrides 33 | def _read(self, file_path: str) -> Iterator[Instance]: 34 | with open(file_path, 'r') as f: 35 | reader = csv.reader(f) 36 | for row in reader: 37 | label, sentence = self.classes[int(row[0])-1], row[2] 38 | yield self.text_to_instance([Token(word) for word in sentence.split(" ")], label) 39 | 40 | if __name__ == "__main__": 41 | r = AgNewsReader() 42 | dataset = r.read("/Users/kajyuuen/workspace/book-src/ch02/datasets/test.csv") 43 | print(dataset[0]) -------------------------------------------------------------------------------- /classifier-model/src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kajyuuen/allennlp-book/e3df88ef284e24e5ae0d9cfa3527c255c05a25e6/classifier-model/src/models/__init__.py -------------------------------------------------------------------------------- /classifier-model/src/models/text_classifier.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from overrides import overrides 3 | 4 | import torch 5 | 6 | from allennlp.models import Model 7 | 8 | from allennlp.data.vocabulary import Vocabulary 9 | 10 | from allennlp.modules.text_field_embedders import TextFieldEmbedder 11 | from allennlp.modules.seq2vec_encoders import Seq2VecEncoder 12 | 13 | from allennlp.nn.util import get_text_field_mask 14 | 15 | from allennlp.training.metrics import CategoricalAccuracy 16 | 17 | @Model.register('text_classifier') 18 | class TextClassifier(Model): 19 | def __init__(self, 20 | word_embeddings: TextFieldEmbedder, 21 | encoder: Seq2VecEncoder, 22 | vocab: Vocabulary) -> None: 23 | super().__init__(vocab) 24 | self.word_embeddings = word_embeddings 25 | self.encoder = encoder 26 | 27 | self.classification_layer = torch.nn.Linear(in_features = encoder.get_output_dim(), 28 | out_features = vocab.get_vocab_size('labels')) 29 | self._loss = torch.nn.CrossEntropyLoss() 30 | self._accuracy = CategoricalAccuracy() 31 | 32 | @overrides 33 | def forward(self, 34 | sentence: Dict[str, torch.Tensor], 35 | label: torch.IntTensor = None) -> Dict[str, torch.Tensor]: 36 | mask = get_text_field_mask(sentence) 37 | 38 | embeddings = self.word_embeddings(sentence) 39 | encoder_out = self.encoder(embeddings, mask) 40 | 41 | label_logit = self.classification_layer(encoder_out) 42 | output = {} 43 | 44 | if label is not None: 45 | self._accuracy(label_logit, label) 46 | output["loss"] = self._loss(label_logit, label.long()) 47 | 48 | return output 49 | 50 | @overrides 51 | def get_metrics(self, 52 | reset: bool = False) -> Dict[str, float]: 53 | return {"accuracy": self._accuracy.get_metric(reset)} 54 | 55 | 56 | -------------------------------------------------------------------------------- /jp-classifier-model/configs/experiment.jsonnet: -------------------------------------------------------------------------------- 1 | local embedding_dim = 10; 2 | local num_filters = 8; 3 | local output_dim = 16; 4 | local num_epochs = 100; 5 | local batch_size = 2; 6 | local learning_rate = 0.1; 7 | 8 | { 9 | dataset_reader: { 10 | type: 'livedoor_news_reader', 11 | token_indexers: { 12 | tokens: { 13 | type: 'single_id', 14 | token_min_padding_length: num_filters 15 | }, 16 | }, 17 | tokenizer: { 18 | type: 'janome', 19 | }, 20 | }, 21 | train_data_path: 'datasets/text/', 22 | model: { 23 | type: 'text_classifier', 24 | word_embeddings: { 25 | tokens: { 26 | type: 'embedding', 27 | embedding_dim: embedding_dim, 28 | trainable: true 29 | } 30 | }, 31 | encoder: { 32 | type: 'cnn', 33 | num_filters: num_filters, 34 | embedding_dim: embedding_dim, 35 | output_dim: output_dim 36 | } 37 | }, 38 | iterator: { 39 | type: 'bucket', 40 | batch_size: batch_size, 41 | sorting_keys: [['sentence', 'num_tokens']] 42 | }, 43 | trainer: { 44 | num_epochs: num_epochs, 45 | optimizer: { 46 | type: 'sgd', 47 | lr: learning_rate 48 | } 49 | } 50 | } -------------------------------------------------------------------------------- /jp-classifier-model/datasets/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | curl -OL https://www.rondhuit.com/download/ldcc-20140209.tar.gz 4 | tar -zxvf ldcc-20140209.tar.gz -------------------------------------------------------------------------------- /jp-classifier-model/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kajyuuen/allennlp-book/e3df88ef284e24e5ae0d9cfa3527c255c05a25e6/jp-classifier-model/src/__init__.py -------------------------------------------------------------------------------- /jp-classifier-model/src/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kajyuuen/allennlp-book/e3df88ef284e24e5ae0d9cfa3527c255c05a25e6/jp-classifier-model/src/data/__init__.py -------------------------------------------------------------------------------- /jp-classifier-model/src/data/dataset_readers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kajyuuen/allennlp-book/e3df88ef284e24e5ae0d9cfa3527c255c05a25e6/jp-classifier-model/src/data/dataset_readers/__init__.py -------------------------------------------------------------------------------- /jp-classifier-model/src/data/dataset_readers/livedoor_news_reader.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Iterator 2 | from overrides import overrides 3 | 4 | import os 5 | 6 | from allennlp.data import Instance 7 | from allennlp.data.tokenizers import Token, Tokenizer 8 | from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer 9 | from allennlp.data.dataset_readers import DatasetReader 10 | from allennlp.data.fields import TextField, LabelField 11 | 12 | @DatasetReader.register("livedoor_news_reader") 13 | class LivedoorNewsReader(DatasetReader): 14 | def __init__(self, 15 | tokenizer: Tokenizer, 16 | token_indexers: Dict[str, TokenIndexer] = None) -> None: 17 | super().__init__(lazy=False) 18 | self.tokenizer = tokenizer 19 | self.token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()} 20 | 21 | @overrides 22 | def text_to_instance(self, 23 | tokens: List[Token], 24 | label: str = None) -> Instance: 25 | sentence_field = TextField(tokens, self.token_indexers) 26 | fields = {"sentence": sentence_field} 27 | 28 | if label: 29 | label_field = LabelField(label) 30 | fields["label"] = label_field 31 | 32 | return Instance(fields) 33 | 34 | @overrides 35 | def _read(self, path: str) -> Iterator[Instance]: 36 | dirs_path = os.listdir(path) 37 | category_dirs = [f for f in dirs_path if os.path.isdir(os.path.join(path, f))] 38 | for category_dir in category_dirs: 39 | file_dir_path = os.path.join(path, category_dir) 40 | files = os.listdir(file_dir_path) 41 | for i, file_name in enumerate(files): 42 | # 各カテゴリ10文章づつ読み込む 43 | if i == 10: 44 | break 45 | with open(os.path.join(file_dir_path, file_name)) as f: 46 | text = f.read() 47 | label = category_dir 48 | yield self.text_to_instance([Token(word) for word in text], label) 49 | -------------------------------------------------------------------------------- /jp-classifier-model/src/data/tokenizers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kajyuuen/allennlp-book/e3df88ef284e24e5ae0d9cfa3527c255c05a25e6/jp-classifier-model/src/data/tokenizers/__init__.py -------------------------------------------------------------------------------- /jp-classifier-model/src/data/tokenizers/janome_tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | from overrides import overrides 4 | from janome.tokenizer import Tokenizer as JTokenizer 5 | 6 | from allennlp.data.tokenizers.token import Token 7 | from allennlp.data.tokenizers.tokenizer import Tokenizer 8 | 9 | @Tokenizer.register("janome") 10 | class JanomeTokenizer(Tokenizer): 11 | def __init__(self) -> None: 12 | self.tokenizer = JTokenizer() 13 | super().__init__() 14 | 15 | @overrides 16 | def tokenize(self, text: str) -> List[Token]: 17 | return self.tokenizer.tokenize(text, wakati=True) 18 | 19 | if __name__ == "__main__": 20 | jt = JanomeTokenizer() 21 | print(jt.tokenize("僕は古明地こいしちゃん!")) -------------------------------------------------------------------------------- /jp-classifier-model/src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kajyuuen/allennlp-book/e3df88ef284e24e5ae0d9cfa3527c255c05a25e6/jp-classifier-model/src/models/__init__.py -------------------------------------------------------------------------------- /jp-classifier-model/src/models/text_classifier.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from overrides import overrides 3 | 4 | import torch 5 | 6 | from allennlp.models import Model 7 | 8 | from allennlp.data.vocabulary import Vocabulary 9 | 10 | from allennlp.modules.text_field_embedders import TextFieldEmbedder 11 | from allennlp.modules.seq2vec_encoders import Seq2VecEncoder 12 | 13 | from allennlp.nn.util import get_text_field_mask 14 | 15 | from allennlp.training.metrics import CategoricalAccuracy 16 | 17 | @Model.register('text_classifier') 18 | class TextClassifier(Model): 19 | def __init__(self, 20 | word_embeddings: TextFieldEmbedder, 21 | encoder: Seq2VecEncoder, 22 | vocab: Vocabulary) -> None: 23 | super().__init__(vocab) 24 | self.word_embeddings = word_embeddings 25 | self.encoder = encoder 26 | 27 | self.classification_layer = torch.nn.Linear(in_features = encoder.get_output_dim(), 28 | out_features = vocab.get_vocab_size('labels')) 29 | self._loss = torch.nn.CrossEntropyLoss() 30 | self._accuracy = CategoricalAccuracy() 31 | 32 | @overrides 33 | def forward(self, 34 | sentence: Dict[str, torch.Tensor], 35 | label: torch.IntTensor = None) -> Dict[str, torch.Tensor]: 36 | mask = get_text_field_mask(sentence) 37 | 38 | embeddings = self.word_embeddings(sentence) 39 | encoder_out = self.encoder(embeddings, mask) 40 | 41 | label_logit = self.classification_layer(encoder_out) 42 | output = {} 43 | 44 | if label is not None: 45 | self._accuracy(label_logit, label) 46 | output["loss"] = self._loss(label_logit, label.long()) 47 | 48 | return output 49 | 50 | @overrides 51 | def get_metrics(self, 52 | reset: bool = False) -> Dict[str, float]: 53 | return {"accuracy": self._accuracy.get_metric(reset)} 54 | 55 | 56 | -------------------------------------------------------------------------------- /mlflow/.gitignore: -------------------------------------------------------------------------------- 1 | mlruns 2 | -------------------------------------------------------------------------------- /mlflow/MLproject: -------------------------------------------------------------------------------- 1 | name: allennlp-mlflow 2 | 3 | conda_env: conda.yaml 4 | 5 | entry_points: 6 | train: 7 | parameters: 8 | config: path 9 | command: "PYTHONPATH=`pwd` python scripts/train.py --include-package src {config}" 10 | -------------------------------------------------------------------------------- /mlflow/README.md: -------------------------------------------------------------------------------- 1 | ### Dependencies 2 | 3 | - conda 4 | - mlflow 5 | - awscli 6 | - boto3 7 | -------------------------------------------------------------------------------- /mlflow/conda.yaml: -------------------------------------------------------------------------------- 1 | name: allennlp-mlflow 2 | channels: 3 | - defaults 4 | - anaconda 5 | dependencies: 6 | - python==3.7 7 | - pip: 8 | - allennlp==0.9.0 9 | - mlflow>=1.0 10 | - cloudpickle==1.2.2 11 | -------------------------------------------------------------------------------- /mlflow/configs/ner.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "conll2003", 4 | "tag_label": "ner", 5 | "coding_scheme": "BIOUL", 6 | "token_indexers": { 7 | "tokens": { 8 | "type": "single_id", 9 | "lowercase_tokens": true 10 | }, 11 | "token_characters": { 12 | "type": "characters", 13 | "min_padding_length": 3 14 | } 15 | } 16 | }, 17 | "train_data_path": "https://raw.githubusercontent.com/synalp/NER/master/corpus/CoNLL-2003/eng.train", 18 | "validation_data_path": "https://raw.githubusercontent.com/synalp/NER/master/corpus/CoNLL-2003/eng.testa", 19 | "test_data_path": "https://raw.githubusercontent.com/synalp/NER/master/corpus/CoNLL-2003/eng.testb", 20 | "model": { 21 | "type": "crf_tagger", 22 | "label_encoding": "BIOUL", 23 | "constrain_crf_decoding": true, 24 | "calculate_span_f1": true, 25 | "dropout": 0.5, 26 | "include_start_end_transitions": false, 27 | "text_field_embedder": { 28 | "token_embedders": { 29 | "tokens": { 30 | "type": "embedding", 31 | "embedding_dim": 50, 32 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/glove/glove.6B.50d.txt.gz", 33 | "trainable": true 34 | }, 35 | "token_characters": { 36 | "type": "character_encoding", 37 | "embedding": { 38 | "embedding_dim": 16 39 | }, 40 | "encoder": { 41 | "type": "cnn", 42 | "embedding_dim": 16, 43 | "num_filters": 128, 44 | "ngram_filter_sizes": [3], 45 | "conv_layer_activation": "relu" 46 | } 47 | } 48 | }, 49 | }, 50 | "encoder": { 51 | "type": "lstm", 52 | "input_size": 50 + 128, 53 | "hidden_size": 200, 54 | "num_layers": 2, 55 | "dropout": 0.5, 56 | "bidirectional": true 57 | }, 58 | }, 59 | "iterator": { 60 | "type": "basic", 61 | "batch_size": 64 62 | }, 63 | "trainer": { 64 | "type": "callback", 65 | "callbacks": [ 66 | "checkpoint", 67 | "validate", 68 | {"type": "track_metrics", "patience": 25, "validation_metric": "+f1-measure-overall"}, 69 | {"type": "gradient_norm_and_clip", "grad_norm": 5.0}, 70 | {"type": "mlflow_metrics", "should_log_learning_rate": true}, 71 | ], 72 | "optimizer": { 73 | "type": "adam", 74 | "lr": 0.001 75 | }, 76 | "num_epochs": 75, 77 | "cuda_device": 0 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /mlflow/configs/sequence_tagging.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": {"type": "sequence_tagging"}, 3 | "train_data_path": "https://raw.githubusercontent.com/allenai/allennlp/master/allennlp/tests/fixtures/data/sequence_tagging.tsv", 4 | "validation_data_path": "https://raw.githubusercontent.com/allenai/allennlp/master/allennlp/tests/fixtures/data/sequence_tagging.tsv", 5 | "model": { 6 | "type": "simple_tagger", 7 | "text_field_embedder": { 8 | "token_embedders": {"tokens": {"type": "embedding", "embedding_dim": 5}} 9 | }, 10 | "encoder": {"type": "lstm", "input_size": 5, "hidden_size": 7, "num_layers": 2} 11 | }, 12 | "trainer": { 13 | "type": "callback", 14 | "optimizer": {"type": "sgd", "lr": 0.01, "momentum": 0.9}, 15 | "num_epochs": 2, 16 | "callbacks": [ 17 | "checkpoint", 18 | "track_metrics", 19 | "validate", 20 | "mlflow_metrics" 21 | ] 22 | }, 23 | "iterator": {"type": "basic", "batch_size": 2} 24 | } 25 | -------------------------------------------------------------------------------- /mlflow/data/ner_data.json: -------------------------------------------------------------------------------- 1 | { 2 | "columns": ["sentence"], 3 | "data":[ 4 | ["Jim bought 300 shares of Acme Corp. in 2006."], 5 | ["They marched from the Houses of Parliament to a rally in Hyde Park."] 6 | ] 7 | } 8 | -------------------------------------------------------------------------------- /mlflow/scripts/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from urllib.parse import urlparse 4 | 5 | import allennlp 6 | from allennlp.commands.train import train_model 7 | from allennlp.common.params import Params 8 | from allennlp.common.util import import_submodules 9 | import mlflow 10 | import mlflow.pyfunc 11 | from mlflow.utils.file_utils import yaml 12 | 13 | 14 | class AllennlpPredictorWrapper(mlflow.pyfunc.PythonModel): 15 | def __init__(self, predictor_name: str = None): 16 | self._predictor_name = predictor_name 17 | 18 | def load_context(self, context): 19 | from allennlp.predictors import Predictor 20 | self.predictor = Predictor.from_path( 21 | context.artifacts["model_archive"], 22 | predictor_name=self._predictor_name, 23 | ) 24 | 25 | def predict(self, context, model_input): 26 | inputs = model_input.to_dict(orient="records") 27 | return self.predictor.predict_batch_json(inputs) 28 | 29 | 30 | if __name__ == "__main__": 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument("param_path", type=str) 33 | parser.add_argument("-o", "--overrides", type=str, default="") 34 | parser.add_argument("--include-package", type=str, action="append", default=[]) 35 | parser.add_argument("--predictor", type=str, default=None) 36 | parser.add_argument("--model-path", type=str, default="allennlp_model") 37 | parser.add_argument("--conda-env", type=str, default="conda.yaml") 38 | args = parser.parse_args() 39 | 40 | for package_name in args.include_package: 41 | import_submodules(package_name) 42 | 43 | params = Params.from_file(args.param_path, args.overrides) 44 | 45 | with mlflow.start_run(): 46 | artifact_uri = urlparse(mlflow.get_artifact_uri()) 47 | if artifact_uri.scheme != "file": 48 | raise RuntimeError("scheme not supported: {artifact_uri.scheme}") 49 | 50 | serialization_dir = artifact_uri.path 51 | _model = train_model(params, serialization_dir) 52 | 53 | artifacts = { 54 | "model_archive": mlflow.get_artifact_uri("model.tar.gz") 55 | } 56 | with open(args.conda_env, "r") as f: 57 | conda_env = yaml.safe_load(f) 58 | 59 | mlflow.pyfunc.log_model( 60 | artifact_path=args.model_path, 61 | python_model=AllennlpPredictorWrapper(args.predictor), 62 | artifacts=artifacts, 63 | conda_env=conda_env, 64 | ) 65 | -------------------------------------------------------------------------------- /mlflow/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kajyuuen/allennlp-book/e3df88ef284e24e5ae0d9cfa3527c255c05a25e6/mlflow/src/__init__.py -------------------------------------------------------------------------------- /mlflow/src/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kajyuuen/allennlp-book/e3df88ef284e24e5ae0d9cfa3527c255c05a25e6/mlflow/src/training/__init__.py -------------------------------------------------------------------------------- /mlflow/src/training/callbacks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kajyuuen/allennlp-book/e3df88ef284e24e5ae0d9cfa3527c255c05a25e6/mlflow/src/training/callbacks/__init__.py -------------------------------------------------------------------------------- /mlflow/src/training/callbacks/mlflow_metrics.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | 5 | import torch 6 | import mlflow 7 | from allennlp.common.checks import ConfigurationError 8 | from allennlp.models import Model 9 | from allennlp.training.callbacks.callback import Callback, handle_event 10 | from allennlp.training.callback_trainer import CallbackTrainer 11 | from allennlp.training.callbacks.events import Events 12 | from allennlp.training.metric_tracker import MetricTracker 13 | 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | @Callback.register("mlflow_metrics") 19 | class MlflowMetrics(Callback): 20 | def __init__(self, should_log_learning_rate: bool = False) -> None: 21 | self._should_log_learning_rate = should_log_learning_rate 22 | 23 | @staticmethod 24 | def log_metric(key, value, step=None): 25 | if mlflow.active_run() is None: 26 | logger.warning("A new mlflow active run will be created.") 27 | mlflow.log_metric(key, value, step) 28 | 29 | @handle_event(Events.BATCH_END, priority=100) 30 | def end_of_batch(self, trainer: CallbackTrainer): 31 | step = trainer.batch_num_total 32 | self.log_metric("batch/training_loss", trainer.train_metrics["loss"], step) 33 | for key, value in trainer.train_metrics.items(): 34 | self.log_metric("batch/training_" + key, value, step) 35 | 36 | if self._should_log_learning_rate: 37 | names = {param: name for name, param in trainer.model.named_parameters()} 38 | for group in trainer.optimizer.param_groups: 39 | if "lr" not in group: 40 | continue 41 | rate = group["lr"] 42 | for param in group["params"]: 43 | effective_rate = rate * float(param.requires_grad) 44 | self.log_metric( 45 | "batch/learning_rate." + names[param], 46 | effective_rate, step 47 | ) 48 | 49 | @handle_event(Events.EPOCH_END, priority=110) 50 | def end_of_epoch(self, trainer: CallbackTrainer): 51 | epoch = trainer.epoch_number + 1 52 | training_elapsed_time = time.time() - trainer.training_start_time 53 | self.log_metric("training_duration_seconds", training_elapsed_time, epoch) 54 | 55 | for key, value in trainer.train_metrics.items(): 56 | self.log_metric("epoch/training_" + key, value, epoch) 57 | for key, value in trainer.val_metrics.items(): 58 | self.log_metric("epoch/validation_" + key, value, epoch) 59 | -------------------------------------------------------------------------------- /ner-model/configs/bert-experiment.jsonnet: -------------------------------------------------------------------------------- 1 | local embedding_dim = 768; 2 | local hidden_dim = 2; 3 | local num_epochs = 1; 4 | local batch_size = 2; 5 | local learning_rate = 0.1; 6 | local bert_path = "./pretrain_bert"; 7 | 8 | { 9 | dataset_reader: { 10 | type: "conll_2003_reader", 11 | token_indexers: { 12 | bert: { 13 | type: "bert-pretrained", 14 | pretrained_model: bert_path, 15 | do_lowercase: false 16 | }, 17 | }, 18 | }, 19 | train_data_path: "datasets/eng.train", 20 | validation_data_path: "datasets/eng.testa", 21 | model: { 22 | type: "ner_tagger", 23 | word_embeddings: { 24 | allow_unmatched_keys: true, 25 | bert: { 26 | type: "bert-pretrained", 27 | pretrained_model: bert_path, 28 | } 29 | }, 30 | encoder: { 31 | type: "lstm", 32 | input_size: embedding_dim, 33 | hidden_size: hidden_dim 34 | } 35 | }, 36 | iterator: { 37 | type: "bucket", 38 | batch_size: batch_size, 39 | sorting_keys: [["sentence", "num_tokens"]] 40 | }, 41 | trainer: { 42 | num_epochs: num_epochs, 43 | optimizer: { 44 | type: "sgd", 45 | lr: learning_rate 46 | } 47 | } 48 | } -------------------------------------------------------------------------------- /ner-model/configs/experiment.jsonnet: -------------------------------------------------------------------------------- 1 | local embedding_dim = 6; 2 | local hidden_dim = 2; 3 | local num_epochs = 1; 4 | local batch_size = 2; 5 | local learning_rate = 0.1; 6 | 7 | { 8 | dataset_reader: { 9 | type: 'conll_2003_reader', 10 | }, 11 | train_data_path: 'datasets/eng.train', 12 | validation_data_path: 'datasets/eng.testa', 13 | model: { 14 | type: 'ner_tagger', 15 | word_embeddings: { 16 | tokens: { 17 | type: 'embedding', 18 | embedding_dim: embedding_dim 19 | } 20 | }, 21 | encoder: { 22 | type: 'lstm', 23 | input_size: embedding_dim, 24 | hidden_size: hidden_dim 25 | } 26 | }, 27 | iterator: { 28 | type: 'bucket', 29 | batch_size: batch_size, 30 | sorting_keys: [['sentence', 'num_tokens']] 31 | }, 32 | trainer: { 33 | num_epochs: num_epochs, 34 | optimizer: { 35 | type: 'sgd', 36 | lr: learning_rate 37 | } 38 | } 39 | } -------------------------------------------------------------------------------- /ner-model/datasets/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | curl -OL https://github.com/synalp/NER/raw/master/corpus/CoNLL-2003/eng.train 4 | curl -OL https://github.com/synalp/NER/raw/master/corpus/CoNLL-2003/eng.testa 5 | curl -OL https://github.com/synalp/NER/raw/master/corpus/CoNLL-2003/eng.testb 6 | -------------------------------------------------------------------------------- /ner-model/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kajyuuen/allennlp-book/e3df88ef284e24e5ae0d9cfa3527c255c05a25e6/ner-model/src/__init__.py -------------------------------------------------------------------------------- /ner-model/src/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kajyuuen/allennlp-book/e3df88ef284e24e5ae0d9cfa3527c255c05a25e6/ner-model/src/data/__init__.py -------------------------------------------------------------------------------- /ner-model/src/data/dataset_readers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kajyuuen/allennlp-book/e3df88ef284e24e5ae0d9cfa3527c255c05a25e6/ner-model/src/data/dataset_readers/__init__.py -------------------------------------------------------------------------------- /ner-model/src/data/dataset_readers/conll_2003_reader.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Iterator 2 | from overrides import overrides 3 | 4 | from allennlp.data import Instance 5 | from allennlp.data.tokenizers import Token 6 | from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer 7 | from allennlp.data.dataset_readers import DatasetReader 8 | from allennlp.data.fields import TextField, SequenceLabelField 9 | 10 | @DatasetReader.register("conll_2003_reader") 11 | class Conll2003Reader(DatasetReader): 12 | def __init__(self, token_indexers: Dict[str, TokenIndexer] = None) -> None: 13 | super().__init__(lazy=False) 14 | self.token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()} 15 | 16 | @overrides 17 | def text_to_instance(self, 18 | tokens: List[Token], 19 | tags: List[str] = None) -> Instance: 20 | sentence_field = TextField(tokens, self.token_indexers) 21 | fields = {"sentence": sentence_field} 22 | 23 | if tags: 24 | label_field = SequenceLabelField(labels = tags, 25 | sequence_field = sentence_field) 26 | fields["labels"] = label_field 27 | 28 | return Instance(fields) 29 | 30 | @overrides 31 | def _read(self, file_path: str) -> Iterator[Instance]: 32 | with open(file_path) as f: 33 | sentence, tags = [], [] 34 | for line in f: 35 | rows = line.strip().split() 36 | if len(rows) == 0: 37 | if len(sentence) > 0: 38 | yield self.text_to_instance([Token(word) for word in sentence], tags) 39 | sentence, tags = [], [] 40 | continue 41 | word, tag = rows[0], rows[3] 42 | sentence.append(word) 43 | tags.append(tag) -------------------------------------------------------------------------------- /ner-model/src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kajyuuen/allennlp-book/e3df88ef284e24e5ae0d9cfa3527c255c05a25e6/ner-model/src/models/__init__.py -------------------------------------------------------------------------------- /ner-model/src/models/ner_tagger.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from overrides import overrides 3 | 4 | import torch 5 | 6 | from allennlp.models import Model 7 | 8 | from allennlp.data.vocabulary import Vocabulary 9 | 10 | from allennlp.modules.text_field_embedders import TextFieldEmbedder 11 | from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder 12 | 13 | from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits 14 | 15 | from allennlp.training.metrics import SpanBasedF1Measure 16 | 17 | @Model.register('ner_tagger') 18 | class NERTagger(Model): 19 | def __init__(self, 20 | word_embeddings: TextFieldEmbedder, 21 | encoder: Seq2SeqEncoder, 22 | vocab: Vocabulary) -> None: 23 | super().__init__(vocab) 24 | self.word_embeddings = word_embeddings 25 | self.encoder = encoder 26 | 27 | self.hidden2tag = torch.nn.Linear(in_features = encoder.get_output_dim(), 28 | out_features = vocab.get_vocab_size('labels')) 29 | 30 | self._f1_metric = SpanBasedF1Measure(vocab, 'labels') 31 | 32 | @overrides 33 | def forward(self, 34 | sentence: Dict[str, torch.Tensor], 35 | labels: torch.Tensor = None) -> Dict[str, torch.Tensor]: 36 | mask = get_text_field_mask(sentence) 37 | 38 | embeddings = self.word_embeddings(sentence) 39 | 40 | encoder_out = self.encoder(embeddings, mask) 41 | 42 | tag_logits = self.hidden2tag(encoder_out) 43 | output = {"tag_logits": tag_logits} 44 | 45 | if labels is not None: 46 | self._f1_metric(tag_logits, labels, mask) 47 | output["loss"] = sequence_cross_entropy_with_logits(tag_logits, labels, mask) 48 | 49 | return output 50 | 51 | @overrides 52 | def get_metrics(self, 53 | reset: bool = False) -> Dict[str, float]: 54 | return self._f1_metric.get_metric(reset=reset) 55 | 56 | -------------------------------------------------------------------------------- /ner-model/src/predictors/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kajyuuen/allennlp-book/e3df88ef284e24e5ae0d9cfa3527c255c05a25e6/ner-model/src/predictors/__init__.py -------------------------------------------------------------------------------- /ner-model/src/predictors/conll_2003_predictor.py: -------------------------------------------------------------------------------- 1 | from overrides import overrides 2 | 3 | import numpy as np 4 | 5 | from allennlp.common.util import JsonDict, sanitize 6 | from allennlp.data import DatasetReader, Instance 7 | from allennlp.predictors.predictor import Predictor 8 | from allennlp.models import Model 9 | 10 | @Predictor.register("conll_2003_predictor") 11 | class CoNLL2003Predictor(Predictor): 12 | def __init__(self, model:Model, dataset_reader: DatasetReader) -> None: 13 | super().__init__(model, dataset_reader) 14 | 15 | @overrides 16 | def predict_instance(self, instance: Instance) -> JsonDict: 17 | outputs = self._model.forward_on_instance(instance) 18 | 19 | outputs["predicted_labels"] = [self._model.vocab.get_token_from_index(i, 'labels') for i in np.argmax(outputs["tag_logits"], axis=-1)] 20 | del outputs["tag_logits"] 21 | 22 | return sanitize(outputs) 23 | -------------------------------------------------------------------------------- /ner-model/src/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kajyuuen/allennlp-book/e3df88ef284e24e5ae0d9cfa3527c255c05a25e6/ner-model/src/tests/__init__.py -------------------------------------------------------------------------------- /ner-model/src/tests/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kajyuuen/allennlp-book/e3df88ef284e24e5ae0d9cfa3527c255c05a25e6/ner-model/src/tests/data/__init__.py -------------------------------------------------------------------------------- /ner-model/src/tests/data/dataset_readers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kajyuuen/allennlp-book/e3df88ef284e24e5ae0d9cfa3527c255c05a25e6/ner-model/src/tests/data/dataset_readers/__init__.py -------------------------------------------------------------------------------- /ner-model/src/tests/data/dataset_readers/conll_2003_reader_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from src.data.dataset_readers.conll_2003_reader import Conll2003Reader 4 | from allennlp.common.util import ensure_list 5 | 6 | class TestConll2003Reader: 7 | def test_read_from_file(self): 8 | conll_reader = Conll2003Reader() 9 | instances = conll_reader.read("./src/tests/fixtures/data/conll2003.txt") 10 | instances = ensure_list(instances) 11 | 12 | fields = instances[0].fields 13 | tokens = [t.text for t in fields["sentence"].tokens] 14 | assert tokens == ["U.N.", "official", "Ekeus", "heads", "for", "Baghdad", "."] 15 | assert fields["labels"].labels == ["I-ORG", "O", "I-PER", "O", "O", "I-LOC", "O"] 16 | 17 | -------------------------------------------------------------------------------- /ner-model/src/tests/fixtures/configs/experiment.jsonnet: -------------------------------------------------------------------------------- 1 | local embedding_dim = 6; 2 | local hidden_dim = 2; 3 | local num_epochs = 1; 4 | local batch_size = 2; 5 | local learning_rate = 0.1; 6 | 7 | { 8 | dataset_reader: { 9 | type: 'conll_2003_reader', 10 | }, 11 | train_data_path: './src/tests/fixtures/data/conll2003.txt', 12 | validation_data_path: './src/tests/fixtures/data/conll2003.txt', 13 | model: { 14 | type: 'ner_tagger', 15 | word_embeddings: { 16 | tokens: { 17 | type: 'embedding', 18 | embedding_dim: embedding_dim 19 | } 20 | }, 21 | encoder: { 22 | type: 'lstm', 23 | input_size: embedding_dim, 24 | hidden_size: hidden_dim 25 | } 26 | }, 27 | iterator: { 28 | type: 'bucket', 29 | batch_size: batch_size, 30 | sorting_keys: [['sentence', 'num_tokens']] 31 | }, 32 | trainer: { 33 | num_epochs: num_epochs, 34 | optimizer: { 35 | type: 'sgd', 36 | lr: learning_rate 37 | } 38 | } 39 | } -------------------------------------------------------------------------------- /ner-model/src/tests/fixtures/data/conll2003.txt: -------------------------------------------------------------------------------- 1 | U.N. NNP I-NP I-ORG 2 | official NN I-NP O 3 | Ekeus NNP I-NP I-PER 4 | heads VBZ I-VP O 5 | for IN I-PP O 6 | Baghdad NNP I-NP I-LOC 7 | . . O O 8 | 9 | -------------------------------------------------------------------------------- /ner-model/src/tests/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kajyuuen/allennlp-book/e3df88ef284e24e5ae0d9cfa3527c255c05a25e6/ner-model/src/tests/models/__init__.py -------------------------------------------------------------------------------- /ner-model/src/tests/models/ner_tagger_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from src.models.ner_tagger import NERTagger 3 | 4 | from allennlp.models import Model 5 | from allennlp.common.testing import ModelTestCase 6 | 7 | class TestNERTagger(ModelTestCase): 8 | def setUp(self): 9 | super().setUp() 10 | self.set_up_model( 11 | "./src/tests/fixtures/configs/experiment.jsonnet", 12 | "./src/tests/fixtures/data/conll2003.txt" 13 | ) 14 | 15 | def test_train(self): 16 | self.ensure_model_can_train_save_and_load(self.param_file) 17 | -------------------------------------------------------------------------------- /nli/.gitignore: -------------------------------------------------------------------------------- 1 | /tmp 2 | -------------------------------------------------------------------------------- /nli/configs/esim.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "snli", 4 | "token_indexers": { 5 | "tokens": { 6 | "type": "single_id", 7 | "lowercase_tokens": true 8 | } 9 | } 10 | }, 11 | "train_data_path": "https://allennlp.s3.amazonaws.com/datasets/snli/snli_1.0_train.jsonl", 12 | "validation_data_path": "https://allennlp.s3.amazonaws.com/datasets/snli/snli_1.0_dev.jsonl", 13 | "model": { 14 | "type": "esim", 15 | "dropout": 0.5, 16 | "text_field_embedder": { 17 | "token_embedders": { 18 | "tokens": { 19 | "type": "embedding", 20 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.txt.gz", 21 | "embedding_dim": 300, 22 | "trainable": true 23 | } 24 | } 25 | }, 26 | "encoder": { 27 | "type": "lstm", 28 | "input_size": 300, 29 | "hidden_size": 300, 30 | "num_layers": 1, 31 | "bidirectional": true 32 | }, 33 | "similarity_function": { 34 | "type": "dot_product" 35 | }, 36 | "projection_feedforward": { 37 | "input_dim": 2400, 38 | "hidden_dims": 300, 39 | "num_layers": 1, 40 | "activations": "relu" 41 | }, 42 | "inference_encoder": { 43 | "type": "lstm", 44 | "input_size": 300, 45 | "hidden_size": 300, 46 | "num_layers": 1, 47 | "bidirectional": true 48 | }, 49 | "output_feedforward": { 50 | "input_dim": 2400, 51 | "num_layers": 1, 52 | "hidden_dims": 300, 53 | "activations": "relu", 54 | "dropout": 0.5 55 | }, 56 | "output_logit": { 57 | "input_dim": 300, 58 | "num_layers": 1, 59 | "hidden_dims": 3, 60 | "activations": "linear" 61 | }, 62 | "initializer": [ 63 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 64 | [".*linear_layers.*bias", {"type": "zero"}], 65 | [".*weight_ih.*", {"type": "xavier_uniform"}], 66 | [".*weight_hh.*", {"type": "orthogonal"}], 67 | [".*bias_ih.*", {"type": "zero"}], 68 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 69 | ] 70 | }, 71 | "iterator": { 72 | "type": "bucket", 73 | "sorting_keys": [["premise", "num_tokens"], 74 | ["hypothesis", "num_tokens"]], 75 | "batch_size": 32 76 | }, 77 | "trainer": { 78 | "optimizer": { 79 | "type": "adam", 80 | "lr": 0.0004 81 | }, 82 | "validation_metric": "+accuracy", 83 | "num_serialized_models_to_keep": 2, 84 | "num_epochs": 75, 85 | "grad_norm": 10.0, 86 | "patience": 5, 87 | "cuda_device": 0, 88 | "learning_rate_scheduler": { 89 | "type": "reduce_on_plateau", 90 | "factor": 0.5, 91 | "mode": "max", 92 | "patience": 0 93 | } 94 | } 95 | } 96 | 97 | -------------------------------------------------------------------------------- /nli/configs/san.jsonnet: -------------------------------------------------------------------------------- 1 | local embedding_dim = 300; 2 | local hidden_dim = 300; 3 | local bidirectional = true; 4 | local num_layers = 2; 5 | local maxout = true; 6 | 7 | local num_directions = if bidirectional then 2 else 1; 8 | 9 | 10 | { 11 | "dataset_reader": { 12 | "type": "snli", 13 | "token_indexers": { 14 | "tokens": { 15 | "type": "single_id", 16 | "lowercase_tokens": true 17 | } 18 | } 19 | }, 20 | "train_data_path": "https://allennlp.s3.amazonaws.com/datasets/snli/snli_1.0_train.jsonl", 21 | "validation_data_path": "https://allennlp.s3.amazonaws.com/datasets/snli/snli_1.0_dev.jsonl", 22 | "model": { 23 | "type": "san", 24 | "dropout": 0.5, 25 | "text_field_embedder": { 26 | "token_embedders": { 27 | "tokens": { 28 | "type": "embedding", 29 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.txt.gz", 30 | "embedding_dim": embedding_dim, 31 | "trainable": true, 32 | } 33 | } 34 | }, 35 | "lexical_feedforward": { 36 | "input_dim": embedding_dim, 37 | "hidden_dims": hidden_dim, 38 | "num_layers": 1, 39 | "activations": "relu" 40 | }, 41 | "contextual_encoder": { 42 | "type": "full-layer-lstm", 43 | "input_dim": hidden_dim, 44 | "hidden_dim": hidden_dim, 45 | "num_layers": num_layers, 46 | "bidirectional": bidirectional, 47 | "maxout": maxout, 48 | }, 49 | "attention_feedforward": { 50 | "input_dim": hidden_dim * num_layers * (if maxout then 1 else num_directions), 51 | "hidden_dims": hidden_dim, 52 | "num_layers": 1, 53 | "activations": "relu" 54 | }, 55 | "matrix_attention": { 56 | "type": "dot_product" 57 | }, 58 | "memory_encoder": { 59 | "type": "lstm", 60 | "input_size": 2 * hidden_dim * num_directions, 61 | "hidden_size": hidden_dim, 62 | "num_layers": 1, 63 | "bidirectional": bidirectional, 64 | }, 65 | "output_feedforward": { 66 | "input_dim": 4 * hidden_dim * num_directions, 67 | "num_layers": 1, 68 | "hidden_dims": hidden_dim, 69 | "activations": "relu", 70 | "dropout": 0.5 71 | }, 72 | "output_logit": { 73 | "input_dim": hidden_dim, 74 | "num_layers": 1, 75 | "hidden_dims": 3, 76 | "activations": "linear" 77 | }, 78 | "initializer": [ 79 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 80 | [".*linear_layers.*bias", {"type": "zero"}], 81 | [".*weight_ih.*", {"type": "xavier_uniform"}], 82 | [".*weight_hh.*", {"type": "orthogonal"}], 83 | [".*bias_ih.*", {"type": "zero"}], 84 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 85 | ] 86 | }, 87 | "iterator": { 88 | "type": "bucket", 89 | "sorting_keys": [["premise", "num_tokens"], 90 | ["hypothesis", "num_tokens"]], 91 | "batch_size": 32 92 | }, 93 | "trainer": { 94 | "optimizer": { 95 | "type": "adam", 96 | "lr": 0.0004 97 | }, 98 | "validation_metric": "+accuracy", 99 | "num_serialized_models_to_keep": 2, 100 | "num_epochs": 75, 101 | "grad_norm": 10.0, 102 | "patience": 5, 103 | "cuda_device": 0, 104 | "learning_rate_scheduler": { 105 | "type": "reduce_on_plateau", 106 | "factor": 0.5, 107 | "mode": "max", 108 | "patience": 0 109 | } 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /nli/configs/san_test.jsonnet: -------------------------------------------------------------------------------- 1 | local embedding_dim = 300; 2 | local hidden_dim = 300; 3 | local bidirectional = true; 4 | 5 | local num_directions = if bidirectional then 2 else 1; 6 | 7 | { 8 | "random_seed": 1, 9 | "pytorch_seed": 1, 10 | "dataset_reader": { 11 | "type": "snli", 12 | "token_indexers": { 13 | "tokens": { 14 | "type": "single_id", 15 | "lowercase_tokens": true 16 | } 17 | } 18 | }, 19 | "train_data_path": "data/snli_test.jsonl", 20 | "validation_data_path": "data/snli_test.jsonl", 21 | "model": { 22 | "type": "san", 23 | "dropout": 0.5, 24 | "text_field_embedder": { 25 | "token_embedders": { 26 | "tokens": { 27 | "type": "embedding", 28 | "embedding_dim": embedding_dim, 29 | } 30 | } 31 | }, 32 | "lexical_feedforward": { 33 | "input_dim": embedding_dim, 34 | "hidden_dims": hidden_dim, 35 | "num_layers": 1, 36 | "activations": "relu" 37 | }, 38 | "contextual_encoder": { 39 | "type": "lstm", 40 | "input_size": hidden_dim, 41 | "hidden_size": hidden_dim, 42 | "num_layers": 2, 43 | "bidirectional": bidirectional, 44 | }, 45 | "attention_feedforward": { 46 | "input_dim": hidden_dim * num_directions, 47 | "hidden_dims": hidden_dim, 48 | "num_layers": 1, 49 | "activations": "relu" 50 | }, 51 | "matrix_attention": { 52 | "type": "dot_product" 53 | }, 54 | "memory_encoder": { 55 | "type": "lstm", 56 | "input_size": 2 * hidden_dim * num_directions, 57 | "hidden_size": hidden_dim, 58 | "num_layers": 1, 59 | "bidirectional": bidirectional, 60 | }, 61 | "output_feedforward": { 62 | "input_dim": 4 * hidden_dim * num_directions, 63 | "num_layers": 1, 64 | "hidden_dims": hidden_dim, 65 | "activations": "relu", 66 | "dropout": 0.5 67 | }, 68 | "output_logit": { 69 | "input_dim": hidden_dim, 70 | "num_layers": 1, 71 | "hidden_dims": 3, 72 | "activations": "linear" 73 | }, 74 | "initializer": [ 75 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 76 | [".*linear_layers.*bias", {"type": "zero"}], 77 | [".*weight_ih.*", {"type": "xavier_uniform"}], 78 | [".*weight_hh.*", {"type": "orthogonal"}], 79 | [".*bias_ih.*", {"type": "zero"}], 80 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 81 | ] 82 | }, 83 | "iterator": { 84 | "type": "bucket", 85 | "sorting_keys": [["premise", "num_tokens"], 86 | ["hypothesis", "num_tokens"]], 87 | "batch_size": 5 88 | }, 89 | "trainer": { 90 | "optimizer": { 91 | "type": "adam", 92 | "lr": 0.0004 93 | }, 94 | "validation_metric": "+accuracy", 95 | "num_serialized_models_to_keep": 2, 96 | "num_epochs": 5, 97 | "grad_norm": 10.0, 98 | "patience": 5, 99 | "cuda_device": 0, 100 | "learning_rate_scheduler": { 101 | "type": "reduce_on_plateau", 102 | "factor": 0.5, 103 | "mode": "max", 104 | "patience": 0 105 | } 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /nli/data/snli_test.jsonl: -------------------------------------------------------------------------------- 1 | {"annotator_labels": ["neutral", "entailment", "neutral", "neutral", "neutral"], "captionID": "4705552913.jpg#2", "gold_label": "neutral", "pairID": "4705552913.jpg#2r1n", "sentence1": "Two women are embracing while holding to go packages.", "sentence1_binary_parse": "( ( Two women ) ( ( are ( embracing ( while ( holding ( to ( go packages ) ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (CD Two) (NNS women)) (VP (VBP are) (VP (VBG embracing) (SBAR (IN while) (S (NP (VBG holding)) (VP (TO to) (VP (VB go) (NP (NNS packages)))))))) (. .)))", "sentence2": "The sisters are hugging goodbye while holding to go packages after just eating lunch.", "sentence2_binary_parse": "( ( The sisters ) ( ( are ( ( hugging goodbye ) ( while ( holding ( to ( ( go packages ) ( after ( just ( eating lunch ) ) ) ) ) ) ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT The) (NNS sisters)) (VP (VBP are) (VP (VBG hugging) (NP (UH goodbye)) (PP (IN while) (S (VP (VBG holding) (S (VP (TO to) (VP (VB go) (NP (NNS packages)) (PP (IN after) (S (ADVP (RB just)) (VP (VBG eating) (NP (NN lunch))))))))))))) (. .)))"} 2 | {"annotator_labels": ["entailment", "entailment", "entailment", "entailment", "entailment"], "captionID": "4705552913.jpg#2", "gold_label": "entailment", "pairID": "4705552913.jpg#2r1e", "sentence1": "Two women are embracing while holding to go packages.", "sentence1_binary_parse": "( ( Two women ) ( ( are ( embracing ( while ( holding ( to ( go packages ) ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (CD Two) (NNS women)) (VP (VBP are) (VP (VBG embracing) (SBAR (IN while) (S (NP (VBG holding)) (VP (TO to) (VP (VB go) (NP (NNS packages)))))))) (. .)))", "sentence2": "Two woman are holding packages.", "sentence2_binary_parse": "( ( Two woman ) ( ( are ( holding packages ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (CD Two) (NN woman)) (VP (VBP are) (VP (VBG holding) (NP (NNS packages)))) (. .)))"} 3 | {"annotator_labels": ["contradiction", "contradiction", "contradiction", "contradiction", "contradiction"], "captionID": "4705552913.jpg#2", "gold_label": "contradiction", "pairID": "4705552913.jpg#2r1c", "sentence1": "Two women are embracing while holding to go packages.", "sentence1_binary_parse": "( ( Two women ) ( ( are ( embracing ( while ( holding ( to ( go packages ) ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (CD Two) (NNS women)) (VP (VBP are) (VP (VBG embracing) (SBAR (IN while) (S (NP (VBG holding)) (VP (TO to) (VP (VB go) (NP (NNS packages)))))))) (. .)))", "sentence2": "The men are fighting outside a deli.", "sentence2_binary_parse": "( ( The men ) ( ( are ( fighting ( outside ( a deli ) ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT The) (NNS men)) (VP (VBP are) (VP (VBG fighting) (PP (IN outside) (NP (DT a) (NNS deli))))) (. .)))"} 4 | {"annotator_labels": ["entailment", "entailment", "entailment", "entailment", "entailment"], "captionID": "2407214681.jpg#0", "gold_label": "entailment", "pairID": "2407214681.jpg#0r1e", "sentence1": "Two young children in blue jerseys, one with the number 9 and one with the number 2 are standing on wooden steps in a bathroom and washing their hands in a sink.", "sentence1_binary_parse": "( ( ( Two ( young children ) ) ( in ( ( ( ( ( blue jerseys ) , ) ( one ( with ( the ( number 9 ) ) ) ) ) and ) ( one ( with ( the ( number 2 ) ) ) ) ) ) ) ( ( are ( ( ( standing ( on ( ( wooden steps ) ( in ( a bathroom ) ) ) ) ) and ) ( ( washing ( their hands ) ) ( in ( a sink ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (NP (CD Two) (JJ young) (NNS children)) (PP (IN in) (NP (NP (JJ blue) (NNS jerseys)) (, ,) (NP (NP (CD one)) (PP (IN with) (NP (DT the) (NN number) (CD 9)))) (CC and) (NP (NP (CD one)) (PP (IN with) (NP (DT the) (NN number) (CD 2))))))) (VP (VBP are) (VP (VP (VBG standing) (PP (IN on) (NP (NP (JJ wooden) (NNS steps)) (PP (IN in) (NP (DT a) (NN bathroom)))))) (CC and) (VP (VBG washing) (NP (PRP$ their) (NNS hands)) (PP (IN in) (NP (DT a) (NN sink)))))) (. .)))", "sentence2": "Two kids in numbered jerseys wash their hands.", "sentence2_binary_parse": "( ( ( Two kids ) ( in ( numbered jerseys ) ) ) ( ( wash ( their hands ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (NP (CD Two) (NNS kids)) (PP (IN in) (NP (JJ numbered) (NNS jerseys)))) (VP (VBP wash) (NP (PRP$ their) (NNS hands))) (. .)))"} 5 | {"annotator_labels": ["neutral", "neutral", "neutral", "entailment", "entailment"], "captionID": "2407214681.jpg#0", "gold_label": "neutral", "pairID": "2407214681.jpg#0r1n", "sentence1": "Two young children in blue jerseys, one with the number 9 and one with the number 2 are standing on wooden steps in a bathroom and washing their hands in a sink.", "sentence1_binary_parse": "( ( ( Two ( young children ) ) ( in ( ( ( ( ( blue jerseys ) , ) ( one ( with ( the ( number 9 ) ) ) ) ) and ) ( one ( with ( the ( number 2 ) ) ) ) ) ) ) ( ( are ( ( ( standing ( on ( ( wooden steps ) ( in ( a bathroom ) ) ) ) ) and ) ( ( washing ( their hands ) ) ( in ( a sink ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (NP (CD Two) (JJ young) (NNS children)) (PP (IN in) (NP (NP (JJ blue) (NNS jerseys)) (, ,) (NP (NP (CD one)) (PP (IN with) (NP (DT the) (NN number) (CD 9)))) (CC and) (NP (NP (CD one)) (PP (IN with) (NP (DT the) (NN number) (CD 2))))))) (VP (VBP are) (VP (VP (VBG standing) (PP (IN on) (NP (NP (JJ wooden) (NNS steps)) (PP (IN in) (NP (DT a) (NN bathroom)))))) (CC and) (VP (VBG washing) (NP (PRP$ their) (NNS hands)) (PP (IN in) (NP (DT a) (NN sink)))))) (. .)))", "sentence2": "Two kids at a ballgame wash their hands.", "sentence2_binary_parse": "( ( ( Two kids ) ( at ( a ballgame ) ) ) ( ( wash ( their hands ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (NP (CD Two) (NNS kids)) (PP (IN at) (NP (DT a) (NN ballgame)))) (VP (VBP wash) (NP (PRP$ their) (NNS hands))) (. .)))"} 6 | {"annotator_labels": ["contradiction", "contradiction", "contradiction", "contradiction", "contradiction"], "captionID": "2407214681.jpg#0", "gold_label": "contradiction", "pairID": "2407214681.jpg#0r1c", "sentence1": "Two young children in blue jerseys, one with the number 9 and one with the number 2 are standing on wooden steps in a bathroom and washing their hands in a sink.", "sentence1_binary_parse": "( ( ( Two ( young children ) ) ( in ( ( ( ( ( blue jerseys ) , ) ( one ( with ( the ( number 9 ) ) ) ) ) and ) ( one ( with ( the ( number 2 ) ) ) ) ) ) ) ( ( are ( ( ( standing ( on ( ( wooden steps ) ( in ( a bathroom ) ) ) ) ) and ) ( ( washing ( their hands ) ) ( in ( a sink ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (NP (CD Two) (JJ young) (NNS children)) (PP (IN in) (NP (NP (JJ blue) (NNS jerseys)) (, ,) (NP (NP (CD one)) (PP (IN with) (NP (DT the) (NN number) (CD 9)))) (CC and) (NP (NP (CD one)) (PP (IN with) (NP (DT the) (NN number) (CD 2))))))) (VP (VBP are) (VP (VP (VBG standing) (PP (IN on) (NP (NP (JJ wooden) (NNS steps)) (PP (IN in) (NP (DT a) (NN bathroom)))))) (CC and) (VP (VBG washing) (NP (PRP$ their) (NNS hands)) (PP (IN in) (NP (DT a) (NN sink)))))) (. .)))", "sentence2": "Two kids in jackets walk to school.", "sentence2_binary_parse": "( ( ( Two kids ) ( in jackets ) ) ( ( walk ( to school ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (NP (CD Two) (NNS kids)) (PP (IN in) (NP (NNS jackets)))) (VP (VBP walk) (PP (TO to) (NP (NN school)))) (. .)))"} 7 | {"annotator_labels": ["contradiction", "contradiction", "contradiction", "contradiction", "contradiction"], "captionID": "4718146904.jpg#2", "gold_label": "contradiction", "pairID": "4718146904.jpg#2r1c", "sentence1": "A man selling donuts to a customer during a world exhibition event held in the city of Angeles", "sentence1_binary_parse": "( ( A ( man selling ) ) ( ( donuts ( to ( a customer ) ) ) ( during ( ( a ( world ( exhibition event ) ) ) ( held ( in ( ( the city ) ( of Angeles ) ) ) ) ) ) ) )", "sentence1_parse": "(ROOT (S (NP (DT A) (NN man) (NN selling)) (VP (VBZ donuts) (PP (TO to) (NP (DT a) (NN customer))) (PP (IN during) (NP (NP (DT a) (NN world) (NN exhibition) (NN event)) (VP (VBN held) (PP (IN in) (NP (NP (DT the) (NN city)) (PP (IN of) (NP (NNP Angeles)))))))))))", "sentence2": "A woman drinks her coffee in a small cafe.", "sentence2_binary_parse": "( ( A woman ) ( ( ( drinks ( her coffee ) ) ( in ( a ( small cafe ) ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT A) (NN woman)) (VP (VBZ drinks) (NP (PRP$ her) (NN coffee)) (PP (IN in) (NP (DT a) (JJ small) (NN cafe)))) (. .)))"} 8 | {"annotator_labels": ["neutral", "entailment", "entailment", "neutral", "neutral"], "captionID": "4718146904.jpg#2", "gold_label": "neutral", "pairID": "4718146904.jpg#2r1n", "sentence1": "A man selling donuts to a customer during a world exhibition event held in the city of Angeles", "sentence1_binary_parse": "( ( A ( man selling ) ) ( ( donuts ( to ( a customer ) ) ) ( during ( ( a ( world ( exhibition event ) ) ) ( held ( in ( ( the city ) ( of Angeles ) ) ) ) ) ) ) )", "sentence1_parse": "(ROOT (S (NP (DT A) (NN man) (NN selling)) (VP (VBZ donuts) (PP (TO to) (NP (DT a) (NN customer))) (PP (IN during) (NP (NP (DT a) (NN world) (NN exhibition) (NN event)) (VP (VBN held) (PP (IN in) (NP (NP (DT the) (NN city)) (PP (IN of) (NP (NNP Angeles)))))))))))", "sentence2": "A man selling donuts to a customer during a world exhibition event while people wait in line behind him.", "sentence2_binary_parse": "( ( A ( man selling ) ) ( ( ( donuts ( to ( ( a customer ) ( during ( a ( world ( exhibition event ) ) ) ) ) ) ) ( while ( people ( ( wait ( in line ) ) ( behind him ) ) ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT A) (NN man) (NN selling)) (VP (VBZ donuts) (PP (TO to) (NP (NP (DT a) (NN customer)) (PP (IN during) (NP (DT a) (NN world) (NN exhibition) (NN event))))) (SBAR (IN while) (S (NP (NNS people)) (VP (VBP wait) (PP (IN in) (NP (NN line))) (PP (IN behind) (NP (PRP him))))))) (. .)))"} 9 | {"annotator_labels": ["entailment", "neutral", "entailment", "entailment", "entailment"], "captionID": "4718146904.jpg#2", "gold_label": "entailment", "pairID": "4718146904.jpg#2r1e", "sentence1": "A man selling donuts to a customer during a world exhibition event held in the city of Angeles", "sentence1_binary_parse": "( ( A ( man selling ) ) ( ( donuts ( to ( a customer ) ) ) ( during ( ( a ( world ( exhibition event ) ) ) ( held ( in ( ( the city ) ( of Angeles ) ) ) ) ) ) ) )", "sentence1_parse": "(ROOT (S (NP (DT A) (NN man) (NN selling)) (VP (VBZ donuts) (PP (TO to) (NP (DT a) (NN customer))) (PP (IN during) (NP (NP (DT a) (NN world) (NN exhibition) (NN event)) (VP (VBN held) (PP (IN in) (NP (NP (DT the) (NN city)) (PP (IN of) (NP (NNP Angeles)))))))))))", "sentence2": "A man selling donuts to a customer.", "sentence2_binary_parse": "( ( A ( man selling ) ) ( ( donuts ( to ( a customer ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT A) (NN man) (NN selling)) (VP (VBZ donuts) (PP (TO to) (NP (DT a) (NN customer)))) (. .)))"} 10 | {"annotator_labels": ["entailment", "neutral", "entailment", "entailment", "neutral"], "captionID": "3980085662.jpg#0", "gold_label": "entailment", "pairID": "3980085662.jpg#0r1e", "sentence1": "Two young boys of opposing teams play football, while wearing full protection uniforms and helmets.", "sentence1_binary_parse": "( ( ( Two ( young boys ) ) ( of ( opposing teams ) ) ) ( ( ( ( play football ) , ) ( while ( wearing ( full ( protection ( ( uniforms and ) helmets ) ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (NP (CD Two) (JJ young) (NNS boys)) (PP (IN of) (NP (VBG opposing) (NNS teams)))) (VP (VBP play) (NP (NN football)) (, ,) (PP (IN while) (S (VP (VBG wearing) (NP (JJ full) (NN protection) (NNS uniforms) (CC and) (NNS helmets)))))) (. .)))", "sentence2": "boys play football", "sentence2_binary_parse": "( boys ( play football ) )", "sentence2_parse": "(ROOT (S (NP (NNS boys)) (VP (VBP play) (NP (NN football)))))"} 11 | -------------------------------------------------------------------------------- /nli/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kajyuuen/allennlp-book/e3df88ef284e24e5ae0d9cfa3527c255c05a25e6/nli/src/__init__.py -------------------------------------------------------------------------------- /nli/src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kajyuuen/allennlp-book/e3df88ef284e24e5ae0d9cfa3527c255c05a25e6/nli/src/models/__init__.py -------------------------------------------------------------------------------- /nli/src/models/san.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple, cast 2 | 3 | import torch 4 | from allennlp.common.checks import check_dimensions_match 5 | from allennlp.data import Vocabulary 6 | from allennlp.models import Model 7 | from allennlp.modules import ( 8 | FeedForward, MatrixAttention, Seq2SeqEncoder, TextFieldEmbedder, TimeDistributed 9 | ) 10 | from allennlp.modules.attention import BilinearAttention 11 | from allennlp.training.metrics import CategoricalAccuracy 12 | from allennlp.nn import InitializerApplicator, RegularizerApplicator 13 | from allennlp.nn import util 14 | 15 | 16 | @Model.register("san") 17 | class StochasticAnswerNetworks(Model): 18 | def __init__(self, 19 | vocab: Vocabulary, 20 | text_field_embedder: TextFieldEmbedder, 21 | lexical_feedforward: FeedForward, 22 | contextual_encoder: Seq2SeqEncoder, 23 | attention_feedforward: FeedForward, 24 | matrix_attention: MatrixAttention, 25 | memory_encoder: Seq2SeqEncoder, 26 | output_feedforward: FeedForward, 27 | output_logit: FeedForward, 28 | answer_steps: int = 5, 29 | dropout: float = 0.5, 30 | initializer: InitializerApplicator = InitializerApplicator(), 31 | regularizer: Optional[RegularizerApplicator] = None) -> None: 32 | super().__init__(vocab, regularizer) 33 | 34 | self._text_field_embedder = text_field_embedder 35 | self._lexical_feedforward = TimeDistributed(lexical_feedforward) 36 | self._contextual_encoder = contextual_encoder 37 | self._attention_feedforward = TimeDistributed(attention_feedforward) 38 | self._matrix_attention = matrix_attention 39 | self._memory_encoder = memory_encoder 40 | self._output_feedforward = output_feedforward 41 | self._output_logit = output_logit 42 | self._answer_steps = answer_steps 43 | self._answer_gru_cell = torch.nn.GRUCell( 44 | self._memory_encoder.get_output_dim(), 45 | self._memory_encoder.get_output_dim(), 46 | ) 47 | self._answer_attention = TimeDistributed( 48 | torch.nn.Linear(self._memory_encoder.get_output_dim(), 1) 49 | ) 50 | self._answer_bilinear = BilinearAttention( 51 | self._memory_encoder.get_output_dim(), 52 | self._memory_encoder.get_output_dim(), 53 | ) 54 | 55 | check_dimensions_match(text_field_embedder.get_output_dim(), lexical_feedforward.get_input_dim(), 56 | "text field embedding dim", "lexical feedforward input dim") 57 | check_dimensions_match(lexical_feedforward.get_output_dim(), contextual_encoder.get_input_dim(), 58 | "lexical feedforwrd input dim", "contextual layer input dim") 59 | check_dimensions_match(contextual_encoder.get_output_dim(), attention_feedforward.get_input_dim(), 60 | "contextual layer output dim", "attention feedforward input dim") 61 | check_dimensions_match(contextual_encoder.get_output_dim() * 2, memory_encoder.get_input_dim(), 62 | "contextual layer output dim", "memory encoder input dim") 63 | check_dimensions_match(memory_encoder.get_output_dim() * 4, output_feedforward.get_input_dim(), 64 | "memory encoder output dim", "output feedforward input") 65 | check_dimensions_match(output_feedforward.get_output_dim(), output_logit.get_input_dim(), 66 | "output feedforward output dim", "output logit input") 67 | 68 | self._dropout = torch.nn.Dropout(dropout) if dropout else None 69 | 70 | self._accuracy = CategoricalAccuracy() 71 | self._loss = torch.nn.NLLLoss() 72 | 73 | initializer(self) 74 | 75 | def forward(self, # type: ignore 76 | premise: Dict[str, torch.LongTensor], 77 | hypothesis: Dict[str, torch.LongTensor], 78 | label: torch.IntTensor = None, 79 | metadata: List[Dict[str, Any]] = None # pylint: disable=unused-argument 80 | ) -> Dict[str, torch.Tensor]: 81 | # pylint: disable=arguments-differ 82 | premise_embeddings = self._text_field_embedder(premise) 83 | hypothesis_embeddings = self._text_field_embedder(hypothesis) 84 | 85 | premise_mask = util.get_text_field_mask(premise).float() 86 | hypothesis_mask = util.get_text_field_mask(hypothesis).float() 87 | 88 | # Lexicon Encoding Layer 89 | premise_lexical_embeddings = self._lexical_feedforward(premise_embeddings) 90 | hypothesis_lexical_embeddings = self._lexical_feedforward(hypothesis_embeddings) 91 | 92 | # Contextual Encoding Layer 93 | encoded_premise = self._contextual_encoder( 94 | premise_lexical_embeddings, premise_mask 95 | ) 96 | encoded_hypothesis = self._contextual_encoder( 97 | hypothesis_lexical_embeddings, hypothesis_mask 98 | ) 99 | 100 | # Memory Layer 101 | premise_memory, hypothesis_memory = self._compute_memory( 102 | encoded_premise, encoded_hypothesis, 103 | premise_mask, hypothesis_mask, 104 | ) 105 | 106 | # Answer Module 107 | label_probs = self._compute_answer( 108 | premise_memory, hypothesis_memory, 109 | premise_mask, hypothesis_mask 110 | ) 111 | 112 | output_dict = {"label_probs": label_probs} 113 | 114 | if label is not None: 115 | label_log_probs = (label_probs + 1e-45).log() 116 | loss = self._loss(label_log_probs, label.long().view(-1)) 117 | self._accuracy(label_probs, label) 118 | output_dict["loss"] = loss 119 | 120 | return output_dict 121 | 122 | def get_metrics(self, reset: bool = False) -> Dict[str, float]: 123 | return {'accuracy': self._accuracy.get_metric(reset)} 124 | 125 | def _compute_memory( 126 | self, 127 | encoded_premise: torch.Tensor, 128 | encoded_hypothesis: torch.Tensor, 129 | premise_mask: torch.Tensor, 130 | hypothesis_mask: torch.Tensor 131 | ) -> Tuple[torch.Tensor, torch.Tensor]: 132 | # Shape: (batch_size, premise_length, hypothesis_length) 133 | attention_matrix = self._matrix_attention( 134 | self._attention_feedforward(encoded_premise), 135 | self._attention_feedforward(encoded_hypothesis), 136 | ) 137 | 138 | if self._dropout: 139 | attention_matrix = self._dropout(attention_matrix) 140 | 141 | # Shape: (batch_size, premise_length, hypothesis_length) 142 | p2h_attention = util.masked_softmax(attention_matrix, hypothesis_mask) 143 | # Shape: (batch_size, premise_length, embedding_dim) 144 | attended_hypothesis = util.weighted_sum(encoded_hypothesis, p2h_attention) 145 | 146 | # Shape: (batch_size, hypothesis_length, premise_length) 147 | h2p_attention = util.masked_softmax( 148 | attention_matrix.transpose(1, 2).contiguous(), premise_mask) 149 | # Shape: (batch_size, hypothesis_length, embedding_dim) 150 | attended_premise = util.weighted_sum(encoded_premise, h2p_attention) 151 | 152 | premise_memory = self._memory_encoder( 153 | torch.cat([encoded_premise, attended_hypothesis], dim=-1), 154 | premise_mask, 155 | ) 156 | hypothesis_memory = self._memory_encoder( 157 | torch.cat([encoded_hypothesis, attended_premise], dim=-1), 158 | hypothesis_mask, 159 | ) 160 | 161 | return premise_memory, hypothesis_memory 162 | 163 | def _compute_answer(self, 164 | premise_memory: torch.Tensor, 165 | hypothesis_memory: torch.Tensor, 166 | premise_mask: torch.Tensor, 167 | hypothesis_mask: torch.Tensor) -> torch.Tensor: 168 | batch_size = premise_memory.size(0) 169 | num_labels = self._output_logit.get_output_dim() 170 | 171 | # Shape: (batch_size, hypothesis_length) 172 | hypothesis_attention = util.masked_softmax( 173 | self._answer_attention(hypothesis_memory).squeeze(), 174 | hypothesis_mask, 175 | ) 176 | # Shape: (batch_size, embedding_dim) 177 | answer_state = util.weighted_sum(hypothesis_memory, hypothesis_attention) 178 | 179 | label_prob_steps: torch.Tensor = answer_state.new_zeros( 180 | (batch_size, num_labels, self._answer_steps) 181 | ) 182 | for step in range(self._answer_steps): 183 | # Shape: (batch_size, premise_length) 184 | premise_attention = self._answer_bilinear(answer_state, premise_memory, premise_mask) 185 | # Shape: (batch_size, embedding_dim) 186 | cell_input = util.weighted_sum(premise_memory, premise_attention) 187 | 188 | answer_state = self._answer_gru_cell(cell_input, answer_state) 189 | 190 | output_hidden = torch.cat([ 191 | answer_state, 192 | cell_input, 193 | (answer_state - cell_input).abs(), 194 | answer_state * cell_input, 195 | ], dim=-1) 196 | label_logits = self._output_logit(self._output_feedforward(output_hidden)) 197 | label_prob_steps[:, :, step] = label_logits.softmax(-1) 198 | 199 | if self.training and self._dropout: 200 | # stochastic prediction dropout 201 | binary_mask = ( 202 | torch.rand((batch_size, self._answer_steps)) > self._dropout.p 203 | ).to(label_prob_steps.device) 204 | label_probs = util.masked_mean( 205 | label_prob_steps, binary_mask.float().unsqueeze(1), dim=2 206 | ) 207 | label_probs = util.replace_masked_values( 208 | label_probs, binary_mask.sum(1, keepdim=True).bool().float(), 1.0 / num_labels 209 | ) 210 | else: 211 | label_probs = label_prob_steps.mean(2) 212 | 213 | return label_probs 214 | -------------------------------------------------------------------------------- /nli/src/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kajyuuen/allennlp-book/e3df88ef284e24e5ae0d9cfa3527c255c05a25e6/nli/src/modules/__init__.py -------------------------------------------------------------------------------- /nli/src/modules/full_layer_lstm.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | from allennlp.modules.seq2seq_encoders import PytorchSeq2SeqWrapper, Seq2SeqEncoder 5 | 6 | 7 | @Seq2SeqEncoder.register("full-layer-lstm") 8 | class FullLayerLSTM(Seq2SeqEncoder): 9 | def __init__(self, 10 | input_dim: int, 11 | hidden_dim: int, 12 | num_layers: int = 2, 13 | bias: bool = True, 14 | dropout: float = 0.0, 15 | bidirectional: bool = False, 16 | maxout: bool = False) -> None: 17 | super().__init__() 18 | self._input_dim = input_dim 19 | self._hidden_dim = hidden_dim 20 | self._num_layers = num_layers 21 | self._maxout = maxout 22 | 23 | self._num_directions = 2 if bidirectional else 1 24 | 25 | self._lstm_layers = [ 26 | PytorchSeq2SeqWrapper(torch.nn.LSTM( 27 | input_dim, hidden_dim, num_layers=1, 28 | bias=bias, dropout=dropout, bidirectional=bidirectional, 29 | batch_first=True, 30 | )) 31 | ] 32 | if self._num_layers > 1: 33 | for _ in range(1, self._num_layers): 34 | self._lstm_layers.append( 35 | PytorchSeq2SeqWrapper(torch.nn.LSTM( 36 | self._num_directions * hidden_dim, hidden_dim, num_layers=1, 37 | bias=bias, dropout=dropout, bidirectional=bidirectional, 38 | batch_first=True, 39 | )) 40 | ) 41 | for i, lstm_layer in enumerate(self._lstm_layers): 42 | self.add_module('lstm_layer_%d' % i, lstm_layer) 43 | 44 | def forward(self, inputs: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: 45 | # pylint: disable=arguments-differ 46 | batch_size, sequence_length, _embedding_dim = inputs.size() 47 | lstm_input = inputs 48 | 49 | lstm_outputs: List[torch.Tensor] = [] 50 | for lstm_layer in self._lstm_layers: 51 | lstm_output = lstm_layer(lstm_input, mask) 52 | lstm_outputs.append(lstm_output) 53 | lstm_input = lstm_output 54 | 55 | if self._maxout: 56 | for i, lstm_output in enumerate(lstm_outputs): 57 | lstm_outputs[i] = lstm_output.view( 58 | batch_size, sequence_length, self._hidden_dim, self._num_directions 59 | ).max(-1)[0] 60 | 61 | output = torch.cat(lstm_outputs, dim=-1) 62 | return output 63 | 64 | def get_input_dim(self) -> int: 65 | return self._input_dim 66 | 67 | def get_output_dim(self) -> int: 68 | if self._maxout: 69 | return self._hidden_dim * self._num_layers 70 | return self._hidden_dim * self._num_layers * self._num_directions 71 | 72 | def is_bidirectional(self) -> bool: 73 | return self._num_directions > 1 74 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | allennlp==0.9.0 2 | awscli>=1.17.0 3 | boto3 4 | Janome==0.3.10 5 | mlflow==1.6.0 6 | pytest 7 | -------------------------------------------------------------------------------- /seq2seq/.gitignore: -------------------------------------------------------------------------------- 1 | /tmp 2 | /data/* 3 | !/data/dataset.py 4 | -------------------------------------------------------------------------------- /seq2seq/configs/common.jsonnet: -------------------------------------------------------------------------------- 1 | local train_data_path = "data/train.tsv"; 2 | local validation_data_path = "data/valid.tsv"; 3 | 4 | local dataset_reader = { 5 | "type": "seq2seq", 6 | "source_tokenizer": { 7 | "type": "character" 8 | }, 9 | "target_tokenizer": { 10 | "type": "character" 11 | }, 12 | "source_token_indexers": { 13 | "tokens": { 14 | "type": "single_id", 15 | "namespace": "source_tokens" 16 | } 17 | }, 18 | "target_token_indexers": { 19 | "tokens": { 20 | "namespace": "target_tokens" 21 | } 22 | } 23 | }; 24 | 25 | { 26 | "dataset_reader": dataset_reader, 27 | "train_data_path": train_data_path, 28 | "validation_data_path": validation_data_path, 29 | "iterator": { 30 | "type": "bucket", 31 | "batch_size" : 100, 32 | "sorting_keys": [["source_tokens", "num_tokens"]] 33 | }, 34 | "trainer": { 35 | "num_epochs": 100, 36 | "patience": 10, 37 | "cuda_device": 0, 38 | "optimizer": { 39 | "type": "adam", 40 | "lr": 0.01 41 | } 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /seq2seq/configs/composed_seq2seq.jsonnet: -------------------------------------------------------------------------------- 1 | local embedding_dim = 10; 2 | local hidden_dim = 16; 3 | local num_layers = 2; 4 | local num_attention_heads = 4; 5 | local projection_dim = hidden_dim; 6 | local feedforward_hidden_dim = hidden_dim * 2; 7 | 8 | local model = { 9 | "type": "composed_seq2seq", 10 | "source_text_embedder": { 11 | "token_embedders": { 12 | "tokens": { 13 | "type": "embedding", 14 | "vocab_namespace": "source_tokens", 15 | "embedding_dim": embedding_dim, 16 | "trainable": true 17 | } 18 | } 19 | }, 20 | "encoder": { 21 | "type": "stacked_self_attention", 22 | "input_dim": embedding_dim, 23 | "hidden_dim": hidden_dim, 24 | "projection_dim": projection_dim, 25 | "feedforward_hidden_dim": feedforward_hidden_dim, 26 | "num_layers": num_layers, 27 | "num_attention_heads": num_attention_heads, 28 | }, 29 | "decoder": { 30 | "type": "auto_regressive_seq_decoder", 31 | "target_namespace": "target_tokens", 32 | "target_embedder": { 33 | "vocab_namespace": "target_tokens", 34 | "embedding_dim": hidden_dim, 35 | }, 36 | "decoder_net": { 37 | "type": "stacked_self_attention", 38 | "target_embedding_dim": hidden_dim, 39 | "decoding_dim": hidden_dim, 40 | "feedforward_hidden_dim": feedforward_hidden_dim, 41 | "num_layers": num_layers, 42 | "num_attention_heads": num_attention_heads, 43 | "positional_encoding_max_steps": 10, 44 | }, 45 | "max_decoding_steps": 5, 46 | "beam_size": 5, 47 | "tensor_based_metric": {"type": "bleu"}, 48 | } 49 | }; 50 | 51 | local COMMON = import 'common.jsonnet'; 52 | { 53 | "random_seed": 1, 54 | "pytorch_seed": 1, 55 | "dataset_reader": COMMON['dataset_reader'], 56 | "train_data_path": COMMON['train_data_path'], 57 | "validation_data_path": COMMON['validation_data_path'], 58 | "model": model, 59 | "iterator": COMMON['iterator'], 60 | "trainer": COMMON['trainer'] 61 | } 62 | -------------------------------------------------------------------------------- /seq2seq/configs/simple_seq2seq.jsonnet: -------------------------------------------------------------------------------- 1 | local embedding_dim = 10; 2 | local hidden_dim = 10; 3 | local bidirectional = true; 4 | 5 | local model = { 6 | "type": "simple_seq2seq", 7 | "source_embedder": { 8 | "token_embedders": { 9 | "tokens": { 10 | "type": "embedding", 11 | "vocab_namespace": "source_tokens", 12 | "embedding_dim": embedding_dim, 13 | "trainable": true 14 | } 15 | } 16 | }, 17 | "encoder": { 18 | "type": "lstm", 19 | "input_size": embedding_dim, 20 | "hidden_size": hidden_dim, 21 | "num_layers": 2, 22 | "dropout": 0.4, 23 | "bidirectional": bidirectional 24 | }, 25 | "max_decoding_steps": 5, 26 | "target_embedding_dim": embedding_dim, 27 | "target_namespace": "target_tokens", 28 | "attention": { 29 | "type": "dot_product" 30 | }, 31 | "beam_size": 5 32 | }; 33 | 34 | local COMMON = import 'common.jsonnet'; 35 | { 36 | "random_seed": 1, 37 | "pytorch_seed": 1, 38 | "dataset_reader": COMMON['dataset_reader'], 39 | "train_data_path": COMMON['train_data_path'], 40 | "validation_data_path": COMMON['validation_data_path'], 41 | "model": model, 42 | "iterator": COMMON['iterator'], 43 | "trainer": COMMON['trainer'] 44 | } 45 | -------------------------------------------------------------------------------- /seq2seq/data/dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import random 5 | 6 | 7 | def main(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--max-value", type=int, default=100) 10 | parser.add_argument("--valid-size", type=int, default=1000) 11 | parser.add_argument("--output-train", type=str, default="./data/train.tsv") 12 | parser.add_argument("--output-valid", type=str, default="./data/valid.tsv") 13 | parser.add_argument("--random-seed", type=int, default=1) 14 | args = parser.parse_args() 15 | 16 | assert args.valid_size < args.max_value ** 2 17 | 18 | random.seed(args.random_seed) 19 | 20 | dataset = [ 21 | f"{x} + {y}\t{x + y}" 22 | for x in range(args.max_value) 23 | for y in range(args.max_value) 24 | ] 25 | random.shuffle(dataset) 26 | 27 | trains = dataset[args.valid_size:] 28 | valids = dataset[:args.valid_size] 29 | 30 | with open(args.output_train, "w") as f: 31 | f.write("\n".join(trains)) 32 | 33 | with open(args.output_valid, "w") as f: 34 | f.write("\n".join(valids)) 35 | 36 | 37 | if __name__ == "__main__": 38 | main() 39 | -------------------------------------------------------------------------------- /seq2seq/decoder.py: -------------------------------------------------------------------------------- 1 | # This source code is from here: 2 | # https://github.com/allenai/allennlp/pull/3464 3 | 4 | from typing import Dict, List, Tuple, Optional 5 | from overrides import overrides 6 | 7 | import numpy 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.nn import Linear 11 | from copy import deepcopy 12 | 13 | from allennlp.common.checks import ConfigurationError 14 | from allennlp.common.util import END_SYMBOL, START_SYMBOL 15 | from allennlp.modules.seq2seq_decoders.seq_decoder import SeqDecoder 16 | from allennlp.data import Vocabulary 17 | from allennlp.modules import Embedding 18 | from allennlp.modules.seq2seq_decoders.decoder_net import DecoderNet 19 | from allennlp.nn import util 20 | from allennlp.nn.beam_search import BeamSearch 21 | from allennlp.training.metrics import Metric 22 | 23 | 24 | @SeqDecoder.register("auto_regressive_seq_decoder", exist_ok=True) 25 | class AutoRegressiveSeqDecoder(SeqDecoder): 26 | """ 27 | An autoregressive decoder that can be used for most seq2seq tasks. 28 | 29 | Parameters 30 | ---------- 31 | vocab : ``Vocabulary``, required 32 | Vocabulary containing source and target vocabularies. They may be under the same namespace 33 | (`tokens`) or the target tokens can have a different namespace, in which case it needs to 34 | be specified as `target_namespace`. 35 | decoder_net : ``DecoderNet``, required 36 | Module that contains implementation of neural network for decoding output elements 37 | max_decoding_steps : ``int`` 38 | Maximum length of decoded sequences. 39 | target_embedder : ``Embedding`` 40 | Embedder for target tokens. 41 | target_namespace : ``str``, optional (default = 'tokens') 42 | If the target side vocabulary is different from the source side's, you need to specify the 43 | target's namespace here. If not, we'll assume it is "tokens", which is also the default 44 | choice for the source side, and this might cause them to share vocabularies. 45 | beam_size : ``int``, optional (default = 4) 46 | Width of the beam for beam search. 47 | tensor_based_metric : ``Metric``, optional (default = None) 48 | A metric to track on validation data that takes raw tensors when its called. 49 | This metric must accept two arguments when called: a batched tensor 50 | of predicted token indices, and a batched tensor of gold token indices. 51 | token_based_metric : ``Metric``, optional (default = None) 52 | A metric to track on validation data that takes lists of lists of tokens 53 | as input. This metric must accept two arguments when called, both 54 | of type `List[List[str]]`. The first is a predicted sequence for each item 55 | in the batch and the second is a gold sequence for each item in the batch. 56 | scheduled_sampling_ratio : ``float`` optional (default = 0) 57 | Defines ratio between teacher forced training and real output usage. If its zero 58 | (teacher forcing only) and `decoder_net`supports parallel decoding, we get the output 59 | predictions in a single forward pass of the `decoder_net`. 60 | """ 61 | 62 | def __init__( 63 | self, 64 | vocab: Vocabulary, 65 | decoder_net: DecoderNet, 66 | max_decoding_steps: int, 67 | target_embedder: Embedding, 68 | target_namespace: str = "tokens", 69 | tie_output_embedding: bool = False, 70 | scheduled_sampling_ratio: float = 0, 71 | label_smoothing_ratio: Optional[float] = None, 72 | beam_size: int = 4, 73 | tensor_based_metric: Metric = None, 74 | token_based_metric: Metric = None, 75 | ) -> None: 76 | super().__init__(target_embedder) 77 | 78 | self._vocab = vocab 79 | 80 | # Decodes the sequence of encoded hidden states into e new sequence of hidden states. 81 | self._decoder_net = decoder_net 82 | self._max_decoding_steps = max_decoding_steps 83 | self._target_namespace = target_namespace 84 | self._label_smoothing_ratio = label_smoothing_ratio 85 | 86 | # At prediction time, we use a beam search to find the most likely sequence of target tokens. 87 | # We need the start symbol to provide as the input at the first timestep of decoding, and 88 | # end symbol as a way to indicate the end of the decoded sequence. 89 | self._start_index = self._vocab.get_token_index(START_SYMBOL, self._target_namespace) 90 | self._end_index = self._vocab.get_token_index(END_SYMBOL, self._target_namespace) 91 | self._beam_search = BeamSearch( 92 | self._end_index, max_steps=max_decoding_steps, beam_size=beam_size 93 | ) 94 | 95 | target_vocab_size = self._vocab.get_vocab_size(self._target_namespace) 96 | 97 | if self.target_embedder.get_output_dim() != self._decoder_net.target_embedding_dim: 98 | raise ConfigurationError( 99 | "Target Embedder output_dim doesn't match decoder module's input." 100 | ) 101 | 102 | # We project the hidden state from the decoder into the output vocabulary space 103 | # in order to get log probabilities of each target token, at each time step. 104 | self._output_projection_layer = Linear( 105 | self._decoder_net.get_output_dim(), target_vocab_size 106 | ) 107 | 108 | if tie_output_embedding: 109 | if self._output_projection_layer.weight.shape != self.target_embedder.weight.shape: 110 | raise ConfigurationError( 111 | "Can't tie embeddings with output linear layer, due to shape mismatch" 112 | ) 113 | self._output_projection_layer.weight = self.target_embedder.weight 114 | 115 | # These metrics will be updated during training and validation 116 | self._tensor_based_metric = tensor_based_metric 117 | self._token_based_metric = token_based_metric 118 | 119 | self._scheduled_sampling_ratio = scheduled_sampling_ratio 120 | 121 | def _forward_beam_search(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 122 | """ 123 | Prepare inputs for the beam search, does beam search and returns beam search results. 124 | """ 125 | batch_size = state["source_mask"].size()[0] 126 | start_predictions = state["source_mask"].new_full( 127 | (batch_size,), fill_value=self._start_index 128 | ) 129 | 130 | # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps) 131 | # shape (log_probabilities): (batch_size, beam_size) 132 | all_top_k_predictions, log_probabilities = self._beam_search.search( 133 | start_predictions, state, self.take_step 134 | ) 135 | 136 | output_dict = { 137 | "class_log_probabilities": log_probabilities, 138 | "predictions": all_top_k_predictions, 139 | } 140 | return output_dict 141 | 142 | def _forward_loss( 143 | self, state: Dict[str, torch.Tensor], target_tokens: Dict[str, torch.LongTensor] 144 | ) -> Dict[str, torch.Tensor]: 145 | """ 146 | Make forward pass during training or do greedy search during prediction. 147 | 148 | Notes 149 | ----- 150 | We really only use the predictions from the method to test that beam search 151 | with a beam size of 1 gives the same results. 152 | """ 153 | # shape: (batch_size, max_input_sequence_length, encoder_output_dim) 154 | encoder_outputs = state["encoder_outputs"] 155 | 156 | # shape: (batch_size, max_input_sequence_length) 157 | source_mask = state["source_mask"] 158 | 159 | # shape: (batch_size, max_target_sequence_length) 160 | targets = target_tokens["tokens"] 161 | 162 | # Prepare embeddings for targets. They will be used as gold embeddings during decoder training 163 | # shape: (batch_size, max_target_sequence_length, embedding_dim) 164 | target_embedding = self.target_embedder(targets) 165 | 166 | # shape: (batch_size, max_target_batch_sequence_length) 167 | target_mask = util.get_text_field_mask(target_tokens) 168 | 169 | if self._scheduled_sampling_ratio == 0 and self._decoder_net.decodes_parallel: 170 | _, decoder_output = self._decoder_net( 171 | previous_state=state, 172 | previous_steps_predictions=target_embedding[:, :-1, :], 173 | encoder_outputs=encoder_outputs, 174 | source_mask=source_mask, 175 | previous_steps_mask=target_mask[:, :-1], 176 | ) 177 | 178 | # shape: (group_size, max_target_sequence_length, num_classes) 179 | logits = self._output_projection_layer(decoder_output) 180 | else: 181 | batch_size = source_mask.size()[0] 182 | _, target_sequence_length = targets.size() 183 | 184 | # The last input from the target is either padding or the end symbol. 185 | # Either way, we don't have to process it. 186 | num_decoding_steps = target_sequence_length - 1 187 | 188 | # Initialize target predictions with the start index. 189 | # shape: (batch_size,) 190 | last_predictions = source_mask.new_full((batch_size,), fill_value=self._start_index) 191 | 192 | # shape: (steps, batch_size, target_embedding_dim) 193 | steps_embeddings = torch.Tensor([]) 194 | 195 | step_logits: List[torch.Tensor] = [] 196 | 197 | for timestep in range(num_decoding_steps): 198 | if self.training and torch.rand(1).item() < self._scheduled_sampling_ratio: 199 | # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio 200 | # during training. 201 | # shape: (batch_size, steps, target_embedding_dim) 202 | state["previous_steps_predictions"] = steps_embeddings 203 | 204 | # shape: (batch_size, ) 205 | effective_last_prediction = last_predictions 206 | else: 207 | # shape: (batch_size, ) 208 | effective_last_prediction = targets[:, timestep] 209 | 210 | if timestep == 0: 211 | state["previous_steps_predictions"] = torch.Tensor([]) 212 | else: 213 | # shape: (batch_size, steps, target_embedding_dim) 214 | state["previous_steps_predictions"] = target_embedding[:, :timestep] 215 | 216 | # shape: (batch_size, num_classes) 217 | output_projections, state = self._prepare_output_projections( 218 | effective_last_prediction, state 219 | ) 220 | 221 | # list of tensors, shape: (batch_size, 1, num_classes) 222 | step_logits.append(output_projections.unsqueeze(1)) 223 | 224 | # shape (predicted_classes): (batch_size,) 225 | _, predicted_classes = torch.max(output_projections, 1) 226 | 227 | # shape (predicted_classes): (batch_size,) 228 | last_predictions = predicted_classes 229 | 230 | # shape: (batch_size, 1, target_embedding_dim) 231 | last_predictions_embeddings = self.target_embedder(last_predictions).unsqueeze(1) 232 | 233 | # This step is required, since we want to keep up two different prediction history: gold and real 234 | if steps_embeddings.shape[-1] == 0: 235 | # There is no previous steps, except for start vectors in ``last_predictions`` 236 | # shape: (group_size, 1, target_embedding_dim) 237 | steps_embeddings = last_predictions_embeddings 238 | else: 239 | # shape: (group_size, steps_count, target_embedding_dim) 240 | steps_embeddings = torch.cat([steps_embeddings, last_predictions_embeddings], 1) 241 | 242 | # shape: (batch_size, num_decoding_steps, num_classes) 243 | logits = torch.cat(step_logits, 1) 244 | 245 | # Compute loss. 246 | target_mask = util.get_text_field_mask(target_tokens) 247 | loss = self._get_loss(logits, targets, target_mask) 248 | 249 | # TODO: We will be using beam search to get predictions for validation, but if beam size in 1 250 | # we could consider taking the last_predictions here and building step_predictions 251 | # and use that instead of running beam search again, if performance in validation is taking a hit 252 | output_dict = {"loss": loss} 253 | 254 | return output_dict 255 | 256 | def _prepare_output_projections( 257 | self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] 258 | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: 259 | """ 260 | Decode current state and last prediction to produce produce projections 261 | into the target space, which can then be used to get probabilities of 262 | each target token for the next step. 263 | 264 | Inputs are the same as for `take_step()`. 265 | """ 266 | # shape: (group_size, max_input_sequence_length, encoder_output_dim) 267 | encoder_outputs = state["encoder_outputs"] 268 | 269 | # shape: (group_size, max_input_sequence_length) 270 | source_mask = state["source_mask"] 271 | 272 | # shape: (group_size, steps_count, decoder_output_dim) 273 | previous_steps_predictions = state.get("previous_steps_predictions") 274 | 275 | # shape: (batch_size, 1, target_embedding_dim) 276 | last_predictions_embeddings = self.target_embedder(last_predictions).unsqueeze(1) 277 | 278 | if previous_steps_predictions is None or previous_steps_predictions.shape[-1] == 0: 279 | # There is no previous steps, except for start vectors in ``last_predictions`` 280 | # shape: (group_size, 1, target_embedding_dim) 281 | previous_steps_predictions = last_predictions_embeddings 282 | else: 283 | # shape: (group_size, steps_count, target_embedding_dim) 284 | previous_steps_predictions = torch.cat( 285 | [previous_steps_predictions, last_predictions_embeddings], 1 286 | ) 287 | 288 | decoder_state, decoder_output = self._decoder_net( 289 | previous_state=state, 290 | encoder_outputs=encoder_outputs, 291 | source_mask=source_mask, 292 | previous_steps_predictions=previous_steps_predictions, 293 | ) 294 | state["previous_steps_predictions"] = previous_steps_predictions 295 | 296 | # Update state with new decoder state, override previous state 297 | state.update(decoder_state) 298 | 299 | if self._decoder_net.decodes_parallel: 300 | decoder_output = decoder_output[:, -1, :] 301 | 302 | # shape: (group_size, num_classes) 303 | output_projections = self._output_projection_layer(decoder_output) 304 | 305 | return output_projections, state 306 | 307 | def _get_loss( 308 | self, logits: torch.LongTensor, targets: torch.LongTensor, target_mask: torch.LongTensor 309 | ) -> torch.Tensor: 310 | """ 311 | Compute loss. 312 | 313 | Takes logits (unnormalized outputs from the decoder) of size (batch_size, 314 | num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1) 315 | and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross 316 | entropy loss while taking the mask into account. 317 | 318 | The length of ``targets`` is expected to be greater than that of ``logits`` because the 319 | decoder does not need to compute the output corresponding to the last timestep of 320 | ``targets``. This method aligns the inputs appropriately to compute the loss. 321 | 322 | During training, we want the logit corresponding to timestep i to be similar to the target 323 | token from timestep i + 1. That is, the targets should be shifted by one timestep for 324 | appropriate comparison. Consider a single example where the target has 3 words, and 325 | padding is to 7 tokens. 326 | The complete sequence would correspond to w1 w2 w3

327 | and the mask would be 1 1 1 1 1 0 0 328 | and let the logits be l1 l2 l3 l4 l5 l6 329 | We actually need to compare: 330 | the sequence w1 w2 w3

331 | with masks 1 1 1 1 0 0 332 | against l1 l2 l3 l4 l5 l6 333 | (where the input was) w1 w2 w3

334 | """ 335 | # shape: (batch_size, num_decoding_steps) 336 | relevant_targets = targets[:, 1:].contiguous() 337 | 338 | # shape: (batch_size, num_decoding_steps) 339 | relevant_mask = target_mask[:, 1:].contiguous() 340 | 341 | return util.sequence_cross_entropy_with_logits( 342 | logits, relevant_targets, relevant_mask, label_smoothing=self._label_smoothing_ratio 343 | ) 344 | 345 | def get_output_dim(self): 346 | return self._decoder_net.get_output_dim() 347 | 348 | def take_step( 349 | self, last_predictions: torch.Tensor, state: Dict[str, torch.Tensor] 350 | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: 351 | """ 352 | Take a decoding step. This is called by the beam search class. 353 | 354 | Parameters 355 | ---------- 356 | last_predictions : ``torch.Tensor`` 357 | A tensor of shape ``(group_size,)``, which gives the indices of the predictions 358 | during the last time step. 359 | state : ``Dict[str, torch.Tensor]`` 360 | A dictionary of tensors that contain the current state information 361 | needed to predict the next step, which includes the encoder outputs, 362 | the source mask, and the decoder hidden state and context. Each of these 363 | tensors has shape ``(group_size, *)``, where ``*`` can be any other number 364 | of dimensions. 365 | 366 | Returns 367 | ------- 368 | Tuple[torch.Tensor, Dict[str, torch.Tensor]] 369 | A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities`` 370 | is a tensor of shape ``(group_size, num_classes)`` containing the predicted 371 | log probability of each class for the next step, for each item in the group, 372 | while ``updated_state`` is a dictionary of tensors containing the encoder outputs, 373 | source mask, and updated decoder hidden state and context. 374 | 375 | Notes 376 | ----- 377 | We treat the inputs as a batch, even though ``group_size`` is not necessarily 378 | equal to ``batch_size``, since the group may contain multiple states 379 | for each source sentence in the batch. 380 | """ 381 | # shape: (group_size, num_classes) 382 | output_projections, state = self._prepare_output_projections(last_predictions, state) 383 | 384 | # shape: (group_size, num_classes) 385 | class_log_probabilities = F.log_softmax(output_projections, dim=-1) 386 | 387 | return class_log_probabilities, state 388 | 389 | @overrides 390 | def get_metrics(self, reset: bool = False) -> Dict[str, float]: 391 | all_metrics: Dict[str, float] = {} 392 | if not self.training: 393 | if self._tensor_based_metric is not None: 394 | all_metrics.update( 395 | self._tensor_based_metric.get_metric(reset=reset) # type: ignore 396 | ) 397 | if self._token_based_metric is not None: 398 | all_metrics.update(self._token_based_metric.get_metric(reset=reset)) # type: ignore 399 | return all_metrics 400 | 401 | @overrides 402 | def forward( 403 | self, 404 | encoder_out: Dict[str, torch.LongTensor], 405 | target_tokens: Dict[str, torch.LongTensor] = None, 406 | ) -> Dict[str, torch.Tensor]: 407 | state = encoder_out 408 | decoder_init_state = self._decoder_net.init_decoder_state(state) 409 | state.update(decoder_init_state) 410 | 411 | if target_tokens: 412 | state_forward_loss = state if self.training else {k: v.clone() for k, v in state.items()} 413 | output_dict = self._forward_loss(state_forward_loss, target_tokens) 414 | else: 415 | output_dict = {} 416 | 417 | if not self.training: 418 | predictions = self._forward_beam_search(state) 419 | output_dict.update(predictions) 420 | 421 | if target_tokens: 422 | if self._tensor_based_metric is not None: 423 | # shape: (batch_size, beam_size, max_sequence_length) 424 | top_k_predictions = output_dict["predictions"] 425 | # shape: (batch_size, max_predicted_sequence_length) 426 | best_predictions = top_k_predictions[:, 0, :] 427 | 428 | self._tensor_based_metric( # type: ignore 429 | best_predictions, target_tokens["tokens"] 430 | ) 431 | 432 | if self._token_based_metric is not None: 433 | output_dict = self.post_process(output_dict) 434 | predicted_tokens = output_dict["predicted_tokens"] 435 | 436 | self._token_based_metric( # type: ignore 437 | predicted_tokens, self.indices_to_tokens(target_tokens["tokens"][:, 1:]) 438 | ) 439 | 440 | return output_dict 441 | 442 | @overrides 443 | def post_process(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 444 | """ 445 | This method trims the output predictions to the first end symbol, replaces indices with 446 | corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``. 447 | """ 448 | predicted_indices = output_dict["predictions"] 449 | all_predicted_tokens = self.indices_to_tokens(predicted_indices) 450 | output_dict["predicted_tokens"] = all_predicted_tokens 451 | return output_dict 452 | 453 | def indices_to_tokens(self, batch_indeces: numpy.ndarray) -> List[List[str]]: 454 | 455 | if not isinstance(batch_indeces, numpy.ndarray): 456 | batch_indeces = batch_indeces.detach().cpu().numpy() 457 | 458 | all_tokens = [] 459 | for indices in batch_indeces: 460 | # Beam search gives us the top k results for each source sentence in the batch 461 | # but we just want the single best. 462 | if len(indices.shape) > 1: 463 | indices = indices[0] 464 | indices = list(indices) 465 | # Collect indices till the first end_symbol 466 | if self._end_index in indices: 467 | indices = indices[: indices.index(self._end_index)] 468 | tokens = [ 469 | self._vocab.get_token_from_index(x, namespace=self._target_namespace) 470 | for x in indices 471 | ] 472 | all_tokens.append(tokens) 473 | 474 | return all_tokens 475 | --------------------------------------------------------------------------------