├── .gitignore
├── ADR_NLP
├── __init__.py
├── data.py
├── data_helpers.py
├── metrics.py
├── utils.py
└── visualize_ner.py
├── LICENSE
├── README.md
├── finetune.py
├── pretrain.py
└── testing.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # jsonl data files
132 | *.jsonl
133 | data/
134 |
135 | # Training runs & WANDB logs
136 | wandb/
137 | cross-validation-finetuned/
138 | adr-ner-finetuned/
--------------------------------------------------------------------------------
/ADR_NLP/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AustinMOS/adr-nlp/d7f1e8e53b274c2403f72afd43f2839dea8589fe/ADR_NLP/__init__.py
--------------------------------------------------------------------------------
/ADR_NLP/data.py:
--------------------------------------------------------------------------------
1 | """
2 | This script is used to create our final tokenized data set
3 | using the HuggingFace Datasets library.
4 |
5 | The helper functions from data_helpers.py are used to process annotated
6 | text, generated using the Prodigy NLP tool.
7 |
8 | The final dataset will be chunked according to the maximum tokens allowed
9 | in our model (512)
10 | """
11 | import datasets
12 | from datasets import Features, Sequence
13 | from datasets.features import Value, ClassLabel
14 | from ADR_NLP.data_helpers import (
15 | split_text, tokenize_and_align_labels,
16 | label_table, split)
17 | import random
18 |
19 | # A helper function that returns True if 'jsonl' is found in a string, otherwise False
20 | def is_jsonl(string):
21 | if 'jsonl' in string:
22 | return True
23 | else:
24 | return False
25 |
26 | """
27 | A class to hold the data and labels for the model.
28 | Can be initialized from a JSONL file and procesed, or from a preprocessed Dataset that has been saved.
29 | """
30 | class NERdataset():
31 | def __init__(self, data_file, text_col, tokenizer, folds, seed, save = None):
32 | self.tokenizer = tokenizer
33 | self.text_col = text_col
34 | self.folds = folds
35 | self.seed = seed
36 | self.data_file = data_file
37 | self.save = save
38 | self.ta = tokenize_and_align_labels(tokenizer)
39 |
40 | # Load dataset
41 | self.load()
42 | # Process the data (and save if save is not None)
43 | self.process()
44 |
45 | # Load using datasets.Dataset.from_json if the data is an unprocessed jsonl file, otherwise load from a preprocessed dataset using load_dataset
46 | def load(self):
47 | if is_jsonl(self.data_file):
48 | self.dataset = datasets.Dataset.from_json(self.data_file)
49 | else:
50 | self.dataset = datasets.load_from_disk(self.data_file)
51 | return self
52 |
53 | def process_jsonl(self):
54 | rm_cols = list(set(self.dataset.column_names) - set([self.text_col, 'tokens','ner_tags']))
55 | self.dataset = self.dataset.map(label_table, remove_columns=rm_cols)
56 |
57 | ner_names = list(set([it for sl in self.dataset['ner_tags'] for it in sl]))
58 | features = Features(
59 | {self.text_col: Value(dtype='string', id=None),
60 | 'tokens': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None),
61 | 'ner_tags': Sequence(feature=ClassLabel(names=ner_names), length=-1, id=None)}
62 | )
63 | self.dataset = self.dataset.map(features.encode_example, features=features)
64 | if self.save:
65 | self.dataset.save_to_disk(self.save)
66 |
67 | def process_dataset(self):
68 | label_list = self.dataset.features['ner_tags'].feature.names
69 | ids = range(len(label_list))
70 | self.label2id = dict(zip(label_list, ids))
71 | self.id2label = dict(zip(ids, label_list))
72 |
73 | tokenized_dataset = self.dataset.map(self.ta.tokenize_align, batched=True)
74 |
75 | split_dataset = datasets.Dataset.from_pandas(
76 | tokenized_dataset
77 | .remove_columns(['ner_tags', 'tokens', self.text_col])
78 | .to_pandas()
79 | .applymap(split_text, max_length=self.tokenizer.model_max_length)
80 | .explode(['attention_mask', 'input_ids', 'labels', 'token_type_ids'])
81 | .reset_index())
82 |
83 | if self.folds > 1:
84 |
85 | indices = list(dict.fromkeys(split_dataset['index']))
86 |
87 | folds_list = list(split(indices, n=self.folds))
88 | random.shuffle(indices)
89 |
90 | self.dset = dict()
91 | for i in range(self.folds):
92 | test = [indices[index] for index in folds_list[i]]
93 | train = list(set(indices) - set(test))
94 | self.dset[f'fold{i}'] = datasets.DatasetDict({
95 | 'train': split_dataset.filter(lambda example: example['index'] in train),
96 | 'test': split_dataset.filter(lambda example: example['index'] in test)
97 | })
98 |
99 | else:
100 | self.dset = split_dataset.train_test_split(test_size=0.2, seed=self.seed)
101 |
102 | # Process the data
103 | def process(self):
104 | if is_jsonl(self.data_file):
105 | self.process_jsonl()
106 | self.process_dataset()
107 | else:
108 | self.process_dataset()
109 |
110 |
111 |
112 |
113 | def __repr__(self):
114 | return f'Data(data_file={self.data_file}, text_col={self.text_col}, tokenizer={self.tokenizer}, folds={self.folds}, seed={self.seed})'
115 |
--------------------------------------------------------------------------------
/ADR_NLP/data_helpers.py:
--------------------------------------------------------------------------------
1 | from spacy.training import offsets_to_biluo_tags
2 | import spacy
3 |
4 |
5 | # we will create a function to convert offset formatted labels to BILUO tags
6 | def label_table(dataset):
7 | nlp = spacy.blank("en")
8 | tokens = [word['text'] for word in dataset['tokens']]
9 | dataset['tokens'] = tokens
10 | dataset['ner_tags'] = offsets_to_biluo_tags(
11 | nlp(dataset['text']),
12 | [(d['start'], d['end'], d['label']) for d in [d for d in (dataset['spans'] or [])]])
13 |
14 | return dataset
15 |
16 | # this class holds a function for tokenizing and aligning labels
17 | class tokenize_and_align_labels():
18 | def __init__(self, tokenizer, label_all_tokens=True):
19 | self.tokenizer = tokenizer
20 | self.label_all_tokens = label_all_tokens
21 | def tokenize_align(self, examples):
22 | tokenized_inputs = self.tokenizer(examples["tokens"], truncation=False, is_split_into_words=True)
23 |
24 | labels = []
25 | for i, label in enumerate(examples["ner_tags"]):
26 | word_ids = tokenized_inputs.word_ids(batch_index=i)
27 | previous_word_idx = None
28 | label_ids = []
29 | for word_idx in word_ids:
30 | # Special tokens have a word id that is None. We set the label to -100 so they are automatically
31 | # ignored in the loss function.
32 | if word_idx is None:
33 | label_ids.append(-100)
34 | # We set the label for the first token of each word.
35 | elif word_idx != previous_word_idx:
36 | label_ids.append(label[word_idx])
37 | # For the other tokens in a word, we set the label to either the current label or -100, depending on
38 | # the label_all_tokens flag.
39 | else:
40 | label_ids.append(label[word_idx] if self.label_all_tokens else -100)
41 | previous_word_idx = word_idx
42 |
43 | labels.append(label_ids)
44 |
45 | tokenized_inputs["labels"] = labels
46 | return tokenized_inputs
47 |
48 | # function for splitting text into appropriate sized chunks
49 | def split_text(x, max_length=512):
50 | length = len(x)
51 | if length > max_length:
52 | splits = length // max_length
53 | y = list()
54 | [y.append(x[i : i + max_length]) for i in range(0, splits*max_length, max_length)]
55 | if length % max_length > 0:
56 | y.append(x[splits*max_length : length])
57 | else:
58 | y = list([x])
59 |
60 | return y
61 |
62 | def split(a, n):
63 | k, m = divmod(len(a), n)
64 | return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n))
--------------------------------------------------------------------------------
/ADR_NLP/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from datasets import load_metric
3 | from scipy.special import softmax
4 | from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve
5 | import pandas as pd
6 | import re
7 | import wandb
8 |
9 | class CompMetrics():
10 | def __init__(self, label_list):
11 | self.label_list = label_list
12 | self.metric = load_metric("seqeval")
13 | def compute_metrics(self, p):
14 | predictions, labels = p
15 | predictions = np.argmax(predictions, axis=2)
16 |
17 | # Remove ignored index (special tokens)
18 | true_predictions = [
19 | [self.label_list[p] for (p, l) in zip(prediction, label) if l != -100]
20 | for prediction, label in zip(predictions, labels)
21 | ]
22 | true_labels = [
23 | [self.label_list[l] for (p, l) in zip(prediction, label) if l != -100]
24 | for prediction, label in zip(predictions, labels)
25 | ]
26 |
27 | results = self.metric.compute(predictions=true_predictions, references=true_labels)
28 | return {
29 | "precision": results["overall_precision"],
30 | "recall": results["overall_recall"],
31 | "f1": results["overall_f1"],
32 | "accuracy": results["overall_accuracy"],
33 | }
34 |
35 |
36 | def doc_level_metrics(trainer, dataset, label_list, metric_labels = ['ADR'], wandb_log = False):
37 | probs, _, _ = trainer.predict(dataset)
38 | doc_metrics = dict()
39 | for label in metric_labels:
40 | label_loc = [i for i, item in enumerate(label_list) if re.search(label, item)]
41 | probs_soft = softmax(probs, axis=2)[:,:,label_loc].max(axis=(1,2))
42 |
43 | doc_level = dataset.map(
44 | lambda example:
45 | {'Truth': any(item in example['labels'] for item in label_loc)},
46 | remove_columns = ['attention_mask', 'input_ids', 'labels', 'token_type_ids']).to_pandas()
47 | doc_level['Prob'] = probs_soft
48 |
49 | truth = doc_level.groupby('index')['Truth'].max().values
50 | estimate = doc_level.groupby('index')['Prob'].max().values
51 |
52 |
53 | pr = average_precision_score(truth, estimate)
54 | auc = roc_auc_score(truth, estimate)
55 | fpr, tpr, threshold = roc_curve(truth, estimate)
56 | roc_table = pd.DataFrame({'fpr': fpr, 'tpr': tpr, 'threshold': threshold})
57 | doc_metrics[label] = {"AUC": auc, "PR": pr}
58 | if wandb_log:
59 | wandb.log({label + " AUC": auc, label + " PR": pr})
60 | wandb.log({"roc" : wandb.Table(dataframe=roc_table)})
61 |
62 | return doc_metrics
--------------------------------------------------------------------------------
/ADR_NLP/utils.py:
--------------------------------------------------------------------------------
1 | def none_or_str(value):
2 | if value == 'None':
3 | return None
4 | return value
--------------------------------------------------------------------------------
/ADR_NLP/visualize_ner.py:
--------------------------------------------------------------------------------
1 | import matplotlib.cm as cm
2 | import html
3 | from IPython.display import display, HTML
4 | import torch
5 | import numpy as np
6 | from transformers import pipeline
7 |
8 | def value2rgba(x, cmap=cm.RdYlGn, alpha_mult=1.0):
9 | "Convert a value `x` from 0 to 1 (inclusive) to an RGBA tuple according to `cmap` times transparency `alpha_mult`."
10 | c = cmap(x)
11 | rgb = (np.array(c[:-1]) * 255).astype(int)
12 | a = c[-1] * alpha_mult
13 | return tuple(rgb.tolist() + [a])
14 |
15 |
16 | def piece_prob_html(pieces, prob, sep=' ', **kwargs):
17 | html_code,spans = [''], []
18 | for p, a in zip(pieces, prob):
19 | p = html.escape(p)
20 | c = str(value2rgba(a, alpha_mult=0.5, **kwargs))
21 | spans.append(f'{p}')
22 | html_code.append(sep.join(spans))
23 | html_code.append('')
24 | return ''.join(html_code)
25 |
26 | def show_piece_attn(*args, **kwargs):
27 | from IPython.display import display, HTML
28 | display(HTML(piece_prob_html(*args, **kwargs)))
29 |
30 | def split_text(x, max_length):
31 | length = len(x)
32 | if length > max_length:
33 | splits = length // max_length
34 | y = list()
35 | [y.append(torch.tensor([x[i : i + max_length]])) for i in range(0, splits*max_length, max_length)]
36 | if length % max_length > 0:
37 | y.append(torch.tensor([x[splits*max_length : length]]))
38 | else:
39 | y = list(torch.tensor([x]))
40 |
41 | return y
42 |
43 | def nothing_ent(i, word):
44 | return {
45 | 'entity': 'O',
46 | 'score': 0,
47 | 'index': i,
48 | 'word': word,
49 | 'start': 0,
50 | 'end': 0
51 | }
52 |
53 | def generate_highlighted_text(model, tokenizer, text):
54 | ner_model = pipeline(
55 | 'token-classification',
56 | model=model,
57 | tokenizer=tokenizer,
58 | ignore_labels=None,
59 | device=0)
60 | result = ner_model(text)
61 | tokens = ner_model.tokenizer.tokenize(text)
62 | label_indeces = [i['index'] - 1 for i in result]
63 |
64 | entities = list()
65 | for i, word in enumerate(tokens):
66 | if i in label_indeces:
67 | entities.append(result[label_indeces.index(i)])
68 | else:
69 | entities.append(nothing_ent(i, word))
70 | entities = ner_model.group_entities(entities)
71 | spans = [e['word'] for e in entities]
72 | probs = [e['score'] for e in entities]
73 | return piece_prob_html(spans, probs, sep=' ')
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Medicines Optimisation Service - Austin Health
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Natural Language Processing for Adverse Drug Reaction (ADR) Detection
4 |
5 | This repo contains code from a project to identify ADRs in discharge summaries at Austin Health. The model uses the HuggingFace Transformers library, beginning with the pretrained DeBERTa model. Further MLM pre-training is performed on a large corpus of unannotated discharge summaries. Finally, fine-tuning is peformed on a corpus of annotated discharge summaries (annotated using [Prodigy](https://prodi.gy)). The model performs NER, but final performance is measured at the document level using the maximum token-level score.
6 |
7 | We used [Weights and Biases](https://wandb.ai) for experiment tracking.
8 |
9 | The *pretrain* script takes a folder containing discharge summaries stored in CSV folders, tokenizes and continues MLM training on [deberta-base](https://huggingface.co/microsoft/deberta-base).
10 |
11 | Fine-tuning can then be performed with the *finetune* script using CLI commands. This script assumes the data is either a JSONL file of annotated text exported from Prodigy (`--datafile example.jsonl`), or a saved HuggingFace Datasets. If you run this script once on a JSONL file of annotations, you can choose to save the Dataset into a folder (`--save_data_dir "save_to_here"`) and use this for subsequent training runs (`--datafile "save_to_here"`).
12 |
13 | Example usage:
14 | ```bash
15 | python .\finetune.py --folds 5 --epochs 15 --lr 5e-5 --wandb_on --hub_off --project 'CLI Tests' --run_name cross-validation --datafile 'data'
16 | ```
17 |
18 | ---
19 | **Note:** you might find that your exported annotations (JSONL file) is not encoded using UTF-8, which will prevent this code from working. There are various methods to change the encoding and these can all be found with a quick Google search. On a windows machine, for example, modify the following in powershell:
20 | ```powershell
21 | Get-Content .\name_of_file.jsonl -Encoding Unicode | Set-Content -Encoding UTF8 .\name_of_new_file.jsonl
22 | ```
23 |
--------------------------------------------------------------------------------
/finetune.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from ADR_NLP.utils import none_or_str
3 | from ADR_NLP.data import NERdataset
4 | from ADR_NLP.metrics import CompMetrics, doc_level_metrics
5 | from ADR_NLP.visualize_ner import generate_highlighted_text
6 | from transformers import (
7 | AutoTokenizer, PreTrainedTokenizerFast, DataCollatorForTokenClassification,
8 | AutoModelForTokenClassification, Trainer, TrainingArguments)
9 | from pathlib import Path
10 |
11 |
12 | def main(
13 | model_name, push_to_hub, datafile, epochs, batch_size, lr,
14 | weight_decay, seed, wb, run_name, project, text_file,
15 | save_data_dir, text_col, tokenizer_name, folds, hub_id
16 | ):
17 |
18 | if wb:
19 | import wandb
20 | wandb.init(project=project)
21 | text = Path(text_file).read_text()
22 |
23 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, add_prefix_space=True)
24 | assert isinstance(tokenizer, PreTrainedTokenizerFast)
25 | data_collator = DataCollatorForTokenClassification(tokenizer)
26 |
27 | dataset = NERdataset(
28 | datafile,
29 | text_col,
30 | tokenizer,
31 | folds,
32 | seed,
33 | save_data_dir)
34 |
35 | labels = list(dataset.label2id.keys())
36 | cm = CompMetrics(labels)
37 |
38 |
39 | train_args = TrainingArguments(
40 | f"{run_name}-finetuned",
41 | evaluation_strategy = "epoch",
42 | learning_rate=lr,
43 | per_device_train_batch_size=batch_size,
44 | per_device_eval_batch_size=batch_size,
45 | num_train_epochs=epochs,
46 | weight_decay=weight_decay,
47 | push_to_hub=push_to_hub,
48 | seed=seed,
49 | hub_model_id=hub_id,
50 | hub_strategy="end",
51 | )
52 |
53 | if folds > 1:
54 | for fold, data in dataset.dset.items():
55 | model = AutoModelForTokenClassification.from_pretrained(
56 | model_name,
57 | label2id = dataset.label2id,
58 | id2label = dataset.id2label,
59 | num_labels = len(dataset.label2id))
60 |
61 | train_args.run_name = f'{run_name}-{fold}'
62 |
63 | trainer = Trainer(
64 | model,
65 | train_args,
66 | train_dataset=data["train"],
67 | eval_dataset=data["test"],
68 | data_collator=data_collator,
69 | tokenizer=tokenizer,
70 | compute_metrics=cm.compute_metrics
71 | )
72 |
73 | trainer.train()
74 |
75 | doc_metrics = doc_level_metrics(
76 | trainer,
77 | data["test"],
78 | label_list = labels,
79 | metric_labels = ['ADR'],
80 | wandb_log=wb)
81 |
82 | pred_html = generate_highlighted_text(model, tokenizer, text)
83 |
84 | if wb:
85 | wandb.log({"NER": wandb.Html(pred_html)})
86 |
87 | if wb:
88 | wandb.finish()
89 |
90 | print(doc_metrics)
91 |
92 | else:
93 | model = AutoModelForTokenClassification.from_pretrained(
94 | model_name,
95 | label2id = dataset.label2id,
96 | id2label = dataset.id2label,
97 | num_labels = len(dataset.label2id))
98 |
99 | train_args.push_to_hub = push_to_hub
100 | train_args.run_name = run_name
101 |
102 | trainer = Trainer(
103 | model,
104 | train_args,
105 | train_dataset=dataset.dset["train"],
106 | eval_dataset=dataset.dset["test"],
107 | data_collator=data_collator,
108 | tokenizer=tokenizer,
109 | compute_metrics=cm.compute_metrics
110 | )
111 |
112 | trainer.train()
113 |
114 | doc_metrics = doc_level_metrics(
115 | trainer,
116 | dataset.dset["test"],
117 | label_list = labels,
118 | metric_labels = ['ADR'],
119 | wandb_log=wb)
120 |
121 | pred_html = generate_highlighted_text(model, tokenizer, text)
122 |
123 | if wb:
124 | wandb.log({"NER": wandb.Html(pred_html)})
125 |
126 | if wb:
127 | wandb.finish()
128 |
129 | print(doc_metrics)
130 |
131 | if push_to_hub:
132 | trainer.push_to_hub()
133 |
134 |
135 | if __name__ == "__main__":
136 | parser = argparse.ArgumentParser(description='Train an NLP model')
137 | parser.add_argument('--model_name', type=str, default='austin/Austin-MeDeBERTa', help='Choose a model from the HF hub')
138 |
139 | parser.add_argument('--hub_on', dest='push_to_hub', action='store_true', help='Push the model to the HF hub')
140 | parser.add_argument('--hub_off', dest='push_to_hub', action='store_false', help='Push the model to the HF hub')
141 | parser.set_defaults(push_to_hub=False)
142 |
143 | parser.add_argument('--datafile', type=str, default='annotations.jsonl', help='Path to the data file')
144 | parser.add_argument('--epochs', type=int, default=10, help='Number of epochs to train')
145 | parser.add_argument('--batch_size', type=int, default=3, help='Batch size')
146 | parser.add_argument('--lr', type=float, default=0.001, help='Learning rate')
147 | parser.add_argument('--weight_decay', type=float, default=0.01, help='Weight decay')
148 | parser.add_argument('--seed', type=int, default=42, help='Random seed')
149 |
150 | parser.add_argument('--wandb_on', dest='wb', action='store_true', help='Use wandb for logging')
151 | parser.add_argument('--wandb_off', dest='wb', action='store_false', help='No wandb logging')
152 | parser.set_defaults(wb=False)
153 |
154 | parser.add_argument('--run_name', type=str, default='adr_nlp', help='Name of the run for wandb')
155 | parser.add_argument('--project', type=str, default='adr_nlp', help='Name of the project for wandb')
156 | parser.add_argument('--text_file', type=str, default='testing.txt', help='If logging with wandb, a file containing a text to test NER on')
157 |
158 | parser.add_argument('--save_data_dir', type=none_or_str, nargs='?', default=None, help='Save processed dataset to this directory')
159 | parser.add_argument('--text_col', type=str, default='text', help='Name of column in datafile that contains the text')
160 | parser.add_argument(
161 | '--tokenizer_name',
162 | type=none_or_str,
163 | nargs='?',
164 | default=None,
165 | help='Name of the tokenizer if different to model. Must be a fast tokenizer.')
166 | parser.add_argument('--folds', default=5, type=int, help='Number of folds to split data into.')
167 | parser.add_argument('--hub_id', type=str, default=None, help='If pushing to hub, use this id')
168 | args = parser.parse_args()
169 |
170 | if args.tokenizer_name is None:
171 | args.tokenizer_name = args.model_name
172 |
173 | main(**vars(args))
--------------------------------------------------------------------------------
/pretrain.py:
--------------------------------------------------------------------------------
1 | import datasets
2 | from transformers import (AutoTokenizer, DataCollatorForLanguageModeling, TrainingArguments,
3 | Trainer, AutoModelForMaskedLM, PreTrainedTokenizerFast)
4 |
5 | from datasets import load_dataset, concatenate_datasets
6 | from transformers import AutoTokenizer, PreTrainedTokenizerFast
7 | import re
8 | import glob
9 |
10 | text_path = "../pretraining_data" # Path to documents to use for pre-training (stored in CSV files)
11 | base_model = "microsoft/deberta-base" # Base pre-trained model
12 | save_to_hub = "organization/model_name" # Directory where the model will be saved in HF hub
13 | report = "wandb" # report to Weights and Biases
14 | run_name = "run_name" # Name of the run for reporting and local storage
15 | other_data_sources = False # False for this project. Set to True if also using other texts (e.g. radiology reports, pathology reports etc.)
16 |
17 |
18 | # Function to remove discharge summary "bloat" (e.g. long copy-and-paste sections like medication list and pathology results)
19 | def dcsumm_body(text):
20 | if text['TEXT']:
21 | reduced = re.search('PRINCIPAL DIAGNOSIS(.*?)DISCHARGE|PRESCRIBED MEDICATION', text['TEXT'], flags=re.S)
22 | if reduced:
23 | text['TEXT'] = reduced.group(1)
24 | else:
25 | text['TEXT'] = ''
26 | else:
27 | text['TEXT'] = ''
28 | return text
29 |
30 |
31 | def tokenize_function(examples, tokenizer):
32 | return tokenizer(examples["TEXT"])
33 |
34 |
35 | def group_texts(examples):
36 | # Concatenate all texts.
37 | concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
38 | total_length = len(concatenated_examples[list(examples.keys())[0]])
39 | total_length = (total_length // 512) * 512
40 | result = {
41 | k: [t[i: i + 512] for i in range(0, total_length, 512)]
42 | for k, t in concatenated_examples.items()
43 | }
44 | result["labels"] = result["input_ids"].copy()
45 | return result
46 |
47 |
48 | datafiles = glob.glob(f"{text_path}/*")
49 | tokenizer = AutoTokenizer.from_pretrained(base_model)
50 | assert isinstance(tokenizer, PreTrainedTokenizerFast)
51 |
52 | dcsumm_datafiles=[datafile for datafile in datafiles if 'summaries' in datafile]
53 | other_datafiles=[datafile for datafile in datafiles if 'summaries' not in datafile]
54 |
55 | dataset = load_dataset("csv", data_files=dcsumm_datafiles) \
56 | .map(dcsumm_body) \
57 | .filter(lambda example: (example['TEXT'] is not None) & (example['TEXT'] != '')) \
58 | .remove_columns(['EPISODE_ID', 'PATIENT_ID', 'DOC_ID', 'START_DTTM'])
59 |
60 | if other_data_sources:
61 | other_dataset = load_dataset("csv", data_files=other_datafiles) \
62 | .filter(lambda example: (example['TEXT'] is not None) & (example['TEXT'] != ''))
63 |
64 | dataset = concatenate_datasets([dataset['train'], other_dataset['train']])
65 |
66 | tokenised = dataset.map(tokenize_function, fn_kwargs={'tokenizer': tokenizer},
67 | batched=True, batch_size=5000,
68 | remove_columns=["TEXT"])
69 |
70 | tokenised = tokenised.map(
71 | group_texts,
72 | batched=True,
73 | num_proc=10
74 | )
75 |
76 | tokenised.save_to_disk("tokenized-texts")
77 |
78 | model = AutoModelForMaskedLM.from_pretrained(base_model)
79 | tokenizer = AutoTokenizer.from_pretrained(base_model)
80 | data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)
81 | assert isinstance(tokenizer, PreTrainedTokenizerFast)
82 |
83 | train_test = datasets.load_from_disk('tokenized-texts')
84 | train_test = train_test.train_test_split(test_size=0.05)
85 |
86 | training_args = TrainingArguments(
87 | run_name,
88 | evaluation_strategy = "steps",
89 | eval_steps = 40_000,
90 | learning_rate=5e-5,
91 | weight_decay=0.01,
92 | per_device_train_batch_size=3,
93 | per_device_eval_batch_size=3,
94 | report_to=report,
95 | run_name=run_name,
96 | num_train_epochs=5,
97 | save_steps=40_000,
98 | hub_model_id=save_to_hub,
99 | push_to_hub=True
100 | )
101 |
102 | trainer = Trainer(
103 | model=model,
104 | args=training_args,
105 | train_dataset=train_test["train"],
106 | eval_dataset=train_test["test"],
107 | data_collator=data_collator,
108 | )
109 |
110 | trainer.train()
111 | trainer.push_to_hub()
--------------------------------------------------------------------------------
/testing.txt:
--------------------------------------------------------------------------------
1 | # Pancreatitis
2 | - Lipase: 535 -> 154 -> 145
3 | - Managed with NBM, IV fluids
4 | - CT AP and abdo USS: normal
5 | - Likely secondary to Azathioprine - ceased, never to be used again.
6 |
--------------------------------------------------------------------------------