├── .gitattributes ├── LICENSE ├── README.md ├── assets └── figure.png ├── data.py ├── eval_gen.py ├── extract_texts.py ├── gen_span_detection.py ├── main.py ├── model.py ├── options.py ├── preprocess_phenos.py └── requirements.txt /.gitattributes: -------------------------------------------------------------------------------- 1 | *.7z filter=lfs diff=lfs merge=lfs -text 2 | *.arrow filter=lfs diff=lfs merge=lfs -text 3 | *.bin filter=lfs diff=lfs merge=lfs -text 4 | *.bz2 filter=lfs diff=lfs merge=lfs -text 5 | *.ckpt filter=lfs diff=lfs merge=lfs -text 6 | *.ftz filter=lfs diff=lfs merge=lfs -text 7 | *.gz filter=lfs diff=lfs merge=lfs -text 8 | *.h5 filter=lfs diff=lfs merge=lfs -text 9 | *.joblib filter=lfs diff=lfs merge=lfs -text 10 | *.lfs.* filter=lfs diff=lfs merge=lfs -text 11 | *.mlmodel filter=lfs diff=lfs merge=lfs -text 12 | *.model filter=lfs diff=lfs merge=lfs -text 13 | *.msgpack filter=lfs diff=lfs merge=lfs -text 14 | *.npy filter=lfs diff=lfs merge=lfs -text 15 | *.npz filter=lfs diff=lfs merge=lfs -text 16 | *.onnx filter=lfs diff=lfs merge=lfs -text 17 | *.ot filter=lfs diff=lfs merge=lfs -text 18 | *.parquet filter=lfs diff=lfs merge=lfs -text 19 | *.pb filter=lfs diff=lfs merge=lfs -text 20 | *.pickle filter=lfs diff=lfs merge=lfs -text 21 | *.pkl filter=lfs diff=lfs merge=lfs -text 22 | *.pt filter=lfs diff=lfs merge=lfs -text 23 | *.pth filter=lfs diff=lfs merge=lfs -text 24 | *.rar filter=lfs diff=lfs merge=lfs -text 25 | *.safetensors filter=lfs diff=lfs merge=lfs -text 26 | saved_model/**/* filter=lfs diff=lfs merge=lfs -text 27 | *.tar.* filter=lfs diff=lfs merge=lfs -text 28 | *.tflite filter=lfs diff=lfs merge=lfs -text 29 | *.tgz filter=lfs diff=lfs merge=lfs -text 30 | *.wasm filter=lfs diff=lfs merge=lfs -text 31 | *.xz filter=lfs diff=lfs merge=lfs -text 32 | *.zip filter=lfs diff=lfs merge=lfs -text 33 | *.zst filter=lfs diff=lfs merge=lfs -text 34 | *tfevents* filter=lfs diff=lfs merge=lfs -text 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Computational Language Understanding Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MedDec: A Dataset for Extracting Medical Decisions from Discharge Summaries 2 | 3 | ![MedDec](assets/figure.png) 4 | 5 | This is the code and dataset described in **[MedDec (Elgaar et al., Findings of ACL: ACL 2024)](https://aclanthology.org/2024.findings-acl.975/)**. 6 | 7 | MedDec is the first dataset specifically developed for extracting and classifying medical decisions from clinical notes. It includes 451 expert-annotated annotated discharge summaries from the MIMIC-III dataset, offering a valuable resource for understanding and facilitating clinical decision-making. 8 | 9 | # Dataset 10 | 11 | > [!TIP] 12 | > The dataset has been released as of October 16, 2024. 13 | 14 | The dataset is available through this link: **[https://physionet.org/content/meddec/1.0.0/](https://physionet.org/content/meddec/1.0.0/)**. The user must sign a data usage agreement before accessing the dataset. 15 | 16 | ### Phenotypes Annotations 17 | 18 | The phenotype annotations used in the paper are available here: [https://physionet.org/content/phenotype-annotations-mimic/1.20.03/](https://physionet.org/content/phenotype-annotations-mimic/1.20.03/). 19 | 20 | # Prerequisites 21 | 22 | ### Requirements 23 | 24 | Install the required packages using the following command: 25 | ``` 26 | pip install -r requirements.txt 27 | ``` 28 | 29 | ### Extract Notes Text from MIMIC-III 30 | 31 | To extract the notes text from the MIMIC-III dataset, run the following command: 32 | ``` 33 | python extract_notes.py 34 | ``` 35 | - `data_dir`: Directory where the dataset is stored. 36 | - `notes_path`: Path to the `NOTEEVENTS.csv` file. The texts will be written to the `data_dir/raw_text` directory. 37 | 38 | ### Aggregate Phenotype Annotations 39 | 40 | To preprocess the phenotype annotations, run the following command: 41 | ``` 42 | python preprocess_phenotypes.py 43 | ``` 44 | - `phenotypes_path`: Path to the phenotype annotations. The aggregated annotations will be written to `phenos.csv` in the same directory as the input file. 45 | 46 | # Running the Baselines 47 | 48 | The code expects `data`, `raw_text`, and `phenos.csv` to be in the `data_dir` directory. 49 | 50 | To train the baselines, run the following command: 51 | ``` 52 | python main.py --data_dir --label_encoding multiclass --model_name google/electra-base-discriminator --total_steps 5000 --lr 4e-5 53 | ``` 54 | 55 | To evaluate the baselines, run the following command: 56 | ``` 57 | python main.py --data_dir --eval_only --ckpt ./checkpoints/[datetime]-[model_name] 58 | ``` 59 | 60 | ## Arguments 61 | 62 | - `data_dir`: The directory where the dataset is stored. The default is `./data/`. 63 | - `pheno_path`: The path to the phenotype annotations. The default is `./ACTdb102003.csv`. 64 | - `task`: `token` is the token classification task (decision extraction), and `seq` is the sequence classification task (phenotype prediction). The default is `token`. 65 | - `eval_only`: Whether to evaluate the model only. `--ckpt` should be provided. The default is `False`. 66 | - `label_encoding`: `multiclass`, `bo` (beginning inside outside), or `boe` (beginning outside end). The default is `multiclass`. 67 | - `truncate_train`: Truncate the training sequences to a maximum length. Otherwise, the sequences are randomly chunked at training time. The default is `False`. 68 | - `truncate_eval`: Truncate the evaluation sequences to a maximum length. The default is `False`. 69 | - `use_crf`: Whether to use a CRF layer. The default is `False`. 70 | - `model_name`: The name of the model from Hugging Face Transformers 71 | - `total_steps`: The number of training steps 72 | - `lr`: The learning rate 73 | - `batch_size`: The batch size 74 | - `seed`: The random seed 75 | 76 | 77 | 78 | # Citation 79 | 80 | If you use this dataset or code, please consider citing the following paper: 81 | ```bibtex 82 | @inproceedings{elgaar-etal-2024-meddec, 83 | title = "{M}ed{D}ec: A Dataset for Extracting Medical Decisions from Discharge Summaries", 84 | author = "Elgaar, Mohamed and Cheng, Jiali and Vakil, Nidhi and Amiri, Hadi and Celi, Leo Anthony", 85 | editor = "Ku, Lun-Wei and Martins, Andre and Srikumar, Vivek", 86 | booktitle = "Findings of the Association for Computational Linguistics ACL 2024", 87 | month = aug, 88 | year = "2024", 89 | address = "Bangkok, Thailand and virtual meeting", 90 | publisher = "Association for Computational Linguistics", 91 | url = "https://aclanthology.org/2024.findings-acl.975", 92 | pages = "16442--16455", 93 | } 94 | ``` 95 | 96 | Additionally, please cite the dataset as follows: 97 | ```bibtex 98 | @misc{elgaar2024meddec, 99 | title = "MedDec: Medical Decisions for Discharge Summaries in the MIMIC-III Database", 100 | author = "Elgaar, Mohamed and Cheng, Jiali and Vakil, Nidhi and Amiri, Hadi and Celi, Leo Anthony", 101 | year = "2024", 102 | version = "1.0.0", 103 | publisher = "PhysioNet", 104 | url = "https://doi.org/10.13026/nqnw-7d62" 105 | } 106 | ``` 107 | -------------------------------------------------------------------------------- /assets/figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CLU-UML/MedDec/40348abb40121e71b07ee4cdd1eb72b195e7c852/assets/figure.png -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import json 3 | import os 4 | import pandas as pd 5 | import numpy as np 6 | from torch.utils.data import Dataset, DataLoader 7 | from transformers import AutoTokenizer 8 | from glob import glob 9 | from collections.abc import Iterable 10 | from collections import defaultdict 11 | 12 | 13 | pheno_map = {'alcohol.abuse': 0, 14 | 'advanced.lung.disease': 1, 15 | 'advanced.heart.disease': 2, 16 | 'chronic.pain.fibromyalgia': 3, 17 | 'other.substance.abuse': 4, 18 | 'psychiatric.disorders': 5, 19 | 'obesity': 6, 20 | 'depression': 7, 21 | 'advanced.cancer': 8, 22 | 'chronic.neurological.dystrophies': 9, 23 | 'none': -1 24 | } 25 | rev_pheno_map = {v: k for k,v in pheno_map.items()} 26 | valid_cats = range(0,9) 27 | 28 | def gen_splits(args, phenos): 29 | if args.unseen_pheno is None: 30 | splits_dir = os.path.join(args.data_dir, 'splits') 31 | train_files = open(os.path.join(splits_dir, 'train.txt')).read().splitlines() 32 | val_files = open(os.path.join(splits_dir, 'val.txt')).read().splitlines() 33 | test_files = open(os.path.join(splits_dir, 'test.txt')).read().splitlines() 34 | return train_files, val_files, test_files 35 | 36 | 37 | np.random.seed(0) 38 | if args.task == 'token': 39 | files = glob(os.path.join(args.data_dir, 'data/**/*')) 40 | files = ["/".join(x.split('/')[-2:]) for x in files] 41 | subjects = np.unique([os.path.basename(x).split('_')[0] for x in files]) 42 | elif phenos is not None: 43 | subjects = phenos['subject_id'].unique() 44 | else: 45 | raise ValueError 46 | 47 | phenos['phenotype_label'] = phenos['phenotype_label'].apply(lambda x: x.lower()) 48 | 49 | n = len(subjects) 50 | train_count = int(0.8*n) 51 | val_count = max(0, int(0.9*n) - train_count) 52 | test_count = n - train_count - val_count 53 | 54 | train, val, test = [], [], [] 55 | np.random.shuffle(subjects) 56 | subjects = list(subjects) 57 | pheno_list = set(np.unique(list(pheno_map.keys())).tolist()) 58 | # pheno_list = set(pheno_map.keys()) 59 | if args.unseen_pheno is not None: 60 | test_phenos = {rev_pheno_map[args.unseen_pheno]} 61 | unseen_pheno = rev_pheno_map[args.unseen_pheno] 62 | train_phenos = pheno_list - test_phenos 63 | else: 64 | test_phenos = pheno_list 65 | train_phenos = pheno_list 66 | unseen_pheno = 'null' 67 | while len(subjects) > 0: 68 | if len(pheno_list) > 0: 69 | for pheno in pheno_list: 70 | if len(train) < train_count and pheno in train_phenos: 71 | el = None 72 | for i, subj in enumerate(subjects): 73 | row = phenos[phenos.subject_id == subj] 74 | if row['phenotype_label'].apply(lambda x: pheno in x and not unseen_pheno in x).any(): 75 | el = subjects.pop(i) 76 | break 77 | if el is not None: 78 | train.append(el) 79 | elif el is None: 80 | pheno_list.remove(pheno) 81 | break 82 | if len(val) < val_count and (not args.pheno_id or len(val) <= (0.5*val_count)): 83 | el = None 84 | for i, subj in enumerate(subjects): 85 | row = phenos[phenos.subject_id == subj] 86 | if row['phenotype_label'].apply(lambda x: pheno in x).any(): 87 | el = subjects.pop(i) 88 | break 89 | if el is not None: 90 | val.append(el) 91 | elif el is None: 92 | pheno_list.remove(pheno) 93 | break 94 | if len(test) < test_count or (args.unseen_pheno is not None and pheno in test_phenos): 95 | el = None 96 | for i, subj in enumerate(subjects): 97 | row = phenos[phenos.subject_id == subj] 98 | if row['phenotype_label'].apply(lambda x: pheno in x).any(): 99 | el = subjects.pop(i) 100 | break 101 | if el is not None: 102 | test.append(el) 103 | elif el is None: 104 | pheno_list.remove(pheno) 105 | break 106 | else: 107 | if len(train) < train_count: 108 | el = subjects.pop() 109 | if el is not None: 110 | train.append(el) 111 | if len(val) < val_count: 112 | el = subjects.pop() 113 | if el is not None: 114 | val.append(el) 115 | if len(test) < test_count: 116 | el = subjects.pop() 117 | if el is not None: 118 | test.append(el) 119 | 120 | if args.task == 'token': 121 | train = [x for x in files if os.path.basename(x).split('_')[0] in train] 122 | val = [x for x in files if os.path.basename(x).split('_')[0] in val] 123 | test = [x for x in files if os.path.basename(x).split('_')[0] in test] 124 | elif phenos is not None: 125 | train = phenos[phenos.subject_id.isin(train)] 126 | val = phenos[phenos.subject_id.isin(val)] 127 | test = phenos[phenos.subject_id.isin(test)] 128 | return train, val, test 129 | 130 | class MyDataset(Dataset): 131 | def __init__(self, args, tokenizer, data_source, phenos, train = False): 132 | super().__init__() 133 | self.tokenizer = tokenizer 134 | self.data = [] 135 | self.train = train 136 | self.pheno_ids = defaultdict(list) 137 | self.dec_ids = {k: [] for k in pheno_map.keys()} 138 | self.meddec_stats = pd.read_csv(os.path.join(args.data_dir, 'stats.csv')).set_index(['SUBJECT_ID', 'HADM_ID', 'ROW_ID']) 139 | self.stats = defaultdict(list) 140 | 141 | if args.task == 'seq': 142 | for i, row in data_source.iterrows(): 143 | sample = self.load_phenos(args, row, i) 144 | self.data.append(sample) 145 | else: 146 | for i, fn in enumerate(data_source): 147 | sample = self.load_decisions(args, fn, i, phenos) 148 | self.data.append(sample) 149 | 150 | def get_col(self, col): 151 | return np.array([x[col] for x in self.data]) 152 | 153 | def load_phenos(self, args, row, idx): 154 | txt_path = os.path.join(args.data_dir, f'raw_text/{row["subject_id"]}_{row["hadm_id"]}_{row["row_id"]}.txt') 155 | text = open(txt_path).read() 156 | encoding = self.tokenizer.encode_plus(text, 157 | truncation=args.truncate_train if self.train else args.truncate_eval) 158 | ids = None 159 | 160 | labels = np.zeros(args.num_phenos) 161 | 162 | sample_phenos = row['phenotype_label'] 163 | if sample_phenos != 'none': 164 | for pheno in sample_phenos.split(','): 165 | labels[pheno_map[pheno.lower()]] = 1 166 | 167 | if args.pheno_id is not None: 168 | if args.pheno_id == -1: 169 | labels = [0.0 if any(labels) else 1.0] 170 | else: 171 | labels = [labels[args.pheno_id]] 172 | 173 | return encoding['input_ids'], labels, ids 174 | 175 | def load_decisions(self, args, fn, idx, phenos): 176 | basename = os.path.splitext(os.path.basename(fn))[0] 177 | file_dir = os.path.join(args.data_dir, 'data', fn) 178 | 179 | sid, hadm, rid = map(int, basename.split('_')[:3]) 180 | txt_path = os.path.join(args.data_dir, f'raw_text/{basename}.txt') 181 | text = open(txt_path).read() 182 | encoding = self.tokenizer.encode_plus(text, 183 | max_length=args.max_len, 184 | truncation=args.truncate_train if self.train else args.truncate_eval, 185 | padding = 'max_length', 186 | ) 187 | if (sid, hadm, rid) in phenos.index: 188 | sample_phenos = phenos.loc[sid, hadm, rid]['phenotype_label'] 189 | for pheno in sample_phenos.split(','): 190 | self.pheno_ids[pheno].append(idx) 191 | 192 | 193 | with open(file_dir) as f: 194 | data = json.load(f, strict=False) 195 | annots = data['annotations'] 196 | 197 | if args.label_encoding == 'multiclass': 198 | labels = np.full(len(encoding['input_ids']), args.num_labels-1, dtype=int) 199 | else: 200 | labels = np.zeros((len(encoding['input_ids']), args.num_labels)) 201 | if not self.train: 202 | token_mask = np.ones(len(encoding['input_ids'])) 203 | all_spans = [] 204 | for annot in annots: 205 | start = int(annot['start_offset']) 206 | 207 | enc_start = encoding.char_to_token(start) 208 | i = 1 209 | while enc_start is None and i < 10: 210 | enc_start = encoding.char_to_token(start+i) 211 | i += 1 212 | if i == 10: 213 | break 214 | 215 | end = int(annot['end_offset']) 216 | enc_end = encoding.char_to_token(end) 217 | j = 1 218 | while enc_end is None and j < 10: 219 | enc_end = encoding.char_to_token(end+j) 220 | j += 1 221 | if j == 10: 222 | enc_end = len(encoding.input_ids) 223 | 224 | if enc_end == enc_start: 225 | enc_end += 1 226 | 227 | if enc_start is None or enc_end is None: 228 | raise ValueError 229 | 230 | cat = parse_cat(annot['category']) 231 | if cat: 232 | cat -= 1 233 | if cat is None or cat not in valid_cats: 234 | if annot['category'] == 'TBD' and not self.train: 235 | token_mask[enc_start:enc_end] = 0 236 | continue 237 | 238 | if args.label_encoding == 'multiclass': 239 | cat1 = cat * 2 240 | cat2 = cat * 2 + 1 241 | if not any([x in [2*y for y in range(args.num_labels//2)] for x in labels[enc_start:enc_end]]): 242 | labels[enc_start] = cat1 243 | if enc_end > enc_start + 1: 244 | labels[enc_start+1:enc_end] = cat2 245 | if not self.train: 246 | all_spans.append({'token_start': enc_start, 'token_end': enc_end-1, 'label': cat, 'text_start': start, 'text_end': end}) 247 | elif args.label_encoding == 'bo': 248 | cat1 = cat * 2 249 | cat2 = cat * 2 + 1 250 | labels[enc_start, cat1] = 1 251 | labels[enc_start+1:enc_end, cat2] = 1 252 | elif args.label_encoding == 'boe': 253 | cat1 = cat * 3 254 | cat2 = cat * 3 + 1 255 | cat3 = cat * 3 + 2 256 | labels[enc_start, cat1] = 1 257 | labels[enc_start+1:enc_end-1, cat2] = 1 258 | labels[enc_end-1, cat3] = 1 259 | else: 260 | labels[enc_start:enc_end, cat] = 1 261 | 262 | row = self.meddec_stats.loc[sid, hadm, rid] 263 | 264 | self.stats['gender'].append(row.GENDER) 265 | self.stats['ethnicity'].append(row.ETHNICITY) 266 | self.stats['language'].append(row.LANGUAGE) 267 | 268 | results = { 269 | 'input_ids': encoding['input_ids'], 270 | 'labels': labels, 271 | 't2c': encoding.token_to_chars, 272 | } 273 | if not self.train: 274 | results['all_spans'] = all_spans, 275 | results['file_name'] = fn 276 | results['token_mask'] = token_mask 277 | return results 278 | 279 | def __getitem__(self, idx): 280 | return self.data[idx] 281 | 282 | def __len__(self): 283 | return len(self.data) 284 | 285 | def parse_cat(cat): 286 | for i,c in enumerate(cat): 287 | if c.isnumeric(): 288 | if cat[i+1].isnumeric(): 289 | return int(cat[i:i+2]) 290 | return int(c) 291 | return None 292 | 293 | 294 | def load_phenos(args): 295 | phenos = pd.read_csv(os.path.join(args.data_dir, 'phenos.csv')) 296 | phenos.rename({'Ham_ID': 'HADM_ID'}, inplace=True, axis=1) 297 | phenos = phenos[phenos.phenotype_label != '?'] 298 | phenos.rename(lambda k: k.lower(), inplace=True, axis = 1) 299 | return phenos 300 | 301 | def downsample(dataset): 302 | data = dataset.data 303 | class0 = [x for x in data if x[1][0] == 0] 304 | class1 = [x for x in data if x[1][0] == 1] 305 | 306 | if len(class0) > len(class1): 307 | class0 = resample(class0, replace=False, n_samples=len(class1), random_state=0) 308 | else: 309 | class1 = resample(class1, replace=False, n_samples=len(class0), random_state=0) 310 | dataset.data = class0 + class1 311 | 312 | def upsample(dataset): 313 | data = dataset.data 314 | class0 = [x for x in data if x[1][0] == 0] 315 | class1 = [x for x in data if x[1][0] == 1] 316 | 317 | if len(class0) > len(class1): 318 | class1 = resample(class1, replace=True, n_samples=len(class0), random_state=0) 319 | else: 320 | class0 = resample(class0, replace=True, n_samples=len(class1), random_state=0) 321 | dataset.data = class0 + class1 322 | 323 | def load_tokenizer(name): 324 | return AutoTokenizer.from_pretrained(name) 325 | 326 | def load_data(args): 327 | from sklearn.utils import resample 328 | def collate_segment(batch): 329 | xs = [] 330 | ys = [] 331 | t2cs = [] 332 | has_ids = 'ids' in batch[0] 333 | if has_ids: 334 | idss = [] 335 | else: 336 | ids = None 337 | masks = [] 338 | for i in range(len(batch)): 339 | x = batch[i]['input_ids'] 340 | y = batch[i]['labels'] 341 | if has_ids: 342 | ids = batch[i]['ids'] 343 | n = len(x) 344 | if n > args.max_len: 345 | start = np.random.randint(0, n - args.max_len + 1) 346 | x = x[start:start + args.max_len] 347 | if args.task == 'token': 348 | y = y[start:start + args.max_len] 349 | if has_ids: 350 | new_ids = [] 351 | ids = [x[start:start + args.max_len] for x in ids] 352 | for subids in ids: 353 | subids = [idx for idx, x in enumerate(subids) if x] 354 | new_ids.append(subids) 355 | all_ids = set([y for x in new_ids for y in x]) 356 | nones = set(range(args.max_len)) - all_ids 357 | new_ids.append(list(nones)) 358 | mask = [1] * args.max_len 359 | elif n < args.max_len: 360 | x = np.pad(x, (0, args.max_len - n)) 361 | if args.task == 'token': 362 | y = np.pad(y, ((0, args.max_len - n), (0, 0))) 363 | mask = [1] * n + [0] * (args.max_len - n) 364 | else: 365 | mask = [1] * n 366 | xs.append(x) 367 | ys.append(y) 368 | t2cs.append(batch[i]['t2c']) 369 | if has_ids: 370 | idss.append(new_ids) 371 | masks.append(mask) 372 | 373 | xs = torch.tensor(xs) 374 | ys = torch.tensor(ys) 375 | masks = torch.tensor(masks) 376 | return {'input_ids': xs, 'labels': ys, 'ids': ids, 'mask': masks, 't2c': t2cs} 377 | 378 | def collate_full(batch): 379 | lens = [len(x['input_ids']) for x in batch] 380 | max_len = max(args.max_len, max(lens)) 381 | for i in range(len(batch)): 382 | batch[i]['input_ids'] = np.pad(batch[i]['input_ids'], (0, max_len - lens[i])) 383 | if args.task == 'token': 384 | if args.label_encoding == 'multiclass': 385 | batch[i]['labels'] = np.pad(batch[i]['labels'], (0, max_len - lens[i]), constant_values=-100) 386 | else: 387 | batch[i]['labels'] = np.pad(batch[i]['labels'], ((0, max_len - lens[i]), (0, 0))) 388 | mask = [1] * lens[i] + [0] * (max_len - lens[i]) 389 | batch[i]['mask'] = mask 390 | 391 | new_batch = {} 392 | for k in batch[0].keys(): 393 | collated = [sample[k] for sample in batch] 394 | if k in ['all_spans', 'file_name']: 395 | new_batch[k] = collated 396 | elif isinstance(batch[0][k], Iterable): 397 | new_batch[k] = torch.tensor(np.array(collated)) 398 | else: 399 | new_batch[k] = collated 400 | return new_batch 401 | 402 | tokenizer = load_tokenizer(args.model_name) 403 | args.vocab_size = tokenizer.vocab_size 404 | args.max_length = min(tokenizer.model_max_length, 512) 405 | 406 | phenos = load_phenos(args) 407 | train_files, val_files, test_files = gen_splits(args, phenos) 408 | phenos.set_index(['subject_id', 'hadm_id', 'row_id'], inplace=True) 409 | 410 | train_dataset = MyDataset(args, tokenizer, train_files, phenos, train=True) 411 | val_dataset = MyDataset(args, tokenizer, val_files, phenos) 412 | test_dataset = MyDataset(args, tokenizer, test_files, phenos) 413 | 414 | if args.resample == 'down': 415 | downsample(train_dataset) 416 | elif args.resample == 'up': 417 | upsample(train_dataset) 418 | 419 | print('Train dataset:', len(train_dataset)) 420 | print('Val dataset:', len(val_dataset)) 421 | print('Test dataset:', len(test_dataset)) 422 | 423 | train_ns = DataLoader(train_dataset, 1, False, 424 | collate_fn=collate_full, 425 | ) 426 | train_dataloader = DataLoader(train_dataset, args.batch_size, True, 427 | collate_fn=collate_segment, 428 | ) 429 | val_dataloader = DataLoader(val_dataset, 1, False, collate_fn=collate_full) 430 | test_dataloader = DataLoader(test_dataset, 1, False, collate_fn=collate_full) 431 | 432 | return train_dataloader, val_dataloader, test_dataloader, train_ns 433 | -------------------------------------------------------------------------------- /eval_gen.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from os import path 4 | 5 | # Non-exact span-match (plus-minus 50 chars) 6 | def get_tp(ys, preds, method='em'): 7 | c = 0 8 | for pred in all_preds: 9 | pred_sample, pred_cat, pred_dec = pred 10 | for sample, cat, dec in all_labels: 11 | if method == 'em': 12 | if dec == pred_dec \ 13 | and pred_cat == cat \ 14 | and pred_sample == sample: 15 | c+= 1 16 | elif method == 'approx-m': 17 | if (dec in pred_dec or pred_dec in dec) \ 18 | and pred_cat == cat \ 19 | and pred_sample == sample \ 20 | and abs(len(pred_dec.split()) - len(dec.split())) <= 10: 21 | c+= 1 22 | return c 23 | 24 | 25 | def f1_score(ys, preds, method='em'): 26 | # tp = len(preds & ys) 27 | tp = get_tp(ys, preds, method) 28 | fn = len(ys) - tp 29 | fp = len(preds) - tp 30 | f1 = (2 * tp / (2 * tp + fp + fn)) * 100 if tp + fp + fn > 0 else 0 31 | return f1 32 | 33 | def process_labels(labels, sample): 34 | output = set() 35 | for cat, decs in labels.items(): 36 | for dec in decs: 37 | output.add((sample, int(cat), dec.strip())) 38 | return output 39 | 40 | def process_preds(preds, sample, cat): 41 | output = set() 42 | for pred in preds: 43 | output.add((sample, cat, pred.strip())) 44 | return output 45 | 46 | all_labels = set() 47 | all_preds = set() 48 | for sample in os.listdir('gens/one'): 49 | labels = process_labels(json.load(open(path.join('gens/one', sample, 'labels.json'))), int(sample)) 50 | all_labels |= labels 51 | for cat in range(1, 10): 52 | preds = [x.strip() for x in open(path.join('gens/one', sample, 'cat_%d'%cat))] 53 | preds = process_preds(preds, int(sample), cat) 54 | all_preds |= preds 55 | 56 | # method = 'approx-m' 57 | method = 'em' 58 | print(f1_score(all_labels, all_preds, method)) 59 | 60 | all_labels = set() 61 | all_preds = set() 62 | for sample in os.listdir('gens/zero'): 63 | labels = process_labels(json.load(open(path.join('gens/zero', sample, 'labels.json'))), int(sample)) 64 | all_labels |= labels 65 | for cat in range(1, 10): 66 | preds = [x.strip() for x in open(path.join('gens/zero', sample, 'cat_%d'%cat))] 67 | preds = process_preds(preds, int(sample), cat) 68 | all_preds |= preds 69 | 70 | print(f1_score(all_labels, all_preds, method)) 71 | -------------------------------------------------------------------------------- /extract_texts.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | import pandas as pd 3 | import os, sys 4 | 5 | if len(sys.argv) != 3: 6 | print('Usage: python extract_texts.py ') 7 | sys.exit(1) 8 | 9 | data_dir = sys.argv[1] 10 | notes_path = sys.argv[2] 11 | text_dir = os.path.join(data_dir, 'raw_text') 12 | 13 | files = glob(os.path.join(data_dir, 'data/*.json')) 14 | 15 | notes = pd.read_csv(notes_path).set_index(['SUBJECT_ID', 'HADM_ID', 'ROW_ID']) 16 | 17 | os.makedirs(text_dir, exist_ok=True) 18 | for fn in files: 19 | sid, hadm, rid = map(int, os.path.splitext(os.path.basename(fn))[0].split('_')) 20 | note = notes.loc[sid, hadm, rid] 21 | out_fn = f'{sid}_{hadm}_{rid}.txt' 22 | with open(os.path.join(text_dir, out_fn), 'w') as f: 23 | f.write(note.TEXT) 24 | -------------------------------------------------------------------------------- /gen_span_detection.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import numpy as np 5 | from os import path 6 | from glob import glob 7 | from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed 8 | from collections import defaultdict 9 | set_seed(0) 10 | 11 | 12 | categories = [ 13 | "Contact related: Decision regarding admittance or discharge from hospital, scheduling of control and referral to other parts of the healthcare system", 14 | "Gathering additional information: Decision to obtain information from other sources than patient interview, physical examination and patient chart", 15 | "Defining problem: Complex, interpretative assessments that define what the problem is and reflect a medically informed conclusion", 16 | "Treatment goal: Decision to set defined goal for treatment and thereby being more specific than giving advice", 17 | "Drug: Decision to start, refrain from, stop, alter or maintain a drug regimen", 18 | "Therapeutic procedure related: Decision to intervene on a medical problem, plan, perform or refrain from therapeutic procedures of a medical nature", 19 | "Evaluating test result: Simple, normative assessments of clinical findings and tests", 20 | "Deferment: Decision to actively delay decision or a rejection to decide on a problem presented by a patient", 21 | "Advice and precaution: Decision to give the patient advice or precaution, thereby transferring responsibility for action from the provider to the patient", 22 | "Legal and insurance related: Medical decision concerning the patient, which is based on or restricted by legal regulations or financial arrangements", 23 | ] 24 | 25 | 26 | data_dir = '/data/mohamed/data/mimic_decisions/' 27 | test_samples = [x.strip() for x in open(path.join(data_dir, 'test.txt'))] 28 | np.random.shuffle(test_samples) 29 | 30 | def resolve_src(fn): 31 | basename = path.basename(fn).split("-")[0] 32 | txt_candidates = glob(os.path.join(data_dir, 33 | f'raw_text/{basename}*.txt')) 34 | text = open(txt_candidates[0]).read() 35 | return text 36 | 37 | model_id = "meta-llama/Meta-Llama-3-8B-Instruct" 38 | 39 | tokenizer = AutoTokenizer.from_pretrained(model_id) 40 | model = AutoModelForCausalLM.from_pretrained( 41 | model_id, 42 | torch_dtype=torch.bfloat16, 43 | device_map="auto", 44 | ) 45 | 46 | def prompt(note, cat): 47 | messages = [ 48 | {'role': 'system', 'content': f'Extract all sub-strings from the following Clinical Note that contain medical decisions within the specified category.\nPrint each sub-string in a new line.\nIf no such sub-string exists, output \"None\".\n[Clinical Note]: {note}'}, 49 | {"role": "user", "content": f"[Category]: {categories[cat-1]}"}, 50 | # {"role": "user", "content": f"Extract all sub-strings from the following Clinical Note that contain medical decisions within the specified category.\nPrint each sub-string in a new line.\nIf no such sub-string exists, output \"None\".\n[Clinical Note]: {note}\n\n[Category]: {categories[cat-1]}"}, 51 | ] 52 | 53 | input_ids = tokenizer.apply_chat_template( 54 | messages, 55 | add_generation_prompt=True, 56 | return_tensors="pt" 57 | ).to(model.device) 58 | 59 | terminators = [ 60 | tokenizer.eos_token_id, 61 | tokenizer.convert_tokens_to_ids("<|eot_id|>") 62 | ] 63 | 64 | outputs = model.generate( 65 | input_ids, 66 | max_new_tokens=1024, 67 | eos_token_id=terminators, 68 | do_sample=True, 69 | temperature=0.6, 70 | top_p=0.9, 71 | ) 72 | response = outputs[0][input_ids.shape[-1]:] 73 | return tokenizer.decode(response, skip_special_tokens=True) 74 | 75 | def prompt_oneshot(note, cat, demo_cat, demos): 76 | messages = [ 77 | {'role': 'system', 'content': f'Extract all sub-strings from the following Clinical Note that contain medical decisions within the specified category.\nPrint each sub-string in a new line.\nIf no such sub-string exists, output \"None\".\n[Clinical Note]: {note}'}, 78 | {"role": "user", "content": f"[Category]: {categories[demo_cat-1]}"}, 79 | {"role": "assistant", "content": f"{demos}"}, 80 | {"role": "user", "content": f"[Category]: {categories[cat-1]}"}, 81 | ] 82 | 83 | input_ids = tokenizer.apply_chat_template( 84 | messages, 85 | add_generation_prompt=True, 86 | return_tensors="pt" 87 | ).to(model.device) 88 | 89 | terminators = [ 90 | tokenizer.eos_token_id, 91 | tokenizer.convert_tokens_to_ids("<|eot_id|>") 92 | ] 93 | 94 | outputs = model.generate( 95 | input_ids, 96 | max_new_tokens=1024, 97 | eos_token_id=terminators, 98 | do_sample=True, 99 | temperature=0.6, 100 | top_p=0.9, 101 | ) 102 | response = outputs[0][input_ids.shape[-1]:] 103 | return tokenizer.decode(response, skip_special_tokens=True) 104 | 105 | def zeroshot(): 106 | for i, fn in enumerate(test_samples[:10]): 107 | print(i) 108 | try: 109 | annots = json.load(open(path.join(data_dir, 'data', fn)))[0]['annotations'] 110 | annots = group_annots(annots) 111 | out_dir = path.join('gens', 'zero', str(i)) 112 | os.makedirs(out_dir, exist_ok=True) 113 | json.dump(annots, open(path.join(out_dir, 'labels.json'), 'w')) 114 | for cat in range(1, 10): 115 | text = resolve_src(fn) 116 | response = prompt(text, cat) 117 | with open(path.join(out_dir, f'cat_{cat}'), 'w') as f: 118 | f.write(response) 119 | except torch.cuda.OutOfMemoryError: 120 | print('OOM') 121 | continue 122 | 123 | def parse_cat(cat): 124 | for i,c in enumerate(cat): 125 | if c.isnumeric(): 126 | if cat[i+1].isnumeric(): 127 | return int(cat[i:i+2]) 128 | return int(c) 129 | return None 130 | 131 | def group_annots(annots): 132 | new_annots = defaultdict(list) 133 | for ann in annots: 134 | cat = parse_cat(ann['category']) 135 | if cat is None: 136 | continue 137 | dec = ann['decision'] 138 | new_annots[cat].append(dec) 139 | return new_annots 140 | 141 | def get_demos(annots, cat): 142 | lens = {k: len(v) for k,v in annots.items() if k != cat} 143 | max_annot = max(lens.keys(), key=lens.get) 144 | lines = [f'* "{x}"' for x in annots[max_annot]] 145 | # lines = ['[START]'] + lines + ['[END]'] 146 | demos = '\n'.join(lines) 147 | return max_annot, demos 148 | 149 | 150 | def oneshot(): 151 | for i, fn in enumerate(test_samples[:10]): 152 | print(i) 153 | try: 154 | annots = json.load(open(path.join(data_dir, 'data', fn)))[0]['annotations'] 155 | annots = group_annots(annots) 156 | out_dir = path.join('gens', 'one', str(i)) 157 | os.makedirs(out_dir, exist_ok=True) 158 | json.dump(annots, open(path.join(out_dir, 'labels.json'), 'w')) 159 | for cat in range(1, 10): 160 | demo_cat, demos = get_demos(annots, cat) 161 | text = resolve_src(fn) 162 | response = prompt_oneshot(text, cat, demo_cat, demos) 163 | # print(response) 164 | with open(path.join(out_dir, f'cat_{cat}'), 'w') as f: 165 | f.write(response) 166 | except torch.cuda.OutOfMemoryError: 167 | print('OOM') 168 | continue 169 | 170 | zeroshot() 171 | # oneshot() 172 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import aim 3 | import json 4 | import torch 5 | import warnings 6 | import numpy as np 7 | import pandas as pd 8 | warnings.filterwarnings("ignore") 9 | 10 | from tqdm import tqdm 11 | from data import load_data 12 | from model import load_model 13 | from options import get_args 14 | 15 | mean = lambda l: sum(l)/len(l) if len(l) > 0 else .0 16 | 17 | args = get_args() 18 | 19 | device = 'cuda:%s'%args.gpu 20 | all_losses = {'train': [], 'val': [], 'test': []} 21 | 22 | 23 | def indicators_to_spans(labels, idx = None): 24 | def add_span(idx, c, start, end): 25 | span = (idx, c, start, end) 26 | spans.add(span) 27 | 28 | spans = set() 29 | if args.label_encoding == 'multiclass': 30 | num_tokens = len(labels) 31 | num_classes = args.num_labels // 2 32 | start = None 33 | cat = -1 34 | for t in range(num_tokens): 35 | prev_tag = labels[t-1] if t > 0 else args.num_labels -1 36 | cur_tag = labels[t] 37 | 38 | if start is not None and cur_tag == cat + 1: 39 | continue 40 | elif start is not None: 41 | add_span(idx, cat // 2, start, t - 1) 42 | start = None 43 | 44 | if start is None and (cur_tag in [2*x for x in range(num_classes)] 45 | or (prev_tag == (args.num_labels - 1) 46 | and cur_tag != (args.num_labels - 1))): 47 | start = t 48 | cat = int(cur_tag) // 2 * 2 49 | else: 50 | num_tokens, num_classes = labels.shape 51 | if args.label_encoding == 'bo': 52 | num_classes //= 2 53 | elif args.label_encoding == 'boe': 54 | num_classes //= 3 55 | 56 | for c in range(num_classes): 57 | start = None 58 | for t in range(num_tokens): 59 | if args.label_encoding == 'bo': 60 | if start and (labels[t, 2 * c] == 1 or labels[t, 2 * c + 1] == 0): 61 | add_span(idx, c, start, t - 1) 62 | start = None 63 | elif start and labels[t, 2 * c + 1] == 1: 64 | continue 65 | if labels[t, 2 * c] == 1: 66 | start = t 67 | elif args.label_encoding == 'boe': 68 | if not start and labels[t, 3 * c] == 1: 69 | start = t 70 | elif start and labels[t, 3 * c + 2] == 1: 71 | add_span(idx, c, start, t - 1) 72 | start = None 73 | else: 74 | if start and labels[t,c] == 1: 75 | continue 76 | elif start and labels[t,c] == 0: 77 | add_span(idx, c, start, t - 1) 78 | start = None 79 | elif labels[t,c] == 1 and t == (num_tokens - 1): 80 | span = (idx, c, -1, -1) 81 | spans.add(span) 82 | elif labels[t,c] == 1: 83 | start = t 84 | return spans 85 | 86 | 87 | def id_to_label(labels): 88 | new_labels = [] 89 | for l in labels: 90 | if l == (args.num_labels - 1): 91 | new_l = 'O' 92 | elif l % 2 == 0: 93 | new_l = 'B-%d'% (l // 2) 94 | else: 95 | new_l = 'I-%d'% (l // 2) 96 | new_labels.append(new_l) 97 | return new_labels 98 | 99 | def f1_score(ys, preds): 100 | tp = len(preds & ys) 101 | fn = len(ys) - tp 102 | fp = len(preds) - tp 103 | f1 = (2 * tp / (2 * tp + fp + fn)) * 100 if tp + fp + fn > 0 else 0 104 | return f1 105 | 106 | def recall_score(ys, preds): 107 | tp = len(preds & ys) 108 | fn = len(ys) - tp 109 | recall = tp / (tp + fn) 110 | return recall 111 | 112 | def calc_metrics_spans(ys, preds, span_ys = None): 113 | all_preds = [] 114 | all_ys = [] 115 | for i, (y, pred) in enumerate(zip(ys, preds)): 116 | pred_spans = indicators_to_spans(pred, idx = i) 117 | all_preds.append(pred_spans) 118 | if span_ys is None: 119 | y = y.squeeze() 120 | y_spans = indicators_to_spans(y, idx = i) 121 | all_ys.append(y_spans) 122 | 123 | all_preds = set().union(*all_preds) 124 | if span_ys is None: 125 | all_ys = set().union(*all_ys) 126 | else: 127 | all_ys = set(span_ys) 128 | f1 = f1_score(all_ys, all_preds) 129 | 130 | perclass = {} 131 | for c in range(args.num_decs): 132 | sub_ys = {x for x in all_ys if x[1] == c} 133 | sub_preds = {x for x in all_preds if x[1] == c} 134 | perclass[c] = f1_score(sub_ys, sub_preds) 135 | 136 | return f1, all_preds, all_ys, perclass 137 | 138 | def save_losses(model, crit, train_dataloader, val_dataloader, test_dataloader): 139 | train_losses = evaluate(model, train_dataloader, crit, return_losses = True) 140 | all_losses['train'].append(train_losses) 141 | val_losses = evaluate(model, val_dataloader, crit, return_losses = True) 142 | all_losses['val'].append(val_losses) 143 | test_losses = evaluate(model, test_dataloader, crit, return_losses = True) 144 | all_losses['test'].append(test_losses) 145 | 146 | def evaluate(model, dataloader, crit, return_losses = False, return_preds = False): 147 | model.eval() 148 | outs, ys = [], [] 149 | lens = [] 150 | token_masks = [] 151 | for batch in tqdm(dataloader, desc='Evaluation'): 152 | x = batch['input_ids'] 153 | y = batch['labels'] 154 | mask = batch['mask'] 155 | if args.task == 'seq': 156 | ids = batch['ids'] 157 | 158 | with torch.no_grad(): 159 | logits = model.generate(x, mask) 160 | 161 | outs.append(logits) 162 | lens.extend([x.shape[0] for x in logits]) 163 | ys.append(y) 164 | 165 | if 'token_mask' in batch: 166 | token_masks.append(batch['token_mask']) 167 | 168 | if args.label_encoding == 'multiclass': 169 | outs_stack = torch.cat([x.view(-1, args.num_labels) for x in outs], 0) 170 | ys_stack = torch.cat([x.view(-1) for x in ys], 0).to(device) 171 | preds = [x.squeeze() for x in outs] 172 | 173 | if args.use_crf: 174 | padded_outs = torch.nn.utils.rnn.pad_sequence(preds, batch_first=True) 175 | outs_mask = ~(padded_outs[:,:,0] == 0) 176 | preds = crit.decode(padded_outs, mask=outs_mask) 177 | preds_stack = torch.tensor([x for pred in preds for x in pred]).to(device) 178 | padded_ys = torch.nn.utils.rnn.pad_sequence([x.squeeze() for x in ys], batch_first=True) 179 | loss = -1 * crit(padded_outs, padded_ys, mask=outs_mask, reduction='mean') 180 | 181 | else: 182 | preds = [x.argmax(-1) for x in preds] 183 | preds_stack = outs_stack.argmax(-1) 184 | loss = crit(outs_stack, ys_stack) 185 | else: 186 | outs_stack = torch.cat(outs, 1) 187 | ys_stack = torch.cat(ys, 1).to(device) 188 | loss = crit(outs_stack, ys_stack) 189 | 190 | losses = [] 191 | offset = 0 192 | if return_losses: 193 | for ln in lens: 194 | sub_losses = loss[offset:offset+ln] 195 | offset += ln 196 | losses.append(sub_losses.mean().item()) 197 | return losses 198 | 199 | 200 | loss = loss.mean() 201 | 202 | y = torch.cat(ys, 1).squeeze() 203 | 204 | if len(token_masks) > 0: 205 | token_masks = torch.cat(token_masks, 1).squeeze().to(device) 206 | acc = ((ys_stack == preds_stack).float() * token_masks).sum() / token_masks.sum() * 100 207 | else: 208 | acc = (ys_stack == preds_stack).float().mean() * 100 209 | 210 | if 'all_spans' in dataloader.dataset.data[0]: 211 | all_spans = [x['all_spans'] for x in dataloader.dataset.data] 212 | span_ys = [(i, s['label'], s['token_start'], s['token_end']) for i, spans in enumerate(all_spans) for s in spans[0]] 213 | else: 214 | span_ys = None 215 | f1, span_preds, span_ys, perclass = calc_metrics_spans(ys, preds, span_ys) 216 | if return_preds: 217 | return span_preds, span_ys 218 | metrics_out = {} 219 | metrics_out['f1'] = f1 220 | metrics_out['acc'] = acc 221 | model.train() 222 | 223 | # genders = dataloader.dataset.stats['gender'] 224 | # for g in set(genders): 225 | # ids = [i for i,x in enumerate(genders) if x==g] 226 | # sub_ys = torch.cat([x for i,x in enumerate(ys) if i in ids], 1).squeeze().cpu() 227 | # sub_preds = torch.cat([x for i,x in enumerate(preds) if i in ids]).cpu() 228 | # sub_acc = (sub_ys == sub_preds).float().mean() * 100 229 | # print(g, sub_acc) 230 | 231 | # ethnicities = dataloader.dataset.stats['ethnicity'] 232 | # for e in set(ethnicities): 233 | # ids = [i for i,x in enumerate(ethnicities) if x==e] 234 | # sub_ys = torch.cat([x for i,x in enumerate(ys) if i in ids], 1).squeeze().cpu() 235 | # sub_preds = torch.cat([x for i,x in enumerate(preds) if i in ids]).cpu() 236 | # sub_acc = (sub_ys == sub_preds).float().mean() * 100 237 | # print(e, sub_acc) 238 | 239 | # langs = dataloader.dataset.stats['language'] 240 | # for l in set(langs): 241 | # ids = [i for i,x in enumerate(langs) if x==l] 242 | # sub_ys = torch.cat([x for i,x in enumerate(ys) if i in ids], 1).squeeze().cpu() 243 | # sub_preds = torch.cat([x for i,x in enumerate(preds) if i in ids]).cpu() 244 | # sub_acc = (sub_ys == sub_preds).float().mean() * 100 245 | # print(l, sub_acc) 246 | 247 | if args.task == 'token': 248 | pheno_results = {} 249 | for pheno, ids in dataloader.dataset.pheno_ids.items(): 250 | sub_ys = [x for i,x in enumerate(ys) if i in ids] 251 | sub_preds = [x for i,x in enumerate(preds) if i in ids] 252 | f1, span_preds, span_ys, _ = calc_metrics_spans(sub_ys, sub_preds) 253 | pheno_results[pheno] = f1 254 | else: 255 | pheno_results = None 256 | return metrics_out, pheno_results, loss, perclass 257 | 258 | def process(sample, model, tokenizer, out_dir): 259 | hadm = sample['HADM_ID'] 260 | fn = f"{sample['SUBJECT_ID']}_{int(hadm) if not pd.isnull(hadm) else 'NaN'}_{sample['ROW_ID']}.json" 261 | out_file = out_dir + fn 262 | if not os.path.exists(out_file): 263 | encoding = tokenizer.encode_plus(sample['TEXT']) 264 | x = torch.tensor(encoding['input_ids']).unsqueeze(0).to(device) 265 | mask = torch.tensor(encoding['attention_mask']).unsqueeze(0).to(device) 266 | 267 | with torch.no_grad(): 268 | out = model.generate(x, mask) 269 | 270 | if args.label_encoding == 'multiclass': 271 | pred = out.argmax(-1) 272 | else: 273 | pred = torch.where(out > 0, 1, 0) 274 | pred = pred.squeeze() 275 | spans = indicators_to_spans(pred) 276 | all_spans = [] 277 | for _, cat, start, end in spans: 278 | span_dict = {} 279 | span_dict['decision'] = sample['TEXT'][start:end] 280 | span_dict['category'] = 'Category %d'%(cat+1) 281 | span_dict['start_offset'] = start 282 | span_dict['end_offset'] = end 283 | all_spans.append(span_dict) 284 | with open(out_file, 'w') as f: 285 | json.dump(all_spans, f) 286 | 287 | def predict_mimic(model, data, tokenizer): 288 | 289 | model.eval() 290 | outs = [] 291 | out_dir = './all_mimic_decisions/' 292 | kwargs = {'model': model, 'tokenizer': tokenizer, 'out_dir': out_dir} 293 | import multiprocessing as mp 294 | from functools import partial 295 | mp.set_start_method('spawn', force=True) 296 | process_ = partial(process, **kwargs) 297 | pool = mp.Pool(3) 298 | for _ in tqdm(pool.imap_unordered(process_, data, chunksize=20000), total=len(data)): 299 | pass 300 | pool.close() 301 | # for sample in tqdm(data): 302 | # process(sample) 303 | 304 | def train(args, model, crit, optimizer, lr_scheduler, 305 | train_dataloader, val_dataloader, verbose=True, train_ns=None, test_dataloader=None): 306 | writer = aim.Run(experiment=args.aim_exp, repo=args.aim_repo, 307 | system_tracking_interval=0) if not args.debug else None 308 | if writer is not None: 309 | writer['hparams'] = args.__dict__ 310 | 311 | step = 0 312 | best_f1 = -1 313 | best_acc = 0 314 | best_step = 0 315 | best_pheno = None 316 | best_perclass = None 317 | train_iter = iter(train_dataloader) 318 | losses = [] 319 | while step < args.total_steps: 320 | batch = next(train_iter, None) 321 | if batch is None: 322 | train_iter = iter(train_dataloader) 323 | continue 324 | x = batch['input_ids'] 325 | y = batch['labels'] 326 | mask = batch['mask'] 327 | 328 | y = y.to(device) 329 | if args.task == 'seq': 330 | out, _ = model.phenos(x, mask) 331 | logits = out[1] 332 | elif args.task == 'token': 333 | out, logits = model.decisions(x, mask) 334 | 335 | if args.label_encoding == 'multiclass': 336 | if args.use_crf: 337 | loss = -1 * crit(logits, y, reduction='mean') 338 | else: 339 | loss = crit(logits.view(-1, args.num_labels), y.view(-1)).mean() 340 | else: 341 | loss = crit(logits, y).mean() 342 | total_loss = loss 343 | 344 | 345 | losses.append(loss.item()) 346 | total_loss /= args.grad_accumulation 347 | total_loss.backward(retain_graph=True) 348 | 349 | if (step+1) % args.grad_accumulation == 0: 350 | optimizer.step() 351 | optimizer.zero_grad() 352 | lr_scheduler.step() 353 | 354 | if step % (args.train_log*args.grad_accumulation) == 0: 355 | avg_loss = np.mean(losses) 356 | if verbose: 357 | print('step %d - training loss: %.3f'%(step, avg_loss)) 358 | if writer is not None: 359 | writer.track(avg_loss, name='bce_loss', context={'split': 'train'}, step = step) 360 | losses = [] 361 | 362 | if len(val_dataloader) > 0 and step % (args.val_log*args.grad_accumulation) == 0: 363 | if args.save_losses: 364 | save_losses(model, crit, train_ns, val_dataloader, test_dataloader) 365 | metrics_out, pheno_results, loss, perclass = evaluate(model, val_dataloader, crit) 366 | f1, acc = metrics_out['f1'], metrics_out['acc'] 367 | if verbose: 368 | print('[step: {:5d}] f1: {:.1f}, acc: {:.1f}, loss: {:.3f}' 369 | .format(step, f1, acc, loss)) 370 | if writer is not None: 371 | writer.track(loss, name='bce_loss', context={'split': 'val'}, step = step) 372 | writer.track(f1, name='f1', step = step) 373 | # writer.track(prec, name='precision', step = step) 374 | # writer.track(rec, name='recall', step = step) 375 | if f1 > best_f1: 376 | best_f1 = f1 377 | best_acc = acc 378 | best_step = step 379 | best_pheno = pheno_results 380 | # best_perclass = metrics_out[4:6] 381 | if not args.debug: 382 | torch.save(model.state_dict(), args.ckpt_dir) 383 | step += 1 384 | if writer is not None: 385 | writer.track(best_f1, name = 'best_f1') 386 | writer.track(best_step, name = 'best_step') 387 | if best_pheno is not None: 388 | for pheno, f1 in best_pheno.items(): 389 | writer.track(f1, name='best_f1', context={'group': pheno}) 390 | if args.task == 'token': 391 | f1s = best_perclass 392 | for i in range(len(f1s)): 393 | writer.track(f1s[i], name='best_f1', context={'decision': i}) 394 | return best_f1, best_acc, best_step 395 | 396 | def main(args): 397 | f1s = [] 398 | for seed in args.seed: 399 | torch.manual_seed(seed) 400 | np.random.seed(seed) 401 | args.seed = seed 402 | train_dataloader, val_dataloader, test_dataloader, train_ns = load_data(args) 403 | model, crit, optimizer, lr_scheduler = load_model(args, device) 404 | 405 | if not args.eval_only: 406 | f1, acc, step = train(args, model, crit, 407 | optimizer, lr_scheduler, train_dataloader, 408 | val_dataloader, args.verbose, train_ns, test_dataloader) 409 | f1s.append(f1) 410 | print('seed: %d, F1: %.1f, Acc: %.1f'%(seed, f1, acc)) 411 | # Test 412 | metrics_out, pheno_results, loss, perclass = evaluate(model, test_dataloader, crit) 413 | f1, acc = metrics_out['f1'], metrics_out['acc'] 414 | print('[Test] f1: {:.1f}, acc: {:.1f}, loss: {:.3f}' 415 | .format(f1, acc, loss)) 416 | # print(pheno_results) 417 | print(perclass) 418 | else: 419 | model.eval() 420 | # Train 421 | # metrics_out, pheno_results, loss = evaluate(model, train_ns, crit) 422 | # f1, acc = metrics_out['f1'], metrics_out['acc'] 423 | # print('[Train] f1: {:.1f}, acc: {:.1f}, loss: {:.3f}' 424 | # .format(f1, acc, loss)) 425 | 426 | # Val 427 | # metrics_out, pheno_results, loss, perclass = evaluate(model, val_dataloader, crit) 428 | # f1, acc = metrics_out['f1'], metrics_out['acc'] 429 | # print('[Val] f1: {:.1f}, acc: {:.1f}, loss: {:.3f}' 430 | # .format(f1, acc, loss)) 431 | 432 | # Test 433 | metrics_out, pheno_results, loss, perclass = evaluate(model, test_dataloader, crit) 434 | f1, acc = metrics_out['f1'], metrics_out['acc'] 435 | print('[Test] f1: {:.1f}, acc: {:.1f}, loss: {:.3f}' 436 | .format(f1, acc, loss)) 437 | # print(pheno_results) 438 | print(perclass) 439 | 440 | # predict_mimic(model, data, tokenizer) 441 | if args.save_losses: 442 | np.savez('losses_%d.npz'%seed, train=all_losses['train'], val=all_losses['val'], test=all_losses['test']) 443 | return np.mean(f1s) 444 | 445 | if __name__ == '__main__': 446 | main(args) 447 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from transformers import AutoModel 4 | from torch.optim import AdamW 5 | from transformers import get_linear_schedule_with_warmup 6 | 7 | class MyModel(nn.Module): 8 | def __init__(self, args, backbone): 9 | super().__init__() 10 | self.args = args 11 | self.backbone = backbone 12 | self.cls_id = 0 13 | hidden_dim = self.backbone.config.hidden_size 14 | self.classifier = nn.Sequential( 15 | nn.Dropout(0.1), 16 | nn.Linear(hidden_dim, args.num_labels) 17 | ) 18 | 19 | def forward(self, x, mask): 20 | x = x.to(self.backbone.device) 21 | mask = mask.to(self.backbone.device) 22 | out = self.backbone(x, attention_mask = mask, output_attentions=True) 23 | return out, self.classifier(out.last_hidden_state) 24 | 25 | def decisions(self, x, mask): 26 | x = x.to(self.backbone.device) 27 | mask = mask.to(self.backbone.device) 28 | out = self.backbone(x, attention_mask = mask, output_attentions=False) 29 | return out, self.classifier(out.last_hidden_state) 30 | 31 | def phenos(self, x, mask): 32 | x = x.to(self.backbone.device) 33 | mask = mask.to(self.backbone.device) 34 | out = self.backbone(x, attention_mask = mask, output_attentions=True) 35 | return out, self.classifier(out.pooler_output) 36 | 37 | def generate(self, x, mask, choice=None): 38 | outs = [] 39 | if self.args.task == 'seq' or choice == 'seq': 40 | for i, offset in enumerate(range(0, x.shape[1], self.args.max_len-1)): 41 | if i == 0: 42 | segment = x[:, offset:offset + self.args.max_len-1] 43 | segment_mask = mask[:, offset:offset + self.args.max_len-1] 44 | else: 45 | segment = torch.cat((torch.ones((x.shape[0], 1), dtype=int).to(x.device)\ 46 | *self.cls_id, 47 | x[:, offset:offset + self.args.max_len-1]), axis=1) 48 | segment_mask = torch.cat((torch.ones((mask.shape[0], 1)).to(mask.device), 49 | mask[:, offset:offset + self.args.max_len-1]), axis=1) 50 | logits = self.phenos(segment, segment_mask)[1] 51 | outs.append(logits) 52 | 53 | return torch.max(torch.stack(outs, 1), 1).values 54 | elif self.args.task == 'token': 55 | for i, offset in enumerate(range(0, x.shape[1], self.args.max_len)): 56 | segment = x[:, offset:offset + self.args.max_len] 57 | segment_mask = mask[:, offset:offset + self.args.max_len] 58 | h = self.decisions(segment, segment_mask)[0].last_hidden_state 59 | outs.append(h) 60 | h = torch.cat(outs, 1) 61 | return self.classifier(h) 62 | 63 | 64 | def load_model(args, device): 65 | if args.model == 'lstm': 66 | model = LSTM(args).to(device) 67 | model.device = device 68 | elif args.model == 'cnn': 69 | model = CNN(args).to(device) 70 | model.device = device 71 | else: 72 | model = MyModel(args, AutoModel.from_pretrained(args.model_name)).to(device) 73 | if args.ckpt: 74 | model.load_state_dict(torch.load(args.ckpt, map_location=device), strict=True) 75 | if args.label_encoding == 'multiclass': 76 | if args.use_crf: 77 | from torchcrf import CRF 78 | crit = CRF(args.num_labels, batch_first = True).to(device) 79 | else: 80 | crit = nn.CrossEntropyLoss(reduction='none') 81 | else: 82 | crit = nn.BCEWithLogitsLoss( 83 | pos_weight=torch.ones(args.num_labels).to(device)*args.pos_weight, 84 | reduction='none' 85 | ) 86 | optimizer = AdamW(model.parameters(), lr=args.lr) 87 | lr_scheduler = get_linear_schedule_with_warmup(optimizer, 88 | int(0.1*args.total_steps), args.total_steps) 89 | 90 | return model, crit, optimizer, lr_scheduler 91 | -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from datetime import datetime 4 | 5 | def get_args(): 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--data_dir', default='./data/') 8 | parser.add_argument('--ckpt') 9 | parser.add_argument('--aim_repo', default='.') 10 | parser.add_argument('--aim_exp', default='mimic-decisions-1215') 11 | parser.add_argument('--label_encoding', default='multiclass') 12 | parser.add_argument('--debug', action='store_true') 13 | parser.add_argument('--save_losses', action='store_true') 14 | parser.add_argument('--task', default='token', choices=['seq', 'token']) 15 | parser.add_argument('--max_len', type=int, default=512) 16 | parser.add_argument('--model', default='roberta-base',) 17 | parser.add_argument('--model_name', default='google/electra-base-discriminator',) 18 | parser.add_argument('--gpu', default='0') 19 | parser.add_argument('--grad_accumulation', default=2, type=int) 20 | parser.add_argument('--pheno_id', type=int) 21 | parser.add_argument('--unseen_pheno', type=int) 22 | parser.add_argument('--total_steps', type=int, default=5000) 23 | parser.add_argument('--train_log', type=int, default=500) 24 | parser.add_argument('--val_log', type=int, default=1000) 25 | parser.add_argument('--seed', default = '0') 26 | parser.add_argument('--num_phenos', type=int, default=10) 27 | parser.add_argument('--num_decs', type=int, default=9) 28 | parser.add_argument('--batch_size', type=int, default=8) 29 | parser.add_argument('--pos_weight', type=float, default=1.0) 30 | parser.add_argument('--truncate_train', action='store_true') 31 | parser.add_argument('--truncate_eval', action='store_true') 32 | parser.add_argument('--load_ckpt', action='store_true') 33 | parser.add_argument('--eval_only', action='store_true') 34 | parser.add_argument('--lr', type=float, default=4e-5) 35 | parser.add_argument('--resample', default='') 36 | parser.add_argument('--verbose', type=bool, default=True) 37 | parser.add_argument('--use_crf', type=bool) 38 | 39 | 40 | args = parser.parse_args() 41 | 42 | curtime = datetime.now().strftime('%m%d_%H-%M-%S') 43 | args.ckpt_dir = './checkpoints/%s-%s-%s'%(curtime, os.path.basename(args.model_name), args.model) 44 | args.seed = [int(x) for x in args.seed.split(',')] 45 | 46 | if args.task == 'seq' and args.pheno_id is not None: 47 | args.num_labels = 1 48 | elif args.task == 'seq': 49 | args.num_labels = args.num_phenos 50 | elif args.task == 'token': 51 | args.num_labels = args.num_decs 52 | if args.label_encoding == 'multiclass': 53 | args.num_labels = args.num_labels * 2 + 1 54 | elif args.label_encoding == 'bo': 55 | args.num_labels *= 2 56 | elif args.label_encoding == 'boe': 57 | args.num_labels *= 3 58 | 59 | return args 60 | -------------------------------------------------------------------------------- /preprocess_phenos.py: -------------------------------------------------------------------------------- 1 | # Aggregate the annotations of the same SUBJECT_ID, HADM_ID, ROW_ID into a single phenotype label. 2 | # If there is only one annotation, use that as the gold data. 3 | # If there are multiple annotations, prioritize annotators based on the following order: 4 | # DAG or PAT are prioritized over all other annotators. 5 | # JF or JTW are prioritized over ETM and JW. 6 | # If there are multiple annotations from the same prioritized annotator, use the sum of the annotations. 7 | # If one of the annotations is NONE or UNSURE, the phenotype label is '?'. 8 | 9 | import pandas as pd 10 | import os, sys 11 | 12 | def aggregate_annotations(df): 13 | def process_group(group): 14 | if len(group) == 1: 15 | return process_single_annotation(group.iloc[0]) 16 | else: 17 | return process_multiple_annotations(group) 18 | 19 | def process_single_annotation(row): 20 | phenotype_label = [col if col != 'UNSURE' else '?' for col in PHENOTYPE_COLUMNS if row[col] > 0] 21 | return pd.Series({'phenotype_label': ','.join(phenotype_label) if phenotype_label else '?'}) 22 | 23 | def process_multiple_annotations(group): 24 | priority_operators = [['DAG', 'PAT'], ['JTW', 'JF'], ['ETM', 'JW']] 25 | for operator_group in priority_operators: 26 | if group['OPERATOR'].isin(operator_group).any(): 27 | selected_rows = group[group['OPERATOR'].isin(operator_group)] 28 | break 29 | else: 30 | selected_rows = group 31 | 32 | selected_rows_unique = selected_rows.drop(['BATCH.ID', 'OPERATOR'], axis=1).drop_duplicates() 33 | 34 | if len(selected_rows_unique) > 1 and selected_rows_unique['NONE'].sum() > 0: 35 | return pd.Series({ 36 | 'phenotype_label': '?', 37 | 'OPERATOR': ','.join(selected_rows['OPERATOR'].unique()) 38 | }) 39 | 40 | # Sum over phenotype_columns and keep other unchanged 41 | selected_rows_unique = selected_rows_unique.sum() 42 | selected_rows_unique['OPERATOR'] = ','.join(selected_rows['OPERATOR'].unique()) 43 | 44 | # selected_rows = selected_rows_unique.sum() 45 | phenotype_label = [col if col != 'UNSURE' else '?' for col in PHENOTYPE_COLUMNS if selected_rows_unique[col] > 0] 46 | return pd.Series({'phenotype_label': ','.join(sorted(phenotype_label)) if phenotype_label else '?', 47 | 'OPERATOR': ",".join(selected_rows['OPERATOR'].unique())}) 48 | 49 | return df.groupby(['SUBJECT_ID', 'HADM_ID', 'ROW_ID']).apply(process_group).reset_index() 50 | 51 | 52 | 53 | 54 | # Constants 55 | PHENOTYPE_COLUMNS = ['ADVANCED.CANCER', 'ADVANCED.HEART.DISEASE', 'ADVANCED.LUNG.DISEASE', 'ALCOHOL.ABUSE', 56 | 'CHRONIC.NEUROLOGICAL.DYSTROPHIES', 'CHRONIC.PAIN.FIBROMYALGIA', 'DEPRESSION', 'OBESITY', 57 | 'OTHER.SUBSTANCE.ABUSE', 'PSYCHIATRIC.DISORDERS', 'NONE'] 58 | 59 | # Paths 60 | if len(sys.argv) != 2: 61 | print('Usage: python preprocess_phenos.py ') 62 | sys.exit(1) 63 | INPUT_FILE = sys.argv[1] 64 | OUTPUT_FILE = os.path.join(os.path.dirname(INPUT_FILE), 'phenos.csv') 65 | 66 | # Main execution 67 | if __name__ == "__main__": 68 | # Read and preprocess data 69 | df = pd.read_csv(INPUT_FILE) 70 | df['PSYCHIATRIC.DISORDERS'] = df['DEMENTIA'] | df['DEVELOPMENTAL.DELAY.RETARDATION'] | df['SCHIZOPHRENIA.AND.OTHER.PSYCHIATRIC.DISORDERS'] 71 | df = df[['SUBJECT_ID', 'HADM_ID', 'ROW_ID'] + PHENOTYPE_COLUMNS + ['OPERATOR', 'BATCH.ID']] 72 | 73 | # Aggregate annotations 74 | result_df = aggregate_annotations(df) 75 | result_df.to_csv(OUTPUT_FILE, index=False) 76 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aim==3.19.3 2 | datasets==2.19.0 3 | nltk==3.8.1 4 | numpy==2.1.0 5 | pandas==2.2.2 6 | quickumls==1.4.2 7 | scikit_learn==1.4.2 8 | seqeval==1.2.2 9 | spacy==3.3.0 10 | torch==2.2.1 11 | tqdm==4.65.0 12 | transformers==4.39.3 13 | --------------------------------------------------------------------------------