├── .gitignore ├── LICENSE ├── README.md ├── images └── codeless.png ├── models ├── __init__.py ├── bert_classifier.py ├── ner.py └── utils.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .venv 86 | env/ 87 | venv/ 88 | ENV/ 89 | env.bak/ 90 | venv.bak/ 91 | 92 | # Spyder project settings 93 | .spyderproject 94 | .spyproject 95 | 96 | # Rope project settings 97 | .ropeproject 98 | 99 | # mkdocs documentation 100 | /site 101 | 102 | # mypy 103 | .mypy_cache/ 104 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Label Studio for Hugging Face's Transformers 2 | 3 | [Website](https://labelstud.io/) • [Docs](https://labelstud.io/guide) • [Twitter](https://twitter.com/heartexlabs) • [Join Slack Community ](https://slack.labelstud.io/?source=github-1) 4 | 5 |
6 | 7 | **Transfer learning for NLP models by annotating your textual data without any additional coding.** 8 | 9 | This package provides a ready-to-use container that links together: 10 | 11 | - [Label Studio](https://github.com/heartexlabs/label-studio) as annotation frontend 12 | - [Hugging Face's transformers](https://github.com/huggingface/transformers) as machine learning backend for NLP 13 | 14 |
15 | 16 | [](https://github.com/heartexlabs/label-studio-transformers) 17 | 18 | ### Quick Usage 19 | 20 | #### Install Label Studio and other dependencies 21 | 22 | ```bash 23 | pip install -r requirements.txt 24 | ``` 25 | 26 | ##### Create ML backend with BERT classifier 27 | ```bash 28 | label-studio-ml init my-ml-backend --script models/bert_classifier.py 29 | cp models/utils.py my-ml-backend/utils.py 30 | 31 | # Start ML backend at http://localhost:9090 32 | label-studio-ml start my-ml-backend 33 | 34 | # Start Label Studio in the new terminal with the same python environment 35 | label-studio start 36 | ``` 37 | 38 | 1. Create a project with `Choices` and `Text` tags in the labeling config. 39 | 2. Connect the ML backend in the Project settings with `http://localhost:9090` 40 | 41 | ##### Create ML backend with BERT named entity recognizer 42 | ```bash 43 | label-studio-ml init my-ml-backend --script models/ner.py 44 | cp models/utils.py my-ml-backend/utils.py 45 | 46 | # Start ML backend at http://localhost:9090 47 | label-studio-ml start my-ml-backend 48 | 49 | # Start Label Studio in the new terminal with the same python environment 50 | label-studio start 51 | ``` 52 | 53 | 1. Create a project with `Labels` and `Text` tags in the labeling config. 54 | 2. Connect the ML backend in the Project settings with `http://localhost:9090` 55 | 56 | #### Training and inference 57 | 58 | The browser opens at `http://localhost:8080`. Upload your data on **Import** page then annotate by selecting **Labeling** page. 59 | Once you've annotate sufficient amount of data, go to **Model** page and press **Start Training** button. Once training is finished, model automatically starts serving for inference from Label Studio, and you'll find all model checkpoints inside `my-ml-backend//` directory. 60 | 61 | [Click here](https://labelstud.io/guide/ml.html) to read more about how to use Machine Learning backend and build Human-in-the-Loop pipelines with Label Studio 62 | 63 | ## License 64 | 65 | This software is licensed under the [Apache 2.0 LICENSE](/LICENSE) © [Heartex](https://www.heartex.com/). 2020 66 | 67 | 68 | -------------------------------------------------------------------------------- /images/codeless.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanSignal/label-studio-transformers/de267cbd903ebf907d4449323fb81e57e3d38257/images/codeless.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanSignal/label-studio-transformers/de267cbd903ebf907d4449323fb81e57e3d38257/models/__init__.py -------------------------------------------------------------------------------- /models/bert_classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | 5 | from torch.utils.data import SequentialSampler 6 | from tqdm import tqdm, trange 7 | from collections import deque 8 | from tensorboardX import SummaryWriter 9 | from transformers import BertTokenizer, BertForSequenceClassification 10 | from transformers import AdamW, get_linear_schedule_with_warmup 11 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler 12 | 13 | from label_studio_ml.model import LabelStudioMLBase 14 | 15 | from utils import prepare_texts, calc_slope 16 | 17 | 18 | if torch.cuda.is_available(): 19 | device = torch.device("cuda") 20 | print('There are %d GPU(s) available.' % torch.cuda.device_count()) 21 | print('We will use the GPU:', torch.cuda.get_device_name(0)) 22 | else: 23 | print('No GPU available, using the CPU instead.') 24 | device = torch.device("cpu") 25 | 26 | 27 | class BertClassifier(LabelStudioMLBase): 28 | 29 | def __init__( 30 | self, pretrained_model='bert-base-multilingual-cased', maxlen=64, 31 | batch_size=32, num_epochs=100, logging_steps=1, train_logs=None, **kwargs 32 | ): 33 | super(BertClassifier, self).__init__(**kwargs) 34 | self.pretrained_model = pretrained_model 35 | self.maxlen = maxlen 36 | self.batch_size = batch_size 37 | self.num_epochs = num_epochs 38 | self.logging_steps = logging_steps 39 | self.train_logs = train_logs 40 | 41 | # then collect all keys from config which will be used to extract data from task and to form prediction 42 | # Parsed label config contains only one output of type 43 | assert len(self.parsed_label_config) == 1 44 | self.from_name, self.info = list(self.parsed_label_config.items())[0] 45 | assert self.info['type'] == 'Choices' 46 | 47 | # the model has only one textual input 48 | assert len(self.info['to_name']) == 1 49 | assert len(self.info['inputs']) == 1 50 | assert self.info['inputs'][0]['type'] == 'Text' 51 | self.to_name = self.info['to_name'][0] 52 | self.value = self.info['inputs'][0]['value'] 53 | 54 | if not self.train_output: 55 | self.labels = self.info['labels'] 56 | self.reset_model('bert-base-multilingual-cased', cache_dir=None, device='cpu') 57 | print('Initialized with from_name={from_name}, to_name={to_name}, labels={labels}'.format( 58 | from_name=self.from_name, to_name=self.to_name, labels=str(self.labels) 59 | )) 60 | else: 61 | self.load(self.train_output) 62 | print('Loaded from train output with from_name={from_name}, to_name={to_name}, labels={labels}'.format( 63 | from_name=self.from_name, to_name=self.to_name, labels=str(self.labels) 64 | )) 65 | 66 | def reset_model(self, pretrained_model, cache_dir, device): 67 | model = BertForSequenceClassification.from_pretrained( 68 | pretrained_model, 69 | num_labels=len(self.labels), 70 | output_attentions=False, 71 | output_hidden_states=False, 72 | cache_dir=cache_dir 73 | ) 74 | model.to(device) 75 | return model 76 | 77 | def load(self, train_output): 78 | pretrained_model = train_output['model_path'] 79 | self.tokenizer = BertTokenizer.from_pretrained(pretrained_model) 80 | self.model = BertForSequenceClassification.from_pretrained(pretrained_model) 81 | self.model.to(device) 82 | self.model.eval() 83 | self.batch_size = train_output['batch_size'] 84 | self.labels = train_output['labels'] 85 | self.maxlen = train_output['maxlen'] 86 | 87 | @property 88 | def not_trained(self): 89 | return not hasattr(self, 'tokenizer') 90 | 91 | def predict(self, tasks, **kwargs): 92 | if self.not_trained: 93 | print('Can\'t get prediction because model is not trained yet.') 94 | return [] 95 | 96 | texts = [task['data'][self.value] for task in tasks] 97 | predict_dataloader = prepare_texts(texts, self.tokenizer, self.maxlen, SequentialSampler, self.batch_size) 98 | 99 | pred_labels, pred_scores = [], [] 100 | for batch in predict_dataloader: 101 | batch = tuple(t.to(device) for t in batch) 102 | inputs = { 103 | 'input_ids': batch[0], 104 | 'attention_mask': batch[1] 105 | } 106 | with torch.no_grad(): 107 | outputs = self.model(**inputs) 108 | logits = outputs[0] 109 | 110 | batch_preds = logits.detach().cpu().numpy() 111 | 112 | argmax_batch_preds = np.argmax(batch_preds, axis=-1) 113 | pred_labels.extend(str(self.labels[i]) for i in argmax_batch_preds) 114 | 115 | max_batch_preds = np.max(batch_preds, axis=-1) 116 | pred_scores.extend(float(s) for s in max_batch_preds) 117 | 118 | predictions = [] 119 | for predicted_label, score in zip(pred_labels, pred_scores): 120 | result = [{ 121 | 'from_name': self.from_name, 122 | 'to_name': self.to_name, 123 | 'type': 'choices', 124 | 'value': {'choices': [predicted_label]} 125 | }] 126 | 127 | predictions.append({'result': result, 'score': score}) 128 | return predictions 129 | 130 | def fit(self, completions, workdir=None, cache_dir=None, **kwargs): 131 | input_texts = [] 132 | output_labels, output_labels_idx = [], [] 133 | label2idx = {l: i for i, l in enumerate(self.labels)} 134 | for completion in completions: 135 | # get input text from task data 136 | 137 | if completion['annotations'][0].get('skipped'): 138 | continue 139 | 140 | input_text = completion['data'][self.value] 141 | input_texts.append(input_text) 142 | 143 | # get an annotation 144 | output_label = completion['annotations'][0]['result'][0]['value']['choices'][0] 145 | output_labels.append(output_label) 146 | output_label_idx = label2idx[output_label] 147 | output_labels_idx.append(output_label_idx) 148 | 149 | new_labels = set(output_labels) 150 | added_labels = new_labels - set(self.labels) 151 | if len(added_labels) > 0: 152 | print('Label set has been changed. Added ones: ' + str(list(added_labels))) 153 | self.labels = list(sorted(new_labels)) 154 | label2idx = {l: i for i, l in enumerate(self.labels)} 155 | output_labels_idx = [label2idx[label] for label in output_labels] 156 | 157 | tokenizer = BertTokenizer.from_pretrained(self.pretrained_model, cache_dir=cache_dir) 158 | 159 | train_dataloader = prepare_texts(input_texts, tokenizer, self.maxlen, RandomSampler, self.batch_size, output_labels_idx) 160 | model = self.reset_model(self.pretrained_model, cache_dir, device) 161 | 162 | total_steps = len(train_dataloader) * self.num_epochs 163 | optimizer = AdamW(model.parameters(), lr=2e-5, eps=1e-8) 164 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps) 165 | global_step = 0 166 | total_loss, logging_loss = 0.0, 0.0 167 | model.zero_grad() 168 | train_iterator = trange(self.num_epochs, desc='Epoch') 169 | if self.train_logs: 170 | tb_writer = SummaryWriter(logdir=os.path.join(self.train_logs, os.path.basename(self.output_dir))) 171 | else: 172 | tb_writer = None 173 | loss_queue = deque(maxlen=10) 174 | for epoch in train_iterator: 175 | epoch_iterator = tqdm(train_dataloader, desc='Iteration') 176 | for step, batch in enumerate(epoch_iterator): 177 | model.train() 178 | batch = tuple(t.to(device) for t in batch) 179 | inputs = {'input_ids': batch[0], 180 | 'attention_mask': batch[1], 181 | 'labels': batch[2]} 182 | outputs = model(**inputs) 183 | loss = outputs[0] 184 | loss.backward() 185 | total_loss += loss.item() 186 | 187 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 188 | optimizer.step() 189 | scheduler.step() 190 | model.zero_grad() 191 | global_step += 1 192 | if global_step % self.logging_steps == 0: 193 | last_loss = (total_loss - logging_loss) / self.logging_steps 194 | loss_queue.append(last_loss) 195 | if tb_writer: 196 | tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step) 197 | tb_writer.add_scalar('loss', last_loss, global_step) 198 | logging_loss = total_loss 199 | 200 | # slope-based early stopping 201 | if len(loss_queue) == loss_queue.maxlen: 202 | slope = calc_slope(loss_queue) 203 | if tb_writer: 204 | tb_writer.add_scalar('slope', slope, global_step) 205 | if abs(slope) < 1e-2: 206 | break 207 | 208 | if tb_writer: 209 | tb_writer.close() 210 | 211 | model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training # noqa 212 | model_to_save.save_pretrained(workdir) 213 | tokenizer.save_pretrained(workdir) 214 | 215 | return { 216 | 'model_path': workdir, 217 | 'batch_size': self.batch_size, 218 | 'maxlen': self.maxlen, 219 | 'pretrained_model': self.pretrained_model, 220 | 'labels': self.labels 221 | } 222 | -------------------------------------------------------------------------------- /models/ner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import re 4 | import os 5 | import io 6 | import logging 7 | 8 | from functools import partial 9 | from itertools import groupby 10 | from operator import itemgetter 11 | from torch.nn import CrossEntropyLoss 12 | from torch.utils.data import Dataset, DataLoader 13 | from tqdm import tqdm, trange 14 | from tensorboardX import SummaryWriter 15 | from collections import deque 16 | 17 | from transformers import ( 18 | BertTokenizer, BertForTokenClassification, BertConfig, 19 | RobertaConfig, RobertaForTokenClassification, RobertaTokenizer, 20 | DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer, 21 | CamembertConfig, CamembertForTokenClassification, CamembertTokenizer, 22 | AutoConfig, AutoModelForTokenClassification, AutoTokenizer 23 | ) 24 | from transformers import AdamW, get_linear_schedule_with_warmup 25 | 26 | from label_studio_ml.model import LabelStudioMLBase 27 | from utils import calc_slope 28 | 29 | 30 | logger = logging.getLogger(__name__) 31 | 32 | 33 | ALL_MODELS = sum( 34 | [list(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, DistilBertConfig)], 35 | []) 36 | 37 | MODEL_CLASSES = { 38 | 'bert': (BertConfig, BertForTokenClassification, BertTokenizer), 39 | 'roberta': (RobertaConfig, RobertaForTokenClassification, RobertaTokenizer), 40 | 'distilbert': (DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer), 41 | 'camembert': (CamembertConfig, CamembertForTokenClassification, CamembertTokenizer), 42 | } 43 | 44 | 45 | class SpanLabeledTextDataset(Dataset): 46 | 47 | def __init__( 48 | self, list_of_strings, list_of_spans=None, tokenizer=None, tag_idx_map=None, 49 | cls_token='[CLS]', sep_token='[SEP]', pad_token_label_id=-1, max_seq_length=128, sep_token_extra=False, 50 | cls_token_at_end=False, sequence_a_segment_id=0, cls_token_segment_id=1, mask_padding_with_zero=True 51 | ): 52 | self.list_of_strings = list_of_strings 53 | self.list_of_spans = list_of_spans or [[] * len(list_of_strings)] 54 | self.tokenizer = tokenizer 55 | self.cls_token = cls_token 56 | self.sep_token = sep_token 57 | self.pad_token_label_id = pad_token_label_id 58 | self.max_seq_length = max_seq_length 59 | self.sep_token_extra = sep_token_extra 60 | self.cls_token_at_end = cls_token_at_end 61 | self.sequence_a_segment_id = sequence_a_segment_id 62 | self.cls_token_segment_id = cls_token_segment_id 63 | self.mask_padding_with_zero = mask_padding_with_zero 64 | 65 | (self.original_list_of_tokens, self.original_list_of_tags, tag_idx_map_, 66 | original_list_of_tokens_start_map) = self._prepare_data() 67 | 68 | if tag_idx_map is None: 69 | self.tag_idx_map = tag_idx_map_ 70 | else: 71 | self.tag_idx_map = tag_idx_map 72 | 73 | (self.list_of_tokens, self.list_of_token_ids, self.list_of_labels, self.list_of_label_ids, 74 | self.list_of_segment_ids, self.list_of_token_start_map) = [], [], [], [], [], [] 75 | 76 | for original_tokens, original_tags, original_token_start_map in zip( 77 | self.original_list_of_tokens, 78 | self.original_list_of_tags, 79 | original_list_of_tokens_start_map 80 | ): 81 | tokens, token_ids, labels, label_ids, segment_ids, token_start_map = self._convert_to_features( 82 | original_tokens, original_tags, self.tag_idx_map, original_token_start_map) 83 | self.list_of_token_ids.append(token_ids) 84 | self.list_of_tokens.append(tokens) 85 | self.list_of_labels.append(labels) 86 | self.list_of_segment_ids.append(segment_ids) 87 | self.list_of_label_ids.append(label_ids) 88 | self.list_of_token_start_map.append(token_start_map) 89 | 90 | def get_params_dict(self): 91 | return { 92 | 'cls_token': self.cls_token, 93 | 'sep_token': self.sep_token, 94 | 'pad_token_label_id': self.pad_token_label_id, 95 | 'max_seq_length': self.max_seq_length, 96 | 'sep_token_extra': self.sep_token_extra, 97 | 'cls_token_at_end': self.cls_token_at_end, 98 | 'sequence_a_segment_id': self.sequence_a_segment_id, 99 | 'cls_token_segment_id': self.cls_token_segment_id, 100 | 'mask_padding_with_zero': self.mask_padding_with_zero 101 | } 102 | 103 | def dump(self, output_file): 104 | with io.open(output_file, mode='w') as f: 105 | for tokens, labels in zip(self.list_of_tokens, self.list_of_labels): 106 | for token, label in zip(tokens, labels): 107 | f.write(f'{token} {label}\n') 108 | f.write('\n') 109 | 110 | def _convert_to_features(self, words, labels, label_map, list_token_start_map): 111 | tokens, out_labels, label_ids, tokens_idx_map = [], [], [], [] 112 | for i, (word, label, token_start) in enumerate(zip(words, labels, list_token_start_map)): 113 | word_tokens = self.tokenizer.tokenize(word) 114 | tokens.extend(word_tokens) 115 | tokens_idx_map.extend([token_start] * len(word_tokens)) 116 | # Use the real label id for the first token of the word, and padding ids for the remaining tokens 117 | label_ids.extend([label_map[label]] + [self.pad_token_label_id] * (len(word_tokens) - 1)) 118 | out_labels.extend([label] + ['X'] * (len(word_tokens) - 1)) 119 | 120 | # Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa. 121 | special_tokens_count = 3 if self.sep_token_extra else 2 122 | if len(tokens) > self.max_seq_length - special_tokens_count: 123 | tokens = tokens[:(self.max_seq_length - special_tokens_count)] 124 | label_ids = label_ids[:(self.max_seq_length - special_tokens_count)] 125 | out_labels = out_labels[:(self.max_seq_length - special_tokens_count)] 126 | tokens_idx_map = tokens_idx_map[:(self.max_seq_length - special_tokens_count)] 127 | 128 | # The convention in BERT is: 129 | # (a) For sequence pairs: 130 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 131 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 132 | # (b) For single sequences: 133 | # tokens: [CLS] the dog is hairy . [SEP] 134 | # type_ids: 0 0 0 0 0 0 0 135 | # 136 | # Where "type_ids" are used to indicate whether this is the first 137 | # sequence or the second sequence. The embedding vectors for `type=0` and 138 | # `type=1` were learned during pre-training and are added to the wordpiece 139 | # embedding vector (and position vector). This is not *strictly* necessary 140 | # since the [SEP] token unambiguously separates the sequences, but it makes 141 | # it easier for the model to learn the concept of sequences. 142 | # 143 | # For classification tasks, the first vector (corresponding to [CLS]) is 144 | # used as as the "sentence vector". Note that this only makes sense because 145 | # the entire model is fine-tuned. 146 | tokens += [self.sep_token] 147 | label_ids += [self.pad_token_label_id] 148 | out_labels += ['X'] 149 | tokens_idx_map += [-1] 150 | if self.sep_token_extra: 151 | # roberta uses an extra separator b/w pairs of sentences 152 | tokens += [self.sep_token] 153 | label_ids += [self.pad_token_label_id] 154 | out_labels += ['X'] 155 | tokens_idx_map += [-1] 156 | segment_ids = [self.sequence_a_segment_id] * len(tokens) 157 | if self.cls_token_at_end: 158 | tokens += [self.cls_token] 159 | label_ids += [self.pad_token_label_id] 160 | out_labels += ['X'] 161 | segment_ids += [self.cls_token_segment_id] 162 | tokens_idx_map += [-1] 163 | else: 164 | tokens = [self.cls_token] + tokens 165 | label_ids = [self.pad_token_label_id] + label_ids 166 | out_labels = ['X'] + out_labels 167 | segment_ids = [self.cls_token_segment_id] + segment_ids 168 | tokens_idx_map = [-1] + tokens_idx_map 169 | 170 | token_ids = self.tokenizer.convert_tokens_to_ids(tokens) 171 | 172 | return tokens, token_ids, out_labels, label_ids, segment_ids, tokens_idx_map 173 | 174 | def _apply_tokenizer(self, original_tokens, original_tags): 175 | out_tokens, out_tags, out_maps = [], [], [] 176 | for i, (original_token, original_tag) in enumerate(zip(original_tokens, original_tags)): 177 | tokens = self.tokenizer.tokenize(original_token) 178 | out_tokens.extend(tokens) 179 | out_maps.extend([i] * len(tokens)) 180 | start_tag = original_tag.startswith('B-') 181 | for j in range(len(tokens)): 182 | if (j == 0 and start_tag) or original_tag == 'O': 183 | out_tags.append(original_tag) 184 | else: 185 | out_tags.append(f'I-{original_tag[2:]}') 186 | return out_tokens, out_tags, out_maps 187 | 188 | def _prepare_data(self): 189 | list_of_tokens, list_of_tags, list_of_token_idx_maps = [], [], [] 190 | tag_idx_map = {'O': 0} 191 | for text, spans in zip(self.list_of_strings, self.list_of_spans): 192 | if not text: 193 | continue 194 | 195 | tokens = [] 196 | start = 0 197 | for t in text.split(): 198 | tokens.append((t, start)) 199 | start += len(t) + 1 200 | 201 | if spans: 202 | spans = list(sorted(spans, key=itemgetter('start'))) 203 | span = spans.pop(0) 204 | prefix = 'B-' 205 | tags = [] 206 | for token, token_start in tokens: 207 | token_end = token_start + len(token) - 1 208 | 209 | # token precedes current span 210 | if not span or token_end < span['start']: 211 | tags.append('O') 212 | continue 213 | 214 | # token jumps over the span (it could happens 215 | # when prev label ends with whitespaces, e.g. "cat " "too" or span created for whitespace) 216 | if token_start > span['end']: 217 | 218 | prefix = 'B-' 219 | no_more_spans = False 220 | while token_start > span['end']: 221 | if not len(spans): 222 | no_more_spans = True 223 | break 224 | span = spans.pop(0) 225 | 226 | if no_more_spans: 227 | tags.append('O') 228 | span = None 229 | continue 230 | 231 | if token_end < span['start']: 232 | tags.append('O') 233 | continue 234 | 235 | label = span['label'] 236 | if label.startswith(prefix): 237 | tag = label 238 | else: 239 | tag = f'{prefix}{label}' 240 | tags.append(tag) 241 | if tag not in tag_idx_map: 242 | tag_idx_map[tag] = len(tag_idx_map) 243 | if span['end'] > token_end: 244 | prefix = 'I-' 245 | elif len(spans): 246 | span = spans.pop(0) 247 | prefix = 'B-' 248 | else: 249 | span = None 250 | else: 251 | tags = ['O'] * len(tokens) 252 | 253 | list_of_tokens.append([t[0] for t in tokens]) 254 | list_of_token_idx_maps.append([t[1] for t in tokens]) 255 | list_of_tags.append(tags) 256 | 257 | return list_of_tokens, list_of_tags, tag_idx_map, list_of_token_idx_maps 258 | 259 | def __len__(self): 260 | return len(self.list_of_token_ids) 261 | 262 | def __getitem__(self, idx): 263 | return { 264 | 'tokens': self.list_of_token_ids[idx], 265 | 'labels': self.list_of_label_ids[idx], 266 | 'segments': self.list_of_segment_ids[idx], 267 | 'token_start_map': self.list_of_token_start_map[idx], 268 | 'string': self.list_of_strings[idx] 269 | } 270 | 271 | @property 272 | def num_labels(self): 273 | return len(self.tag_idx_map) 274 | 275 | @classmethod 276 | def pad_sequences(cls, batch, mask_padding_with_zero, pad_on_left, pad_token, pad_token_segment_id, pad_token_label_id): 277 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 278 | # tokens are attended to. 279 | max_seq_length = max(map(len, (sample['tokens'] for sample in batch))) 280 | batch_input_ids, batch_label_ids, batch_segment_ids, batch_input_mask, batch_token_start_map = [], [], [], [], [] 281 | batch_strings = [] 282 | for sample in batch: 283 | input_ids = sample['tokens'] 284 | label_ids = sample['labels'] 285 | segment_ids = sample['segments'] 286 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 287 | # tokens are attended to. 288 | input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) 289 | padding_length = max_seq_length - len(input_ids) 290 | if pad_on_left: 291 | input_ids = ([pad_token] * padding_length) + input_ids 292 | input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask 293 | segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids 294 | label_ids = ([pad_token_label_id] * padding_length) + label_ids 295 | else: 296 | input_ids += ([pad_token] * padding_length) 297 | input_mask += ([0 if mask_padding_with_zero else 1] * padding_length) 298 | segment_ids += ([pad_token_segment_id] * padding_length) 299 | label_ids += ([pad_token_label_id] * padding_length) 300 | batch_input_ids.append(input_ids) 301 | batch_label_ids.append(label_ids) 302 | batch_segment_ids.append(segment_ids) 303 | batch_input_mask.append(input_mask) 304 | batch_token_start_map.append(sample['token_start_map']) 305 | batch_strings.append(sample['string']) 306 | 307 | return { 308 | 'input_ids': torch.tensor(batch_input_ids, dtype=torch.long), 309 | 'label_ids': torch.tensor(batch_label_ids, dtype=torch.long), 310 | 'segment_ids': torch.tensor(batch_segment_ids, dtype=torch.long), 311 | 'input_mask': torch.tensor(batch_input_mask, dtype=torch.long), 312 | 'token_start_map': batch_token_start_map, 313 | 'strings': batch_strings 314 | } 315 | 316 | @classmethod 317 | def get_padding_function(cls, model_type, tokenizer, pad_token_label_id): 318 | return partial( 319 | cls.pad_sequences, 320 | mask_padding_with_zero=True, 321 | pad_on_left=model_type in ['xlnet'], 322 | pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0], 323 | pad_token_segment_id=4 if model_type in ['xlnet'] else 0, 324 | pad_token_label_id=pad_token_label_id 325 | ) 326 | 327 | 328 | class TransformersBasedTagger(LabelStudioMLBase): 329 | 330 | def __init__(self, **kwargs): 331 | super(TransformersBasedTagger, self).__init__(**kwargs) 332 | 333 | assert len(self.parsed_label_config) == 1 334 | self.from_name, self.info = list(self.parsed_label_config.items())[0] 335 | assert self.info['type'] == 'Labels' 336 | 337 | # the model has only one textual input 338 | assert len(self.info['to_name']) == 1 339 | assert len(self.info['inputs']) == 1 340 | assert self.info['inputs'][0]['type'] == 'Text' 341 | self.to_name = self.info['to_name'][0] 342 | self.value = self.info['inputs'][0]['value'] 343 | 344 | if not self.train_output: 345 | self.labels = self.info['labels'] 346 | else: 347 | self.load(self.train_output) 348 | 349 | def load(self, train_output): 350 | pretrained_model = train_output['model_path'] 351 | self._model_type = train_output['model_type'] 352 | _, model_class, tokenizer_class = MODEL_CLASSES[train_output['model_type']] 353 | 354 | self._tokenizer = AutoTokenizer.from_pretrained(pretrained_model) 355 | self._model = AutoModelForTokenClassification.from_pretrained(pretrained_model) 356 | self._batch_size = train_output['batch_size'] 357 | self._pad_token = self._tokenizer.convert_tokens_to_ids([self._tokenizer.pad_token])[0] 358 | self._pad_token_label_id = train_output['pad_token_label_id'] 359 | self._label_map = train_output['label_map'] 360 | self._mask_padding_with_zero = True 361 | self._dataset_params_dict = train_output['dataset_params_dict'] 362 | 363 | self._batch_padding = SpanLabeledTextDataset.get_padding_function( 364 | self._model_type, self._tokenizer, self._pad_token_label_id) 365 | 366 | def predict(self, tasks, **kwargs): 367 | texts = [task['data'][self.value] for task in tasks] 368 | predict_set = SpanLabeledTextDataset(texts, tokenizer=self._tokenizer, **self._dataset_params_dict) 369 | from_name = self.from_name 370 | to_name = self.to_name 371 | predict_loader = DataLoader( 372 | dataset=predict_set, 373 | batch_size=self._batch_size, 374 | collate_fn=self._batch_padding 375 | ) 376 | 377 | results = [] 378 | for batch in tqdm(predict_loader, desc='Prediction'): 379 | inputs = { 380 | 'input_ids': batch['input_ids'], 381 | 'attention_mask': batch['input_mask'], 382 | 'token_type_ids': batch['segment_ids'] 383 | } 384 | if self._model_type == 'distilbert': 385 | inputs.pop('token_type_ids') 386 | with torch.no_grad(): 387 | model_output = self._model(**inputs) 388 | logits = model_output[0] 389 | 390 | batch_preds = logits.detach().cpu().numpy() 391 | argmax_batch_preds = np.argmax(batch_preds, axis=-1) 392 | max_batch_preds = np.max(batch_preds, axis=-1) 393 | input_mask = batch['input_mask'].detach().cpu().numpy() 394 | batch_token_start_map = batch['token_start_map'] 395 | batch_strings = batch['strings'] 396 | 397 | for max_preds, argmax_preds, mask_tokens, token_start_map, string in zip( 398 | max_batch_preds, argmax_batch_preds, input_mask, batch_token_start_map, batch_strings 399 | ): 400 | preds, scores, starts = [], [], [] 401 | for max_pred, argmax_pred, mask_token, token_start in zip(max_preds, argmax_preds, mask_tokens, token_start_map): 402 | if token_start != -1: 403 | preds.append(self._label_map[str(argmax_pred)]) 404 | scores.append(max_pred) 405 | starts.append(token_start) 406 | mean_score = np.mean(scores) if len(scores) > 0 else 0 407 | 408 | result = [] 409 | 410 | for label, group in groupby(zip(preds, starts, scores), key=lambda i: re.sub('^(B-|I-)', '', i[0])): 411 | _, group_start, _ = list(group)[0] 412 | if len(result) > 0: 413 | if group_start == 0: 414 | result.pop(-1) 415 | else: 416 | result[-1]['value']['end'] = group_start - 1 417 | if label != 'O': 418 | result.append({ 419 | 'from_name': from_name, 420 | 'to_name': to_name, 421 | 'type': 'labels', 422 | 'value': { 423 | 'labels': [label], 424 | 'start': group_start, 425 | 'end': None, 426 | 'text': '...' 427 | } 428 | }) 429 | if result and result[-1]['value']['end'] is None: 430 | result[-1]['value']['end'] = len(string) 431 | results.append({ 432 | 'result': result, 433 | 'score': float(mean_score), 434 | 'cluster': None 435 | }) 436 | return results 437 | 438 | def get_spans(self, completion): 439 | spans = [] 440 | for r in completion['result']: 441 | if r['from_name'] == self.from_name and r['to_name'] == self.to_name: 442 | labels = r['value'].get('labels') 443 | if not isinstance(labels, list) or len(labels) == 0: 444 | logger.warning(f'Error while parsing {r}: list type expected for "labels"') 445 | continue 446 | label = labels[0] 447 | start, end = r['value'].get('start'), r['value'].get('end') 448 | if start is None or end is None: 449 | logger.warning(f'Error while parsing {r}: "labels" should contain "start" and "end" fields') 450 | spans.append({ 451 | 'label': label, 452 | 'start': start, 453 | 'end': end 454 | }) 455 | return spans 456 | 457 | def fit( 458 | self, completions, workdir=None, model_type='bert', pretrained_model='bert-base-uncased', 459 | batch_size=32, learning_rate=5e-5, adam_epsilon=1e-8, num_train_epochs=100, weight_decay=0.0, logging_steps=1, 460 | warmup_steps=0, save_steps=50, dump_dataset=True, cache_dir='~/.heartex/cache', train_logs=None, 461 | **kwargs 462 | ): 463 | train_logs = train_logs or os.path.join(workdir, 'train_logs') 464 | os.makedirs(train_logs, exist_ok=True) 465 | logger.debug('Prepare models') 466 | cache_dir = os.path.expanduser(cache_dir) 467 | os.makedirs(cache_dir, exist_ok=True) 468 | 469 | model_type = model_type.lower() 470 | # assert model_type in MODEL_CLASSES.keys(), f'Input model type {model_type} not in {MODEL_CLASSES.keys()}' 471 | # assert pretrained_model in ALL_MODELS, f'Pretrained model {pretrained_model} not in {ALL_MODELS}' 472 | 473 | tokenizer = AutoTokenizer.from_pretrained(pretrained_model, cache_dir=cache_dir) 474 | 475 | logger.debug('Read data') 476 | # read input data stream 477 | texts, list_of_spans = [], [] 478 | for item in completions: 479 | texts.append(item['data'][self.value]) 480 | list_of_spans.append(self.get_spans(item['annotations'][0])) 481 | 482 | logger.debug('Prepare dataset') 483 | pad_token_label_id = CrossEntropyLoss().ignore_index 484 | train_set = SpanLabeledTextDataset( 485 | texts, list_of_spans, tokenizer, 486 | cls_token_at_end=model_type in ['xlnet'], 487 | cls_token_segment_id=2 if model_type in ['xlnet'] else 0, 488 | sep_token_extra=model_type in ['roberta'], 489 | pad_token_label_id=pad_token_label_id 490 | ) 491 | 492 | if dump_dataset: 493 | dataset_file = os.path.join(workdir, 'train_set.txt') 494 | train_set.dump(dataset_file) 495 | 496 | # config = config_class.from_pretrained(pretrained_model, num_labels=train_set.num_labels, cache_dir=cache_dir) 497 | config = AutoConfig.from_pretrained(pretrained_model, num_labels=train_set.num_labels, cache_dir=cache_dir) 498 | # model = model_class.from_pretrained(pretrained_model, config=config, cache_dir=cache_dir) 499 | model = AutoModelForTokenClassification.from_pretrained(pretrained_model, config=config, cache_dir=cache_dir) 500 | 501 | batch_padding = SpanLabeledTextDataset.get_padding_function(model_type, tokenizer, pad_token_label_id) 502 | 503 | train_loader = DataLoader( 504 | dataset=train_set, 505 | batch_size=batch_size, 506 | shuffle=True, 507 | collate_fn=batch_padding 508 | ) 509 | 510 | no_decay = ['bias', 'LayerNorm.weight'] 511 | optimizer_grouped_parameters = [ 512 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 513 | 'weight_decay': weight_decay}, 514 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 515 | ] 516 | 517 | num_training_steps = len(train_loader) * num_train_epochs 518 | optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=adam_epsilon) 519 | scheduler = get_linear_schedule_with_warmup( 520 | optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps) 521 | 522 | tr_loss, logging_loss = 0, 0 523 | global_step = 0 524 | if train_logs: 525 | tb_writer = SummaryWriter(logdir=os.path.join(train_logs, os.path.basename(workdir))) 526 | epoch_iterator = trange(num_train_epochs, desc='Epoch') 527 | loss_queue = deque(maxlen=10) 528 | for _ in epoch_iterator: 529 | batch_iterator = tqdm(train_loader, desc='Batch') 530 | for step, batch in enumerate(batch_iterator): 531 | 532 | model.train() 533 | inputs = { 534 | 'input_ids': batch['input_ids'], 535 | 'attention_mask': batch['input_mask'], 536 | 'labels': batch['label_ids'], 537 | 'token_type_ids': batch['segment_ids'] 538 | } 539 | if model_type == 'distilbert': 540 | inputs.pop('token_type_ids') 541 | 542 | model_output = model(**inputs) 543 | loss = model_output[0] 544 | loss.backward() 545 | tr_loss += loss.item() 546 | optimizer.step() 547 | scheduler.step() 548 | model.zero_grad() 549 | global_step += 1 550 | if global_step % logging_steps == 0: 551 | last_loss = (tr_loss - logging_loss) / logging_steps 552 | loss_queue.append(last_loss) 553 | if train_logs: 554 | tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step) 555 | tb_writer.add_scalar('loss', last_loss, global_step) 556 | logging_loss = tr_loss 557 | 558 | # slope-based early stopping 559 | if len(loss_queue) == loss_queue.maxlen: 560 | slope = calc_slope(loss_queue) 561 | if train_logs: 562 | tb_writer.add_scalar('slope', slope, global_step) 563 | if abs(slope) < 1e-2: 564 | break 565 | 566 | if train_logs: 567 | tb_writer.close() 568 | 569 | model_to_save = model.module if hasattr(model, "module") else model # Take care of distributed/parallel training 570 | model_to_save.save_pretrained(workdir) 571 | tokenizer.save_pretrained(workdir) 572 | label_map = {i: t for t, i in train_set.tag_idx_map.items()} 573 | 574 | return { 575 | 'model_path': workdir, 576 | 'batch_size': batch_size, 577 | 'pad_token_label_id': pad_token_label_id, 578 | 'dataset_params_dict': train_set.get_params_dict(), 579 | 'model_type': model_type, 580 | 'pretrained_model': pretrained_model, 581 | 'label_map': label_map 582 | } 583 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from torch.utils.data import TensorDataset, DataLoader 5 | 6 | 7 | def pad_sequences(input_ids, maxlen): 8 | padded_ids = [] 9 | for ids in input_ids: 10 | nonpad = min(len(ids), maxlen) 11 | pids = [ids[i] for i in range(nonpad)] 12 | for i in range(nonpad, maxlen): 13 | pids.append(0) 14 | padded_ids.append(pids) 15 | return padded_ids 16 | 17 | 18 | def prepare_texts(texts, tokenizer, maxlen, sampler_class, batch_size, choices_ids=None): 19 | # create input token indices 20 | input_ids = [] 21 | for text in texts: 22 | input_ids.append(tokenizer.encode(text, add_special_tokens=True)) 23 | # input_ids = pad_sequences(input_ids, maxlen=maxlen, dtype='long', value=0, truncating='post', padding='post') 24 | input_ids = pad_sequences(input_ids, maxlen) 25 | # Create attention masks 26 | attention_masks = [] 27 | for sent in input_ids: 28 | attention_masks.append([int(token_id > 0) for token_id in sent]) 29 | 30 | if choices_ids is not None: 31 | dataset = TensorDataset(torch.tensor(input_ids, dtype=torch.long), torch.tensor(attention_masks, dtype=torch.long), torch.tensor(choices_ids, dtype=torch.long)) 32 | else: 33 | dataset = TensorDataset(torch.tensor(input_ids, dtype=torch.long), torch.tensor(attention_masks, dtype=torch.long)) 34 | sampler = sampler_class(dataset) 35 | dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size) 36 | return dataloader 37 | 38 | 39 | def calc_slope(y): 40 | n = len(y) 41 | if n == 1: 42 | raise ValueError('Can\'t compute slope for array of length=1') 43 | x_mean = (n + 1) / 2 44 | x2_mean = (n + 1) * (2 * n + 1) / 6 45 | xy_mean = np.average(y, weights=np.arange(1, n + 1)) 46 | y_mean = np.mean(y) 47 | slope = (xy_mean - x_mean * y_mean) / (x2_mean - x_mean * x_mean) 48 | return slope -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.13.1 2 | transformers==4.4.2 3 | tensorboardX==1.9 4 | label-studio>=1.0.0 5 | git+git://github.com/heartexlabs/label-studio-ml-backend@master#egg=label-studio-ml 6 | --------------------------------------------------------------------------------