├── .gitignore ├── ADR_NLP ├── __init__.py ├── data.py ├── data_helpers.py ├── metrics.py ├── utils.py └── visualize_ner.py ├── LICENSE ├── README.md ├── finetune.py ├── pretrain.py └── testing.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # jsonl data files 132 | *.jsonl 133 | data/ 134 | 135 | # Training runs & WANDB logs 136 | wandb/ 137 | cross-validation-finetuned/ 138 | adr-ner-finetuned/ -------------------------------------------------------------------------------- /ADR_NLP/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AustinMOS/adr-nlp/d7f1e8e53b274c2403f72afd43f2839dea8589fe/ADR_NLP/__init__.py -------------------------------------------------------------------------------- /ADR_NLP/data.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script is used to create our final tokenized data set 3 | using the HuggingFace Datasets library. 4 | 5 | The helper functions from data_helpers.py are used to process annotated 6 | text, generated using the Prodigy NLP tool. 7 | 8 | The final dataset will be chunked according to the maximum tokens allowed 9 | in our model (512) 10 | """ 11 | import datasets 12 | from datasets import Features, Sequence 13 | from datasets.features import Value, ClassLabel 14 | from ADR_NLP.data_helpers import ( 15 | split_text, tokenize_and_align_labels, 16 | label_table, split) 17 | import random 18 | 19 | # A helper function that returns True if 'jsonl' is found in a string, otherwise False 20 | def is_jsonl(string): 21 | if 'jsonl' in string: 22 | return True 23 | else: 24 | return False 25 | 26 | """ 27 | A class to hold the data and labels for the model. 28 | Can be initialized from a JSONL file and procesed, or from a preprocessed Dataset that has been saved. 29 | """ 30 | class NERdataset(): 31 | def __init__(self, data_file, text_col, tokenizer, folds, seed, save = None): 32 | self.tokenizer = tokenizer 33 | self.text_col = text_col 34 | self.folds = folds 35 | self.seed = seed 36 | self.data_file = data_file 37 | self.save = save 38 | self.ta = tokenize_and_align_labels(tokenizer) 39 | 40 | # Load dataset 41 | self.load() 42 | # Process the data (and save if save is not None) 43 | self.process() 44 | 45 | # Load using datasets.Dataset.from_json if the data is an unprocessed jsonl file, otherwise load from a preprocessed dataset using load_dataset 46 | def load(self): 47 | if is_jsonl(self.data_file): 48 | self.dataset = datasets.Dataset.from_json(self.data_file) 49 | else: 50 | self.dataset = datasets.load_from_disk(self.data_file) 51 | return self 52 | 53 | def process_jsonl(self): 54 | rm_cols = list(set(self.dataset.column_names) - set([self.text_col, 'tokens','ner_tags'])) 55 | self.dataset = self.dataset.map(label_table, remove_columns=rm_cols) 56 | 57 | ner_names = list(set([it for sl in self.dataset['ner_tags'] for it in sl])) 58 | features = Features( 59 | {self.text_col: Value(dtype='string', id=None), 60 | 'tokens': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), 61 | 'ner_tags': Sequence(feature=ClassLabel(names=ner_names), length=-1, id=None)} 62 | ) 63 | self.dataset = self.dataset.map(features.encode_example, features=features) 64 | if self.save: 65 | self.dataset.save_to_disk(self.save) 66 | 67 | def process_dataset(self): 68 | label_list = self.dataset.features['ner_tags'].feature.names 69 | ids = range(len(label_list)) 70 | self.label2id = dict(zip(label_list, ids)) 71 | self.id2label = dict(zip(ids, label_list)) 72 | 73 | tokenized_dataset = self.dataset.map(self.ta.tokenize_align, batched=True) 74 | 75 | split_dataset = datasets.Dataset.from_pandas( 76 | tokenized_dataset 77 | .remove_columns(['ner_tags', 'tokens', self.text_col]) 78 | .to_pandas() 79 | .applymap(split_text, max_length=self.tokenizer.model_max_length) 80 | .explode(['attention_mask', 'input_ids', 'labels', 'token_type_ids']) 81 | .reset_index()) 82 | 83 | if self.folds > 1: 84 | 85 | indices = list(dict.fromkeys(split_dataset['index'])) 86 | 87 | folds_list = list(split(indices, n=self.folds)) 88 | random.shuffle(indices) 89 | 90 | self.dset = dict() 91 | for i in range(self.folds): 92 | test = [indices[index] for index in folds_list[i]] 93 | train = list(set(indices) - set(test)) 94 | self.dset[f'fold{i}'] = datasets.DatasetDict({ 95 | 'train': split_dataset.filter(lambda example: example['index'] in train), 96 | 'test': split_dataset.filter(lambda example: example['index'] in test) 97 | }) 98 | 99 | else: 100 | self.dset = split_dataset.train_test_split(test_size=0.2, seed=self.seed) 101 | 102 | # Process the data 103 | def process(self): 104 | if is_jsonl(self.data_file): 105 | self.process_jsonl() 106 | self.process_dataset() 107 | else: 108 | self.process_dataset() 109 | 110 | 111 | 112 | 113 | def __repr__(self): 114 | return f'Data(data_file={self.data_file}, text_col={self.text_col}, tokenizer={self.tokenizer}, folds={self.folds}, seed={self.seed})' 115 | -------------------------------------------------------------------------------- /ADR_NLP/data_helpers.py: -------------------------------------------------------------------------------- 1 | from spacy.training import offsets_to_biluo_tags 2 | import spacy 3 | 4 | 5 | # we will create a function to convert offset formatted labels to BILUO tags 6 | def label_table(dataset): 7 | nlp = spacy.blank("en") 8 | tokens = [word['text'] for word in dataset['tokens']] 9 | dataset['tokens'] = tokens 10 | dataset['ner_tags'] = offsets_to_biluo_tags( 11 | nlp(dataset['text']), 12 | [(d['start'], d['end'], d['label']) for d in [d for d in (dataset['spans'] or [])]]) 13 | 14 | return dataset 15 | 16 | # this class holds a function for tokenizing and aligning labels 17 | class tokenize_and_align_labels(): 18 | def __init__(self, tokenizer, label_all_tokens=True): 19 | self.tokenizer = tokenizer 20 | self.label_all_tokens = label_all_tokens 21 | def tokenize_align(self, examples): 22 | tokenized_inputs = self.tokenizer(examples["tokens"], truncation=False, is_split_into_words=True) 23 | 24 | labels = [] 25 | for i, label in enumerate(examples["ner_tags"]): 26 | word_ids = tokenized_inputs.word_ids(batch_index=i) 27 | previous_word_idx = None 28 | label_ids = [] 29 | for word_idx in word_ids: 30 | # Special tokens have a word id that is None. We set the label to -100 so they are automatically 31 | # ignored in the loss function. 32 | if word_idx is None: 33 | label_ids.append(-100) 34 | # We set the label for the first token of each word. 35 | elif word_idx != previous_word_idx: 36 | label_ids.append(label[word_idx]) 37 | # For the other tokens in a word, we set the label to either the current label or -100, depending on 38 | # the label_all_tokens flag. 39 | else: 40 | label_ids.append(label[word_idx] if self.label_all_tokens else -100) 41 | previous_word_idx = word_idx 42 | 43 | labels.append(label_ids) 44 | 45 | tokenized_inputs["labels"] = labels 46 | return tokenized_inputs 47 | 48 | # function for splitting text into appropriate sized chunks 49 | def split_text(x, max_length=512): 50 | length = len(x) 51 | if length > max_length: 52 | splits = length // max_length 53 | y = list() 54 | [y.append(x[i : i + max_length]) for i in range(0, splits*max_length, max_length)] 55 | if length % max_length > 0: 56 | y.append(x[splits*max_length : length]) 57 | else: 58 | y = list([x]) 59 | 60 | return y 61 | 62 | def split(a, n): 63 | k, m = divmod(len(a), n) 64 | return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n)) -------------------------------------------------------------------------------- /ADR_NLP/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from datasets import load_metric 3 | from scipy.special import softmax 4 | from sklearn.metrics import roc_auc_score, average_precision_score, roc_curve 5 | import pandas as pd 6 | import re 7 | import wandb 8 | 9 | class CompMetrics(): 10 | def __init__(self, label_list): 11 | self.label_list = label_list 12 | self.metric = load_metric("seqeval") 13 | def compute_metrics(self, p): 14 | predictions, labels = p 15 | predictions = np.argmax(predictions, axis=2) 16 | 17 | # Remove ignored index (special tokens) 18 | true_predictions = [ 19 | [self.label_list[p] for (p, l) in zip(prediction, label) if l != -100] 20 | for prediction, label in zip(predictions, labels) 21 | ] 22 | true_labels = [ 23 | [self.label_list[l] for (p, l) in zip(prediction, label) if l != -100] 24 | for prediction, label in zip(predictions, labels) 25 | ] 26 | 27 | results = self.metric.compute(predictions=true_predictions, references=true_labels) 28 | return { 29 | "precision": results["overall_precision"], 30 | "recall": results["overall_recall"], 31 | "f1": results["overall_f1"], 32 | "accuracy": results["overall_accuracy"], 33 | } 34 | 35 | 36 | def doc_level_metrics(trainer, dataset, label_list, metric_labels = ['ADR'], wandb_log = False): 37 | probs, _, _ = trainer.predict(dataset) 38 | doc_metrics = dict() 39 | for label in metric_labels: 40 | label_loc = [i for i, item in enumerate(label_list) if re.search(label, item)] 41 | probs_soft = softmax(probs, axis=2)[:,:,label_loc].max(axis=(1,2)) 42 | 43 | doc_level = dataset.map( 44 | lambda example: 45 | {'Truth': any(item in example['labels'] for item in label_loc)}, 46 | remove_columns = ['attention_mask', 'input_ids', 'labels', 'token_type_ids']).to_pandas() 47 | doc_level['Prob'] = probs_soft 48 | 49 | truth = doc_level.groupby('index')['Truth'].max().values 50 | estimate = doc_level.groupby('index')['Prob'].max().values 51 | 52 | 53 | pr = average_precision_score(truth, estimate) 54 | auc = roc_auc_score(truth, estimate) 55 | fpr, tpr, threshold = roc_curve(truth, estimate) 56 | roc_table = pd.DataFrame({'fpr': fpr, 'tpr': tpr, 'threshold': threshold}) 57 | doc_metrics[label] = {"AUC": auc, "PR": pr} 58 | if wandb_log: 59 | wandb.log({label + " AUC": auc, label + " PR": pr}) 60 | wandb.log({"roc" : wandb.Table(dataframe=roc_table)}) 61 | 62 | return doc_metrics -------------------------------------------------------------------------------- /ADR_NLP/utils.py: -------------------------------------------------------------------------------- 1 | def none_or_str(value): 2 | if value == 'None': 3 | return None 4 | return value -------------------------------------------------------------------------------- /ADR_NLP/visualize_ner.py: -------------------------------------------------------------------------------- 1 | import matplotlib.cm as cm 2 | import html 3 | from IPython.display import display, HTML 4 | import torch 5 | import numpy as np 6 | from transformers import pipeline 7 | 8 | def value2rgba(x, cmap=cm.RdYlGn, alpha_mult=1.0): 9 | "Convert a value `x` from 0 to 1 (inclusive) to an RGBA tuple according to `cmap` times transparency `alpha_mult`." 10 | c = cmap(x) 11 | rgb = (np.array(c[:-1]) * 255).astype(int) 12 | a = c[-1] * alpha_mult 13 | return tuple(rgb.tolist() + [a]) 14 | 15 | 16 | def piece_prob_html(pieces, prob, sep=' ', **kwargs): 17 | html_code,spans = [''], [] 18 | for p, a in zip(pieces, prob): 19 | p = html.escape(p) 20 | c = str(value2rgba(a, alpha_mult=0.5, **kwargs)) 21 | spans.append(f'{p}') 22 | html_code.append(sep.join(spans)) 23 | html_code.append('') 24 | return ''.join(html_code) 25 | 26 | def show_piece_attn(*args, **kwargs): 27 | from IPython.display import display, HTML 28 | display(HTML(piece_prob_html(*args, **kwargs))) 29 | 30 | def split_text(x, max_length): 31 | length = len(x) 32 | if length > max_length: 33 | splits = length // max_length 34 | y = list() 35 | [y.append(torch.tensor([x[i : i + max_length]])) for i in range(0, splits*max_length, max_length)] 36 | if length % max_length > 0: 37 | y.append(torch.tensor([x[splits*max_length : length]])) 38 | else: 39 | y = list(torch.tensor([x])) 40 | 41 | return y 42 | 43 | def nothing_ent(i, word): 44 | return { 45 | 'entity': 'O', 46 | 'score': 0, 47 | 'index': i, 48 | 'word': word, 49 | 'start': 0, 50 | 'end': 0 51 | } 52 | 53 | def generate_highlighted_text(model, tokenizer, text): 54 | ner_model = pipeline( 55 | 'token-classification', 56 | model=model, 57 | tokenizer=tokenizer, 58 | ignore_labels=None, 59 | device=0) 60 | result = ner_model(text) 61 | tokens = ner_model.tokenizer.tokenize(text) 62 | label_indeces = [i['index'] - 1 for i in result] 63 | 64 | entities = list() 65 | for i, word in enumerate(tokens): 66 | if i in label_indeces: 67 | entities.append(result[label_indeces.index(i)]) 68 | else: 69 | entities.append(nothing_ent(i, word)) 70 | entities = ner_model.group_entities(entities) 71 | spans = [e['word'] for e in entities] 72 | probs = [e['score'] for e in entities] 73 | return piece_prob_html(spans, probs, sep=' ') -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Medicines Optimisation Service - Austin Health 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Natural Language Processing for Adverse Drug Reaction (ADR) Detection 4 | 5 | This repo contains code from a project to identify ADRs in discharge summaries at Austin Health. The model uses the HuggingFace Transformers library, beginning with the pretrained DeBERTa model. Further MLM pre-training is performed on a large corpus of unannotated discharge summaries. Finally, fine-tuning is peformed on a corpus of annotated discharge summaries (annotated using [Prodigy](https://prodi.gy)). The model performs NER, but final performance is measured at the document level using the maximum token-level score. 6 | 7 | We used [Weights and Biases](https://wandb.ai) for experiment tracking. 8 | 9 | The *pretrain* script takes a folder containing discharge summaries stored in CSV folders, tokenizes and continues MLM training on [deberta-base](https://huggingface.co/microsoft/deberta-base). 10 | 11 | Fine-tuning can then be performed with the *finetune* script using CLI commands. This script assumes the data is either a JSONL file of annotated text exported from Prodigy (`--datafile example.jsonl`), or a saved HuggingFace Datasets. If you run this script once on a JSONL file of annotations, you can choose to save the Dataset into a folder (`--save_data_dir "save_to_here"`) and use this for subsequent training runs (`--datafile "save_to_here"`). 12 | 13 | Example usage: 14 | ```bash 15 | python .\finetune.py --folds 5 --epochs 15 --lr 5e-5 --wandb_on --hub_off --project 'CLI Tests' --run_name cross-validation --datafile 'data' 16 | ``` 17 | 18 | --- 19 | **Note:** you might find that your exported annotations (JSONL file) is not encoded using UTF-8, which will prevent this code from working. There are various methods to change the encoding and these can all be found with a quick Google search. On a windows machine, for example, modify the following in powershell: 20 | ```powershell 21 | Get-Content .\name_of_file.jsonl -Encoding Unicode | Set-Content -Encoding UTF8 .\name_of_new_file.jsonl 22 | ``` 23 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from ADR_NLP.utils import none_or_str 3 | from ADR_NLP.data import NERdataset 4 | from ADR_NLP.metrics import CompMetrics, doc_level_metrics 5 | from ADR_NLP.visualize_ner import generate_highlighted_text 6 | from transformers import ( 7 | AutoTokenizer, PreTrainedTokenizerFast, DataCollatorForTokenClassification, 8 | AutoModelForTokenClassification, Trainer, TrainingArguments) 9 | from pathlib import Path 10 | 11 | 12 | def main( 13 | model_name, push_to_hub, datafile, epochs, batch_size, lr, 14 | weight_decay, seed, wb, run_name, project, text_file, 15 | save_data_dir, text_col, tokenizer_name, folds, hub_id 16 | ): 17 | 18 | if wb: 19 | import wandb 20 | wandb.init(project=project) 21 | text = Path(text_file).read_text() 22 | 23 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, add_prefix_space=True) 24 | assert isinstance(tokenizer, PreTrainedTokenizerFast) 25 | data_collator = DataCollatorForTokenClassification(tokenizer) 26 | 27 | dataset = NERdataset( 28 | datafile, 29 | text_col, 30 | tokenizer, 31 | folds, 32 | seed, 33 | save_data_dir) 34 | 35 | labels = list(dataset.label2id.keys()) 36 | cm = CompMetrics(labels) 37 | 38 | 39 | train_args = TrainingArguments( 40 | f"{run_name}-finetuned", 41 | evaluation_strategy = "epoch", 42 | learning_rate=lr, 43 | per_device_train_batch_size=batch_size, 44 | per_device_eval_batch_size=batch_size, 45 | num_train_epochs=epochs, 46 | weight_decay=weight_decay, 47 | push_to_hub=push_to_hub, 48 | seed=seed, 49 | hub_model_id=hub_id, 50 | hub_strategy="end", 51 | ) 52 | 53 | if folds > 1: 54 | for fold, data in dataset.dset.items(): 55 | model = AutoModelForTokenClassification.from_pretrained( 56 | model_name, 57 | label2id = dataset.label2id, 58 | id2label = dataset.id2label, 59 | num_labels = len(dataset.label2id)) 60 | 61 | train_args.run_name = f'{run_name}-{fold}' 62 | 63 | trainer = Trainer( 64 | model, 65 | train_args, 66 | train_dataset=data["train"], 67 | eval_dataset=data["test"], 68 | data_collator=data_collator, 69 | tokenizer=tokenizer, 70 | compute_metrics=cm.compute_metrics 71 | ) 72 | 73 | trainer.train() 74 | 75 | doc_metrics = doc_level_metrics( 76 | trainer, 77 | data["test"], 78 | label_list = labels, 79 | metric_labels = ['ADR'], 80 | wandb_log=wb) 81 | 82 | pred_html = generate_highlighted_text(model, tokenizer, text) 83 | 84 | if wb: 85 | wandb.log({"NER": wandb.Html(pred_html)}) 86 | 87 | if wb: 88 | wandb.finish() 89 | 90 | print(doc_metrics) 91 | 92 | else: 93 | model = AutoModelForTokenClassification.from_pretrained( 94 | model_name, 95 | label2id = dataset.label2id, 96 | id2label = dataset.id2label, 97 | num_labels = len(dataset.label2id)) 98 | 99 | train_args.push_to_hub = push_to_hub 100 | train_args.run_name = run_name 101 | 102 | trainer = Trainer( 103 | model, 104 | train_args, 105 | train_dataset=dataset.dset["train"], 106 | eval_dataset=dataset.dset["test"], 107 | data_collator=data_collator, 108 | tokenizer=tokenizer, 109 | compute_metrics=cm.compute_metrics 110 | ) 111 | 112 | trainer.train() 113 | 114 | doc_metrics = doc_level_metrics( 115 | trainer, 116 | dataset.dset["test"], 117 | label_list = labels, 118 | metric_labels = ['ADR'], 119 | wandb_log=wb) 120 | 121 | pred_html = generate_highlighted_text(model, tokenizer, text) 122 | 123 | if wb: 124 | wandb.log({"NER": wandb.Html(pred_html)}) 125 | 126 | if wb: 127 | wandb.finish() 128 | 129 | print(doc_metrics) 130 | 131 | if push_to_hub: 132 | trainer.push_to_hub() 133 | 134 | 135 | if __name__ == "__main__": 136 | parser = argparse.ArgumentParser(description='Train an NLP model') 137 | parser.add_argument('--model_name', type=str, default='austin/Austin-MeDeBERTa', help='Choose a model from the HF hub') 138 | 139 | parser.add_argument('--hub_on', dest='push_to_hub', action='store_true', help='Push the model to the HF hub') 140 | parser.add_argument('--hub_off', dest='push_to_hub', action='store_false', help='Push the model to the HF hub') 141 | parser.set_defaults(push_to_hub=False) 142 | 143 | parser.add_argument('--datafile', type=str, default='annotations.jsonl', help='Path to the data file') 144 | parser.add_argument('--epochs', type=int, default=10, help='Number of epochs to train') 145 | parser.add_argument('--batch_size', type=int, default=3, help='Batch size') 146 | parser.add_argument('--lr', type=float, default=0.001, help='Learning rate') 147 | parser.add_argument('--weight_decay', type=float, default=0.01, help='Weight decay') 148 | parser.add_argument('--seed', type=int, default=42, help='Random seed') 149 | 150 | parser.add_argument('--wandb_on', dest='wb', action='store_true', help='Use wandb for logging') 151 | parser.add_argument('--wandb_off', dest='wb', action='store_false', help='No wandb logging') 152 | parser.set_defaults(wb=False) 153 | 154 | parser.add_argument('--run_name', type=str, default='adr_nlp', help='Name of the run for wandb') 155 | parser.add_argument('--project', type=str, default='adr_nlp', help='Name of the project for wandb') 156 | parser.add_argument('--text_file', type=str, default='testing.txt', help='If logging with wandb, a file containing a text to test NER on') 157 | 158 | parser.add_argument('--save_data_dir', type=none_or_str, nargs='?', default=None, help='Save processed dataset to this directory') 159 | parser.add_argument('--text_col', type=str, default='text', help='Name of column in datafile that contains the text') 160 | parser.add_argument( 161 | '--tokenizer_name', 162 | type=none_or_str, 163 | nargs='?', 164 | default=None, 165 | help='Name of the tokenizer if different to model. Must be a fast tokenizer.') 166 | parser.add_argument('--folds', default=5, type=int, help='Number of folds to split data into.') 167 | parser.add_argument('--hub_id', type=str, default=None, help='If pushing to hub, use this id') 168 | args = parser.parse_args() 169 | 170 | if args.tokenizer_name is None: 171 | args.tokenizer_name = args.model_name 172 | 173 | main(**vars(args)) -------------------------------------------------------------------------------- /pretrain.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | from transformers import (AutoTokenizer, DataCollatorForLanguageModeling, TrainingArguments, 3 | Trainer, AutoModelForMaskedLM, PreTrainedTokenizerFast) 4 | 5 | from datasets import load_dataset, concatenate_datasets 6 | from transformers import AutoTokenizer, PreTrainedTokenizerFast 7 | import re 8 | import glob 9 | 10 | text_path = "../pretraining_data" # Path to documents to use for pre-training (stored in CSV files) 11 | base_model = "microsoft/deberta-base" # Base pre-trained model 12 | save_to_hub = "organization/model_name" # Directory where the model will be saved in HF hub 13 | report = "wandb" # report to Weights and Biases 14 | run_name = "run_name" # Name of the run for reporting and local storage 15 | other_data_sources = False # False for this project. Set to True if also using other texts (e.g. radiology reports, pathology reports etc.) 16 | 17 | 18 | # Function to remove discharge summary "bloat" (e.g. long copy-and-paste sections like medication list and pathology results) 19 | def dcsumm_body(text): 20 | if text['TEXT']: 21 | reduced = re.search('PRINCIPAL DIAGNOSIS(.*?)DISCHARGE|PRESCRIBED MEDICATION', text['TEXT'], flags=re.S) 22 | if reduced: 23 | text['TEXT'] = reduced.group(1) 24 | else: 25 | text['TEXT'] = '' 26 | else: 27 | text['TEXT'] = '' 28 | return text 29 | 30 | 31 | def tokenize_function(examples, tokenizer): 32 | return tokenizer(examples["TEXT"]) 33 | 34 | 35 | def group_texts(examples): 36 | # Concatenate all texts. 37 | concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} 38 | total_length = len(concatenated_examples[list(examples.keys())[0]]) 39 | total_length = (total_length // 512) * 512 40 | result = { 41 | k: [t[i: i + 512] for i in range(0, total_length, 512)] 42 | for k, t in concatenated_examples.items() 43 | } 44 | result["labels"] = result["input_ids"].copy() 45 | return result 46 | 47 | 48 | datafiles = glob.glob(f"{text_path}/*") 49 | tokenizer = AutoTokenizer.from_pretrained(base_model) 50 | assert isinstance(tokenizer, PreTrainedTokenizerFast) 51 | 52 | dcsumm_datafiles=[datafile for datafile in datafiles if 'summaries' in datafile] 53 | other_datafiles=[datafile for datafile in datafiles if 'summaries' not in datafile] 54 | 55 | dataset = load_dataset("csv", data_files=dcsumm_datafiles) \ 56 | .map(dcsumm_body) \ 57 | .filter(lambda example: (example['TEXT'] is not None) & (example['TEXT'] != '')) \ 58 | .remove_columns(['EPISODE_ID', 'PATIENT_ID', 'DOC_ID', 'START_DTTM']) 59 | 60 | if other_data_sources: 61 | other_dataset = load_dataset("csv", data_files=other_datafiles) \ 62 | .filter(lambda example: (example['TEXT'] is not None) & (example['TEXT'] != '')) 63 | 64 | dataset = concatenate_datasets([dataset['train'], other_dataset['train']]) 65 | 66 | tokenised = dataset.map(tokenize_function, fn_kwargs={'tokenizer': tokenizer}, 67 | batched=True, batch_size=5000, 68 | remove_columns=["TEXT"]) 69 | 70 | tokenised = tokenised.map( 71 | group_texts, 72 | batched=True, 73 | num_proc=10 74 | ) 75 | 76 | tokenised.save_to_disk("tokenized-texts") 77 | 78 | model = AutoModelForMaskedLM.from_pretrained(base_model) 79 | tokenizer = AutoTokenizer.from_pretrained(base_model) 80 | data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15) 81 | assert isinstance(tokenizer, PreTrainedTokenizerFast) 82 | 83 | train_test = datasets.load_from_disk('tokenized-texts') 84 | train_test = train_test.train_test_split(test_size=0.05) 85 | 86 | training_args = TrainingArguments( 87 | run_name, 88 | evaluation_strategy = "steps", 89 | eval_steps = 40_000, 90 | learning_rate=5e-5, 91 | weight_decay=0.01, 92 | per_device_train_batch_size=3, 93 | per_device_eval_batch_size=3, 94 | report_to=report, 95 | run_name=run_name, 96 | num_train_epochs=5, 97 | save_steps=40_000, 98 | hub_model_id=save_to_hub, 99 | push_to_hub=True 100 | ) 101 | 102 | trainer = Trainer( 103 | model=model, 104 | args=training_args, 105 | train_dataset=train_test["train"], 106 | eval_dataset=train_test["test"], 107 | data_collator=data_collator, 108 | ) 109 | 110 | trainer.train() 111 | trainer.push_to_hub() -------------------------------------------------------------------------------- /testing.txt: -------------------------------------------------------------------------------- 1 | # Pancreatitis 2 | - Lipase: 535 -> 154 -> 145 3 | - Managed with NBM, IV fluids 4 | - CT AP and abdo USS: normal 5 | - Likely secondary to Azathioprine - ceased, never to be used again. 6 | --------------------------------------------------------------------------------