├── .gitignore
├── Jupyter Notebooks
├── Active Learning.ipynb
├── Get Results.ipynb
└── Train the Model.ipynb
├── README.md
├── chexpert_approximator
├── .gitignore
├── __init__.py
├── data_processor.py
├── model.py
├── reload_and_get_logits.py
└── run_classifier.py
└── env.yml
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
106 | # out
107 | out/
108 |
109 | # tags
110 | tags
111 |
--------------------------------------------------------------------------------
/Jupyter Notebooks/Train the Model.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Imports"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "%load_ext autoreload\n",
17 | "import sys\n",
18 | "\n",
19 | "sys.path.append('..')"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": 2,
25 | "metadata": {},
26 | "outputs": [
27 | {
28 | "name": "stdout",
29 | "output_type": "stream",
30 | "text": [
31 | "Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.\n"
32 | ]
33 | }
34 | ],
35 | "source": [
36 | "%autoreload\n",
37 | "\n",
38 | "import itertools, os, pickle, pandas as pd\n",
39 | "from collections import Counter\n",
40 | "\n",
41 | "from chexpert_approximator.data_processor import *\n",
42 | "from chexpert_approximator.run_classifier import *\n",
43 | "\n",
44 | "from chexpert_approximator.reload_and_get_logits import *"
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": 3,
50 | "metadata": {},
51 | "outputs": [],
52 | "source": [
53 | "# We can't display MIMIC-CXR Output:\n",
54 | "\n",
55 | "DO_BLIND = True\n",
56 | "def blind_display(df):\n",
57 | " if DO_BLIND: \n",
58 | " df = df.copy()\n",
59 | " index_levels = df.index.names\n",
60 | " df.reset_index('rad_id', inplace=True)\n",
61 | " df['rad_id'] = [0 for _ in df['rad_id']]\n",
62 | " df.set_index('rad_id', append=True, inplace=True)\n",
63 | " df = df.reorder_levels(index_levels, axis=0)\n",
64 | "\n",
65 | " for c in df.columns:\n",
66 | " if pd.api.types.is_string_dtype(df[c]): df[c] = ['SAMPLE' for _ in df[c]]\n",
67 | " else: df[c] = np.NaN\n",
68 | "\n",
69 | " display(df.head())"
70 | ]
71 | },
72 | {
73 | "cell_type": "markdown",
74 | "metadata": {},
75 | "source": [
76 | "# Load the Data"
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "execution_count": 4,
82 | "metadata": {},
83 | "outputs": [],
84 | "source": [
85 | "DATA_DIR = '/scratch/chexpert_approximator/processed_data/' # INSERT YOUR DATA DIR HERE!\n",
86 | "# DATA MUST BE STORED IN A FILE `inputs.hdf` under key `with_folds`.\n",
87 | "INPUT_PATH, INPUT_KEY = os.path.join(DATA_DIR, 'inputs.hdf'), 'with_folds'\n",
88 | "\n",
89 | "# YOUR CLINICAL BERT MODEL GOES HERE\n",
90 | "BERT_MODEL_PATH = (\n",
91 | " '/data/medg/misc/clinical_BERT/cliniBERT/models/pretrained_bert_tf/bert_pretrain_output_all_notes_300000/'\n",
92 | ")\n",
93 | "\n",
94 | "# THIS IS WHERE YOUR PRE-TRAINED CHEXPERT++ MODEL WILL BE WRITTEN\n",
95 | "OUT_CXPPP_DIR = '../out/run_1'\n",
96 | "\n",
97 | "# DON'T MODIFY THESE\n",
98 | "FOLD = 'Fold'\n",
99 | "\n",
100 | "KEY = {\n",
101 | " 0: 'No Mention',\n",
102 | " 1: 'Uncertain Mention',\n",
103 | " 2: 'Negative Mention',\n",
104 | " 3: 'Positive Mention',\n",
105 | "}\n",
106 | "INV_KEY = {v: k for k, v in KEY.items()}"
107 | ]
108 | },
109 | {
110 | "cell_type": "code",
111 | "execution_count": 5,
112 | "metadata": {},
113 | "outputs": [],
114 | "source": [
115 | "inputs = pd.read_hdf(INPUT_PATH, INPUT_KEY)\n",
116 | "label_cols = [col for col in inputs.index.names if col not in ('rad_id', FOLD)]"
117 | ]
118 | },
119 | {
120 | "cell_type": "code",
121 | "execution_count": 6,
122 | "metadata": {},
123 | "outputs": [
124 | {
125 | "data": {
126 | "text/html": [
127 | "
\n",
128 | "\n",
141 | "
\n",
142 | " \n",
143 | " \n",
144 | " | \n",
145 | " | \n",
146 | " | \n",
147 | " | \n",
148 | " | \n",
149 | " | \n",
150 | " | \n",
151 | " | \n",
152 | " | \n",
153 | " | \n",
154 | " | \n",
155 | " | \n",
156 | " | \n",
157 | " | \n",
158 | " | \n",
159 | " | \n",
160 | " sentence | \n",
161 | "
\n",
162 | " \n",
163 | " rad_id | \n",
164 | " No Finding | \n",
165 | " Enlarged Cardiomediastinum | \n",
166 | " Cardiomegaly | \n",
167 | " Lung Lesion | \n",
168 | " Airspace Opacity | \n",
169 | " Edema | \n",
170 | " Consolidation | \n",
171 | " Pneumonia | \n",
172 | " Atelectasis | \n",
173 | " Pneumothorax | \n",
174 | " Pleural Effusion | \n",
175 | " Pleural Other | \n",
176 | " Fracture | \n",
177 | " Support Devices | \n",
178 | " Fold | \n",
179 | " | \n",
180 | "
\n",
181 | " \n",
182 | " \n",
183 | " \n",
184 | " 0 | \n",
185 | " Positive Mention | \n",
186 | " No Mention | \n",
187 | " No Mention | \n",
188 | " No Mention | \n",
189 | " No Mention | \n",
190 | " No Mention | \n",
191 | " No Mention | \n",
192 | " No Mention | \n",
193 | " No Mention | \n",
194 | " No Mention | \n",
195 | " No Mention | \n",
196 | " No Mention | \n",
197 | " No Mention | \n",
198 | " No Mention | \n",
199 | " 6 | \n",
200 | " SAMPLE | \n",
201 | "
\n",
202 | " \n",
203 | " Negative Mention | \n",
204 | " No Mention | \n",
205 | " No Mention | \n",
206 | " No Mention | \n",
207 | " No Mention | \n",
208 | " No Mention | \n",
209 | " No Mention | \n",
210 | " No Mention | \n",
211 | " No Mention | \n",
212 | " No Mention | \n",
213 | " No Mention | \n",
214 | " No Mention | \n",
215 | " No Mention | \n",
216 | " 6 | \n",
217 | " SAMPLE | \n",
218 | "
\n",
219 | " \n",
220 | " No Mention | \n",
221 | " No Mention | \n",
222 | " No Mention | \n",
223 | " No Mention | \n",
224 | " No Mention | \n",
225 | " No Mention | \n",
226 | " No Mention | \n",
227 | " No Mention | \n",
228 | " No Mention | \n",
229 | " No Mention | \n",
230 | " No Mention | \n",
231 | " No Mention | \n",
232 | " No Mention | \n",
233 | " 6 | \n",
234 | " SAMPLE | \n",
235 | "
\n",
236 | " \n",
237 | " 4 | \n",
238 | " SAMPLE | \n",
239 | "
\n",
240 | " \n",
241 | " No Mention | \n",
242 | " No Mention | \n",
243 | " No Mention | \n",
244 | " No Mention | \n",
245 | " Positive Mention | \n",
246 | " No Mention | \n",
247 | " No Mention | \n",
248 | " No Mention | \n",
249 | " No Mention | \n",
250 | " No Mention | \n",
251 | " No Mention | \n",
252 | " No Mention | \n",
253 | " No Mention | \n",
254 | " No Mention | \n",
255 | " 4 | \n",
256 | " SAMPLE | \n",
257 | "
\n",
258 | " \n",
259 | "
\n",
260 | "
"
261 | ],
262 | "text/plain": [
263 | " sentence\n",
264 | "rad_id No Finding Enlarged Cardiomediastinum Cardiomegaly Lung Lesion Airspace Opacity Edema Consolidation Pneumonia Atelectasis Pneumothorax Pleural Effusion Pleural Other Fracture Support Devices Fold \n",
265 | "0 Positive Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention 6 SAMPLE\n",
266 | " Negative Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention 6 SAMPLE\n",
267 | " No Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention 6 SAMPLE\n",
268 | " 4 SAMPLE\n",
269 | " No Mention No Mention No Mention No Mention Positive Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention No Mention 4 SAMPLE"
270 | ]
271 | },
272 | "metadata": {},
273 | "output_type": "display_data"
274 | }
275 | ],
276 | "source": [
277 | "blind_display(inputs)"
278 | ]
279 | },
280 | {
281 | "cell_type": "markdown",
282 | "metadata": {},
283 | "source": [
284 | "# Model\n",
285 | "## Data Processor"
286 | ]
287 | },
288 | {
289 | "cell_type": "code",
290 | "execution_count": 44,
291 | "metadata": {},
292 | "outputs": [],
293 | "source": [
294 | "class CheXpertProcessor(DataProcessor):\n",
295 | " def __init__(self, tuning_fold, held_out_fold):\n",
296 | " super().__init__()\n",
297 | " self.tuning_fold, self.held_out_fold = tuning_fold, held_out_fold\n",
298 | " \n",
299 | " \"\"\"Processor for the CheXpert approximator.\n",
300 | " Honestly this is kind of silly, as it never stores internal state.\"\"\"\n",
301 | " def get_train_examples(self, df): return self._create_examples(\n",
302 | " df, set([f for f in range(K) if f not in (self.tuning_fold, self.held_out_fold)])\n",
303 | " )\n",
304 | " def get_dev_examples(self, df): return self._create_examples(df, set([self.tuning_fold]))\n",
305 | " def get_examples(self, df, folds=[]): return self._create_examples(df, set(folds))\n",
306 | " \n",
307 | " def get_labels(self): return {label: list(range(4)) for label in label_cols}\n",
308 | "\n",
309 | " def _create_examples(self, df, folds):\n",
310 | " \"\"\"Creates examples for the training and dev sets.\"\"\"\n",
311 | " df = df[df.index.get_level_values(FOLD).isin(folds)]\n",
312 | " lmap = {l: i for i, l in enumerate(df.index.names)}\n",
313 | " \n",
314 | " examples = []\n",
315 | " for idx, r in df.iterrows():\n",
316 | " labels = {l: INV_KEY[idx[lmap[l]]] for l in label_cols}\n",
317 | " \n",
318 | " examples.append(InputExample(\n",
319 | " guid=str(idx[lmap['rad_id']]), text_a=r.sentence, text_b=None, label=labels\n",
320 | " ))\n",
321 | " return examples"
322 | ]
323 | },
324 | {
325 | "cell_type": "markdown",
326 | "metadata": {},
327 | "source": [
328 | "## Running the Model"
329 | ]
330 | },
331 | {
332 | "cell_type": "code",
333 | "execution_count": 45,
334 | "metadata": {},
335 | "outputs": [],
336 | "source": [
337 | "processor = CheXpertProcessor(8, 9)\n",
338 | "# train_examples = processor.get_train_examples(inputs)\n",
339 | "# dev_examples = processor.get_dev_examples(inputs)"
340 | ]
341 | },
342 | {
343 | "cell_type": "code",
344 | "execution_count": 13,
345 | "metadata": {},
346 | "outputs": [],
347 | "source": [
348 | "os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2'"
349 | ]
350 | },
351 | {
352 | "cell_type": "code",
353 | "execution_count": null,
354 | "metadata": {
355 | "scrolled": false
356 | },
357 | "outputs": [
358 | {
359 | "name": "stdout",
360 | "output_type": "stream",
361 | "text": [
362 | "/data/medg/misc/clinical_BERT/cliniBERT/models/pretrained_bert_tf/bert_pretrain_output_all_notes_300000/\n",
363 | "Max Sequence Length: 112\n"
364 | ]
365 | },
366 | {
367 | "data": {
368 | "application/vnd.jupyter.widget-view+json": {
369 | "model_id": "",
370 | "version_major": 2,
371 | "version_minor": 0
372 | },
373 | "text/plain": [
374 | "HBox(children=(IntProgress(value=0, description='Epoch', max=5, style=ProgressStyle(description_width='initial…"
375 | ]
376 | },
377 | "metadata": {},
378 | "output_type": "display_data"
379 | },
380 | {
381 | "data": {
382 | "application/vnd.jupyter.widget-view+json": {
383 | "model_id": "",
384 | "version_major": 2,
385 | "version_minor": 0
386 | },
387 | "text/plain": [
388 | "HBox(children=(IntProgress(value=0, description='Iteration', max=18840, style=ProgressStyle(description_width=…"
389 | ]
390 | },
391 | "metadata": {},
392 | "output_type": "display_data"
393 | },
394 | {
395 | "name": "stderr",
396 | "output_type": "stream",
397 | "text": [
398 | "/scratch/conda_envs/chexpert_approximator/lib/python3.7/site-packages/torch/nn/parallel/_functions.py:61: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
399 | " warnings.warn('Was asked to gather along dimension 0, but all '\n"
400 | ]
401 | },
402 | {
403 | "data": {
404 | "application/vnd.jupyter.widget-view+json": {
405 | "model_id": "",
406 | "version_major": 2,
407 | "version_minor": 0
408 | },
409 | "text/plain": [
410 | "HBox(children=(IntProgress(value=0, description='Iteration', max=18840, style=ProgressStyle(description_width=…"
411 | ]
412 | },
413 | "metadata": {},
414 | "output_type": "display_data"
415 | },
416 | {
417 | "data": {
418 | "application/vnd.jupyter.widget-view+json": {
419 | "model_id": "",
420 | "version_major": 2,
421 | "version_minor": 0
422 | },
423 | "text/plain": [
424 | "HBox(children=(IntProgress(value=0, description='Iteration', max=18840, style=ProgressStyle(description_width=…"
425 | ]
426 | },
427 | "metadata": {},
428 | "output_type": "display_data"
429 | }
430 | ],
431 | "source": [
432 | "out = build_and_train(\n",
433 | " inputs,\n",
434 | " bert_model = BERT_MODEL_PATH,\n",
435 | " processor = processor,\n",
436 | " task_dimensions = {l: 4 for l in label_cols},\n",
437 | " output_dir = OUT_CXPPP_DIR,\n",
438 | " gradient_accumulation_steps = 1,\n",
439 | " gpu = '0,1,2',\n",
440 | " do_train = True,\n",
441 | " do_eval = True,\n",
442 | " seed = 42,\n",
443 | " do_lower_case = False,\n",
444 | " max_seq_length = 128,\n",
445 | " train_batch_size = 32,\n",
446 | " eval_batch_size = 8,\n",
447 | " learning_rate = 5e-5,\n",
448 | " num_train_epochs = 5,\n",
449 | " warmup_proportion = 0.1,\n",
450 | ")"
451 | ]
452 | },
453 | {
454 | "cell_type": "code",
455 | "execution_count": 16,
456 | "metadata": {},
457 | "outputs": [
458 | {
459 | "data": {
460 | "text/plain": [
461 | "{'eval_loss': 0.04080065189753456,\n",
462 | " 'eval_accuracy': 0.9992550486952994,\n",
463 | " 'global_step': 94200,\n",
464 | " 'loss': 0.0014591591281715949}"
465 | ]
466 | },
467 | "execution_count": 16,
468 | "metadata": {},
469 | "output_type": "execute_result"
470 | }
471 | ],
472 | "source": [
473 | "out"
474 | ]
475 | }
476 | ],
477 | "metadata": {
478 | "kernelspec": {
479 | "display_name": "Python 3",
480 | "language": "python",
481 | "name": "python3"
482 | },
483 | "language_info": {
484 | "codemirror_mode": {
485 | "name": "ipython",
486 | "version": 3
487 | },
488 | "file_extension": ".py",
489 | "mimetype": "text/x-python",
490 | "name": "python",
491 | "nbconvert_exporter": "python",
492 | "pygments_lexer": "ipython3",
493 | "version": "3.7.3"
494 | }
495 | },
496 | "nbformat": 4,
497 | "nbformat_minor": 2
498 | }
499 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # `Chexpert++`
2 | ## Description
3 | Source implementation and pointer to pre-trained models for `chexpert++` (arxiv link forthcoming) a BERT-based
4 | approximation to CheXpert for radiology report labeling. Note that a compelling, co-discovered alternative is
5 | [1], which features a more full-fledged annotation effort featuring two board-certified radiologists and a
6 | more robust error resolution system. This paper is accessible [here](://arxiv.org/pdf/2004.09167.pdf).
7 |
8 | ## Obtaining our Pre-trained Model
9 | Our Pre-trained BERT model is soon to be available via PhysioNet. In the meantime, it is accessible on google cloud platform (GCP) to users who are credentialed for accessing the MIMIC-CXR GCP bucket via PhysioNet. Our bucket link and instructions to gain access through PhysioNet are included below, and please email
10 | [`mmd@mit.edu`](mailto:mmd@mit.edu) if you have any questions.
11 |
12 | ### Our Bucket
13 | https://console.cloud.google.com/storage/browser/chexpertplusplus
14 |
15 | ### Instructions for getting physionet MIMIC-CXR GCP Access
16 | 1. First, follow the physionet instructions to add google cloud access, here: https://mimic.physionet.org/gettingstarted/cloud/Next,
17 | 2. Next, get access to MIMIC-CXR in general on Physionet: https://physionet.org/content/mimic-cxr/2.0.0/ (go to the bottom of the page and follow the steps listed under "Files", including becoming a credentialed user and signing the data use agreement)
18 | 3. Finally, request access to MIMIC-CXR via GCP on Physionet: https://physionet.org/projects/mimic-cxr/2.0.0/request_access/3
19 |
20 | ## Installation
21 | To install a conda environment suitable for reproducing this work, use the environment spec available in
22 | `env.yml`, via, e.g.
23 | ```
24 | conda env create -f env.yml -n [ENVIRONMENT NAME]
25 | ```
26 |
27 | Additionally, you must download the [MIMIC-CXR dataset](https://physionet.org/content/mimic-cxr/2.0.0/) and
28 | split the reports into sentences, then label each of these with the CheXpert labeler (code/splits not
29 | provided). You must also download the Clinical BERT model, available
30 | [here](https://github.com/EmilyAlsentzer/clinicalBERT).
31 |
32 | ## Usage Instructions
33 | Main model source code is available in `./chexpert_approximator`. Model training, evaluation, and active
34 | learning proof-of-concept are all available in `Jupyter Notebooks/`.
35 |
36 | ## Citation
37 | *This Work:*
38 | Matthew B.A. McDermott, Tzu Ming Harry Hsu, Wei-Hung Weng, Marzyeh Ghassemi, and Peter Szolovits.
39 | "`Chexpert++`: Approximating the CheXpert labeler for Speed, Differentiability, and Probabilistic Output."
40 | Machine Learning for Health Care (2020) _(in press; link TBA)_.
41 |
42 | *[1]*
43 | Akshay Smit, Saahil Jain, Pranav Rajpurkar, Anuj Pareek, Andrew Y. Ng, and Matthew P. Lungren. "CheXbert:
44 | Combining Automatic Labelers and Expert Annotations for Accurate Radiology Report Labeling Using BERT." arXiv
45 | preprint arXiv:2004.09167 (2020). [https://arxiv.org/pdf/2004.09167.pdf](https://arxiv.org/pdf/2004.09167.pdf)
46 |
--------------------------------------------------------------------------------
/chexpert_approximator/.gitignore:
--------------------------------------------------------------------------------
1 | tags
2 |
--------------------------------------------------------------------------------
/chexpert_approximator/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mmcdermott/chexpertplusplus/b773053f7ad340b43608660d3adee0401087572b/chexpert_approximator/__init__.py
--------------------------------------------------------------------------------
/chexpert_approximator/data_processor.py:
--------------------------------------------------------------------------------
1 | import csv, sys
2 |
3 | class InputFeatures(object):
4 | """A single set of features of data."""
5 | ### MODIFIED TO SUPPORT MULTI-LABELS
6 |
7 | def __init__(self, input_ids, input_mask, segment_ids, label_ids):
8 | self.input_ids = input_ids
9 | self.input_mask = input_mask
10 | self.segment_ids = segment_ids
11 | self.label_ids = label_ids
12 |
13 |
14 | class InputExample(object):
15 | """A single training/test example for simple sequence classification."""
16 |
17 | def __init__(self, guid, text_a, text_b=None, label=None):
18 | """Constructs a InputExample.
19 |
20 | Args:
21 | guid: Unique id for the example.
22 | text_a: string. The untokenized text of the first sequence. For single
23 | sequence tasks, only this sequence must be specified.
24 | text_b: (Optional) string. The untokenized text of the second sequence.
25 | Only must be specified for sequence pair tasks.
26 | label: (Optional) string. The label of the example. This should be
27 | specified for train and dev examples, but not for test examples.
28 | """
29 | self.guid = guid
30 | self.text_a = text_a
31 | self.text_b = text_b
32 | self.label = label
33 |
34 | class DataProcessor(object):
35 | """Base class for data converters for sequence classification data sets."""
36 |
37 | def get_train_examples(self, data_dir):
38 | """Gets a collection of `InputExample`s for the train set."""
39 | raise NotImplementedError()
40 |
41 | def get_dev_examples(self, data_dir):
42 | """Gets a collection of `InputExample`s for the dev set."""
43 | raise NotImplementedError()
44 |
45 | def get_test_examples(self, data_dir):
46 | """Gets a collection of `InputExample`s for the test set."""
47 | raise NotImplementedError()
48 |
49 | def get_labels(self):
50 | """Gets the list of labels for this data set."""
51 | raise NotImplementedError()
52 |
53 | @classmethod
54 | def _read_tsv(cls, input_file, quotechar=None):
55 | """Reads a tab separated value file."""
56 | with open(input_file, "r") as f:
57 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
58 | lines = []
59 | for line in reader:
60 | if sys.version_info[0] == 2:
61 | line = list(unicode(cell, 'utf-8') for cell in line)
62 | lines.append(line)
63 | return lines
64 |
65 | def convert_examples_to_features(examples, task_list, max_seq_length, tokenizer):
66 | """Loads a data file into a list of `InputBatch`s."""
67 |
68 | features = []
69 | max_len = 0
70 | for (ex_index, example) in enumerate(examples):
71 | tokens_a = tokenizer.tokenize(example.text_a)
72 |
73 | tokens_b = None
74 | if example.text_b:
75 | tokens_b = tokenizer.tokenize(example.text_b)
76 | seq_len = len(tokens_a) + len(tokens_b)
77 |
78 | # Modifies `tokens_a` and `tokens_b` in place so that the total
79 | # length is less than the specified length.
80 | # Account for [CLS], [SEP], [SEP] with "- 3"
81 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
82 | else:
83 | seq_len = len(tokens_a)
84 | # Account for [CLS] and [SEP] with "- 2"
85 | if len(tokens_a) > max_seq_length - 2:
86 | tokens_a = tokens_a[:(max_seq_length - 2)]
87 |
88 | if seq_len > max_len:
89 | max_len = seq_len
90 | # The convention in BERT is:
91 | # (a) For sequence pairs:
92 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
93 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
94 | # (b) For single sequences:
95 | # tokens: [CLS] the dog is hairy . [SEP]
96 | # type_ids: 0 0 0 0 0 0 0
97 | #
98 | # Where "type_ids" are used to indicate whether this is the first
99 | # sequence or the second sequence. The embedding vectors for `type=0` and
100 | # `type=1` were learned during pre-training and are added to the wordpiece
101 | # embedding vector (and position vector). This is not *strictly* necessary
102 | # since the [SEP] token unambigiously separates the sequences, but it makes
103 | # it easier for the model to learn the concept of sequences.
104 | #
105 | # For classification tasks, the first vector (corresponding to [CLS]) is
106 | # used as as the "sentence vector". Note that this only makes sense because
107 | # the entire model is fine-tuned.
108 | tokens = ["[CLS]"] + tokens_a + ["[SEP]"]
109 | segment_ids = [0] * len(tokens)
110 |
111 | if tokens_b:
112 | tokens += tokens_b + ["[SEP]"]
113 | segment_ids += [1] * (len(tokens_b) + 1)
114 |
115 | input_ids = tokenizer.convert_tokens_to_ids(tokens)
116 |
117 | # The mask has 1 for real tokens and 0 for padding tokens. Only real
118 | # tokens are attended to.
119 | input_mask = [1] * len(input_ids)
120 |
121 | # Zero-pad up to the sequence length.
122 | padding = [0] * (max_seq_length - len(input_ids))
123 | input_ids += padding
124 | input_mask += padding
125 | segment_ids += padding
126 |
127 | assert len(input_ids) == max_seq_length
128 | assert len(input_mask) == max_seq_length
129 | assert len(segment_ids) == max_seq_length
130 |
131 | features.append(InputFeatures(
132 | input_ids=input_ids,
133 | input_mask=input_mask,
134 | segment_ids=segment_ids,
135 | label_ids=[example.label[t] for t in task_list],
136 | ))
137 |
138 | print('Max Sequence Length: %d' %max_len)
139 |
140 | return features
141 |
142 | def _truncate_seq_pair(tokens_a, tokens_b, max_length):
143 | """Truncates a sequence pair in place to the maximum length."""
144 |
145 | # This is a simple heuristic which will always truncate the longer sequence
146 | # one token at a time. This makes more sense than truncating an equal percent
147 | # of tokens from each, since if one sequence is very short then each token
148 | # that's truncated likely contains more information than a longer sequence.
149 | while True:
150 | total_length = len(tokens_a) + len(tokens_b)
151 | if total_length <= max_length:
152 | break
153 | if len(tokens_a) > len(tokens_b):
154 | tokens_a.pop()
155 | else:
156 | tokens_b.pop()
--------------------------------------------------------------------------------
/chexpert_approximator/model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from pytorch_pretrained_bert.modeling import BertPreTrainedModel, BertModel
4 | from torch import nn
5 |
6 | def multi_task_accuracy(out_by_task, labels_by_task):
7 | acc = 0
8 | for task in out_by_task.keys():
9 | out, labels = out_by_task[task], labels_by_task[task]
10 | outputs = np.argmax(out, axis=1)
11 | acc += np.sum(outputs == labels)
12 | return acc / len(out_by_task)
13 |
14 | # TODO(mmd): should probably not do via a dictionary. Possible to do all as tensor.
15 | class BertForMultitaskSequenceClassification(BertPreTrainedModel):
16 | """BERT model for multitask sequence classification.
17 | This module is composed of the BERT model with a linear layer per task on top of
18 | the pooled output.
19 | Params:
20 | `config`: a BertConfig class instance with the configuration to build a new model.
21 | `num_labels_per_task`: A dictionary from task_name to the number of classes for the classifier.
22 | Inputs:
23 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
24 | with the word token indices in the vocabulary. Items in the batch should begin with the special "CLS" token. (see the tokens preprocessing logic in the scripts
25 | `extract_features.py`, `run_classifier.py` and `run_squad.py`)
26 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
27 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
28 | a `sentence B` token (see BERT paper for more details).
29 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
30 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
31 | input sequence length in the current batch. It's the mask that we typically use for attention when
32 | a batch has varying length sentences.
33 | `labels_per_task`: A dictionary of task name to labels for the classification output:
34 | {task: torch.LongTensor of shape [batch_size] with indices selected in [0, ..., num_labels_per_task[task]]}.
35 | Outputs:
36 | TODO(mmd): update outputs & example usage.
37 | Outputs the classification logits per task as a dictionary of shape {task: [batch_size, num_labels]}.
38 | Outputs the CrossEntropy classification loss as a dictionary of shape {task: loss} for tasks with provided labels
39 | Example usage:
40 | ```python
41 | # Already been converted into WordPiece token ids
42 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
43 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
44 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
45 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
46 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
47 | num_labels = 2
48 | model = BertForSequenceClassification(config, num_labels)
49 | logits = model(input_ids, token_type_ids, input_mask)
50 | ```
51 | """
52 | def __init__(self, config, num_labels_per_task):
53 | super().__init__(config)
54 | self.num_labels_per_task = num_labels_per_task
55 | self.bert = BertModel(config)
56 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
57 | self.classifiers = nn.ModuleDict(
58 | {task_name: nn.Linear(config.hidden_size, task_dim) for task_name, task_dim in num_labels_per_task.items()}
59 | )
60 | self.apply(self.init_bert_weights)
61 |
62 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels_per_task={}):
63 | _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
64 | pooled_output = self.dropout(pooled_output)
65 | logits = {task: layer(pooled_output) for task, layer in self.classifiers.items()}
66 |
67 | losses = {}
68 | for task, labels in labels_per_task.items():
69 | losses[task] = nn.CrossEntropyLoss()(logits[task].view(-1, self.num_labels_per_task[task]), labels.view(-1))
70 |
71 | return logits, losses
72 |
--------------------------------------------------------------------------------
/chexpert_approximator/reload_and_get_logits.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import os
17 | import random
18 |
19 | import numpy as np
20 | import torch
21 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, TensorDataset)
22 | from torch.utils.data.distributed import DistributedSampler
23 | from tqdm import tqdm, tqdm_notebook
24 |
25 | from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
26 | from pytorch_pretrained_bert.modeling import BertConfig, WEIGHTS_NAME, CONFIG_NAME
27 | from pytorch_pretrained_bert.tokenization import BertTokenizer
28 | from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
29 |
30 | #added
31 | import json
32 | from random import shuffle
33 | from sklearn.model_selection import KFold
34 | import math
35 |
36 | from .data_processor import *
37 | from .model import *
38 |
39 | def reload_and_get_logits(
40 | df,
41 | bert_model,
42 | processor,
43 | task_dimensions,
44 | output_dir,
45 | processor_args = {},
46 | gpu = None,
47 | seed = 42,
48 | do_lower_case = False,
49 | max_seq_length = 128,
50 | batch_size = 8,
51 | learning_rate = 5e-5,
52 | num_train_epochs = 3,
53 | cache_dir = None,
54 | tqdm = tqdm_notebook,
55 | model = None,
56 | ):
57 | print(bert_model)
58 |
59 | if gpu is not None and torch.cuda.is_available():
60 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu
61 | device = torch.device("cuda")
62 | n_gpu = torch.cuda.device_count()
63 | else:
64 | device = torch.device("cpu")
65 | n_gpu = 0
66 |
67 | random.seed(seed)
68 | np.random.seed(seed)
69 | torch.manual_seed(seed)
70 | if n_gpu > 0:
71 | torch.cuda.manual_seed_all(seed)
72 |
73 | if not os.path.exists(output_dir) or len(os.listdir(output_dir)) == 0:
74 | raise ValueError("Output directory ({}) doesn't exist or is not empty.".format(output_dir))
75 |
76 | label_map = processor.get_labels()
77 | task_list = list(label_map.keys())
78 | num_tasks = len(label_map)
79 | tokenizer = BertTokenizer.from_pretrained(bert_model, do_lower_case=do_lower_case)
80 |
81 | # Prepare model
82 | ### TODO(mmd): this is where to reload the model properly.
83 | cache_dir = cache_dir if cache_dir else os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_-1')
84 | # Load a trained model and config that you have fine-tuned
85 | output_model_file = os.path.join(output_dir, WEIGHTS_NAME)
86 | output_config_file = os.path.join(output_dir, CONFIG_NAME)
87 | config = BertConfig(output_config_file)
88 |
89 | if model is None:
90 | model = BertForMultitaskSequenceClassification(
91 | config,
92 | num_labels_per_task = {task: len(labels) for task, labels in label_map.items()},
93 | )
94 | model.load_state_dict(torch.load(output_model_file))
95 |
96 | model.to(device)
97 |
98 | eval_examples = processor.get_examples(df, **processor_args)
99 | eval_features = convert_examples_to_features(
100 | eval_examples, task_list, max_seq_length, tokenizer
101 | )
102 |
103 | all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
104 | all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
105 | all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
106 | all_label_ids = torch.tensor([f.label_ids for f in eval_features], dtype=torch.long)
107 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
108 | # Run prediction for full data
109 | eval_sampler = SequentialSampler(eval_data)
110 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=batch_size)
111 |
112 | model.eval()
113 | eval_loss, eval_accuracy = 0, 0
114 | nb_eval_steps, nb_eval_examples = 0, 0
115 |
116 | all_logits = {t: [] for t in task_list}
117 | for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
118 | input_ids = input_ids.to(device)
119 | input_mask = input_mask.to(device)
120 | segment_ids = segment_ids.to(device)
121 | label_ids = label_ids.to(device)
122 | label_ids = {t: label_ids[:, i] for i, t in enumerate(task_list)}
123 |
124 | with torch.no_grad():
125 | logits, tmp_eval_loss = model(input_ids, segment_ids, input_mask, label_ids)
126 | # logits = model(input_ids, segment_ids, input_mask)
127 |
128 | logits = {t: lt.detach().cpu().numpy() for t, lt in logits.items()}
129 | for t in task_list: all_logits[t].extend(list(logits[t]))
130 | label_ids = {t: l.to('cpu').numpy() for t, l in label_ids.items()}
131 |
132 | tmp_eval_accuracy = multi_task_accuracy(logits, label_ids)
133 |
134 | tmp_eval_loss = sum(tmp_eval_loss.values())
135 | eval_loss += tmp_eval_loss.mean().item()
136 | eval_accuracy += tmp_eval_accuracy
137 |
138 | nb_eval_examples += input_ids.size(0)
139 | nb_eval_steps += 1
140 |
141 | eval_loss = eval_loss / nb_eval_steps
142 | eval_accuracy = eval_accuracy / nb_eval_examples
143 | result = {'eval_loss': eval_loss,
144 | 'eval_accuracy': eval_accuracy}
145 |
146 | output_eval_file = os.path.join(output_dir, "eval_results.txt")
147 | with open(output_eval_file, "w") as writer:
148 | for key in sorted(result.keys()): writer.write("%s = %s\n" % (key, str(result[key])))
149 | return model, all_logits, result
150 |
--------------------------------------------------------------------------------
/chexpert_approximator/run_classifier.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import os
17 | import random
18 |
19 | import numpy as np
20 | import torch
21 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, TensorDataset)
22 | from torch.utils.data.distributed import DistributedSampler
23 | from tqdm import tqdm, tqdm_notebook
24 |
25 | from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
26 | from pytorch_pretrained_bert.modeling import BertConfig, WEIGHTS_NAME, CONFIG_NAME
27 | from pytorch_pretrained_bert.tokenization import BertTokenizer
28 | from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
29 |
30 | #added
31 | import json
32 | from random import shuffle
33 | from sklearn.model_selection import KFold
34 | import math
35 |
36 | from .data_processor import *
37 | from .model import *
38 |
39 | def build_and_train(
40 | df,
41 | bert_model,
42 | processor,
43 | task_dimensions,
44 | output_dir,
45 | gradient_accumulation_steps = 1,
46 | gpu = None,
47 | do_train = True,
48 | do_eval = False,
49 | seed = 42,
50 | do_lower_case = False,
51 | max_seq_length = 128,
52 | train_batch_size = 32,
53 | eval_batch_size = 8,
54 | learning_rate = 5e-5,
55 | num_train_epochs = 3,
56 | warmup_proportion = 0.1,
57 | cache_dir = None,
58 | tqdm = tqdm_notebook,
59 | model = None,
60 | ):
61 | print(bert_model)
62 |
63 | if gpu is not None and torch.cuda.is_available():
64 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu
65 | device = torch.device("cuda")
66 | n_gpu = torch.cuda.device_count()
67 | else:
68 | device = torch.device("cpu")
69 | n_gpu = 0
70 |
71 | if gradient_accumulation_steps < 1:
72 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
73 | gradient_accumulation_steps))
74 |
75 | train_batch_size = train_batch_size // gradient_accumulation_steps
76 |
77 | random.seed(seed)
78 | np.random.seed(seed)
79 | torch.manual_seed(seed)
80 | if n_gpu > 0:
81 | torch.cuda.manual_seed_all(seed)
82 |
83 | if not do_train and not do_eval:
84 | raise ValueError("At least one of `do_train` or `do_eval` must be True.")
85 |
86 | if os.path.exists(output_dir) and os.listdir(output_dir) and do_train:
87 | raise ValueError("Output directory ({}) already exists and is not empty.".format(output_dir))
88 | if not os.path.exists(output_dir):
89 | os.makedirs(output_dir)
90 |
91 | label_map = processor.get_labels()
92 | task_list = list(label_map.keys())
93 | num_tasks = len(label_map)
94 | tokenizer = BertTokenizer.from_pretrained(bert_model, do_lower_case=do_lower_case)
95 |
96 | #TESTING
97 | #splits = processor.create_cv_examples(data_dir)
98 |
99 |
100 | train_examples = None
101 | num_train_optimization_steps = None
102 | if do_train:
103 | train_examples = processor.get_train_examples(df)
104 | num_train_optimization_steps = int(
105 | len(train_examples) / train_batch_size / gradient_accumulation_steps) * num_train_epochs
106 |
107 | # Prepare model
108 | cache_dir = cache_dir if cache_dir else os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_-1')
109 |
110 | if model is None:
111 | model = BertForMultitaskSequenceClassification.from_pretrained(
112 | bert_model,
113 | cache_dir=cache_dir,
114 | num_labels_per_task = {task: len(labels) for task, labels in label_map.items()},
115 | )
116 |
117 | model.to(device)
118 | if n_gpu > 1: model = torch.nn.DataParallel(model)
119 |
120 | # Prepare optimizer
121 | param_optimizer = list(model.named_parameters())
122 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
123 | optimizer_grouped_parameters = [
124 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
125 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
126 | ]
127 |
128 | optimizer = BertAdam(
129 | optimizer_grouped_parameters,
130 | lr=learning_rate,
131 | warmup=warmup_proportion,
132 | t_total=num_train_optimization_steps
133 | )
134 |
135 | global_step = 0
136 | nb_tr_steps = 0
137 | tr_loss = 0
138 | if do_train:
139 | train_features = convert_examples_to_features(
140 | train_examples, task_list, max_seq_length, tokenizer
141 | )
142 |
143 | all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
144 | all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
145 | all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
146 | all_label_ids = torch.tensor([f.label_ids for f in train_features], dtype=torch.long)
147 |
148 | train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
149 | train_sampler = RandomSampler(train_data)
150 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=train_batch_size)
151 |
152 | model.train()
153 | for _ in tqdm(range(int(num_train_epochs)), desc="Epoch"):
154 | tr_loss = 0
155 | nb_tr_examples, nb_tr_steps = 0, 0
156 | for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
157 | batch = tuple(t.to(device) for t in batch)
158 | input_ids, input_mask, segment_ids, label_ids = batch
159 |
160 | label_ids = {t: label_ids[:, i] for i, t in enumerate(task_list)}
161 |
162 | _, losses = model(input_ids, segment_ids, input_mask, label_ids)
163 | loss = sum(losses.values()) / num_tasks
164 |
165 | if n_gpu > 1:
166 | loss = loss.mean() # mean() to average on multi-gpu.
167 | if gradient_accumulation_steps > 1:
168 | loss = loss / gradient_accumulation_steps
169 |
170 | loss.backward()
171 |
172 | tr_loss += loss.item()
173 | nb_tr_examples += input_ids.size(0)
174 | nb_tr_steps += 1
175 | if (step + 1) % gradient_accumulation_steps == 0:
176 | optimizer.step()
177 | optimizer.zero_grad()
178 | global_step += 1
179 | #### HERE NEED TO MODIFY.
180 |
181 | if do_train:
182 | # Save a trained model and the associated configuration
183 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
184 | output_model_file = os.path.join(output_dir, WEIGHTS_NAME)
185 | torch.save(model_to_save.state_dict(), output_model_file)
186 | output_config_file = os.path.join(output_dir, CONFIG_NAME)
187 | with open(output_config_file, 'w') as f:
188 | f.write(model_to_save.config.to_json_string())
189 |
190 | # Load a trained model and config that you have fine-tuned
191 | config = BertConfig(output_config_file)
192 | model = BertForMultitaskSequenceClassification(
193 | config,
194 | num_labels_per_task = {task: len(labels) for task, labels in label_map.items()},
195 | )
196 | model.load_state_dict(torch.load(output_model_file))
197 | else:
198 | model = BertForMultitaskSequenceClassification.from_pretrained(
199 | bert_model,
200 | num_labels_per_task = {task: len(labels) for task, labels in label_map.items()},
201 | )
202 | model.to(device)
203 |
204 | if do_eval:
205 | eval_examples = processor.get_dev_examples(df)
206 | eval_features = convert_examples_to_features(
207 | eval_examples, task_list, max_seq_length, tokenizer
208 | )
209 |
210 | all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
211 | all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
212 | all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
213 | all_label_ids = torch.tensor([f.label_ids for f in eval_features], dtype=torch.long)
214 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
215 | # Run prediction for full data
216 | eval_sampler = SequentialSampler(eval_data)
217 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=eval_batch_size)
218 |
219 | model.eval()
220 | eval_loss, eval_accuracy = 0, 0
221 | nb_eval_steps, nb_eval_examples = 0, 0
222 |
223 | for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
224 | input_ids = input_ids.to(device)
225 | input_mask = input_mask.to(device)
226 | segment_ids = segment_ids.to(device)
227 | label_ids = label_ids.to(device)
228 |
229 | label_ids = {t: label_ids[:, i] for i, t in enumerate(task_list)}
230 |
231 | with torch.no_grad():
232 | logits, tmp_eval_loss = model(input_ids, segment_ids, input_mask, label_ids)
233 | # logits = model(input_ids, segment_ids, input_mask)
234 |
235 | logits = {t: lt.detach().cpu().numpy() for t, lt in logits.items()}
236 | label_ids = {t: l.to('cpu').numpy() for t, l in label_ids.items()}
237 |
238 | tmp_eval_accuracy = multi_task_accuracy(logits, label_ids)
239 |
240 | tmp_eval_loss = sum(tmp_eval_loss.values())
241 | eval_loss += tmp_eval_loss.mean().item()
242 | eval_accuracy += tmp_eval_accuracy
243 |
244 | nb_eval_examples += input_ids.size(0)
245 | nb_eval_steps += 1
246 |
247 | eval_loss = eval_loss / nb_eval_steps
248 | eval_accuracy = eval_accuracy / nb_eval_examples
249 | loss = tr_loss/nb_tr_steps if do_train else None
250 | result = {'eval_loss': eval_loss,
251 | 'eval_accuracy': eval_accuracy,
252 | 'global_step': global_step,
253 | 'loss': loss}
254 |
255 | output_eval_file = os.path.join(output_dir, "eval_results.txt")
256 | with open(output_eval_file, "w") as writer:
257 | for key in sorted(result.keys()): writer.write("%s = %s\n" % (key, str(result[key])))
258 | return result
259 |
--------------------------------------------------------------------------------
/env.yml:
--------------------------------------------------------------------------------
1 | name:
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - attrs=19.1.0=py_0
8 | - backcall=0.1.0=py_0
9 | - bleach=3.1.0=py_0
10 | - bzip2=1.0.6=h14c3975_1002
11 | - ca-certificates=2019.3.9=hecc5488_0
12 | - certifi=2019.3.9=py37_0
13 | - cffi=1.12.3=py37h8022711_0
14 | - cycler=0.10.0=py_1
15 | - decorator=4.4.0=py_0
16 | - defusedxml=0.5.0=py_1
17 | - entrypoints=0.3=py37_1000
18 | - expat=2.2.5=hf484d3e_1002
19 | - fontconfig=2.13.1=he4413a7_1000
20 | - freetype=2.10.0=he983fc9_0
21 | - future=0.17.1=py37_1000
22 | - gettext=0.19.8.1=hc5be6a0_1002
23 | - glib=2.56.2=had28632_1001
24 | - hyperopt=0.1.2=py_0
25 | - icu=58.2=hf484d3e_1000
26 | - ipykernel=5.1.0=py37h24bf2e0_1002
27 | - ipython=7.4.0=py37h24bf2e0_0
28 | - ipython_genutils=0.2.0=py_1
29 | - ipywidgets=7.4.2=py_0
30 | - jedi=0.13.3=py37_0
31 | - jinja2=2.10.1=py_0
32 | - jpeg=9c=h14c3975_1001
33 | - jsonschema=3.0.1=py37_0
34 | - jupyter=1.0.0=py_2
35 | - jupyter_client=5.2.4=py_3
36 | - jupyter_console=6.0.0=py_0
37 | - jupyter_contrib_core=0.3.3=py_2
38 | - jupyter_contrib_nbextensions=0.5.1=py37_0
39 | - jupyter_core=4.4.0=py_0
40 | - jupyter_highlight_selected_word=0.2.0=py37_1000
41 | - jupyter_latex_envs=1.4.4=py37_1000
42 | - jupyter_nbextensions_configurator=0.4.1=py37_0
43 | - kiwisolver=1.0.1=py37h6bb024c_1002
44 | - libblas=3.8.0=5_openblas
45 | - libcblas=3.8.0=5_openblas
46 | - libffi=3.2.1=he1b5a44_1006
47 | - libiconv=1.15=h516909a_1005
48 | - liblapack=3.8.0=5_openblas
49 | - libpng=1.6.37=hed695b0_0
50 | - libsodium=1.0.16=h14c3975_1001
51 | - libtiff=4.0.10=h648cc4a_1001
52 | - libuuid=2.32.1=h14c3975_1000
53 | - libxcb=1.13=h14c3975_1002
54 | - libxml2=2.9.9=h13577e0_0
55 | - libxslt=1.1.32=h4785a14_1002
56 | - lxml=4.3.3=py37h7ec2d77_0
57 | - markupsafe=1.1.1=py37h14c3975_0
58 | - matplotlib=3.0.3=py37_1
59 | - matplotlib-base=3.0.3=py37h5f35d83_1
60 | - mistune=0.8.4=py37h14c3975_1000
61 | - nbconvert=5.4.1=py_2
62 | - nbformat=4.4.0=py_1
63 | - ncurses=6.1=hf484d3e_1002
64 | - networkx=2.3=py_0
65 | - ninja=1.9.0=h6bb024c_0
66 | - notebook=5.7.8=py37_0
67 | - numpy=1.16.3=py37he5ce36f_0
68 | - olefile=0.46=py_0
69 | - openblas=0.3.5=h9ac9557_1001
70 | - openssl=1.1.1b=h14c3975_1
71 | - pandas=0.24.2=py37hf484d3e_0
72 | - pandoc=2.7.2=0
73 | - pandocfilters=1.4.2=py_1
74 | - parso=0.4.0=py_0
75 | - pexpect=4.7.0=py37_0
76 | - pickleshare=0.7.5=py37_1000
77 | - pillow=6.0.0=py37he7afcd5_0
78 | - pip=19.0.3=py37_0
79 | - prometheus_client=0.6.0=py_0
80 | - prompt_toolkit=2.0.9=py_0
81 | - pthread-stubs=0.4=h14c3975_1001
82 | - ptyprocess=0.6.0=py_1001
83 | - pycparser=2.19=py37_1
84 | - pygments=2.3.1=py_0
85 | - pymongo=3.7.2=py37hf484d3e_0
86 | - pyparsing=2.4.0=py_0
87 | - pyqt=5.6.0=py37h13b7fb3_1008
88 | - pyrsistent=0.14.11=py37h14c3975_0
89 | - python=3.7.3=h5b0a415_0
90 | - python-dateutil=2.8.0=py_0
91 | - pytz=2019.1=py_0
92 | - pyyaml=5.1=py37h14c3975_0
93 | - pyzmq=18.0.1=py37hc4ba49a_1
94 | - qtconsole=4.4.3=py_0
95 | - readline=7.0=hf8c457e_1001
96 | - scikit-learn=0.20.3=py37ha8026db_1
97 | - scipy=1.2.1=py37h09a28d5_1
98 | - send2trash=1.5.0=py_0
99 | - setuptools=41.0.1=py37_0
100 | - sip=4.18.1=py37hf484d3e_1000
101 | - six=1.12.0=py37_1000
102 | - sqlite=3.26.0=h67949de_1001
103 | - terminado=0.8.2=py37_0
104 | - testpath=0.4.2=py_1001
105 | - tk=8.6.9=h84994c4_1001
106 | - tornado=6.0.2=py37h516909a_0
107 | - tqdm=4.31.1=py_0
108 | - traitlets=4.3.2=py37_1000
109 | - wcwidth=0.1.7=py_1
110 | - webencodings=0.5.1=py_1
111 | - wheel=0.33.1=py37_0
112 | - widgetsnbextension=3.4.2=py37_1000
113 | - xorg-libxau=1.0.9=h14c3975_0
114 | - xorg-libxdmcp=1.1.3=h516909a_0
115 | - xz=5.2.4=h14c3975_1001
116 | - yaml=0.1.7=h14c3975_1001
117 | - zeromq=4.3.1=hf484d3e_1000
118 | - zlib=1.2.11=h14c3975_1004
119 | - cudatoolkit=9.0=h13b8566_0
120 | - dbus=1.13.2=h714fa37_1
121 | - gst-plugins-base=1.14.0=hbbd80ab_1
122 | - gstreamer=1.14.0=hb453b48_1
123 | - intel-openmp=2019.3=199
124 | - libgcc-ng=8.2.0=hdf63c60_1
125 | - libgfortran-ng=7.3.0=hdf63c60_0
126 | - libstdcxx-ng=8.2.0=hdf63c60_1
127 | - mkl=2019.3=199
128 | - pcre=8.43=he6710b0_0
129 | - qt=5.6.3=h8bf5577_3
130 | - pytorch=1.0.1=py3.7_cuda9.0.176_cudnn7.4.2_2
131 | - torchvision=0.2.2=py_3
132 | - pip:
133 | - boto3==1.9.134
134 | - botocore==1.12.134
135 | - chardet==3.0.4
136 | - docutils==0.14
137 | - idna==2.8
138 | - jmespath==0.9.4
139 | - pytorch-pretrained-bert==0.6.1
140 | - regex==2019.4.14
141 | - requests==2.21.0
142 | - s3transfer==0.2.0
143 | - torch==1.0.1.post2
144 | - urllib3==1.24.2
145 | prefix: /scratch/conda_envs/chexpert_approximator
146 |
--------------------------------------------------------------------------------