├── .flake8 ├── .github └── workflows │ └── ci.yml ├── .gitignore ├── README.md ├── my_project ├── __init__.py ├── dataset_reader.py ├── model.py └── train.py ├── pytest.ini ├── requirements.txt ├── run.py └── tests ├── __init__.py ├── fixtures └── toy_data.tsv ├── test_dataset_reader.py └── test_model.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 115 3 | 4 | ignore = 5 | E203 # whitespace before : 6 | W503 # line break before binary operator 7 | F401 # unused imports 8 | 9 | per-file-ignores = 10 | # __init__.py files are allowed to have unused imports and lines-too-long 11 | */__init__.py:F401,F403 12 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - master 7 | push: 8 | branches: 9 | - master 10 | 11 | jobs: 12 | checks: 13 | name: Checks 14 | runs-on: ubuntu-latest 15 | strategy: 16 | matrix: 17 | python: ['3.8'] 18 | 19 | steps: 20 | - uses: actions/checkout@v2 21 | 22 | - name: Setup Python 23 | uses: actions/setup-python@v1 24 | with: 25 | python-version: ${{ matrix.python }} 26 | 27 | - uses: actions/cache@v2 28 | with: 29 | path: ${{ env.pythonLocation }} 30 | key: ${{ runner.os }}-pydeps-${{ env.pythonLocation }}-${{ hashFiles('requirements.txt') }} 31 | 32 | - name: Install requirements 33 | run: | 34 | pip install -r requirements.txt 35 | 36 | - name: Debug info 37 | run: | 38 | pip freeze 39 | 40 | - name: Black 41 | if: always() 42 | run: | 43 | black . --check 44 | 45 | - name: Flake8 46 | if: always() 47 | run: | 48 | flake8 . 49 | 50 | - name: MyPy 51 | if: always() 52 | run: | 53 | mypy . 54 | 55 | - name: PyTest 56 | if: always() 57 | run: | 58 | pytest -v . 59 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .pytest_cache 2 | __pycache__ 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AllenNLP Template Project using your own python script 2 | 3 | A template for starting a new allennlp project using your own python script, instead of config files 4 | and `allennlp train`. For simple projects, all you need to do is get your model code into the class 5 | in `my_project/model.py`, your data loading code into the `DatasetReader` code in 6 | `my_project/dataset_reader.py`, and configuration code in `my_project/train.py` (the `build_*` 7 | methods), and that's it, you can train your model with `python run.py`. We recommend also making 8 | appropriate changes to the test code, and using that for development, but that's optional. 9 | 10 | See the [AllenNLP Guide](https://guide.allennlp.org/your-first-model) for a quick start on how to 11 | use what's in this example project. We're grabbing the model and dataset reader classes from that 12 | guide. You can replace those classes with a model and dataset reader for whatever you want 13 | (including copying code from our [model library](https://github.com/allenai/allennlp-models) as a 14 | starting point). The very brief version of what's in here: 15 | 16 | * A `Model` class in `my_project/model.py`. 17 | * A `DatasetReader` class in `my_project/dataset_reader.py`. 18 | * Tests for both of these classes in `tests`, including a small toy dataset that can be read. We 19 | strongly recommend that you use a toy dataset with tests like this during model development, for 20 | quick debug cycles. To run the tests, just run `pytest` from the base directory of this project. 21 | * A script to configure the model, dataset reader, and other training loop objects, in 22 | `my_project/train.py`. The `build_*` methods are meant to be changed according to however you 23 | want to setup your training run. You probably don't need to change the `run_training_loop` 24 | function. To train the model just run `python run.py` after doing `pip install allennlp`. 25 | -------------------------------------------------------------------------------- /my_project/__init__.py: -------------------------------------------------------------------------------- 1 | # These imports are important for making the configuration files find the classes that you wrote. 2 | # If you don't have these, you'll get errors about allennlp not being able to find 3 | # "simple_classifier", or whatever name you registered your model with. These imports and the 4 | # contents of .allennlp_plugins makes it so you can just use `allennlp train`, and we will find your 5 | # classes and use them. If you change the name of `my_project`, you'll also need to change it in 6 | # the same way in the .allennlp_plugins file. 7 | from my_project.model import * 8 | from my_project.dataset_reader import * 9 | -------------------------------------------------------------------------------- /my_project/dataset_reader.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Iterable, List 2 | 3 | from allennlp.data import DatasetReader, Instance, Field 4 | from allennlp.data.fields import LabelField, TextField 5 | from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer 6 | from allennlp.data.tokenizers import Token, Tokenizer, WhitespaceTokenizer 7 | 8 | 9 | @DatasetReader.register("classification-tsv") 10 | class ClassificationTsvReader(DatasetReader): 11 | def __init__( 12 | self, 13 | tokenizer: Tokenizer = None, 14 | token_indexers: Dict[str, TokenIndexer] = None, 15 | max_tokens: int = None, 16 | **kwargs 17 | ): 18 | super().__init__(**kwargs) 19 | self.tokenizer = tokenizer or WhitespaceTokenizer() 20 | self.token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()} 21 | self.max_tokens = max_tokens 22 | 23 | def text_to_instance(self, text: str, label: str = None) -> Instance: # type: ignore 24 | tokens = self.tokenizer.tokenize(text) 25 | if self.max_tokens: 26 | tokens = tokens[: self.max_tokens] 27 | text_field = TextField(tokens, self.token_indexers) 28 | fields: Dict[str, Field] = {"text": text_field} 29 | if label: 30 | fields["label"] = LabelField(label) 31 | return Instance(fields) 32 | 33 | def _read(self, file_path: str) -> Iterable[Instance]: 34 | with open(file_path, "r") as lines: 35 | for line in lines: 36 | text, sentiment = line.strip().split("\t") 37 | yield self.text_to_instance(text, sentiment) 38 | -------------------------------------------------------------------------------- /my_project/model.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | from allennlp.data import Vocabulary, TextFieldTensors 5 | from allennlp.models import Model 6 | from allennlp.modules import TextFieldEmbedder, Seq2VecEncoder 7 | from allennlp.nn import util 8 | from allennlp.training.metrics import CategoricalAccuracy 9 | 10 | 11 | @Model.register("simple_classifier") 12 | class SimpleClassifier(Model): 13 | def __init__( 14 | self, vocab: Vocabulary, embedder: TextFieldEmbedder, encoder: Seq2VecEncoder 15 | ): 16 | super().__init__(vocab) 17 | self.embedder = embedder 18 | self.encoder = encoder 19 | num_labels = vocab.get_vocab_size("labels") 20 | self.classifier = torch.nn.Linear(encoder.get_output_dim(), num_labels) 21 | self.accuracy = CategoricalAccuracy() 22 | 23 | def forward( # type: ignore 24 | self, text: TextFieldTensors, label: torch.Tensor = None 25 | ) -> Dict[str, torch.Tensor]: 26 | # Shape: (batch_size, num_tokens, embedding_dim) 27 | embedded_text = self.embedder(text) 28 | # Shape: (batch_size, num_tokens) 29 | mask = util.get_text_field_mask(text) 30 | # Shape: (batch_size, encoding_dim) 31 | encoded_text = self.encoder(embedded_text, mask) 32 | # Shape: (batch_size, num_labels) 33 | logits = self.classifier(encoded_text) 34 | # Shape: (batch_size, num_labels) 35 | probs = torch.nn.functional.softmax(logits, dim=-1) 36 | # Shape: (1,) 37 | output = {"probs": probs} 38 | if label is not None: 39 | self.accuracy(logits, label) 40 | output["loss"] = torch.nn.functional.cross_entropy(logits, label) 41 | return output 42 | 43 | def get_metrics(self, reset: bool = False) -> Dict[str, float]: 44 | return {"accuracy": self.accuracy.get_metric(reset)} 45 | -------------------------------------------------------------------------------- /my_project/train.py: -------------------------------------------------------------------------------- 1 | # This file contains a bunch of build_* methods that configure objects however you want, and a 2 | # run_training_loop method that calls these methods and runs the trainer. 3 | 4 | from itertools import chain 5 | from typing import Iterable, Tuple 6 | 7 | import allennlp 8 | import torch 9 | from allennlp.data import DataLoader, DatasetReader, Instance, Vocabulary 10 | from allennlp.data.data_loaders import MultiProcessDataLoader 11 | from allennlp.models import Model 12 | from allennlp.modules.seq2vec_encoders import BagOfEmbeddingsEncoder 13 | from allennlp.modules.token_embedders import Embedding 14 | from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder 15 | from allennlp.training.trainer import GradientDescentTrainer, Trainer 16 | from allennlp.training.optimizers import AdamOptimizer 17 | 18 | from my_project.dataset_reader import ClassificationTsvReader 19 | from my_project.model import SimpleClassifier 20 | 21 | 22 | def build_dataset_reader() -> DatasetReader: 23 | return ClassificationTsvReader() 24 | 25 | 26 | def build_vocab(train_loader, dev_loader) -> Vocabulary: 27 | print("Building the vocabulary") 28 | return Vocabulary.from_instances( 29 | chain(train_loader.iter_instances(), dev_loader.iter_instances()) 30 | ) 31 | 32 | 33 | def build_model(vocab: Vocabulary) -> Model: 34 | print("Building the model") 35 | vocab_size = vocab.get_vocab_size("tokens") 36 | embedder = BasicTextFieldEmbedder( 37 | {"tokens": Embedding(embedding_dim=10, num_embeddings=vocab_size)} 38 | ) 39 | encoder = BagOfEmbeddingsEncoder(embedding_dim=10) 40 | return SimpleClassifier(vocab, embedder, encoder) 41 | 42 | 43 | def build_data_loaders( 44 | reader, 45 | train_data_path: str, 46 | validation_data_path: str, 47 | ) -> Tuple[DataLoader, DataLoader]: 48 | train_loader = MultiProcessDataLoader( 49 | reader, train_data_path, batch_size=8, shuffle=True 50 | ) 51 | dev_loader = MultiProcessDataLoader( 52 | reader, validation_data_path, batch_size=8, shuffle=False 53 | ) 54 | return train_loader, dev_loader 55 | 56 | 57 | def build_trainer( 58 | model: Model, 59 | serialization_dir: str, 60 | train_loader: DataLoader, 61 | dev_loader: DataLoader, 62 | ) -> Trainer: 63 | parameters = [(n, p) for n, p in model.named_parameters() if p.requires_grad] 64 | optimizer = AdamOptimizer(parameters) # type: ignore 65 | # There are a *lot* of other things you could configure with the trainer. See 66 | # http://docs.allennlp.org/master/api/training/trainer/#gradientdescenttrainer-objects for more 67 | # information. 68 | 69 | trainer = GradientDescentTrainer( 70 | model=model, 71 | serialization_dir=serialization_dir, 72 | data_loader=train_loader, 73 | validation_data_loader=dev_loader, 74 | num_epochs=5, 75 | optimizer=optimizer, 76 | validation_metric="+accuracy", 77 | ) 78 | return trainer 79 | 80 | 81 | def run_training_loop(serialization_dir: str): 82 | reader = build_dataset_reader() 83 | 84 | train_loader, dev_loader = build_data_loaders( 85 | reader, "/path/to/your/training/data", "/path/to/your/validation/data" 86 | ) 87 | 88 | vocab = build_vocab(train_loader, dev_loader) 89 | model = build_model(vocab) 90 | 91 | # This is the allennlp-specific functionality in the Dataset object; 92 | # we need to be able convert strings in the data to integers, and this 93 | # is how we do it. 94 | train_loader.index_with(vocab) 95 | dev_loader.index_with(vocab) 96 | 97 | trainer = build_trainer(model, serialization_dir, train_loader, dev_loader) 98 | 99 | # NOTE: Training using multiple GPUs is hard in this setting. If you want multi-GPU training, 100 | # we recommend using our config file template instead, which handles this case better, as well 101 | # as saving the model in a way that it can be easily loaded later. If you really want to use 102 | # your own python script with distributed training, have a look at the code for the allennlp 103 | # train command (https://github.com/allenai/allennlp/blob/master/allennlp/commands/train.py), 104 | # which is where we handle distributed training. Also, let us know on github that you want 105 | # this; we could refactor things to make this usage much easier, if there's enough interest. 106 | 107 | print("Starting training") 108 | trainer.train() 109 | print("Finished training") 110 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | filterwarnings = 3 | ignore::DeprecationWarning 4 | ignore::PendingDeprecationWarning 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | allennlp>=2.0,<3.0 2 | pytest 3 | flake8 4 | mypy==0.800 5 | black==20.8b1 6 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | # There are lots of way to set up this training script. We're putting the bulk of the code inside 2 | # the my_project module, with a simple run script in the base directory. If you prefer, you could 3 | # just take train.py and move it to the top-level directory and use that as your run.py. Do 4 | # whatever you're most comfortable with. 5 | 6 | from my_project.train import run_training_loop 7 | 8 | run_training_loop(serialization_dir="results/") 9 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/allennlp-template-python-script/e5c66e9b90d5855919f645899a8ee3900ae500c5/tests/__init__.py -------------------------------------------------------------------------------- /tests/fixtures/toy_data.tsv: -------------------------------------------------------------------------------- 1 | it is movies like these that make a jaded movie viewer thankful for the invention of the timex indiglo watch . of course , the ads make it seem like so much more . sounds like a cool movie , does it not ? after the first fifteen minutes , it quickly becomes apparent that it is not . it is clear that the film is nothing more than an attempt to cash in on the teenage spending dollar , judging from the rash of really awful teen-flicks that we've been seeing as of late . avoid this film at all costs . neg 2 | the music is well-chosen and scored . in terms of overall quality , i would compare the truman show to niccol's gattaca . both films are well made with interesting stories set in interesting worlds . but neither film really felt like it capitalized on all the great ideas ; neither film " clicked " and became an instant classic . nevertheless , i look forward to niccol's next film , whatever it may be . pos 3 | -------------------------------------------------------------------------------- /tests/test_dataset_reader.py: -------------------------------------------------------------------------------- 1 | from allennlp.common.testing import AllenNlpTestCase 2 | 3 | from my_project.dataset_reader import ClassificationTsvReader 4 | 5 | 6 | class TestTextClassificationJsonReader(AllenNlpTestCase): 7 | def test_read_from_file_ag_news_corpus_and_truncates_properly(self): 8 | reader = ClassificationTsvReader() 9 | data_path = "tests/fixtures/toy_data.tsv" 10 | instances = list(reader.read(data_path)) 11 | 12 | assert len(instances) == 2 13 | 14 | fields = instances[0].fields 15 | expected_tokens = ["it", "is", "movies", "like", "these"] 16 | assert [t.text for t in fields["text"].tokens][:5] == expected_tokens 17 | assert fields["label"].label == "neg" 18 | 19 | fields = instances[1].fields 20 | expected_tokens = ["the", "music", "is", "well-chosen", "and"] 21 | assert [t.text for t in fields["text"].tokens][:5] == expected_tokens 22 | assert fields["label"].label == "pos" 23 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | 3 | from allennlp.common.testing import ModelTestCase 4 | 5 | from my_project.train import ( 6 | build_dataset_reader, 7 | build_vocab, 8 | build_model, 9 | build_data_loaders, 10 | build_trainer, 11 | ) 12 | 13 | 14 | class TestSimpleClassifier(ModelTestCase): 15 | def test_model_can_train(self): 16 | with tempfile.TemporaryDirectory() as serialization_dir: 17 | reader = build_dataset_reader() 18 | train_loader, dev_loader = build_data_loaders( 19 | reader, "tests/fixtures/toy_data.tsv", "tests/fixtures/toy_data.tsv" 20 | ) 21 | vocab = build_vocab(train_loader, dev_loader) 22 | train_loader.index_with(vocab) 23 | dev_loader.index_with(vocab) 24 | # Ideally you'd want to build a tiny toy model here, instead of calling the full 25 | # build_model function, like we do with the data above. 26 | model = build_model(vocab) 27 | trainer = build_trainer( 28 | model, serialization_dir, train_loader, train_loader 29 | ) 30 | # This built-in test makes sure that your data can load, that it gets passed to the 31 | # model correctly, that your model computes a loss in a way that we can get gradients 32 | # from it, and that all of your parameters get non-zero gradient updates. 33 | self.ensure_model_can_train(trainer) 34 | --------------------------------------------------------------------------------