├── .gitignore ├── README.md ├── experiments ├── newsgroups_with_cuda.json └── newsgroups_without_cuda.json ├── newsgroups ├── __init__.py ├── dataset_readers │ ├── __init__.py │ └── fetch_newsgroups.py └── models │ ├── __init__.py │ └── newsgroups_classifier.py ├── requirements.txt └── run.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .cache/ 3 | .coverage 4 | doc/_build/ 5 | .ipynb_checkpoints 6 | 7 | # build artifacts 8 | allennlp.egg-info/ 9 | .eggs/ 10 | build/ 11 | dist/ 12 | .mypy_cache 13 | 14 | 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Training a deep learning model with AllenNLP 2 | 3 | In this tutorial we’ll use the 20 newsgroups provided by scikit-learn. For more details check out [this article](https://medium.com/swlh/deep-learning-for-text-made-easy-with-allennlp-62bc79d41f31) 😉 4 | 5 | With AllenNLP we define the model architecture in a JSON file ([experiments/newsgroups_without_cuda.json](https://github.com/dmesquita/easy-deep-learning-with-AllenNLP/blob/master/experiments/newsgroups_without_cuda.json)). This ``Model`` performs text classification for the newsgroup files. The basic model structure: we'll embed the text and encode it with a Seq2VecEncoder. We'll then pass the result through a feedforward network, the output of which we'll use as our scores for each label. 6 | 7 | ## 1 —Data inputs 8 | To set the input dataset and how to read from it we use the ``'dataset_reader'`` key in the JSON file. We specify how to read the data [here](https://github.com/dmesquita/easy-deep-learning-with-AllenNLP/blob/2b2cf0176404346f7713d72fd34f78f645f6d7cf/newsgroups/dataset_readers/fetch_newsgroups.py#L55) by creating a [``DatasetReader`` class](https://github.com/dmesquita/easy-deep-learning-with-AllenNLP/blob/master/newsgroups/dataset_readers/fetch_newsgroups.py) 9 | 10 | ## 2 — The model 11 | To specify the model we’ll set the ``'model'`` key. There are three more parameters inside: ``'model_text_field_embedder'``, ``'internal_text_encoder'`` and ``'classifier_feedforward'``. The internals of the model is defined in the [``Fetch20NewsgroupsClassifier`` class](https://github.com/dmesquita/easy-deep-learning-with-AllenNLP/blob/master/newsgroups/models/newsgroups_classifier.py) 12 | 13 | ## 3 — The data iterator 14 | AllenNLP provides an iterator called BucketIterator that makes the computations (padding) more efficient by padding batches with respect to the maximum input lengths per batch. To do that it sorts the instances by the number of tokens in each text. We set these parameters in the ``'iterator'`` key of the JSON file. 15 | 16 | ## 4 — Training the model 17 | The trainer uses the AdaGrad optimizer for 30 epochs, stopping if validation accuracy has not increased for the last 3 epochs. This is also specified in the [JSON file](https://github.com/dmesquita/easy-deep-learning-with-AllenNLP/blob/master/experiments/newsgroups_without_cuda.json). 18 | 19 | To train the model locally we need to run this: 20 | 21 | ```python3 run.py train experiments/newsgroups_without_cuda.json --include-package newsgroups.dataset_readers --include-package newsgroups.models``` 22 | 23 | 24 | ### Train the model: colaboratory notebook 25 | [https://colab.research.google.com/drive/1q3b5HAkcjYsVd6yhrwnxL2ByqGK08jhQ](https://colab.research.google.com/drive/1q3b5HAkcjYsVd6yhrwnxL2ByqGK08jhQ) 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /experiments/newsgroups_with_cuda.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "20newsgroups" 4 | }, 5 | "train_data_path": "train", 6 | "test_data_path": "test", 7 | "evaluate_on_test": true, 8 | "model": { 9 | "type": "20newsgroups_classifier", 10 | "model_text_field_embedder": { 11 | "tokens": { 12 | "type": "embedding", 13 | "pretrained_file": "https://s3-us-west-2.amazonaws.com/allennlp/datasets/glove/glove.6B.100d.txt.gz", 14 | "embedding_dim": 100, 15 | "trainable": false 16 | } 17 | }, 18 | "internal_text_encoder": { 19 | "type": "lstm", 20 | "bidirectional": true, 21 | "input_size": 100, 22 | "hidden_size": 100, 23 | "num_layers": 1, 24 | "dropout": 0.2 25 | }, 26 | "classifier_feedforward": { 27 | "input_dim": 200, 28 | "num_layers": 2, 29 | "hidden_dims": [200, 100], 30 | "activations": ["relu", "linear"], 31 | "dropout": [0.2, 0.0] 32 | } 33 | }, 34 | "iterator": { 35 | "type": "bucket", 36 | "sorting_keys": [["text", "num_tokens"]], 37 | "batch_size": 64 38 | }, 39 | "trainer": { 40 | "num_epochs": 30, 41 | "patience": 3, 42 | "cuda_device": 0, 43 | "grad_clipping": 5.0, 44 | "validation_metric": "+accuracy", 45 | "optimizer": { 46 | "type": "adagrad" 47 | } 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /experiments/newsgroups_without_cuda.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "20newsgroups" 4 | }, 5 | "train_data_path": "train", 6 | "test_data_path": "test", 7 | "evaluate_on_test": true, 8 | "model": { 9 | "type": "20newsgroups_classifier", 10 | "model_text_field_embedder": { 11 | "tokens": { 12 | "type": "embedding", 13 | "pretrained_file": "https://s3-us-west-2.amazonaws.com/allennlp/datasets/glove/glove.6B.100d.txt.gz", 14 | "embedding_dim": 100, 15 | "trainable": false 16 | } 17 | }, 18 | "internal_text_encoder": { 19 | "type": "lstm", 20 | "bidirectional": true, 21 | "input_size": 100, 22 | "hidden_size": 100, 23 | "num_layers": 1, 24 | "dropout": 0.2 25 | }, 26 | "classifier_feedforward": { 27 | "input_dim": 200, 28 | "num_layers": 2, 29 | "hidden_dims": [200, 3], 30 | "activations": ["relu", "linear"], 31 | "dropout": [0.2, 0.0] 32 | } 33 | }, 34 | "iterator": { 35 | "type": "bucket", 36 | "sorting_keys": [["text", "num_tokens"]], 37 | "batch_size": 64 38 | }, 39 | "trainer": { 40 | "num_epochs": 30, 41 | "patience": 3, 42 | "cuda_device": -1, 43 | "grad_clipping": 5.0, 44 | "validation_metric": "+accuracy", 45 | "optimizer": { 46 | "type": "adagrad" 47 | } 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /newsgroups/__init__.py: -------------------------------------------------------------------------------- 1 | from newsgroups.dataset_readers import * 2 | from newsgroups.models import * 3 | -------------------------------------------------------------------------------- /newsgroups/dataset_readers/__init__.py: -------------------------------------------------------------------------------- 1 | from newsgroups.dataset_readers.fetch_newsgroups import NewsgroupsDatasetReader 2 | -------------------------------------------------------------------------------- /newsgroups/dataset_readers/fetch_newsgroups.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import json 3 | import logging 4 | 5 | from overrides import overrides 6 | 7 | import tqdm 8 | 9 | from allennlp.common import Params 10 | from allennlp.common.file_utils import cached_path 11 | from allennlp.data.dataset_readers.dataset_reader import DatasetReader 12 | from allennlp.data.fields import LabelField, TextField 13 | from allennlp.data.instance import Instance 14 | from allennlp.data.tokenizers import Tokenizer, WordTokenizer 15 | from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer 16 | 17 | from sklearn.datasets import fetch_20newsgroups 18 | 19 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 20 | 21 | 22 | @DatasetReader.register("20newsgroups") 23 | class NewsgroupsDatasetReader(DatasetReader): 24 | """ 25 | Reads a JSON-lines file containing papers from the Semantic Scholar database, and creates a 26 | dataset suitable for document classification using these papers. 27 | 28 | Expected format for each input line: {"paperAbstract": "text", "title": "text", "venue": "text"} 29 | 30 | The JSON could have other fields, too, but they are ignored. 31 | 32 | The output of ``read`` is a list of ``Instance`` s with the fields: 33 | title: ``TextField`` 34 | abstract: ``TextField`` 35 | label: ``LabelField`` 36 | 37 | where the ``label`` is derived from the venue of the paper. 38 | 39 | Parameters 40 | ---------- 41 | tokenizer : ``Tokenizer``, optional 42 | Tokenizer to use to split the title and abstrct into words or other kinds of tokens. 43 | Defaults to ``WordTokenizer()``. 44 | token_indexers : ``Dict[str, TokenIndexer]``, optional 45 | Indexers used to define input token representations. Defaults to ``{"tokens": 46 | SingleIdTokenIndexer()}``. 47 | """ 48 | def __init__(self, 49 | tokenizer: Tokenizer = None, 50 | token_indexers: Dict[str, TokenIndexer] = None) -> None: 51 | self._tokenizer = tokenizer or WordTokenizer() 52 | self._token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()} 53 | 54 | @overrides 55 | def _read(self, file_path): 56 | instances = [] 57 | if file_path == "train": 58 | logger.info("Reading instances from: %s", file_path) 59 | categories = ["comp.graphics","sci.space","rec.sport.baseball"] 60 | newsgroups_data = fetch_20newsgroups(subset='train',categories=categories) 61 | 62 | elif file_path == "test": 63 | logger.info("Reading instances from: %s", file_path) 64 | categories = ["comp.graphics","sci.space","rec.sport.baseball"] 65 | newsgroups_data = fetch_20newsgroups(subset='test',categories=categories) 66 | 67 | else: 68 | raise ConfigurationError("Path string not specified in read method") 69 | 70 | for i,text in enumerate(newsgroups_data.data): 71 | if file_path == "validate": 72 | if i == 400: 73 | break 74 | text = newsgroups_data.data[i] 75 | target = newsgroups_data.target[i] 76 | yield self.text_to_instance(text, target) 77 | 78 | @overrides 79 | def text_to_instance(self, text: str, target: str = None) -> Instance: # type: ignore 80 | # pylint: disable=arguments-differ 81 | tokenized_text = self._tokenizer.tokenize(text) 82 | text_field = TextField(tokenized_text, self._token_indexers) 83 | fields = {'text': text_field} 84 | if target is not None: 85 | fields['label'] = LabelField(int(target),skip_indexing=True) 86 | return Instance(fields) 87 | 88 | 89 | -------------------------------------------------------------------------------- /newsgroups/models/__init__.py: -------------------------------------------------------------------------------- 1 | from newsgroups.models.newsgroups_classifier import Fetch20NewsgroupsClassifier 2 | -------------------------------------------------------------------------------- /newsgroups/models/newsgroups_classifier.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | 3 | import numpy 4 | from overrides import overrides 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | from allennlp.common import Params 9 | from allennlp.common.checks import ConfigurationError 10 | from allennlp.data import Vocabulary 11 | from allennlp.modules import FeedForward, Seq2VecEncoder, TextFieldEmbedder 12 | from allennlp.models.model import Model 13 | from allennlp.nn import InitializerApplicator, RegularizerApplicator 14 | from allennlp.nn import util 15 | from allennlp.training.metrics import CategoricalAccuracy 16 | 17 | 18 | @Model.register("20newsgroups_classifier") 19 | class Fetch20NewsgroupsClassifier(Model): 20 | """ 21 | This ``Model`` performs text classification for a newsgroup text. We assume we're given a 22 | text and we predict some output label. 23 | The basic model structure: we'll embed the text and encode it with 24 | a Seq2VecEncoder, getting a single vector representing the content. We'll then 25 | the result through a feedforward network, the output of 26 | which we'll use as our scores for each label. 27 | Parameters 28 | ---------- 29 | vocab : ``Vocabulary``, required 30 | A Vocabulary, required in order to compute sizes for input/output projections. 31 | model_text_field_embedder : ``TextFieldEmbedder``, required 32 | Used to embed the ``tokens`` ``TextField`` we get as input to the model. 33 | internal_text_encoder : ``Seq2VecEncoder`` 34 | The encoder that we will use to convert the input text to a vector. 35 | classifier_feedforward : ``FeedForward`` 36 | initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) 37 | Used to initialize the model parameters. 38 | regularizer : ``RegularizerApplicator``, optional (default=``None``) 39 | If provided, will be used to calculate the regularization penalty during training. 40 | """ 41 | def __init__(self, vocab: Vocabulary, 42 | model_text_field_embedder: TextFieldEmbedder, 43 | internal_text_encoder: Seq2VecEncoder, 44 | classifier_feedforward: FeedForward, 45 | initializer: InitializerApplicator = InitializerApplicator(), 46 | regularizer: Optional[RegularizerApplicator] = None) -> None: 47 | super(Fetch20NewsgroupsClassifier, self).__init__(vocab, regularizer) 48 | 49 | self.model_text_field_embedder = model_text_field_embedder 50 | self.num_classes = self.vocab.get_vocab_size("labels") 51 | self.internal_text_encoder = internal_text_encoder 52 | self.classifier_feedforward = classifier_feedforward 53 | 54 | if model_text_field_embedder.get_output_dim() != internal_text_encoder.get_input_dim(): 55 | raise ConfigurationError("The output dimension of the model_text_field_embedder must match the " 56 | "input dimension of the title_encoder. Found {} and {}, " 57 | "respectively.".format(model_text_field_embedder.get_output_dim(), 58 | internal_text_encoder.get_input_dim())) 59 | self.metrics = { 60 | "accuracy": CategoricalAccuracy(), 61 | "accuracy3": CategoricalAccuracy(top_k=3) 62 | } 63 | self.loss = torch.nn.CrossEntropyLoss() 64 | 65 | initializer(self) 66 | 67 | @overrides 68 | def forward(self, # type: ignore 69 | text: Dict[str, torch.LongTensor], 70 | label: torch.LongTensor = None) -> Dict[str, torch.Tensor]: 71 | # pylint: disable=arguments-differ 72 | """ 73 | Parameters 74 | ---------- 75 | input_text : Dict[str, Variable], required 76 | The output of ``TextField.as_array()``. 77 | label : Variable, optional (default = None) 78 | A variable representing the label for each instance in the batch. 79 | Returns 80 | ------- 81 | An output dictionary consisting of: 82 | class_probabilities : torch.FloatTensor 83 | A tensor of shape ``(batch_size, num_classes)`` representing a distribution over the 84 | label classes for each instance. 85 | loss : torch.FloatTensor, optional 86 | A scalar loss to be optimised. 87 | """ 88 | embedded_text = self.model_text_field_embedder(text) 89 | text_mask = util.get_text_field_mask(text) 90 | encoded_text = self.internal_text_encoder(embedded_text, text_mask) 91 | 92 | logits = self.classifier_feedforward(encoded_text) 93 | output_dict = {'logits': logits} 94 | if label is not None: 95 | loss = self.loss(logits, label.squeeze(-1)) 96 | for metric in self.metrics.values(): 97 | metric(logits, label.squeeze(-1)) 98 | output_dict["loss"] = loss 99 | 100 | return output_dict 101 | 102 | @overrides 103 | def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 104 | """ 105 | Does a simple argmax over the class probabilities, converts indices to string labels, and 106 | adds a ``"label"`` key to the dictionary with the result. 107 | """ 108 | class_probabilities = F.softmax(output_dict['logits']) 109 | output_dict['class_probabilities'] = class_probabilities 110 | 111 | predictions = class_probabilities.cpu().data.numpy() 112 | argmax_indices = numpy.argmax(predictions, axis=-1) 113 | labels = [self.vocab.get_token_from_index(x, namespace="labels") 114 | for x in argmax_indices] 115 | output_dict['label'] = labels 116 | return output_dict 117 | 118 | @overrides 119 | def get_metrics(self, reset: bool = False) -> Dict[str, float]: 120 | return {metric_name: metric.get_metric(reset) for metric_name, metric in self.metrics.items()} 121 | 122 | @classmethod 123 | def from_params(cls, vocab: Vocabulary, params: Params) -> 'Fetch20NewsgroupsClassifier': 124 | embedder_params = params.pop("model_text_field_embedder") 125 | model_text_field_embedder = TextFieldEmbedder.from_params(embedder_params, vocab=vocab) 126 | internal_text_encoder = Seq2VecEncoder.from_params(params.pop("internal_text_encoder")) 127 | classifier_feedforward = FeedForward.from_params(params.pop("classifier_feedforward")) 128 | 129 | initializer = InitializerApplicator.from_params(params.pop('initializer', [])) 130 | regularizer = RegularizerApplicator.from_params(params.pop('regularizer', [])) 131 | 132 | return cls(vocab=vocab, 133 | model_text_field_embedder=model_text_field_embedder, 134 | internal_text_encoder=internal_text_encoder, 135 | classifier_feedforward=classifier_feedforward, 136 | initializer=initializer, 137 | regularizer=regularizer) 138 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | allennlp==0.6.0 2 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import logging 3 | import os 4 | import sys 5 | 6 | sys.path.insert(0, os.path.dirname(os.path.abspath(os.path.join(__file__, os.pardir)))) 7 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 8 | level=logging.INFO) 9 | 10 | from newsgroups import * 11 | 12 | from allennlp.commands import main # pylint: disable=wrong-import-position 13 | 14 | if __name__ == "__main__": 15 | main(prog="python run.py") 16 | --------------------------------------------------------------------------------