├── sequence_aligner ├── __init__.py ├── labelset.py ├── containers.py ├── alignment.py └── dataset.py ├── LICENSE ├── .gitignore ├── README.md └── notebooks ├── Start Here.ipynb └── how-to-align-notebook.ipynb /sequence_aligner/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sequence_aligner/labelset.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from typing import List 3 | 4 | from .alignment import align_tokens_and_annotations_bilou 5 | 6 | 7 | class LabelSet: 8 | def __init__(self, labels: List[str]): 9 | self.labels_to_id = {} 10 | self.ids_to_label = {} 11 | self.labels_to_id["O"] = 0 12 | self.ids_to_label[0] = "O" 13 | num = 0 # in case there are no labels 14 | # Writing BILU will give us incremntal ids for the labels 15 | for _num, (label, s) in enumerate(itertools.product(labels, "BILU")): 16 | num = _num + 1 # skip 0 17 | l = f"{s}-{label}" 18 | self.labels_to_id[l] = num 19 | self.ids_to_label[num] = l 20 | 21 | 22 | def get_aligned_label_ids_from_annotations(self, tokenized_text, annotations): 23 | raw_labels = align_tokens_and_annotations_bilou(tokenized_text, annotations) 24 | return list(map(self.labels_to_id.get, raw_labels)) 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 LightTag 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /sequence_aligner/containers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dataclasses import dataclass 3 | from typing import List, Any 4 | 5 | IntList = List[int] # A list of token_ids 6 | IntListList = List[IntList] # A List of List of token_ids, e.g. a Batch 7 | 8 | 9 | @dataclass 10 | class TrainingExample: 11 | input_ids: IntList 12 | attention_masks: IntList 13 | labels: IntList 14 | 15 | 16 | 17 | 18 | 19 | class TraingingBatch: 20 | def __getitem__(self, item): 21 | return getattr(self, item) 22 | 23 | def __init__(self, examples: List[TrainingExample]): 24 | self.input_ids: torch.Tensor 25 | self.attention_masks: torch.Tensor 26 | self.labels: torch.Tensor 27 | input_ids: IntListList = [] 28 | masks: IntListList = [] 29 | labels: IntListList = [] 30 | for ex in examples: 31 | input_ids.append(ex.input_ids) 32 | masks.append(ex.attention_masks) 33 | labels.append(ex.labels) 34 | self.input_ids = torch.LongTensor(input_ids) 35 | self.attention_masks = torch.LongTensor(masks) 36 | self.labels = torch.LongTensor(labels) 37 | -------------------------------------------------------------------------------- /sequence_aligner/alignment.py: -------------------------------------------------------------------------------- 1 | from tokenizers import Encoding 2 | def align_tokens_and_annotations_bilou(tokenized: Encoding, annotations): 3 | tokens = tokenized.tokens 4 | aligned_labels = ["O"] * len( 5 | tokens 6 | ) # Make a list to store our labels the same length as our tokens 7 | for anno in annotations: 8 | annotation_token_ix_set = ( 9 | set() 10 | ) # A set that stores the token indices of the annotation 11 | for char_ix in range(anno["start"], anno["end"]): 12 | 13 | token_ix = tokenized.char_to_token(char_ix) 14 | if token_ix is not None: 15 | annotation_token_ix_set.add(token_ix) 16 | if len(annotation_token_ix_set) == 1: 17 | # If there is only one token 18 | token_ix = annotation_token_ix_set.pop() 19 | prefix = ( 20 | "U" # This annotation spans one token so is prefixed with U for unique 21 | ) 22 | aligned_labels[token_ix] = f"{prefix}-{anno['label']}" 23 | 24 | else: 25 | 26 | last_token_in_anno_ix = len(annotation_token_ix_set) - 1 27 | for num, token_ix in enumerate(sorted(annotation_token_ix_set)): 28 | if num == 0: 29 | prefix = "B" 30 | elif num == last_token_in_anno_ix: 31 | prefix = "L" # Its the last token 32 | else: 33 | prefix = "I" # We're inside of a multi token annotation 34 | aligned_labels[token_ix] = f"{prefix}-{anno['label']}" 35 | return aligned_labels -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | .idea/ -------------------------------------------------------------------------------- /sequence_aligner/dataset.py: -------------------------------------------------------------------------------- 1 | from typing import List, Any 2 | 3 | from torch.utils.data import Dataset 4 | from transformers import PreTrainedTokenizerFast 5 | from typing_extensions import TypedDict 6 | 7 | from .containers import TrainingExample 8 | from .labelset import LabelSet 9 | class ExpectedAnnotationShape(TypedDict): 10 | start:int 11 | end:int 12 | label :str 13 | 14 | class ExpectedDataItemShape(TypedDict): 15 | content:str # The Text to be annotated 16 | annotations :List[ExpectedAnnotationShape] 17 | 18 | class TrainingDataset(Dataset): 19 | '''''' 20 | def __init__( 21 | self, 22 | data: Any, 23 | label_set: LabelSet, 24 | tokenizer: PreTrainedTokenizerFast, 25 | tokens_per_batch=32, 26 | window_stride=None, 27 | ): 28 | self.label_set = label_set 29 | if window_stride is None: 30 | self.window_stride = tokens_per_batch 31 | self.tokenizer = tokenizer 32 | self.texts = [] 33 | self.annotations = [] 34 | 35 | for example in data: 36 | self.texts.append(example["content"]) 37 | self.annotations.append(example["annotations"]) 38 | ###TOKENIZE All THE DATA 39 | tokenized_batch = self.tokenizer(self.texts, add_special_tokens=False) 40 | ###ALIGN LABELS ONE EXAMPLE AT A TIME 41 | aligned_labels = [] 42 | for ix in range(len(tokenized_batch.encodings)): 43 | encoding = tokenized_batch.encodings[ix] 44 | raw_annotations = self.annotations[ix] 45 | aligned = label_set.get_aligned_label_ids_from_annotations( 46 | encoding, raw_annotations 47 | ) 48 | aligned_labels.append(aligned) 49 | ###END OF LABEL ALIGNMENT 50 | 51 | ###MAKE A LIST OF TRAINING EXAMPLES. (This is where we add padding) 52 | self.training_examples: List[TrainingExample] = [] 53 | empty_label_id = "O" 54 | for encoding, label in zip(tokenized_batch.encodings, aligned_labels): 55 | length = len(label) # How long is this sequence 56 | for start in range(0, length, self.window_stride): 57 | 58 | end = min(start + tokens_per_batch, length) 59 | 60 | # How much padding do we need ? 61 | padding_to_add = max(0, tokens_per_batch - end + start) 62 | self.training_examples.append( 63 | TrainingExample( 64 | # Record the tokens 65 | input_ids=encoding.ids[start:end] # The ids of the tokens 66 | + [self.tokenizer.pad_token_id] 67 | * padding_to_add, # padding if needed 68 | labels=( 69 | label[start:end] 70 | + [-100] * padding_to_add # padding if needed 71 | ), # -100 is a special token for padding of labels, 72 | attention_masks=( 73 | encoding.attention_mask[start:end] 74 | + [0] 75 | * padding_to_add # 0'd attenetion masks where we added padding 76 | ), 77 | ) 78 | ) 79 | 80 | def __len__(self): 81 | return len(self.training_examples) 82 | 83 | def __getitem__(self, idx) -> TrainingExample: 84 | 85 | return self.training_examples[idx] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Preparing Sequence Labeling Data for Transformers Is Hard 2 | 3 | This repo has some utilities to align offset annotations (start,end) to tokenizer outputs, 4 | and to create pytorch datasets and dataloaders that handle padding and batching. 5 | 6 | The impetus for this repo is this [github issue](https://github.com/huggingface/transformers/issues/7019). 7 | A blog post explaining our thinking around how to [best prepare sequence labeling data for use with pre-trained transformers](https://www.lighttag.io/blog/sequence-labeling-with-transformers/) 8 | and another post, deriving [the implementation in this repo is here](https://lighttag.io/blog/sequence-labeling-with-transformers/example). 9 | 10 | This is a POC and maybe a work in progress. Issues, PRs and contributions welcome. 11 | The code is optimized for readability and clarity of thought. There is plenty of room for performance improvement, 12 | but not much of a case for it because compute time and memory are dominated by training. 13 | 14 | ## Quick Example 15 | If we have annotated data like this 16 | ```python 17 | [{'annotations': [], 18 | 'content': 'No formal drug interaction studies of Aranesp? have been ' 19 | 'performed.', 20 | 'metadata': {'original_id': 'DrugDDI.d390.s0'}}, 21 | {'annotations': [{'end': 13, 'label': 'drug', 'start': 6, 'tag': 'drug'}, 22 | {'end': 60, 'label': 'drug', 'start': 43, 'tag': 'drug'}, 23 | {'end': 112, 'label': 'drug', 'start': 105, 'tag': 'drug'}, 24 | {'end': 177, 'label': 'drug', 'start': 164, 'tag': 'drug'}, 25 | {'end': 194, 'label': 'drug', 'start': 181, 'tag': 'drug'}, 26 | {'end': 219, 'label': 'drug', 'start': 211, 'tag': 'drug'}, 27 | {'end': 238, 'label': 'drug', 'start': 227, 'tag': 'drug'}], 28 | 'content': 'Since PLETAL is extensively metabolized by cytochrome P-450 ' 29 | 'isoenzymes, caution should be exercised when PLETAL is ' 30 | 'coadministered with inhibitors of C.P.A. such as ketoconazole ' 31 | 'and erythromycin or inhibitors of CYP2C19 such as omeprazole.', 32 | 'metadata': {'original_id': 'DrugDDI.d452.s0'}}, 33 | {'annotations': [{'end': 58, 'label': 'drug', 'start': 47, 'tag': 'drug'}, 34 | {'end': 75, 'label': 'drug', 'start': 62, 'tag': 'drug'}, 35 | {'end': 135, 'label': 'drug', 'start': 124, 'tag': 'drug'}, 36 | {'end': 164, 'label': 'drug', 'start': 152, 'tag': 'drug'}], 37 | 'content': 'Pharmacokinetic studies have demonstrated that omeprazole and ' 38 | 'erythromycin significantly increased the systemic exposure of ' 39 | 'cilostazol and/or its major metabolites.', 40 | 'metadata': {'original_id': 'DrugDDI.d452.s1'}}] 41 | ``` 42 | We can do this 43 | ```python 44 | from sequence_aligner.labelset import LabelSet 45 | from sequence_aligner.dataset import TrainingDataset 46 | from sequence_aligner.containers import TraingingBatch 47 | import json 48 | raw = json.load(open('./data/ddi_train.json')) 49 | for example in raw: 50 | for annotation in example['annotations']: 51 | #We expect the key of label to be label but the data has tag 52 | annotation['label'] = annotation['tag'] 53 | 54 | from torch.utils.data import DataLoader 55 | from transformers import BertForTokenClassification,AdamW 56 | model = BertForTokenClassification.from_pretrained( 57 | "bert-base-cased", num_labels=len(dataset.label_set.ids_to_label.values()) 58 | ) 59 | optimizer = AdamW(model.parameters(), lr=5e-6) 60 | 61 | dataloader = DataLoader( 62 | dataset, 63 | collate_fn=TraingingBatch, 64 | batch_size=4, 65 | shuffle=True, 66 | ) 67 | for num, batch in enumerate(dataloader): 68 | loss, logits = model( 69 | input_ids=batch.input_ids, 70 | attention_mask=batch.attention_masks, 71 | labels=batch.labels, 72 | ) 73 | loss.backward() 74 | optimizer.step() 75 | -------------------------------------------------------------------------------- /notebooks/Start Here.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# What Is This\n", 8 | "This notebook shows how to use the utilities in the repo to quickly start a sequence labeling training. \n", 9 | "The utilities take care of alignment, padding, batching and windowing. \n", 10 | "For a walk through of the utiltiies see our [tutorial on sequence labeling with transformers](https://lighttag.io/blog/sequence-labeling-with-transformers/example). For the reasoning behind it see our semi-essay on the considerations of [aligning span annotations to Huggingface tokenizer outputs](https://www.lighttag.io/blog/sequence-labeling-with-transformers/) " 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "cd .." 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 8, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "from sequence_aligner.labelset import LabelSet\n", 29 | "from sequence_aligner.dataset import TrainingDataset\n", 30 | "from sequence_aligner.containers import TraingingBatch\n", 31 | "import json\n" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 5, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "## Load The Raw Data\n", 41 | "raw = json.load(open('./data/ddi_train.json'))\n", 42 | "for example in raw:\n", 43 | " for annotation in example['annotations']:\n", 44 | " #We expect the key of label to be label but the data has tag\n", 45 | " annotation['label'] = annotation['tag']" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 7, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "from transformers import BertTokenizerFast\n", 55 | "tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')\n", 56 | "label_set = LabelSet(labels=[\"drug\"]) #Only one label in this dataset\n", 57 | "dataset = TrainingDataset(data=raw,tokenizer=tokenizer,label_set=label_set)" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 12, 63 | "metadata": {}, 64 | "outputs": [ 65 | { 66 | "name": "stderr", 67 | "output_type": "stream", 68 | "text": [ 69 | "Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForTokenClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']\n", 70 | "- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n", 71 | "- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", 72 | "Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.weight', 'classifier.bias']\n", 73 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" 74 | ] 75 | }, 76 | { 77 | "name": "stdout", 78 | "output_type": "stream", 79 | "text": [ 80 | "tensor(1.6381, grad_fn=)\n", 81 | "tensor(1.4514, grad_fn=)\n", 82 | "tensor(1.5203, grad_fn=)\n", 83 | "tensor(1.3982, grad_fn=)\n", 84 | "tensor(1.2953, grad_fn=)\n" 85 | ] 86 | } 87 | ], 88 | "source": [ 89 | "from torch.utils.data import DataLoader\n", 90 | "from transformers import BertForTokenClassification,AdamW\n", 91 | "model = BertForTokenClassification.from_pretrained(\n", 92 | " \"bert-base-cased\", num_labels=len(dataset.label_set.ids_to_label.values())\n", 93 | ")\n", 94 | "optimizer = AdamW(model.parameters(), lr=5e-6)\n", 95 | "\n", 96 | "dataloader = DataLoader(\n", 97 | " dataset,\n", 98 | " collate_fn=TraingingBatch,\n", 99 | " batch_size=4,\n", 100 | " shuffle=True,\n", 101 | ")\n", 102 | "for num, batch in enumerate(dataloader):\n", 103 | " loss, logits = model(\n", 104 | " input_ids=batch.input_ids,\n", 105 | " attention_mask=batch.attention_masks,\n", 106 | " labels=batch.labels,\n", 107 | " )\n", 108 | " loss.backward()\n", 109 | " optimizer.step()\n", 110 | " print(loss)\n", 111 | " if num > 3:\n", 112 | " break" 113 | ] 114 | } 115 | ], 116 | "metadata": { 117 | "kernelspec": { 118 | "display_name": "Python 3", 119 | "language": "python", 120 | "name": "python3" 121 | }, 122 | "language_info": { 123 | "codemirror_mode": { 124 | "name": "ipython", 125 | "version": 3 126 | }, 127 | "file_extension": ".py", 128 | "mimetype": "text/x-python", 129 | "name": "python", 130 | "nbconvert_exporter": "python", 131 | "pygments_lexer": "ipython3", 132 | "version": "3.6.9" 133 | } 134 | }, 135 | "nbformat": 4, 136 | "nbformat_minor": 2 137 | } 138 | -------------------------------------------------------------------------------- /notebooks/how-to-align-notebook.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Our previous post on aligning span annotations to Hugginface's tokenizer outputs discussed the the various tradeoffs one needs to consider, and concluded that a windowing strategy over the tokenized text and labels is optimal for our use cases. \n", 8 | "\n", 9 | "This post demonstrates an end to end implementation of token alignment and windowing. We'll start by implementing utility classes that make programming a little easier, then implement the alignment functionality which aligns offset annotations to the out of a tokenizer. Finnaly we'll implement a PyTorch Dataset that stores our aligned tokens and labels as windows, a Collator to implement batching and a simple DataLoader to be used in training. \n", 10 | "\n", 11 | "We'll show and end to end flow on the DDI Corpus, recognizing pharmacological entities with BERT." 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "## Utility Classes For Convenient APIs\n", 19 | "We'll start by defining some types and utility classes that will make our work more convient" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 1, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "import warnings\n", 29 | "warnings.filterwarnings('ignore')\n", 30 | "\n", 31 | "from typing_extensions import TypedDict\n", 32 | "from typing import List,Any\n", 33 | "IntList = List[int] # A list of token_ids\n", 34 | "IntListList = List[IntList] # A List of List of token_ids, e.g. a Batch" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "metadata": {}, 47 | "source": [ 48 | "## The Alignment Algorithm\n", 49 | "\n", 50 | "### FastTokenizers Simplify Alignment\n", 51 | "Recent versions of Hugginface's tokenizers library include variants of Tokenizers that end with Fast and inherit from [PreTrainedTokenizerFast](https://huggingface.co/transformers/main_classes/tokenizer.html#transformers.PreTrainedTokenizerFast) such as [BertTokenizerFast](https://huggingface.co/transformers/model_doc/bert.html#berttokenizerfast) and [GPT2TokenizerFast](https://huggingface.co/transformers/model_doc/gpt2.html#gpt2tokenizerfast). \n", 52 | "\n", 53 | "Per the tokenizer's documentation\n", 54 | "> When the tokenizer is a “Fast” tokenizer (i.e., backed by HuggingFace tokenizers library), [the output] provides in addition several advanced alignment methods which can be used to map between the original string (character and words) and the token space (e.g., getting the index of the token comprising a given character or the span of characters corresponding to a given token).\n", 55 | "\n", 56 | "Notably, the output provides the methods [token_to_chars](https://huggingface.co/transformers/main_classes/tokenizer.html#transformers.BatchEncoding.token_to_chars) and [char_to_token](https://huggingface.co/transformers/main_classes/tokenizer.html#transformers.BatchEncoding.char_to_token) which do exactly what their name implies, provide mappings between tokens and charachter offsets in the original text. That's exactly what we need to align annotations in offset format with tokens.\n", 57 | "\n", 58 | "\n" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "metadata": {}, 64 | "source": [ 65 | "## A warmup implementation\n", 66 | "Our final implementation will use the BIOUL scheme we mentioned before. But before we do that, let's try a simple alignment to see what it feels like" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 2, 72 | "metadata": {}, 73 | "outputs": [ 74 | { 75 | "name": "stdout", 76 | "output_type": "stream", 77 | "text": [ 78 | "Tal Perry Person\n", 79 | "founder Title\n", 80 | "LightTag Org\n" 81 | ] 82 | } 83 | ], 84 | "source": [ 85 | "text = \"I am Tal Perry, founder of LightTag\"\n", 86 | "annotations = [\n", 87 | " dict(start=5,end=14,text=\"Tal Perry\",label=\"Person\"),\n", 88 | " dict(start=16,end=23,text=\"founder\",label=\"Title\"),\n", 89 | " dict(start=27,end=35,text=\"LightTag\",label=\"Org\"),\n", 90 | " \n", 91 | " ]\n", 92 | "for anno in annotations:\n", 93 | " # Show our annotations\n", 94 | " print (text[anno['start']:anno['end']],anno['label'])\n", 95 | " \n" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 3, 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "from transformers import BertTokenizerFast, BatchEncoding\n", 105 | "from tokenizers import Encoding\n", 106 | "tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased') # Load a pre-trained tokenizer\n", 107 | "tokenized_batch : BatchEncoding = tokenizer(text)\n", 108 | "tokenized_text :Encoding =tokenized_batch[0]\n", 109 | "\n" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 4, 115 | "metadata": {}, 116 | "outputs": [ 117 | { 118 | "name": "stdout", 119 | "output_type": "stream", 120 | "text": [ 121 | "[CLS] - O\n", 122 | "I - O\n", 123 | "am - O\n", 124 | "Ta - Person\n", 125 | "##l - Person\n", 126 | "Perry - Person\n", 127 | ", - O\n", 128 | "founder - Title\n", 129 | "of - O\n", 130 | "Light - Org\n", 131 | "##T - Org\n", 132 | "##ag - Org\n", 133 | "[SEP] - O\n" 134 | ] 135 | } 136 | ], 137 | "source": [ 138 | "tokens = tokenized_text.tokens\n", 139 | "aligned_labels = [\"O\"]*len(tokens) # Make a list to store our labels the same length as our tokens\n", 140 | "for anno in (annotations):\n", 141 | " for char_ix in range(anno['start'],anno['end']):\n", 142 | " token_ix = tokenized_text.char_to_token(char_ix)\n", 143 | " if token_ix is not None: # White spaces have no token and will return None\n", 144 | " aligned_labels[token_ix] = anno['label']\n", 145 | "for token,label in zip(tokens,aligned_labels):\n", 146 | " print (token,\"-\",label)" 147 | ] 148 | }, 149 | { 150 | "cell_type": "markdown", 151 | "metadata": {}, 152 | "source": [ 153 | "### Accounting For Multi Token Annotations\n", 154 | "In the above example, some of our annotations spanned multiple tokens. For instance \"Tal Perry\" spanned \"Ta\", \"##l\" and \"Perry\". Clearly by themeselves none of those tokens are a Person, and so our current alignment scheme isn't as useful as it could be. \n", 155 | "To overcome that, we'll use the previously mentioned BIOLU scheme, which will indicate if a token is the begining, inside, last token in an annotation or if it is not part of an annotation or if it is perfectly aligned with an annotation." 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 5, 161 | "metadata": {}, 162 | "outputs": [ 163 | { 164 | "name": "stdout", 165 | "output_type": "stream", 166 | "text": [ 167 | "[CLS] - O\n", 168 | "I - O\n", 169 | "am - O\n", 170 | "Ta - B-Person\n", 171 | "##l - I-Person\n", 172 | "Perry - L-Person\n", 173 | ", - O\n", 174 | "founder - U-Title\n", 175 | "of - O\n", 176 | "Light - B-Org\n", 177 | "##T - I-Org\n", 178 | "##ag - L-Org\n", 179 | "[SEP] - O\n" 180 | ] 181 | } 182 | ], 183 | "source": [ 184 | "def align_tokens_and_annotations_bilou(tokenized: Encoding, annotations):\n", 185 | " tokens = tokenized.tokens\n", 186 | " aligned_labels = [\"O\"] * len(\n", 187 | " tokens\n", 188 | " ) # Make a list to store our labels the same length as our tokens\n", 189 | " for anno in annotations:\n", 190 | " annotation_token_ix_set = (\n", 191 | " set()\n", 192 | " ) # A set that stores the token indices of the annotation\n", 193 | " for char_ix in range(anno[\"start\"], anno[\"end\"]):\n", 194 | "\n", 195 | " token_ix = tokenized.char_to_token(char_ix)\n", 196 | " if token_ix is not None:\n", 197 | " annotation_token_ix_set.add(token_ix)\n", 198 | " if len(annotation_token_ix_set) == 1:\n", 199 | " # If there is only one token\n", 200 | " token_ix = annotation_token_ix_set.pop()\n", 201 | " prefix = (\n", 202 | " \"U\" # This annotation spans one token so is prefixed with U for unique\n", 203 | " )\n", 204 | " aligned_labels[token_ix] = f\"{prefix}-{anno['label']}\"\n", 205 | "\n", 206 | " else:\n", 207 | "\n", 208 | " last_token_in_anno_ix = len(annotation_token_ix_set) - 1\n", 209 | " for num, token_ix in enumerate(sorted(annotation_token_ix_set)):\n", 210 | " if num == 0:\n", 211 | " prefix = \"B\"\n", 212 | " elif num == last_token_in_anno_ix:\n", 213 | " prefix = \"L\" # Its the last token\n", 214 | " else:\n", 215 | " prefix = \"I\" # We're inside of a multi token annotation\n", 216 | " aligned_labels[token_ix] = f\"{prefix}-{anno['label']}\"\n", 217 | " return aligned_labels\n", 218 | "\n", 219 | "\n", 220 | "labels = align_tokens_and_annotations_bilou(tokenized_text, annotations)\n", 221 | "for token, label in zip(tokens, labels):\n", 222 | " print(token, \"-\", label)" 223 | ] 224 | }, 225 | { 226 | "cell_type": "markdown", 227 | "metadata": {}, 228 | "source": [ 229 | "Notice how **founder** above has a **U** prefix and the other annotations now follow a BIL scheme.\n", 230 | "\n", 231 | "**Note** In production, you'll convert the labels to ids, using the LabelSet we defined above. I'm going to skip that for now for the sake of readability" 232 | ] 233 | }, 234 | { 235 | "cell_type": "markdown", 236 | "metadata": {}, 237 | "source": [ 238 | "### Mapping Labels To Ids\n", 239 | "It's great that we have our annotations aligned, but we need the labels as integer ids for training. During inference, we'll also need a way to map predicted ids back to labels.\n", 240 | "I'm going to make a custom class that handles that, called a LabelSet. " 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": 6, 246 | "metadata": {}, 247 | "outputs": [ 248 | { 249 | "name": "stdout", 250 | "output_type": "stream", 251 | "text": [ 252 | "[CLS] - 0\n", 253 | "I - 0\n", 254 | "am - 0\n", 255 | "Ta - 1\n", 256 | "##l - 2\n", 257 | "Perry - 3\n", 258 | ", - 0\n", 259 | "founder - 12\n", 260 | "of - 0\n", 261 | "Light - 5\n", 262 | "##T - 6\n", 263 | "##ag - 7\n", 264 | "[SEP] - 0\n" 265 | ] 266 | } 267 | ], 268 | "source": [ 269 | "import itertools\n", 270 | "\n", 271 | "\n", 272 | "class LabelSet:\n", 273 | " def __init__(self, labels: List[str]):\n", 274 | " self.labels_to_id = {}\n", 275 | " self.ids_to_label = {}\n", 276 | " self.labels_to_id[\"O\"] = 0\n", 277 | " self.ids_to_label[0] = \"O\"\n", 278 | " num = 0 # in case there are no labels\n", 279 | " # Writing BILU will give us incremntal ids for the labels\n", 280 | " for _num, (label, s) in enumerate(itertools.product(labels, \"BILU\")):\n", 281 | " num = _num + 1 # skip 0\n", 282 | " l = f\"{s}-{label}\"\n", 283 | " self.labels_to_id[l] = num\n", 284 | " self.ids_to_label[num] = l\n", 285 | " # Add the OUTSIDE label - no label for the token\n", 286 | "\n", 287 | " def get_aligned_label_ids_from_annotations(self, tokenized_text, annotations):\n", 288 | " raw_labels = align_tokens_and_annotations_bilou(tokenized_text, annotations)\n", 289 | " return list(map(self.labels_to_id.get, raw_labels))\n", 290 | "\n", 291 | "\n", 292 | "example_label_set = LabelSet(labels=[\"Person\", \"Org\", \"Title\"])\n", 293 | "aligned_label_ids = example_label_set.get_aligned_label_ids_from_annotations(\n", 294 | " tokenized_text, annotations\n", 295 | ")\n", 296 | "\n", 297 | "for token, label in zip(tokens, aligned_label_ids):\n", 298 | " print(token, \"-\", label)" 299 | ] 300 | }, 301 | { 302 | "cell_type": "markdown", 303 | "metadata": {}, 304 | "source": [ 305 | "# Batching\n", 306 | "Now that we have alignment logic in place, we need to figure out how to load, batch and pad the data. We also need to habdle the case where our text is longer than we can feed our model. Below we show an implementation of a particular strategy, windowing over uniform length segments of the text. This isn't the only strategy, or even necasarily the best, but it fits our use case well. You can read more about why [we use windowing when training ner models with BERT here](https://www.lighttag.io/blog/sequence-labeling-with-transformers/). Below we'll just show how to do that.\n", 307 | "\n", 308 | "## The Raw Dataset\n", 309 | "We'll be using the [DDI Corpus](https://www.sciencedirect.com/science/article/pii/S1532046413001123). This notebook will pull the files locally but you can download them as [JSON here](https://github.com/LightTag/DDICorpus).\n", 310 | "Let's take a quick look at it\n" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": 7, 316 | "metadata": {}, 317 | "outputs": [ 318 | { 319 | "name": "stdout", 320 | "output_type": "stream", 321 | "text": [ 322 | "{'annotations': [{'end': 58, 'label': 'drug', 'start': 47, 'tag': 'drug'},\n", 323 | " {'end': 75, 'label': 'drug', 'start': 62, 'tag': 'drug'},\n", 324 | " {'end': 135, 'label': 'drug', 'start': 124, 'tag': 'drug'},\n", 325 | " {'end': 164, 'label': 'drug', 'start': 152, 'tag': 'drug'}],\n", 326 | " 'content': 'Pharmacokinetic studies have demonstrated that omeprazole and '\n", 327 | " 'erythromycin significantly increased the systemic exposure of '\n", 328 | " 'cilostazol and/or its major metabolites.',\n", 329 | " 'metadata': {'original_id': 'DrugDDI.d452.s1'}}\n" 330 | ] 331 | } 332 | ], 333 | "source": [ 334 | "import json\n", 335 | "from pprint import pprint\n", 336 | "\n", 337 | "raw = json.load(open(\"./ddi_train.json\"))\n", 338 | "for example in raw:\n", 339 | " # our simple implementation expects the label to be called label, so we adjust the original data\n", 340 | " for anno in example[\"annotations\"]:\n", 341 | " anno[\"label\"] = anno[\"tag\"]\n", 342 | "pprint(raw[2])" 343 | ] 344 | }, 345 | { 346 | "cell_type": "markdown", 347 | "metadata": {}, 348 | "source": [ 349 | "Lets take a look at that tokenized\n" 350 | ] 351 | }, 352 | { 353 | "cell_type": "code", 354 | "execution_count": 8, 355 | "metadata": {}, 356 | "outputs": [ 357 | { 358 | "name": "stdout", 359 | "output_type": "stream", 360 | "text": [ 361 | "[CLS] - O\n", 362 | "Ph - O\n", 363 | "##arma - O\n", 364 | "##co - O\n", 365 | "##kin - O\n", 366 | "##etic - O\n", 367 | "studies - O\n", 368 | "have - O\n", 369 | "demonstrated - O\n", 370 | "that - O\n", 371 | "o - B-drug\n", 372 | "##me - I-drug\n", 373 | "##pra - I-drug\n", 374 | "##zo - I-drug\n", 375 | "##le - L-drug\n", 376 | "and - O\n", 377 | "er - B-drug\n", 378 | "##yt - I-drug\n", 379 | "##hr - I-drug\n", 380 | "##omy - I-drug\n", 381 | "##cin - L-drug\n", 382 | "significantly - O\n", 383 | "increased - O\n", 384 | "the - O\n", 385 | "systemic - O\n", 386 | "exposure - O\n", 387 | "of - O\n", 388 | "c - B-drug\n", 389 | "##ilo - I-drug\n", 390 | "##sta - I-drug\n", 391 | "##zo - I-drug\n", 392 | "##l - L-drug\n", 393 | "and - O\n", 394 | "/ - O\n", 395 | "or - O\n", 396 | "its - O\n", 397 | "major - O\n", 398 | "meta - B-drug\n", 399 | "##bol - I-drug\n", 400 | "##ites - I-drug\n", 401 | ". - L-drug\n", 402 | "[SEP] - O\n" 403 | ] 404 | } 405 | ], 406 | "source": [ 407 | "example = raw[2]\n", 408 | "tokenized_batch = tokenizer(example[\"content\"])\n", 409 | "tokenized_text = tokenized_batch[0]\n", 410 | "labels = align_tokens_and_annotations_bilou(tokenized_text, example[\"annotations\"])\n", 411 | "for token, label in zip(tokenized_text.tokens, labels):\n", 412 | " print(token, \"-\", label)" 413 | ] 414 | }, 415 | { 416 | "cell_type": "markdown", 417 | "metadata": {}, 418 | "source": [ 419 | "## Padding and Windowing in a Dataset\n", 420 | "Our dataset is conveniently split into sentences. We still need to batch it and pad the examples. More commonly, data is not split into sentences, and so we will window over fixed sized parts of it. The windowing, padding and alignment logic will be done in a pytorch Dataset and we'll get to batching in a moment" 421 | ] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "execution_count": 9, 426 | "metadata": {}, 427 | "outputs": [], 428 | "source": [ 429 | "from dataclasses import dataclass\n", 430 | "from torch.utils.data import Dataset\n", 431 | "from transformers import PreTrainedTokenizerFast" 432 | ] 433 | }, 434 | { 435 | "cell_type": "code", 436 | "execution_count": 10, 437 | "metadata": {}, 438 | "outputs": [], 439 | "source": [ 440 | "@dataclass\n", 441 | "class TrainingExample:\n", 442 | " input_ids: IntList\n", 443 | " attention_masks: IntList\n", 444 | " labels: IntList\n", 445 | "\n", 446 | "\n", 447 | "class TraingDataset(Dataset):\n", 448 | " def __init__(\n", 449 | " self,\n", 450 | " data: Any,\n", 451 | " label_set: LabelSet,\n", 452 | " tokenizer: PreTrainedTokenizerFast,\n", 453 | " tokens_per_batch=32,\n", 454 | " window_stride=None,\n", 455 | " ):\n", 456 | " self.label_set = label_set\n", 457 | " if window_stride is None:\n", 458 | " self.window_stride = tokens_per_batch\n", 459 | " self.tokenizer = tokenizer\n", 460 | " for example in data:\n", 461 | " # changes tag key to label\n", 462 | " for a in example[\"annotations\"]:\n", 463 | " a[\"label\"] = a[\"tag\"]\n", 464 | " self.texts = []\n", 465 | " self.annotations = []\n", 466 | "\n", 467 | " for example in data:\n", 468 | " self.texts.append(example[\"content\"])\n", 469 | " self.annotations.append(example[\"annotations\"])\n", 470 | " ###TOKENIZE All THE DATA\n", 471 | " tokenized_batch = self.tokenizer(self.texts, add_special_tokens=False)\n", 472 | " ###ALIGN LABELS ONE EXAMPLE AT A TIME\n", 473 | " aligned_labels = []\n", 474 | " for ix in range(len(tokenized_batch.encodings)):\n", 475 | " encoding = tokenized_batch.encodings[ix]\n", 476 | " raw_annotations = self.annotations[ix]\n", 477 | " aligned = label_set.get_aligned_label_ids_from_annotations(\n", 478 | " encoding, raw_annotations\n", 479 | " )\n", 480 | " aligned_labels.append(aligned)\n", 481 | " ###END OF LABEL ALIGNMENT\n", 482 | "\n", 483 | " ###MAKE A LIST OF TRAINING EXAMPLES. (This is where we add padding)\n", 484 | " self.training_examples: List[TrainingExample] = []\n", 485 | " empty_label_id = \"O\"\n", 486 | " for encoding, label in zip(tokenized_batch.encodings, aligned_labels):\n", 487 | " length = len(label) # How long is this sequence\n", 488 | " for start in range(0, length, self.window_stride):\n", 489 | "\n", 490 | " end = min(start + tokens_per_batch, length)\n", 491 | "\n", 492 | " # How much padding do we need ?\n", 493 | " padding_to_add = max(0, tokens_per_batch - end + start)\n", 494 | " self.training_examples.append(\n", 495 | " TrainingExample(\n", 496 | " # Record the tokens\n", 497 | " input_ids=encoding.ids[start:end] # The ids of the tokens\n", 498 | " + [self.tokenizer.pad_token_id]\n", 499 | " * padding_to_add, # padding if needed\n", 500 | " labels=(\n", 501 | " label[start:end]\n", 502 | " + [-100] * padding_to_add # padding if needed\n", 503 | " ), # -100 is a special token for padding of labels,\n", 504 | " attention_masks=(\n", 505 | " encoding.attention_mask[start:end]\n", 506 | " + [0]\n", 507 | " * padding_to_add # 0'd attenetion masks where we added padding\n", 508 | " ),\n", 509 | " )\n", 510 | " )\n", 511 | "\n", 512 | " def __len__(self):\n", 513 | " return len(self.training_examples)\n", 514 | "\n", 515 | " def __getitem__(self, idx) -> TrainingExample:\n", 516 | "\n", 517 | " return self.training_examples[idx]" 518 | ] 519 | }, 520 | { 521 | "cell_type": "markdown", 522 | "metadata": {}, 523 | "source": [ 524 | "### Let's See what comes out\n", 525 | "Below we'll create a dataset instance.\n", 526 | "We first create a label_set, in this case there is only one label, **drug**. \n", 527 | "We then instantiate our Dataset by passing the raw data, the tokenizer and the label_set.\n", 528 | "We get back **TrainingExample** instances with the windowed and padded input_ids and label_ids as well as attention_masks. " 529 | ] 530 | }, 531 | { 532 | "cell_type": "code", 533 | "execution_count": 11, 534 | "metadata": {}, 535 | "outputs": [ 536 | { 537 | "name": "stdout", 538 | "output_type": "stream", 539 | "text": [ 540 | "TrainingExample(input_ids=[1233, 1621, 4420, 18061, 5165, 1114, 4267, 6066, 1465, 3171, 1306, 117, 1126, 27558, 1104, 140], attention_masks=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], labels=[3, 0, 0, 0, 0, 0, 1, 2, 2, 2, 2, 3, 0, 0, 0, 0])\n" 541 | ] 542 | } 543 | ], 544 | "source": [ 545 | "label_set = LabelSet(labels=[\"drug\"])\n", 546 | "ds = TraingDataset(\n", 547 | " data=raw, tokenizer=tokenizer, label_set=label_set, tokens_per_batch=16\n", 548 | ")\n", 549 | "ex = ds[10]\n", 550 | "pprint(ex)" 551 | ] 552 | }, 553 | { 554 | "cell_type": "markdown", 555 | "metadata": {}, 556 | "source": [ 557 | "### Batching\n", 558 | "We still need a way batch these examples. We can't feed a list of TraingExamples to a model, we need to make tensors out of the input_ids and labels. This is easily achieved with a collating function. A collating function gets a list of items from our dataset (in our case a list of TraingExamples) and returns a batched tensors. \n", 559 | "\n", 560 | "We'll simplify things, by making a **TraingBatch** class whose constructor is the collating function" 561 | ] 562 | }, 563 | { 564 | "cell_type": "code", 565 | "execution_count": 12, 566 | "metadata": {}, 567 | "outputs": [], 568 | "source": [ 569 | "import torch\n", 570 | "\n", 571 | "\n", 572 | "class TraingingBatch:\n", 573 | " def __getitem__(self, item):\n", 574 | " return getattr(self, item)\n", 575 | "\n", 576 | " def __init__(self, examples: List[TrainingExample]):\n", 577 | " self.input_ids: torch.Tensor\n", 578 | " self.attention_masks: torch.Tensor\n", 579 | " self.labels: torch.Tensor\n", 580 | " input_ids: IntListList = []\n", 581 | " masks: IntListList = []\n", 582 | " labels: IntListList = []\n", 583 | " for ex in examples:\n", 584 | " input_ids.append(ex.input_ids)\n", 585 | " masks.append(ex.attention_masks)\n", 586 | " labels.append(ex.labels)\n", 587 | " self.input_ids = torch.LongTensor(input_ids)\n", 588 | " self.attention_masks = torch.LongTensor(masks)\n", 589 | " self.labels = torch.LongTensor(labels)" 590 | ] 591 | }, 592 | { 593 | "cell_type": "markdown", 594 | "metadata": {}, 595 | "source": [ 596 | "# Traing Our Model\n", 597 | "With our batching ready, let's use a pre trained model and show how to fine tune it on our new dataset. " 598 | ] 599 | }, 600 | { 601 | "cell_type": "code", 602 | "execution_count": 13, 603 | "metadata": {}, 604 | "outputs": [ 605 | { 606 | "name": "stderr", 607 | "output_type": "stream", 608 | "text": [ 609 | "Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForTokenClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']\n", 610 | "- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n", 611 | "- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", 612 | "Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.weight', 'classifier.bias']\n", 613 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" 614 | ] 615 | }, 616 | { 617 | "name": "stdout", 618 | "output_type": "stream", 619 | "text": [ 620 | "tensor(1.6987, grad_fn=)\n", 621 | "tensor(1.6388, grad_fn=)\n", 622 | "tensor(1.6135, grad_fn=)\n", 623 | "tensor(1.4385, grad_fn=)\n", 624 | "tensor(1.5159, grad_fn=)\n", 625 | "tensor(1.4509, grad_fn=)\n", 626 | "tensor(1.3011, grad_fn=)\n", 627 | "tensor(1.2812, grad_fn=)\n", 628 | "tensor(1.1388, grad_fn=)\n", 629 | "tensor(1.4184, grad_fn=)\n", 630 | "tensor(1.3591, grad_fn=)\n", 631 | "tensor(1.2249, grad_fn=)\n", 632 | "tensor(0.9483, grad_fn=)\n", 633 | "tensor(1.2650, grad_fn=)\n", 634 | "tensor(1.1502, grad_fn=)\n", 635 | "tensor(0.5125, grad_fn=)\n", 636 | "tensor(0.9448, grad_fn=)\n", 637 | "tensor(1.2908, grad_fn=)\n", 638 | "tensor(0.8918, grad_fn=)\n", 639 | "tensor(1.0335, grad_fn=)\n", 640 | "tensor(1.2265, grad_fn=)\n", 641 | "tensor(1.5571, grad_fn=)\n" 642 | ] 643 | } 644 | ], 645 | "source": [ 646 | "from torch.utils.data.dataloader import DataLoader\n", 647 | "from transformers import BertForTokenClassification, AdamW\n", 648 | "\n", 649 | "model = BertForTokenClassification.from_pretrained(\n", 650 | " \"bert-base-cased\", num_labels=len(ds.label_set.ids_to_label.values())\n", 651 | ")\n", 652 | "optimizer = AdamW(model.parameters(), lr=5e-6)\n", 653 | "\n", 654 | "dataloader = DataLoader(\n", 655 | " ds,\n", 656 | " collate_fn=TraingingBatch,\n", 657 | " batch_size=4,\n", 658 | " shuffle=True,\n", 659 | ")\n", 660 | "for num, batch in enumerate(dataloader):\n", 661 | " loss, logits = model(\n", 662 | " input_ids=batch.input_ids,\n", 663 | " attention_mask=batch.attention_masks,\n", 664 | " labels=batch.labels,\n", 665 | " )\n", 666 | " loss.backward()\n", 667 | " optimizer.step()\n", 668 | " print(loss)\n", 669 | " if num > 20:\n", 670 | " break" 671 | ] 672 | } 673 | ], 674 | "metadata": { 675 | "kernelspec": { 676 | "display_name": "Python 3", 677 | "language": "python", 678 | "name": "python3" 679 | }, 680 | "language_info": { 681 | "codemirror_mode": { 682 | "name": "ipython", 683 | "version": 3 684 | }, 685 | "file_extension": ".py", 686 | "mimetype": "text/x-python", 687 | "name": "python", 688 | "nbconvert_exporter": "python", 689 | "pygments_lexer": "ipython3", 690 | "version": "3.6.9" 691 | } 692 | }, 693 | "nbformat": 4, 694 | "nbformat_minor": 2 695 | } 696 | --------------------------------------------------------------------------------