├── .gitignore
├── LICENSE
├── README.md
├── images
└── codeless.png
├── models
├── __init__.py
├── bert_classifier.py
├── ner.py
└── utils.py
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .venv
86 | env/
87 | venv/
88 | ENV/
89 | env.bak/
90 | venv.bak/
91 |
92 | # Spyder project settings
93 | .spyderproject
94 | .spyproject
95 |
96 | # Rope project settings
97 | .ropeproject
98 |
99 | # mkdocs documentation
100 | /site
101 |
102 | # mypy
103 | .mypy_cache/
104 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Label Studio for Hugging Face's Transformers
2 |
3 | [Website](https://labelstud.io/) • [Docs](https://labelstud.io/guide) • [Twitter](https://twitter.com/heartexlabs) • [Join Slack Community
](https://slack.labelstud.io/?source=github-1)
4 |
5 |
6 |
7 | **Transfer learning for NLP models by annotating your textual data without any additional coding.**
8 |
9 | This package provides a ready-to-use container that links together:
10 |
11 | - [Label Studio](https://github.com/heartexlabs/label-studio) as annotation frontend
12 | - [Hugging Face's transformers](https://github.com/huggingface/transformers) as machine learning backend for NLP
13 |
14 |
15 |
16 | [
](https://github.com/heartexlabs/label-studio-transformers)
17 |
18 | ### Quick Usage
19 |
20 | #### Install Label Studio and other dependencies
21 |
22 | ```bash
23 | pip install -r requirements.txt
24 | ```
25 |
26 | ##### Create ML backend with BERT classifier
27 | ```bash
28 | label-studio-ml init my-ml-backend --script models/bert_classifier.py
29 | cp models/utils.py my-ml-backend/utils.py
30 |
31 | # Start ML backend at http://localhost:9090
32 | label-studio-ml start my-ml-backend
33 |
34 | # Start Label Studio in the new terminal with the same python environment
35 | label-studio start
36 | ```
37 |
38 | 1. Create a project with `Choices` and `Text` tags in the labeling config.
39 | 2. Connect the ML backend in the Project settings with `http://localhost:9090`
40 |
41 | ##### Create ML backend with BERT named entity recognizer
42 | ```bash
43 | label-studio-ml init my-ml-backend --script models/ner.py
44 | cp models/utils.py my-ml-backend/utils.py
45 |
46 | # Start ML backend at http://localhost:9090
47 | label-studio-ml start my-ml-backend
48 |
49 | # Start Label Studio in the new terminal with the same python environment
50 | label-studio start
51 | ```
52 |
53 | 1. Create a project with `Labels` and `Text` tags in the labeling config.
54 | 2. Connect the ML backend in the Project settings with `http://localhost:9090`
55 |
56 | #### Training and inference
57 |
58 | The browser opens at `http://localhost:8080`. Upload your data on **Import** page then annotate by selecting **Labeling** page.
59 | Once you've annotate sufficient amount of data, go to **Model** page and press **Start Training** button. Once training is finished, model automatically starts serving for inference from Label Studio, and you'll find all model checkpoints inside `my-ml-backend//` directory.
60 |
61 | [Click here](https://labelstud.io/guide/ml.html) to read more about how to use Machine Learning backend and build Human-in-the-Loop pipelines with Label Studio
62 |
63 | ## License
64 |
65 | This software is licensed under the [Apache 2.0 LICENSE](/LICENSE) © [Heartex](https://www.heartex.com/). 2020
66 |
67 |
68 |
--------------------------------------------------------------------------------
/images/codeless.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanSignal/label-studio-transformers/de267cbd903ebf907d4449323fb81e57e3d38257/images/codeless.png
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanSignal/label-studio-transformers/de267cbd903ebf907d4449323fb81e57e3d38257/models/__init__.py
--------------------------------------------------------------------------------
/models/bert_classifier.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import os
4 |
5 | from torch.utils.data import SequentialSampler
6 | from tqdm import tqdm, trange
7 | from collections import deque
8 | from tensorboardX import SummaryWriter
9 | from transformers import BertTokenizer, BertForSequenceClassification
10 | from transformers import AdamW, get_linear_schedule_with_warmup
11 | from torch.utils.data import TensorDataset, DataLoader, RandomSampler
12 |
13 | from label_studio_ml.model import LabelStudioMLBase
14 |
15 | from utils import prepare_texts, calc_slope
16 |
17 |
18 | if torch.cuda.is_available():
19 | device = torch.device("cuda")
20 | print('There are %d GPU(s) available.' % torch.cuda.device_count())
21 | print('We will use the GPU:', torch.cuda.get_device_name(0))
22 | else:
23 | print('No GPU available, using the CPU instead.')
24 | device = torch.device("cpu")
25 |
26 |
27 | class BertClassifier(LabelStudioMLBase):
28 |
29 | def __init__(
30 | self, pretrained_model='bert-base-multilingual-cased', maxlen=64,
31 | batch_size=32, num_epochs=100, logging_steps=1, train_logs=None, **kwargs
32 | ):
33 | super(BertClassifier, self).__init__(**kwargs)
34 | self.pretrained_model = pretrained_model
35 | self.maxlen = maxlen
36 | self.batch_size = batch_size
37 | self.num_epochs = num_epochs
38 | self.logging_steps = logging_steps
39 | self.train_logs = train_logs
40 |
41 | # then collect all keys from config which will be used to extract data from task and to form prediction
42 | # Parsed label config contains only one output of type
43 | assert len(self.parsed_label_config) == 1
44 | self.from_name, self.info = list(self.parsed_label_config.items())[0]
45 | assert self.info['type'] == 'Choices'
46 |
47 | # the model has only one textual input
48 | assert len(self.info['to_name']) == 1
49 | assert len(self.info['inputs']) == 1
50 | assert self.info['inputs'][0]['type'] == 'Text'
51 | self.to_name = self.info['to_name'][0]
52 | self.value = self.info['inputs'][0]['value']
53 |
54 | if not self.train_output:
55 | self.labels = self.info['labels']
56 | self.reset_model('bert-base-multilingual-cased', cache_dir=None, device='cpu')
57 | print('Initialized with from_name={from_name}, to_name={to_name}, labels={labels}'.format(
58 | from_name=self.from_name, to_name=self.to_name, labels=str(self.labels)
59 | ))
60 | else:
61 | self.load(self.train_output)
62 | print('Loaded from train output with from_name={from_name}, to_name={to_name}, labels={labels}'.format(
63 | from_name=self.from_name, to_name=self.to_name, labels=str(self.labels)
64 | ))
65 |
66 | def reset_model(self, pretrained_model, cache_dir, device):
67 | model = BertForSequenceClassification.from_pretrained(
68 | pretrained_model,
69 | num_labels=len(self.labels),
70 | output_attentions=False,
71 | output_hidden_states=False,
72 | cache_dir=cache_dir
73 | )
74 | model.to(device)
75 | return model
76 |
77 | def load(self, train_output):
78 | pretrained_model = train_output['model_path']
79 | self.tokenizer = BertTokenizer.from_pretrained(pretrained_model)
80 | self.model = BertForSequenceClassification.from_pretrained(pretrained_model)
81 | self.model.to(device)
82 | self.model.eval()
83 | self.batch_size = train_output['batch_size']
84 | self.labels = train_output['labels']
85 | self.maxlen = train_output['maxlen']
86 |
87 | @property
88 | def not_trained(self):
89 | return not hasattr(self, 'tokenizer')
90 |
91 | def predict(self, tasks, **kwargs):
92 | if self.not_trained:
93 | print('Can\'t get prediction because model is not trained yet.')
94 | return []
95 |
96 | texts = [task['data'][self.value] for task in tasks]
97 | predict_dataloader = prepare_texts(texts, self.tokenizer, self.maxlen, SequentialSampler, self.batch_size)
98 |
99 | pred_labels, pred_scores = [], []
100 | for batch in predict_dataloader:
101 | batch = tuple(t.to(device) for t in batch)
102 | inputs = {
103 | 'input_ids': batch[0],
104 | 'attention_mask': batch[1]
105 | }
106 | with torch.no_grad():
107 | outputs = self.model(**inputs)
108 | logits = outputs[0]
109 |
110 | batch_preds = logits.detach().cpu().numpy()
111 |
112 | argmax_batch_preds = np.argmax(batch_preds, axis=-1)
113 | pred_labels.extend(str(self.labels[i]) for i in argmax_batch_preds)
114 |
115 | max_batch_preds = np.max(batch_preds, axis=-1)
116 | pred_scores.extend(float(s) for s in max_batch_preds)
117 |
118 | predictions = []
119 | for predicted_label, score in zip(pred_labels, pred_scores):
120 | result = [{
121 | 'from_name': self.from_name,
122 | 'to_name': self.to_name,
123 | 'type': 'choices',
124 | 'value': {'choices': [predicted_label]}
125 | }]
126 |
127 | predictions.append({'result': result, 'score': score})
128 | return predictions
129 |
130 | def fit(self, completions, workdir=None, cache_dir=None, **kwargs):
131 | input_texts = []
132 | output_labels, output_labels_idx = [], []
133 | label2idx = {l: i for i, l in enumerate(self.labels)}
134 | for completion in completions:
135 | # get input text from task data
136 |
137 | if completion['annotations'][0].get('skipped'):
138 | continue
139 |
140 | input_text = completion['data'][self.value]
141 | input_texts.append(input_text)
142 |
143 | # get an annotation
144 | output_label = completion['annotations'][0]['result'][0]['value']['choices'][0]
145 | output_labels.append(output_label)
146 | output_label_idx = label2idx[output_label]
147 | output_labels_idx.append(output_label_idx)
148 |
149 | new_labels = set(output_labels)
150 | added_labels = new_labels - set(self.labels)
151 | if len(added_labels) > 0:
152 | print('Label set has been changed. Added ones: ' + str(list(added_labels)))
153 | self.labels = list(sorted(new_labels))
154 | label2idx = {l: i for i, l in enumerate(self.labels)}
155 | output_labels_idx = [label2idx[label] for label in output_labels]
156 |
157 | tokenizer = BertTokenizer.from_pretrained(self.pretrained_model, cache_dir=cache_dir)
158 |
159 | train_dataloader = prepare_texts(input_texts, tokenizer, self.maxlen, RandomSampler, self.batch_size, output_labels_idx)
160 | model = self.reset_model(self.pretrained_model, cache_dir, device)
161 |
162 | total_steps = len(train_dataloader) * self.num_epochs
163 | optimizer = AdamW(model.parameters(), lr=2e-5, eps=1e-8)
164 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)
165 | global_step = 0
166 | total_loss, logging_loss = 0.0, 0.0
167 | model.zero_grad()
168 | train_iterator = trange(self.num_epochs, desc='Epoch')
169 | if self.train_logs:
170 | tb_writer = SummaryWriter(logdir=os.path.join(self.train_logs, os.path.basename(self.output_dir)))
171 | else:
172 | tb_writer = None
173 | loss_queue = deque(maxlen=10)
174 | for epoch in train_iterator:
175 | epoch_iterator = tqdm(train_dataloader, desc='Iteration')
176 | for step, batch in enumerate(epoch_iterator):
177 | model.train()
178 | batch = tuple(t.to(device) for t in batch)
179 | inputs = {'input_ids': batch[0],
180 | 'attention_mask': batch[1],
181 | 'labels': batch[2]}
182 | outputs = model(**inputs)
183 | loss = outputs[0]
184 | loss.backward()
185 | total_loss += loss.item()
186 |
187 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
188 | optimizer.step()
189 | scheduler.step()
190 | model.zero_grad()
191 | global_step += 1
192 | if global_step % self.logging_steps == 0:
193 | last_loss = (total_loss - logging_loss) / self.logging_steps
194 | loss_queue.append(last_loss)
195 | if tb_writer:
196 | tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
197 | tb_writer.add_scalar('loss', last_loss, global_step)
198 | logging_loss = total_loss
199 |
200 | # slope-based early stopping
201 | if len(loss_queue) == loss_queue.maxlen:
202 | slope = calc_slope(loss_queue)
203 | if tb_writer:
204 | tb_writer.add_scalar('slope', slope, global_step)
205 | if abs(slope) < 1e-2:
206 | break
207 |
208 | if tb_writer:
209 | tb_writer.close()
210 |
211 | model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training # noqa
212 | model_to_save.save_pretrained(workdir)
213 | tokenizer.save_pretrained(workdir)
214 |
215 | return {
216 | 'model_path': workdir,
217 | 'batch_size': self.batch_size,
218 | 'maxlen': self.maxlen,
219 | 'pretrained_model': self.pretrained_model,
220 | 'labels': self.labels
221 | }
222 |
--------------------------------------------------------------------------------
/models/ner.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import re
4 | import os
5 | import io
6 | import logging
7 |
8 | from functools import partial
9 | from itertools import groupby
10 | from operator import itemgetter
11 | from torch.nn import CrossEntropyLoss
12 | from torch.utils.data import Dataset, DataLoader
13 | from tqdm import tqdm, trange
14 | from tensorboardX import SummaryWriter
15 | from collections import deque
16 |
17 | from transformers import (
18 | BertTokenizer, BertForTokenClassification, BertConfig,
19 | RobertaConfig, RobertaForTokenClassification, RobertaTokenizer,
20 | DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer,
21 | CamembertConfig, CamembertForTokenClassification, CamembertTokenizer,
22 | AutoConfig, AutoModelForTokenClassification, AutoTokenizer
23 | )
24 | from transformers import AdamW, get_linear_schedule_with_warmup
25 |
26 | from label_studio_ml.model import LabelStudioMLBase
27 | from utils import calc_slope
28 |
29 |
30 | logger = logging.getLogger(__name__)
31 |
32 |
33 | ALL_MODELS = sum(
34 | [list(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, DistilBertConfig)],
35 | [])
36 |
37 | MODEL_CLASSES = {
38 | 'bert': (BertConfig, BertForTokenClassification, BertTokenizer),
39 | 'roberta': (RobertaConfig, RobertaForTokenClassification, RobertaTokenizer),
40 | 'distilbert': (DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer),
41 | 'camembert': (CamembertConfig, CamembertForTokenClassification, CamembertTokenizer),
42 | }
43 |
44 |
45 | class SpanLabeledTextDataset(Dataset):
46 |
47 | def __init__(
48 | self, list_of_strings, list_of_spans=None, tokenizer=None, tag_idx_map=None,
49 | cls_token='[CLS]', sep_token='[SEP]', pad_token_label_id=-1, max_seq_length=128, sep_token_extra=False,
50 | cls_token_at_end=False, sequence_a_segment_id=0, cls_token_segment_id=1, mask_padding_with_zero=True
51 | ):
52 | self.list_of_strings = list_of_strings
53 | self.list_of_spans = list_of_spans or [[] * len(list_of_strings)]
54 | self.tokenizer = tokenizer
55 | self.cls_token = cls_token
56 | self.sep_token = sep_token
57 | self.pad_token_label_id = pad_token_label_id
58 | self.max_seq_length = max_seq_length
59 | self.sep_token_extra = sep_token_extra
60 | self.cls_token_at_end = cls_token_at_end
61 | self.sequence_a_segment_id = sequence_a_segment_id
62 | self.cls_token_segment_id = cls_token_segment_id
63 | self.mask_padding_with_zero = mask_padding_with_zero
64 |
65 | (self.original_list_of_tokens, self.original_list_of_tags, tag_idx_map_,
66 | original_list_of_tokens_start_map) = self._prepare_data()
67 |
68 | if tag_idx_map is None:
69 | self.tag_idx_map = tag_idx_map_
70 | else:
71 | self.tag_idx_map = tag_idx_map
72 |
73 | (self.list_of_tokens, self.list_of_token_ids, self.list_of_labels, self.list_of_label_ids,
74 | self.list_of_segment_ids, self.list_of_token_start_map) = [], [], [], [], [], []
75 |
76 | for original_tokens, original_tags, original_token_start_map in zip(
77 | self.original_list_of_tokens,
78 | self.original_list_of_tags,
79 | original_list_of_tokens_start_map
80 | ):
81 | tokens, token_ids, labels, label_ids, segment_ids, token_start_map = self._convert_to_features(
82 | original_tokens, original_tags, self.tag_idx_map, original_token_start_map)
83 | self.list_of_token_ids.append(token_ids)
84 | self.list_of_tokens.append(tokens)
85 | self.list_of_labels.append(labels)
86 | self.list_of_segment_ids.append(segment_ids)
87 | self.list_of_label_ids.append(label_ids)
88 | self.list_of_token_start_map.append(token_start_map)
89 |
90 | def get_params_dict(self):
91 | return {
92 | 'cls_token': self.cls_token,
93 | 'sep_token': self.sep_token,
94 | 'pad_token_label_id': self.pad_token_label_id,
95 | 'max_seq_length': self.max_seq_length,
96 | 'sep_token_extra': self.sep_token_extra,
97 | 'cls_token_at_end': self.cls_token_at_end,
98 | 'sequence_a_segment_id': self.sequence_a_segment_id,
99 | 'cls_token_segment_id': self.cls_token_segment_id,
100 | 'mask_padding_with_zero': self.mask_padding_with_zero
101 | }
102 |
103 | def dump(self, output_file):
104 | with io.open(output_file, mode='w') as f:
105 | for tokens, labels in zip(self.list_of_tokens, self.list_of_labels):
106 | for token, label in zip(tokens, labels):
107 | f.write(f'{token} {label}\n')
108 | f.write('\n')
109 |
110 | def _convert_to_features(self, words, labels, label_map, list_token_start_map):
111 | tokens, out_labels, label_ids, tokens_idx_map = [], [], [], []
112 | for i, (word, label, token_start) in enumerate(zip(words, labels, list_token_start_map)):
113 | word_tokens = self.tokenizer.tokenize(word)
114 | tokens.extend(word_tokens)
115 | tokens_idx_map.extend([token_start] * len(word_tokens))
116 | # Use the real label id for the first token of the word, and padding ids for the remaining tokens
117 | label_ids.extend([label_map[label]] + [self.pad_token_label_id] * (len(word_tokens) - 1))
118 | out_labels.extend([label] + ['X'] * (len(word_tokens) - 1))
119 |
120 | # Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa.
121 | special_tokens_count = 3 if self.sep_token_extra else 2
122 | if len(tokens) > self.max_seq_length - special_tokens_count:
123 | tokens = tokens[:(self.max_seq_length - special_tokens_count)]
124 | label_ids = label_ids[:(self.max_seq_length - special_tokens_count)]
125 | out_labels = out_labels[:(self.max_seq_length - special_tokens_count)]
126 | tokens_idx_map = tokens_idx_map[:(self.max_seq_length - special_tokens_count)]
127 |
128 | # The convention in BERT is:
129 | # (a) For sequence pairs:
130 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
131 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
132 | # (b) For single sequences:
133 | # tokens: [CLS] the dog is hairy . [SEP]
134 | # type_ids: 0 0 0 0 0 0 0
135 | #
136 | # Where "type_ids" are used to indicate whether this is the first
137 | # sequence or the second sequence. The embedding vectors for `type=0` and
138 | # `type=1` were learned during pre-training and are added to the wordpiece
139 | # embedding vector (and position vector). This is not *strictly* necessary
140 | # since the [SEP] token unambiguously separates the sequences, but it makes
141 | # it easier for the model to learn the concept of sequences.
142 | #
143 | # For classification tasks, the first vector (corresponding to [CLS]) is
144 | # used as as the "sentence vector". Note that this only makes sense because
145 | # the entire model is fine-tuned.
146 | tokens += [self.sep_token]
147 | label_ids += [self.pad_token_label_id]
148 | out_labels += ['X']
149 | tokens_idx_map += [-1]
150 | if self.sep_token_extra:
151 | # roberta uses an extra separator b/w pairs of sentences
152 | tokens += [self.sep_token]
153 | label_ids += [self.pad_token_label_id]
154 | out_labels += ['X']
155 | tokens_idx_map += [-1]
156 | segment_ids = [self.sequence_a_segment_id] * len(tokens)
157 | if self.cls_token_at_end:
158 | tokens += [self.cls_token]
159 | label_ids += [self.pad_token_label_id]
160 | out_labels += ['X']
161 | segment_ids += [self.cls_token_segment_id]
162 | tokens_idx_map += [-1]
163 | else:
164 | tokens = [self.cls_token] + tokens
165 | label_ids = [self.pad_token_label_id] + label_ids
166 | out_labels = ['X'] + out_labels
167 | segment_ids = [self.cls_token_segment_id] + segment_ids
168 | tokens_idx_map = [-1] + tokens_idx_map
169 |
170 | token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
171 |
172 | return tokens, token_ids, out_labels, label_ids, segment_ids, tokens_idx_map
173 |
174 | def _apply_tokenizer(self, original_tokens, original_tags):
175 | out_tokens, out_tags, out_maps = [], [], []
176 | for i, (original_token, original_tag) in enumerate(zip(original_tokens, original_tags)):
177 | tokens = self.tokenizer.tokenize(original_token)
178 | out_tokens.extend(tokens)
179 | out_maps.extend([i] * len(tokens))
180 | start_tag = original_tag.startswith('B-')
181 | for j in range(len(tokens)):
182 | if (j == 0 and start_tag) or original_tag == 'O':
183 | out_tags.append(original_tag)
184 | else:
185 | out_tags.append(f'I-{original_tag[2:]}')
186 | return out_tokens, out_tags, out_maps
187 |
188 | def _prepare_data(self):
189 | list_of_tokens, list_of_tags, list_of_token_idx_maps = [], [], []
190 | tag_idx_map = {'O': 0}
191 | for text, spans in zip(self.list_of_strings, self.list_of_spans):
192 | if not text:
193 | continue
194 |
195 | tokens = []
196 | start = 0
197 | for t in text.split():
198 | tokens.append((t, start))
199 | start += len(t) + 1
200 |
201 | if spans:
202 | spans = list(sorted(spans, key=itemgetter('start')))
203 | span = spans.pop(0)
204 | prefix = 'B-'
205 | tags = []
206 | for token, token_start in tokens:
207 | token_end = token_start + len(token) - 1
208 |
209 | # token precedes current span
210 | if not span or token_end < span['start']:
211 | tags.append('O')
212 | continue
213 |
214 | # token jumps over the span (it could happens
215 | # when prev label ends with whitespaces, e.g. "cat " "too" or span created for whitespace)
216 | if token_start > span['end']:
217 |
218 | prefix = 'B-'
219 | no_more_spans = False
220 | while token_start > span['end']:
221 | if not len(spans):
222 | no_more_spans = True
223 | break
224 | span = spans.pop(0)
225 |
226 | if no_more_spans:
227 | tags.append('O')
228 | span = None
229 | continue
230 |
231 | if token_end < span['start']:
232 | tags.append('O')
233 | continue
234 |
235 | label = span['label']
236 | if label.startswith(prefix):
237 | tag = label
238 | else:
239 | tag = f'{prefix}{label}'
240 | tags.append(tag)
241 | if tag not in tag_idx_map:
242 | tag_idx_map[tag] = len(tag_idx_map)
243 | if span['end'] > token_end:
244 | prefix = 'I-'
245 | elif len(spans):
246 | span = spans.pop(0)
247 | prefix = 'B-'
248 | else:
249 | span = None
250 | else:
251 | tags = ['O'] * len(tokens)
252 |
253 | list_of_tokens.append([t[0] for t in tokens])
254 | list_of_token_idx_maps.append([t[1] for t in tokens])
255 | list_of_tags.append(tags)
256 |
257 | return list_of_tokens, list_of_tags, tag_idx_map, list_of_token_idx_maps
258 |
259 | def __len__(self):
260 | return len(self.list_of_token_ids)
261 |
262 | def __getitem__(self, idx):
263 | return {
264 | 'tokens': self.list_of_token_ids[idx],
265 | 'labels': self.list_of_label_ids[idx],
266 | 'segments': self.list_of_segment_ids[idx],
267 | 'token_start_map': self.list_of_token_start_map[idx],
268 | 'string': self.list_of_strings[idx]
269 | }
270 |
271 | @property
272 | def num_labels(self):
273 | return len(self.tag_idx_map)
274 |
275 | @classmethod
276 | def pad_sequences(cls, batch, mask_padding_with_zero, pad_on_left, pad_token, pad_token_segment_id, pad_token_label_id):
277 | # The mask has 1 for real tokens and 0 for padding tokens. Only real
278 | # tokens are attended to.
279 | max_seq_length = max(map(len, (sample['tokens'] for sample in batch)))
280 | batch_input_ids, batch_label_ids, batch_segment_ids, batch_input_mask, batch_token_start_map = [], [], [], [], []
281 | batch_strings = []
282 | for sample in batch:
283 | input_ids = sample['tokens']
284 | label_ids = sample['labels']
285 | segment_ids = sample['segments']
286 | # The mask has 1 for real tokens and 0 for padding tokens. Only real
287 | # tokens are attended to.
288 | input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
289 | padding_length = max_seq_length - len(input_ids)
290 | if pad_on_left:
291 | input_ids = ([pad_token] * padding_length) + input_ids
292 | input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask
293 | segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
294 | label_ids = ([pad_token_label_id] * padding_length) + label_ids
295 | else:
296 | input_ids += ([pad_token] * padding_length)
297 | input_mask += ([0 if mask_padding_with_zero else 1] * padding_length)
298 | segment_ids += ([pad_token_segment_id] * padding_length)
299 | label_ids += ([pad_token_label_id] * padding_length)
300 | batch_input_ids.append(input_ids)
301 | batch_label_ids.append(label_ids)
302 | batch_segment_ids.append(segment_ids)
303 | batch_input_mask.append(input_mask)
304 | batch_token_start_map.append(sample['token_start_map'])
305 | batch_strings.append(sample['string'])
306 |
307 | return {
308 | 'input_ids': torch.tensor(batch_input_ids, dtype=torch.long),
309 | 'label_ids': torch.tensor(batch_label_ids, dtype=torch.long),
310 | 'segment_ids': torch.tensor(batch_segment_ids, dtype=torch.long),
311 | 'input_mask': torch.tensor(batch_input_mask, dtype=torch.long),
312 | 'token_start_map': batch_token_start_map,
313 | 'strings': batch_strings
314 | }
315 |
316 | @classmethod
317 | def get_padding_function(cls, model_type, tokenizer, pad_token_label_id):
318 | return partial(
319 | cls.pad_sequences,
320 | mask_padding_with_zero=True,
321 | pad_on_left=model_type in ['xlnet'],
322 | pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
323 | pad_token_segment_id=4 if model_type in ['xlnet'] else 0,
324 | pad_token_label_id=pad_token_label_id
325 | )
326 |
327 |
328 | class TransformersBasedTagger(LabelStudioMLBase):
329 |
330 | def __init__(self, **kwargs):
331 | super(TransformersBasedTagger, self).__init__(**kwargs)
332 |
333 | assert len(self.parsed_label_config) == 1
334 | self.from_name, self.info = list(self.parsed_label_config.items())[0]
335 | assert self.info['type'] == 'Labels'
336 |
337 | # the model has only one textual input
338 | assert len(self.info['to_name']) == 1
339 | assert len(self.info['inputs']) == 1
340 | assert self.info['inputs'][0]['type'] == 'Text'
341 | self.to_name = self.info['to_name'][0]
342 | self.value = self.info['inputs'][0]['value']
343 |
344 | if not self.train_output:
345 | self.labels = self.info['labels']
346 | else:
347 | self.load(self.train_output)
348 |
349 | def load(self, train_output):
350 | pretrained_model = train_output['model_path']
351 | self._model_type = train_output['model_type']
352 | _, model_class, tokenizer_class = MODEL_CLASSES[train_output['model_type']]
353 |
354 | self._tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
355 | self._model = AutoModelForTokenClassification.from_pretrained(pretrained_model)
356 | self._batch_size = train_output['batch_size']
357 | self._pad_token = self._tokenizer.convert_tokens_to_ids([self._tokenizer.pad_token])[0]
358 | self._pad_token_label_id = train_output['pad_token_label_id']
359 | self._label_map = train_output['label_map']
360 | self._mask_padding_with_zero = True
361 | self._dataset_params_dict = train_output['dataset_params_dict']
362 |
363 | self._batch_padding = SpanLabeledTextDataset.get_padding_function(
364 | self._model_type, self._tokenizer, self._pad_token_label_id)
365 |
366 | def predict(self, tasks, **kwargs):
367 | texts = [task['data'][self.value] for task in tasks]
368 | predict_set = SpanLabeledTextDataset(texts, tokenizer=self._tokenizer, **self._dataset_params_dict)
369 | from_name = self.from_name
370 | to_name = self.to_name
371 | predict_loader = DataLoader(
372 | dataset=predict_set,
373 | batch_size=self._batch_size,
374 | collate_fn=self._batch_padding
375 | )
376 |
377 | results = []
378 | for batch in tqdm(predict_loader, desc='Prediction'):
379 | inputs = {
380 | 'input_ids': batch['input_ids'],
381 | 'attention_mask': batch['input_mask'],
382 | 'token_type_ids': batch['segment_ids']
383 | }
384 | if self._model_type == 'distilbert':
385 | inputs.pop('token_type_ids')
386 | with torch.no_grad():
387 | model_output = self._model(**inputs)
388 | logits = model_output[0]
389 |
390 | batch_preds = logits.detach().cpu().numpy()
391 | argmax_batch_preds = np.argmax(batch_preds, axis=-1)
392 | max_batch_preds = np.max(batch_preds, axis=-1)
393 | input_mask = batch['input_mask'].detach().cpu().numpy()
394 | batch_token_start_map = batch['token_start_map']
395 | batch_strings = batch['strings']
396 |
397 | for max_preds, argmax_preds, mask_tokens, token_start_map, string in zip(
398 | max_batch_preds, argmax_batch_preds, input_mask, batch_token_start_map, batch_strings
399 | ):
400 | preds, scores, starts = [], [], []
401 | for max_pred, argmax_pred, mask_token, token_start in zip(max_preds, argmax_preds, mask_tokens, token_start_map):
402 | if token_start != -1:
403 | preds.append(self._label_map[str(argmax_pred)])
404 | scores.append(max_pred)
405 | starts.append(token_start)
406 | mean_score = np.mean(scores) if len(scores) > 0 else 0
407 |
408 | result = []
409 |
410 | for label, group in groupby(zip(preds, starts, scores), key=lambda i: re.sub('^(B-|I-)', '', i[0])):
411 | _, group_start, _ = list(group)[0]
412 | if len(result) > 0:
413 | if group_start == 0:
414 | result.pop(-1)
415 | else:
416 | result[-1]['value']['end'] = group_start - 1
417 | if label != 'O':
418 | result.append({
419 | 'from_name': from_name,
420 | 'to_name': to_name,
421 | 'type': 'labels',
422 | 'value': {
423 | 'labels': [label],
424 | 'start': group_start,
425 | 'end': None,
426 | 'text': '...'
427 | }
428 | })
429 | if result and result[-1]['value']['end'] is None:
430 | result[-1]['value']['end'] = len(string)
431 | results.append({
432 | 'result': result,
433 | 'score': float(mean_score),
434 | 'cluster': None
435 | })
436 | return results
437 |
438 | def get_spans(self, completion):
439 | spans = []
440 | for r in completion['result']:
441 | if r['from_name'] == self.from_name and r['to_name'] == self.to_name:
442 | labels = r['value'].get('labels')
443 | if not isinstance(labels, list) or len(labels) == 0:
444 | logger.warning(f'Error while parsing {r}: list type expected for "labels"')
445 | continue
446 | label = labels[0]
447 | start, end = r['value'].get('start'), r['value'].get('end')
448 | if start is None or end is None:
449 | logger.warning(f'Error while parsing {r}: "labels" should contain "start" and "end" fields')
450 | spans.append({
451 | 'label': label,
452 | 'start': start,
453 | 'end': end
454 | })
455 | return spans
456 |
457 | def fit(
458 | self, completions, workdir=None, model_type='bert', pretrained_model='bert-base-uncased',
459 | batch_size=32, learning_rate=5e-5, adam_epsilon=1e-8, num_train_epochs=100, weight_decay=0.0, logging_steps=1,
460 | warmup_steps=0, save_steps=50, dump_dataset=True, cache_dir='~/.heartex/cache', train_logs=None,
461 | **kwargs
462 | ):
463 | train_logs = train_logs or os.path.join(workdir, 'train_logs')
464 | os.makedirs(train_logs, exist_ok=True)
465 | logger.debug('Prepare models')
466 | cache_dir = os.path.expanduser(cache_dir)
467 | os.makedirs(cache_dir, exist_ok=True)
468 |
469 | model_type = model_type.lower()
470 | # assert model_type in MODEL_CLASSES.keys(), f'Input model type {model_type} not in {MODEL_CLASSES.keys()}'
471 | # assert pretrained_model in ALL_MODELS, f'Pretrained model {pretrained_model} not in {ALL_MODELS}'
472 |
473 | tokenizer = AutoTokenizer.from_pretrained(pretrained_model, cache_dir=cache_dir)
474 |
475 | logger.debug('Read data')
476 | # read input data stream
477 | texts, list_of_spans = [], []
478 | for item in completions:
479 | texts.append(item['data'][self.value])
480 | list_of_spans.append(self.get_spans(item['annotations'][0]))
481 |
482 | logger.debug('Prepare dataset')
483 | pad_token_label_id = CrossEntropyLoss().ignore_index
484 | train_set = SpanLabeledTextDataset(
485 | texts, list_of_spans, tokenizer,
486 | cls_token_at_end=model_type in ['xlnet'],
487 | cls_token_segment_id=2 if model_type in ['xlnet'] else 0,
488 | sep_token_extra=model_type in ['roberta'],
489 | pad_token_label_id=pad_token_label_id
490 | )
491 |
492 | if dump_dataset:
493 | dataset_file = os.path.join(workdir, 'train_set.txt')
494 | train_set.dump(dataset_file)
495 |
496 | # config = config_class.from_pretrained(pretrained_model, num_labels=train_set.num_labels, cache_dir=cache_dir)
497 | config = AutoConfig.from_pretrained(pretrained_model, num_labels=train_set.num_labels, cache_dir=cache_dir)
498 | # model = model_class.from_pretrained(pretrained_model, config=config, cache_dir=cache_dir)
499 | model = AutoModelForTokenClassification.from_pretrained(pretrained_model, config=config, cache_dir=cache_dir)
500 |
501 | batch_padding = SpanLabeledTextDataset.get_padding_function(model_type, tokenizer, pad_token_label_id)
502 |
503 | train_loader = DataLoader(
504 | dataset=train_set,
505 | batch_size=batch_size,
506 | shuffle=True,
507 | collate_fn=batch_padding
508 | )
509 |
510 | no_decay = ['bias', 'LayerNorm.weight']
511 | optimizer_grouped_parameters = [
512 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
513 | 'weight_decay': weight_decay},
514 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
515 | ]
516 |
517 | num_training_steps = len(train_loader) * num_train_epochs
518 | optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=adam_epsilon)
519 | scheduler = get_linear_schedule_with_warmup(
520 | optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps)
521 |
522 | tr_loss, logging_loss = 0, 0
523 | global_step = 0
524 | if train_logs:
525 | tb_writer = SummaryWriter(logdir=os.path.join(train_logs, os.path.basename(workdir)))
526 | epoch_iterator = trange(num_train_epochs, desc='Epoch')
527 | loss_queue = deque(maxlen=10)
528 | for _ in epoch_iterator:
529 | batch_iterator = tqdm(train_loader, desc='Batch')
530 | for step, batch in enumerate(batch_iterator):
531 |
532 | model.train()
533 | inputs = {
534 | 'input_ids': batch['input_ids'],
535 | 'attention_mask': batch['input_mask'],
536 | 'labels': batch['label_ids'],
537 | 'token_type_ids': batch['segment_ids']
538 | }
539 | if model_type == 'distilbert':
540 | inputs.pop('token_type_ids')
541 |
542 | model_output = model(**inputs)
543 | loss = model_output[0]
544 | loss.backward()
545 | tr_loss += loss.item()
546 | optimizer.step()
547 | scheduler.step()
548 | model.zero_grad()
549 | global_step += 1
550 | if global_step % logging_steps == 0:
551 | last_loss = (tr_loss - logging_loss) / logging_steps
552 | loss_queue.append(last_loss)
553 | if train_logs:
554 | tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
555 | tb_writer.add_scalar('loss', last_loss, global_step)
556 | logging_loss = tr_loss
557 |
558 | # slope-based early stopping
559 | if len(loss_queue) == loss_queue.maxlen:
560 | slope = calc_slope(loss_queue)
561 | if train_logs:
562 | tb_writer.add_scalar('slope', slope, global_step)
563 | if abs(slope) < 1e-2:
564 | break
565 |
566 | if train_logs:
567 | tb_writer.close()
568 |
569 | model_to_save = model.module if hasattr(model, "module") else model # Take care of distributed/parallel training
570 | model_to_save.save_pretrained(workdir)
571 | tokenizer.save_pretrained(workdir)
572 | label_map = {i: t for t, i in train_set.tag_idx_map.items()}
573 |
574 | return {
575 | 'model_path': workdir,
576 | 'batch_size': batch_size,
577 | 'pad_token_label_id': pad_token_label_id,
578 | 'dataset_params_dict': train_set.get_params_dict(),
579 | 'model_type': model_type,
580 | 'pretrained_model': pretrained_model,
581 | 'label_map': label_map
582 | }
583 |
--------------------------------------------------------------------------------
/models/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | from torch.utils.data import TensorDataset, DataLoader
5 |
6 |
7 | def pad_sequences(input_ids, maxlen):
8 | padded_ids = []
9 | for ids in input_ids:
10 | nonpad = min(len(ids), maxlen)
11 | pids = [ids[i] for i in range(nonpad)]
12 | for i in range(nonpad, maxlen):
13 | pids.append(0)
14 | padded_ids.append(pids)
15 | return padded_ids
16 |
17 |
18 | def prepare_texts(texts, tokenizer, maxlen, sampler_class, batch_size, choices_ids=None):
19 | # create input token indices
20 | input_ids = []
21 | for text in texts:
22 | input_ids.append(tokenizer.encode(text, add_special_tokens=True))
23 | # input_ids = pad_sequences(input_ids, maxlen=maxlen, dtype='long', value=0, truncating='post', padding='post')
24 | input_ids = pad_sequences(input_ids, maxlen)
25 | # Create attention masks
26 | attention_masks = []
27 | for sent in input_ids:
28 | attention_masks.append([int(token_id > 0) for token_id in sent])
29 |
30 | if choices_ids is not None:
31 | dataset = TensorDataset(torch.tensor(input_ids, dtype=torch.long), torch.tensor(attention_masks, dtype=torch.long), torch.tensor(choices_ids, dtype=torch.long))
32 | else:
33 | dataset = TensorDataset(torch.tensor(input_ids, dtype=torch.long), torch.tensor(attention_masks, dtype=torch.long))
34 | sampler = sampler_class(dataset)
35 | dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size)
36 | return dataloader
37 |
38 |
39 | def calc_slope(y):
40 | n = len(y)
41 | if n == 1:
42 | raise ValueError('Can\'t compute slope for array of length=1')
43 | x_mean = (n + 1) / 2
44 | x2_mean = (n + 1) * (2 * n + 1) / 6
45 | xy_mean = np.average(y, weights=np.arange(1, n + 1))
46 | y_mean = np.mean(y)
47 | slope = (xy_mean - x_mean * y_mean) / (x2_mean - x_mean * x_mean)
48 | return slope
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.13.1
2 | transformers==4.4.2
3 | tensorboardX==1.9
4 | label-studio>=1.0.0
5 | git+git://github.com/heartexlabs/label-studio-ml-backend@master#egg=label-studio-ml
6 |
--------------------------------------------------------------------------------