├── .gitignore
├── downstream_tasks
├── i2b2_preprocessing
│ ├── i2b2_2006_deid
│ │ ├── split_train_dev.py
│ │ └── to_conll.py
│ ├── i2b2_2012
│ │ └── Reformat.ipynb
│ └── i2b2_2014_deid_hf_risk
│ │ └── Reformat.ipynb
├── run_classifier.sh
├── ner_eval
│ ├── score_i2b2.py
│ ├── ner_detokenize.py
│ ├── format_for_i2b2_eval.py
│ └── conlleval.pl
├── run_i2b2.sh
├── run_classifier.py
└── run_ner.py
├── lm_pretraining
├── create_pretrain_data.sh
├── finetune_lm_tf.sh
├── format_mimic_for_BERT.py
├── heuristic_tokenize.py
├── create_pretraining_data.py
└── run_pretraining.py
├── LICENSE
├── README.md
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
--------------------------------------------------------------------------------
/downstream_tasks/i2b2_preprocessing/i2b2_2006_deid/split_train_dev.py:
--------------------------------------------------------------------------------
1 | with open('train_dev.conll', 'r') as f:
2 | sents = f.read().strip().split('\n\n')
3 |
4 | print(len(sents))
5 | import random
6 |
7 | random.seed(555)
8 |
9 | n = len(sents)
10 | ind = int(0.7*n)
11 |
12 | train_sents = sents[:ind]
13 | dev_sents = sents[ind:]
14 |
15 | print(len(train_sents))
16 | with open('train.conll', 'w') as f:
17 | f.write('\n\n'.join(train_sents))
18 |
19 | print(len(dev_sents))
20 | with open('dev.conll', 'w') as f:
21 | f.write('\n\n'.join(dev_sents))
22 |
--------------------------------------------------------------------------------
/lm_pretraining/create_pretrain_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | BERT_BASE_DIR=/PATH/TO/BERT/VOCAB/FILE #modify this to bert or biobert folder containing a vocab.txt file
4 | DATA_DIR=/PATH/TO/TOKENIZED/NOTES #modify this to be the path to the tokenized data
5 | OUTPUT_DIR=/PATH/TO/OUTPUT/DIR # modify this to be your output directory path
6 |
7 |
8 | #modify this to be the note type that you want to create pretraining data for - e.g. ecg, echo, radiology, physician, nursing, etc.
9 | # Note that you can also specify multiple input files & output files below
10 | DATA_FILE=nursing_other
11 |
12 |
13 | # Note that create_pretraining_data.py is unmodified from the script in the original BERT repo.
14 | # Refer to the BERT repo for the most up to date version of this code.
15 | python create_pretraining_data.py \
16 | --input_file=$DATA_DIR/$DATA_FILE.txt \
17 | --output_file=$OUTPUT_DIR/$DATA_FILE.tfrecord \
18 | --vocab_file=$BERT_BASE_DIR/vocab.txt \
19 | --do_lower_case=False \
20 | --max_seq_length=128 \
21 | --max_predictions_per_seq=22 \
22 | --masked_lm_prob=0.15 \
23 | --random_seed=12345 \
24 | --dupe_factor=5
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Emily Alsentzer
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/downstream_tasks/run_classifier.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #Example script for running run_classifier.py
4 |
5 |
6 | for EPOCHS in 3 4 5 ; do
7 | for LR in 2e-5 3e-5 5e-5; do
8 | for BATCH_SZ in 16 32 ; do
9 | MAX_SEQ_LEN=150
10 |
11 | DATA_DIR=PATH/TO/MEDNLI/DATA/ #Modify this to be the path to the MedNLI data
12 | OUTPUT_DIR=PATH/TO/OUTPUT/DIR/ #Modify this to be the path to your output directory
13 | CLINICAL_BERT_LOC=PATH/TO/CLINICAL/BERT/MODEL #Modify this to be the path to the clinical BERT model
14 |
15 | echo $OUTPUT_DIR
16 |
17 | BERT_MODEL=clinical_bert # You can change this to biobert or bert-base-cased
18 |
19 | mkdir -p $OUTPUT_DIR
20 |
21 | python run_classifier.py \
22 | --data_dir=$DATA_DIR \
23 | --bert_model=$BERT_MODEL \
24 | --model_loc $CLINICAL_BERT_LOC \
25 | --task_name mednli \
26 | --do_train \
27 | --do_eval \
28 | --do_test \
29 | --output_dir=$OUTPUT_DIR \
30 | --num_train_epochs $EPOCHS \
31 | --learning_rate $LR \
32 | --train_batch_size $BATCH_SZ \
33 | --max_seq_length $MAX_SEQ_LEN \
34 | --gradient_accumulation_steps 2
35 | done
36 | done
37 | done
38 |
39 |
--------------------------------------------------------------------------------
/lm_pretraining/finetune_lm_tf.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # location of bert or biobert model
4 | # Note that if you use biobert as your base model, you'll need to change init_checkpoint to be biobert_model.ckpt
5 | BERT_BASE_DIR=/PATH/TO/BERT/MODEL
6 |
7 | # folder where you want to save your clinical BERT model
8 | OUTPUT_DIR=/PATH/TO/CLINICAL/BERT/OUTPUT/DIR
9 |
10 | # folder that contains the tfrecords - this will be the output directory from create_pretrain_data.sh
11 | INPUT_FILES_DIR=/PATH/TO/TFRECORDS
12 |
13 | NUM_TRAIN_STEPS=100000
14 | NUM_WARMUP_STEPS=10000
15 | LR=5e-5
16 |
17 | # This example illustrates the training of Bio+Discharge Summary BERT. If you change the input_file
18 | # to the tfrecords for all MIMIC sections - e.g.
19 | # --input_file=../data/tf_records/discharge_summary.tfrecord,../data/tf_records/physician.tfrecord,../data/tf_records/nursing.tfrecord,../data/tf_records/nursing_other.tfrecord,../data/tf_records/radiology.tfrecord,../data/tf_records/general.tfrecord,../data/tf_records/respiratory.tfrecord,../data/tf_records/consult.tfrecord,../data/tf_records/nutrition.tfrecord,../data/tf_records/case_management.tfrecord,../data/tf_records/pharmacy.tfrecord,../data/tf_records/rehab_services.tfrecord,../data/tf_records/social_work.tfrecord,../data/tf_records/ecg.tfrecord,../data/tf_records/echo.tfrecord \
20 |
21 | python run_pretraining.py \
22 | --output_dir=$OUTPUT_DIR \
23 | --input_file=$INPUT_FILES_DIR/discharge_summary.tfrecord \
24 | --do_train=True \
25 | --do_eval=True \
26 | --bert_config_file=$BERT_BASE_DIR/bert_config.json \
27 | --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
28 | --train_batch_size=32 \
29 | --max_seq_length=128 \
30 | --max_predictions_per_seq=20 \
31 | --num_train_steps=$NUM_TRAIN_STEPS \
32 | --num_warmup_steps=$NUM_WARMUP_STEPS \
33 | --learning_rate=$LR \
34 | --save_checkpoints_steps=50000 \
35 | --keep_checkpoint_max=15
36 |
37 |
--------------------------------------------------------------------------------
/downstream_tasks/i2b2_preprocessing/i2b2_2006_deid/to_conll.py:
--------------------------------------------------------------------------------
1 | import re
2 | import sys
3 | import xml.etree.ElementTree as ET
4 |
5 |
6 | i = 0
7 | with open(sys.argv[1]) as f:
8 | for line in f.readlines():
9 | # filter out non-lines
10 | if line.startswith( '') or \
11 | line.startswith(''):
12 | continue
13 |
14 | # Parse PHI type/tokens
15 | regex = '((.*?))'
16 | phi_tags = re.findall(regex, line)
17 | for tag in phi_tags:
18 | line = line.replace(tag[0], '__phi__').strip()
19 |
20 | # Walk through sentence
21 | phi_ind = 0
22 | for w in line.split():
23 | if w == '__phi__':
24 | phi = phi_tags[phi_ind]
25 | tag = phi[1]
26 | toks = phi[2].split()
27 | print(toks[0], 'B-%s'%tag)
28 | for t in toks[1:]:
29 | print(t, 'I-%s'%tag)
30 | phi_ind += 1
31 | # Two elif statements check for edge cases with Dates
32 | elif w.startswith('__phi__'):
33 | # examples like following format:
34 | # 01/01/1995 or 01-01-95
35 | phi = phi_tags[phi_ind]
36 | tag = phi[1]
37 | toks = phi[2].split()
38 | print(toks[0], 'B-%s'%tag)
39 | if w[7:8] == '/' or w[7:8] == '-':
40 | print(w[8:], 'O') # remove the / or - in the year
41 | else:
42 | print(w[7:], 'O')
43 | phi_ind += 1
44 | elif w.endswith('__phi__'):
45 | # 19950101
46 | phi = phi_tags[phi_ind]
47 | tag = phi[1]
48 | toks = phi[2].split()
49 | print(w[:-7], 'O')
50 | print(toks[0], 'B-%s'%tag)
51 | phi_ind += 1
52 | else:
53 | print(w, 'O')
54 | print()
55 | i+=1
56 |
57 |
58 |
--------------------------------------------------------------------------------
/downstream_tasks/ner_eval/score_i2b2.py:
--------------------------------------------------------------------------------
1 |
2 | import argparse
3 | import os
4 | import re
5 |
6 |
7 | parser = argparse.ArgumentParser(description='')
8 | parser.add_argument('--input_pred_dir', type=str, help='Location of input i2b2 prediction formatted files')
9 | parser.add_argument('--input_gold_dir', type=str, help='Location of input i2b2 gold formatted files')
10 | parser.add_argument('--output_dir', type=str, help='Location of output files')
11 |
12 | args = parser.parse_args()
13 |
14 | def format_line(line):
15 | text, s_line, s_word, e_line, e_word, label = re.findall('c="(.*)" ([0-9]+):([0-9]+) ([0-9]+):([0-9]+)\|\|t="(.*)"', line)[0]
16 | return ({'text':text, 's_line':s_line, 's_word':s_word, 'e_line':e_line, 'e_word':e_word, 'label':label})
17 | # Example format: c="cortical-type symptoms" 820:6 820:7||t="Problem"
18 |
19 |
20 | pred_lines = []
21 | with open(os.path.join(args.input_pred_dir, "i2b2.con"),'r') as pred_f:
22 | for line in pred_f:
23 | pred_lines.append(format_line(line))
24 |
25 | gold_lines = []
26 | with open(os.path.join(args.input_gold_dir, "i2b2.con"),'r') as gold_f:
27 | for line in gold_f:
28 | gold_lines.append(format_line(line))
29 |
30 | def in_gold(pred, gold_lines):
31 | for gold in gold_lines:
32 | if gold == pred:
33 | return True
34 | return False
35 |
36 | def in_pred(gold, pred_lines):
37 | for pred in pred_lines:
38 | if gold == pred:
39 | return True
40 | return False
41 |
42 |
43 | precCount = 0
44 | for pred in pred_lines:
45 | if in_gold(pred, gold_lines):
46 | precCount += 1
47 |
48 |
49 | recallCount = 0
50 | for gold in gold_lines:
51 | if in_pred(gold, pred_lines):
52 | recallCount += 1
53 |
54 | # totalEvents: total number of Events in the first file
55 | # recall total number of Events in the gold file that can be found in the pred file
56 | # precision total number of Events in the pred file that can be found in the gold file
57 |
58 | systemEventCount = len(pred_lines)
59 | goldEventCount = len(gold_lines)
60 | print('Predicted events: %d, Gold events: %d' %(systemEventCount, goldEventCount))
61 | precision=float(precCount)/systemEventCount
62 | recall=float(recallCount)/goldEventCount
63 | fScore=2*(precision*recall)/(precision+recall)
64 | print('Exact Precision: %0.5f, Recall: %0.5f, F1: %0.5f' %(precision, recall, fScore))
65 |
66 | with open(os.path.join(args.output_dir, "final_results.txt"),'w') as writer:
67 | writer.write('Precision\tRecall\tF1\n')
68 | writer.write('%0.5f\t%0.5f\t%0.5f\n' %(precision, recall, fScore))
69 |
70 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # clinicalBERT
2 | Repository for [Publicly Available Clinical BERT Embeddings](https://www.aclweb.org/anthology/W19-1909/) (NAACL Clinical NLP Workshop 2019)
3 |
4 | ## Using Clinical BERT
5 |
6 | UPDATE: You can now use ClinicalBERT directly through the [transformers](https://github.com/huggingface/transformers) library. Check out the [Bio+Clinical BERT](https://huggingface.co/emilyalsentzer/Bio_ClinicalBERT) and [Bio+Discharge Summary BERT](https://huggingface.co/emilyalsentzer/Bio_Discharge_Summary_BERT) model pages for instructions on how to use the models within the Transformers library.
7 |
8 | ## Download Clinical BERT
9 |
10 | The Clinical BERT models can also be downloaded [here](https://www.dropbox.com/s/8armk04fu16algz/pretrained_bert_tf.tar.gz?dl=0), or via
11 |
12 | ```
13 | wget -O pretrained_bert_tf.tar.gz https://www.dropbox.com/s/8armk04fu16algz/pretrained_bert_tf.tar.gz?dl=1
14 | ```
15 |
16 | `biobert_pretrain_output_all_notes_150000` corresponds to Bio+Clinical BERT, and `biobert_pretrain_output_disch_100000` corresponds to Bio+Discharge Summary BERT. Both models are finetuned from [BioBERT](https://arxiv.org/abs/1901.08746). We specifically use the [BioBERT-Base v1.0 (+ PubMed 200K + PMC 270K)](https://github.com/naver/biobert-pretrained) version of BioBERT.
17 |
18 | `bert_pretrain_output_all_notes_150000` corresponds to Clinical BERT, and `bert_pretrain_output_disch_100000` corresponds to Discharge Summary BERT. Both models are finetuned from the cased version of BERT, specifically cased_L-12_H-768_A-12.
19 |
20 | ## Reproduce Clinical BERT
21 | #### Pretraining
22 | To reproduce the steps necessary to finetune BERT or BioBERT on MIMIC data, follow the following steps:
23 | 1. Run `format_mimic_for_BERT.py` - Note you'll need to change the file paths at the top of the file.
24 | 2. Run `create_pretrain_data.sh`
25 | 3. Run `finetune_lm_tf.sh`
26 |
27 | Note: See issue [#4](https://github.com/EmilyAlsentzer/clinicalBERT/issues/4) for ways to improve section splitting code.
28 |
29 | #### Downstream Tasks
30 | To see an example of how to use clinical BERT for the Med NLI tasks, go to the `run_classifier.sh` script in the downstream_tasks folder. To see an example for NER tasks, go to the `run_i2b2.sh` script.
31 |
32 | ## Contact
33 | Please post a Github issue or contact emilya@mit.edu if you have any questions.
34 |
35 | ## Citation
36 | Please acknowledge the following work in papers or derivative software:
37 |
38 | Emily Alsentzer, John Murphy, William Boag, Wei-Hung Weng, Di Jin, Tristan Naumann, and Matthew McDermott. 2019. Publicly available clinical BERT embeddings. In Proceedings of the 2nd Clinical Natural Language Processing Workshop, pages 72-78, Minneapolis, Minnesota, USA. Association for Computational Linguistics.
39 |
40 | ```
41 | @inproceedings{alsentzer-etal-2019-publicly,
42 | title = "Publicly Available Clinical {BERT} Embeddings",
43 | author = "Alsentzer, Emily and
44 | Murphy, John and
45 | Boag, William and
46 | Weng, Wei-Hung and
47 | Jin, Di and
48 | Naumann, Tristan and
49 | McDermott, Matthew",
50 | booktitle = "Proceedings of the 2nd Clinical Natural Language Processing Workshop",
51 | month = jun,
52 | year = "2019",
53 | address = "Minneapolis, Minnesota, USA",
54 | publisher = "Association for Computational Linguistics",
55 | url = "https://www.aclweb.org/anthology/W19-1909",
56 | doi = "10.18653/v1/W19-1909",
57 | pages = "72--78"
58 | }
59 | ```
60 |
--------------------------------------------------------------------------------
/downstream_tasks/run_i2b2.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #update these to the path of the bert model and the path to the NER data
4 | BERT_DIR=/PATH/TO/BERT/MODEL #clinical bert, biobert, or bert
5 |
6 | #path to NER data. Make sure to preprocess data according to BIO format first.
7 | #You could use the scripts in the i2b2_preprocessing folder to preprocess the i2b2 data
8 | NER_DIR=/PATH/TO/NER/DATA/DIRECTORY
9 |
10 |
11 | for EPOCHS in 2 3 4 ; do
12 | for LEARN_RATE in 2e-5 3e-5 5e-5 ; do
13 | for BATCH_SZ in 16 32 ; do
14 |
15 |
16 | OUTPUT_DIR=/PATH/TO/OUTPUT/DIRECTORY #update this to the output directory you want
17 | mkdir -p $OUTPUT_DIR
18 |
19 | # You can change the task_name to 'i2b2_2014', 'i2b2_2010', 'i2b2_2006', or 'i2b2_2012'
20 | # Note that you may need to modify the DataProcessor code in `run_ner.py` to adapt to the format of your input
21 | # If you want to use biobert, change the init_checkpoint to biobert_model.ckpt
22 | # run_ner.py is adapted from kyzhouhzau's BERT-NER github and the BioBERT repo
23 | python run_ner.py \
24 | --do_train=True \
25 | --do_eval=True \
26 | --do_predict=True \
27 | --task_name='i2b2_2006' \
28 | --vocab_file=$BERT_DIR/vocab.txt \
29 | --bert_config_file=$BERT_DIR/bert_config.json \
30 | --init_checkpoint=$BERT_DIR/bert_model.ckpt \
31 | --num_train_epochs=$EPOCHS \
32 | --learning_rate=$LEARN_RATE \
33 | --train_batch_size=$BATCH_SZ \
34 | --max_seq_length=150 \
35 | --data_dir=$NER_DIR \
36 | --output_dir=$OUTPUT_DIR \
37 | --save_checkpoints_steps=2000
38 |
39 |
40 | # Note here we're performing 10 fold CV, but if you want to recover the original train, val, test split, use CV iter = 9
41 | # Also go to run_ner.py & modify line 738-739 so that you only run the last CV iteration.
42 | for CV_ITER in 0 1 2 3 4 5 6 7 8 9 ; do
43 | for MODE in eval test ; do
44 | EVAL_OUTPUT_DIR=$OUTPUT_DIR/$CV_ITER #creates a new folder for each CV iteration
45 | mkdir -p $EVAL_OUTPUT_DIR
46 | mkdir -p $EVAL_OUTPUT_DIR/$MODE
47 | mkdir -p $EVAL_OUTPUT_DIR/$MODE/gold/
48 | mkdir -p $EVAL_OUTPUT_DIR/$MODE/pred/
49 |
50 | OUTPUT_FILE=${EVAL_OUTPUT_DIR}/NER_result_conll_${MODE}.txt
51 |
52 | #convert word-piece BERT NER results to CoNLL eval format
53 | #Code is adapted from the BioBERT github
54 | python ner_eval/ner_detokenize.py \
55 | --token_test_path=${OUTPUT_DIR}/${CV_ITER}_token_${MODE}.txt \
56 | --label_test_path=${OUTPUT_DIR}/${CV_ITER}_label_${MODE}.txt \
57 | --answer_path=${NER_DIR}/${CV_ITER}_${MODE} \
58 | --tok_to_orig_map_path=${OUTPUT_DIR}/${CV_ITER}_tok_to_orig_map_${MODE}.txt \
59 | --output_file=$OUTPUT_FILE
60 |
61 | #convert to i2b2 evaluation format (adapted from Cliner Repo)
62 | python ner_eval/format_for_i2b2_eval.py \
63 | --results_file $OUTPUT_FILE \
64 | --output_gold_dir $EVAL_OUTPUT_DIR/$MODE/gold/ \
65 | --output_pred_dir $EVAL_OUTPUT_DIR/$MODE/pred/
66 |
67 | # evaluate performance on i2b2 tasks
68 | python ner_eval/score_i2b2.py \
69 | --input_gold_dir $EVAL_OUTPUT_DIR/$MODE/gold/ \
70 | --input_pred_dir $EVAL_OUTPUT_DIR/$MODE/pred/ \
71 | --output_dir $EVAL_OUTPUT_DIR/$MODE
72 | done
73 | done
74 | done
75 | done
76 | done
77 |
78 |
79 |
80 |
--------------------------------------------------------------------------------
/lm_pretraining/format_mimic_for_BERT.py:
--------------------------------------------------------------------------------
1 | import psycopg2
2 | import pandas as pd
3 | import sys
4 | import spacy
5 | import re
6 | import stanfordnlp
7 | import time
8 | import scispacy
9 | from tqdm import tqdm
10 | from heuristic_tokenize import sent_tokenize_rules
11 |
12 |
13 | # update these constants to run this script
14 | OUTPUT_DIR = '/PATH/TO/OUTPUT/DIR' #this path will contain tokenized notes. This dir will be the input dir for create_pretrain_data.sh
15 | MIMIC_NOTES_FILE = 'PATH/TO/MIMIC/DATA' #this is the path to mimic data if you're reading from a csv. Else uncomment the code to read from database below
16 |
17 |
18 | #setting sentence boundaries
19 | def sbd_component(doc):
20 | for i, token in enumerate(doc[:-2]):
21 | # define sentence start if period + titlecase token
22 | if token.text == '.' and doc[i+1].is_title:
23 | doc[i+1].sent_start = True
24 | if token.text == '-' and doc[i+1].text != '-':
25 | doc[i+1].sent_start = True
26 | return doc
27 |
28 | #convert de-identification text into one token
29 | def fix_deid_tokens(text, processed_text):
30 | deid_regex = r"\[\*\*.{0,15}.*?\*\*\]"
31 | if text:
32 | indexes = [m.span() for m in re.finditer(deid_regex,text,flags=re.IGNORECASE)]
33 | else:
34 | indexes = []
35 | for start,end in indexes:
36 | processed_text.merge(start_idx=start,end_idx=end)
37 | return processed_text
38 |
39 |
40 | def process_section(section, note, processed_sections):
41 | # perform spacy processing on section
42 | processed_section = nlp(section['sections'])
43 | processed_section = fix_deid_tokens(section['sections'], processed_section)
44 | processed_sections.append(processed_section)
45 |
46 | def process_note_helper(note):
47 | # split note into sections
48 | note_sections = sent_tokenize_rules(note)
49 | processed_sections = []
50 | section_frame = pd.DataFrame({'sections':note_sections})
51 | section_frame.apply(process_section, args=(note,processed_sections,), axis=1)
52 | return(processed_sections)
53 |
54 | def process_text(sent, note):
55 | sent_text = sent['sents'].text
56 | if len(sent_text) > 0 and sent_text.strip() != '\n':
57 | if '\n' in sent_text:
58 | sent_text = sent_text.replace('\n', ' ')
59 | note['text'] += sent_text + '\n'
60 |
61 | def get_sentences(processed_section, note):
62 | # get sentences from spacy processing
63 | sent_frame = pd.DataFrame({'sents': list(processed_section['sections'].sents)})
64 | sent_frame.apply(process_text, args=(note,), axis=1)
65 |
66 | def process_note(note):
67 | try:
68 | note_text = note['text'] #unicode(note['text'])
69 | note['text'] = ''
70 | processed_sections = process_note_helper(note_text)
71 | ps = {'sections': processed_sections}
72 | ps = pd.DataFrame(ps)
73 | ps.apply(get_sentences, args=(note,), axis=1)
74 | return note
75 | except Exception as e:
76 | pass
77 | #print ('error', e)
78 |
79 |
80 |
81 | if len(sys.argv) < 2:
82 | print('Please specify the note category.')
83 | sys.exit()
84 |
85 | category = sys.argv[1]
86 |
87 |
88 | start = time.time()
89 | tqdm.pandas()
90 |
91 | print('Begin reading notes')
92 |
93 |
94 | # Uncomment this to use postgres to query mimic instead of reading from a file
95 | # con = psycopg2.connect(dbname='mimic', host="/var/run/postgresql")
96 | # notes_query = "(select * from mimiciii.noteevents);"
97 | # notes = pd.read_sql_query(notes_query, con)
98 | notes = pd.read_csv(MIMIC_NOTES_FILE, index_col = 0)
99 | #print(set(notes['category'])) # all categories
100 |
101 |
102 | notes = notes[notes['category'] == category]
103 | print('Number of notes: %d' %len(notes.index))
104 | notes['ind'] = list(range(len(notes.index)))
105 |
106 | # NOTE: `disable=['tagger', 'ner'] was added after paper submission to make this process go faster
107 | # our time estimate in the paper did not include the code to skip spacy's NER & tagger
108 | nlp = spacy.load('en_core_sci_md', disable=['tagger','ner'])
109 | nlp.add_pipe(sbd_component, before='parser')
110 |
111 |
112 | formatted_notes = notes.progress_apply(process_note, axis=1)
113 | with open(OUTPUT_DIR + category + '.txt','w') as f:
114 | for text in formatted_notes['text']:
115 | if text != None and len(text) != 0 :
116 | f.write(text)
117 | f.write('\n')
118 |
119 | end = time.time()
120 | print (end-start)
121 | print ("Done formatting notes")
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # This file may be used to create an environment using:
2 | # $ conda create --name --file
3 | # platform: linux-64
4 | absl-py=0.7.0=pypi_0
5 | asn1crypto=0.24.0=py36_0
6 | astor=0.7.1=pypi_0
7 | awscli=1.16.111=pypi_0
8 | backcall=0.1.0=py36_0
9 | blas=1.0=mkl
10 | bleach=3.1.0=py36_0
11 | boto3=1.9.86=pypi_0
12 | botocore=1.12.101=pypi_0
13 | ca-certificates=2018.11.29=ha4d7672_0
14 | certifi=2018.11.29=py36_1000
15 | cffi=1.11.5=py36he75722e_1
16 | chardet=3.0.4=py36_1
17 | colorama=0.3.9=pypi_0
18 | conllu=1.2.2=pypi_0
19 | cryptography=2.4.2=py36h1ba5d50_0
20 | cymem=2.0.2=py36hfd86e86_0
21 | cytoolz=0.9.0.1=py36h14c3975_1
22 | dbus=1.13.6=h746ee38_0
23 | decorator=4.3.0=py36_0
24 | dill=0.2.8.2=py36_0
25 | docutils=0.14=pypi_0
26 | en-core-sci-md=0.1.0=pypi_0
27 | en-core-sci-sm=0.1.0=pypi_0
28 | en-core-web-sm=2.0.0=pypi_0
29 | entrypoints=0.3=py36_0
30 | expat=2.2.6=he6710b0_0
31 | fontconfig=2.13.0=h9420a91_0
32 | freetype=2.9.1=h8a8886c_1
33 | gast=0.2.2=pypi_0
34 | glib=2.56.2=hd408876_0
35 | gmp=6.1.2=h6c8ec71_1
36 | grpcio=1.18.0=pypi_0
37 | gst-plugins-base=1.14.0=hbbd80ab_1
38 | gstreamer=1.14.0=hb453b48_1
39 | h5py=2.9.0=pypi_0
40 | icu=58.2=h9c2bf20_1
41 | idna=2.8=py36_0
42 | intel-openmp=2019.1=144
43 | ipykernel=5.1.0=py36h39e3cac_0
44 | ipython=7.2.0=py36h39e3cac_0
45 | ipython_genutils=0.2.0=py36_0
46 | ipywidgets=7.4.2=py36_0
47 | jedi=0.13.2=py36_0
48 | jinja2=2.10=py36_0
49 | jmespath=0.9.3=pypi_0
50 | jpeg=9b=h024ee3a_2
51 | jsonschema=2.6.0=py36_0
52 | jupyter=1.0.0=py36_7
53 | jupyter-contrib-core=0.3.3=pypi_0
54 | jupyter-highlight-selected-word=0.2.0=pypi_0
55 | jupyter-latex-envs=1.4.6=pypi_0
56 | jupyter-nbextensions-configurator=0.4.1=pypi_0
57 | jupyter_client=5.2.4=py36_0
58 | jupyter_console=6.0.0=py36_0
59 | jupyter_contrib_core=0.3.3=py_2
60 | jupyter_contrib_nbextensions=0.5.1=py36_0
61 | jupyter_core=4.4.0=py36_0
62 | jupyter_highlight_selected_word=0.2.0=py36_1000
63 | jupyter_latex_envs=1.4.4=py36_1000
64 | jupyter_nbextensions_configurator=0.4.1=py36_0
65 | keras-applications=1.0.7=pypi_0
66 | keras-preprocessing=1.0.9=pypi_0
67 | krb5=1.16.1=h173b8e3_7
68 | libedit=3.1.20181209=hc058e9b_0
69 | libffi=3.2.1=hd88cf55_4
70 | libgcc-ng=8.2.0=hdf63c60_1
71 | libgfortran-ng=7.3.0=hdf63c60_0
72 | libpng=1.6.36=hbc83047_0
73 | libpq=11.1=h20c2e04_0
74 | libsodium=1.0.16=h1bed415_0
75 | libstdcxx-ng=8.2.0=hdf63c60_1
76 | libuuid=1.0.3=h1bed415_2
77 | libxcb=1.13=h1bed415_1
78 | libxml2=2.9.9=he19cac6_0
79 | libxslt=1.1.32=h4785a14_1002
80 | lxml=4.3.0=pypi_0
81 | markdown=3.0.1=pypi_0
82 | markupsafe=1.1.0=py36h7b6447c_0
83 | mistune=0.8.4=py36h7b6447c_0
84 | mkl=2018.0.3=1
85 | mkl_fft=1.0.6=py36h7dd41cf_0
86 | mkl_random=1.0.1=py36h4414c95_1
87 | msgpack-numpy=0.4.3.2=py36_0
88 | msgpack-python=0.5.6=py36h6bb024c_1
89 | murmurhash=1.0.1=py36he6710b0_0
90 | nbconvert=5.3.1=py36_0
91 | nbformat=4.4.0=py36_0
92 | ncurses=6.1=he6710b0_1
93 | nltk=3.4=py36_1
94 | notebook=5.7.4=py36_0
95 | numpy=1.15.4=py36h1d66e8a_0
96 | numpy-base=1.15.4=py36h81de0dd_0
97 | openssl=1.1.1a=h14c3975_1000
98 | pandas=0.24.0=py36he6710b0_0
99 | pandoc=2.2.3.2=0
100 | pandocfilters=1.4.2=py36_1
101 | parso=0.3.1=py36_0
102 | pcre=8.42=h439df22_0
103 | pexpect=4.6.0=py36_0
104 | pickleshare=0.7.5=py36_0
105 | pip=18.1=py36_0
106 | plac=0.9.6=py36_0
107 | preshed=2.0.1=py36he6710b0_0
108 | prometheus_client=0.5.0=py36_0
109 | prompt_toolkit=2.0.7=py36_0
110 | protobuf=3.6.1=pypi_0
111 | psutil=5.5.0=pypi_0
112 | psycopg2=2.7.6.1=py36h1ba5d50_0
113 | ptyprocess=0.6.0=py36_0
114 | pyasn1=0.4.5=pypi_0
115 | pycparser=2.19=py36_0
116 | pygments=2.3.1=py36_0
117 | pyopenssl=18.0.0=py36_0
118 | pyqt=5.9.2=py36h05f1152_2
119 | pysocks=1.6.8=py36_0
120 | python=3.6.8=h0371630_0
121 | python-dateutil=2.7.5=py36_0
122 | pytorch-pretrained-bert=0.4.0=pypi_0
123 | pytz=2018.9=py36_0
124 | pyyaml=3.13=pypi_0
125 | pyzmq=17.1.2=py36h14c3975_0
126 | qt=5.9.7=h5867ecd_1
127 | qtconsole=4.4.3=py36_0
128 | readline=7.0=h7b6447c_5
129 | regex=2018.1.10=pypi_0
130 | requests=2.21.0=py36_0
131 | rsa=3.4.2=pypi_0
132 | s3transfer=0.1.13=pypi_0
133 | scipy=1.2.1=pypi_0
134 | scispacy=0.1.0=pypi_0
135 | send2trash=1.5.0=py36_0
136 | setuptools=40.6.3=py36_0
137 | sip=4.19.8=py36hf484d3e_0
138 | six=1.12.0=py36_0
139 | spacy=2.0.18=pypi_0
140 | sqlite=3.26.0=h7b6447c_0
141 | stanfordcorenlp=3.9.1.1=pypi_0
142 | stanfordnlp=0.1.0=pypi_0
143 | statistics=1.0.3.5=pypi_0
144 | tensorboard=1.12.2=pypi_0
145 | tensorflow-gpu=1.12.0=pypi_0
146 | termcolor=1.1.0=pypi_0
147 | terminado=0.8.1=py36_1
148 | testpath=0.4.2=py36_0
149 | thinc=6.12.1=py36h4989274_0
150 | tk=8.6.8=hbc83047_0
151 | toolz=0.9.0=py36_0
152 | torch=1.0.0=pypi_0
153 | tornado=5.1.1=py36h7b6447c_0
154 | tqdm=4.29.1=py_0
155 | traitlets=4.3.2=py36_0
156 | ujson=1.35=py36h14c3975_0
157 | urllib3=1.24.1=py36_0
158 | wcwidth=0.1.7=py36_0
159 | webencodings=0.5.1=py36_1
160 | werkzeug=0.14.1=pypi_0
161 | wheel=0.32.3=py36_0
162 | widgetsnbextension=3.4.2=py36_0
163 | wrapt=1.10.11=py36h14c3975_2
164 | xz=5.2.4=h14c3975_4
165 | yaml=0.1.7=h14c3975_1001
166 | zeromq=4.2.5=hf484d3e_1
167 | zlib=1.2.11=h7b6447c_3
168 |
--------------------------------------------------------------------------------
/downstream_tasks/ner_eval/ner_detokenize.py:
--------------------------------------------------------------------------------
1 | # Note that this code is adapted from the BioBERT github repo:
2 | # https://github.com/guidoajansen/biobert/tree/87e70a4dfb0dcc1e29ef9d6562f87c4854504e97/biobert/biocodes
3 |
4 | import argparse
5 | import itertools
6 |
7 | parser = argparse.ArgumentParser(description='')
8 | parser.add_argument('--token_test_path', type=str, help='')
9 | parser.add_argument('--label_test_path', type=str, help='')
10 | parser.add_argument('--answer_path', type=str, help='')
11 | parser.add_argument('--output_file', type=str, help='')
12 | parser.add_argument('--tok_to_orig_map_path', type=str, help='')
13 |
14 | args = parser.parse_args()
15 |
16 | def detokenize(golden_path, pred_token_test_path, pred_label_test_path, tok_to_orig_map_path, output_file):
17 |
18 | """convert word-piece BERT-NER results to original words (CoNLL eval format)
19 |
20 | Args:
21 | golden_path: path to golden dataset. ex) NCBI-disease/test.tsv
22 | pred_token_test_path: path to token_test.txt from output folder. ex) output/token_test.txt
23 | pred_label_test_path: path to label_test.txt from output folder. ex) output/label_test.txt
24 | output_file: file where result will write to. ex) output/conll_format.txt
25 |
26 | Outs:
27 | NER_result_conll.txt
28 | """
29 | # read golden
30 | ans = dict({'toks':[], 'labels':[]})
31 | with open(golden_path,'r') as in_:
32 | ans_toks, ans_labels = [],[]
33 | for line in in_:
34 | line = line.strip()
35 | if line == '':
36 | if len(ans_toks) == 0: # there must be extra empty lines
37 | continue
38 | #ans['toks'].append('[SEP]')
39 | ans['toks'].append(ans_toks)
40 | ans['labels'].append(ans_labels)
41 | ans_toks =[]
42 | ans_labels=[]
43 | continue
44 | tmp = line.split()
45 | ans_toks.append(tmp[0])
46 | ans_labels.append(tmp[1])
47 | if len(ans_toks) > 0: #don't forget the last sentence if there's no final empty line
48 | ans['toks'].append(ans_toks)
49 | ans['labels'].append(ans_labels)
50 |
51 | # read predicted
52 | pred = dict({'toks':[], 'labels':[], 'tok_to_orig':[]}) # dictionary for predicted tokens and labels.
53 | with open(pred_token_test_path,'r') as in_: #'token_test.txt'
54 | pred_toks = []
55 | for line in in_:
56 | line = line.strip()
57 | if line =='':
58 | pred['toks'].append(pred_toks)
59 | pred_toks = []
60 | continue
61 | pred_toks.append(line)
62 | if len(pred_toks) > 0: #don't forget the last sentence if there's no final empty line
63 | pred['toks'].append(pred_toks)
64 |
65 | with open(tok_to_orig_map_path,'r') as in_: #'tok_to_orig_map_test.txt'
66 | pred_tok_to_orig = []
67 | for line in in_:
68 | line = line.strip()
69 | if line =='':
70 | pred['tok_to_orig'].append(pred_tok_to_orig)
71 | pred_tok_to_orig=[]
72 | continue
73 | pred_tok_to_orig.append(int(line))
74 | if len(pred_tok_to_orig) > 0: #don't forget the last sentence if there's no final empty line
75 | pred['tok_to_orig'].append(pred_tok_to_orig)
76 |
77 |
78 | with open(pred_label_test_path,'r') as in_:
79 | pred_labels = []
80 | for line in in_:
81 | line = line.strip()
82 | if line in ['[CLS]','[SEP]', 'X']: # replace non-text tokens with O. This will not be evaluated.
83 | pred_labels.append('O')
84 | continue
85 | if line == '':
86 | pred['labels'].append(pred_labels)
87 | pred_labels = []
88 | continue
89 | pred_labels.append(line)
90 | if len(pred_labels) > 0: #don't forget the last sentence if there's no final empty line
91 | pred['labels'].append(pred_labels)
92 |
93 |
94 |
95 | print(len(pred['toks']), len(pred['labels']), len(ans['labels']), len(ans['toks']))
96 |
97 |
98 |
99 | if (len(pred['toks']) != len(pred['labels'])): # Sanity check
100 | print("Error! : len(pred['toks']) != len(pred['labels']) : Please report us")
101 | print(len(pred['toks']), len(pred['labels']))
102 | raise
103 |
104 | if (len(ans['labels']) != len(pred['labels'])): # Sanity check
105 | print(len(ans['labels']), len(pred['labels']))
106 | print("Error! : len(ans['labels']) != len(bert_pred['labels']) : Please report us")
107 | raise
108 |
109 | bert_pred = dict({'toks':[], 'labels':[]})
110 | num_too_short = 0
111 | for t, l, tok_to_orig_map, ans_toks in zip(pred['toks'],pred['labels'], pred['tok_to_orig'], ans['toks']):
112 | #remove first and last from each list, which are just buffers
113 | t.pop(0)
114 | t.pop()
115 | l.pop(0)
116 | l.pop()
117 | tok_to_orig_map.pop(0)
118 | tok_to_orig_map.pop()
119 |
120 | if (len(t)!= len(tok_to_orig_map)):
121 | num_too_short += 1
122 | print('Sentence of length %d was truncated' %len(tok_to_orig_map))
123 |
124 |
125 | bert_pred_toks, bert_pred_labs = [],[]
126 | for ind_into_orig in range(int(tok_to_orig_map[len(t)-1]) + 1): #indexing into t here to deal with issue of truncated tokens
127 | tok_indices = [i for i, x in enumerate(tok_to_orig_map) if x == ind_into_orig]
128 | if len(t) in tok_indices: #skip that token and label because part of the word was truncated during eval
129 | continue
130 | wordpiece_toks = [t[ind][2:] if t[ind][:2] == '##' else t[ind] for ind in tok_indices]
131 | wordpiece_labs = [l[ind] for ind in tok_indices]
132 | bert_pred_toks.append(''.join(wordpiece_toks))
133 | bert_pred_labs.append(wordpiece_labs[0])
134 | if len(ans_toks) != len(bert_pred_toks): #if sentence was truncated assume remaining toks were predicted as 0
135 | n_missing_labs = len(ans_toks) - len(bert_pred_toks)
136 | bert_pred_labs.extend(['O'] * n_missing_labs)
137 | bert_pred['toks'].append(bert_pred_toks)
138 | bert_pred['labels'].append(bert_pred_labs)
139 |
140 |
141 | print('Number of sentences that were truncated: %d' %num_too_short)
142 |
143 |
144 | flattened_pred_toks = [item for sublist in bert_pred['toks'] for item in sublist]
145 | flattened_pred_labs = [item for sublist in bert_pred['labels'] for item in sublist]
146 | flattened_ans_labs = [item for sublist in ans['labels'] for item in sublist]
147 | flattened_ans_toks= [item for sublist in ans['toks'] for item in sublist]
148 |
149 | print(len(flattened_pred_toks), len(flattened_pred_labs), len(flattened_ans_labs), len(flattened_ans_toks))
150 |
151 |
152 | if (len(bert_pred['toks']) != len(bert_pred['labels'])): # Sanity check
153 | print("Error! : len(bert_pred['toks']) != len(bert_pred['labels']) : Please report us")
154 | raise
155 |
156 | if (len(ans['labels']) != len(bert_pred['labels'])): # Sanity check
157 | print("Error! : len(ans['labels']) != len(bert_pred['labels']) : Please report us")
158 | raise
159 |
160 | with open(output_file, 'w') as out_:
161 | for ans_toks, ans_labs, pred_labs in zip(ans['toks'], ans['labels'], bert_pred['labels']):
162 | for ans_t, ans_l, pred_l in zip(ans_toks, ans_labs, pred_labs):
163 | out_.write("%s %s %s\n"%(ans_t, ans_l, pred_l))
164 | out_.write('\n')
165 |
166 |
167 | detokenize(args.answer_path, args.token_test_path, args.label_test_path, args.tok_to_orig_map_path, args.output_file)
168 |
--------------------------------------------------------------------------------
/downstream_tasks/ner_eval/format_for_i2b2_eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | # Functions are all taken from Willie Boag's cliner code `documents.py` at
5 | # https://github.com/text-machine-lab/CliNER/blob/5e1599fb2a2209fa0183f623308516816a033d4f/code/notes/documents.py
6 |
7 | def convert_to_i2b2_format(tok_sents, pred_labels, mode=None):
8 | """
9 | Purpose: Return the given concept label predictions in i2b2 format
10 |
11 | @param tokenized_sents. of tokenized sentences
12 | @param pred_labels. of predicted_labels
13 | @return of i2b2-concept-file-formatted data
14 | """
15 |
16 | # Return value
17 | retStr = ''
18 |
19 | concept_tuples = tok_labels_to_concepts(tok_sents, pred_labels, mode)
20 |
21 | # For each classification
22 | for classification in concept_tuples:
23 |
24 | # Ensure 'none' classifications are skipped
25 | if classification[0] == 'none':
26 | raise('Classification label "none" should never happen')
27 |
28 | concept = classification[0]
29 | lineno = classification[1]
30 | start = classification[2]
31 | end = classification[3]
32 |
33 | # A list of words (corresponding line from the text file)
34 | text = tok_sents[lineno-1]
35 |
36 | #print "\n" + "-" * 80
37 | #print "classification: ", classification
38 | #print "lineno: ", lineno
39 | #print "start: ", start
40 | #print "end ", end
41 | #print "text: ", text
42 | #print 'len(text): ', len(text)
43 | #print "text[start]: ", text[start]
44 | #print "concept: ", concept
45 |
46 | datum = text[start]
47 | for j in range(start, end):
48 | datum += " " + text[j+1]
49 | datum = datum.lower()
50 |
51 | #print 'datum: ', datum
52 |
53 | # Line:TokenNumber of where the concept starts and ends
54 | idx1 = "%d:%d" % (lineno, start)
55 | idx2 = "%d:%d" % (lineno, end)
56 |
57 | # Classification
58 | label = concept.capitalize()
59 |
60 |
61 | # Print format
62 | retStr += "c=\"%s\" %s %s||t=\"%s\"\n" % (datum, idx1, idx2, label)
63 |
64 | # return formatted data
65 | return retStr.strip()
66 |
67 |
68 | def tok_labels_to_concepts(tokenized_sents, tok_labels, mode):
69 |
70 | #print tok_labels
71 | '''
72 | for gold,sent in zip(tok_labels, tokenized_sents):
73 | print gold
74 | print sent
75 | print
76 | '''
77 |
78 | # convert 'B-treatment' into ('B','treatment') and 'O' into ('O',None)
79 | def split_label(label):
80 | if label == 'O':
81 | iob,tag = 'O', None
82 | else:
83 | #print(label)
84 | if 'LOCATION-OTHER' in label:
85 | label = label.replace('LOCATION-OTHER', 'LOCATION_OTHER')
86 | if len(label.split('-')) != 2:
87 | print(label.split('-'))
88 | iob,tag = label.split('-')
89 | return iob, tag
90 |
91 |
92 |
93 | # preprocess predictions to "correct" starting Is into Bs
94 | corrected = []
95 | for lineno,(toks, labels) in enumerate(zip(tokenized_sents, tok_labels)):
96 | corrected_line = []
97 | for i in range(len(labels)):
98 | #'''
99 | # is this a candidate for error?
100 |
101 | iob,tag = split_label(labels[i])
102 | if iob == 'I':
103 | # beginning of line has no previous
104 | if i == 0:
105 | print(mode, 'CORRECTING! A')
106 | if mode == 'gold':
107 | print(toks, labels)
108 | new_label = 'B' + labels[i][1:]
109 | else:
110 | # ensure either its outside OR mismatch type
111 | prev_iob,prev_tag = split_label(labels[i-1])
112 | if prev_iob == 'O' or prev_tag != tag:
113 | print(mode, 'CORRECTING! B')
114 | new_label = 'B' + labels[i][1:]
115 | else:
116 | new_label = labels[i]
117 | else:
118 | new_label = labels[i]
119 | #'''
120 | corrected_line.append(new_label)
121 | corrected.append( corrected_line )
122 |
123 | tok_labels = corrected
124 |
125 | concepts = []
126 | for i,labs in enumerate(tok_labels):
127 |
128 | N = len(labs)
129 | begins = [ j for j,lab in enumerate(labs) if (lab[0] == 'B') ]
130 |
131 | for start in begins:
132 | # "B-test" --> "-test"
133 | label = labs[start][1:]
134 |
135 | # get ending token index
136 | end = start
137 | while (end < N-1) and tok_labels[i][end+1].startswith('I') and tok_labels[i][end+1][1:] == label:
138 | end += 1
139 |
140 | # concept tuple
141 | concept_tuple = (label[1:], i+1, start, end)
142 | concepts.append(concept_tuple)
143 |
144 | '''
145 | # test it out
146 | for i in range(len(tokenized_sents)):
147 | assert len(tokenized_sents[i]) == len(tok_labels[i])
148 | for tok,lab in zip(tokenized_sents[i],tok_labels[i]):
149 | if lab != 'O': print '\t',
150 | print lab, tok
151 | print
152 | exit()
153 | '''
154 |
155 | # test it out
156 | test_tok_labels = tok_concepts_to_labels(tokenized_sents, concepts)
157 | #'''
158 | for lineno,(test,gold,sent) in enumerate(zip(test_tok_labels, tok_labels, tokenized_sents)):
159 | for i,(a,b) in enumerate(zip(test,gold)):
160 | #'''
161 | if not ((a == b)or(a[0]=='B' and b[0]=='I' and a[1:]==b[1:])):
162 | print()
163 | print('lineno: ', lineno)
164 | print()
165 | print('generated: ', test[i-3:i+4])
166 | print('predicted: ', gold[i-3:i+4])
167 | print(sent[i-3:i+4])
168 | print('a[0]: ', a[0])
169 | print('b[0]: ', b[0])
170 | print('a[1:]: ', a[1:])
171 | print('b[1:]: ', b[1:])
172 | print('a[1:] == b[a:]: ', a[1:] == b[1:])
173 | print()
174 | #'''
175 | assert (a == b) or (a[0]=='B' and b[0]=='I' and a[1:]==b[1:])
176 | i += 1
177 | #'''
178 | assert test_tok_labels == tok_labels
179 |
180 | return concepts
181 |
182 |
183 | def tok_concepts_to_labels(tokenized_sents, tok_concepts):
184 | # parallel to tokens
185 | labels = [ ['O' for tok in sent] for sent in tokenized_sents ]
186 |
187 | # fill each concept's tokens appropriately
188 | for concept in tok_concepts:
189 | label,lineno,start_tok,end_tok = concept
190 | labels[lineno-1][start_tok] = 'B-%s' % label
191 | for i in range(start_tok+1,end_tok+1):
192 | labels[lineno-1][i] = 'I-%s' % label
193 |
194 | # test it out
195 | '''
196 | for i in range(len(tokenized_sents)):
197 | assert len(tokenized_sents[i]) == len(labels[i])
198 | for tok,lab in zip(tokenized_sents[i],labels[i]):
199 | if lab != 'O': print '\t',
200 | print lab, tok
201 | print
202 | exit()
203 | '''
204 |
205 | return labels
206 |
207 |
208 |
209 | parser = argparse.ArgumentParser(description='')
210 | parser.add_argument('--results_file', type=str, help='Location of results file in conll format')
211 | parser.add_argument('--output_pred_dir', type=str, help='Location of where to output prediction formatted files')
212 | parser.add_argument('--output_gold_dir', type=str, help='Location of where to output gold formatted files')
213 |
214 |
215 | args = parser.parse_args()
216 |
217 | results_dict = {'tokens': [], 'gold_labels': [], 'predicted_labels': []}
218 | with open(args.results_file, 'r') as results: #+'/NER_result_conll.txt'
219 | toks, gold_labs, pred_labs = [],[],[]
220 | for line in results:
221 | line = line.strip()
222 | if line == '':
223 | results_dict['tokens'].append(toks)
224 | results_dict['gold_labels'].append(gold_labs)
225 | results_dict['predicted_labels'].append(pred_labs)
226 | toks, gold_labs, pred_labs = [],[],[]
227 | else:
228 | tok, gold_lab, pred_lab = line.split()
229 | toks.append(tok)
230 | gold_labs.append(gold_lab)
231 | pred_labs.append(pred_lab)
232 | if len(toks) > 0:
233 | results_dict['tokens'].append(toks)
234 | results_dict['gold_labels'].append(gold_labs)
235 | results_dict['predicted_labels'].append(pred_labs)
236 |
237 | i2b2_format_gold = convert_to_i2b2_format(results_dict['tokens'], results_dict['gold_labels'], 'gold')
238 | i2b2_format_predicted = convert_to_i2b2_format(results_dict['tokens'], results_dict['predicted_labels'], 'pred')
239 |
240 | with open(os.path.join(args.output_pred_dir, "i2b2.con"),'w') as writer:
241 | writer.write(i2b2_format_predicted)
242 |
243 | with open(os.path.join(args.output_gold_dir, "i2b2.con"),'w') as writer:
244 | writer.write(i2b2_format_gold)
245 |
246 |
247 |
--------------------------------------------------------------------------------
/lm_pretraining/heuristic_tokenize.py:
--------------------------------------------------------------------------------
1 | # NOTE: this code is taken directly from Willie Boag's mimic-tokenize github repository
2 | # https://github.com/wboag/mimic-tokenize/blob/master/heuristic-tokenize.py commit e953d271bbb4c53aee5cc9a7b8be870a6b007604
3 | # The code was modified in two ways:
4 | # (1) to make the script compatible with Python 3
5 | # (2) to remove the main and discharge_tokenize methods which we don't directly use
6 |
7 | # There are two known issues with this code. We have not yet fixed them because we want to maintain the reproducibility of
8 | # our code for the work that was published. However, anyone wanting to extend this work should make the following changes:
9 | # (1) fix a bug on line #168 where . should be replaced with \. i.e. should be `while re.search('\n\s*%d\.'%n,segment):`
10 | # (2) add else statement (`else: new_segments.append(segments[i])`) to the if statement at line 287
11 | # `if (i == N-1) or is_title(segments[i+1]):`
12 |
13 |
14 |
15 | import sys
16 | import re, nltk
17 | import os
18 |
19 | def strip(s):
20 | return s.strip()
21 |
22 | def is_inline_title(text):
23 | m = re.search('^([a-zA-Z ]+:) ', text)
24 | if not m:
25 | return False
26 | return is_title(m.groups()[0])
27 |
28 | stopwords = set(['of', 'on', 'or'])
29 | def is_title(text):
30 | if not text.endswith(':'):
31 | return False
32 | text = text[:-1]
33 |
34 | # be a little loose here... can tighten if it causes errors
35 | text = re.sub('(\([^\)]*?\))', '', text)
36 |
37 | # Are all non-stopwords capitalized?
38 | for word in text.split():
39 | if word in stopwords: continue
40 | if not word[0].isupper():
41 | return False
42 |
43 | # I noticed this is a common issue (non-title aapears at beginning of line)
44 | if text == 'Disp':
45 | return False
46 |
47 | # optionally: could assert that it is less than 6 tokens
48 | return True
49 |
50 |
51 |
52 | def sent_tokenize_rules(text):
53 |
54 | # long sections are OBVIOUSLY different sentences
55 | text = re.sub('---+', '\n\n-----\n\n', text)
56 | text = re.sub('___+', '\n\n_____\n\n', text)
57 | text = re.sub('\n\n+', '\n\n', text)
58 |
59 | segments = text.split('\n\n')
60 |
61 | # strategy: break down segments and chip away structure until just prose.
62 | # once you have prose, use nltk.sent_tokenize()
63 |
64 | ### Separate section headers ###
65 | new_segments = []
66 |
67 | # deal with this one edge case (multiple headers per line) up front
68 | m1 = re.match('(Admission Date:) (.*) (Discharge Date:) (.*)', segments[0])
69 | if m1:
70 | new_segments += list(map(strip,m1.groups()))
71 | segments = segments[1:]
72 |
73 | m2 = re.match('(Date of Birth:) (.*) (Sex:) (.*)' , segments[0])
74 | if m2:
75 | new_segments += list(map(strip,m2.groups()))
76 | segments = segments[1:]
77 |
78 | for segment in segments:
79 | # find all section headers
80 | possible_headers = re.findall('\n([A-Z][^\n:]+:)', '\n'+segment)
81 | #assert len(possible_headers) < 2, str(possible_headers)
82 | headers = []
83 | for h in possible_headers:
84 | #print 'cand=[%s]' % h
85 | if is_title(h.strip()):
86 | #print '\tYES=[%s]' % h
87 | headers.append(h.strip())
88 |
89 | # split text into new segments, delimiting on these headers
90 | for h in headers:
91 | h = h.strip()
92 |
93 | # split this segment into 3 smaller segments
94 | ind = segment.index(h)
95 | prefix = segment[:ind].strip()
96 | rest = segment[ ind+len(h):].strip()
97 |
98 | # add the prefix (potentially empty)
99 | if len(prefix) > 0:
100 | new_segments.append(prefix.strip())
101 |
102 | # add the header
103 | new_segments.append(h)
104 |
105 | # remove the prefix from processing (very unlikely to be empty)
106 | segment = rest.strip()
107 |
108 | # add the final piece (aka what comes after all headers are processed)
109 | if len(segment) > 0:
110 | new_segments.append(segment.strip())
111 |
112 | # copy over the new list of segments (further segmented than original segments)
113 | segments = list(new_segments)
114 | new_segments = []
115 |
116 |
117 | ### Low-hanging fruit: "_____" is a delimiter
118 | for segment in segments:
119 | subsections = segment.split('\n_____\n')
120 | new_segments.append(subsections[0])
121 | for ss in subsections[1:]:
122 | new_segments.append('_____')
123 | new_segments.append(ss)
124 |
125 | segments = list(new_segments)
126 | new_segments = []
127 |
128 |
129 | ### Low-hanging fruit: "-----" is a delimiter
130 | for segment in segments:
131 | subsections = segment.split('\n-----\n')
132 | new_segments.append(subsections[0])
133 | for ss in subsections[1:]:
134 | new_segments.append('-----')
135 | new_segments.append(ss)
136 |
137 | segments = list(new_segments)
138 | new_segments = []
139 |
140 | '''
141 | for segment in segments:
142 | print '------------START------------'
143 | print segment
144 | print '-------------END-------------'
145 | print
146 | exit()
147 | '''
148 |
149 | ### Separate enumerated lists ###
150 | for segment in segments:
151 | if not re.search('\n\s*\d+\.', '\n'+segment):
152 | new_segments.append(segment)
153 | continue
154 |
155 | '''
156 | print '------------START------------'
157 | print segment
158 | print '-------------END-------------'
159 | print
160 | '''
161 |
162 | # generalizes in case the list STARTS this section
163 | segment = '\n'+segment
164 |
165 | # determine whether this segment contains a bulleted list (assumes i,i+1,...,n)
166 | start = int(re.search('\n\s*(\d+)\.', '\n'+segment).groups()[0])
167 | n = start
168 | while re.search('\n\s*%d.'%n,segment): # SHOULD CHANGE TO: while re.search('\n\s*%d\.'%n,segment): #(CHANGED . to \.)
169 | n += 1
170 | n -= 1
171 |
172 | # no bulleted list
173 | if n < 1:
174 | new_segments.append(segment)
175 | continue
176 |
177 | '''
178 | print '------------START------------'
179 | print segment
180 | print '-------------END-------------'
181 |
182 | print start,n
183 | print
184 | '''
185 |
186 | # break each list into its own line
187 | # challenge: not clear how to tell when the list ends if more text happens next
188 | for i in range(start,n+1):
189 | matching_text = re.search('(\n\s*\d+\.)',segment).groups()[0]
190 | prefix = segment[:segment.index(matching_text) ].strip()
191 | segment = segment[ segment.index(matching_text):].strip()
192 | if len(prefix)>0:
193 | new_segments.append(prefix)
194 |
195 | if len(segment)>0:
196 | new_segments.append(segment)
197 |
198 | segments = list(new_segments)
199 | new_segments = []
200 |
201 | '''
202 | TODO: Big Challenge
203 |
204 | There is so much variation in what makes a list. Intuitively, I can tell it's a
205 | list because it shows repeated structure (often following a header)
206 |
207 | Examples of some lists (with numbers & symptoms changed around to noise)
208 |
209 | Past Medical History:
210 | -- Hyperlipidemia
211 | -- lactose intolerance
212 | -- Hypertension
213 |
214 |
215 | Physical Exam:
216 | Vitals - T 82.2 BP 123/23 HR 73 R 21 75% on 2L NC
217 | General - well appearing male, sitting up in chair in NAD
218 | Neck - supple, JVP elevated to angle of jaw
219 | CV - distant heart sounds, RRR, faint __PHI_43__ murmur at
220 |
221 |
222 | Labs:
223 | __PHI_10__ 12:00PM BLOOD WBC-8.8 RBC-8.88* Hgb-88.8* Hct-88.8*
224 | MCV-88 MCH-88.8 MCHC-88.8 RDW-88.8* Plt Ct-888
225 | __PHI_14__ 04:54AM BLOOD WBC-8.8 RBC-8.88* Hgb-88.8* Hct-88.8*
226 | MCV-88 MCH-88.8 MCHC-88.8 RDW-88.8* Plt Ct-888
227 | __PHI_23__ 03:33AM BLOOD WBC-8.8 RBC-8.88* Hgb-88.8* Hct-88.8*
228 | MCV-88 MCH-88.8 MCHC-88.8 RDW-88.8* Plt Ct-888
229 | __PHI_109__ 03:06AM BLOOD WBC-8.8 RBC-8.88* Hgb-88.8* Hct-88.8*
230 | MCV-88 MCH-88.8 MCHC-88.8 RDW-88.8* Plt Ct-888
231 | __PHI_1__ 05:09AM BLOOD WBC-8.8 RBC-8.88* Hgb-88.8* Hct-88.8*
232 | MCV-88 MCH-88.8 MCHC-88.8 RDW-88.8* Plt Ct-888
233 | __PHI_26__ 04:53AM BLOOD WBC-8.8 RBC-8.88* Hgb-88.8* Hct-88.8*
234 | MCV-88 MCH-88.8 MCHC-88.8 RDW-88.8* Plt Ct-888
235 | __PHI_301__ 05:30AM BLOOD WBC-8.8 RBC-8.88* Hgb-88.8* Hct-88.8*
236 | MCV-88 MCH-88.8 MCHC-88.8 RDW-88.8* Plt Ct-888
237 |
238 |
239 | Medications on Admission:
240 | Allopurinol 100 mg DAILY
241 | Aspirin 250 mg DAILY
242 | Atorvastatin 10 mg DAILY
243 | Glimepiride 1 mg once a week.
244 | Hexavitamin DAILY
245 | Lasix 50mg M-W-F; 60mg T-Th-Sat-Sun
246 | Metoprolol 12.5mg TID
247 | Prilosec OTC 20 mg once a day
248 | Verapamil 120 mg SR DAILY
249 | '''
250 |
251 | ### Remove lines with inline titles from larger segments (clearly nonprose)
252 | for segment in segments:
253 | '''
254 | With: __PHI_6__, MD __PHI_5__
255 | Building: De __PHI_45__ Building (__PHI_32__ Complex) __PHI_87__
256 | Campus: WEST
257 | '''
258 |
259 | lines = segment.split('\n')
260 |
261 | buf = []
262 | for i in range(len(lines)):
263 | if is_inline_title(lines[i]):
264 | if len(buf) > 0:
265 | new_segments.append('\n'.join(buf))
266 | buf = []
267 | buf.append(lines[i])
268 | if len(buf) > 0:
269 | new_segments.append('\n'.join(buf))
270 |
271 | segments = list(new_segments)
272 | new_segments = []
273 |
274 |
275 | # Going to put one-liner answers with their sections
276 | # (aka A A' B B' C D D' --> AA' BB' C DD' )
277 | N = len(segments)
278 | for i in range(len(segments)):
279 | # avoid segfaults
280 | if i==0:
281 | new_segments.append(segments[i])
282 | continue
283 |
284 | if segments[i].count('\n') == 0 and \
285 | is_title(segments[i-1]) and \
286 | not is_title(segments[i ]):
287 | if (i == N-1) or is_title(segments[i+1]):
288 | new_segments = new_segments[:-1]
289 | new_segments.append(segments[i-1] + ' ' + segments[i])
290 | #else: new_segments.append(segments[i]) #ADD TO FIX BUG
291 | # currently If the code sees a segment that doesn't have any new lines and the prior line is a title
292 | # *but* it is not the last segment and the next segment is not a title then that segment is just dropped
293 | # so lists that have a title header will lose their first entry
294 | else:
295 | new_segments.append(segments[i])
296 |
297 | segments = list(new_segments)
298 | new_segments = []
299 |
300 | '''
301 | Should do some kind of regex to find "TEST: value" in segments?
302 |
303 | Indication: Source of embolism.
304 | BP (mm Hg): 145/89
305 | HR (bpm): 80
306 |
307 | Note: I made a temporary hack that fixes this particular problem.
308 | We'll see how it shakes out
309 | '''
310 |
311 |
312 | '''
313 | Separate ALL CAPS lines (Warning... is there ever prose that can be all caps?)
314 | '''
315 |
316 |
317 |
318 |
319 | return segments
320 |
321 |
322 |
323 |
324 |
--------------------------------------------------------------------------------
/downstream_tasks/ner_eval/conlleval.pl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/perl -w
2 | # conlleval: evaluate result of processing CoNLL-2000 shared task
3 | # usage: conlleval [-l] [-r] [-d delimiterTag] [-o oTag] < file
4 | # README: http://www.clips.uantwerpen.be/conll2000/chunking/output.html
5 | # options: l: generate LaTeX output for tables like in
6 | # http://cnts.uia.ac.be/conll2003/ner/example.tex
7 | # r: accept raw result tags (without B- and I- prefix;
8 | # assumes one word per chunk)
9 | # d: alternative delimiter tag (default is single space)
10 | # o: alternative outside tag (default is O)
11 | # note: the file should contain lines with items separated
12 | # by $delimiter characters (default space). The final
13 | # two items should contain the correct tag and the
14 | # guessed tag in that order. Sentences should be
15 | # separated from each other by empty lines or lines
16 | # with $boundary fields (default -X-).
17 | # url: http://www.clips.uantwerpen.be/conll2000/chunking/
18 | # started: 1998-09-25
19 | # version: 2004-01-26
20 | # author: Erik Tjong Kim Sang
21 |
22 | use strict;
23 |
24 | my $false = 0;
25 | my $true = 42;
26 |
27 | my $boundary = "-X-"; # sentence boundary
28 | my $correct; # current corpus chunk tag (I,O,B)
29 | my $correctChunk = 0; # number of correctly identified chunks
30 | my $correctTags = 0; # number of correct chunk tags
31 | my $correctType; # type of current corpus chunk tag (NP,VP,etc.)
32 | my $delimiter = " "; # field delimiter
33 | my $FB1 = 0.0; # FB1 score (Van Rijsbergen 1979)
34 | my $firstItem; # first feature (for sentence boundary checks)
35 | my $foundCorrect = 0; # number of chunks in corpus
36 | my $foundGuessed = 0; # number of identified chunks
37 | my $guessed; # current guessed chunk tag
38 | my $guessedType; # type of current guessed chunk tag
39 | my $i; # miscellaneous counter
40 | my $inCorrect = $false; # currently processed chunk is correct until now
41 | my $lastCorrect = "O"; # previous chunk tag in corpus
42 | my $latex = 0; # generate LaTeX formatted output
43 | my $lastCorrectType = ""; # type of previously identified chunk tag
44 | my $lastGuessed = "O"; # previously identified chunk tag
45 | my $lastGuessedType = ""; # type of previous chunk tag in corpus
46 | my $lastType; # temporary storage for detecting duplicates
47 | my $line; # line
48 | my $nbrOfFeatures = -1; # number of features per line
49 | my $precision = 0.0; # precision score
50 | my $oTag = "O"; # outside tag, default O
51 | my $raw = 0; # raw input: add B to every token
52 | my $recall = 0.0; # recall score
53 | my $tokenCounter = 0; # token counter (ignores sentence breaks)
54 |
55 | my %correctChunk = (); # number of correctly identified chunks per type
56 | my %foundCorrect = (); # number of chunks in corpus per type
57 | my %foundGuessed = (); # number of identified chunks per type
58 |
59 | my @features; # features on line
60 | my @sortedTypes; # sorted list of chunk type names
61 |
62 | # sanity check
63 | while (@ARGV and $ARGV[0] =~ /^-/) {
64 | if ($ARGV[0] eq "-l") { $latex = 1; shift(@ARGV); }
65 | elsif ($ARGV[0] eq "-r") { $raw = 1; shift(@ARGV); }
66 | elsif ($ARGV[0] eq "-d") {
67 | shift(@ARGV);
68 | if (not defined $ARGV[0]) {
69 | die "conlleval: -d requires delimiter character";
70 | }
71 | $delimiter = shift(@ARGV);
72 | } elsif ($ARGV[0] eq "-o") {
73 | shift(@ARGV);
74 | if (not defined $ARGV[0]) {
75 | die "conlleval: -o requires delimiter character";
76 | }
77 | $oTag = shift(@ARGV);
78 | } else { die "conlleval: unknown argument $ARGV[0]\n"; }
79 | }
80 | if (@ARGV) { die "conlleval: unexpected command line argument\n"; }
81 | # process input
82 | while () {
83 | chomp($line = $_);
84 | @features = split(/$delimiter/,$line);
85 | if ($nbrOfFeatures < 0) { $nbrOfFeatures = $#features; }
86 | elsif ($nbrOfFeatures != $#features and @features != 0) {
87 | printf STDERR "unexpected number of features: %d (%d)\n",
88 | $#features+1,$nbrOfFeatures+1;
89 | exit(1);
90 | }
91 | if (@features == 0 or
92 | $features[0] eq $boundary) { @features = ($boundary,"O","O"); }
93 | if (@features < 2) {
94 | die "conlleval: unexpected number of features in line $line\n";
95 | }
96 | if ($raw) {
97 | if ($features[$#features] eq $oTag) { $features[$#features] = "O"; }
98 | if ($features[$#features-1] eq $oTag) { $features[$#features-1] = "O"; }
99 | if ($features[$#features] ne "O") {
100 | $features[$#features] = "B-$features[$#features]";
101 | }
102 | if ($features[$#features-1] ne "O") {
103 | $features[$#features-1] = "B-$features[$#features-1]";
104 | }
105 | }
106 | # 20040126 ET code which allows hyphens in the types
107 | if ($features[$#features] =~ /^([^-]*)-(.*)$/) {
108 | $guessed = $1;
109 | $guessedType = $2;
110 | } else {
111 | $guessed = $features[$#features];
112 | $guessedType = "";
113 | }
114 | pop(@features);
115 | if ($features[$#features] =~ /^([^-]*)-(.*)$/) {
116 | $correct = $1;
117 | $correctType = $2;
118 | } else {
119 | $correct = $features[$#features];
120 | $correctType = "";
121 | }
122 | pop(@features);
123 | # ($guessed,$guessedType) = split(/-/,pop(@features));
124 | # ($correct,$correctType) = split(/-/,pop(@features));
125 | $guessedType = $guessedType ? $guessedType : "";
126 | $correctType = $correctType ? $correctType : "";
127 | $firstItem = shift(@features);
128 |
129 | # 1999-06-26 sentence breaks should always be counted as out of chunk
130 | if ( $firstItem eq $boundary ) { $guessed = "O"; }
131 |
132 | if ($inCorrect) {
133 | if ( &endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and
134 | &endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and
135 | $lastGuessedType eq $lastCorrectType) {
136 | $inCorrect=$false;
137 | $correctChunk++;
138 | $correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ?
139 | $correctChunk{$lastCorrectType}+1 : 1;
140 | } elsif (
141 | &endOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) !=
142 | &endOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) or
143 | $guessedType ne $correctType ) {
144 | $inCorrect=$false;
145 | }
146 | }
147 |
148 | if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) and
149 | &startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) and
150 | $guessedType eq $correctType) { $inCorrect = $true; }
151 |
152 | if ( &startOfChunk($lastCorrect,$correct,$lastCorrectType,$correctType) ) {
153 | $foundCorrect++;
154 | $foundCorrect{$correctType} = $foundCorrect{$correctType} ?
155 | $foundCorrect{$correctType}+1 : 1;
156 | }
157 | if ( &startOfChunk($lastGuessed,$guessed,$lastGuessedType,$guessedType) ) {
158 | $foundGuessed++;
159 | $foundGuessed{$guessedType} = $foundGuessed{$guessedType} ?
160 | $foundGuessed{$guessedType}+1 : 1;
161 | }
162 | if ( $firstItem ne $boundary ) {
163 | if ( $correct eq $guessed and $guessedType eq $correctType ) {
164 | $correctTags++;
165 | }
166 | $tokenCounter++;
167 | }
168 |
169 | $lastGuessed = $guessed;
170 | $lastCorrect = $correct;
171 | $lastGuessedType = $guessedType;
172 | $lastCorrectType = $correctType;
173 | }
174 | if ($inCorrect) {
175 | $correctChunk++;
176 | $correctChunk{$lastCorrectType} = $correctChunk{$lastCorrectType} ?
177 | $correctChunk{$lastCorrectType}+1 : 1;
178 | }
179 |
180 | if (not $latex) {
181 | # compute overall precision, recall and FB1 (default values are 0.0)
182 | $precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0);
183 | $recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0);
184 | $FB1 = 2*$precision*$recall/($precision+$recall)
185 | if ($precision+$recall > 0);
186 |
187 | # print overall performance
188 | printf "processed $tokenCounter tokens with $foundCorrect phrases; ";
189 | printf "found: $foundGuessed phrases; correct: $correctChunk.\n";
190 | if ($tokenCounter>0) {
191 | printf "accuracy: %6.2f%%; ",100*$correctTags/$tokenCounter;
192 | printf "precision: %6.2f%%; ",$precision;
193 | printf "recall: %6.2f%%; ",$recall;
194 | printf "FB1: %6.2f\n",$FB1;
195 | }
196 | }
197 |
198 | # sort chunk type names
199 | undef($lastType);
200 | @sortedTypes = ();
201 | foreach $i (sort (keys %foundCorrect,keys %foundGuessed)) {
202 | if (not($lastType) or $lastType ne $i) {
203 | push(@sortedTypes,($i));
204 | }
205 | $lastType = $i;
206 | }
207 | # print performance per chunk type
208 | if (not $latex) {
209 | for $i (@sortedTypes) {
210 | $correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0;
211 | if (not($foundGuessed{$i})) { $foundGuessed{$i} = 0; $precision = 0.0; }
212 | else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; }
213 | if (not($foundCorrect{$i})) { $recall = 0.0; }
214 | else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; }
215 | if ($precision+$recall == 0.0) { $FB1 = 0.0; }
216 | else { $FB1 = 2*$precision*$recall/($precision+$recall); }
217 | printf "%17s: ",$i;
218 | printf "precision: %6.2f%%; ",$precision;
219 | printf "recall: %6.2f%%; ",$recall;
220 | printf "FB1: %6.2f %d\n",$FB1,$foundGuessed{$i};
221 | }
222 | } else {
223 | print " & Precision & Recall & F\$_{\\beta=1} \\\\\\hline";
224 | for $i (@sortedTypes) {
225 | $correctChunk{$i} = $correctChunk{$i} ? $correctChunk{$i} : 0;
226 | if (not($foundGuessed{$i})) { $precision = 0.0; }
227 | else { $precision = 100*$correctChunk{$i}/$foundGuessed{$i}; }
228 | if (not($foundCorrect{$i})) { $recall = 0.0; }
229 | else { $recall = 100*$correctChunk{$i}/$foundCorrect{$i}; }
230 | if ($precision+$recall == 0.0) { $FB1 = 0.0; }
231 | else { $FB1 = 2*$precision*$recall/($precision+$recall); }
232 | printf "\n%-7s & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\",
233 | $i,$precision,$recall,$FB1;
234 | }
235 | print "\\hline\n";
236 | $precision = 0.0;
237 | $recall = 0;
238 | $FB1 = 0.0;
239 | $precision = 100*$correctChunk/$foundGuessed if ($foundGuessed > 0);
240 | $recall = 100*$correctChunk/$foundCorrect if ($foundCorrect > 0);
241 | $FB1 = 2*$precision*$recall/($precision+$recall)
242 | if ($precision+$recall > 0);
243 | printf "Overall & %6.2f\\%% & %6.2f\\%% & %6.2f \\\\\\hline\n",
244 | $precision,$recall,$FB1;
245 | }
246 |
247 | exit 0;
248 |
249 | # endOfChunk: checks if a chunk ended between the previous and current word
250 | # arguments: previous and current chunk tags, previous and current types
251 | # note: this code is capable of handling other chunk representations
252 | # than the default CoNLL-2000 ones, see EACL'99 paper of Tjong
253 | # Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006
254 |
255 | sub endOfChunk {
256 | my $prevTag = shift(@_);
257 | my $tag = shift(@_);
258 | my $prevType = shift(@_);
259 | my $type = shift(@_);
260 | my $chunkEnd = $false;
261 |
262 | if ( $prevTag eq "B" and $tag eq "B" ) { $chunkEnd = $true; }
263 | if ( $prevTag eq "B" and $tag eq "O" ) { $chunkEnd = $true; }
264 | if ( $prevTag eq "I" and $tag eq "B" ) { $chunkEnd = $true; }
265 | if ( $prevTag eq "I" and $tag eq "O" ) { $chunkEnd = $true; }
266 |
267 | if ( $prevTag eq "E" and $tag eq "E" ) { $chunkEnd = $true; }
268 | if ( $prevTag eq "E" and $tag eq "I" ) { $chunkEnd = $true; }
269 | if ( $prevTag eq "E" and $tag eq "O" ) { $chunkEnd = $true; }
270 | if ( $prevTag eq "I" and $tag eq "O" ) { $chunkEnd = $true; }
271 |
272 | if ($prevTag ne "O" and $prevTag ne "." and $prevType ne $type) {
273 | $chunkEnd = $true;
274 | }
275 |
276 | # corrected 1998-12-22: these chunks are assumed to have length 1
277 | if ( $prevTag eq "]" ) { $chunkEnd = $true; }
278 | if ( $prevTag eq "[" ) { $chunkEnd = $true; }
279 |
280 | return($chunkEnd);
281 | }
282 |
283 | # startOfChunk: checks if a chunk started between the previous and current word
284 | # arguments: previous and current chunk tags, previous and current types
285 | # note: this code is capable of handling other chunk representations
286 | # than the default CoNLL-2000 ones, see EACL'99 paper of Tjong
287 | # Kim Sang and Veenstra http://xxx.lanl.gov/abs/cs.CL/9907006
288 |
289 | sub startOfChunk {
290 | my $prevTag = shift(@_);
291 | my $tag = shift(@_);
292 | my $prevType = shift(@_);
293 | my $type = shift(@_);
294 | my $chunkStart = $false;
295 |
296 | if ( $prevTag eq "B" and $tag eq "B" ) { $chunkStart = $true; }
297 | if ( $prevTag eq "I" and $tag eq "B" ) { $chunkStart = $true; }
298 | if ( $prevTag eq "O" and $tag eq "B" ) { $chunkStart = $true; }
299 | if ( $prevTag eq "O" and $tag eq "I" ) { $chunkStart = $true; }
300 |
301 | if ( $prevTag eq "E" and $tag eq "E" ) { $chunkStart = $true; }
302 | if ( $prevTag eq "E" and $tag eq "I" ) { $chunkStart = $true; }
303 | if ( $prevTag eq "O" and $tag eq "E" ) { $chunkStart = $true; }
304 | if ( $prevTag eq "O" and $tag eq "I" ) { $chunkStart = $true; }
305 |
306 | if ($tag ne "O" and $tag ne "." and $prevType ne $type) {
307 | $chunkStart = $true;
308 | }
309 |
310 | # corrected 1998-12-22: these chunks are assumed to have length 1
311 | if ( $tag eq "[" ) { $chunkStart = $true; }
312 | if ( $tag eq "]" ) { $chunkStart = $true; }
313 |
314 | return($chunkStart);
315 | }
316 |
--------------------------------------------------------------------------------
/lm_pretraining/create_pretraining_data.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Create masked LM/next sentence masked_lm TF examples for BERT."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import random
23 | import tokenization
24 | import tensorflow as tf
25 |
26 | flags = tf.flags
27 |
28 | FLAGS = flags.FLAGS
29 |
30 | flags.DEFINE_string("input_file", None,
31 | "Input raw text file (or comma-separated list of files).")
32 |
33 | flags.DEFINE_string(
34 | "output_file", None,
35 | "Output TF example file (or comma-separated list of files).")
36 |
37 | flags.DEFINE_string("vocab_file", None,
38 | "The vocabulary file that the BERT model was trained on.")
39 |
40 | flags.DEFINE_bool(
41 | "do_lower_case", True,
42 | "Whether to lower case the input text. Should be True for uncased "
43 | "models and False for cased models.")
44 |
45 | flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")
46 |
47 | flags.DEFINE_integer("max_predictions_per_seq", 20,
48 | "Maximum number of masked LM predictions per sequence.")
49 |
50 | flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.")
51 |
52 | flags.DEFINE_integer(
53 | "dupe_factor", 10,
54 | "Number of times to duplicate the input data (with different masks).")
55 |
56 | flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.")
57 |
58 | flags.DEFINE_float(
59 | "short_seq_prob", 0.1,
60 | "Probability of creating sequences which are shorter than the "
61 | "maximum length.")
62 |
63 |
64 | class TrainingInstance(object):
65 | """A single training instance (sentence pair)."""
66 |
67 | def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels,
68 | is_random_next):
69 | self.tokens = tokens
70 | self.segment_ids = segment_ids
71 | self.is_random_next = is_random_next
72 | self.masked_lm_positions = masked_lm_positions
73 | self.masked_lm_labels = masked_lm_labels
74 |
75 | def __str__(self):
76 | s = ""
77 | s += "tokens: %s\n" % (" ".join(
78 | [tokenization.printable_text(x) for x in self.tokens]))
79 | s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids]))
80 | s += "is_random_next: %s\n" % self.is_random_next
81 | s += "masked_lm_positions: %s\n" % (" ".join(
82 | [str(x) for x in self.masked_lm_positions]))
83 | s += "masked_lm_labels: %s\n" % (" ".join(
84 | [tokenization.printable_text(x) for x in self.masked_lm_labels]))
85 | s += "\n"
86 | return s
87 |
88 | def __repr__(self):
89 | return self.__str__()
90 |
91 |
92 | def write_instance_to_example_files(instances, tokenizer, max_seq_length,
93 | max_predictions_per_seq, output_files):
94 | """Create TF example files from `TrainingInstance`s."""
95 | writers = []
96 | for output_file in output_files:
97 | writers.append(tf.python_io.TFRecordWriter(output_file))
98 |
99 | writer_index = 0
100 |
101 | total_written = 0
102 | for (inst_index, instance) in enumerate(instances):
103 | input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
104 | input_mask = [1] * len(input_ids)
105 | segment_ids = list(instance.segment_ids)
106 | assert len(input_ids) <= max_seq_length
107 |
108 | while len(input_ids) < max_seq_length:
109 | input_ids.append(0)
110 | input_mask.append(0)
111 | segment_ids.append(0)
112 |
113 | assert len(input_ids) == max_seq_length
114 | assert len(input_mask) == max_seq_length
115 | assert len(segment_ids) == max_seq_length
116 |
117 | masked_lm_positions = list(instance.masked_lm_positions)
118 | masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)
119 | masked_lm_weights = [1.0] * len(masked_lm_ids)
120 |
121 | while len(masked_lm_positions) < max_predictions_per_seq:
122 | masked_lm_positions.append(0)
123 | masked_lm_ids.append(0)
124 | masked_lm_weights.append(0.0)
125 |
126 | next_sentence_label = 1 if instance.is_random_next else 0
127 |
128 | features = collections.OrderedDict()
129 | features["input_ids"] = create_int_feature(input_ids)
130 | features["input_mask"] = create_int_feature(input_mask)
131 | features["segment_ids"] = create_int_feature(segment_ids)
132 | features["masked_lm_positions"] = create_int_feature(masked_lm_positions)
133 | features["masked_lm_ids"] = create_int_feature(masked_lm_ids)
134 | features["masked_lm_weights"] = create_float_feature(masked_lm_weights)
135 | features["next_sentence_labels"] = create_int_feature([next_sentence_label])
136 |
137 | tf_example = tf.train.Example(features=tf.train.Features(feature=features))
138 |
139 | writers[writer_index].write(tf_example.SerializeToString())
140 | writer_index = (writer_index + 1) % len(writers)
141 |
142 | total_written += 1
143 |
144 | if inst_index < 20:
145 | tf.logging.info("*** Example ***")
146 | tf.logging.info("tokens: %s" % " ".join(
147 | [tokenization.printable_text(x) for x in instance.tokens]))
148 |
149 | for feature_name in features.keys():
150 | feature = features[feature_name]
151 | values = []
152 | if feature.int64_list.value:
153 | values = feature.int64_list.value
154 | elif feature.float_list.value:
155 | values = feature.float_list.value
156 | tf.logging.info(
157 | "%s: %s" % (feature_name, " ".join([str(x) for x in values])))
158 |
159 | for writer in writers:
160 | writer.close()
161 |
162 | tf.logging.info("Wrote %d total instances", total_written)
163 |
164 |
165 | def create_int_feature(values):
166 | feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
167 | return feature
168 |
169 |
170 | def create_float_feature(values):
171 | feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
172 | return feature
173 |
174 |
175 | def create_training_instances(input_files, tokenizer, max_seq_length,
176 | dupe_factor, short_seq_prob, masked_lm_prob,
177 | max_predictions_per_seq, rng):
178 | """Create `TrainingInstance`s from raw text."""
179 | all_documents = [[]]
180 |
181 | # Input file format:
182 | # (1) One sentence per line. These should ideally be actual sentences, not
183 | # entire paragraphs or arbitrary spans of text. (Because we use the
184 | # sentence boundaries for the "next sentence prediction" task).
185 | # (2) Blank lines between documents. Document boundaries are needed so
186 | # that the "next sentence prediction" task doesn't span between documents.
187 | for input_file in input_files:
188 | with tf.gfile.GFile(input_file, "r") as reader:
189 | while True:
190 | line = tokenization.convert_to_unicode(reader.readline())
191 | if not line:
192 | break
193 | line = line.strip()
194 |
195 | # Empty lines are used as document delimiters
196 | if not line:
197 | all_documents.append([])
198 | tokens = tokenizer.tokenize(line)
199 | if tokens:
200 | all_documents[-1].append(tokens)
201 |
202 | # Remove empty documents
203 | all_documents = [x for x in all_documents if x]
204 | rng.shuffle(all_documents)
205 |
206 | vocab_words = list(tokenizer.vocab.keys())
207 | instances = []
208 | for _ in range(dupe_factor):
209 | for document_index in range(len(all_documents)):
210 | instances.extend(
211 | create_instances_from_document(
212 | all_documents, document_index, max_seq_length, short_seq_prob,
213 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng))
214 |
215 | rng.shuffle(instances)
216 | return instances
217 |
218 |
219 | def create_instances_from_document(
220 | all_documents, document_index, max_seq_length, short_seq_prob,
221 | masked_lm_prob, max_predictions_per_seq, vocab_words, rng):
222 | """Creates `TrainingInstance`s for a single document."""
223 | document = all_documents[document_index]
224 |
225 | # Account for [CLS], [SEP], [SEP]
226 | max_num_tokens = max_seq_length - 3
227 |
228 | # We *usually* want to fill up the entire sequence since we are padding
229 | # to `max_seq_length` anyways, so short sequences are generally wasted
230 | # computation. However, we *sometimes*
231 | # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
232 | # sequences to minimize the mismatch between pre-training and fine-tuning.
233 | # The `target_seq_length` is just a rough target however, whereas
234 | # `max_seq_length` is a hard limit.
235 | target_seq_length = max_num_tokens
236 | if rng.random() < short_seq_prob:
237 | target_seq_length = rng.randint(2, max_num_tokens)
238 |
239 | # We DON'T just concatenate all of the tokens from a document into a long
240 | # sequence and choose an arbitrary split point because this would make the
241 | # next sentence prediction task too easy. Instead, we split the input into
242 | # segments "A" and "B" based on the actual "sentences" provided by the user
243 | # input.
244 | instances = []
245 | current_chunk = []
246 | current_length = 0
247 | i = 0
248 | while i < len(document):
249 | segment = document[i]
250 | current_chunk.append(segment)
251 | current_length += len(segment)
252 | if i == len(document) - 1 or current_length >= target_seq_length:
253 | if current_chunk:
254 | # `a_end` is how many segments from `current_chunk` go into the `A`
255 | # (first) sentence.
256 | a_end = 1
257 | if len(current_chunk) >= 2:
258 | a_end = rng.randint(1, len(current_chunk) - 1)
259 |
260 | tokens_a = []
261 | for j in range(a_end):
262 | tokens_a.extend(current_chunk[j])
263 |
264 | tokens_b = []
265 | # Random next
266 | is_random_next = False
267 | if len(current_chunk) == 1 or rng.random() < 0.5:
268 | is_random_next = True
269 | target_b_length = target_seq_length - len(tokens_a)
270 |
271 | # This should rarely go for more than one iteration for large
272 | # corpora. However, just to be careful, we try to make sure that
273 | # the random document is not the same as the document
274 | # we're processing.
275 | for _ in range(10):
276 | random_document_index = rng.randint(0, len(all_documents) - 1)
277 | if random_document_index != document_index:
278 | break
279 |
280 | random_document = all_documents[random_document_index]
281 | random_start = rng.randint(0, len(random_document) - 1)
282 | for j in range(random_start, len(random_document)):
283 | tokens_b.extend(random_document[j])
284 | if len(tokens_b) >= target_b_length:
285 | break
286 | # We didn't actually use these segments so we "put them back" so
287 | # they don't go to waste.
288 | num_unused_segments = len(current_chunk) - a_end
289 | i -= num_unused_segments
290 | # Actual next
291 | else:
292 | is_random_next = False
293 | for j in range(a_end, len(current_chunk)):
294 | tokens_b.extend(current_chunk[j])
295 | truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
296 |
297 | assert len(tokens_a) >= 1
298 | assert len(tokens_b) >= 1
299 |
300 | tokens = []
301 | segment_ids = []
302 | tokens.append("[CLS]")
303 | segment_ids.append(0)
304 | for token in tokens_a:
305 | tokens.append(token)
306 | segment_ids.append(0)
307 |
308 | tokens.append("[SEP]")
309 | segment_ids.append(0)
310 |
311 | for token in tokens_b:
312 | tokens.append(token)
313 | segment_ids.append(1)
314 | tokens.append("[SEP]")
315 | segment_ids.append(1)
316 |
317 | (tokens, masked_lm_positions,
318 | masked_lm_labels) = create_masked_lm_predictions(
319 | tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng)
320 | instance = TrainingInstance(
321 | tokens=tokens,
322 | segment_ids=segment_ids,
323 | is_random_next=is_random_next,
324 | masked_lm_positions=masked_lm_positions,
325 | masked_lm_labels=masked_lm_labels)
326 | instances.append(instance)
327 | current_chunk = []
328 | current_length = 0
329 | i += 1
330 |
331 | return instances
332 |
333 |
334 | MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
335 | ["index", "label"])
336 |
337 |
338 | def create_masked_lm_predictions(tokens, masked_lm_prob,
339 | max_predictions_per_seq, vocab_words, rng):
340 | """Creates the predictions for the masked LM objective."""
341 |
342 | cand_indexes = []
343 | for (i, token) in enumerate(tokens):
344 | if token == "[CLS]" or token == "[SEP]":
345 | continue
346 | cand_indexes.append(i)
347 |
348 | rng.shuffle(cand_indexes)
349 |
350 | output_tokens = list(tokens)
351 |
352 | num_to_predict = min(max_predictions_per_seq,
353 | max(1, int(round(len(tokens) * masked_lm_prob))))
354 |
355 | masked_lms = []
356 | covered_indexes = set()
357 | for index in cand_indexes:
358 | if len(masked_lms) >= num_to_predict:
359 | break
360 | if index in covered_indexes:
361 | continue
362 | covered_indexes.add(index)
363 |
364 | masked_token = None
365 | # 80% of the time, replace with [MASK]
366 | if rng.random() < 0.8:
367 | masked_token = "[MASK]"
368 | else:
369 | # 10% of the time, keep original
370 | if rng.random() < 0.5:
371 | masked_token = tokens[index]
372 | # 10% of the time, replace with random word
373 | else:
374 | masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
375 |
376 | output_tokens[index] = masked_token
377 |
378 | masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
379 |
380 | masked_lms = sorted(masked_lms, key=lambda x: x.index)
381 |
382 | masked_lm_positions = []
383 | masked_lm_labels = []
384 | for p in masked_lms:
385 | masked_lm_positions.append(p.index)
386 | masked_lm_labels.append(p.label)
387 |
388 | return (output_tokens, masked_lm_positions, masked_lm_labels)
389 |
390 |
391 | def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng):
392 | """Truncates a pair of sequences to a maximum sequence length."""
393 | while True:
394 | total_length = len(tokens_a) + len(tokens_b)
395 | if total_length <= max_num_tokens:
396 | break
397 |
398 | trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
399 | assert len(trunc_tokens) >= 1
400 |
401 | # We want to sometimes truncate from the front and sometimes from the
402 | # back to add more randomness and avoid biases.
403 | if rng.random() < 0.5:
404 | del trunc_tokens[0]
405 | else:
406 | trunc_tokens.pop()
407 |
408 |
409 | def main(_):
410 | tf.logging.set_verbosity(tf.logging.INFO)
411 |
412 | tokenizer = tokenization.FullTokenizer(
413 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
414 |
415 | input_files = []
416 | for input_pattern in FLAGS.input_file.split(","):
417 | input_files.extend(tf.gfile.Glob(input_pattern))
418 |
419 | tf.logging.info("*** Reading from input files ***")
420 | for input_file in input_files:
421 | tf.logging.info(" %s", input_file)
422 |
423 | rng = random.Random(FLAGS.random_seed)
424 | instances = create_training_instances(
425 | input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
426 | FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
427 | rng)
428 |
429 | output_files = FLAGS.output_file.split(",")
430 | tf.logging.info("*** Writing to output files ***")
431 | for output_file in output_files:
432 | tf.logging.info(" %s", output_file)
433 |
434 | write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
435 | FLAGS.max_predictions_per_seq, output_files)
436 |
437 |
438 | if __name__ == "__main__":
439 | flags.mark_flag_as_required("input_file")
440 | flags.mark_flag_as_required("output_file")
441 | flags.mark_flag_as_required("vocab_file")
442 | tf.app.run()
443 |
--------------------------------------------------------------------------------
/lm_pretraining/run_pretraining.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Run masked LM/next sentence masked_lm pre-training for BERT."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import os
22 | import modeling
23 | import optimization
24 | import tensorflow as tf
25 |
26 | flags = tf.flags
27 |
28 | FLAGS = flags.FLAGS
29 |
30 | ## Required parameters
31 | flags.DEFINE_string(
32 | "bert_config_file", None,
33 | "The config json file corresponding to the pre-trained BERT model. "
34 | "This specifies the model architecture.")
35 |
36 | flags.DEFINE_string(
37 | "input_file", None,
38 | "Input TF example files (can be a glob or comma separated).")
39 |
40 | flags.DEFINE_string(
41 | "output_dir", None,
42 | "The output directory where the model checkpoints will be written.")
43 |
44 | ## Other parameters
45 | flags.DEFINE_string(
46 | "init_checkpoint", None,
47 | "Initial checkpoint (usually from a pre-trained BERT model).")
48 |
49 | flags.DEFINE_integer(
50 | "max_seq_length", 128,
51 | "The maximum total input sequence length after WordPiece tokenization. "
52 | "Sequences longer than this will be truncated, and sequences shorter "
53 | "than this will be padded. Must match data generation.")
54 |
55 | flags.DEFINE_integer(
56 | "max_predictions_per_seq", 20,
57 | "Maximum number of masked LM predictions per sequence. "
58 | "Must match data generation.")
59 |
60 | flags.DEFINE_bool("do_train", False, "Whether to run training.")
61 |
62 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.")
63 |
64 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
65 |
66 | flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.")
67 |
68 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
69 |
70 | flags.DEFINE_integer("num_train_steps", 100000, "Number of training steps.")
71 |
72 | flags.DEFINE_integer("num_warmup_steps", 10000, "Number of warmup steps.")
73 |
74 | flags.DEFINE_integer("save_checkpoints_steps", 1000,
75 | "How often to save the model checkpoint.")
76 |
77 | flags.DEFINE_integer("iterations_per_loop", 1000,
78 | "How many steps to make in each estimator call.")
79 |
80 | flags.DEFINE_integer("max_eval_steps", 100, "Maximum number of eval steps.")
81 |
82 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
83 |
84 | tf.flags.DEFINE_string(
85 | "tpu_name", None,
86 | "The Cloud TPU to use for training. This should be either the name "
87 | "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
88 | "url.")
89 |
90 | tf.flags.DEFINE_string(
91 | "tpu_zone", None,
92 | "[Optional] GCE zone where the Cloud TPU is located in. If not "
93 | "specified, we will attempt to automatically detect the GCE project from "
94 | "metadata.")
95 |
96 | tf.flags.DEFINE_string(
97 | "gcp_project", None,
98 | "[Optional] Project name for the Cloud TPU-enabled project. If not "
99 | "specified, we will attempt to automatically detect the GCE project from "
100 | "metadata.")
101 |
102 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
103 |
104 | flags.DEFINE_integer(
105 | "num_tpu_cores", 8,
106 | "Only used if `use_tpu` is True. Total number of TPU cores to use.")
107 |
108 |
109 | def model_fn_builder(bert_config, init_checkpoint, learning_rate,
110 | num_train_steps, num_warmup_steps, use_tpu,
111 | use_one_hot_embeddings):
112 | """Returns `model_fn` closure for TPUEstimator."""
113 |
114 | def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
115 | """The `model_fn` for TPUEstimator."""
116 |
117 | tf.logging.info("*** Features ***")
118 | for name in sorted(features.keys()):
119 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
120 |
121 | input_ids = features["input_ids"]
122 | input_mask = features["input_mask"]
123 | segment_ids = features["segment_ids"]
124 | masked_lm_positions = features["masked_lm_positions"]
125 | masked_lm_ids = features["masked_lm_ids"]
126 | masked_lm_weights = features["masked_lm_weights"]
127 | next_sentence_labels = features["next_sentence_labels"]
128 |
129 | is_training = (mode == tf.estimator.ModeKeys.TRAIN)
130 |
131 | model = modeling.BertModel(
132 | config=bert_config,
133 | is_training=is_training,
134 | input_ids=input_ids,
135 | input_mask=input_mask,
136 | token_type_ids=segment_ids,
137 | use_one_hot_embeddings=use_one_hot_embeddings)
138 |
139 | (masked_lm_loss,
140 | masked_lm_example_loss, masked_lm_log_probs) = get_masked_lm_output(
141 | bert_config, model.get_sequence_output(), model.get_embedding_table(),
142 | masked_lm_positions, masked_lm_ids, masked_lm_weights)
143 |
144 | (next_sentence_loss, next_sentence_example_loss,
145 | next_sentence_log_probs) = get_next_sentence_output(
146 | bert_config, model.get_pooled_output(), next_sentence_labels)
147 |
148 | total_loss = masked_lm_loss + next_sentence_loss
149 |
150 | tvars = tf.trainable_variables()
151 |
152 | initialized_variable_names = {}
153 | scaffold_fn = None
154 | if init_checkpoint:
155 | (assignment_map, initialized_variable_names
156 | ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
157 | if use_tpu:
158 |
159 | def tpu_scaffold():
160 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
161 | return tf.train.Scaffold()
162 |
163 | scaffold_fn = tpu_scaffold
164 | else:
165 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
166 |
167 | tf.logging.info("**** Trainable Variables ****")
168 | for var in tvars:
169 | init_string = ""
170 | if var.name in initialized_variable_names:
171 | init_string = ", *INIT_FROM_CKPT*"
172 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
173 | init_string)
174 |
175 | output_spec = None
176 | if mode == tf.estimator.ModeKeys.TRAIN:
177 | train_op = optimization.create_optimizer(
178 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)
179 |
180 | output_spec = tf.contrib.tpu.TPUEstimatorSpec(
181 | mode=mode,
182 | loss=total_loss,
183 | train_op=train_op,
184 | scaffold_fn=scaffold_fn)
185 | elif mode == tf.estimator.ModeKeys.EVAL:
186 |
187 | def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
188 | masked_lm_weights, next_sentence_example_loss,
189 | next_sentence_log_probs, next_sentence_labels):
190 | """Computes the loss and accuracy of the model."""
191 | masked_lm_log_probs = tf.reshape(masked_lm_log_probs,
192 | [-1, masked_lm_log_probs.shape[-1]])
193 | masked_lm_predictions = tf.argmax(
194 | masked_lm_log_probs, axis=-1, output_type=tf.int32)
195 | masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1])
196 | masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
197 | masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
198 | masked_lm_accuracy = tf.metrics.accuracy(
199 | labels=masked_lm_ids,
200 | predictions=masked_lm_predictions,
201 | weights=masked_lm_weights)
202 | masked_lm_mean_loss = tf.metrics.mean(
203 | values=masked_lm_example_loss, weights=masked_lm_weights)
204 |
205 | next_sentence_log_probs = tf.reshape(
206 | next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]])
207 | next_sentence_predictions = tf.argmax(
208 | next_sentence_log_probs, axis=-1, output_type=tf.int32)
209 | next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
210 | next_sentence_accuracy = tf.metrics.accuracy(
211 | labels=next_sentence_labels, predictions=next_sentence_predictions)
212 | next_sentence_mean_loss = tf.metrics.mean(
213 | values=next_sentence_example_loss)
214 |
215 | return {
216 | "masked_lm_accuracy": masked_lm_accuracy,
217 | "masked_lm_loss": masked_lm_mean_loss,
218 | "next_sentence_accuracy": next_sentence_accuracy,
219 | "next_sentence_loss": next_sentence_mean_loss,
220 | }
221 |
222 | eval_metrics = (metric_fn, [
223 | masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
224 | masked_lm_weights, next_sentence_example_loss,
225 | next_sentence_log_probs, next_sentence_labels
226 | ])
227 | output_spec = tf.contrib.tpu.TPUEstimatorSpec(
228 | mode=mode,
229 | loss=total_loss,
230 | eval_metrics=eval_metrics,
231 | scaffold_fn=scaffold_fn)
232 | else:
233 | raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode))
234 |
235 | return output_spec
236 |
237 | return model_fn
238 |
239 |
240 | def get_masked_lm_output(bert_config, input_tensor, output_weights, positions,
241 | label_ids, label_weights):
242 | """Get loss and log probs for the masked LM."""
243 | input_tensor = gather_indexes(input_tensor, positions)
244 |
245 | with tf.variable_scope("cls/predictions"):
246 | # We apply one more non-linear transformation before the output layer.
247 | # This matrix is not used after pre-training.
248 | with tf.variable_scope("transform"):
249 | input_tensor = tf.layers.dense(
250 | input_tensor,
251 | units=bert_config.hidden_size,
252 | activation=modeling.get_activation(bert_config.hidden_act),
253 | kernel_initializer=modeling.create_initializer(
254 | bert_config.initializer_range))
255 | input_tensor = modeling.layer_norm(input_tensor)
256 |
257 | # The output weights are the same as the input embeddings, but there is
258 | # an output-only bias for each token.
259 | output_bias = tf.get_variable(
260 | "output_bias",
261 | shape=[bert_config.vocab_size],
262 | initializer=tf.zeros_initializer())
263 | logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
264 | logits = tf.nn.bias_add(logits, output_bias)
265 | log_probs = tf.nn.log_softmax(logits, axis=-1)
266 |
267 | label_ids = tf.reshape(label_ids, [-1])
268 | label_weights = tf.reshape(label_weights, [-1])
269 |
270 | one_hot_labels = tf.one_hot(
271 | label_ids, depth=bert_config.vocab_size, dtype=tf.float32)
272 |
273 | # The `positions` tensor might be zero-padded (if the sequence is too
274 | # short to have the maximum number of predictions). The `label_weights`
275 | # tensor has a value of 1.0 for every real prediction and 0.0 for the
276 | # padding predictions.
277 | per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])
278 | numerator = tf.reduce_sum(label_weights * per_example_loss)
279 | denominator = tf.reduce_sum(label_weights) + 1e-5
280 | loss = numerator / denominator
281 |
282 | return (loss, per_example_loss, log_probs)
283 |
284 |
285 | def get_next_sentence_output(bert_config, input_tensor, labels):
286 | """Get loss and log probs for the next sentence prediction."""
287 |
288 | # Simple binary classification. Note that 0 is "next sentence" and 1 is
289 | # "random sentence". This weight matrix is not used after pre-training.
290 | with tf.variable_scope("cls/seq_relationship"):
291 | output_weights = tf.get_variable(
292 | "output_weights",
293 | shape=[2, bert_config.hidden_size],
294 | initializer=modeling.create_initializer(bert_config.initializer_range))
295 | output_bias = tf.get_variable(
296 | "output_bias", shape=[2], initializer=tf.zeros_initializer())
297 |
298 | logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
299 | logits = tf.nn.bias_add(logits, output_bias)
300 | log_probs = tf.nn.log_softmax(logits, axis=-1)
301 | labels = tf.reshape(labels, [-1])
302 | one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32)
303 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
304 | loss = tf.reduce_mean(per_example_loss)
305 | return (loss, per_example_loss, log_probs)
306 |
307 |
308 | def gather_indexes(sequence_tensor, positions):
309 | """Gathers the vectors at the specific positions over a minibatch."""
310 | sequence_shape = modeling.get_shape_list(sequence_tensor, expected_rank=3)
311 | batch_size = sequence_shape[0]
312 | seq_length = sequence_shape[1]
313 | width = sequence_shape[2]
314 |
315 | flat_offsets = tf.reshape(
316 | tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
317 | flat_positions = tf.reshape(positions + flat_offsets, [-1])
318 | flat_sequence_tensor = tf.reshape(sequence_tensor,
319 | [batch_size * seq_length, width])
320 | output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
321 | return output_tensor
322 |
323 |
324 | def input_fn_builder(input_files,
325 | max_seq_length,
326 | max_predictions_per_seq,
327 | is_training,
328 | num_cpu_threads=4):
329 | """Creates an `input_fn` closure to be passed to TPUEstimator."""
330 |
331 | def input_fn(params):
332 | """The actual input function."""
333 | batch_size = params["batch_size"]
334 |
335 | name_to_features = {
336 | "input_ids":
337 | tf.FixedLenFeature([max_seq_length], tf.int64),
338 | "input_mask":
339 | tf.FixedLenFeature([max_seq_length], tf.int64),
340 | "segment_ids":
341 | tf.FixedLenFeature([max_seq_length], tf.int64),
342 | "masked_lm_positions":
343 | tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
344 | "masked_lm_ids":
345 | tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
346 | "masked_lm_weights":
347 | tf.FixedLenFeature([max_predictions_per_seq], tf.float32),
348 | "next_sentence_labels":
349 | tf.FixedLenFeature([1], tf.int64),
350 | }
351 |
352 | # For training, we want a lot of parallel reading and shuffling.
353 | # For eval, we want no shuffling and parallel reading doesn't matter.
354 | if is_training:
355 | d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files))
356 | d = d.repeat()
357 | d = d.shuffle(buffer_size=len(input_files))
358 |
359 | # `cycle_length` is the number of parallel files that get read.
360 | cycle_length = min(num_cpu_threads, len(input_files))
361 |
362 | # `sloppy` mode means that the interleaving is not exact. This adds
363 | # even more randomness to the training pipeline.
364 | d = d.apply(
365 | tf.contrib.data.parallel_interleave(
366 | tf.data.TFRecordDataset,
367 | sloppy=is_training,
368 | cycle_length=cycle_length))
369 | d = d.shuffle(buffer_size=100)
370 | else:
371 | d = tf.data.TFRecordDataset(input_files)
372 | # Since we evaluate for a fixed number of steps we don't want to encounter
373 | # out-of-range exceptions.
374 | d = d.repeat()
375 |
376 | # We must `drop_remainder` on training because the TPU requires fixed
377 | # size dimensions. For eval, we assume we are evaluating on the CPU or GPU
378 | # and we *don't* want to drop the remainder, otherwise we wont cover
379 | # every sample.
380 | d = d.apply(
381 | tf.contrib.data.map_and_batch(
382 | lambda record: _decode_record(record, name_to_features),
383 | batch_size=batch_size,
384 | num_parallel_batches=num_cpu_threads,
385 | drop_remainder=True))
386 | return d
387 |
388 | return input_fn
389 |
390 |
391 | def _decode_record(record, name_to_features):
392 | """Decodes a record to a TensorFlow example."""
393 | example = tf.parse_single_example(record, name_to_features)
394 |
395 | # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
396 | # So cast all int64 to int32.
397 | for name in list(example.keys()):
398 | t = example[name]
399 | if t.dtype == tf.int64:
400 | t = tf.to_int32(t)
401 | example[name] = t
402 |
403 | return example
404 |
405 |
406 | def main(_):
407 | tf.logging.set_verbosity(tf.logging.INFO)
408 |
409 | if not FLAGS.do_train and not FLAGS.do_eval:
410 | raise ValueError("At least one of `do_train` or `do_eval` must be True.")
411 |
412 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
413 |
414 | tf.gfile.MakeDirs(FLAGS.output_dir)
415 |
416 | input_files = []
417 | for input_pattern in FLAGS.input_file.split(","):
418 | input_files.extend(tf.gfile.Glob(input_pattern))
419 |
420 | tf.logging.info("*** Input Files ***")
421 | for input_file in input_files:
422 | tf.logging.info(" %s" % input_file)
423 |
424 | tpu_cluster_resolver = None
425 | if FLAGS.use_tpu and FLAGS.tpu_name:
426 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
427 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
428 |
429 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
430 | run_config = tf.contrib.tpu.RunConfig(
431 | cluster=tpu_cluster_resolver,
432 | master=FLAGS.master,
433 | model_dir=FLAGS.output_dir,
434 | save_checkpoints_steps=FLAGS.save_checkpoints_steps,
435 | tpu_config=tf.contrib.tpu.TPUConfig(
436 | iterations_per_loop=FLAGS.iterations_per_loop,
437 | num_shards=FLAGS.num_tpu_cores,
438 | per_host_input_for_training=is_per_host))
439 |
440 | model_fn = model_fn_builder(
441 | bert_config=bert_config,
442 | init_checkpoint=FLAGS.init_checkpoint,
443 | learning_rate=FLAGS.learning_rate,
444 | num_train_steps=FLAGS.num_train_steps,
445 | num_warmup_steps=FLAGS.num_warmup_steps,
446 | use_tpu=FLAGS.use_tpu,
447 | use_one_hot_embeddings=FLAGS.use_tpu)
448 |
449 | # If TPU is not available, this will fall back to normal Estimator on CPU
450 | # or GPU.
451 | estimator = tf.contrib.tpu.TPUEstimator(
452 | use_tpu=FLAGS.use_tpu,
453 | model_fn=model_fn,
454 | config=run_config,
455 | train_batch_size=FLAGS.train_batch_size,
456 | eval_batch_size=FLAGS.eval_batch_size)
457 |
458 | if FLAGS.do_train:
459 | tf.logging.info("***** Running training *****")
460 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
461 | train_input_fn = input_fn_builder(
462 | input_files=input_files,
463 | max_seq_length=FLAGS.max_seq_length,
464 | max_predictions_per_seq=FLAGS.max_predictions_per_seq,
465 | is_training=True)
466 | estimator.train(input_fn=train_input_fn, max_steps=FLAGS.num_train_steps)
467 |
468 | if FLAGS.do_eval:
469 | tf.logging.info("***** Running evaluation *****")
470 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
471 |
472 | eval_input_fn = input_fn_builder(
473 | input_files=input_files,
474 | max_seq_length=FLAGS.max_seq_length,
475 | max_predictions_per_seq=FLAGS.max_predictions_per_seq,
476 | is_training=False)
477 |
478 | result = estimator.evaluate(
479 | input_fn=eval_input_fn, steps=FLAGS.max_eval_steps)
480 |
481 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
482 | with tf.gfile.GFile(output_eval_file, "w") as writer:
483 | tf.logging.info("***** Eval results *****")
484 | for key in sorted(result.keys()):
485 | tf.logging.info(" %s = %s", key, str(result[key]))
486 | writer.write("%s = %s\n" % (key, str(result[key])))
487 |
488 |
489 | if __name__ == "__main__":
490 | flags.mark_flag_as_required("input_file")
491 | flags.mark_flag_as_required("bert_config_file")
492 | flags.mark_flag_as_required("output_dir")
493 | tf.app.run()
494 |
--------------------------------------------------------------------------------
/downstream_tasks/i2b2_preprocessing/i2b2_2012/Reformat.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 20,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os, xml, xml.etree.ElementTree as ET, numpy as np"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 2,
15 | "metadata": {},
16 | "outputs": [
17 | {
18 | "name": "stdout",
19 | "output_type": "stream",
20 | "text": [
21 | "*******This is a brief decsription of the i2b2 sample data*********\r\n",
22 | "1. Data files:\r\n",
23 | " The group of 3 files with file extentions: txt, extent and tlink\r\n",
24 | "are in the format that are consistent with the previous i2b2 challenges.\r\n",
25 | " In addition, an xml file is provided for each record. The XML \r\n",
26 | "file is browsable using the MAE tool found in the directory.\r\n",
27 | "\r\n",
28 | "2. Usage of MAE tool\r\n",
29 | " - Download and unzip the MAE.zip \r\n",
30 | " - Run the MAE_v0.9.3.jar (please make sure that you have Java installed)\r\n",
31 | " - In menu bar, select File -> Load DTD, and in the popup window,\r\n",
32 | " navigate to the sample folder in the unzipped MAE folder, load the\r\n",
33 | " TemporalAnnotation.dtd\r\n",
34 | " - Then, select File -> Load File, and load one of the records, e.g. \r\n",
35 | " 357.xml\r\n",
36 | " - You will be able to browse the annotations by clicking on different\r\n",
37 | " tabs.\r\n",
38 | "\r\n"
39 | ]
40 | }
41 | ],
42 | "source": [
43 | "cat readme.txt"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": 54,
49 | "metadata": {},
50 | "outputs": [],
51 | "source": [
52 | "START_CDATA = \"\"\n",
54 | "\n",
55 | "TAGS = ['MEDICATION', 'OBSEE', 'SMOKER', 'HYPERTENSION', 'event', 'FAMILY_HIST']\n",
56 | "\n",
57 | "def read_xml_file(xml_path, event_tag_type='ALL_CHILDREN', match_text=True):\n",
58 | " with open(xml_path, mode='r') as f:\n",
59 | " lines = f.readlines()\n",
60 | " text, in_text = [], False\n",
61 | " for i, l in enumerate(lines):\n",
62 | " if START_CDATA in l:\n",
63 | " text.append(list(l[l.find(START_CDATA) + len(START_CDATA):]))\n",
64 | " in_text = True\n",
65 | " elif END_CDATA in l:\n",
66 | " text.append(list(l[:l.find(END_CDATA)]))\n",
67 | " break\n",
68 | " elif in_text:\n",
69 | "# if xml_path.endswith('180-03.xml') and '0808' in l and 'Effingham' in l:\n",
70 | "# print(\"Adjusting known error\")\n",
71 | "# l = l[:9] + ' ' * 4 + l[9:]\n",
72 | "# # elif xml_path.endswith('188-05.xml') and 'Johnson & Johnson' in l:\n",
73 | "# # print(\"Adjusting known error\")\n",
74 | "# # l = l.replace('&', 'and')\n",
75 | " text.append(list(l))\n",
76 | " \n",
77 | " pos_transformer = {}\n",
78 | " \n",
79 | " linear_pos = 1\n",
80 | " for line, sentence in enumerate(text):\n",
81 | " for char_pos, char in enumerate(sentence):\n",
82 | " pos_transformer[linear_pos] = (line, char_pos)\n",
83 | " linear_pos += 1\n",
84 | " \n",
85 | " try: xml_parsed = ET.parse(xml_path)\n",
86 | " except:\n",
87 | " print(xml_path)\n",
88 | " raise\n",
89 | " \n",
90 | " tag_containers = xml_parsed.findall('TAGS')\n",
91 | " assert len(tag_containers) == 1, \"Found multiple tag sets!\"\n",
92 | " tag_container = tag_containers[0]\n",
93 | " \n",
94 | "# event_tags = tag_container.getchildren() if event_tag_type == 'ALL_CHILDREN' else tag_container.findall('event')\n",
95 | " event_tags = tag_container.findall('EVENT')\n",
96 | " event_labels = [['O'] * len(sentence) for sentence in text]\n",
97 | " for event_tag in event_tags:\n",
98 | " base_label = event_tag.attrib['type']\n",
99 | " start_pos, end_pos, event_text = event_tag.attrib['start'], event_tag.attrib['end'], event_tag.attrib['text']\n",
100 | " start_pos, end_pos = int(start_pos)+1, int(end_pos)\n",
101 | " event_text = ' '.join(event_text.split())\n",
102 | "# if event_text == \"0808 O’neil’s Court\":\n",
103 | "# print(\"Adjusting known error\")\n",
104 | "# end_pos -= 4\n",
105 | "# if event_text == 'Johnson and Johnson' and xml_path.endswith('188-05.xml'):\n",
106 | "# print(\"Adjusting known error\")\n",
107 | "# event_text = 'Johnson & Johnson'\n",
108 | " \n",
109 | "\n",
110 | " (start_line, start_char), (end_line, end_char) = pos_transformer[start_pos], pos_transformer[end_pos]\n",
111 | " \n",
112 | " obs_text = []\n",
113 | " for line in range(start_line, end_line+1):\n",
114 | " t = text[line]\n",
115 | " s = start_char if line == start_line else 0\n",
116 | " e = end_char if line == end_line else len(t)\n",
117 | " obs_text.append(''.join(t[s:e+1]).strip())\n",
118 | " obs_text = ' '.join(obs_text)\n",
119 | " obs_text = ' '.join(obs_text.split())\n",
120 | " \n",
121 | " if ''' in obs_text and ''' not in event_text: event_text = event_text.replace(\"'\", \"'\")\n",
122 | " if '"' in obs_text and '"' not in event_text: event_text = event_text.replace('\"', '"')\n",
123 | " \n",
124 | " if match_text: assert obs_text == event_text, (\n",
125 | " (\"Texts don't match! %s v %s\" % (event_text, obs_text)) + '\\n' + str((\n",
126 | " start_pos, end_pos, line, s, e, t, xml_path\n",
127 | " ))\n",
128 | " )\n",
129 | " \n",
130 | " if base_label.strip() == '': continue\n",
131 | " \n",
132 | " event_labels[end_line][end_char] = 'I-%s' % base_label\n",
133 | " event_labels[start_line][start_char] = 'B-%s' % base_label\n",
134 | " \n",
135 | " for line in range(start_line, end_line+1):\n",
136 | " t = text[line]\n",
137 | " s = start_char+1 if line == start_line else 0\n",
138 | " e = end_char-1 if line == end_line else len(t)-1\n",
139 | " for i in range(s, e+1): event_labels[line][i] = 'I-%s' % base_label\n",
140 | "\n",
141 | " return text, event_labels\n",
142 | " \n",
143 | "def merge_into_words(text_by_char, all_labels_by_char):\n",
144 | " assert len(text_by_char) == len(all_labels_by_char), \"Incorrect # of sentences!\"\n",
145 | " \n",
146 | " N = len(text_by_char)\n",
147 | " \n",
148 | " text_by_word, all_labels_by_word = [], []\n",
149 | " \n",
150 | " for sentence_num in range(N):\n",
151 | " sentence_by_char = text_by_char[sentence_num]\n",
152 | " labels_by_char = all_labels_by_char[sentence_num]\n",
153 | " \n",
154 | " assert len(sentence_by_char) == len(labels_by_char), \"Incorrect # of chars in sentence!\"\n",
155 | " S = len(sentence_by_char)\n",
156 | " \n",
157 | " if labels_by_char == (['O'] * len(sentence_by_char)):\n",
158 | " sentence_by_word = ''.join(sentence_by_char).split()\n",
159 | " labels_by_word = ['O'] * len(sentence_by_word)\n",
160 | " else: \n",
161 | " sentence_by_word, labels_by_word = [], []\n",
162 | " text_chunks, labels_chunks = [], []\n",
163 | " s = 0\n",
164 | " for i in range(S):\n",
165 | " if i == S-1:\n",
166 | " text_chunks.append(sentence_by_char[s:])\n",
167 | " labels_chunks.append(labels_by_char[s:])\n",
168 | " elif labels_by_char[i] == 'O': continue\n",
169 | " else:\n",
170 | " if i > 0 and labels_by_char[i-1] == 'O':\n",
171 | " text_chunks.append(sentence_by_char[s:i])\n",
172 | " labels_chunks.append(labels_by_char[s:i])\n",
173 | " s = i\n",
174 | " if labels_by_char[i+1] == 'O' or labels_by_char[i+1][2:] != labels_by_char[i][2:]:\n",
175 | " text_chunks.append(sentence_by_char[s:i+1])\n",
176 | " labels_chunks.append(labels_by_char[s:i+1])\n",
177 | " s = i+1\n",
178 | " \n",
179 | " for text_chunk, labels_chunk in zip(text_chunks, labels_chunks):\n",
180 | " assert len(text_chunk) == len(labels_chunk), \"Bad Chunking (len)\"\n",
181 | " assert len(text_chunk) > 0, \"Bad chunking (len 0)\" + str(text_chunks) + str(labels_chunks)\n",
182 | " \n",
183 | " labels_set = set(labels_chunk)\n",
184 | " assert labels_set == set(['O']) or (len(labels_set) <= 3 and 'O' not in labels_set), (\n",
185 | " (\"Bad chunking (contents) %s\" % ', '.join(labels_set))+ str(text_chunks) + str(labels_chunks)\n",
186 | " )\n",
187 | " \n",
188 | " text_chunk_by_word = ''.join(text_chunk).split()\n",
189 | " W = len(text_chunk_by_word)\n",
190 | " if W == 0: \n",
191 | "# assert labels_set == set(['O']), \"0-word chunking and non-0 label!\" + str(\n",
192 | "# text_chunks) + str(labels_chunks\n",
193 | "# )\n",
194 | " continue\n",
195 | " \n",
196 | " if labels_chunk[0] == 'O': labels_chunk_by_word = ['O'] * W\n",
197 | " elif W == 1: labels_chunk_by_word = [labels_chunk[0]]\n",
198 | " elif W == 2: labels_chunk_by_word = [labels_chunk[0], labels_chunk[-1]]\n",
199 | " else: labels_chunk_by_word = [\n",
200 | " labels_chunk[0]\n",
201 | " ] + [labels_chunk[1]] * (W - 2) + [\n",
202 | " labels_chunk[-1]\n",
203 | " ]\n",
204 | " \n",
205 | " sentence_by_word.extend(text_chunk_by_word)\n",
206 | " labels_by_word.extend(labels_chunk_by_word)\n",
207 | "\n",
208 | " assert len(sentence_by_word) == len(labels_by_word), \"Incorrect # of words in sentence!\" \n",
209 | " \n",
210 | " if len(sentence_by_word) == 0: continue\n",
211 | " \n",
212 | " text_by_word.append(sentence_by_word)\n",
213 | " all_labels_by_word.append(labels_by_word)\n",
214 | " return text_by_word, all_labels_by_word\n",
215 | "\n",
216 | "def reprocess_event_labels(folders, base_path='.', event_tag_type='event', match_text=True, dev_set_size=None):\n",
217 | " all_texts_by_patient, all_labels_by_patient = {}, {}\n",
218 | "\n",
219 | " for folder in folders:\n",
220 | " folder_dir = os.path.join(base_path, folder)\n",
221 | " xml_filenames = [x for x in os.listdir(folder_dir) if x.endswith('xml')]\n",
222 | " for xml_filename in xml_filenames:\n",
223 | " patient_num = int(xml_filename[:-4])\n",
224 | " xml_filepath = os.path.join(folder_dir, xml_filename)\n",
225 | " \n",
226 | " text_by_char, labels_by_char = read_xml_file(\n",
227 | " xml_filepath,\n",
228 | " event_tag_type=event_tag_type,\n",
229 | " match_text=match_text\n",
230 | " )\n",
231 | " text_by_word, labels_by_word = merge_into_words(text_by_char, labels_by_char)\n",
232 | " \n",
233 | " if patient_num not in all_texts_by_patient:\n",
234 | " all_texts_by_patient[patient_num] = []\n",
235 | " all_labels_by_patient[patient_num] = []\n",
236 | " \n",
237 | " all_texts_by_patient[patient_num].extend(text_by_word)\n",
238 | " all_labels_by_patient[patient_num].extend(labels_by_word)\n",
239 | " \n",
240 | " patients = set(all_texts_by_patient.keys())\n",
241 | " \n",
242 | " if dev_set_size is None: train_patients, dev_patients = list(patients), []\n",
243 | " else:\n",
244 | " N_train = int(len(patients) * (1-dev_set_size))\n",
245 | " patients_random = np.random.permutation(list(patients))\n",
246 | " train_patients = list(patients_random[:N_train])\n",
247 | " dev_patients = list(patients_random[N_train:])\n",
248 | " \n",
249 | " train_texts, train_labels = [], []\n",
250 | " dev_texts, dev_labels = [], []\n",
251 | " \n",
252 | " for patient_num in train_patients:\n",
253 | " train_texts.extend(all_texts_by_patient[patient_num])\n",
254 | " train_labels.extend(all_labels_by_patient[patient_num])\n",
255 | "\n",
256 | " for patient_num in dev_patients:\n",
257 | " dev_texts.extend(all_texts_by_patient[patient_num])\n",
258 | " dev_labels.extend(all_labels_by_patient[patient_num])\n",
259 | "\n",
260 | "\n",
261 | " train_out_text_by_sentence = []\n",
262 | " for text, labels in zip(train_texts, train_labels):\n",
263 | " train_out_text_by_sentence.append('\\n'.join('%s %s' % x for x in zip(text, labels)))\n",
264 | " dev_out_text_by_sentence = []\n",
265 | " for text, labels in zip(dev_texts, dev_labels):\n",
266 | " dev_out_text_by_sentence.append('\\n'.join('%s %s' % x for x in zip(text, labels)))\n",
267 | "\n",
268 | " return '\\n\\n'.join(train_out_text_by_sentence), '\\n\\n'.join(dev_out_text_by_sentence)"
269 | ]
270 | },
271 | {
272 | "cell_type": "code",
273 | "execution_count": 55,
274 | "metadata": {},
275 | "outputs": [],
276 | "source": [
277 | "final_train_text, final_dev_text = reprocess_event_labels(\n",
278 | " ['2012-07-15.original-annotation.release'], dev_set_size=0.1, match_text=True\n",
279 | ")"
280 | ]
281 | },
282 | {
283 | "cell_type": "code",
284 | "execution_count": 56,
285 | "metadata": {},
286 | "outputs": [],
287 | "source": [
288 | "test_text, _ = reprocess_event_labels(\n",
289 | " ['2012-08-08.test-data.event-timex-groundtruth/xml'], match_text=False, dev_set_size=None\n",
290 | ")"
291 | ]
292 | },
293 | {
294 | "cell_type": "code",
295 | "execution_count": 57,
296 | "metadata": {},
297 | "outputs": [
298 | {
299 | "name": "stdout",
300 | "output_type": "stream",
301 | "text": [
302 | "ADMISSION B-OCCURRENCE\n",
303 | "DATE O\n",
304 | ": O\n",
305 | "\n",
306 | "12/13/2002 O\n",
307 | "\n",
308 | "DISCHARGE B-OCCURRENCE\n",
309 | "DATE O\n",
310 | ": O\n",
311 | "\n",
312 | "12/14/2002 O\n",
313 | "\n",
314 | "HISTORY O\n",
315 | "OF O\n",
316 | "PRESENT O\n",
317 | "ILLNESS O\n",
318 | ": O\n",
319 | "\n",
320 | "The O\n",
321 | "patient O\n",
322 | "is O\n",
323 | "a O\n",
324 | "30 O\n",
325 | "-year-old O\n",
326 | "Gravida B-OCCURRENCE\n",
327 | "III O\n",
328 | ", O\n",
329 | "Para B-OCCURRENCE\n",
330 | "II O\n",
331 | "who O\n",
332 | "presented B-EVIDENTIAL\n",
333 | "desiring O\n",
334 | "definitive B-TREATMENT\n",
335 | "surgical I-TREATMENT\n",
336 | "sterilization I-TREATMENT\n",
337 | ". O\n",
338 | "\n",
339 | "She O\n",
340 | "was O\n",
341 | "extensively B-OCCURRENCE\n",
342 | "counseled I-OCCURRENCE\n",
343 | "regarding O\n",
344 | "other B-TREATMENT\n",
345 | "methods I-TREATMENT\n",
346 | "of O\n",
347 | "birth B-TREATMENT\n",
348 | "control I-TREATMEN\n"
349 | ]
350 | }
351 | ],
352 | "source": [
353 | "print(final_train_text[:500])"
354 | ]
355 | },
356 | {
357 | "cell_type": "code",
358 | "execution_count": 58,
359 | "metadata": {},
360 | "outputs": [
361 | {
362 | "name": "stdout",
363 | "output_type": "stream",
364 | "text": [
365 | "Admission B-OCCURRENCE\n",
366 | "Date O\n",
367 | ": O\n",
368 | "\n",
369 | "2010-02-05 O\n",
370 | "\n",
371 | "Discharge B-OCCURRENCE\n",
372 | "Date O\n",
373 | ": O\n",
374 | "\n",
375 | "2010-02-06 O\n",
376 | "\n",
377 | "Service O\n",
378 | ": O\n",
379 | "\n",
380 | "TRAUMA O\n",
381 | "\n",
382 | "HISTORY O\n",
383 | "OF O\n",
384 | "PRESENT O\n",
385 | "ILLNESS O\n",
386 | ": O\n",
387 | "\n",
388 | "The O\n",
389 | "patient O\n",
390 | "is O\n",
391 | "a O\n",
392 | "23 O\n",
393 | "year O\n",
394 | "old O\n",
395 | "female O\n",
396 | ", O\n",
397 | "status O\n",
398 | "post O\n",
399 | "fall B-PROBLEM\n",
400 | "from O\n",
401 | "standing O\n",
402 | "position O\n",
403 | "after O\n",
404 | "slipping O\n",
405 | "on O\n",
406 | "ice O\n",
407 | ". O\n",
408 | "\n",
409 | "She O\n",
410 | "had O\n",
411 | "no O\n",
412 | "loss B-PROBLEM\n",
413 | "of I-PROBLEM\n",
414 | "consciousness I-PROBLEM\n",
415 | ". O\n",
416 | "\n",
417 | "She O\n",
418 | "recall\n"
419 | ]
420 | }
421 | ],
422 | "source": [
423 | "print(final_dev_text[:400])"
424 | ]
425 | },
426 | {
427 | "cell_type": "code",
428 | "execution_count": 59,
429 | "metadata": {},
430 | "outputs": [
431 | {
432 | "name": "stdout",
433 | "output_type": "stream",
434 | "text": [
435 | "ADMISSION B-OCCURRENCE\n",
436 | "DATE O\n",
437 | ": O\n",
438 | "\n",
439 | "10/14/96 O\n",
440 | "\n",
441 | "DISCHARGE B-OCCURRENCE\n",
442 | "DATE O\n",
443 | ": O\n",
444 | "\n",
445 | "10/27/96 O\n",
446 | "date O\n",
447 | "of O\n",
448 | "birth B-OCCURRENCE\n",
449 | "; O\n",
450 | "September O\n",
451 | "30 O\n",
452 | ", O\n",
453 | "1917 O\n",
454 | "\n",
455 | "THER O\n",
456 | "PROCEDURES O\n",
457 | ": O\n",
458 | "\n",
459 | "arterial B-TEST\n",
460 | "catheterization I-TEST\n",
461 | "on O\n",
462 | "10/14/96 O\n",
463 | ", O\n",
464 | "head B-TEST\n",
465 | "CT I-TEST\n",
466 | "scan I-TEST\n",
467 | "on O\n",
468 | "10/14/96 O\n",
469 | "\n",
470 | "HISTORY O\n",
471 | "AND O\n",
472 | "REASON O\n",
473 | "FOR O\n",
474 | "HOSPITALIZATION O\n",
475 | ": O\n",
476 | "\n",
477 | "Granrivern O\n",
478 | "Call O\n",
479 | "is O\n",
480 | "a O\n",
481 | "79-year-old O\n",
482 | "right O\n",
483 | "han\n"
484 | ]
485 | }
486 | ],
487 | "source": [
488 | "print(test_text[:400])"
489 | ]
490 | },
491 | {
492 | "cell_type": "code",
493 | "execution_count": 60,
494 | "metadata": {},
495 | "outputs": [],
496 | "source": [
497 | "labels = {}\n",
498 | "for s in final_train_text, final_dev_text, test_text:\n",
499 | " for line in s.split('\\n'):\n",
500 | " if line == '': continue\n",
501 | " label = line.split()[-1]\n",
502 | " assert label == 'O' or label.startswith('B-') or label.startswith('I-'), \"label wrong! %s\" % label\n",
503 | " if label not in labels: labels[label] = 1\n",
504 | " else: labels[label] += 1"
505 | ]
506 | },
507 | {
508 | "cell_type": "code",
509 | "execution_count": 61,
510 | "metadata": {},
511 | "outputs": [
512 | {
513 | "data": {
514 | "text/plain": [
515 | "{'B-OCCURRENCE': 5774,\n",
516 | " 'O': 114910,\n",
517 | " 'B-EVIDENTIAL': 1334,\n",
518 | " 'B-TREATMENT': 7098,\n",
519 | " 'I-TREATMENT': 6748,\n",
520 | " 'I-OCCURRENCE': 3590,\n",
521 | " 'B-CLINICAL_DEPT': 1724,\n",
522 | " 'I-CLINICAL_DEPT': 3253,\n",
523 | " 'B-PROBLEM': 9319,\n",
524 | " 'I-PROBLEM': 13543,\n",
525 | " 'B-TEST': 4762,\n",
526 | " 'I-TEST': 5931,\n",
527 | " 'I-EVIDENTIAL': 84}"
528 | ]
529 | },
530 | "execution_count": 61,
531 | "metadata": {},
532 | "output_type": "execute_result"
533 | }
534 | ],
535 | "source": [
536 | "labels"
537 | ]
538 | },
539 | {
540 | "cell_type": "code",
541 | "execution_count": 62,
542 | "metadata": {},
543 | "outputs": [],
544 | "source": [
545 | "with open('./processed/train.tsv', mode='w') as f:\n",
546 | " f.write(final_train_text)\n",
547 | "with open('./processed/dev.tsv', mode='w') as f:\n",
548 | " f.write(final_dev_text)\n",
549 | "with open('./processed/test.tsv', mode='w') as f:\n",
550 | " f.write(test_text)"
551 | ]
552 | },
553 | {
554 | "cell_type": "code",
555 | "execution_count": null,
556 | "metadata": {},
557 | "outputs": [],
558 | "source": []
559 | }
560 | ],
561 | "metadata": {
562 | "kernelspec": {
563 | "display_name": "Python 3",
564 | "language": "python",
565 | "name": "python3"
566 | },
567 | "language_info": {
568 | "codemirror_mode": {
569 | "name": "ipython",
570 | "version": 3
571 | },
572 | "file_extension": ".py",
573 | "mimetype": "text/x-python",
574 | "name": "python",
575 | "nbconvert_exporter": "python",
576 | "pygments_lexer": "ipython3",
577 | "version": "3.6.8"
578 | }
579 | },
580 | "nbformat": 4,
581 | "nbformat_minor": 2
582 | }
583 |
--------------------------------------------------------------------------------
/downstream_tasks/i2b2_preprocessing/i2b2_2014_deid_hf_risk/Reformat.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os, xml.etree.ElementTree as ET, numpy as np"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 2,
15 | "metadata": {},
16 | "outputs": [
17 | {
18 | "name": "stdout",
19 | "output_type": "stream",
20 | "text": [
21 | "This folder contains a sample of the data that will be released as a part \r\n",
22 | "of the two 2014 i2b2 challenge tracks. \r\n",
23 | "\r\n",
24 | "Each track has its own folder of sample data; the contents are outlined \r\n",
25 | "below.\r\n",
26 | "\r\n",
27 | "----------------------------------------------\r\n",
28 | "\r\n",
29 | "Track 1: De-identification \r\n",
30 | "\r\n",
31 | "Folder contents:\r\n",
32 | "\r\n",
33 | "PHI/\r\n",
34 | "\tThis folder contains only PHI-related annotations for the text \r\n",
35 | "in each document. These files are the gold standard for the de-id track.\r\n",
36 | "\r\n",
37 | "De-identification_guidelines_2014_distribution.pdf\r\n",
38 | "\tThis file contains the guidelines for PHI identification\r\n",
39 | "\r\n",
40 | "de-idi2b2_distribution.dtd\r\n",
41 | "\tThis file is the DTD that describes the valid tags for the \r\n",
42 | "de-identification annotation task.\r\n",
43 | "\r\n",
44 | "---------------------------------------------------\r\n",
45 | "\r\n",
46 | "Track 2: Identifying risk factors for heart disease over time\r\n",
47 | "\r\n",
48 | "Folder contents:\r\n",
49 | "\r\n",
50 | "gold/\r\n",
51 | "\tThis folder contains the gold standard annotations that will be used for\r\n",
52 | "evaluation. These are document-level tags which are defined in the annotation\r\n",
53 | "guidelines. Valid tags and attributes are outlined in the provided cardiac risk DTD\r\n",
54 | "\r\n",
55 | "complete/\r\n",
56 | "\tThis folder contains complete annotations for the entire text. It contains \r\n",
57 | "document level annotations. Each document level annotation is supplemented with tags \r\n",
58 | "that show the the evidence found by our annotators for a particular document level \r\n",
59 | "tag. These annotator tags include position information and are included with the hope \r\n",
60 | "that they will ease system development and error analysis.\r\n",
61 | "\r\n",
62 | "i2b2_2014_annotation_guidelines_distribution.pdf\r\n",
63 | "\tThis file contains the guidelines for the risk factor annotation \r\n",
64 | "\r\n",
65 | "cardiacRisk_distribution.dtd\r\n",
66 | "\tThis file is the DTD that describes the valid tags for the \r\n",
67 | "risk factor annotation task\r\n",
68 | "\r\n",
69 | "---------------------------------------------------\r\n",
70 | "\r\n",
71 | "Overview of the sample files:\r\n",
72 | "\r\n",
73 | "The gold/, complete/, and PHI/ folders each contain 8 XML files. These\r\n",
74 | "files contain four discharge summaries each for two different patients. The file\r\n",
75 | "names follow a consistent pattern with the first set of digits identifying the\r\n",
76 | "patient and the last set of digits identifying the sequential record number\r\n",
77 | "\r\n",
78 | "ie: XXX-YY.xml \r\n",
79 | "where XXX is the patient number, and YY is the record number.\r\n",
80 | "\r\n",
81 | "Example: 320-03.xml\r\n",
82 | "This is the third (03) record for patient 320\r\n",
83 | "\r\n",
84 | "Each file has a root level xml node which will contain a\r\n",
85 | " node that holds the medical annotation text and a node containing\r\n",
86 | "annotations for the document text. The specific annotations contained in each file \r\n",
87 | "are described by the accompanying DTDs and annotation guidelines.\r\n",
88 | "\r\n",
89 | "---------------------------------------------------\r\n",
90 | "\r\n"
91 | ]
92 | }
93 | ],
94 | "source": [
95 | "cat readme.txt"
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": 3,
101 | "metadata": {},
102 | "outputs": [],
103 | "source": [
104 | "START_CDATA = \"\"\n",
106 | "\n",
107 | "TAGS = ['MEDICATION', 'OBSEE', 'SMOKER', 'HYPERTENSION', 'PHI', 'FAMILY_HIST']\n",
108 | "\n",
109 | "def read_xml_file(xml_path, PHI_tag_type='ALL_CHILDREN', match_text=True):\n",
110 | " with open(xml_path, mode='r') as f:\n",
111 | " lines = f.readlines()\n",
112 | " text, in_text = [], False\n",
113 | " for i, l in enumerate(lines):\n",
114 | " if START_CDATA in l:\n",
115 | " text.append(list(l[l.find(START_CDATA) + len(START_CDATA):]))\n",
116 | " in_text = True\n",
117 | " elif END_CDATA in l:\n",
118 | " text.append(list(l[:l.find(END_CDATA)]))\n",
119 | " break\n",
120 | " elif in_text:\n",
121 | " if xml_path.endswith('180-03.xml') and '0808' in l and 'Effingham' in l:\n",
122 | " print(\"Adjusting known error\")\n",
123 | " l = l[:9] + ' ' * 4 + l[9:]\n",
124 | "# elif xml_path.endswith('188-05.xml') and 'Johnson & Johnson' in l:\n",
125 | "# print(\"Adjusting known error\")\n",
126 | "# l = l.replace('&', 'and')\n",
127 | " text.append(list(l))\n",
128 | " \n",
129 | " pos_transformer = {}\n",
130 | " \n",
131 | " linear_pos = 1\n",
132 | " for line, sentence in enumerate(text):\n",
133 | " for char_pos, char in enumerate(sentence):\n",
134 | " pos_transformer[linear_pos] = (line, char_pos)\n",
135 | " linear_pos += 1\n",
136 | " \n",
137 | " xml_parsed = ET.parse(xml_path)\n",
138 | " tag_containers = xml_parsed.findall('TAGS')\n",
139 | " assert len(tag_containers) == 1, \"Found multiple tag sets!\"\n",
140 | " tag_container = tag_containers[0]\n",
141 | " \n",
142 | " PHI_tags = tag_container.getchildren() if PHI_tag_type == 'ALL_CHILDREN' else tag_container.findall('PHI')\n",
143 | " PHI_labels = [['O'] * len(sentence) for sentence in text]\n",
144 | " for PHI_tag in PHI_tags:\n",
145 | " base_label = PHI_tag.attrib['TYPE']\n",
146 | " start_pos, end_pos, PHI_text = PHI_tag.attrib['start'], PHI_tag.attrib['end'], PHI_tag.attrib['text']\n",
147 | " start_pos, end_pos = int(start_pos)+1, int(end_pos)\n",
148 | " PHI_text = ' '.join(PHI_text.split())\n",
149 | "# if PHI_text == \"0808 O’neil’s Court\":\n",
150 | "# print(\"Adjusting known error\")\n",
151 | "# end_pos -= 4\n",
152 | " if PHI_text == 'Johnson and Johnson' and xml_path.endswith('188-05.xml'):\n",
153 | " print(\"Adjusting known error\")\n",
154 | " PHI_text = 'Johnson & Johnson'\n",
155 | " \n",
156 | "\n",
157 | " (start_line, start_char), (end_line, end_char) = pos_transformer[start_pos], pos_transformer[end_pos]\n",
158 | " \n",
159 | " obs_text = []\n",
160 | " for line in range(start_line, end_line+1):\n",
161 | " t = text[line]\n",
162 | " s = start_char if line == start_line else 0\n",
163 | " e = end_char if line == end_line else len(t)\n",
164 | " obs_text.append(''.join(t[s:e+1]).strip())\n",
165 | " obs_text = ' '.join(obs_text)\n",
166 | " obs_text = ' '.join(obs_text.split())\n",
167 | " \n",
168 | " if match_text: assert obs_text == PHI_text, (\n",
169 | " (\"Texts don't match! %s v %s\" % (PHI_text, obs_text)) + '\\n' + str((\n",
170 | " start_pos, end_pos, line, s, e, t, xml_path\n",
171 | " ))\n",
172 | " )\n",
173 | " \n",
174 | " PHI_labels[end_line][end_char] = 'I-%s' % base_label\n",
175 | " PHI_labels[start_line][start_char] = 'B-%s' % base_label\n",
176 | " \n",
177 | " for line in range(start_line, end_line+1):\n",
178 | " t = text[line]\n",
179 | " s = start_char+1 if line == start_line else 0\n",
180 | " e = end_char-1 if line == end_line else len(t)-1\n",
181 | " for i in range(s, e+1): PHI_labels[line][i] = 'I-%s' % base_label\n",
182 | "\n",
183 | " return text, PHI_labels\n",
184 | " \n",
185 | "def merge_into_words(text_by_char, all_labels_by_char):\n",
186 | " assert len(text_by_char) == len(all_labels_by_char), \"Incorrect # of sentences!\"\n",
187 | " \n",
188 | " N = len(text_by_char)\n",
189 | " \n",
190 | " text_by_word, all_labels_by_word = [], []\n",
191 | " \n",
192 | " for sentence_num in range(N):\n",
193 | " sentence_by_char = text_by_char[sentence_num]\n",
194 | " labels_by_char = all_labels_by_char[sentence_num]\n",
195 | " \n",
196 | " assert len(sentence_by_char) == len(labels_by_char), \"Incorrect # of chars in sentence!\"\n",
197 | " S = len(sentence_by_char)\n",
198 | " \n",
199 | " if labels_by_char == (['O'] * len(sentence_by_char)):\n",
200 | " sentence_by_word = ''.join(sentence_by_char).split()\n",
201 | " labels_by_word = ['O'] * len(sentence_by_word)\n",
202 | " else: \n",
203 | " sentence_by_word, labels_by_word = [], []\n",
204 | " text_chunks, labels_chunks = [], []\n",
205 | " s = 0\n",
206 | " for i in range(S):\n",
207 | " if i == S-1:\n",
208 | " text_chunks.append(sentence_by_char[s:])\n",
209 | " labels_chunks.append(labels_by_char[s:])\n",
210 | " elif labels_by_char[i] == 'O': continue\n",
211 | " else:\n",
212 | " if i > 0 and labels_by_char[i-1] == 'O':\n",
213 | " text_chunks.append(sentence_by_char[s:i])\n",
214 | " labels_chunks.append(labels_by_char[s:i])\n",
215 | " s = i\n",
216 | " if labels_by_char[i+1] == 'O' or labels_by_char[i+1][2:] != labels_by_char[i][2:]:\n",
217 | " text_chunks.append(sentence_by_char[s:i+1])\n",
218 | " labels_chunks.append(labels_by_char[s:i+1])\n",
219 | " s = i+1\n",
220 | " \n",
221 | " for text_chunk, labels_chunk in zip(text_chunks, labels_chunks):\n",
222 | " assert len(text_chunk) == len(labels_chunk), \"Bad Chunking (len)\"\n",
223 | " assert len(text_chunk) > 0, \"Bad chunking (len 0)\" + str(text_chunks) + str(labels_chunks)\n",
224 | " \n",
225 | " labels_set = set(labels_chunk)\n",
226 | " assert labels_set == set(['O']) or (len(labels_set) <= 3 and 'O' not in labels_set), (\n",
227 | " (\"Bad chunking (contents) %s\" % ', '.join(labels_set))+ str(text_chunks) + str(labels_chunks)\n",
228 | " )\n",
229 | " \n",
230 | " text_chunk_by_word = ''.join(text_chunk).split()\n",
231 | " W = len(text_chunk_by_word)\n",
232 | " if W == 0: \n",
233 | "# assert labels_set == set(['O']), \"0-word chunking and non-0 label!\" + str(\n",
234 | "# text_chunks) + str(labels_chunks\n",
235 | "# )\n",
236 | " continue\n",
237 | " \n",
238 | " if labels_chunk[0] == 'O': labels_chunk_by_word = ['O'] * W\n",
239 | " elif W == 1: labels_chunk_by_word = [labels_chunk[0]]\n",
240 | " elif W == 2: labels_chunk_by_word = [labels_chunk[0], labels_chunk[-1]]\n",
241 | " else: labels_chunk_by_word = [\n",
242 | " labels_chunk[0]\n",
243 | " ] + [labels_chunk[1]] * (W - 2) + [\n",
244 | " labels_chunk[-1]\n",
245 | " ]\n",
246 | " \n",
247 | " sentence_by_word.extend(text_chunk_by_word)\n",
248 | " labels_by_word.extend(labels_chunk_by_word)\n",
249 | "\n",
250 | " assert len(sentence_by_word) == len(labels_by_word), \"Incorrect # of words in sentence!\" \n",
251 | " \n",
252 | " if len(sentence_by_word) == 0: continue\n",
253 | " \n",
254 | " text_by_word.append(sentence_by_word)\n",
255 | " all_labels_by_word.append(labels_by_word)\n",
256 | " return text_by_word, all_labels_by_word\n",
257 | "\n",
258 | "def reprocess_PHI_labels(folders, base_path='.', PHI_tag_type='PHI', match_text=True, dev_set_size=None):\n",
259 | " all_texts_by_patient, all_labels_by_patient = {}, {}\n",
260 | "\n",
261 | " for folder in folders:\n",
262 | " folder_dir = os.path.join(base_path, folder)\n",
263 | " xml_filenames = [x for x in os.listdir(folder_dir) if x.endswith('xml')]\n",
264 | " for xml_filename in xml_filenames:\n",
265 | " patient_num = int(xml_filename[:3])\n",
266 | " xml_filepath = os.path.join(folder_dir, xml_filename)\n",
267 | " \n",
268 | " text_by_char, labels_by_char = read_xml_file(\n",
269 | " xml_filepath,\n",
270 | " PHI_tag_type=PHI_tag_type,\n",
271 | " match_text=match_text\n",
272 | " )\n",
273 | " text_by_word, labels_by_word = merge_into_words(text_by_char, labels_by_char)\n",
274 | " \n",
275 | " if patient_num not in all_texts_by_patient:\n",
276 | " all_texts_by_patient[patient_num] = []\n",
277 | " all_labels_by_patient[patient_num] = []\n",
278 | " \n",
279 | " all_texts_by_patient[patient_num].extend(text_by_word)\n",
280 | " all_labels_by_patient[patient_num].extend(labels_by_word)\n",
281 | " \n",
282 | " patients = set(all_texts_by_patient.keys())\n",
283 | " \n",
284 | " if dev_set_size is None: train_patients, dev_patients = list(patients), []\n",
285 | " else:\n",
286 | " N_train = int(len(patients) * (1-dev_set_size))\n",
287 | " patients_random = np.random.permutation(list(patients))\n",
288 | " train_patients = list(patients_random[:N_train])\n",
289 | " dev_patients = list(patients_random[N_train:])\n",
290 | " \n",
291 | " train_texts, train_labels = [], []\n",
292 | " dev_texts, dev_labels = [], []\n",
293 | " \n",
294 | " for patient_num in train_patients:\n",
295 | " train_texts.extend(all_texts_by_patient[patient_num])\n",
296 | " train_labels.extend(all_labels_by_patient[patient_num])\n",
297 | "\n",
298 | " for patient_num in dev_patients:\n",
299 | " dev_texts.extend(all_texts_by_patient[patient_num])\n",
300 | " dev_labels.extend(all_labels_by_patient[patient_num])\n",
301 | "\n",
302 | "\n",
303 | " train_out_text_by_sentence = []\n",
304 | " for text, labels in zip(train_texts, train_labels):\n",
305 | " train_out_text_by_sentence.append('\\n'.join('%s %s' % x for x in zip(text, labels)))\n",
306 | " dev_out_text_by_sentence = []\n",
307 | " for text, labels in zip(dev_texts, dev_labels):\n",
308 | " dev_out_text_by_sentence.append('\\n'.join('%s %s' % x for x in zip(text, labels)))\n",
309 | "\n",
310 | " return '\\n\\n'.join(train_out_text_by_sentence), '\\n\\n'.join(dev_out_text_by_sentence)"
311 | ]
312 | },
313 | {
314 | "cell_type": "code",
315 | "execution_count": 4,
316 | "metadata": {},
317 | "outputs": [
318 | {
319 | "name": "stdout",
320 | "output_type": "stream",
321 | "text": [
322 | "Adjusting known error\n",
323 | "Adjusting known error\n"
324 | ]
325 | }
326 | ],
327 | "source": [
328 | "final_train_text, final_dev_text = reprocess_PHI_labels(\n",
329 | " ['training-PHI-Gold-Set1/', 'training-PHI-Gold-Set2/'], PHI_tag_type='ALL_CHILDREN',\n",
330 | " dev_set_size=0.1, match_text=True\n",
331 | ")"
332 | ]
333 | },
334 | {
335 | "cell_type": "code",
336 | "execution_count": 5,
337 | "metadata": {},
338 | "outputs": [],
339 | "source": [
340 | "test_text, _ = reprocess_PHI_labels(\n",
341 | " ['testing-PHI-Gold-fixed'], PHI_tag_type='ALL_CHILDREN', match_text=False, dev_set_size=None\n",
342 | ")"
343 | ]
344 | },
345 | {
346 | "cell_type": "code",
347 | "execution_count": 6,
348 | "metadata": {},
349 | "outputs": [
350 | {
351 | "name": "stdout",
352 | "output_type": "stream",
353 | "text": [
354 | "Record O\n",
355 | "date: O\n",
356 | "2067-11-30 B-DATE\n",
357 | "\n",
358 | "CARDIOLOGY O\n",
359 | "\n",
360 | "OCHILTREE B-HOSPITAL\n",
361 | "GENERAL I-HOSPITAL\n",
362 | "HOSPITAL I-HOSPITAL\n",
363 | "\n",
364 | "Reason O\n",
365 | "for O\n",
366 | "visit: O\n",
367 | "\n",
368 | "Follow-up O\n",
369 | "appointment O\n",
370 | "for O\n",
371 | "cardiomyopathy O\n",
372 | "\n",
373 | "Interval O\n",
374 | "History: O\n",
375 | "\n",
376 | "Mr O\n",
377 | "Lara B-PATIENT\n",
378 | "reports O\n",
379 | "no O\n",
380 | "problems O\n",
381 | "since O\n",
382 | "the O\n",
383 | "time O\n",
384 | "of O\n",
385 | "his O\n",
386 | "last O\n",
387 | "visit. O\n",
388 | "He O\n",
389 | "insistists O\n",
390 | "that O\n",
391 | "he O\n",
392 | "has O\n",
393 | "absolutely O\n",
394 | "abstained O\n",
395 | "from O\n",
396 | "etoh O\n",
397 | "since O\n",
398 | "his O\n",
399 | "last O\n",
400 | "visit. O\n",
401 | "\n",
402 | "Denies O\n",
403 | "fevers, O\n",
404 | "chills. O\n",
405 | "denies O\n",
406 | "palpitations O\n",
407 | "or O\n",
408 | "syncope. O\n",
409 | "\n",
410 | "Under O\n",
411 | "a O\n",
412 | "fair O\n",
413 | "am\n"
414 | ]
415 | }
416 | ],
417 | "source": [
418 | "print(final_train_text[:500])"
419 | ]
420 | },
421 | {
422 | "cell_type": "code",
423 | "execution_count": 7,
424 | "metadata": {},
425 | "outputs": [
426 | {
427 | "name": "stdout",
428 | "output_type": "stream",
429 | "text": [
430 | "Record O\n",
431 | "date: O\n",
432 | "2082-02-28 B-DATE\n",
433 | "\n",
434 | "c.c. O\n",
435 | "Preop O\n",
436 | "PE O\n",
437 | "for O\n",
438 | "cataract O\n",
439 | "extraction O\n",
440 | "on O\n",
441 | "3/7 B-DATE\n",
442 | "\n",
443 | "S: O\n",
444 | "ROS: O\n",
445 | "General: O\n",
446 | "no O\n",
447 | "history O\n",
448 | "of O\n",
449 | "fatigue, O\n",
450 | "weight O\n",
451 | "change, O\n",
452 | "loss O\n",
453 | "of O\n",
454 | "appetite, O\n",
455 | "or O\n",
456 | "weakness. O\n",
457 | "\n",
458 | "HEENT: O\n",
459 | "no O\n",
460 | "history O\n",
461 | "of O\n",
462 | "head O\n",
463 | "injury, O\n",
464 | "glaucoma, O\n",
465 | "tinnitus, O\n",
466 | "vertigo, O\n",
467 | "motion O\n",
468 | "sickness, O\n",
469 | "URI, O\n",
470 | "hearing O\n",
471 | "loss. O\n",
472 | "\n",
473 | "Cardiovascular: O\n",
474 | "+ O\n",
475 | "hypertension O\n",
476 | "- O\n",
477 | "adequate O\n",
478 | "co\n"
479 | ]
480 | }
481 | ],
482 | "source": [
483 | "print(final_dev_text[:400])"
484 | ]
485 | },
486 | {
487 | "cell_type": "code",
488 | "execution_count": 8,
489 | "metadata": {},
490 | "outputs": [
491 | {
492 | "name": "stdout",
493 | "output_type": "stream",
494 | "text": [
495 | "Record O\n",
496 | "date: O\n",
497 | "2080-02-18 B-DATE\n",
498 | "\n",
499 | "SDU O\n",
500 | "JAR O\n",
501 | "Admission O\n",
502 | "Note O\n",
503 | "\n",
504 | "Name: O\n",
505 | "Yosef B-PATIENT\n",
506 | "Villegas I-PATIENT\n",
507 | "\n",
508 | "MR: O\n",
509 | "8249813 B-MEDICALRECORD\n",
510 | "\n",
511 | "DOA: O\n",
512 | "2/17/80 B-DATE\n",
513 | "\n",
514 | "PCP: O\n",
515 | "Gilbert B-DOCTOR\n",
516 | "Perez I-DOCTOR\n",
517 | "\n",
518 | "Attending: O\n",
519 | "YBARRA B-DOCTOR\n",
520 | "\n",
521 | "CODE: O\n",
522 | "FULL O\n",
523 | "\n",
524 | "HPI: O\n",
525 | "70 B-AGE\n",
526 | "yo O\n",
527 | "M O\n",
528 | "with O\n",
529 | "NIDDM O\n",
530 | "admitted O\n",
531 | "for O\n",
532 | "cath O\n",
533 | "after O\n",
534 | "positive O\n",
535 | "MIBI. O\n",
536 | "Pt O\n",
537 | "has O\n",
538 | "had O\n",
539 | "increasing O\n",
540 | "CP O\n",
541 | "and O\n",
542 | "SOB O\n",
543 | "on O\n",
544 | "exert\n"
545 | ]
546 | }
547 | ],
548 | "source": [
549 | "print(test_text[:400])"
550 | ]
551 | },
552 | {
553 | "cell_type": "code",
554 | "execution_count": 10,
555 | "metadata": {},
556 | "outputs": [],
557 | "source": [
558 | "labels = {}\n",
559 | "for s in final_train_text, final_dev_text, test_text:\n",
560 | " for line in s.split('\\n'):\n",
561 | " if line == '': continue\n",
562 | " label = line.split()[-1]\n",
563 | " assert label == 'O' or label.startswith('B-') or label.startswith('I-'), \"label wrong! %s\" % label\n",
564 | " if label not in labels: labels[label] = 1\n",
565 | " else: labels[label] += 1"
566 | ]
567 | },
568 | {
569 | "cell_type": "code",
570 | "execution_count": 11,
571 | "metadata": {},
572 | "outputs": [
573 | {
574 | "data": {
575 | "text/plain": [
576 | "{'O': 777860,\n",
577 | " 'B-DATE': 12459,\n",
578 | " 'B-HOSPITAL': 2312,\n",
579 | " 'I-HOSPITAL': 1835,\n",
580 | " 'B-PATIENT': 2195,\n",
581 | " 'B-AGE': 1997,\n",
582 | " 'B-DOCTOR': 4797,\n",
583 | " 'I-DOCTOR': 3482,\n",
584 | " 'I-DATE': 1379,\n",
585 | " 'B-COUNTRY': 183,\n",
586 | " 'B-PROFESSION': 413,\n",
587 | " 'I-PROFESSION': 346,\n",
588 | " 'B-ORGANIZATION': 206,\n",
589 | " 'I-ORGANIZATION': 173,\n",
590 | " 'B-STREET': 352,\n",
591 | " 'I-STREET': 717,\n",
592 | " 'B-CITY': 654,\n",
593 | " 'I-CITY': 171,\n",
594 | " 'B-STATE': 504,\n",
595 | " 'B-ZIP': 352,\n",
596 | " 'I-PATIENT': 1192,\n",
597 | " 'B-MEDICALRECORD': 1033,\n",
598 | " 'I-MEDICALRECORD': 47,\n",
599 | " 'B-IDNUM': 456,\n",
600 | " 'B-USERNAME': 356,\n",
601 | " 'B-PHONE': 524,\n",
602 | " 'I-COUNTRY': 21,\n",
603 | " 'I-PHONE': 100,\n",
604 | " 'I-AGE': 10,\n",
605 | " 'I-IDNUM': 30,\n",
606 | " 'B-FAX': 10,\n",
607 | " 'I-FAX': 2,\n",
608 | " 'B-DEVICE': 15,\n",
609 | " 'B-EMAIL': 5,\n",
610 | " 'B-LOCATION-OTHER': 17,\n",
611 | " 'I-LOCATION-OTHER': 15,\n",
612 | " 'I-STATE': 18,\n",
613 | " 'B-URL': 2,\n",
614 | " 'I-URL': 4,\n",
615 | " 'B-BIOID': 1,\n",
616 | " 'B-HEALTHPLAN': 1,\n",
617 | " 'I-HEALTHPLAN': 1,\n",
618 | " 'I-DEVICE': 2}"
619 | ]
620 | },
621 | "execution_count": 11,
622 | "metadata": {},
623 | "output_type": "execute_result"
624 | }
625 | ],
626 | "source": [
627 | "labels"
628 | ]
629 | },
630 | {
631 | "cell_type": "code",
632 | "execution_count": 28,
633 | "metadata": {},
634 | "outputs": [],
635 | "source": [
636 | "with open('./processed/train.tsv', mode='w') as f:\n",
637 | " f.write(final_train_text)\n",
638 | "with open('./processed/dev.tsv', mode='w') as f:\n",
639 | " f.write(final_dev_text)\n",
640 | "with open('./processed/test.tsv', mode='w') as f:\n",
641 | " f.write(test_text)"
642 | ]
643 | },
644 | {
645 | "cell_type": "code",
646 | "execution_count": null,
647 | "metadata": {},
648 | "outputs": [],
649 | "source": []
650 | }
651 | ],
652 | "metadata": {
653 | "kernelspec": {
654 | "display_name": "Python 3",
655 | "language": "python",
656 | "name": "python3"
657 | },
658 | "language_info": {
659 | "codemirror_mode": {
660 | "name": "ipython",
661 | "version": 3
662 | },
663 | "file_extension": ".py",
664 | "mimetype": "text/x-python",
665 | "name": "python",
666 | "nbconvert_exporter": "python",
667 | "pygments_lexer": "ipython3",
668 | "version": "3.6.8"
669 | }
670 | },
671 | "nbformat": 4,
672 | "nbformat_minor": 2
673 | }
674 |
--------------------------------------------------------------------------------
/downstream_tasks/run_classifier.py:
--------------------------------------------------------------------------------
1 | # Code is adapted from the PyTorch pretrained BERT repo - See copyright & license below.
2 |
3 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
4 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 | from __future__ import absolute_import, division, print_function
19 |
20 | import argparse
21 | import csv
22 | import logging
23 | import os
24 | import random
25 | import sys
26 |
27 |
28 | import numpy as np
29 | import torch
30 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
31 | TensorDataset)
32 | from torch.utils.data.distributed import DistributedSampler
33 | from tqdm import tqdm, trange
34 |
35 | from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
36 | from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig, WEIGHTS_NAME, CONFIG_NAME
37 | from pytorch_pretrained_bert.tokenization import BertTokenizer
38 | from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
39 |
40 | #added
41 | import json
42 | from random import shuffle
43 | import math
44 |
45 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
46 | datefmt = '%m/%d/%Y %H:%M:%S',
47 | level = logging.INFO)
48 | logger = logging.getLogger(__name__)
49 |
50 |
51 |
52 | class InputFeatures(object):
53 | """A single set of features of data."""
54 |
55 | def __init__(self, input_ids, input_mask, segment_ids, label_id):
56 | self.input_ids = input_ids
57 | self.input_mask = input_mask
58 | self.segment_ids = segment_ids
59 | self.label_id = label_id
60 |
61 |
62 | class InputExample(object):
63 | """A single training/test example for simple sequence classification."""
64 |
65 | def __init__(self, guid, text_a, text_b=None, label=None):
66 | """Constructs a InputExample.
67 |
68 | Args:
69 | guid: Unique id for the example.
70 | text_a: string. The untokenized text of the first sequence. For single
71 | sequence tasks, only this sequence must be specified.
72 | text_b: (Optional) string. The untokenized text of the second sequence.
73 | Only must be specified for sequence pair tasks.
74 | label: (Optional) string. The label of the example. This should be
75 | specified for train and dev examples, but not for test examples.
76 | """
77 | self.guid = guid
78 | self.text_a = text_a
79 | self.text_b = text_b
80 | self.label = label
81 |
82 |
83 |
84 | class DataProcessor(object):
85 | """Base class for data converters for sequence classification data sets."""
86 |
87 | def get_train_examples(self, data_dir):
88 | """Gets a collection of `InputExample`s for the train set."""
89 | raise NotImplementedError()
90 |
91 | def get_dev_examples(self, data_dir):
92 | """Gets a collection of `InputExample`s for the dev set."""
93 | raise NotImplementedError()
94 |
95 | def get_test_examples(self, data_dir):
96 | """Gets a collection of `InputExample`s for the test set."""
97 | raise NotImplementedError()
98 |
99 | def get_labels(self):
100 | """Gets the list of labels for this data set."""
101 | raise NotImplementedError()
102 |
103 | @classmethod
104 | def _read_tsv(cls, input_file, quotechar=None):
105 | """Reads a tab separated value file."""
106 | with open(input_file, "r") as f:
107 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
108 | lines = []
109 | for line in reader:
110 | if sys.version_info[0] == 2:
111 | line = list(unicode(cell, 'utf-8') for cell in line)
112 | lines.append(line)
113 | return lines
114 |
115 |
116 | class MrpcProcessor(DataProcessor):
117 | """Processor for the MRPC data set (GLUE version)."""
118 |
119 | def get_train_examples(self, data_dir):
120 | """See base class."""
121 | logger.info("LOOKING AT {}".format(os.path.join(data_dir, "train.tsv")))
122 | return self._create_examples(
123 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
124 |
125 | def get_dev_examples(self, data_dir):
126 | """See base class."""
127 | return self._create_examples(
128 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
129 |
130 | def get_labels(self):
131 | """See base class."""
132 | return ["0", "1"]
133 |
134 | def _create_examples(self, lines, set_type):
135 | """Creates examples for the training and dev sets."""
136 | examples = []
137 | for (i, line) in enumerate(lines):
138 | if i == 0:
139 | continue
140 | guid = "%s-%s" % (set_type, i)
141 | text_a = line[3]
142 | text_b = line[4]
143 | label = line[0]
144 | examples.append(
145 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
146 | return examples
147 |
148 |
149 | class MnliProcessor(DataProcessor):
150 | """Processor for the MultiNLI data set (GLUE version)."""
151 |
152 | def get_train_examples(self, data_dir):
153 | """See base class."""
154 | return self._create_examples(
155 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
156 |
157 | def get_dev_examples(self, data_dir):
158 | """See base class."""
159 | return self._create_examples(
160 | self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")),
161 | "dev_matched")
162 |
163 | def get_labels(self):
164 | """See base class."""
165 | return ["contradiction", "entailment", "neutral"]
166 |
167 | def _create_examples(self, lines, set_type):
168 | """Creates examples for the training and dev sets."""
169 | examples = []
170 | for (i, line) in enumerate(lines):
171 | if i == 0:
172 | continue
173 | guid = "%s-%s" % (set_type, line[0])
174 | text_a = line[8]
175 | text_b = line[9]
176 | label = line[-1]
177 | examples.append(
178 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
179 | return examples
180 |
181 |
182 | class ColaProcessor(DataProcessor):
183 | """Processor for the CoLA data set (GLUE version)."""
184 |
185 | def get_train_examples(self, data_dir):
186 | """See base class."""
187 | return self._create_examples(
188 | self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
189 |
190 | def get_dev_examples(self, data_dir):
191 | """See base class."""
192 | return self._create_examples(
193 | self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
194 |
195 | def get_labels(self):
196 | """See base class."""
197 | return ["0", "1"]
198 |
199 | def _create_examples(self, lines, set_type):
200 | """Creates examples for the training and dev sets."""
201 | examples = []
202 | for (i, line) in enumerate(lines):
203 | guid = "%s-%s" % (set_type, i)
204 | text_a = line[3]
205 | label = line[1]
206 | examples.append(
207 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
208 | return examples
209 |
210 | # NEW
211 |
212 |
213 | class MedNLIProcessor(DataProcessor):
214 | def _chunks(self, l, n):
215 | """Yield successive n-sized chunks from l."""
216 | for i in range(0, len(l), n):
217 | yield l[i:i + n]
218 |
219 |
220 | def get_train_examples(self, data_dir):
221 | """Gets a collection of `InputExample`s for the train set."""
222 | file_path = os.path.join(data_dir, "mli_train_v1.jsonl")
223 | return self._create_examples(file_path)
224 |
225 | def get_dev_examples(self, data_dir):
226 | """Gets a collection of `InputExample`s for the dev set."""
227 | file_path = os.path.join(data_dir, "mli_dev_v1.jsonl")
228 | return self._create_examples(file_path)
229 |
230 | def get_test_examples(self, data_dir):
231 | """Gets a collection of `InputExample`s for the test set."""
232 | file_path = os.path.join(data_dir, "mli_test_v1.jsonl")
233 | return self._create_examples(file_path)
234 |
235 |
236 |
237 |
238 | def get_labels(self):
239 | """See base class."""
240 | return ["contradiction", "entailment", "neutral"]
241 |
242 | def _create_examples(self, file_path):
243 | examples = []
244 | with open(file_path, "r") as f:
245 | lines = f.readlines()
246 | for line in lines:
247 | example = json.loads(line)
248 | examples.append(
249 | InputExample(guid=example['pairID'], text_a=example['sentence1'],
250 | text_b=example['sentence2'], label=example['gold_label']))
251 |
252 | return examples
253 |
254 |
255 | def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):
256 | """Loads a data file into a list of `InputBatch`s."""
257 |
258 | label_map = {label : i for i, label in enumerate(label_list)}
259 |
260 | features = []
261 | max_len = 0
262 | for (ex_index, example) in enumerate(examples):
263 | tokens_a = tokenizer.tokenize(example.text_a)
264 |
265 | tokens_b = None
266 | if example.text_b:
267 | tokens_b = tokenizer.tokenize(example.text_b)
268 | seq_len = len(tokens_a) + len(tokens_b)
269 |
270 | # Modifies `tokens_a` and `tokens_b` in place so that the total
271 | # length is less than the specified length.
272 | # Account for [CLS], [SEP], [SEP] with "- 3"
273 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
274 | else:
275 | seq_len = len(tokens_a)
276 | # Account for [CLS] and [SEP] with "- 2"
277 | if len(tokens_a) > max_seq_length - 2:
278 | tokens_a = tokens_a[:(max_seq_length - 2)]
279 |
280 | if seq_len > max_len:
281 | max_len = seq_len
282 | # The convention in BERT is:
283 | # (a) For sequence pairs:
284 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
285 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
286 | # (b) For single sequences:
287 | # tokens: [CLS] the dog is hairy . [SEP]
288 | # type_ids: 0 0 0 0 0 0 0
289 | #
290 | # Where "type_ids" are used to indicate whether this is the first
291 | # sequence or the second sequence. The embedding vectors for `type=0` and
292 | # `type=1` were learned during pre-training and are added to the wordpiece
293 | # embedding vector (and position vector). This is not *strictly* necessary
294 | # since the [SEP] token unambigiously separates the sequences, but it makes
295 | # it easier for the model to learn the concept of sequences.
296 | #
297 | # For classification tasks, the first vector (corresponding to [CLS]) is
298 | # used as as the "sentence vector". Note that this only makes sense because
299 | # the entire model is fine-tuned.
300 | tokens = ["[CLS]"] + tokens_a + ["[SEP]"]
301 | segment_ids = [0] * len(tokens)
302 |
303 | if tokens_b:
304 | tokens += tokens_b + ["[SEP]"]
305 | segment_ids += [1] * (len(tokens_b) + 1)
306 |
307 | input_ids = tokenizer.convert_tokens_to_ids(tokens)
308 |
309 | # The mask has 1 for real tokens and 0 for padding tokens. Only real
310 | # tokens are attended to.
311 | input_mask = [1] * len(input_ids)
312 |
313 | # Zero-pad up to the sequence length.
314 | padding = [0] * (max_seq_length - len(input_ids))
315 | input_ids += padding
316 | input_mask += padding
317 | segment_ids += padding
318 |
319 | assert len(input_ids) == max_seq_length
320 | assert len(input_mask) == max_seq_length
321 | assert len(segment_ids) == max_seq_length
322 |
323 | label_id = label_map[example.label]
324 | if ex_index < 3:
325 | logger.info("*** Example ***")
326 | logger.info("guid: %s" % (example.guid))
327 | logger.info("tokens: %s" % " ".join(
328 | [str(x) for x in tokens]))
329 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
330 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
331 | logger.info(
332 | "segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
333 | logger.info("label: %s (id = %d)" % (example.label, label_id))
334 |
335 | features.append(
336 | InputFeatures(input_ids=input_ids,
337 | input_mask=input_mask,
338 | segment_ids=segment_ids,
339 | label_id=label_id))
340 |
341 | print('Max Sequence Length: %d' %max_len)
342 |
343 | return features
344 |
345 |
346 | def _truncate_seq_pair(tokens_a, tokens_b, max_length):
347 | """Truncates a sequence pair in place to the maximum length."""
348 |
349 | # This is a simple heuristic which will always truncate the longer sequence
350 | # one token at a time. This makes more sense than truncating an equal percent
351 | # of tokens from each, since if one sequence is very short then each token
352 | # that's truncated likely contains more information than a longer sequence.
353 | while True:
354 | total_length = len(tokens_a) + len(tokens_b)
355 | if total_length <= max_length:
356 | break
357 | if len(tokens_a) > len(tokens_b):
358 | tokens_a.pop()
359 | else:
360 | tokens_b.pop()
361 |
362 | def accuracy(out, labels):
363 | outputs = np.argmax(out, axis=1)
364 | return np.sum(outputs == labels)
365 |
366 | def setup_parser():
367 | parser = argparse.ArgumentParser()
368 |
369 | ## Required parameters
370 | parser.add_argument("--data_dir",
371 | default=None,
372 | type=str,
373 | required=True,
374 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
375 | parser.add_argument("--bert_model", default=None, type=str, required=True,
376 | help="Bert pre-trained model selected in the list: bert-base-uncased, "
377 | "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
378 | "bert-base-multilingual-cased, bert-base-chinese, biobert.")
379 | parser.add_argument("--task_name",
380 | default=None,
381 | type=str,
382 | required=True,
383 | help="The name of the task to train.")
384 | parser.add_argument("--output_dir",
385 | default=None,
386 | type=str,
387 | required=True,
388 | help="The output directory where the model predictions and checkpoints will be written.")
389 |
390 | ## Other parameters
391 | parser.add_argument("--cache_dir",
392 | default="",
393 | type=str,
394 | help="Where do you want to store the pre-trained models downloaded from s3")
395 | parser.add_argument("--max_seq_length",
396 | default=128,
397 | type=int,
398 | help="The maximum total input sequence length after WordPiece tokenization. \n"
399 | "Sequences longer than this will be truncated, and sequences shorter \n"
400 | "than this will be padded.")
401 | parser.add_argument("--do_train",
402 | action='store_true',
403 | help="Whether to run training.")
404 | parser.add_argument("--do_eval",
405 | action='store_true',
406 | help="Whether to run eval on the dev set.")
407 | parser.add_argument("--do_test",
408 | action='store_true',
409 | help="Whether to run eval on the test set.")
410 | parser.add_argument("--do_lower_case",
411 | action='store_true',
412 | help="Set this flag if you are using an uncased model.")
413 | parser.add_argument("--train_batch_size",
414 | default=32,
415 | type=int,
416 | help="Total batch size for training.")
417 | parser.add_argument("--eval_batch_size",
418 | default=8,
419 | type=int,
420 | help="Total batch size for eval.")
421 | parser.add_argument("--learning_rate",
422 | default=5e-5,
423 | type=float,
424 | help="The initial learning rate for Adam.")
425 | parser.add_argument("--num_train_epochs",
426 | default=3.0,
427 | type=float,
428 | help="Total number of training epochs to perform.")
429 | parser.add_argument("--warmup_proportion",
430 | default=0.1,
431 | type=float,
432 | help="Proportion of training to perform linear learning rate warmup for. "
433 | "E.g., 0.1 = 10%% of training.")
434 | parser.add_argument("--no_cuda",
435 | action='store_true',
436 | help="Whether not to use CUDA when available")
437 | parser.add_argument("--local_rank",
438 | type=int,
439 | default=-1,
440 | help="local_rank for distributed training on gpus")
441 | parser.add_argument('--seed',
442 | type=int,
443 | default=42,
444 | help="random seed for initialization")
445 | parser.add_argument('--gradient_accumulation_steps',
446 | type=int,
447 | default=1,
448 | help="Number of updates steps to accumulate before performing a backward/update pass.")
449 | parser.add_argument('--fp16',
450 | action='store_true',
451 | help="Whether to use 16-bit float precision instead of 32-bit")
452 | parser.add_argument('--loss_scale',
453 | type=float, default=0,
454 | help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
455 | "0 (default value): dynamic loss scaling.\n"
456 | "Positive power of 2: static loss scaling value.\n")
457 | parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
458 | parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
459 | parser.add_argument('--model_loc', type=str, default='', help="Specify the location of the bio or clinical bert model")
460 | return parser
461 |
462 | def main():
463 | parser = setup_parser()
464 | args = parser.parse_args()
465 |
466 | # specifies the path where the biobert or clinical bert model is saved
467 | if args.bert_model == 'biobert' or args.bert_model == 'clinical_bert':
468 | args.bert_model = args.model_loc
469 |
470 | print(args.bert_model)
471 |
472 | if args.server_ip and args.server_port:
473 | # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
474 | import ptvsd
475 | print("Waiting for debugger attach")
476 | ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
477 | ptvsd.wait_for_attach()
478 |
479 | processors = {
480 | "cola": ColaProcessor,
481 | "mnli": MnliProcessor,
482 | "mrpc": MrpcProcessor,
483 | "mednli": MedNLIProcessor
484 | }
485 |
486 | num_labels_task = {
487 | "cola": 2,
488 | "mnli": 3,
489 | "mrpc": 2,
490 | "mednli": 3
491 | }
492 |
493 | if args.local_rank == -1 or args.no_cuda:
494 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
495 | n_gpu = torch.cuda.device_count()
496 | else:
497 | torch.cuda.set_device(args.local_rank)
498 | device = torch.device("cuda", args.local_rank)
499 | n_gpu = 1
500 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
501 | torch.distributed.init_process_group(backend='nccl')
502 | logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
503 | device, n_gpu, bool(args.local_rank != -1), args.fp16))
504 |
505 | if args.gradient_accumulation_steps < 1:
506 | raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
507 | args.gradient_accumulation_steps))
508 |
509 | args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
510 |
511 | random.seed(args.seed)
512 | np.random.seed(args.seed)
513 | torch.manual_seed(args.seed)
514 | if n_gpu > 0:
515 | torch.cuda.manual_seed_all(args.seed)
516 |
517 | if not args.do_train and not args.do_eval:
518 | raise ValueError("At least one of `do_train` or `do_eval` must be True.")
519 |
520 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
521 | raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
522 | if not os.path.exists(args.output_dir):
523 | os.makedirs(args.output_dir)
524 |
525 | task_name = args.task_name.lower()
526 |
527 | if task_name not in processors:
528 | raise ValueError("Task not found: %s" % (task_name))
529 |
530 | processor = processors[task_name]()
531 | num_labels = num_labels_task[task_name]
532 | label_list = processor.get_labels()
533 |
534 |
535 | tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
536 |
537 |
538 | print('TRAIN')
539 | train = processor.get_train_examples(args.data_dir)
540 | print([(train[i].text_a,train[i].text_b, train[i].label) for i in range(3)])
541 | print('DEV')
542 | dev = processor.get_dev_examples(args.data_dir)
543 | print([(dev[i].text_a,dev[i].text_b, dev[i].label) for i in range(3)])
544 | print('TEST')
545 | test = processor.get_test_examples(args.data_dir)
546 | print([(test[i].text_a,test[i].text_b, test[i].label) for i in range(3)])
547 |
548 |
549 |
550 | train_examples = None
551 | num_train_optimization_steps = None
552 | if args.do_train:
553 | train_examples = processor.get_train_examples(args.data_dir)
554 | num_train_optimization_steps = int(
555 | len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
556 | if args.local_rank != -1:
557 | num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
558 |
559 | # Prepare model
560 | cache_dir = args.cache_dir if args.cache_dir else os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(args.local_rank))
561 | model = BertForSequenceClassification.from_pretrained(args.bert_model,
562 | cache_dir=cache_dir,
563 | num_labels = num_labels)
564 | if args.fp16:
565 | model.half()
566 | model.to(device)
567 | if args.local_rank != -1:
568 | try:
569 | from apex.parallel import DistributedDataParallel as DDP
570 | except ImportError:
571 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
572 |
573 | model = DDP(model)
574 | elif n_gpu > 1:
575 | model = torch.nn.DataParallel(model)
576 |
577 | # Prepare optimizer
578 | param_optimizer = list(model.named_parameters())
579 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
580 | optimizer_grouped_parameters = [
581 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
582 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
583 | ]
584 | if args.fp16:
585 | try:
586 | from apex.optimizers import FP16_Optimizer
587 | from apex.optimizers import FusedAdam
588 | except ImportError:
589 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
590 |
591 | optimizer = FusedAdam(optimizer_grouped_parameters,
592 | lr=args.learning_rate,
593 | bias_correction=False,
594 | max_grad_norm=1.0)
595 | if args.loss_scale == 0:
596 | optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
597 | else:
598 | optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
599 |
600 | else:
601 | optimizer = BertAdam(optimizer_grouped_parameters,
602 | lr=args.learning_rate,
603 | warmup=args.warmup_proportion,
604 | t_total=num_train_optimization_steps)
605 |
606 | global_step = 0
607 | nb_tr_steps = 0
608 | tr_loss = 0
609 | if args.do_train:
610 | train_features = convert_examples_to_features(
611 | train_examples, label_list, args.max_seq_length, tokenizer)
612 | logger.info("***** Running training *****")
613 | logger.info(" Num examples = %d", len(train_examples))
614 | logger.info(" Batch size = %d", args.train_batch_size)
615 | logger.info(" Num steps = %d", num_train_optimization_steps)
616 | all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
617 | all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
618 | all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
619 | all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
620 | train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
621 | if args.local_rank == -1:
622 | train_sampler = RandomSampler(train_data)
623 | else:
624 | train_sampler = DistributedSampler(train_data)
625 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
626 |
627 | model.train()
628 | for _ in trange(int(args.num_train_epochs), desc="Epoch"):
629 | tr_loss = 0
630 | nb_tr_examples, nb_tr_steps = 0, 0
631 | for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
632 | batch = tuple(t.to(device) for t in batch)
633 | input_ids, input_mask, segment_ids, label_ids = batch
634 | loss = model(input_ids, segment_ids, input_mask, label_ids)
635 | if n_gpu > 1:
636 | loss = loss.mean() # mean() to average on multi-gpu.
637 | if args.gradient_accumulation_steps > 1:
638 | loss = loss / args.gradient_accumulation_steps
639 |
640 | if args.fp16:
641 | optimizer.backward(loss)
642 | else:
643 | loss.backward()
644 |
645 | tr_loss += loss.item()
646 | nb_tr_examples += input_ids.size(0)
647 | nb_tr_steps += 1
648 | if (step + 1) % args.gradient_accumulation_steps == 0:
649 | if args.fp16:
650 | # modify learning rate with special warm up BERT uses
651 | # if args.fp16 is False, BertAdam is used that handles this automatically
652 | lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps, args.warmup_proportion)
653 | for param_group in optimizer.param_groups:
654 | param_group['lr'] = lr_this_step
655 | optimizer.step()
656 | optimizer.zero_grad()
657 | global_step += 1
658 |
659 | if args.do_train:
660 | # Save a trained model and the associated configuration
661 | model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
662 | output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
663 | torch.save(model_to_save.state_dict(), output_model_file)
664 | output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
665 | with open(output_config_file, 'w') as f:
666 | f.write(model_to_save.config.to_json_string())
667 |
668 | # Load a trained model and config that you have fine-tuned
669 | config = BertConfig(output_config_file)
670 | model = BertForSequenceClassification(config, num_labels=num_labels)
671 | model.load_state_dict(torch.load(output_model_file))
672 | else:
673 | model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels)
674 | model.to(device)
675 |
676 | if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
677 | eval_examples = processor.get_dev_examples(args.data_dir)
678 | eval_features = convert_examples_to_features(
679 | eval_examples, label_list, args.max_seq_length, tokenizer)
680 | logger.info("***** Running evaluation *****")
681 | logger.info(" Num examples = %d", len(eval_examples))
682 | logger.info(" Batch size = %d", args.eval_batch_size)
683 | all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
684 | all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
685 | all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
686 | all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
687 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
688 | # Run prediction for full data
689 | eval_sampler = SequentialSampler(eval_data)
690 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
691 |
692 | model.eval()
693 | eval_loss, eval_accuracy = 0, 0
694 | nb_eval_steps, nb_eval_examples = 0, 0
695 |
696 | for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
697 | input_ids = input_ids.to(device)
698 | input_mask = input_mask.to(device)
699 | segment_ids = segment_ids.to(device)
700 | label_ids = label_ids.to(device)
701 |
702 | with torch.no_grad():
703 | tmp_eval_loss = model(input_ids, segment_ids, input_mask, label_ids)
704 | logits = model(input_ids, segment_ids, input_mask)
705 |
706 | logits = logits.detach().cpu().numpy()
707 | label_ids = label_ids.to('cpu').numpy()
708 | tmp_eval_accuracy = accuracy(logits, label_ids)
709 |
710 | eval_loss += tmp_eval_loss.mean().item()
711 | eval_accuracy += tmp_eval_accuracy
712 |
713 | nb_eval_examples += input_ids.size(0)
714 | nb_eval_steps += 1
715 |
716 | eval_loss = eval_loss / nb_eval_steps
717 | eval_accuracy = eval_accuracy / nb_eval_examples
718 | loss = tr_loss/nb_tr_steps if args.do_train else None
719 | result = {'eval_loss': eval_loss,
720 | 'eval_accuracy': eval_accuracy,
721 | 'global_step': global_step,
722 | 'loss': loss}
723 |
724 | output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
725 | with open(output_eval_file, "w") as writer:
726 | logger.info("***** Eval results *****")
727 | for key in sorted(result.keys()):
728 | logger.info(" %s = %s", key, str(result[key]))
729 | writer.write("%s = %s\n" % (key, str(result[key])))
730 |
731 | if args.do_test and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
732 | test_examples = processor.get_test_examples(args.data_dir)
733 | test_features = convert_examples_to_features(
734 | test_examples, label_list, args.max_seq_length, tokenizer)
735 | logger.info("***** Running testing *****")
736 | logger.info(" Num examples = %d", len(test_examples))
737 | logger.info(" Batch size = %d", args.eval_batch_size)
738 | all_input_ids = torch.tensor([f.input_ids for f in test_features], dtype=torch.long)
739 | all_input_mask = torch.tensor([f.input_mask for f in test_features], dtype=torch.long)
740 | all_segment_ids = torch.tensor([f.segment_ids for f in test_features], dtype=torch.long)
741 | all_label_ids = torch.tensor([f.label_id for f in test_features], dtype=torch.long)
742 | test_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
743 | # Run prediction for full data
744 | test_sampler = SequentialSampler(test_data)
745 | test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=args.eval_batch_size)
746 |
747 | model.eval()
748 | test_loss, test_accuracy = 0, 0
749 | nb_test_steps, nb_test_examples = 0, 0
750 |
751 | for input_ids, input_mask, segment_ids, label_ids in tqdm(test_dataloader, desc="Testing"):
752 | input_ids = input_ids.to(device)
753 | input_mask = input_mask.to(device)
754 | segment_ids = segment_ids.to(device)
755 | label_ids = label_ids.to(device)
756 |
757 | with torch.no_grad():
758 | tmp_test_loss = model(input_ids, segment_ids, input_mask, label_ids)
759 | logits = model(input_ids, segment_ids, input_mask)
760 |
761 | logits = logits.detach().cpu().numpy()
762 | label_ids = label_ids.to('cpu').numpy()
763 | tmp_test_accuracy = accuracy(logits, label_ids)
764 |
765 | test_loss += tmp_test_loss.mean().item()
766 | test_accuracy += tmp_test_accuracy
767 |
768 | nb_test_examples += input_ids.size(0)
769 | nb_test_steps += 1
770 |
771 | test_loss = test_loss / nb_test_steps
772 | test_accuracy = test_accuracy / nb_test_examples
773 | loss = tr_loss/nb_tr_steps if args.do_train else None
774 | result = {'test_loss': test_loss,
775 | 'test_accuracy': test_accuracy,
776 | 'global_step': global_step,
777 | 'loss': loss}
778 |
779 | output_test_file = os.path.join(args.output_dir, "test_results.txt")
780 | with open(output_test_file, "w") as writer:
781 | logger.info("***** Test results *****")
782 | for key in sorted(result.keys()):
783 | logger.info(" %s = %s", key, str(result[key]))
784 | writer.write("%s = %s\n" % (key, str(result[key])))
785 |
786 | if __name__ == "__main__":
787 | main()
788 |
--------------------------------------------------------------------------------
/downstream_tasks/run_ner.py:
--------------------------------------------------------------------------------
1 | #! usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | """
4 | This code is adapted from the kyzhouhzau/BERT-NER repo with several modifications.
5 | """
6 | from __future__ import absolute_import
7 | from __future__ import division
8 | from __future__ import print_function
9 |
10 | import sys
11 |
12 | import collections
13 | import os
14 | import math
15 | import modeling
16 | import optimization
17 | import tokenization
18 | import tensorflow as tf
19 | from tensorflow.python.ops import math_ops
20 | import tf_metrics
21 | import pickle
22 | import numpy as np
23 | import itertools
24 | import json
25 | from random import shuffle
26 | import random
27 |
28 | flags = tf.flags
29 |
30 | FLAGS = flags.FLAGS
31 |
32 | flags.DEFINE_string(
33 | "task_name", None, "The name of the task to train."
34 | )
35 |
36 | flags.DEFINE_string(
37 | "data_dir", None,
38 | "The input datadir.",
39 | )
40 |
41 | flags.DEFINE_string(
42 | "output_dir", None,
43 | "The output directory where the model checkpoints will be written."
44 | )
45 |
46 | flags.DEFINE_string(
47 | "bert_config_file", None,
48 | "The config json file corresponding to the pre-trained BERT model."
49 | )
50 |
51 | flags.DEFINE_string("vocab_file", None,
52 | "The vocabulary file that the BERT model was trained on.")
53 |
54 | flags.DEFINE_string(
55 | "init_checkpoint", None,
56 | "Initial checkpoint (usually from a pre-trained BERT model)."
57 | )
58 |
59 | flags.DEFINE_bool(
60 | "do_lower_case", False,
61 | "Whether to lower case the input text."
62 | )
63 |
64 | flags.DEFINE_integer(
65 | "max_seq_length", 128,
66 | "The maximum total input sequence length after WordPiece tokenization."
67 | )
68 |
69 | flags.DEFINE_bool(
70 | "do_train", True,
71 | "Whether to run training."
72 | )
73 | flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
74 |
75 | flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.")
76 |
77 | flags.DEFINE_bool("do_predict", True,"Whether to run the model in inference mode on the test set.")
78 |
79 | flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
80 |
81 | flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.")
82 |
83 | flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.")
84 |
85 | flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
86 |
87 | flags.DEFINE_float("num_train_epochs", 10.0, "Total number of training epochs to perform.")
88 |
89 | flags.DEFINE_float(
90 | "warmup_proportion", 0.1,
91 | "Proportion of training to perform linear learning rate warmup for. "
92 | "E.g., 0.1 = 10% of training.")
93 |
94 | flags.DEFINE_integer("save_checkpoints_steps", 1000,
95 | "How often to save the model checkpoint.")
96 |
97 | flags.DEFINE_integer("keep_checkpoint_max", 1,
98 | "How many model checkpoints to keep.")
99 |
100 | flags.DEFINE_integer("iterations_per_loop", 1000,
101 | "How many steps to make in each estimator call.")
102 |
103 | flags.DEFINE_integer("cross_val_sz", 10, "Number of cross validation folds")
104 |
105 | tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
106 |
107 | flags.DEFINE_integer(
108 | "num_tpu_cores", 8,
109 | "Only used if `use_tpu` is True. Total number of TPU cores to use.")
110 |
111 | class InputExample(object):
112 | """A single training/test example for simple token classification."""
113 |
114 | def __init__(self, guid, tokens, text=None, labels=None):
115 | """Constructs a InputExample.
116 |
117 | Args:
118 | guid: Unique id for the example.
119 | text: (Optional) string. The untokenized text of the sequence.
120 | tokens: list of strings. The tokenized sentence. Each token should have a
121 | corresponding label for train and dev samples.
122 | label: (Optional) list of strings. The labels of the example. This should be
123 | specified for train and dev examples, but not for test examples.
124 | """
125 | self.guid = guid
126 | self.text = text
127 | self.tokens = tokens
128 | self.labels = labels
129 |
130 | def __str__(self):
131 | return self.__repr__()
132 |
133 | def __repr__(self):
134 | l = [
135 | "id: {}".format(self.guid),
136 | "tokens: {}".format(" ".join(self.tokens)),
137 | ]
138 | if self.text is not None:
139 | l.append("text: {}".format(self.text))
140 |
141 | if self.labels is not None:
142 | l.append("labels: {}".format(" ".join(self.labels)))
143 |
144 | return ", ".join(l)
145 |
146 |
147 | class InputFeatures(object):
148 | """A single set of features of data."""
149 |
150 | def __init__(self, input_ids, input_mask, segment_ids, label_ids,):
151 | self.input_ids = input_ids
152 | self.input_mask = input_mask
153 | self.segment_ids = segment_ids
154 | self.label_ids = label_ids
155 | #self.label_mask = label_mask
156 |
157 | class DataProcessor(object):
158 | """Base class for data converters for sequence classification data sets."""
159 |
160 | def get_train_examples(self, data_dir):
161 | """Gets a collection of `InputExample`s for the train set."""
162 | raise NotImplementedError()
163 |
164 | def get_dev_examples(self, data_dir):
165 | """Gets a collection of `InputExample`s for the dev set."""
166 | raise NotImplementedError()
167 |
168 | def get_labels(self):
169 | """Gets the list of labels for this data set."""
170 | raise NotImplementedError()
171 |
172 | @classmethod
173 | def _read_data(cls, input_file):
174 | """Reads in data where each line has the word and its corresponding
175 | label separated by whitespace. Each sentence is separated by a blank
176 | line. E.g.:
177 |
178 | Identification O
179 | of O
180 | APC2 O
181 | , O
182 | a O
183 | homologue O
184 | of O
185 | the O
186 | adenomatous B-Disease
187 | polyposis I-Disease
188 | coli I-Disease
189 | tumour I-Disease
190 | suppressor O
191 | . O
192 |
193 | The O
194 | adenomatous B-Disease
195 | polyposis I-Disease
196 | ...
197 | """
198 | with open(input_file) as f:
199 | lines = []
200 | words = []
201 | labels = []
202 | for line in f:
203 | line = line.strip()
204 | if len(line) == 0: #i.e. we're in between sentences
205 | assert len(words) == len(labels)
206 | if len(words) == 0:
207 | continue
208 | lines.append([words, labels])
209 | words = []
210 | labels = []
211 | continue
212 |
213 | word = line.split()[0]
214 | label = line.split()[-1]
215 | words.append(word)
216 | labels.append(label)
217 |
218 | #TODO: see if there's an off by one error here
219 | return lines
220 |
221 | @classmethod
222 | def _create_example(self, lines, set_type):
223 | examples = []
224 | for (i, line) in enumerate(lines):
225 | guid = "%s-%s" % (set_type, i)
226 | words,labels = line
227 | words = [tokenization.convert_to_unicode(w) for w in words]
228 | labels = [tokenization.convert_to_unicode(l) for l in labels]
229 | examples.append(InputExample(guid=guid, tokens=words, labels=labels))
230 | return examples
231 |
232 | @classmethod
233 | def _chunks(self, l, n):
234 | """Yield successive n-sized chunks from l."""
235 | for i in range(0, len(l), n):
236 | yield l[i:i + n]
237 |
238 | def write_cv_to_file(self, evaluation, test, n):
239 | with open(os.path.join(FLAGS.data_dir, str(n) + '_eval'),'w') as w:
240 | for example in evaluation:
241 | for t, l in zip(example.tokens, example.labels):
242 | w.write("%s %s\n" %(t, l))
243 | w.write("\n")
244 |
245 |
246 | with open(os.path.join(FLAGS.data_dir, str(n) + '_test'),'w') as test_w:
247 | for test_example in test:
248 | for t, l in zip(test_example.tokens, test_example.labels):
249 |
250 | test_w.write("%s %s\n" %(t, l))
251 | test_w.write("\n")
252 |
253 |
254 |
255 | def get_cv_examples(self, splits, n, cv_sz=10):
256 | # note, when n=9 (10th split), this recovers the original train, dev, test split
257 |
258 | dev = splits[(n-1)%cv_sz] #4 #0 #3 #1
259 | test = splits[n] #0 #1 #4 #2
260 | # print('train ind: %d-%d' %((n+1)%cv_sz, (n-1)%cv_sz))
261 | # print('dev ind: %d' %((n-1)%cv_sz))
262 | # print('test ind`: %d' %n)
263 | if (n+1)%cv_sz > (n-1)%cv_sz:
264 | train = splits[:(n-1)%cv_sz] + splits[(n+1)%cv_sz:]
265 | else:
266 | train = splits[(n+1)%cv_sz:(n-1)%cv_sz] #1-3 #2-4 #0-2 #3-0s
267 | train = list(itertools.chain.from_iterable(train))
268 | print("Train size: %d, dev size: %d, test size: %d, total: %d" %(len(train), len(dev), len(test), (len(train)+len(dev)+len(test))))
269 | self.write_cv_to_file(dev, test, n)
270 | return(train, dev, test)
271 |
272 | def create_cv_examples(self, data_dir, cv_sz=10):
273 | train_examples = self.get_train_examples(data_dir)
274 | dev_examples = self.get_dev_examples(data_dir)
275 | test_examples = self.get_test_examples(data_dir)
276 | print('num train examples: %d, num eval examples: %d, num test examples: %d' %(len(train_examples), len(dev_examples), len(test_examples)))
277 | print('Total dataset size: %d' %(len(train_examples) + len(dev_examples) + len(test_examples)))
278 | random.seed(42)
279 | train_dev = train_examples + dev_examples
280 | random.shuffle(train_dev)
281 | split_sz = math.ceil(len(train_dev)/(cv_sz-1))
282 | print('Split size: %d' %split_sz)
283 | splits = list(self._chunks(train_dev, split_sz))
284 | print('Num splits: %d' %(len(splits) + 1))
285 | splits = splits + [test_examples]
286 | print('len splits: ', [len(s) for s in splits])
287 | return splits
288 |
289 |
290 |
291 | class NerProcessor(DataProcessor):
292 | def get_train_examples(self, data_dir):
293 | return self._create_example(
294 | self._read_data(os.path.join(data_dir, "train_dev.tsv")), "train"
295 | )
296 |
297 | def get_dev_examples(self, data_dir):
298 | return self._create_example(
299 | self._read_data(os.path.join(data_dir, "devel.tsv")), "dev"
300 | )
301 |
302 | def get_test_examples(self,data_dir):
303 | test_examples = self._create_example(
304 | self._read_data(os.path.join(data_dir, "test.tsv")), "test")
305 | #print(test_examples)
306 | return test_examples
307 |
308 |
309 | def get_labels(self):
310 | return ["B", "I", "O", "X", "[CLS]", "[SEP]"]
311 |
312 |
313 | class NCBIDiseaseProcessor(DataProcessor):
314 | #https://www.ncbi.nlm.nih.gov/CBBresearch/Dogan/DISEASE/
315 | def get_train_examples(self, data_dir):
316 | return self._create_example(
317 | self._read_data(os.path.join(data_dir, "train.tsv")), "train"
318 | )
319 |
320 | def get_dev_examples(self, data_dir):
321 | return self._create_example(
322 | self._read_data(os.path.join(data_dir, "devel.tsv")), "dev"
323 | )
324 |
325 | def get_test_examples(self,data_dir):
326 | return self._create_example(
327 | self._read_data(os.path.join(data_dir, "test.tsv")), "test")
328 |
329 |
330 | def get_labels(self):
331 | return ["B", "I", "O", "X", "[CLS]", "[SEP]"]
332 |
333 |
334 | class i2b22010Processor(DataProcessor):
335 |
336 | def get_train_examples(self, data_dir):
337 | return self._create_example(
338 | self._read_data(os.path.join(data_dir, "train.tsv")), "train"
339 | )
340 |
341 | def get_dev_examples(self, data_dir):
342 | return self._create_example(
343 | self._read_data(os.path.join(data_dir, "dev.tsv")), "dev"
344 | )
345 |
346 | def get_test_examples(self,data_dir):
347 | print('Path: ', os.path.join(data_dir, "test.tsv"))
348 | test_examples = self._create_example(
349 | self._read_data(os.path.join(data_dir, "test.tsv")), "test")
350 | print(test_examples[-5:])
351 | return test_examples
352 |
353 | def get_labels(self):
354 | return ["B-problem", "I-problem", "B-treatment", "I-treatment", 'B-test', 'I-test', 'O', "X", "[CLS]", "[SEP]"]
355 |
356 | class i2b22006Processor(DataProcessor):
357 |
358 |
359 | def get_train_examples(self, data_dir):
360 | return self._create_example(
361 | self._read_data(os.path.join(data_dir, "train.conll")), "train"
362 | )
363 |
364 | def get_dev_examples(self, data_dir):
365 | return self._create_example(
366 | self._read_data(os.path.join(data_dir, "dev.conll")), "dev"
367 | )
368 |
369 | def get_test_examples(self,data_dir):
370 | return self._create_example(
371 | self._read_data(os.path.join(data_dir, "test.conll")), "test"
372 | )
373 |
374 | def get_labels(self):
375 | return ["B-ID", "I-ID", "B-HOSPITAL", "I-HOSPITAL", 'B-PATIENT', 'I-PATIENT', 'B-PHONE', 'I-PHONE',
376 | 'B-DATE', 'I-DATE', 'B-DOCTOR', 'I-DOCTOR', 'B-LOCATION', 'I-LOCATION', 'B-AGE', 'I-AGE',
377 | 'O', "X", "[CLS]", "[SEP]"]
378 |
379 | class i2b22012Processor(DataProcessor):
380 |
381 | def get_train_examples(self, data_dir):
382 | return self._create_example(
383 | self._read_data(os.path.join(data_dir, "train.tsv")), "train"
384 | )
385 |
386 | def get_dev_examples(self, data_dir):
387 | return self._create_example(
388 | self._read_data(os.path.join(data_dir, "dev.tsv")), "dev"
389 | )
390 |
391 | def get_test_examples(self,data_dir):
392 | return self._create_example(
393 | self._read_data(os.path.join(data_dir, "test.tsv")), "test"
394 | )
395 |
396 | def get_labels(self):
397 | return ['B-OCCURRENCE','I-OCCURRENCE','B-EVIDENTIAL','I-EVIDENTIAL','B-TREATMENT','I-TREATMENT','B-CLINICAL_DEPT',
398 | 'I-CLINICAL_DEPT','B-PROBLEM','I-PROBLEM','B-TEST','I-TEST','O', "X", "[CLS]", "[SEP]"]
399 |
400 | class i2b22014Processor(DataProcessor):
401 |
402 | def get_train_examples(self, data_dir):
403 | return self._create_example(
404 | self._read_data(os.path.join(data_dir, "train.tsv")), "train"
405 | )
406 |
407 | def get_dev_examples(self, data_dir):
408 | return self._create_example(
409 | self._read_data(os.path.join(data_dir, "dev.tsv")), "dev"
410 | )
411 |
412 | def get_test_examples(self,data_dir):
413 | return self._create_example(
414 | self._read_data(os.path.join(data_dir, "test.tsv")), "test"
415 | )
416 |
417 | def get_labels(self):
418 | return ["B-IDNUM", "I-IDNUM", "B-HOSPITAL", "I-HOSPITAL", 'B-PATIENT', 'I-PATIENT', 'B-PHONE', 'I-PHONE',
419 | 'B-DATE', 'I-DATE', 'B-DOCTOR', 'I-DOCTOR', 'B-LOCATION-OTHER', 'I-LOCATION-OTHER', 'B-AGE', 'I-AGE', 'B-BIOID', 'I-BIOID',
420 | 'B-STATE', 'I-STATE','B-ZIP', 'I-ZIP', 'B-HEALTHPLAN', 'I-HEALTHPLAN', 'B-ORGANIZATION', 'I-ORGANIZATION',
421 | 'B-MEDICALRECORD', 'I-MEDICALRECORD', 'B-CITY', 'I-CITY', 'B-STREET', 'I-STREET', 'B-COUNTRY', 'I-COUNTRY',
422 | 'B-URL', 'I-URL',
423 | 'B-USERNAME', 'I-USERNAME', 'B-PROFESSION', 'I-PROFESSION', 'B-FAX', 'I-FAX', 'B-EMAIL', 'I-EMAIL', 'B-DEVICE', 'I-DEVICE',
424 | 'O', "X", "[CLS]", "[SEP]"]
425 |
426 | def write_tokens(tokens,tok_to_orig_map, mode, cv_iter):
427 | #print('MODE: %s' %mode)
428 | if mode == "test" or mode == "eval":
429 | path = os.path.join(FLAGS.output_dir, str(cv_iter) + "_token_"+mode+".txt")
430 | wf = open(path,'a')
431 | for token in tokens:
432 | if token!="**NULL**":
433 | wf.write(token+'\n')
434 | wf.write('\n')
435 | wf.close()
436 | with open(os.path.join(FLAGS.output_dir, str(cv_iter) + "_tok_to_orig_map_"+mode+".txt"),'a') as w:
437 | w.write("-1\n") #correspond to [CLS]
438 | for ind in tok_to_orig_map:
439 | w.write(str(ind)+'\n')
440 | w.write("-1\n") #correspond to [SEP]
441 | w.write('\n')
442 |
443 |
444 |
445 | def convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer,cv_iter, mode):
446 | label_map = {}
447 | for (i, label) in enumerate(label_list,1):
448 | label_map[label] = i
449 | with open(os.path.join(FLAGS.output_dir, 'label2id.pkl'),'wb') as w:
450 | pickle.dump(label_map,w)
451 | orig_tokens = example.tokens
452 | orig_labels = example.labels
453 | tokens = []
454 | labels = []
455 | tok_to_orig_map = []
456 |
457 |
458 | for i, word in enumerate(orig_tokens):
459 | token = tokenizer.tokenize(word)
460 | tokens.extend(token)
461 | orig_label = orig_labels[i]
462 | for m in range(len(token)):
463 | tok_to_orig_map.append(i)
464 | if m == 0:
465 | labels.append(orig_label)
466 | else:
467 | labels.append("X")
468 |
469 |
470 | if len(tokens) >= max_seq_length - 1:
471 | tokens = tokens[0:(max_seq_length - 2)]
472 | labels = labels[0:(max_seq_length - 2)]
473 | ntokens = []
474 | segment_ids = []
475 | label_ids = []
476 | ntokens.append("[CLS]")
477 | segment_ids.append(0)
478 | label_ids.append(label_map["[CLS]"])
479 | for i, token in enumerate(tokens):
480 | ntokens.append(token)
481 | segment_ids.append(0)
482 | label_ids.append(label_map[labels[i]])
483 | ntokens.append("[SEP]")
484 | segment_ids.append(0)
485 | label_ids.append(label_map["[SEP]"])
486 | input_ids = tokenizer.convert_tokens_to_ids(ntokens)
487 | input_mask = [1] * len(input_ids)
488 | while len(input_ids) < max_seq_length:
489 | input_ids.append(0)
490 | input_mask.append(0)
491 | segment_ids.append(0)
492 | label_ids.append(0)
493 | ntokens.append("**NULL**")
494 |
495 | assert len(input_ids) == max_seq_length
496 | assert len(input_mask) == max_seq_length
497 | assert len(segment_ids) == max_seq_length
498 | assert len(label_ids) == max_seq_length
499 |
500 | if ex_index < 3:
501 | tf.logging.info("*** Example ***")
502 | tf.logging.info("guid: %s" % (example.guid))
503 | tf.logging.info("tokens: %s" % " ".join(
504 | [tokenization.printable_text(x) for x in tokens]))
505 | tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
506 | tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
507 | tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
508 | tf.logging.info("label_ids: %s" % " ".join([str(x) for x in label_ids]))
509 |
510 | feature = InputFeatures(
511 | input_ids=input_ids,
512 | input_mask=input_mask,
513 | segment_ids=segment_ids,
514 | label_ids=label_ids,
515 | )
516 | write_tokens(ntokens,tok_to_orig_map, mode, cv_iter)
517 | return feature
518 |
519 |
520 | def filed_based_convert_examples_to_features(
521 | examples, label_list, max_seq_length, tokenizer, output_file,cv_iter, mode=None
522 | ):
523 | writer = tf.python_io.TFRecordWriter(output_file)
524 | for (ex_index, example) in enumerate(examples):
525 | if ex_index % 5000 == 0:
526 | tf.logging.info("Writing example %d of %d" % (ex_index, len(examples)))
527 | feature = convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer, cv_iter, mode )
528 |
529 | def create_int_feature(values):
530 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
531 | return f
532 |
533 | features = collections.OrderedDict()
534 | features["input_ids"] = create_int_feature(feature.input_ids)
535 | features["input_mask"] = create_int_feature(feature.input_mask)
536 | features["segment_ids"] = create_int_feature(feature.segment_ids)
537 | features["label_ids"] = create_int_feature(feature.label_ids)
538 | tf_example = tf.train.Example(features=tf.train.Features(feature=features))
539 | writer.write(tf_example.SerializeToString())
540 |
541 | def file_based_input_fn_builder(input_file, seq_length, is_training, drop_remainder):
542 | name_to_features = {
543 | "input_ids": tf.FixedLenFeature([seq_length], tf.int64),
544 | "input_mask": tf.FixedLenFeature([seq_length], tf.int64),
545 | "segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
546 | "label_ids": tf.FixedLenFeature([seq_length], tf.int64),
547 |
548 | }
549 |
550 | def _decode_record(record, name_to_features):
551 | example = tf.parse_single_example(record, name_to_features)
552 | for name in list(example.keys()):
553 | t = example[name]
554 | if t.dtype == tf.int64:
555 | t = tf.to_int32(t)
556 | example[name] = t
557 | return example
558 |
559 | def input_fn(params):
560 | batch_size = params["batch_size"]
561 | d = tf.data.TFRecordDataset(input_file)
562 | if is_training:
563 | d = d.repeat()
564 | d = d.shuffle(buffer_size=100)
565 | d = d.apply(tf.contrib.data.map_and_batch(
566 | lambda record: _decode_record(record, name_to_features),
567 | batch_size=batch_size,
568 | drop_remainder=drop_remainder
569 | ))
570 | return d
571 | return input_fn
572 |
573 |
574 | def create_model(bert_config, is_training, input_ids, input_mask,
575 | segment_ids, labels, num_labels, use_one_hot_embeddings):
576 | model = modeling.BertModel(
577 | config=bert_config,
578 | is_training=is_training,
579 | input_ids=input_ids,
580 | input_mask=input_mask,
581 | token_type_ids=segment_ids,
582 | use_one_hot_embeddings=use_one_hot_embeddings
583 | )
584 |
585 | output_layer = model.get_sequence_output()
586 |
587 | hidden_size = output_layer.shape[-1].value
588 |
589 | output_weight = tf.get_variable(
590 | "output_weights", [num_labels, hidden_size],
591 | initializer=tf.truncated_normal_initializer(stddev=0.02)
592 | )
593 | output_bias = tf.get_variable(
594 | "output_bias", [num_labels], initializer=tf.zeros_initializer()
595 | )
596 | with tf.variable_scope("loss"):
597 | if is_training:
598 | output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)
599 | output_layer = tf.reshape(output_layer, [-1, hidden_size])
600 | logits = tf.matmul(output_layer, output_weight, transpose_b=True)
601 | logits = tf.nn.bias_add(logits, output_bias)
602 | logits = tf.reshape(logits, [-1, FLAGS.max_seq_length, num_labels])
603 |
604 | ##########################################################################
605 | log_probs = tf.nn.log_softmax(logits, axis=-1)
606 | one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)
607 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
608 | loss = tf.reduce_sum(per_example_loss)
609 | probabilities = tf.nn.softmax(logits, axis=-1)
610 | predict = tf.argmax(probabilities,axis=-1)
611 | return (loss, per_example_loss, logits, log_probs, predict)
612 | ##########################################################################
613 |
614 | def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate,
615 | num_train_steps, num_warmup_steps, use_tpu,
616 | use_one_hot_embeddings):
617 | def model_fn(features, labels, mode, params):
618 | tf.logging.info("*** Features ***")
619 | for name in sorted(features.keys()):
620 | tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
621 | input_ids = features["input_ids"]
622 | input_mask = features["input_mask"]
623 | segment_ids = features["segment_ids"]
624 | label_ids = features["label_ids"]
625 | is_training = (mode == tf.estimator.ModeKeys.TRAIN)
626 |
627 | (total_loss, per_example_loss, logits, log_probs, predicts) = create_model(
628 | bert_config, is_training, input_ids, input_mask,segment_ids, label_ids,
629 | num_labels, use_one_hot_embeddings)
630 | tvars = tf.trainable_variables()
631 | scaffold_fn = None
632 | if init_checkpoint:
633 | print('INIT_CHECKPOINT')
634 | (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars,init_checkpoint)
635 | if use_tpu:
636 | def tpu_scaffold():
637 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
638 | return tf.train.Scaffold()
639 | scaffold_fn = tpu_scaffold
640 | else:
641 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
642 | tf.logging.info("**** Trainable Variables ****")
643 |
644 | for var in tvars:
645 | init_string = ""
646 | if var.name in initialized_variable_names:
647 | init_string = ", *INIT_FROM_CKPT*"
648 | tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
649 | init_string)
650 | output_spec = None
651 | if mode == tf.estimator.ModeKeys.TRAIN:
652 | train_op = optimization.create_optimizer(
653 | total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)
654 | output_spec = tf.contrib.tpu.TPUEstimatorSpec(
655 | mode=mode,
656 | loss=total_loss,
657 | train_op=train_op,
658 | scaffold_fn=scaffold_fn)
659 | elif mode == tf.estimator.ModeKeys.EVAL:
660 | def metric_fn(per_example_loss, label_ids, logits):
661 | predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
662 | precision = tf_metrics.precision(label_ids,predictions,num_labels,[1,2],average="macro")
663 | recall = tf_metrics.recall(label_ids,predictions,num_labels,[1,2],average="macro")
664 | f = tf_metrics.f1(label_ids,predictions,num_labels,[1,2],average="macro")
665 | #
666 | return {
667 | "eval_precision":precision,
668 | "eval_recall":recall,
669 | "eval_f": f,
670 | }
671 | eval_metrics = (metric_fn, [per_example_loss, label_ids, logits])
672 | output_spec = tf.contrib.tpu.TPUEstimatorSpec(
673 | mode=mode,
674 | loss=total_loss,
675 | eval_metrics=eval_metrics,
676 | scaffold_fn=scaffold_fn)
677 |
678 | elif mode == tf.estimator.ModeKeys.PREDICT:
679 | output_spec = tf.contrib.tpu.TPUEstimatorSpec(
680 | mode=mode,
681 | predictions={"prediction": predicts, "log_probs": log_probs},
682 | scaffold_fn=scaffold_fn
683 | )
684 | return output_spec
685 | return model_fn
686 |
687 | def read_tok_file(token_path):
688 | tokens = list()
689 | with open(token_path, 'r') as reader:
690 | for line in reader:
691 | tok = line.strip()
692 | if tok == '[CLS]':
693 | tmp_toks = [tok]
694 | elif tok == '[SEP]':
695 | tmp_toks.append(tok)
696 | tokens.append(tmp_toks)
697 | elif tok == '':
698 | continue
699 | else:
700 | tmp_toks.append(tok)
701 | return tokens
702 |
703 | def main(_):
704 | tf.logging.set_verbosity(tf.logging.INFO)
705 |
706 | processors = {
707 | "ncbi": NCBIDiseaseProcessor,
708 | "i2b2_2010": i2b22010Processor,
709 | "i2b2_2006": i2b22006Processor,
710 | "i2b2_2014": i2b22014Processor,
711 | "i2b2_2012": i2b22012Processor
712 | }
713 | if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
714 | raise ValueError("At least one of `do_train` or `do_eval` or `do_predict` must be True.")
715 |
716 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)
717 |
718 | if FLAGS.max_seq_length > bert_config.max_position_embeddings:
719 | raise ValueError(
720 | "Cannot use sequence length %d because the BERT model "
721 | "was only trained up to sequence length %d" %
722 | (FLAGS.max_seq_length, bert_config.max_position_embeddings))
723 |
724 | task_name = FLAGS.task_name.lower()
725 | if task_name not in processors:
726 | raise ValueError("Task not found: %s" % (task_name))
727 | processor = processors[task_name]()
728 |
729 | label_list = processor.get_labels()
730 |
731 |
732 | splits = processor.create_cv_examples(FLAGS.data_dir, cv_sz=FLAGS.cross_val_sz)
733 |
734 | for cv_iter in range(FLAGS.cross_val_sz):
735 | #for cv_iter in [9]:
736 | # if you only want to use the true train, val, test split, then use the last CV split.
737 | # We ran out of time to do cross validation so we only used the original train/val/test split.
738 |
739 |
740 | tok_eval = os.path.join(FLAGS.output_dir, str(cv_iter) + "_token_eval.txt")
741 | tok_test = os.path.join(FLAGS.output_dir, str(cv_iter) + "_token_test.txt")
742 |
743 | if os.path.exists(tok_eval):
744 | os.remove(tok_eval)
745 | if os.path.exists(tok_test):
746 | os.remove(tok_test)
747 | tokenizer = tokenization.FullTokenizer(
748 | vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
749 | tpu_cluster_resolver = None
750 | if FLAGS.use_tpu and FLAGS.tpu_name:
751 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
752 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
753 |
754 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
755 |
756 | run_config = tf.contrib.tpu.RunConfig(
757 | cluster=tpu_cluster_resolver,
758 | master=FLAGS.master,
759 | model_dir=FLAGS.output_dir,
760 | save_checkpoints_steps=FLAGS.save_checkpoints_steps,
761 | keep_checkpoint_max=FLAGS.keep_checkpoint_max,
762 | tpu_config=tf.contrib.tpu.TPUConfig(
763 | iterations_per_loop=FLAGS.iterations_per_loop,
764 | num_shards=FLAGS.num_tpu_cores,
765 | per_host_input_for_training=is_per_host))
766 |
767 | train_examples = None
768 | num_train_steps = None
769 | num_warmup_steps = None
770 |
771 | if FLAGS.do_train:
772 | train_examples, eval_examples, test_examples = processor.get_cv_examples(splits, cv_iter, cv_sz=FLAGS.cross_val_sz)
773 | print('train sz: %d, val size: %d, test size: %d' %(len(train_examples), len(eval_examples), len(test_examples)))
774 | num_train_steps = int(
775 | len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs)
776 | num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)
777 |
778 | model_fn = model_fn_builder(
779 | bert_config=bert_config,
780 | num_labels=len(label_list)+1,
781 | init_checkpoint=FLAGS.init_checkpoint,
782 | learning_rate=FLAGS.learning_rate,
783 | num_train_steps=num_train_steps,
784 | num_warmup_steps=num_warmup_steps,
785 | use_tpu=FLAGS.use_tpu,
786 | use_one_hot_embeddings=FLAGS.use_tpu)
787 |
788 | estimator = tf.contrib.tpu.TPUEstimator(
789 | use_tpu=FLAGS.use_tpu,
790 | model_fn=model_fn,
791 | config=run_config,
792 | train_batch_size=FLAGS.train_batch_size,
793 | eval_batch_size=FLAGS.eval_batch_size,
794 | predict_batch_size=FLAGS.predict_batch_size)
795 |
796 | if FLAGS.do_train:
797 | train_file = os.path.join(FLAGS.output_dir, "train.tf_record")
798 | filed_based_convert_examples_to_features(
799 | train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file, cv_iter)
800 | tf.logging.info("***** Running training *****")
801 | tf.logging.info(" Num examples = %d", len(train_examples))
802 | tf.logging.info(" Batch size = %d", FLAGS.train_batch_size)
803 | tf.logging.info(" Num steps = %d", num_train_steps)
804 | train_input_fn = file_based_input_fn_builder(
805 | input_file=train_file,
806 | seq_length=FLAGS.max_seq_length,
807 | is_training=True,
808 | drop_remainder=True)
809 | estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
810 |
811 | if FLAGS.do_eval:
812 | _, eval_examples, _ = processor.get_cv_examples(splits, cv_iter, cv_sz=FLAGS.cross_val_sz)
813 |
814 | eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
815 | filed_based_convert_examples_to_features(
816 | eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file, cv_iter, mode="eval", )
817 |
818 | tf.logging.info("***** Running evaluation *****")
819 | tf.logging.info(" Num examples = %d", len(eval_examples))
820 | tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size)
821 | eval_steps = None
822 | if FLAGS.use_tpu:
823 | eval_steps = int(len(eval_examples) / FLAGS.eval_batch_size)
824 | eval_drop_remainder = True if FLAGS.use_tpu else False
825 | eval_input_fn = file_based_input_fn_builder(
826 | input_file=eval_file,
827 | seq_length=FLAGS.max_seq_length,
828 | is_training=False,
829 | drop_remainder=eval_drop_remainder)
830 | result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)
831 | output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
832 | with open(output_eval_file, "w") as writer:
833 | tf.logging.info("***** Eval results *****")
834 | for key in sorted(result.keys()):
835 | tf.logging.info(" %s = %s", key, str(result[key]))
836 | writer.write("%s = %s\n" % (key, str(result[key])))
837 |
838 | #added
839 | eval_token_path = os.path.join(FLAGS.output_dir, str(cv_iter) + "_token_eval.txt")
840 | eval_tokens = read_tok_file(eval_token_path)
841 |
842 | eval_result = estimator.predict(input_fn=eval_input_fn)
843 | output_predict_file = os.path.join(FLAGS.output_dir, str(cv_iter) + "_label_eval.txt")
844 |
845 | with open(os.path.join(FLAGS.output_dir, 'label2id.pkl'),'rb') as rf:
846 | label2id = pickle.load(rf)
847 | id2label = {value:key for key,value in label2id.items()}
848 |
849 | with open(output_predict_file,'w') as p_writer:
850 | for pidx, prediction in enumerate(eval_result):
851 | slen = len(eval_tokens[pidx])
852 | output_line = "\n".join(id2label[id] if id!=0 else id2label[3] for id in prediction['prediction'][:slen]) + "\n" #change to O tag
853 | p_writer.write(output_line)
854 | p_writer.write('\n')
855 |
856 |
857 |
858 | if FLAGS.do_predict:
859 |
860 | _,_, predict_examples = processor.get_cv_examples(splits, cv_iter, cv_sz=FLAGS.cross_val_sz)
861 |
862 | predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
863 | filed_based_convert_examples_to_features(predict_examples, label_list,
864 | FLAGS.max_seq_length, tokenizer,
865 | predict_file,cv_iter, mode="test")
866 |
867 | tf.logging.info("***** Running prediction*****")
868 | tf.logging.info(" Num examples = %d", len(predict_examples))
869 | tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size)
870 | if FLAGS.use_tpu:
871 | # Warning: According to tpu_estimator.py Prediction on TPU is an
872 | # experimental feature and hence not supported here
873 | raise ValueError("Prediction in TPU not supported")
874 | predict_drop_remainder = True if FLAGS.use_tpu else False
875 | predict_input_fn = file_based_input_fn_builder(
876 | input_file=predict_file,
877 | seq_length=FLAGS.max_seq_length,
878 | is_training=False,
879 | drop_remainder=predict_drop_remainder)
880 | prf = estimator.evaluate(input_fn=predict_input_fn, steps=None)
881 | tf.logging.info("***** token-level Test evaluation results *****")
882 | for key in sorted(prf.keys()):
883 | tf.logging.info(" %s = %s", key, str(prf[key]))
884 |
885 | test_token_path = os.path.join(FLAGS.output_dir, str(cv_iter) + "_token_test.txt")
886 | test_tokens = read_tok_file(test_token_path)
887 |
888 |
889 | result = estimator.predict(input_fn=predict_input_fn)
890 | output_predict_file = os.path.join(FLAGS.output_dir, str(cv_iter) + "_label_test.txt")
891 |
892 | with open(os.path.join(FLAGS.output_dir, 'label2id.pkl'),'rb') as rf:
893 | label2id = pickle.load(rf)
894 | id2label = {value:key for key,value in label2id.items()}
895 |
896 | with open(output_predict_file,'w') as p_writer:
897 | for pidx, prediction in enumerate(result):
898 | slen = len(test_tokens[pidx])
899 | output_line = "\n".join(id2label[id] if id!=0 else id2label[3] for id in prediction['prediction'][:slen]) + "\n" #change to O tag
900 | p_writer.write(output_line)
901 | p_writer.write('\n')
902 |
903 |
904 | if __name__ == "__main__":
905 | flags.mark_flag_as_required("data_dir")
906 | flags.mark_flag_as_required("task_name")
907 | flags.mark_flag_as_required("vocab_file")
908 | flags.mark_flag_as_required("bert_config_file")
909 | flags.mark_flag_as_required("output_dir")
910 | tf.app.run()
911 |
--------------------------------------------------------------------------------