├── .gitignore ├── Jupyter Notebooks ├── Active Learning.ipynb ├── Get Results.ipynb └── Train the Model.ipynb ├── README.md ├── chexpert_approximator ├── .gitignore ├── __init__.py ├── data_processor.py ├── model.py ├── reload_and_get_logits.py └── run_classifier.py └── env.yml /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # out 107 | out/ 108 | 109 | # tags 110 | tags 111 | -------------------------------------------------------------------------------- /Jupyter Notebooks/Train the Model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Imports" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "%load_ext autoreload\n", 17 | "import sys\n", 18 | "\n", 19 | "sys.path.append('..')" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "metadata": {}, 26 | "outputs": [ 27 | { 28 | "name": "stdout", 29 | "output_type": "stream", 30 | "text": [ 31 | "Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.\n" 32 | ] 33 | } 34 | ], 35 | "source": [ 36 | "%autoreload\n", 37 | "\n", 38 | "import itertools, os, pickle, pandas as pd\n", 39 | "from collections import Counter\n", 40 | "\n", 41 | "from chexpert_approximator.data_processor import *\n", 42 | "from chexpert_approximator.run_classifier import *\n", 43 | "\n", 44 | "from chexpert_approximator.reload_and_get_logits import *" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 3, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "# We can't display MIMIC-CXR Output:\n", 54 | "\n", 55 | "DO_BLIND = True\n", 56 | "def blind_display(df):\n", 57 | " if DO_BLIND: \n", 58 | " df = df.copy()\n", 59 | " index_levels = df.index.names\n", 60 | " df.reset_index('rad_id', inplace=True)\n", 61 | " df['rad_id'] = [0 for _ in df['rad_id']]\n", 62 | " df.set_index('rad_id', append=True, inplace=True)\n", 63 | " df = df.reorder_levels(index_levels, axis=0)\n", 64 | "\n", 65 | " for c in df.columns:\n", 66 | " if pd.api.types.is_string_dtype(df[c]): df[c] = ['SAMPLE' for _ in df[c]]\n", 67 | " else: df[c] = np.NaN\n", 68 | "\n", 69 | " display(df.head())" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "metadata": {}, 75 | "source": [ 76 | "# Load the Data" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 4, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "DATA_DIR = '/scratch/chexpert_approximator/processed_data/' # INSERT YOUR DATA DIR HERE!\n", 86 | "# DATA MUST BE STORED IN A FILE `inputs.hdf` under key `with_folds`.\n", 87 | "INPUT_PATH, INPUT_KEY = os.path.join(DATA_DIR, 'inputs.hdf'), 'with_folds'\n", 88 | "\n", 89 | "# YOUR CLINICAL BERT MODEL GOES HERE\n", 90 | "BERT_MODEL_PATH = (\n", 91 | " '/data/medg/misc/clinical_BERT/cliniBERT/models/pretrained_bert_tf/bert_pretrain_output_all_notes_300000/'\n", 92 | ")\n", 93 | "\n", 94 | "# THIS IS WHERE YOUR PRE-TRAINED CHEXPERT++ MODEL WILL BE WRITTEN\n", 95 | "OUT_CXPPP_DIR = '../out/run_1'\n", 96 | "\n", 97 | "# DON'T MODIFY THESE\n", 98 | "FOLD = 'Fold'\n", 99 | "\n", 100 | "KEY = {\n", 101 | " 0: 'No Mention',\n", 102 | " 1: 'Uncertain Mention',\n", 103 | " 2: 'Negative Mention',\n", 104 | " 3: 'Positive Mention',\n", 105 | "}\n", 106 | "INV_KEY = {v: k for k, v in KEY.items()}" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 5, 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "inputs = pd.read_hdf(INPUT_PATH, INPUT_KEY)\n", 116 | "label_cols = [col for col in inputs.index.names if col not in ('rad_id', FOLD)]" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 6, 122 | "metadata": {}, 123 | "outputs": [ 124 | { 125 | "data": { 126 | "text/html": [ 127 | "
\n", 128 | "\n", 141 | "\n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | "
sentence
rad_idNo FindingEnlarged CardiomediastinumCardiomegalyLung LesionAirspace OpacityEdemaConsolidationPneumoniaAtelectasisPneumothoraxPleural EffusionPleural OtherFractureSupport DevicesFold
0Positive MentionNo MentionNo MentionNo MentionNo MentionNo MentionNo MentionNo MentionNo MentionNo MentionNo MentionNo MentionNo MentionNo Mention6SAMPLE
Negative MentionNo MentionNo MentionNo MentionNo MentionNo MentionNo MentionNo MentionNo MentionNo MentionNo MentionNo MentionNo Mention6SAMPLE
No MentionNo MentionNo MentionNo MentionNo MentionNo MentionNo MentionNo MentionNo MentionNo MentionNo MentionNo MentionNo Mention6SAMPLE
4SAMPLE
No MentionNo MentionNo MentionNo MentionPositive MentionNo MentionNo MentionNo MentionNo MentionNo MentionNo MentionNo MentionNo MentionNo Mention4SAMPLE
\n", 260 | "
" 261 | ], 262 | "text/plain": [ 263 | " sentence\n", 264 | "rad_id No Finding Enlarged Cardiomediastinum Cardiomegaly Lung Lesion Airspace Opacity Edema Consolidation Pneumonia Atelectasis Pneumothorax Pleural Effusion Pleural Other Fracture Support Devices Fold \n", 265 | "0 Positive Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention 6 SAMPLE\n", 266 | " Negative Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention 6 SAMPLE\n", 267 | " No Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention 6 SAMPLE\n", 268 | " 4 SAMPLE\n", 269 | " No Mention No Mention No Mention No Mention Positive Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention 4 SAMPLE" 270 | ] 271 | }, 272 | "metadata": {}, 273 | "output_type": "display_data" 274 | } 275 | ], 276 | "source": [ 277 | "blind_display(inputs)" 278 | ] 279 | }, 280 | { 281 | "cell_type": "markdown", 282 | "metadata": {}, 283 | "source": [ 284 | "# Model\n", 285 | "## Data Processor" 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": 44, 291 | "metadata": {}, 292 | "outputs": [], 293 | "source": [ 294 | "class CheXpertProcessor(DataProcessor):\n", 295 | " def __init__(self, tuning_fold, held_out_fold):\n", 296 | " super().__init__()\n", 297 | " self.tuning_fold, self.held_out_fold = tuning_fold, held_out_fold\n", 298 | " \n", 299 | " \"\"\"Processor for the CheXpert approximator.\n", 300 | " Honestly this is kind of silly, as it never stores internal state.\"\"\"\n", 301 | " def get_train_examples(self, df): return self._create_examples(\n", 302 | " df, set([f for f in range(K) if f not in (self.tuning_fold, self.held_out_fold)])\n", 303 | " )\n", 304 | " def get_dev_examples(self, df): return self._create_examples(df, set([self.tuning_fold]))\n", 305 | " def get_examples(self, df, folds=[]): return self._create_examples(df, set(folds))\n", 306 | " \n", 307 | " def get_labels(self): return {label: list(range(4)) for label in label_cols}\n", 308 | "\n", 309 | " def _create_examples(self, df, folds):\n", 310 | " \"\"\"Creates examples for the training and dev sets.\"\"\"\n", 311 | " df = df[df.index.get_level_values(FOLD).isin(folds)]\n", 312 | " lmap = {l: i for i, l in enumerate(df.index.names)}\n", 313 | " \n", 314 | " examples = []\n", 315 | " for idx, r in df.iterrows():\n", 316 | " labels = {l: INV_KEY[idx[lmap[l]]] for l in label_cols}\n", 317 | " \n", 318 | " examples.append(InputExample(\n", 319 | " guid=str(idx[lmap['rad_id']]), text_a=r.sentence, text_b=None, label=labels\n", 320 | " ))\n", 321 | " return examples" 322 | ] 323 | }, 324 | { 325 | "cell_type": "markdown", 326 | "metadata": {}, 327 | "source": [ 328 | "## Running the Model" 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": 45, 334 | "metadata": {}, 335 | "outputs": [], 336 | "source": [ 337 | "processor = CheXpertProcessor(8, 9)\n", 338 | "# train_examples = processor.get_train_examples(inputs)\n", 339 | "# dev_examples = processor.get_dev_examples(inputs)" 340 | ] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "execution_count": 13, 345 | "metadata": {}, 346 | "outputs": [], 347 | "source": [ 348 | "os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2'" 349 | ] 350 | }, 351 | { 352 | "cell_type": "code", 353 | "execution_count": null, 354 | "metadata": { 355 | "scrolled": false 356 | }, 357 | "outputs": [ 358 | { 359 | "name": "stdout", 360 | "output_type": "stream", 361 | "text": [ 362 | "/data/medg/misc/clinical_BERT/cliniBERT/models/pretrained_bert_tf/bert_pretrain_output_all_notes_300000/\n", 363 | "Max Sequence Length: 112\n" 364 | ] 365 | }, 366 | { 367 | "data": { 368 | "application/vnd.jupyter.widget-view+json": { 369 | "model_id": "", 370 | "version_major": 2, 371 | "version_minor": 0 372 | }, 373 | "text/plain": [ 374 | "HBox(children=(IntProgress(value=0, description='Epoch', max=5, style=ProgressStyle(description_width='initial…" 375 | ] 376 | }, 377 | "metadata": {}, 378 | "output_type": "display_data" 379 | }, 380 | { 381 | "data": { 382 | "application/vnd.jupyter.widget-view+json": { 383 | "model_id": "", 384 | "version_major": 2, 385 | "version_minor": 0 386 | }, 387 | "text/plain": [ 388 | "HBox(children=(IntProgress(value=0, description='Iteration', max=18840, style=ProgressStyle(description_width=…" 389 | ] 390 | }, 391 | "metadata": {}, 392 | "output_type": "display_data" 393 | }, 394 | { 395 | "name": "stderr", 396 | "output_type": "stream", 397 | "text": [ 398 | "/scratch/conda_envs/chexpert_approximator/lib/python3.7/site-packages/torch/nn/parallel/_functions.py:61: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", 399 | " warnings.warn('Was asked to gather along dimension 0, but all '\n" 400 | ] 401 | }, 402 | { 403 | "data": { 404 | "application/vnd.jupyter.widget-view+json": { 405 | "model_id": "", 406 | "version_major": 2, 407 | "version_minor": 0 408 | }, 409 | "text/plain": [ 410 | "HBox(children=(IntProgress(value=0, description='Iteration', max=18840, style=ProgressStyle(description_width=…" 411 | ] 412 | }, 413 | "metadata": {}, 414 | "output_type": "display_data" 415 | }, 416 | { 417 | "data": { 418 | "application/vnd.jupyter.widget-view+json": { 419 | "model_id": "", 420 | "version_major": 2, 421 | "version_minor": 0 422 | }, 423 | "text/plain": [ 424 | "HBox(children=(IntProgress(value=0, description='Iteration', max=18840, style=ProgressStyle(description_width=…" 425 | ] 426 | }, 427 | "metadata": {}, 428 | "output_type": "display_data" 429 | } 430 | ], 431 | "source": [ 432 | "out = build_and_train(\n", 433 | " inputs,\n", 434 | " bert_model = BERT_MODEL_PATH,\n", 435 | " processor = processor,\n", 436 | " task_dimensions = {l: 4 for l in label_cols},\n", 437 | " output_dir = OUT_CXPPP_DIR,\n", 438 | " gradient_accumulation_steps = 1,\n", 439 | " gpu = '0,1,2',\n", 440 | " do_train = True,\n", 441 | " do_eval = True,\n", 442 | " seed = 42,\n", 443 | " do_lower_case = False,\n", 444 | " max_seq_length = 128,\n", 445 | " train_batch_size = 32,\n", 446 | " eval_batch_size = 8,\n", 447 | " learning_rate = 5e-5,\n", 448 | " num_train_epochs = 5,\n", 449 | " warmup_proportion = 0.1,\n", 450 | ")" 451 | ] 452 | }, 453 | { 454 | "cell_type": "code", 455 | "execution_count": 16, 456 | "metadata": {}, 457 | "outputs": [ 458 | { 459 | "data": { 460 | "text/plain": [ 461 | "{'eval_loss': 0.04080065189753456,\n", 462 | " 'eval_accuracy': 0.9992550486952994,\n", 463 | " 'global_step': 94200,\n", 464 | " 'loss': 0.0014591591281715949}" 465 | ] 466 | }, 467 | "execution_count": 16, 468 | "metadata": {}, 469 | "output_type": "execute_result" 470 | } 471 | ], 472 | "source": [ 473 | "out" 474 | ] 475 | } 476 | ], 477 | "metadata": { 478 | "kernelspec": { 479 | "display_name": "Python 3", 480 | "language": "python", 481 | "name": "python3" 482 | }, 483 | "language_info": { 484 | "codemirror_mode": { 485 | "name": "ipython", 486 | "version": 3 487 | }, 488 | "file_extension": ".py", 489 | "mimetype": "text/x-python", 490 | "name": "python", 491 | "nbconvert_exporter": "python", 492 | "pygments_lexer": "ipython3", 493 | "version": "3.7.3" 494 | } 495 | }, 496 | "nbformat": 4, 497 | "nbformat_minor": 2 498 | } 499 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # `Chexpert++` 2 | ## Description 3 | Source implementation and pointer to pre-trained models for `chexpert++` (arxiv link forthcoming) a BERT-based 4 | approximation to CheXpert for radiology report labeling. Note that a compelling, co-discovered alternative is 5 | [1], which features a more full-fledged annotation effort featuring two board-certified radiologists and a 6 | more robust error resolution system. This paper is accessible [here](://arxiv.org/pdf/2004.09167.pdf). 7 | 8 | ## Obtaining our Pre-trained Model 9 | Our Pre-trained BERT model is soon to be available via PhysioNet. In the meantime, it is accessible on google cloud platform (GCP) to users who are credentialed for accessing the MIMIC-CXR GCP bucket via PhysioNet. Our bucket link and instructions to gain access through PhysioNet are included below, and please email 10 | [`mmd@mit.edu`](mailto:mmd@mit.edu) if you have any questions. 11 | 12 | ### Our Bucket 13 | https://console.cloud.google.com/storage/browser/chexpertplusplus 14 | 15 | ### Instructions for getting physionet MIMIC-CXR GCP Access 16 | 1. First, follow the physionet instructions to add google cloud access, here: https://mimic.physionet.org/gettingstarted/cloud/Next, 17 | 2. Next, get access to MIMIC-CXR in general on Physionet: https://physionet.org/content/mimic-cxr/2.0.0/ (go to the bottom of the page and follow the steps listed under "Files", including becoming a credentialed user and signing the data use agreement) 18 | 3. Finally, request access to MIMIC-CXR via GCP on Physionet: https://physionet.org/projects/mimic-cxr/2.0.0/request_access/3  19 | 20 | ## Installation 21 | To install a conda environment suitable for reproducing this work, use the environment spec available in 22 | `env.yml`, via, e.g. 23 | ``` 24 | conda env create -f env.yml -n [ENVIRONMENT NAME] 25 | ``` 26 | 27 | Additionally, you must download the [MIMIC-CXR dataset](https://physionet.org/content/mimic-cxr/2.0.0/) and 28 | split the reports into sentences, then label each of these with the CheXpert labeler (code/splits not 29 | provided). You must also download the Clinical BERT model, available 30 | [here](https://github.com/EmilyAlsentzer/clinicalBERT). 31 | 32 | ## Usage Instructions 33 | Main model source code is available in `./chexpert_approximator`. Model training, evaluation, and active 34 | learning proof-of-concept are all available in `Jupyter Notebooks/`. 35 | 36 | ## Citation 37 | *This Work:* 38 | Matthew B.A. McDermott, Tzu Ming Harry Hsu, Wei-Hung Weng, Marzyeh Ghassemi, and Peter Szolovits. 39 | "`Chexpert++`: Approximating the CheXpert labeler for Speed, Differentiability, and Probabilistic Output." 40 | Machine Learning for Health Care (2020) _(in press; link TBA)_. 41 | 42 | *[1]* 43 | Akshay Smit, Saahil Jain, Pranav Rajpurkar, Anuj Pareek, Andrew Y. Ng, and Matthew P. Lungren. "CheXbert: 44 | Combining Automatic Labelers and Expert Annotations for Accurate Radiology Report Labeling Using BERT." arXiv 45 | preprint arXiv:2004.09167 (2020). [https://arxiv.org/pdf/2004.09167.pdf](https://arxiv.org/pdf/2004.09167.pdf) 46 | -------------------------------------------------------------------------------- /chexpert_approximator/.gitignore: -------------------------------------------------------------------------------- 1 | tags 2 | -------------------------------------------------------------------------------- /chexpert_approximator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mmcdermott/chexpertplusplus/b773053f7ad340b43608660d3adee0401087572b/chexpert_approximator/__init__.py -------------------------------------------------------------------------------- /chexpert_approximator/data_processor.py: -------------------------------------------------------------------------------- 1 | import csv, sys 2 | 3 | class InputFeatures(object): 4 | """A single set of features of data.""" 5 | ### MODIFIED TO SUPPORT MULTI-LABELS 6 | 7 | def __init__(self, input_ids, input_mask, segment_ids, label_ids): 8 | self.input_ids = input_ids 9 | self.input_mask = input_mask 10 | self.segment_ids = segment_ids 11 | self.label_ids = label_ids 12 | 13 | 14 | class InputExample(object): 15 | """A single training/test example for simple sequence classification.""" 16 | 17 | def __init__(self, guid, text_a, text_b=None, label=None): 18 | """Constructs a InputExample. 19 | 20 | Args: 21 | guid: Unique id for the example. 22 | text_a: string. The untokenized text of the first sequence. For single 23 | sequence tasks, only this sequence must be specified. 24 | text_b: (Optional) string. The untokenized text of the second sequence. 25 | Only must be specified for sequence pair tasks. 26 | label: (Optional) string. The label of the example. This should be 27 | specified for train and dev examples, but not for test examples. 28 | """ 29 | self.guid = guid 30 | self.text_a = text_a 31 | self.text_b = text_b 32 | self.label = label 33 | 34 | class DataProcessor(object): 35 | """Base class for data converters for sequence classification data sets.""" 36 | 37 | def get_train_examples(self, data_dir): 38 | """Gets a collection of `InputExample`s for the train set.""" 39 | raise NotImplementedError() 40 | 41 | def get_dev_examples(self, data_dir): 42 | """Gets a collection of `InputExample`s for the dev set.""" 43 | raise NotImplementedError() 44 | 45 | def get_test_examples(self, data_dir): 46 | """Gets a collection of `InputExample`s for the test set.""" 47 | raise NotImplementedError() 48 | 49 | def get_labels(self): 50 | """Gets the list of labels for this data set.""" 51 | raise NotImplementedError() 52 | 53 | @classmethod 54 | def _read_tsv(cls, input_file, quotechar=None): 55 | """Reads a tab separated value file.""" 56 | with open(input_file, "r") as f: 57 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 58 | lines = [] 59 | for line in reader: 60 | if sys.version_info[0] == 2: 61 | line = list(unicode(cell, 'utf-8') for cell in line) 62 | lines.append(line) 63 | return lines 64 | 65 | def convert_examples_to_features(examples, task_list, max_seq_length, tokenizer): 66 | """Loads a data file into a list of `InputBatch`s.""" 67 | 68 | features = [] 69 | max_len = 0 70 | for (ex_index, example) in enumerate(examples): 71 | tokens_a = tokenizer.tokenize(example.text_a) 72 | 73 | tokens_b = None 74 | if example.text_b: 75 | tokens_b = tokenizer.tokenize(example.text_b) 76 | seq_len = len(tokens_a) + len(tokens_b) 77 | 78 | # Modifies `tokens_a` and `tokens_b` in place so that the total 79 | # length is less than the specified length. 80 | # Account for [CLS], [SEP], [SEP] with "- 3" 81 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 82 | else: 83 | seq_len = len(tokens_a) 84 | # Account for [CLS] and [SEP] with "- 2" 85 | if len(tokens_a) > max_seq_length - 2: 86 | tokens_a = tokens_a[:(max_seq_length - 2)] 87 | 88 | if seq_len > max_len: 89 | max_len = seq_len 90 | # The convention in BERT is: 91 | # (a) For sequence pairs: 92 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 93 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 94 | # (b) For single sequences: 95 | # tokens: [CLS] the dog is hairy . [SEP] 96 | # type_ids: 0 0 0 0 0 0 0 97 | # 98 | # Where "type_ids" are used to indicate whether this is the first 99 | # sequence or the second sequence. The embedding vectors for `type=0` and 100 | # `type=1` were learned during pre-training and are added to the wordpiece 101 | # embedding vector (and position vector). This is not *strictly* necessary 102 | # since the [SEP] token unambigiously separates the sequences, but it makes 103 | # it easier for the model to learn the concept of sequences. 104 | # 105 | # For classification tasks, the first vector (corresponding to [CLS]) is 106 | # used as as the "sentence vector". Note that this only makes sense because 107 | # the entire model is fine-tuned. 108 | tokens = ["[CLS]"] + tokens_a + ["[SEP]"] 109 | segment_ids = [0] * len(tokens) 110 | 111 | if tokens_b: 112 | tokens += tokens_b + ["[SEP]"] 113 | segment_ids += [1] * (len(tokens_b) + 1) 114 | 115 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 116 | 117 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 118 | # tokens are attended to. 119 | input_mask = [1] * len(input_ids) 120 | 121 | # Zero-pad up to the sequence length. 122 | padding = [0] * (max_seq_length - len(input_ids)) 123 | input_ids += padding 124 | input_mask += padding 125 | segment_ids += padding 126 | 127 | assert len(input_ids) == max_seq_length 128 | assert len(input_mask) == max_seq_length 129 | assert len(segment_ids) == max_seq_length 130 | 131 | features.append(InputFeatures( 132 | input_ids=input_ids, 133 | input_mask=input_mask, 134 | segment_ids=segment_ids, 135 | label_ids=[example.label[t] for t in task_list], 136 | )) 137 | 138 | print('Max Sequence Length: %d' %max_len) 139 | 140 | return features 141 | 142 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 143 | """Truncates a sequence pair in place to the maximum length.""" 144 | 145 | # This is a simple heuristic which will always truncate the longer sequence 146 | # one token at a time. This makes more sense than truncating an equal percent 147 | # of tokens from each, since if one sequence is very short then each token 148 | # that's truncated likely contains more information than a longer sequence. 149 | while True: 150 | total_length = len(tokens_a) + len(tokens_b) 151 | if total_length <= max_length: 152 | break 153 | if len(tokens_a) > len(tokens_b): 154 | tokens_a.pop() 155 | else: 156 | tokens_b.pop() -------------------------------------------------------------------------------- /chexpert_approximator/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pytorch_pretrained_bert.modeling import BertPreTrainedModel, BertModel 4 | from torch import nn 5 | 6 | def multi_task_accuracy(out_by_task, labels_by_task): 7 | acc = 0 8 | for task in out_by_task.keys(): 9 | out, labels = out_by_task[task], labels_by_task[task] 10 | outputs = np.argmax(out, axis=1) 11 | acc += np.sum(outputs == labels) 12 | return acc / len(out_by_task) 13 | 14 | # TODO(mmd): should probably not do via a dictionary. Possible to do all as tensor. 15 | class BertForMultitaskSequenceClassification(BertPreTrainedModel): 16 | """BERT model for multitask sequence classification. 17 | This module is composed of the BERT model with a linear layer per task on top of 18 | the pooled output. 19 | Params: 20 | `config`: a BertConfig class instance with the configuration to build a new model. 21 | `num_labels_per_task`: A dictionary from task_name to the number of classes for the classifier. 22 | Inputs: 23 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 24 | with the word token indices in the vocabulary. Items in the batch should begin with the special "CLS" token. (see the tokens preprocessing logic in the scripts 25 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 26 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 27 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 28 | a `sentence B` token (see BERT paper for more details). 29 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 30 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 31 | input sequence length in the current batch. It's the mask that we typically use for attention when 32 | a batch has varying length sentences. 33 | `labels_per_task`: A dictionary of task name to labels for the classification output: 34 | {task: torch.LongTensor of shape [batch_size] with indices selected in [0, ..., num_labels_per_task[task]]}. 35 | Outputs: 36 | TODO(mmd): update outputs & example usage. 37 | Outputs the classification logits per task as a dictionary of shape {task: [batch_size, num_labels]}. 38 | Outputs the CrossEntropy classification loss as a dictionary of shape {task: loss} for tasks with provided labels 39 | Example usage: 40 | ```python 41 | # Already been converted into WordPiece token ids 42 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 43 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 44 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 45 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 46 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 47 | num_labels = 2 48 | model = BertForSequenceClassification(config, num_labels) 49 | logits = model(input_ids, token_type_ids, input_mask) 50 | ``` 51 | """ 52 | def __init__(self, config, num_labels_per_task): 53 | super().__init__(config) 54 | self.num_labels_per_task = num_labels_per_task 55 | self.bert = BertModel(config) 56 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 57 | self.classifiers = nn.ModuleDict( 58 | {task_name: nn.Linear(config.hidden_size, task_dim) for task_name, task_dim in num_labels_per_task.items()} 59 | ) 60 | self.apply(self.init_bert_weights) 61 | 62 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels_per_task={}): 63 | _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) 64 | pooled_output = self.dropout(pooled_output) 65 | logits = {task: layer(pooled_output) for task, layer in self.classifiers.items()} 66 | 67 | losses = {} 68 | for task, labels in labels_per_task.items(): 69 | losses[task] = nn.CrossEntropyLoss()(logits[task].view(-1, self.num_labels_per_task[task]), labels.view(-1)) 70 | 71 | return logits, losses 72 | -------------------------------------------------------------------------------- /chexpert_approximator/reload_and_get_logits.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | import random 18 | 19 | import numpy as np 20 | import torch 21 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, TensorDataset) 22 | from torch.utils.data.distributed import DistributedSampler 23 | from tqdm import tqdm, tqdm_notebook 24 | 25 | from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE 26 | from pytorch_pretrained_bert.modeling import BertConfig, WEIGHTS_NAME, CONFIG_NAME 27 | from pytorch_pretrained_bert.tokenization import BertTokenizer 28 | from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear 29 | 30 | #added 31 | import json 32 | from random import shuffle 33 | from sklearn.model_selection import KFold 34 | import math 35 | 36 | from .data_processor import * 37 | from .model import * 38 | 39 | def reload_and_get_logits( 40 | df, 41 | bert_model, 42 | processor, 43 | task_dimensions, 44 | output_dir, 45 | processor_args = {}, 46 | gpu = None, 47 | seed = 42, 48 | do_lower_case = False, 49 | max_seq_length = 128, 50 | batch_size = 8, 51 | learning_rate = 5e-5, 52 | num_train_epochs = 3, 53 | cache_dir = None, 54 | tqdm = tqdm_notebook, 55 | model = None, 56 | ): 57 | print(bert_model) 58 | 59 | if gpu is not None and torch.cuda.is_available(): 60 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu 61 | device = torch.device("cuda") 62 | n_gpu = torch.cuda.device_count() 63 | else: 64 | device = torch.device("cpu") 65 | n_gpu = 0 66 | 67 | random.seed(seed) 68 | np.random.seed(seed) 69 | torch.manual_seed(seed) 70 | if n_gpu > 0: 71 | torch.cuda.manual_seed_all(seed) 72 | 73 | if not os.path.exists(output_dir) or len(os.listdir(output_dir)) == 0: 74 | raise ValueError("Output directory ({}) doesn't exist or is not empty.".format(output_dir)) 75 | 76 | label_map = processor.get_labels() 77 | task_list = list(label_map.keys()) 78 | num_tasks = len(label_map) 79 | tokenizer = BertTokenizer.from_pretrained(bert_model, do_lower_case=do_lower_case) 80 | 81 | # Prepare model 82 | ### TODO(mmd): this is where to reload the model properly. 83 | cache_dir = cache_dir if cache_dir else os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_-1') 84 | # Load a trained model and config that you have fine-tuned 85 | output_model_file = os.path.join(output_dir, WEIGHTS_NAME) 86 | output_config_file = os.path.join(output_dir, CONFIG_NAME) 87 | config = BertConfig(output_config_file) 88 | 89 | if model is None: 90 | model = BertForMultitaskSequenceClassification( 91 | config, 92 | num_labels_per_task = {task: len(labels) for task, labels in label_map.items()}, 93 | ) 94 | model.load_state_dict(torch.load(output_model_file)) 95 | 96 | model.to(device) 97 | 98 | eval_examples = processor.get_examples(df, **processor_args) 99 | eval_features = convert_examples_to_features( 100 | eval_examples, task_list, max_seq_length, tokenizer 101 | ) 102 | 103 | all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) 104 | all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) 105 | all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) 106 | all_label_ids = torch.tensor([f.label_ids for f in eval_features], dtype=torch.long) 107 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) 108 | # Run prediction for full data 109 | eval_sampler = SequentialSampler(eval_data) 110 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=batch_size) 111 | 112 | model.eval() 113 | eval_loss, eval_accuracy = 0, 0 114 | nb_eval_steps, nb_eval_examples = 0, 0 115 | 116 | all_logits = {t: [] for t in task_list} 117 | for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"): 118 | input_ids = input_ids.to(device) 119 | input_mask = input_mask.to(device) 120 | segment_ids = segment_ids.to(device) 121 | label_ids = label_ids.to(device) 122 | label_ids = {t: label_ids[:, i] for i, t in enumerate(task_list)} 123 | 124 | with torch.no_grad(): 125 | logits, tmp_eval_loss = model(input_ids, segment_ids, input_mask, label_ids) 126 | # logits = model(input_ids, segment_ids, input_mask) 127 | 128 | logits = {t: lt.detach().cpu().numpy() for t, lt in logits.items()} 129 | for t in task_list: all_logits[t].extend(list(logits[t])) 130 | label_ids = {t: l.to('cpu').numpy() for t, l in label_ids.items()} 131 | 132 | tmp_eval_accuracy = multi_task_accuracy(logits, label_ids) 133 | 134 | tmp_eval_loss = sum(tmp_eval_loss.values()) 135 | eval_loss += tmp_eval_loss.mean().item() 136 | eval_accuracy += tmp_eval_accuracy 137 | 138 | nb_eval_examples += input_ids.size(0) 139 | nb_eval_steps += 1 140 | 141 | eval_loss = eval_loss / nb_eval_steps 142 | eval_accuracy = eval_accuracy / nb_eval_examples 143 | result = {'eval_loss': eval_loss, 144 | 'eval_accuracy': eval_accuracy} 145 | 146 | output_eval_file = os.path.join(output_dir, "eval_results.txt") 147 | with open(output_eval_file, "w") as writer: 148 | for key in sorted(result.keys()): writer.write("%s = %s\n" % (key, str(result[key]))) 149 | return model, all_logits, result 150 | -------------------------------------------------------------------------------- /chexpert_approximator/run_classifier.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | import random 18 | 19 | import numpy as np 20 | import torch 21 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, TensorDataset) 22 | from torch.utils.data.distributed import DistributedSampler 23 | from tqdm import tqdm, tqdm_notebook 24 | 25 | from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE 26 | from pytorch_pretrained_bert.modeling import BertConfig, WEIGHTS_NAME, CONFIG_NAME 27 | from pytorch_pretrained_bert.tokenization import BertTokenizer 28 | from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear 29 | 30 | #added 31 | import json 32 | from random import shuffle 33 | from sklearn.model_selection import KFold 34 | import math 35 | 36 | from .data_processor import * 37 | from .model import * 38 | 39 | def build_and_train( 40 | df, 41 | bert_model, 42 | processor, 43 | task_dimensions, 44 | output_dir, 45 | gradient_accumulation_steps = 1, 46 | gpu = None, 47 | do_train = True, 48 | do_eval = False, 49 | seed = 42, 50 | do_lower_case = False, 51 | max_seq_length = 128, 52 | train_batch_size = 32, 53 | eval_batch_size = 8, 54 | learning_rate = 5e-5, 55 | num_train_epochs = 3, 56 | warmup_proportion = 0.1, 57 | cache_dir = None, 58 | tqdm = tqdm_notebook, 59 | model = None, 60 | ): 61 | print(bert_model) 62 | 63 | if gpu is not None and torch.cuda.is_available(): 64 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu 65 | device = torch.device("cuda") 66 | n_gpu = torch.cuda.device_count() 67 | else: 68 | device = torch.device("cpu") 69 | n_gpu = 0 70 | 71 | if gradient_accumulation_steps < 1: 72 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( 73 | gradient_accumulation_steps)) 74 | 75 | train_batch_size = train_batch_size // gradient_accumulation_steps 76 | 77 | random.seed(seed) 78 | np.random.seed(seed) 79 | torch.manual_seed(seed) 80 | if n_gpu > 0: 81 | torch.cuda.manual_seed_all(seed) 82 | 83 | if not do_train and not do_eval: 84 | raise ValueError("At least one of `do_train` or `do_eval` must be True.") 85 | 86 | if os.path.exists(output_dir) and os.listdir(output_dir) and do_train: 87 | raise ValueError("Output directory ({}) already exists and is not empty.".format(output_dir)) 88 | if not os.path.exists(output_dir): 89 | os.makedirs(output_dir) 90 | 91 | label_map = processor.get_labels() 92 | task_list = list(label_map.keys()) 93 | num_tasks = len(label_map) 94 | tokenizer = BertTokenizer.from_pretrained(bert_model, do_lower_case=do_lower_case) 95 | 96 | #TESTING 97 | #splits = processor.create_cv_examples(data_dir) 98 | 99 | 100 | train_examples = None 101 | num_train_optimization_steps = None 102 | if do_train: 103 | train_examples = processor.get_train_examples(df) 104 | num_train_optimization_steps = int( 105 | len(train_examples) / train_batch_size / gradient_accumulation_steps) * num_train_epochs 106 | 107 | # Prepare model 108 | cache_dir = cache_dir if cache_dir else os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_-1') 109 | 110 | if model is None: 111 | model = BertForMultitaskSequenceClassification.from_pretrained( 112 | bert_model, 113 | cache_dir=cache_dir, 114 | num_labels_per_task = {task: len(labels) for task, labels in label_map.items()}, 115 | ) 116 | 117 | model.to(device) 118 | if n_gpu > 1: model = torch.nn.DataParallel(model) 119 | 120 | # Prepare optimizer 121 | param_optimizer = list(model.named_parameters()) 122 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 123 | optimizer_grouped_parameters = [ 124 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 125 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 126 | ] 127 | 128 | optimizer = BertAdam( 129 | optimizer_grouped_parameters, 130 | lr=learning_rate, 131 | warmup=warmup_proportion, 132 | t_total=num_train_optimization_steps 133 | ) 134 | 135 | global_step = 0 136 | nb_tr_steps = 0 137 | tr_loss = 0 138 | if do_train: 139 | train_features = convert_examples_to_features( 140 | train_examples, task_list, max_seq_length, tokenizer 141 | ) 142 | 143 | all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) 144 | all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) 145 | all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) 146 | all_label_ids = torch.tensor([f.label_ids for f in train_features], dtype=torch.long) 147 | 148 | train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) 149 | train_sampler = RandomSampler(train_data) 150 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=train_batch_size) 151 | 152 | model.train() 153 | for _ in tqdm(range(int(num_train_epochs)), desc="Epoch"): 154 | tr_loss = 0 155 | nb_tr_examples, nb_tr_steps = 0, 0 156 | for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): 157 | batch = tuple(t.to(device) for t in batch) 158 | input_ids, input_mask, segment_ids, label_ids = batch 159 | 160 | label_ids = {t: label_ids[:, i] for i, t in enumerate(task_list)} 161 | 162 | _, losses = model(input_ids, segment_ids, input_mask, label_ids) 163 | loss = sum(losses.values()) / num_tasks 164 | 165 | if n_gpu > 1: 166 | loss = loss.mean() # mean() to average on multi-gpu. 167 | if gradient_accumulation_steps > 1: 168 | loss = loss / gradient_accumulation_steps 169 | 170 | loss.backward() 171 | 172 | tr_loss += loss.item() 173 | nb_tr_examples += input_ids.size(0) 174 | nb_tr_steps += 1 175 | if (step + 1) % gradient_accumulation_steps == 0: 176 | optimizer.step() 177 | optimizer.zero_grad() 178 | global_step += 1 179 | #### HERE NEED TO MODIFY. 180 | 181 | if do_train: 182 | # Save a trained model and the associated configuration 183 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self 184 | output_model_file = os.path.join(output_dir, WEIGHTS_NAME) 185 | torch.save(model_to_save.state_dict(), output_model_file) 186 | output_config_file = os.path.join(output_dir, CONFIG_NAME) 187 | with open(output_config_file, 'w') as f: 188 | f.write(model_to_save.config.to_json_string()) 189 | 190 | # Load a trained model and config that you have fine-tuned 191 | config = BertConfig(output_config_file) 192 | model = BertForMultitaskSequenceClassification( 193 | config, 194 | num_labels_per_task = {task: len(labels) for task, labels in label_map.items()}, 195 | ) 196 | model.load_state_dict(torch.load(output_model_file)) 197 | else: 198 | model = BertForMultitaskSequenceClassification.from_pretrained( 199 | bert_model, 200 | num_labels_per_task = {task: len(labels) for task, labels in label_map.items()}, 201 | ) 202 | model.to(device) 203 | 204 | if do_eval: 205 | eval_examples = processor.get_dev_examples(df) 206 | eval_features = convert_examples_to_features( 207 | eval_examples, task_list, max_seq_length, tokenizer 208 | ) 209 | 210 | all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long) 211 | all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long) 212 | all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long) 213 | all_label_ids = torch.tensor([f.label_ids for f in eval_features], dtype=torch.long) 214 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) 215 | # Run prediction for full data 216 | eval_sampler = SequentialSampler(eval_data) 217 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=eval_batch_size) 218 | 219 | model.eval() 220 | eval_loss, eval_accuracy = 0, 0 221 | nb_eval_steps, nb_eval_examples = 0, 0 222 | 223 | for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"): 224 | input_ids = input_ids.to(device) 225 | input_mask = input_mask.to(device) 226 | segment_ids = segment_ids.to(device) 227 | label_ids = label_ids.to(device) 228 | 229 | label_ids = {t: label_ids[:, i] for i, t in enumerate(task_list)} 230 | 231 | with torch.no_grad(): 232 | logits, tmp_eval_loss = model(input_ids, segment_ids, input_mask, label_ids) 233 | # logits = model(input_ids, segment_ids, input_mask) 234 | 235 | logits = {t: lt.detach().cpu().numpy() for t, lt in logits.items()} 236 | label_ids = {t: l.to('cpu').numpy() for t, l in label_ids.items()} 237 | 238 | tmp_eval_accuracy = multi_task_accuracy(logits, label_ids) 239 | 240 | tmp_eval_loss = sum(tmp_eval_loss.values()) 241 | eval_loss += tmp_eval_loss.mean().item() 242 | eval_accuracy += tmp_eval_accuracy 243 | 244 | nb_eval_examples += input_ids.size(0) 245 | nb_eval_steps += 1 246 | 247 | eval_loss = eval_loss / nb_eval_steps 248 | eval_accuracy = eval_accuracy / nb_eval_examples 249 | loss = tr_loss/nb_tr_steps if do_train else None 250 | result = {'eval_loss': eval_loss, 251 | 'eval_accuracy': eval_accuracy, 252 | 'global_step': global_step, 253 | 'loss': loss} 254 | 255 | output_eval_file = os.path.join(output_dir, "eval_results.txt") 256 | with open(output_eval_file, "w") as writer: 257 | for key in sorted(result.keys()): writer.write("%s = %s\n" % (key, str(result[key]))) 258 | return result 259 | -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - attrs=19.1.0=py_0 8 | - backcall=0.1.0=py_0 9 | - bleach=3.1.0=py_0 10 | - bzip2=1.0.6=h14c3975_1002 11 | - ca-certificates=2019.3.9=hecc5488_0 12 | - certifi=2019.3.9=py37_0 13 | - cffi=1.12.3=py37h8022711_0 14 | - cycler=0.10.0=py_1 15 | - decorator=4.4.0=py_0 16 | - defusedxml=0.5.0=py_1 17 | - entrypoints=0.3=py37_1000 18 | - expat=2.2.5=hf484d3e_1002 19 | - fontconfig=2.13.1=he4413a7_1000 20 | - freetype=2.10.0=he983fc9_0 21 | - future=0.17.1=py37_1000 22 | - gettext=0.19.8.1=hc5be6a0_1002 23 | - glib=2.56.2=had28632_1001 24 | - hyperopt=0.1.2=py_0 25 | - icu=58.2=hf484d3e_1000 26 | - ipykernel=5.1.0=py37h24bf2e0_1002 27 | - ipython=7.4.0=py37h24bf2e0_0 28 | - ipython_genutils=0.2.0=py_1 29 | - ipywidgets=7.4.2=py_0 30 | - jedi=0.13.3=py37_0 31 | - jinja2=2.10.1=py_0 32 | - jpeg=9c=h14c3975_1001 33 | - jsonschema=3.0.1=py37_0 34 | - jupyter=1.0.0=py_2 35 | - jupyter_client=5.2.4=py_3 36 | - jupyter_console=6.0.0=py_0 37 | - jupyter_contrib_core=0.3.3=py_2 38 | - jupyter_contrib_nbextensions=0.5.1=py37_0 39 | - jupyter_core=4.4.0=py_0 40 | - jupyter_highlight_selected_word=0.2.0=py37_1000 41 | - jupyter_latex_envs=1.4.4=py37_1000 42 | - jupyter_nbextensions_configurator=0.4.1=py37_0 43 | - kiwisolver=1.0.1=py37h6bb024c_1002 44 | - libblas=3.8.0=5_openblas 45 | - libcblas=3.8.0=5_openblas 46 | - libffi=3.2.1=he1b5a44_1006 47 | - libiconv=1.15=h516909a_1005 48 | - liblapack=3.8.0=5_openblas 49 | - libpng=1.6.37=hed695b0_0 50 | - libsodium=1.0.16=h14c3975_1001 51 | - libtiff=4.0.10=h648cc4a_1001 52 | - libuuid=2.32.1=h14c3975_1000 53 | - libxcb=1.13=h14c3975_1002 54 | - libxml2=2.9.9=h13577e0_0 55 | - libxslt=1.1.32=h4785a14_1002 56 | - lxml=4.3.3=py37h7ec2d77_0 57 | - markupsafe=1.1.1=py37h14c3975_0 58 | - matplotlib=3.0.3=py37_1 59 | - matplotlib-base=3.0.3=py37h5f35d83_1 60 | - mistune=0.8.4=py37h14c3975_1000 61 | - nbconvert=5.4.1=py_2 62 | - nbformat=4.4.0=py_1 63 | - ncurses=6.1=hf484d3e_1002 64 | - networkx=2.3=py_0 65 | - ninja=1.9.0=h6bb024c_0 66 | - notebook=5.7.8=py37_0 67 | - numpy=1.16.3=py37he5ce36f_0 68 | - olefile=0.46=py_0 69 | - openblas=0.3.5=h9ac9557_1001 70 | - openssl=1.1.1b=h14c3975_1 71 | - pandas=0.24.2=py37hf484d3e_0 72 | - pandoc=2.7.2=0 73 | - pandocfilters=1.4.2=py_1 74 | - parso=0.4.0=py_0 75 | - pexpect=4.7.0=py37_0 76 | - pickleshare=0.7.5=py37_1000 77 | - pillow=6.0.0=py37he7afcd5_0 78 | - pip=19.0.3=py37_0 79 | - prometheus_client=0.6.0=py_0 80 | - prompt_toolkit=2.0.9=py_0 81 | - pthread-stubs=0.4=h14c3975_1001 82 | - ptyprocess=0.6.0=py_1001 83 | - pycparser=2.19=py37_1 84 | - pygments=2.3.1=py_0 85 | - pymongo=3.7.2=py37hf484d3e_0 86 | - pyparsing=2.4.0=py_0 87 | - pyqt=5.6.0=py37h13b7fb3_1008 88 | - pyrsistent=0.14.11=py37h14c3975_0 89 | - python=3.7.3=h5b0a415_0 90 | - python-dateutil=2.8.0=py_0 91 | - pytz=2019.1=py_0 92 | - pyyaml=5.1=py37h14c3975_0 93 | - pyzmq=18.0.1=py37hc4ba49a_1 94 | - qtconsole=4.4.3=py_0 95 | - readline=7.0=hf8c457e_1001 96 | - scikit-learn=0.20.3=py37ha8026db_1 97 | - scipy=1.2.1=py37h09a28d5_1 98 | - send2trash=1.5.0=py_0 99 | - setuptools=41.0.1=py37_0 100 | - sip=4.18.1=py37hf484d3e_1000 101 | - six=1.12.0=py37_1000 102 | - sqlite=3.26.0=h67949de_1001 103 | - terminado=0.8.2=py37_0 104 | - testpath=0.4.2=py_1001 105 | - tk=8.6.9=h84994c4_1001 106 | - tornado=6.0.2=py37h516909a_0 107 | - tqdm=4.31.1=py_0 108 | - traitlets=4.3.2=py37_1000 109 | - wcwidth=0.1.7=py_1 110 | - webencodings=0.5.1=py_1 111 | - wheel=0.33.1=py37_0 112 | - widgetsnbextension=3.4.2=py37_1000 113 | - xorg-libxau=1.0.9=h14c3975_0 114 | - xorg-libxdmcp=1.1.3=h516909a_0 115 | - xz=5.2.4=h14c3975_1001 116 | - yaml=0.1.7=h14c3975_1001 117 | - zeromq=4.3.1=hf484d3e_1000 118 | - zlib=1.2.11=h14c3975_1004 119 | - cudatoolkit=9.0=h13b8566_0 120 | - dbus=1.13.2=h714fa37_1 121 | - gst-plugins-base=1.14.0=hbbd80ab_1 122 | - gstreamer=1.14.0=hb453b48_1 123 | - intel-openmp=2019.3=199 124 | - libgcc-ng=8.2.0=hdf63c60_1 125 | - libgfortran-ng=7.3.0=hdf63c60_0 126 | - libstdcxx-ng=8.2.0=hdf63c60_1 127 | - mkl=2019.3=199 128 | - pcre=8.43=he6710b0_0 129 | - qt=5.6.3=h8bf5577_3 130 | - pytorch=1.0.1=py3.7_cuda9.0.176_cudnn7.4.2_2 131 | - torchvision=0.2.2=py_3 132 | - pip: 133 | - boto3==1.9.134 134 | - botocore==1.12.134 135 | - chardet==3.0.4 136 | - docutils==0.14 137 | - idna==2.8 138 | - jmespath==0.9.4 139 | - pytorch-pretrained-bert==0.6.1 140 | - regex==2019.4.14 141 | - requests==2.21.0 142 | - s3transfer==0.2.0 143 | - torch==1.0.1.post2 144 | - urllib3==1.24.2 145 | prefix: /scratch/conda_envs/chexpert_approximator 146 | --------------------------------------------------------------------------------